finish bedrock

This commit is contained in:
Liujian
2025-03-12 21:14:23 +08:00
parent ce50f78d81
commit 29f9ce6db8
4 changed files with 494 additions and 280 deletions

View File

@@ -0,0 +1,409 @@
package bedrock
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/eolinker/eosc/log"
"github.com/mitchellh/mapstructure"
"github.com/aws/aws-sdk-go/aws/awserr"
openai "github.com/sashabaranov/go-openai"
"github.com/aws/aws-sdk-go/private/protocol/eventstream"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/eolinker/eosc/common/bean"
http_service "github.com/eolinker/eosc/eocontext/http-context"
"github.com/eolinker/eosc/eocontext"
ai_convert "github.com/eolinker/apinto/ai-convert"
"github.com/eolinker/eosc"
)
var (
accessConfigManager ai_convert.IModelAccessConfigManager
)
func init() {
bean.Autowired(&accessConfigManager)
ai_convert.RegisterConverterCreateFunc("bedrock", Create)
}
type Config struct {
AccessKey string `json:"aws_access_key_id"`
SecretKey string `json:"aws_secret_access_key"`
Region string `json:"aws_region"`
}
func Create(cfg string) (ai_convert.IConverter, error) {
var conf Config
err := json.Unmarshal([]byte(cfg), &conf)
if err != nil {
return nil, err
}
if conf.AccessKey == "" {
return nil, fmt.Errorf("aws_access_key_id is required")
}
if conf.SecretKey == "" {
return nil, fmt.Errorf("aws_secret_access_key is required")
}
return NewConvert(conf.AccessKey, conf.SecretKey, conf.Region), nil
}
type Convert struct {
signer *v4.Signer
region string
}
func NewConvert(ak string, sk string, region string) *Convert {
return &Convert{
signer: v4.NewSigner(credentials.NewStaticCredentials(ak, sk, "")),
region: region,
}
}
var (
currentPath = "/model/%s/converse"
streamPath = "/model/%s/converse-stream"
)
func (c *Convert) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
provider := ai_convert.GetAIProvider(ctx)
model := ai_convert.GetAIModel(ctx)
modelCfg, has := accessConfigManager.Get(fmt.Sprintf("%s$%s", provider, model))
region := ""
if has {
model = modelCfg.Config()["model"]
region = modelCfg.Config()["region"]
}
if region == "" {
region = c.region
}
base := fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", region)
balanceHandler, err := ai_convert.NewBalanceHandler("", base, 0)
if err != nil {
return err
}
ctx.SetBalance(balanceHandler)
httpContext, err := http_service.Assert(ctx)
if err != nil {
return err
}
body, err := httpContext.Proxy().Body().RawBody()
if err != nil {
return err
}
chatRequest := eosc.NewBase[ai_convert.Request](extender)
err = json.Unmarshal(body, chatRequest)
if err != nil {
return fmt.Errorf("unmarshal body error: %v, body: %s", err, string(body))
}
messages := make([]Message, 0, len(chatRequest.Config.Messages))
systemMessage := make([]*Content, 0)
for _, m := range chatRequest.Config.Messages {
if m.Role == "system" {
systemMessage = append(systemMessage, &Content{Text: m.Content})
} else {
messages = append(messages, Message{
Role: m.Role,
Content: []*Content{{Text: m.Content}},
})
}
}
chatRequest.SetAppend("messages", messages)
chatRequest.SetAppend("system", systemMessage)
path := fmt.Sprintf(currentPath, model)
if chatRequest.Config.Stream {
path = fmt.Sprintf(streamPath, model)
}
uri := fmt.Sprintf("%s%s", base, path)
httpContext.Proxy().URI().SetPath(path)
body, _ = json.Marshal(chatRequest)
httpContext.Proxy().Body().SetRaw("application/json", body)
headers, err := signRequest(c.signer, region, uri, http.Header{}, string(body))
if err != nil {
return err
}
for k, v := range headers {
httpContext.Proxy().Header().SetHeader(k, strings.Join(v, ";"))
}
httpContext.Proxy().Body().SetRaw("application/json", body)
httpContext.Response().AppendStreamFunc(c.streamHandler)
return nil
}
func (c *Convert) ResponseConvert(ctx eocontext.EoContext) error {
httpContext, err := http_service.Assert(ctx)
if err != nil {
return err
}
if httpContext.Response().StatusCode() != 200 {
return nil
}
body := httpContext.Response().GetBody()
var origin BedrockResponse
err = json.Unmarshal(body, &origin)
if err != nil {
return err
}
resp := ConvertBedrockToOpenAI(ctx.RequestId(), ai_convert.GetAIModel(ctx), origin, false)
body, err = json.Marshal(resp)
if err != nil {
return err
}
httpContext.Response().SetBody(body)
return nil
}
func (c *Convert) streamHandler(ctx http_service.IHttpContext, p []byte) ([]byte, error) {
// 创建一个缓冲区来存储转换后的SSE格式数据
var sseBuffer bytes.Buffer
// 生成一个唯一的请求ID
requestID := ctx.RequestId()
model := ai_convert.GetAIModel(ctx)
response, err := EventStreamToJSON(p)
if err != nil {
log.Errorf("event stream to json error: %v", err)
return p, nil
}
for _, r := range response {
switch r.Header.EventType {
case "contentBlockDelta":
{
if r.Payload.Delta != nil {
data := openai.ChatCompletionStreamResponse{
ID: requestID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: model,
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: r.Payload.Delta.Text,
Role: "assistant",
},
},
},
}
content, _ := json.Marshal(data)
sseBuffer.WriteString(fmt.Sprintf("data: %s\n\n", string(content)))
}
}
case "messageStop":
{
usage := new(openai.Usage)
if r.Payload.Usage != nil {
usage.PromptTokens = r.Payload.Usage.InputTokens
usage.CompletionTokens = r.Payload.Usage.OutputTokens
usage.TotalTokens = r.Payload.Usage.TotalTokens
ai_convert.SetAIModelInputToken(ctx, r.Payload.Usage.InputTokens)
ai_convert.SetAIModelOutputToken(ctx, r.Payload.Usage.OutputTokens)
ai_convert.SetAIModelTotalToken(ctx, r.Payload.Usage.TotalTokens)
}
stopReason := openai.FinishReasonStop
//end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
switch r.Payload.StopReason {
case "max_tokens":
stopReason = openai.FinishReasonLength
case "content_filtered":
stopReason = openai.FinishReasonContentFilter
}
data := openai.ChatCompletionStreamResponse{
ID: ctx.RequestId(),
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: ai_convert.GetAIModel(ctx),
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
FinishReason: stopReason,
},
},
Usage: usage,
}
content, _ := json.Marshal(data)
sseBuffer.WriteString(fmt.Sprintf("data: %s\n\n", string(content)))
sseBuffer.WriteString("data: [DONE]\n\n")
return sseBuffer.Bytes(), nil
}
}
}
// 返回转换后的SSE格式数据
return sseBuffer.Bytes(), nil
//// 对响应数据进行划分
//inputToken := GetAIModelInputToken(ctx)
//outputToken := 0
//totalToken := inputToken
//scanner := bufio.NewScanner(bytes.NewReader(p))
//// Check the content encoding and convert to UTF-8 if necessary.
//encoding := ctx.Response().Headers().Get("content-encoding")
//for scanner.Scan() {
// line := scanner.Text()
// if encoding != "utf-8" && encoding != "" {
// tmp, err := encoder.ToUTF8(encoding, []byte(line))
// if err != nil {
// log.Errorf("convert to utf-8 error: %v, line: %s", err, line)
// return p, nil
// }
// if ctx.Response().StatusCode() != 200 || (o.checkErr != nil && !o.checkErr(ctx, tmp)) {
// if o.errorCallback != nil {
// o.errorCallback(ctx, tmp)
// }
// return p, nil
// }
// line = string(tmp)
// }
// line = strings.TrimPrefix(line, "data:")
// if line == "" || strings.Trim(line, " ") == "[DONE]" {
// return p, nil
// }
// var resp openai.ChatCompletionResponse
// err := json.Unmarshal([]byte(line), &resp)
// if err != nil {
// return p, nil
// }
// if len(resp.Choices) > 0 {
// outputToken += getTokens(resp.Choices[0].Message.Content)
// totalToken += outputToken
// }
//}
//if err := scanner.Err(); err != nil {
// log.Errorf("scan error: %v", err)
// return p, nil
//}
//
//SetAIModelInputToken(ctx, inputToken)
//SetAIModelOutputToken(ctx, outputToken)
//SetAIModelTotalToken(ctx, totalToken)
}
func signRequest(signer *v4.Signer, region string, uri string, headers http.Header, body string) (http.Header, error) {
request, err := http.NewRequest(http.MethodPost, uri, nil)
if err != nil {
return nil, err
}
request.Header = headers.Clone()
_, err = signer.Sign(request, strings.NewReader(body), "bedrock", region, time.Now())
if err != nil {
return nil, err
}
return request.Header, nil
}
// EventStreamToJSON 将 Amazon EventStream 格式的数据转换为 JSON 格式
func EventStreamToJSON(eventStreamData []byte) ([]StreamResponse, error) {
// 创建一个结果数组
var result []StreamResponse
// 创建一个 EventStream 解码器
decoder := eventstream.NewDecoder(bytes.NewReader(eventStreamData))
// 循环读取所有事件
for {
// 读取下一个消息
msg, err := decoder.Decode(nil)
if err != nil {
if err == io.EOF {
break // 正常结束
}
// 处理 AWS 错误
if awsErr, ok := err.(awserr.Error); ok {
return nil, fmt.Errorf("AWS Error: %s - %s", awsErr.Code(), awsErr.Message())
}
return nil, fmt.Errorf("解析 EventStream 时出错: %v", err)
}
// 将消息转换为 map
eventMap := make(map[string]interface{})
// 处理消息头
headers := make(map[string]interface{})
for _, header := range msg.Headers {
headers[header.Name] = header.Value
}
eventMap["headers"] = headers
// 处理消息体
if len(msg.Payload) > 0 {
// 尝试将负载解析为 JSON
var payload interface{}
if err := json.Unmarshal(msg.Payload, &payload); err == nil {
eventMap["payload"] = payload
} else {
// 如果不是有效的 JSON则作为字符串处理
eventMap["payload"] = string(msg.Payload)
}
}
var streamResponse StreamResponse
err = mapstructure.Decode(eventMap, &streamResponse)
if err != nil {
return nil, err
}
// 将事件添加到结果数组
result = append(result, streamResponse)
}
return result, nil
}
// ExtractTextFromEventStream 从 EventStream 中提取文本内容
// 这个函数专门用于从 Bedrock 模型响应中提取生成的文本
func ExtractTextFromEventStream(eventStreamData []byte) (string, error) {
var fullText string
decoder := eventstream.NewDecoder(bytes.NewReader(eventStreamData))
for {
msg, err := decoder.Decode(nil)
if err != nil {
if err == io.EOF {
break
}
return "", err
}
// 解析消息负载
var response map[string]interface{}
if err := json.Unmarshal(msg.Payload, &response); err != nil {
continue // 跳过无法解析的消息
}
// 根据 Bedrock 的响应格式提取文本
// 注意:具体的字段名可能需要根据使用的模型进行调整
if completion, ok := response["completion"].(string); ok {
fullText += completion
} else if output, ok := response["output"].(map[string]interface{}); ok {
if text, ok := output["text"].(string); ok {
fullText += text
}
}
}
return fullText, nil
}

