Skip to content

Commit fb57c3d

Browse files
committed
fix(model): 兼容硅基流动 sub_type
1 parent 5e42469 commit fb57c3d

File tree

1 file changed

+27
-55
lines changed

1 file changed

+27
-55
lines changed

backend/internal/model/usecase/model.go

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,20 @@ func (m *ModelUsecase) InitModel(ctx context.Context) error {
255255
return m.repo.InitModel(ctx, m.cfg.InitModel.Name, m.cfg.InitModel.Key, m.cfg.InitModel.URL)
256256
}
257257

258+
func (m *ModelUsecase) getQuery(req *domain.GetProviderModelListReq) request.Query {
259+
q := make(request.Query, 0)
260+
if req.Provider != consts.ModelProviderBaiZhiCloud && req.Provider != consts.ModelProviderSiliconFlow {
261+
return q
262+
}
263+
q["type"] = "text"
264+
q["sub_type"] = string(req.Type)
265+
// 硅基流动不支持coder sub_type
266+
if req.Provider == consts.ModelProviderSiliconFlow && req.Type == consts.ModelTypeCoder {
267+
q["sub_type"] = "chat"
268+
}
269+
return q
270+
}
271+
258272
func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.GetProviderModelListReq) (*domain.GetProviderModelListResp, error) {
259273
switch req.Provider {
260274
case consts.ModelProviderAzureOpenAI,
@@ -266,18 +280,25 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
266280
consts.ModelProviderHunyuan,
267281
consts.ModelProviderMoonshot,
268282
consts.ModelProviderDeepSeek,
283+
consts.ModelProviderSiliconFlow,
284+
consts.ModelProviderBaiZhiCloud,
269285
consts.ModelProviderBaiLian:
270286
u, err := url.Parse(req.BaseURL)
271287
if err != nil {
272288
return nil, err
273289
}
274290
u.Path = path.Join(u.Path, "/models")
275291
client := request.NewClient(u.Scheme, u.Host, m.client.Timeout, request.WithClient(m.client))
276-
resp, err := request.Get[domain.OpenAIResp](client, u.Path, request.WithHeader(
277-
request.Header{
278-
"Authorization": fmt.Sprintf("Bearer %s", req.APIKey),
279-
},
280-
))
292+
query := m.getQuery(req)
293+
resp, err := request.Get[domain.OpenAIResp](
294+
client, u.Path,
295+
request.WithHeader(
296+
request.Header{
297+
"Authorization": fmt.Sprintf("Bearer %s", req.APIKey),
298+
},
299+
),
300+
request.WithQuery(query),
301+
)
281302
if err != nil {
282303
return nil, err
283304
}
@@ -289,6 +310,7 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
289310
}
290311
}),
291312
}, nil
313+
292314
case consts.ModelProviderOllama:
293315
// get from ollama http://10.10.16.24:11434/api/tags
294316
u, err := url.Parse(req.BaseURL)
@@ -306,56 +328,6 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
306328

307329
return request.Get[domain.GetProviderModelListResp](client, u.Path, request.WithHeader(h))
308330

309-
case consts.ModelProviderSiliconFlow, consts.ModelProviderBaiZhiCloud:
310-
if req.Type == consts.ModelTypeEmbedding || req.Type == consts.ModelTypeReranker {
311-
if req.Provider == consts.ModelProviderBaiZhiCloud {
312-
if req.Type == consts.ModelTypeEmbedding {
313-
return &domain.GetProviderModelListResp{
314-
Models: []domain.ProviderModelListItem{
315-
{
316-
Model: "bge-m3",
317-
},
318-
},
319-
}, nil
320-
} else {
321-
return &domain.GetProviderModelListResp{
322-
Models: []domain.ProviderModelListItem{
323-
{
324-
Model: "bge-reranker-v2-m3",
325-
},
326-
},
327-
}, nil
328-
}
329-
}
330-
}
331-
u, err := url.Parse(req.BaseURL)
332-
if err != nil {
333-
return nil, err
334-
}
335-
st := string(req.Type)
336-
if req.Type == consts.ModelTypeLLM {
337-
st = "chat"
338-
}
339-
client := request.NewClient(u.Scheme, u.Host, m.client.Timeout, request.WithClient(m.client))
340-
resp, err := request.Get[domain.OpenAIResp](client, "/v1/models", request.WithHeader(
341-
request.Header{
342-
"Authorization": fmt.Sprintf("Bearer %s", req.APIKey),
343-
},
344-
), request.WithQuery(request.Query{
345-
"type": "text",
346-
"sub_type": st,
347-
}))
348-
if err != nil {
349-
return nil, err
350-
}
351-
352-
return &domain.GetProviderModelListResp{
353-
Models: cvt.Iter(resp.Data, func(_ int, e *domain.OpenAIData) domain.ProviderModelListItem {
354-
return domain.ProviderModelListItem{
355-
Model: e.ID,
356-
}
357-
}),
358-
}, nil
359331
default:
360332
return nil, fmt.Errorf("invalid provider: %s", req.Provider)
361333
}

0 commit comments

Comments
 (0)