Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions coderd/apidoc/docs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions coderd/apidoc/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions coderd/database/db2sdk/db2sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,9 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
if interception.EndedAt.Valid {
intc.EndedAt = &interception.EndedAt.Time
}
if interception.Client.Valid {
intc.Client = &interception.Client.String
}
return intc
}

Expand Down
229 changes: 229 additions & 0 deletions coderd/database/db2sdk/db2sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/database"
Expand Down Expand Up @@ -206,3 +207,231 @@ func TestTemplateVersionParameter_BadDescription(t *testing.T) {
req.NoError(err)
req.NotEmpty(sdk.DescriptionPlaintext, "broke the markdown parser with %v", desc)
}

func TestAIBridgeInterception(t *testing.T) {
t.Parallel()

now := dbtime.Now()
interceptionID := uuid.New()
initiatorID := uuid.New()

cases := []struct {
name string
interception database.AIBridgeInterception
initiator database.VisibleUser
tokenUsages []database.AIBridgeTokenUsage
userPrompts []database.AIBridgeUserPrompt
toolUsages []database.AIBridgeToolUsage
expected codersdk.AIBridgeInterception
}{
{
name: "all_optional_values_set",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these 👍

interception: database.AIBridgeInterception{
ID: interceptionID,
InitiatorID: initiatorID,
Provider: "anthropic",
Model: "claude-3-opus",
StartedAt: now,
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"key":"value"}`),
Valid: true,
},
EndedAt: sql.NullTime{
Time: now.Add(time.Minute),
Valid: true,
},
APIKeyID: sql.NullString{
String: "api-key-123",
Valid: true,
},
Client: sql.NullString{
String: "claude-code/1.0.0",
Valid: true,
},
},
initiator: database.VisibleUser{
ID: initiatorID,
Username: "testuser",
Name: "Test User",
AvatarURL: "https://example.com/avatar.png",
},
tokenUsages: []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: "resp-123",
InputTokens: 100,
OutputTokens: 200,
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"cache":"hit"}`),
Valid: true,
},
CreatedAt: now.Add(10 * time.Second),
},
},
userPrompts: []database.AIBridgeUserPrompt{
{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: "resp-123",
Prompt: "Hello, world!",
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"role":"user"}`),
Valid: true,
},
CreatedAt: now.Add(5 * time.Second),
},
},
toolUsages: []database.AIBridgeToolUsage{
{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: "resp-123",
ServerUrl: sql.NullString{
String: "https://mcp.example.com",
Valid: true,
},
Tool: "read_file",
Input: `{"path":"/tmp/test.txt"}`,
Injected: true,
InvocationError: sql.NullString{
String: "file not found",
Valid: true,
},
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"duration_ms":50}`),
Valid: true,
},
CreatedAt: now.Add(15 * time.Second),
},
},
expected: codersdk.AIBridgeInterception{
ID: interceptionID,
Initiator: codersdk.MinimalUser{
ID: initiatorID,
Username: "testuser",
Name: "Test User",
AvatarURL: "https://example.com/avatar.png",
},
Provider: "anthropic",
Model: "claude-3-opus",
Metadata: map[string]any{"key": "value"},
StartedAt: now,
},
},
{
name: "no_optional_values_set",
interception: database.AIBridgeInterception{
ID: interceptionID,
InitiatorID: initiatorID,
Provider: "openai",
Model: "gpt-4",
StartedAt: now,
Metadata: pqtype.NullRawMessage{Valid: false},
EndedAt: sql.NullTime{Valid: false},
APIKeyID: sql.NullString{Valid: false},
Client: sql.NullString{Valid: false},
},
initiator: database.VisibleUser{
ID: initiatorID,
Username: "minimaluser",
Name: "",
AvatarURL: "",
},
tokenUsages: nil,
userPrompts: nil,
toolUsages: nil,
expected: codersdk.AIBridgeInterception{
ID: interceptionID,
Initiator: codersdk.MinimalUser{
ID: initiatorID,
Username: "minimaluser",
Name: "",
AvatarURL: "",
},
Provider: "openai",
Model: "gpt-4",
Metadata: nil,
StartedAt: now,
},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

result := db2sdk.AIBridgeInterception(
tc.interception,
tc.initiator,
tc.tokenUsages,
tc.userPrompts,
tc.toolUsages,
)

// Check basic fields.
require.Equal(t, tc.expected.ID, result.ID)
require.Equal(t, tc.expected.Initiator, result.Initiator)
require.Equal(t, tc.expected.Provider, result.Provider)
require.Equal(t, tc.expected.Model, result.Model)
require.Equal(t, tc.expected.StartedAt.UTC(), result.StartedAt.UTC())
require.Equal(t, tc.expected.Metadata, result.Metadata)

// Check optional pointer fields.
if tc.interception.APIKeyID.Valid {
require.NotNil(t, result.APIKeyID)
require.Equal(t, tc.interception.APIKeyID.String, *result.APIKeyID)
} else {
require.Nil(t, result.APIKeyID)
}

if tc.interception.EndedAt.Valid {
require.NotNil(t, result.EndedAt)
require.Equal(t, tc.interception.EndedAt.Time.UTC(), result.EndedAt.UTC())
} else {
require.Nil(t, result.EndedAt)
}

if tc.interception.Client.Valid {
require.NotNil(t, result.Client)
require.Equal(t, tc.interception.Client.String, *result.Client)
} else {
require.Nil(t, result.Client)
}

// Check slices.
require.Len(t, result.TokenUsages, len(tc.tokenUsages))
require.Len(t, result.UserPrompts, len(tc.userPrompts))
require.Len(t, result.ToolUsages, len(tc.toolUsages))

// Verify token usages are converted correctly.
for i, tu := range tc.tokenUsages {
require.Equal(t, tu.ID, result.TokenUsages[i].ID)
require.Equal(t, tu.InterceptionID, result.TokenUsages[i].InterceptionID)
require.Equal(t, tu.ProviderResponseID, result.TokenUsages[i].ProviderResponseID)
require.Equal(t, tu.InputTokens, result.TokenUsages[i].InputTokens)
require.Equal(t, tu.OutputTokens, result.TokenUsages[i].OutputTokens)
}

// Verify user prompts are converted correctly.
for i, up := range tc.userPrompts {
require.Equal(t, up.ID, result.UserPrompts[i].ID)
require.Equal(t, up.InterceptionID, result.UserPrompts[i].InterceptionID)
require.Equal(t, up.ProviderResponseID, result.UserPrompts[i].ProviderResponseID)
require.Equal(t, up.Prompt, result.UserPrompts[i].Prompt)
}

// Verify tool usages are converted correctly.
for i, toolUsage := range tc.toolUsages {
require.Equal(t, toolUsage.ID, result.ToolUsages[i].ID)
require.Equal(t, toolUsage.InterceptionID, result.ToolUsages[i].InterceptionID)
require.Equal(t, toolUsage.ProviderResponseID, result.ToolUsages[i].ProviderResponseID)
require.Equal(t, toolUsage.ServerUrl.String, result.ToolUsages[i].ServerURL)
require.Equal(t, toolUsage.Tool, result.ToolUsages[i].Tool)
require.Equal(t, toolUsage.Input, result.ToolUsages[i].Input)
require.Equal(t, toolUsage.Injected, result.ToolUsages[i].Injected)
require.Equal(t, toolUsage.InvocationError.String, result.ToolUsages[i].InvocationError)
}
})
}
}
1 change: 1 addition & 0 deletions codersdk/aibridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type AIBridgeInterception struct {
Initiator MinimalUser `json:"initiator"`
Provider string `json:"provider"`
Model string `json:"model"`
Client *string `json:"client"`
Metadata map[string]any `json:"metadata"`
StartedAt time.Time `json:"started_at" format:"date-time"`
EndedAt *time.Time `json:"ended_at" format:"date-time"`
Expand Down
1 change: 1 addition & 0 deletions docs/reference/api/aibridge.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions docs/reference/api/schemas.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions enterprise/aibridged/aibridged_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/aibridged"
"github.com/coder/coder/v2/enterprise/aibridgedserver"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/testutil"
)
Expand Down Expand Up @@ -226,9 +227,11 @@ func TestIntegration(t *testing.T) {
}
]
}`))
userAgent := "codex_cli_rs/0.87.0"
require.NoError(t, err, "make request to test server")
req.Header.Add("Authorization", "Bearer "+apiKey.Key)
req.Header.Add("Accept", "application/json")
req.Header.Add("User-Agent", userAgent)

// When: aibridged handles the request.
rec := httptest.NewRecorder()
Expand All @@ -251,6 +254,11 @@ func TestIntegration(t *testing.T) {
require.True(t, intc0.EndedAt.Valid)
require.False(t, intc0.EndedAt.Time.Before(intc0.StartedAt), "EndedAt should not be before StartedAt")
require.Less(t, intc0.EndedAt.Time.Sub(intc0.StartedAt), 5*time.Second)
require.True(t, intc0.Client.Valid)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test expects "unknown" as the client value, but the test setup sends "userAgent123" as the User-Agent header. This suggests there's user agent parsing logic that should be tested more explicitly.

Consider:

  1. Testing with a real user agent string that would parse to a known client (e.g., "claude-code/1.0.0")
  2. Adding a test case that verifies the parsing logic works correctly
  3. Documenting what the expected format is and what "unknown" means

This would make the test more meaningful and verify the actual user agent parsing behavior.

require.Equal(t, aibridge.ClientCodex, intc0.Client.String)

intc0Metadata := gjson.GetBytes(intc0.Metadata.RawMessage, aibridgedserver.MetadataUserAgentKey)
require.Equal(t, userAgent, intc0Metadata.String(), "interception metadata user agent should match request user agent")

prompts, err := db.GetAIBridgeUserPromptsByInterceptionID(ctx, interceptions[0].ID)
require.NoError(t, err)
Expand Down
Loading
Loading