View File

@@ -1,223 +0,0 @@
package bedrock
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/eolinker/eosc/common/bean"
http_service "github.com/eolinker/eosc/eocontext/http-context"
"github.com/eolinker/eosc/eocontext"
ai_convert "github.com/eolinker/apinto/ai-convert"
"github.com/eolinker/eosc"
)
var (
accessConfigManager ai_convert.IModelAccessConfigManager
)
func init() {
bean.Autowired(&accessConfigManager)
ai_convert.RegisterConverterCreateFunc("bedrock", Create)
}
type Config struct {
AccessKey string `json:"aws_access_key_id"`
SecretKey string `json:"aws_secret_access_key"`
Region string `json:"aws_region"`
}
func Create(cfg string) (ai_convert.IConverter, error) {
var conf Config
err := json.Unmarshal([]byte(cfg), &conf)
if err != nil {
return nil, err
}
if conf.AccessKey == "" {
return nil, fmt.Errorf("aws_access_key_id is required")
}
if conf.SecretKey == "" {
return nil, fmt.Errorf("aws_secret_access_key is required")
}
return NewConvert(conf.AccessKey, conf.SecretKey, conf.Region), nil
}
type Convert struct {
signer *v4.Signer
region string
}
func NewConvert(ak string, sk string, region string) *Convert {
return &Convert{
signer: v4.NewSigner(credentials.NewStaticCredentials(ak, sk, "")),
region: region,
}
}
var (
currentPath = "/model/%s/converse"
streamPath = "/model/%s/converse-stream"
)
func (c *Convert) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
provider := ai_convert.GetAIProvider(ctx)
model := ai_convert.GetAIModel(ctx)
modelCfg, has := accessConfigManager.Get(fmt.Sprintf("%s$%s", provider, model))
region := ""
if has {
model = modelCfg.Config()["model"]
region = modelCfg.Config()["region"]
}
if region == "" {
region = c.region
}
base := fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", region)
balanceHandler, err := ai_convert.NewBalanceHandler("", base, 0)
if err != nil {
return err
}
ctx.SetBalance(balanceHandler)
httpContext, err := http_service.Assert(ctx)
if err != nil {
return err
}
body, err := httpContext.Proxy().Body().RawBody()
if err != nil {
return err
}
chatRequest := eosc.NewBase[ai_convert.Request](extender)
err = json.Unmarshal(body, chatRequest)
if err != nil {
return fmt.Errorf("unmarshal body error: %v, body: %s", err, string(body))
}
messages := make([]Message, 0, len(chatRequest.Config.Messages))
systemMessage := make([]*Content, 0)
for _, m := range chatRequest.Config.Messages {
if m.Role == "system" {
systemMessage = append(systemMessage, &Content{Text: m.Content})
} else {
messages = append(messages, Message{
Role: m.Role,
Content: []*Content{{Text: m.Content}},
})
}
}
chatRequest.SetAppend("messages", messages)
chatRequest.SetAppend("system", systemMessage)
path := fmt.Sprintf(currentPath, model)
if chatRequest.Config.Stream {
path = fmt.Sprintf(streamPath, model)
}
uri := fmt.Sprintf("%s%s", base, path)
httpContext.Proxy().URI().SetPath(path)
body, _ = json.Marshal(chatRequest)
httpContext.Proxy().Body().SetRaw("application/json", body)
headers, err := signRequest(c.signer, region, uri, http.Header{}, string(body))
if err != nil {
return err
}
for k, v := range headers {
httpContext.Proxy().Header().SetHeader(k, strings.Join(v, ";"))
}
httpContext.Proxy().Body().SetRaw("application/json", body)
return nil
}
func (c *Convert) ResponseConvert(ctx eocontext.EoContext) error {
//httpContext, err := http_service.Assert(ctx)
//if err != nil {
// return err
//}
//if httpContext.Response().StatusCode() != 200 {
// return nil
//}
//body := httpContext.Response().GetBody()
//data := eosc.NewBase[Response](nil)
//err = json.Unmarshal(body, data)
//if err != nil {
// return err
//}
//responseBody := &ai_convert.Response{}
//
//body, err = json.Marshal(responseBody)
//if err != nil {
// return err
//}
//httpContext.Response().AppendStreamFunc(c.streamHandler)
//httpContext.Response().SetBody(body)
return nil
}
func (c *Convert) streamHandler(ctx http_service.IHttpContext, p []byte) ([]byte, error) {
//// 对响应数据进行划分
//inputToken := GetAIModelInputToken(ctx)
//outputToken := 0
//totalToken := inputToken
//scanner := bufio.NewScanner(bytes.NewReader(p))
//// Check the content encoding and convert to UTF-8 if necessary.
//encoding := ctx.Response().Headers().Get("content-encoding")
//for scanner.Scan() {
// line := scanner.Text()
// if encoding != "utf-8" && encoding != "" {
// tmp, err := encoder.ToUTF8(encoding, []byte(line))
// if err != nil {
// log.Errorf("convert to utf-8 error: %v, line: %s", err, line)
// return p, nil
// }
// if ctx.Response().StatusCode() != 200 || (o.checkErr != nil && !o.checkErr(ctx, tmp)) {
// if o.errorCallback != nil {
// o.errorCallback(ctx, tmp)
// }
// return p, nil
// }
// line = string(tmp)
// }
// line = strings.TrimPrefix(line, "data:")
// if line == "" || strings.Trim(line, " ") == "[DONE]" {
// return p, nil
// }
// var resp openai.ChatCompletionResponse
// err := json.Unmarshal([]byte(line), &resp)
// if err != nil {
// return p, nil
// }
// if len(resp.Choices) > 0 {
// outputToken += getTokens(resp.Choices[0].Message.Content)
// totalToken += outputToken
// }
//}
//if err := scanner.Err(); err != nil {
// log.Errorf("scan error: %v", err)
// return p, nil
//}
//
//SetAIModelInputToken(ctx, inputToken)
//SetAIModelOutputToken(ctx, outputToken)
//SetAIModelTotalToken(ctx, totalToken)
return p, nil
}
func signRequest(signer *v4.Signer, region string, uri string, headers http.Header, body string) (http.Header, error) {
request, err := http.NewRequest(http.MethodPost, uri, nil)
if err != nil {
return nil, err
}
request.Header = headers.Clone()
_, err = signer.Sign(request, strings.NewReader(body), "bedrock", region, time.Now())
if err != nil {
return nil, err
}
return request.Header, nil
}

