Skip to content

Commit 1987af7

Browse files
authored
Merge pull request chaitin#322 from yokowu/feat-multi-model
feat: 支持插件选择模型
2 parents d15e980 + 4b357a0 commit 1987af7

File tree

14 files changed

+434
-67
lines changed

14 files changed

+434
-67
lines changed

backend/cmd/server/wire_gen.go

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/consts/model.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package consts
33
type ModelStatus string
44

55
const (
6+
ModelStatusDefault ModelStatus = "default"
67
ModelStatusActive ModelStatus = "active"
78
ModelStatusInactive ModelStatus = "inactive"
89
)

backend/domain/model.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type ModelUsecase interface {
2020
}
2121

2222
type ModelRepo interface {
23-
GetWithCache(ctx context.Context, modelType consts.ModelType) (*db.Model, error)
23+
GetWithCache(ctx context.Context, modelType consts.ModelType) ([]*db.Model, error)
2424
List(ctx context.Context) (*AllModelResp, error)
2525
Create(ctx context.Context, m *CreateModelReq) (*db.Model, error)
2626
Update(ctx context.Context, id string, fn func(tx *db.Tx, old *db.Model, up *db.ModelUpdateOne) error) (*db.Model, error)
@@ -181,7 +181,7 @@ func (m *Model) From(e *db.Model) *Model {
181181
m.ModelType = e.ModelType
182182
m.Status = e.Status
183183
m.IsInternal = e.IsInternal
184-
m.IsActive = e.Status == consts.ModelStatusActive
184+
m.IsActive = e.Status == consts.ModelStatusActive || e.Status == consts.ModelStatusDefault
185185
if p := e.Parameters; p != nil {
186186
m.Param = ModelParam{
187187
R1Enabled: p.R1Enabled,

backend/domain/openai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ type ConfigReq struct {
6868

6969
type ConfigResp struct {
7070
Type consts.ConfigType `json:"type"`
71-
Content string `json:"content"`
71+
Content any `json:"content"`
7272
}
7373
type OpenAIResp struct {
7474
Object string `json:"object"`

backend/domain/plugin.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package domain
2+
3+
type PluginConfig struct {
4+
ProviderProfiles ProviderProfiles `json:"providerProfiles"`
5+
CtcodeTabCompletions CtcodeTabCompletions `json:"ctcodeTabCompletions"`
6+
GlobalSettings GlobalSettings `json:"globalSettings"`
7+
}
8+
9+
type ProviderProfiles struct {
10+
CurrentApiConfigName string `json:"currentApiConfigName"`
11+
ApiConfigs map[string]ApiConfig `json:"apiConfigs"`
12+
ModeApiConfigs map[string]string `json:"modeApiConfigs"`
13+
Migrations Migrations `json:"migrations"`
14+
}
15+
16+
type ApiConfig struct {
17+
ApiProvider string `json:"apiProvider"`
18+
ApiModelId string `json:"apiModelId"`
19+
OpenAiBaseUrl string `json:"openAiBaseUrl"`
20+
OpenAiApiKey string `json:"openAiApiKey"`
21+
OpenAiModelId string `json:"openAiModelId"`
22+
OpenAiR1FormatEnabled bool `json:"openAiR1FormatEnabled"`
23+
OpenAiCustomModelInfo OpenAiCustomModelInfo `json:"openAiCustomModelInfo"`
24+
Id string `json:"id"`
25+
}
26+
27+
type OpenAiCustomModelInfo struct {
28+
MaxTokens int `json:"maxTokens"`
29+
ContextWindow int `json:"contextWindow"`
30+
SupportsImages bool `json:"supportsImages"`
31+
SupportsComputerUse bool `json:"supportsComputerUse"`
32+
SupportsPromptCache bool `json:"supportsPromptCache"`
33+
}
34+
35+
type Migrations struct {
36+
RateLimitSecondsMigrated bool `json:"rateLimitSecondsMigrated"`
37+
DiffSettingsMigrated bool `json:"diffSettingsMigrated"`
38+
}
39+
40+
type CtcodeTabCompletions struct {
41+
Enabled bool `json:"enabled"`
42+
ApiProvider string `json:"apiProvider"`
43+
OpenAiBaseUrl string `json:"openAiBaseUrl"`
44+
OpenAiApiKey string `json:"openAiApiKey"`
45+
OpenAiModelId string `json:"openAiModelId"`
46+
}
47+
48+
type GlobalSettings struct {
49+
AllowedCommands []string `json:"allowedCommands"`
50+
Mode string `json:"mode"`
51+
CustomModes []string `json:"customModes"`
52+
}

backend/internal/middleware/proxy.go

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package middleware
22

33
import (
44
"context"
5+
"encoding/json"
6+
"log/slog"
57
"net/http"
68
"strings"
79

810
"github.com/labstack/echo/v4"
11+
"github.com/redis/go-redis/v9"
912

1013
"github.com/chaitin/MonkeyCode/backend/domain"
1114
"github.com/chaitin/MonkeyCode/backend/ent/rule"
@@ -16,15 +19,23 @@ const (
1619
ApiContextKey = "session:apikey"
1720
)
1821

22+
type proxyModelKey struct{}
23+
1924
type ProxyMiddleware struct {
2025
usecase domain.ProxyUsecase
26+
redis *redis.Client
27+
logger *slog.Logger
2128
}
2229

2330
func NewProxyMiddleware(
2431
usecase domain.ProxyUsecase,
32+
redis *redis.Client,
33+
logger *slog.Logger,
2534
) *ProxyMiddleware {
2635
return &ProxyMiddleware{
2736
usecase: usecase,
37+
redis: redis,
38+
logger: logger.With("module", "ProxyMiddleware"),
2839
}
2940
}
3041

@@ -39,21 +50,54 @@ func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc {
3950
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
4051
}
4152

42-
key, err := p.usecase.ValidateApiKey(c.Request().Context(), apiKey)
43-
if err != nil {
44-
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
53+
ctx := c.Request().Context()
54+
p.logger.With("apiKey", apiKey).DebugContext(ctx, "v1 auth")
55+
if strings.Contains(apiKey, ".") {
56+
s, err := p.redis.Get(ctx, apiKey).Result()
57+
if err != nil {
58+
p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to get api key from redis")
59+
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
60+
}
61+
var model *domain.Model
62+
if err := json.Unmarshal([]byte(s), &model); err != nil {
63+
p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to unmarshal model from redis")
64+
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
65+
}
66+
parts := strings.Split(apiKey, ".")
67+
if len(parts) != 2 {
68+
p.logger.With("fn", "Auth").With("apiKey", apiKey).ErrorContext(ctx, "invalid api key")
69+
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
70+
}
71+
ctx = context.WithValue(ctx, proxyModelKey{}, model)
72+
ctx = context.WithValue(ctx, logger.UserIDKey{}, parts[0])
73+
c.Set(ApiContextKey, &domain.ApiKey{
74+
UserID: parts[0],
75+
Key: apiKey,
76+
})
77+
} else {
78+
key, err := p.usecase.ValidateApiKey(ctx, apiKey)
79+
if err != nil {
80+
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
81+
}
82+
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
83+
c.Set(ApiContextKey, key)
4584
}
4685

47-
ctx := c.Request().Context()
48-
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
4986
ctx = rule.SkipPermission(ctx)
5087
c.SetRequest(c.Request().WithContext(ctx))
51-
c.Set(ApiContextKey, key)
5288
return next(c)
5389
}
5490
}
5591
}
5692

93+
func GetProxyModel(ctx context.Context) *domain.Model {
94+
m := ctx.Value(proxyModelKey{})
95+
if m == nil {
96+
return nil
97+
}
98+
return m.(*domain.Model)
99+
}
100+
57101
func GetApiKey(c echo.Context) *domain.ApiKey {
58102
i := c.Get(ApiContextKey)
59103
if i == nil {

backend/internal/model/repo/model.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,16 @@ func NewModelRepo(db *db.Client) domain.ModelRepo {
3030
return &ModelRepo{db: db, cache: cache}
3131
}
3232

33-
func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType) (*db.Model, error) {
33+
func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType) ([]*db.Model, error) {
3434
if v, ok := r.cache.Get(string(modelType)); ok {
35-
return v.(*db.Model), nil
35+
return v.([]*db.Model), nil
3636
}
3737

3838
m, err := r.db.Model.Query().
3939
Where(model.ModelType(modelType)).
40-
Where(model.Status(consts.ModelStatusActive)).
41-
Only(ctx)
40+
Where(model.StatusIn(consts.ModelStatusActive, consts.ModelStatusDefault)).
41+
Order(ByStatusOrder()).
42+
All(ctx)
4243
if err != nil {
4344
return nil, err
4445
}
@@ -47,14 +48,22 @@ func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType
4748
return m, nil
4849
}
4950

51+
func ByStatusOrder() func(s *sql.Selector) {
52+
return func(s *sql.Selector) {
53+
s.OrderExprFunc(func(b *sql.Builder) {
54+
b.WriteString("case when status = 'default' then 3 when status = 'active' then 2 else 1 end desc")
55+
})
56+
}
57+
}
58+
5059
func (r *ModelRepo) Create(ctx context.Context, m *domain.CreateModelReq) (*db.Model, error) {
5160
n, err := r.db.Model.Query().Where(model.ModelType(m.ModelType)).Count(ctx)
5261
if err != nil {
5362
return nil, err
5463
}
55-
status := consts.ModelStatusInactive
64+
status := consts.ModelStatusActive
5665
if n == 0 {
57-
status = consts.ModelStatusActive
66+
status = consts.ModelStatusDefault
5867
}
5968

6069
r.cache.Delete(string(m.ModelType))

backend/internal/model/usecase/model.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,27 @@ func (m *ModelUsecase) Update(ctx context.Context, req *domain.UpdateModelReq) (
108108
up.SetShowName(*req.ShowName)
109109
}
110110
if req.Status != nil {
111-
if *req.Status == consts.ModelStatusActive {
111+
if *req.Status == consts.ModelStatusDefault {
112112
if err := tx.Model.Update().
113+
Where(model.Status(consts.ModelStatusDefault)).
113114
Where(model.ModelType(old.ModelType)).
114-
SetStatus(consts.ModelStatusInactive).
115+
SetStatus(consts.ModelStatusActive).
115116
Exec(ctx); err != nil {
116117
return err
117118
}
118119
}
120+
if *req.Status == consts.ModelStatusActive {
121+
n, err := tx.Model.Query().
122+
Where(model.Status(consts.ModelStatusDefault)).
123+
Where(model.ModelType(old.ModelType)).
124+
Count(ctx)
125+
if err != nil {
126+
return err
127+
}
128+
if n == 0 {
129+
*req.Status = consts.ModelStatusDefault
130+
}
131+
}
119132
up.SetStatus(*req.Status)
120133
}
121134
if req.Param != nil {

0 commit comments

Comments
 (0)