finish bedrock
This commit is contained in:
409
drivers/ai-provider/bedrock/bedrock.go
Normal file
409
drivers/ai-provider/bedrock/bedrock.go
Normal file
@@ -0,0 +1,409 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/eolinker/eosc/log"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/aws/aws-sdk-go/private/protocol/eventstream"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
|
||||
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
"github.com/eolinker/eosc/common/bean"
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
|
||||
"github.com/eolinker/eosc/eocontext"
|
||||
|
||||
ai_convert "github.com/eolinker/apinto/ai-convert"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
var (
|
||||
accessConfigManager ai_convert.IModelAccessConfigManager
|
||||
)
|
||||
|
||||
func init() {
|
||||
bean.Autowired(&accessConfigManager)
|
||||
ai_convert.RegisterConverterCreateFunc("bedrock", Create)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
AccessKey string `json:"aws_access_key_id"`
|
||||
SecretKey string `json:"aws_secret_access_key"`
|
||||
Region string `json:"aws_region"`
|
||||
}
|
||||
|
||||
func Create(cfg string) (ai_convert.IConverter, error) {
|
||||
var conf Config
|
||||
err := json.Unmarshal([]byte(cfg), &conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conf.AccessKey == "" {
|
||||
return nil, fmt.Errorf("aws_access_key_id is required")
|
||||
}
|
||||
if conf.SecretKey == "" {
|
||||
return nil, fmt.Errorf("aws_secret_access_key is required")
|
||||
}
|
||||
return NewConvert(conf.AccessKey, conf.SecretKey, conf.Region), nil
|
||||
}
|
||||
|
||||
type Convert struct {
|
||||
signer *v4.Signer
|
||||
region string
|
||||
}
|
||||
|
||||
func NewConvert(ak string, sk string, region string) *Convert {
|
||||
return &Convert{
|
||||
signer: v4.NewSigner(credentials.NewStaticCredentials(ak, sk, "")),
|
||||
region: region,
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
currentPath = "/model/%s/converse"
|
||||
streamPath = "/model/%s/converse-stream"
|
||||
)
|
||||
|
||||
func (c *Convert) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
|
||||
provider := ai_convert.GetAIProvider(ctx)
|
||||
model := ai_convert.GetAIModel(ctx)
|
||||
modelCfg, has := accessConfigManager.Get(fmt.Sprintf("%s$%s", provider, model))
|
||||
region := ""
|
||||
if has {
|
||||
model = modelCfg.Config()["model"]
|
||||
region = modelCfg.Config()["region"]
|
||||
}
|
||||
if region == "" {
|
||||
region = c.region
|
||||
}
|
||||
base := fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", region)
|
||||
|
||||
balanceHandler, err := ai_convert.NewBalanceHandler("", base, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.SetBalance(balanceHandler)
|
||||
httpContext, err := http_service.Assert(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := httpContext.Proxy().Body().RawBody()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chatRequest := eosc.NewBase[ai_convert.Request](extender)
|
||||
err = json.Unmarshal(body, chatRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal body error: %v, body: %s", err, string(body))
|
||||
}
|
||||
messages := make([]Message, 0, len(chatRequest.Config.Messages))
|
||||
systemMessage := make([]*Content, 0)
|
||||
for _, m := range chatRequest.Config.Messages {
|
||||
if m.Role == "system" {
|
||||
systemMessage = append(systemMessage, &Content{Text: m.Content})
|
||||
} else {
|
||||
messages = append(messages, Message{
|
||||
Role: m.Role,
|
||||
Content: []*Content{{Text: m.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
chatRequest.SetAppend("messages", messages)
|
||||
chatRequest.SetAppend("system", systemMessage)
|
||||
path := fmt.Sprintf(currentPath, model)
|
||||
if chatRequest.Config.Stream {
|
||||
path = fmt.Sprintf(streamPath, model)
|
||||
}
|
||||
uri := fmt.Sprintf("%s%s", base, path)
|
||||
httpContext.Proxy().URI().SetPath(path)
|
||||
|
||||
body, _ = json.Marshal(chatRequest)
|
||||
httpContext.Proxy().Body().SetRaw("application/json", body)
|
||||
headers, err := signRequest(c.signer, region, uri, http.Header{}, string(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range headers {
|
||||
httpContext.Proxy().Header().SetHeader(k, strings.Join(v, ";"))
|
||||
}
|
||||
httpContext.Proxy().Body().SetRaw("application/json", body)
|
||||
httpContext.Response().AppendStreamFunc(c.streamHandler)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Convert) ResponseConvert(ctx eocontext.EoContext) error {
|
||||
httpContext, err := http_service.Assert(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if httpContext.Response().StatusCode() != 200 {
|
||||
return nil
|
||||
}
|
||||
body := httpContext.Response().GetBody()
|
||||
var origin BedrockResponse
|
||||
|
||||
err = json.Unmarshal(body, &origin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp := ConvertBedrockToOpenAI(ctx.RequestId(), ai_convert.GetAIModel(ctx), origin, false)
|
||||
|
||||
body, err = json.Marshal(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpContext.Response().SetBody(body)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Convert) streamHandler(ctx http_service.IHttpContext, p []byte) ([]byte, error) {
|
||||
// 创建一个缓冲区来存储转换后的SSE格式数据
|
||||
var sseBuffer bytes.Buffer
|
||||
|
||||
// 生成一个唯一的请求ID
|
||||
requestID := ctx.RequestId()
|
||||
model := ai_convert.GetAIModel(ctx)
|
||||
response, err := EventStreamToJSON(p)
|
||||
if err != nil {
|
||||
log.Errorf("event stream to json error: %v", err)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
for _, r := range response {
|
||||
switch r.Header.EventType {
|
||||
case "contentBlockDelta":
|
||||
{
|
||||
if r.Payload.Delta != nil {
|
||||
data := openai.ChatCompletionStreamResponse{
|
||||
ID: requestID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: model,
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Content: r.Payload.Delta.Text,
|
||||
Role: "assistant",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
content, _ := json.Marshal(data)
|
||||
sseBuffer.WriteString(fmt.Sprintf("data: %s\n\n", string(content)))
|
||||
}
|
||||
}
|
||||
case "messageStop":
|
||||
{
|
||||
usage := new(openai.Usage)
|
||||
if r.Payload.Usage != nil {
|
||||
usage.PromptTokens = r.Payload.Usage.InputTokens
|
||||
usage.CompletionTokens = r.Payload.Usage.OutputTokens
|
||||
usage.TotalTokens = r.Payload.Usage.TotalTokens
|
||||
ai_convert.SetAIModelInputToken(ctx, r.Payload.Usage.InputTokens)
|
||||
ai_convert.SetAIModelOutputToken(ctx, r.Payload.Usage.OutputTokens)
|
||||
ai_convert.SetAIModelTotalToken(ctx, r.Payload.Usage.TotalTokens)
|
||||
}
|
||||
stopReason := openai.FinishReasonStop
|
||||
//end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
|
||||
switch r.Payload.StopReason {
|
||||
case "max_tokens":
|
||||
stopReason = openai.FinishReasonLength
|
||||
case "content_filtered":
|
||||
stopReason = openai.FinishReasonContentFilter
|
||||
}
|
||||
data := openai.ChatCompletionStreamResponse{
|
||||
ID: ctx.RequestId(),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: ai_convert.GetAIModel(ctx),
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: stopReason,
|
||||
},
|
||||
},
|
||||
Usage: usage,
|
||||
}
|
||||
content, _ := json.Marshal(data)
|
||||
sseBuffer.WriteString(fmt.Sprintf("data: %s\n\n", string(content)))
|
||||
sseBuffer.WriteString("data: [DONE]\n\n")
|
||||
return sseBuffer.Bytes(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 返回转换后的SSE格式数据
|
||||
return sseBuffer.Bytes(), nil
|
||||
//// 对响应数据进行划分
|
||||
//inputToken := GetAIModelInputToken(ctx)
|
||||
//outputToken := 0
|
||||
//totalToken := inputToken
|
||||
//scanner := bufio.NewScanner(bytes.NewReader(p))
|
||||
//// Check the content encoding and convert to UTF-8 if necessary.
|
||||
//encoding := ctx.Response().Headers().Get("content-encoding")
|
||||
//for scanner.Scan() {
|
||||
// line := scanner.Text()
|
||||
// if encoding != "utf-8" && encoding != "" {
|
||||
// tmp, err := encoder.ToUTF8(encoding, []byte(line))
|
||||
// if err != nil {
|
||||
// log.Errorf("convert to utf-8 error: %v, line: %s", err, line)
|
||||
// return p, nil
|
||||
// }
|
||||
// if ctx.Response().StatusCode() != 200 || (o.checkErr != nil && !o.checkErr(ctx, tmp)) {
|
||||
// if o.errorCallback != nil {
|
||||
// o.errorCallback(ctx, tmp)
|
||||
// }
|
||||
// return p, nil
|
||||
// }
|
||||
// line = string(tmp)
|
||||
// }
|
||||
// line = strings.TrimPrefix(line, "data:")
|
||||
// if line == "" || strings.Trim(line, " ") == "[DONE]" {
|
||||
// return p, nil
|
||||
// }
|
||||
// var resp openai.ChatCompletionResponse
|
||||
// err := json.Unmarshal([]byte(line), &resp)
|
||||
// if err != nil {
|
||||
// return p, nil
|
||||
// }
|
||||
// if len(resp.Choices) > 0 {
|
||||
// outputToken += getTokens(resp.Choices[0].Message.Content)
|
||||
// totalToken += outputToken
|
||||
// }
|
||||
//}
|
||||
//if err := scanner.Err(); err != nil {
|
||||
// log.Errorf("scan error: %v", err)
|
||||
// return p, nil
|
||||
//}
|
||||
//
|
||||
//SetAIModelInputToken(ctx, inputToken)
|
||||
//SetAIModelOutputToken(ctx, outputToken)
|
||||
//SetAIModelTotalToken(ctx, totalToken)
|
||||
}
|
||||
|
||||
func signRequest(signer *v4.Signer, region string, uri string, headers http.Header, body string) (http.Header, error) {
|
||||
request, err := http.NewRequest(http.MethodPost, uri, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header = headers.Clone()
|
||||
|
||||
_, err = signer.Sign(request, strings.NewReader(body), "bedrock", region, time.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request.Header, nil
|
||||
|
||||
}
|
||||
|
||||
// EventStreamToJSON 将 Amazon EventStream 格式的数据转换为 JSON 格式
|
||||
func EventStreamToJSON(eventStreamData []byte) ([]StreamResponse, error) {
|
||||
// 创建一个结果数组
|
||||
var result []StreamResponse
|
||||
|
||||
// 创建一个 EventStream 解码器
|
||||
decoder := eventstream.NewDecoder(bytes.NewReader(eventStreamData))
|
||||
|
||||
// 循环读取所有事件
|
||||
for {
|
||||
|
||||
// 读取下一个消息
|
||||
msg, err := decoder.Decode(nil)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break // 正常结束
|
||||
}
|
||||
|
||||
// 处理 AWS 错误
|
||||
if awsErr, ok := err.(awserr.Error); ok {
|
||||
return nil, fmt.Errorf("AWS Error: %s - %s", awsErr.Code(), awsErr.Message())
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("解析 EventStream 时出错: %v", err)
|
||||
}
|
||||
|
||||
// 将消息转换为 map
|
||||
eventMap := make(map[string]interface{})
|
||||
|
||||
// 处理消息头
|
||||
headers := make(map[string]interface{})
|
||||
for _, header := range msg.Headers {
|
||||
headers[header.Name] = header.Value
|
||||
}
|
||||
|
||||
eventMap["headers"] = headers
|
||||
|
||||
// 处理消息体
|
||||
if len(msg.Payload) > 0 {
|
||||
// 尝试将负载解析为 JSON
|
||||
var payload interface{}
|
||||
if err := json.Unmarshal(msg.Payload, &payload); err == nil {
|
||||
eventMap["payload"] = payload
|
||||
} else {
|
||||
// 如果不是有效的 JSON,则作为字符串处理
|
||||
eventMap["payload"] = string(msg.Payload)
|
||||
}
|
||||
}
|
||||
var streamResponse StreamResponse
|
||||
err = mapstructure.Decode(eventMap, &streamResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 将事件添加到结果数组
|
||||
result = append(result, streamResponse)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExtractTextFromEventStream 从 EventStream 中提取文本内容
|
||||
// 这个函数专门用于从 Bedrock 模型响应中提取生成的文本
|
||||
func ExtractTextFromEventStream(eventStreamData []byte) (string, error) {
|
||||
var fullText string
|
||||
|
||||
decoder := eventstream.NewDecoder(bytes.NewReader(eventStreamData))
|
||||
|
||||
for {
|
||||
msg, err := decoder.Decode(nil)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 解析消息负载
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(msg.Payload, &response); err != nil {
|
||||
continue // 跳过无法解析的消息
|
||||
}
|
||||
|
||||
// 根据 Bedrock 的响应格式提取文本
|
||||
// 注意:具体的字段名可能需要根据使用的模型进行调整
|
||||
if completion, ok := response["completion"].(string); ok {
|
||||
fullText += completion
|
||||
} else if output, ok := response["output"].(map[string]interface{}); ok {
|
||||
if text, ok := output["text"].(string); ok {
|
||||
fullText += text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fullText, nil
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
|
||||
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
"github.com/eolinker/eosc/common/bean"
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
|
||||
"github.com/eolinker/eosc/eocontext"
|
||||
|
||||
ai_convert "github.com/eolinker/apinto/ai-convert"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
var (
|
||||
accessConfigManager ai_convert.IModelAccessConfigManager
|
||||
)
|
||||
|
||||
func init() {
|
||||
bean.Autowired(&accessConfigManager)
|
||||
ai_convert.RegisterConverterCreateFunc("bedrock", Create)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
AccessKey string `json:"aws_access_key_id"`
|
||||
SecretKey string `json:"aws_secret_access_key"`
|
||||
Region string `json:"aws_region"`
|
||||
}
|
||||
|
||||
func Create(cfg string) (ai_convert.IConverter, error) {
|
||||
var conf Config
|
||||
err := json.Unmarshal([]byte(cfg), &conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conf.AccessKey == "" {
|
||||
return nil, fmt.Errorf("aws_access_key_id is required")
|
||||
}
|
||||
if conf.SecretKey == "" {
|
||||
return nil, fmt.Errorf("aws_secret_access_key is required")
|
||||
}
|
||||
return NewConvert(conf.AccessKey, conf.SecretKey, conf.Region), nil
|
||||
}
|
||||
|
||||
type Convert struct {
|
||||
signer *v4.Signer
|
||||
region string
|
||||
}
|
||||
|
||||
func NewConvert(ak string, sk string, region string) *Convert {
|
||||
return &Convert{
|
||||
signer: v4.NewSigner(credentials.NewStaticCredentials(ak, sk, "")),
|
||||
region: region,
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
currentPath = "/model/%s/converse"
|
||||
streamPath = "/model/%s/converse-stream"
|
||||
)
|
||||
|
||||
func (c *Convert) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
|
||||
provider := ai_convert.GetAIProvider(ctx)
|
||||
model := ai_convert.GetAIModel(ctx)
|
||||
modelCfg, has := accessConfigManager.Get(fmt.Sprintf("%s$%s", provider, model))
|
||||
region := ""
|
||||
if has {
|
||||
model = modelCfg.Config()["model"]
|
||||
region = modelCfg.Config()["region"]
|
||||
}
|
||||
if region == "" {
|
||||
region = c.region
|
||||
}
|
||||
base := fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", region)
|
||||
|
||||
balanceHandler, err := ai_convert.NewBalanceHandler("", base, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.SetBalance(balanceHandler)
|
||||
httpContext, err := http_service.Assert(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := httpContext.Proxy().Body().RawBody()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chatRequest := eosc.NewBase[ai_convert.Request](extender)
|
||||
err = json.Unmarshal(body, chatRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal body error: %v, body: %s", err, string(body))
|
||||
}
|
||||
messages := make([]Message, 0, len(chatRequest.Config.Messages))
|
||||
systemMessage := make([]*Content, 0)
|
||||
for _, m := range chatRequest.Config.Messages {
|
||||
if m.Role == "system" {
|
||||
systemMessage = append(systemMessage, &Content{Text: m.Content})
|
||||
} else {
|
||||
messages = append(messages, Message{
|
||||
Role: m.Role,
|
||||
Content: []*Content{{Text: m.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
chatRequest.SetAppend("messages", messages)
|
||||
chatRequest.SetAppend("system", systemMessage)
|
||||
path := fmt.Sprintf(currentPath, model)
|
||||
if chatRequest.Config.Stream {
|
||||
path = fmt.Sprintf(streamPath, model)
|
||||
}
|
||||
uri := fmt.Sprintf("%s%s", base, path)
|
||||
httpContext.Proxy().URI().SetPath(path)
|
||||
|
||||
body, _ = json.Marshal(chatRequest)
|
||||
httpContext.Proxy().Body().SetRaw("application/json", body)
|
||||
headers, err := signRequest(c.signer, region, uri, http.Header{}, string(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range headers {
|
||||
httpContext.Proxy().Header().SetHeader(k, strings.Join(v, ";"))
|
||||
}
|
||||
httpContext.Proxy().Body().SetRaw("application/json", body)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Convert) ResponseConvert(ctx eocontext.EoContext) error {
|
||||
//httpContext, err := http_service.Assert(ctx)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
//if httpContext.Response().StatusCode() != 200 {
|
||||
// return nil
|
||||
//}
|
||||
//body := httpContext.Response().GetBody()
|
||||
//data := eosc.NewBase[Response](nil)
|
||||
//err = json.Unmarshal(body, data)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
//responseBody := &ai_convert.Response{}
|
||||
//
|
||||
//body, err = json.Marshal(responseBody)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
//httpContext.Response().AppendStreamFunc(c.streamHandler)
|
||||
//httpContext.Response().SetBody(body)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Convert) streamHandler(ctx http_service.IHttpContext, p []byte) ([]byte, error) {
|
||||
//// 对响应数据进行划分
|
||||
//inputToken := GetAIModelInputToken(ctx)
|
||||
//outputToken := 0
|
||||
//totalToken := inputToken
|
||||
//scanner := bufio.NewScanner(bytes.NewReader(p))
|
||||
//// Check the content encoding and convert to UTF-8 if necessary.
|
||||
//encoding := ctx.Response().Headers().Get("content-encoding")
|
||||
//for scanner.Scan() {
|
||||
// line := scanner.Text()
|
||||
// if encoding != "utf-8" && encoding != "" {
|
||||
// tmp, err := encoder.ToUTF8(encoding, []byte(line))
|
||||
// if err != nil {
|
||||
// log.Errorf("convert to utf-8 error: %v, line: %s", err, line)
|
||||
// return p, nil
|
||||
// }
|
||||
// if ctx.Response().StatusCode() != 200 || (o.checkErr != nil && !o.checkErr(ctx, tmp)) {
|
||||
// if o.errorCallback != nil {
|
||||
// o.errorCallback(ctx, tmp)
|
||||
// }
|
||||
// return p, nil
|
||||
// }
|
||||
// line = string(tmp)
|
||||
// }
|
||||
// line = strings.TrimPrefix(line, "data:")
|
||||
// if line == "" || strings.Trim(line, " ") == "[DONE]" {
|
||||
// return p, nil
|
||||
// }
|
||||
// var resp openai.ChatCompletionResponse
|
||||
// err := json.Unmarshal([]byte(line), &resp)
|
||||
// if err != nil {
|
||||
// return p, nil
|
||||
// }
|
||||
// if len(resp.Choices) > 0 {
|
||||
// outputToken += getTokens(resp.Choices[0].Message.Content)
|
||||
// totalToken += outputToken
|
||||
// }
|
||||
//}
|
||||
//if err := scanner.Err(); err != nil {
|
||||
// log.Errorf("scan error: %v", err)
|
||||
// return p, nil
|
||||
//}
|
||||
//
|
||||
//SetAIModelInputToken(ctx, inputToken)
|
||||
//SetAIModelOutputToken(ctx, outputToken)
|
||||
//SetAIModelTotalToken(ctx, totalToken)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func signRequest(signer *v4.Signer, region string, uri string, headers http.Header, body string) (http.Header, error) {
|
||||
request, err := http.NewRequest(http.MethodPost, uri, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header = headers.Clone()
|
||||
|
||||
_, err = signer.Sign(request, strings.NewReader(body), "bedrock", region, time.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request.Header, nil
|
||||
|
||||
}
|
||||
@@ -1,5 +1,11 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type ClientRequest struct {
|
||||
Messages []*Message `json:"message,omitempty"`
|
||||
System *Content `json:"system,omitempty"`
|
||||
@@ -21,15 +27,6 @@ type InferenceConfig struct {
|
||||
TopP float64 `json:"topP"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Output Output `json:"output"`
|
||||
StopReason string `json:"stopReason"`
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
Message *Message `json:"message"`
|
||||
}
|
||||
|
||||
// BedrockResponse 代表 Amazon Bedrock 的 JSON 响应格式
|
||||
type BedrockResponse struct {
|
||||
Metrics struct {
|
||||
@@ -51,50 +48,77 @@ type BedrockResponse struct {
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
//
|
||||
//// ConvertBedrockToOpenAI 通用转换方法
|
||||
//func ConvertBedrockToOpenAI(requestId string, model string, bedrockResp BedrockResponse) openai.ChatCompletionResponse {
|
||||
// // 提取文本内容
|
||||
// textContent := ""
|
||||
// if len(bedrockResp.Output.Message.Content) > 0 {
|
||||
// textContent = bedrockResp.Output.Message.Content[0].Text
|
||||
// }
|
||||
// //const (
|
||||
// // FinishReasonStop FinishReason = "stop"
|
||||
// // FinishReasonLength FinishReason = "length"
|
||||
// // FinishReasonFunctionCall FinishReason = "function_call"
|
||||
// // FinishReasonToolCalls FinishReason = "tool_calls"
|
||||
// // FinishReasonContentFilter FinishReason = "content_filter"
|
||||
// // FinishReasonNull FinishReason = "null"
|
||||
// //)
|
||||
// stopReason := openai.FinishReasonStop
|
||||
// //end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
|
||||
// switch bedrockResp.StopReason {
|
||||
// case "max_tokens":
|
||||
// stopReason = openai.FinishReasonLength
|
||||
// case "content_filtered":
|
||||
// stopReason = openai.FinishReasonContentFilter
|
||||
// }
|
||||
// openai.FinishReason(bedrockResp.StopReason)
|
||||
// return openai.ChatCompletionResponse{
|
||||
// ID: requestId,
|
||||
// Object: "",
|
||||
// Created: 0, // 这里可以替换为实际时间戳
|
||||
// Model: model,
|
||||
// Choices: []openai.ChatCompletionChoice{
|
||||
// {
|
||||
//
|
||||
// FinishReason: stopReason,
|
||||
// },
|
||||
// },
|
||||
// Usage: struct {
|
||||
// PromptTokens int `json:"prompt_tokens"`
|
||||
// CompletionTokens int `json:"completion_tokens"`
|
||||
// TotalTokens int `json:"total_tokens"`
|
||||
// }{
|
||||
// PromptTokens: bedrockResp.Usage.InputTokens,
|
||||
// CompletionTokens: bedrockResp.Usage.OutputTokens,
|
||||
// TotalTokens: bedrockResp.Usage.TotalTokens,
|
||||
// },
|
||||
// }
|
||||
//}
|
||||
// ConvertBedrockToOpenAI 通用转换方法
|
||||
func ConvertBedrockToOpenAI(requestId string, model string, bedrockResp BedrockResponse, isStream bool) openai.ChatCompletionResponse {
|
||||
// 提取文本内容
|
||||
textContent := ""
|
||||
if len(bedrockResp.Output.Message.Content) > 0 {
|
||||
textContent = bedrockResp.Output.Message.Content[0].Text
|
||||
}
|
||||
|
||||
stopReason := openai.FinishReasonStop
|
||||
//end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
|
||||
switch bedrockResp.StopReason {
|
||||
case "max_tokens":
|
||||
stopReason = openai.FinishReasonLength
|
||||
case "content_filtered":
|
||||
stopReason = openai.FinishReasonContentFilter
|
||||
}
|
||||
oj := "chat.completion"
|
||||
if isStream {
|
||||
oj = "chat.completion.chunk"
|
||||
}
|
||||
return openai.ChatCompletionResponse{
|
||||
ID: requestId,
|
||||
Object: oj,
|
||||
Created: time.Now().Unix(), // 这里可以替换为实际时间戳
|
||||
Model: model,
|
||||
Choices: []openai.ChatCompletionChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: textContent,
|
||||
},
|
||||
FinishReason: stopReason,
|
||||
},
|
||||
},
|
||||
Usage: openai.Usage{
|
||||
PromptTokens: bedrockResp.Usage.InputTokens,
|
||||
CompletionTokens: bedrockResp.Usage.OutputTokens,
|
||||
TotalTokens: bedrockResp.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type StreamResponse struct {
|
||||
Header Header `json:"headers" mapstructure:"headers"`
|
||||
Payload Payload `json:"payload" mapstructure:"payload"`
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
ContentType string `json:":content-type" mapstructure:":content-type"`
|
||||
EventType string `json:":event-type" mapstructure:":event-type"`
|
||||
MessageType string `json:":message-type" mapstructure:":message-type"`
|
||||
}
|
||||
|
||||
type Payload struct {
|
||||
ContentBlockIndex int `json:"contentBlockIndex" mapstructure:"contentBlockIndex"`
|
||||
Delta *Delta `json:"delta" mapstructure:"delta"`
|
||||
P string `json:"p" mapstructure:"p"`
|
||||
StopReason string `json:"stopReason,omitempty" mapstructure:"stopReason"`
|
||||
Metrics struct {
|
||||
LatencyMs int `json:"latencyMs" mapstructure:"latencyMs"`
|
||||
} `json:"metrics,omitempty" mapstructure:"metrics"`
|
||||
Usage *Usage `json:"usage,omitempty" mapstructure:"usage"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"inputTokens" mapstructure:"inputTokens"`
|
||||
OutputTokens int `json:"outputTokens" mapstructure:"outputTokens"`
|
||||
TotalTokens int `json:"totalTokens" mapstructure:"totalTokens"`
|
||||
}
|
||||
|
||||
type Delta struct {
|
||||
Text string `json:"text" mapstructure:"text"`
|
||||
}
|
||||
|
||||
@@ -75,7 +75,10 @@ func (e *executor) doConverter(ctx http_context.IHttpContext, next eocontext.ICh
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.Response().IsBodyStream() {
|
||||
ctx.Response().SetHeader("Content-Type", "text/event-stream")
|
||||
return nil
|
||||
}
|
||||
if err := resource.ResponseConvert(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -198,6 +201,7 @@ func (e *executor) processKeyPool(ctx http_context.IHttpContext, provider string
|
||||
}
|
||||
}
|
||||
if ctx.Response().IsBodyStream() {
|
||||
ctx.Response().SetHeader("Content-Type", "text/event-stream")
|
||||
return nil
|
||||
}
|
||||
if err = r.ResponseConvert(ctx); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user