View File

@@ -1,5 +1,11 @@
package bedrock
import (
"time"
openai "github.com/sashabaranov/go-openai"
)
type ClientRequest struct {
Messages []*Message `json:"message,omitempty"`
System *Content `json:"system,omitempty"`
@@ -21,15 +27,6 @@ type InferenceConfig struct {
TopP float64 `json:"topP"`
}
type Response struct {
Output Output `json:"output"`
StopReason string `json:"stopReason"`
}
type Output struct {
Message *Message `json:"message"`
}
// BedrockResponse 代表 Amazon Bedrock 的 JSON 响应格式
type BedrockResponse struct {
Metrics struct {
@@ -51,50 +48,77 @@ type BedrockResponse struct {
} `json:"usage"`
}
//
//// ConvertBedrockToOpenAI 通用转换方法
//func ConvertBedrockToOpenAI(requestId string, model string, bedrockResp BedrockResponse) openai.ChatCompletionResponse {
// // 提取文本内容
// textContent := ""
// if len(bedrockResp.Output.Message.Content) > 0 {
// textContent = bedrockResp.Output.Message.Content[0].Text
// }
// //const (
// // FinishReasonStop FinishReason = "stop"
// // FinishReasonLength FinishReason = "length"
// // FinishReasonFunctionCall FinishReason = "function_call"
// // FinishReasonToolCalls FinishReason = "tool_calls"
// // FinishReasonContentFilter FinishReason = "content_filter"
// // FinishReasonNull FinishReason = "null"
// //)
// stopReason := openai.FinishReasonStop
// //end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
// switch bedrockResp.StopReason {
// case "max_tokens":
// stopReason = openai.FinishReasonLength
// case "content_filtered":
// stopReason = openai.FinishReasonContentFilter
// }
// openai.FinishReason(bedrockResp.StopReason)
// return openai.ChatCompletionResponse{
// ID: requestId,
// Object: "",
// Created: 0, // 这里可以替换为实际时间戳
// Model: model,
// Choices: []openai.ChatCompletionChoice{
// {
//
// FinishReason: stopReason,
// },
// },
// Usage: struct {
// PromptTokens int `json:"prompt_tokens"`
// CompletionTokens int `json:"completion_tokens"`
// TotalTokens int `json:"total_tokens"`
// }{
// PromptTokens: bedrockResp.Usage.InputTokens,
// CompletionTokens: bedrockResp.Usage.OutputTokens,
// TotalTokens: bedrockResp.Usage.TotalTokens,
// },
// }
//}
// ConvertBedrockToOpenAI 通用转换方法
func ConvertBedrockToOpenAI(requestId string, model string, bedrockResp BedrockResponse, isStream bool) openai.ChatCompletionResponse {
// 提取文本内容
textContent := ""
if len(bedrockResp.Output.Message.Content) > 0 {
textContent = bedrockResp.Output.Message.Content[0].Text
}
stopReason := openai.FinishReasonStop
//end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
switch bedrockResp.StopReason {
case "max_tokens":
stopReason = openai.FinishReasonLength
case "content_filtered":
stopReason = openai.FinishReasonContentFilter
}
oj := "chat.completion"
if isStream {
oj = "chat.completion.chunk"
}
return openai.ChatCompletionResponse{
ID: requestId,
Object: oj,
Created: time.Now().Unix(), // 这里可以替换为实际时间戳
Model: model,
Choices: []openai.ChatCompletionChoice{
{
Index: 0,
Message: openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: textContent,
},
FinishReason: stopReason,
},
},
Usage: openai.Usage{
PromptTokens: bedrockResp.Usage.InputTokens,
CompletionTokens: bedrockResp.Usage.OutputTokens,
TotalTokens: bedrockResp.Usage.TotalTokens,
},
}
}
type StreamResponse struct {
Header Header `json:"headers" mapstructure:"headers"`
Payload Payload `json:"payload" mapstructure:"payload"`
}
type Header struct {
ContentType string `json:":content-type" mapstructure:":content-type"`
EventType string `json:":event-type" mapstructure:":event-type"`
MessageType string `json:":message-type" mapstructure:":message-type"`
}
type Payload struct {
ContentBlockIndex int `json:"contentBlockIndex" mapstructure:"contentBlockIndex"`
Delta *Delta `json:"delta" mapstructure:"delta"`
P string `json:"p" mapstructure:"p"`
StopReason string `json:"stopReason,omitempty" mapstructure:"stopReason"`
Metrics struct {
LatencyMs int `json:"latencyMs" mapstructure:"latencyMs"`
} `json:"metrics,omitempty" mapstructure:"metrics"`
Usage *Usage `json:"usage,omitempty" mapstructure:"usage"`
}
type Usage struct {
InputTokens int `json:"inputTokens" mapstructure:"inputTokens"`
OutputTokens int `json:"outputTokens" mapstructure:"outputTokens"`
TotalTokens int `json:"totalTokens" mapstructure:"totalTokens"`
}
type Delta struct {
Text string `json:"text" mapstructure:"text"`
}

View File

@@ -75,7 +75,10 @@ func (e *executor) doConverter(ctx http_context.IHttpContext, next eocontext.ICh
return err
}
}
if ctx.Response().IsBodyStream() {
ctx.Response().SetHeader("Content-Type", "text/event-stream")
return nil
}
if err := resource.ResponseConvert(ctx); err != nil {
return err
}
@@ -198,6 +201,7 @@ func (e *executor) processKeyPool(ctx http_context.IHttpContext, provider string
}
}
if ctx.Response().IsBodyStream() {
ctx.Response().SetHeader("Content-Type", "text/event-stream")
return nil
}
if err = r.ResponseConvert(ctx); err != nil {