Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit bfcdb9f

Browse files
committed
vLLM Provider
1 parent b549eab commit bfcdb9f

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

src/codegate/config.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313

1414
logger = setup_logging()
1515

16+
# Default provider URLs
17+
DEFAULT_PROVIDER_URLS = {
18+
"openai": "https://api.openai.com/v1",
19+
"anthropic": "https://api.anthropic.com/v1",
20+
"vllm": "http://localhost:8000", # Base URL without /v1 path
21+
}
22+
1623

1724
@dataclass
1825
class Config:
@@ -33,13 +40,7 @@ class Config:
3340
chat_model_n_gpu_layers: int = -1
3441

3542
# Provider URLs with defaults
36-
provider_urls: Dict[str, str] = field(
37-
default_factory=lambda: {
38-
"openai": "https://api.openai.com/v1",
39-
"anthropic": "https://api.anthropic.com/v1",
40-
"vllm": "http://localhost:8000", # Base URL without /v1 path
41-
}
42-
)
43+
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
4344

4445
def __post_init__(self) -> None:
4546
"""Validate configuration after initialization."""
@@ -105,7 +106,7 @@ def from_file(cls, config_path: Union[str, Path]) -> "Config":
105106
prompts_config = PromptConfig.from_file(prompts_path)
106107

107108
# Get provider URLs from config
108-
provider_urls = cls.provider_urls.copy()
109+
provider_urls = DEFAULT_PROVIDER_URLS.copy()
109110
if "provider_urls" in config_data:
110111
provider_urls.update(config_data.pop("provider_urls"))
111112

@@ -152,7 +153,7 @@ def from_env(cls) -> "Config":
152153
) # noqa: E501
153154

154155
# Load provider URLs from environment variables
155-
for provider in config.provider_urls.keys():
156+
for provider in DEFAULT_PROVIDER_URLS.keys():
156157
env_var = f"CODEGATE_PROVIDER_{provider.upper()}_URL"
157158
if env_var in os.environ:
158159
config.provider_urls[provider] = os.environ[env_var]

tests/test_cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from codegate.cli import cli
1111
from codegate.codegate_logging import LogFormat, LogLevel
12+
from codegate.config import DEFAULT_PROVIDER_URLS
1213

1314

1415
@pytest.fixture
@@ -63,6 +64,7 @@ def test_serve_default_options(cli_runner: CliRunner, mock_logging: Any) -> None
6364
"log_level": "INFO",
6465
"log_format": "JSON",
6566
"prompts_loaded": 7, # Default prompts are loaded
67+
"provider_urls": DEFAULT_PROVIDER_URLS,
6668
},
6769
)
6870
mock_run.assert_called_once()
@@ -98,6 +100,7 @@ def test_serve_custom_options(cli_runner: CliRunner, mock_logging: Any) -> None:
98100
"log_level": "DEBUG",
99101
"log_format": "TEXT",
100102
"prompts_loaded": 7, # Default prompts are loaded
103+
"provider_urls": DEFAULT_PROVIDER_URLS,
101104
},
102105
)
103106
mock_run.assert_called_once()
@@ -136,6 +139,7 @@ def test_serve_with_config_file(
136139
"log_level": "DEBUG",
137140
"log_format": "JSON",
138141
"prompts_loaded": 7, # Default prompts are loaded
142+
"provider_urls": DEFAULT_PROVIDER_URLS,
139143
},
140144
)
141145
mock_run.assert_called_once()
@@ -182,6 +186,7 @@ def test_serve_priority_resolution(
182186
"log_level": "ERROR",
183187
"log_format": "TEXT",
184188
"prompts_loaded": 7, # Default prompts are loaded
189+
"provider_urls": DEFAULT_PROVIDER_URLS,
185190
},
186191
)
187192
mock_run.assert_called_once()

tests/test_config.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
import yaml
88

9-
from codegate.config import Config, ConfigurationError, LogFormat, LogLevel
9+
from codegate.config import DEFAULT_PROVIDER_URLS, Config, ConfigurationError, LogFormat, LogLevel
1010

1111

1212
def test_default_config(default_config: Config) -> None:
@@ -15,6 +15,7 @@ def test_default_config(default_config: Config) -> None:
1515
assert default_config.host == "localhost"
1616
assert default_config.log_level == LogLevel.INFO
1717
assert default_config.log_format == LogFormat.JSON
18+
assert default_config.provider_urls == DEFAULT_PROVIDER_URLS
1819

1920

2021
def test_config_from_file(temp_config_file: Path) -> None:
@@ -24,6 +25,7 @@ def test_config_from_file(temp_config_file: Path) -> None:
2425
assert config.host == "localhost"
2526
assert config.log_level == LogLevel.DEBUG
2627
assert config.log_format == LogFormat.JSON
28+
assert config.provider_urls == DEFAULT_PROVIDER_URLS
2729

2830

2931
def test_config_from_invalid_file(tmp_path: Path) -> None:
@@ -49,6 +51,7 @@ def test_config_from_env(env_vars: None) -> None:
4951
assert config.host == "localhost"
5052
assert config.log_level == LogLevel.WARNING
5153
assert config.log_format == LogFormat.TEXT
54+
assert config.provider_urls == DEFAULT_PROVIDER_URLS
5255

5356

5457
def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> None:
@@ -60,18 +63,21 @@ def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> N
6063
cli_host="example.com",
6164
cli_log_level="WARNING",
6265
cli_log_format="TEXT",
66+
cli_provider_urls={"vllm": "https://custom.vllm.server"},
6367
)
6468
assert config.port == 8080
6569
assert config.host == "example.com"
6670
assert config.log_level == LogLevel.WARNING
6771
assert config.log_format == LogFormat.TEXT
72+
assert config.provider_urls["vllm"] == "https://custom.vllm.server"
6873

6974
# Env vars should override config file
7075
config = Config.load(config_path=temp_config_file)
7176
assert config.port == 8989 # from env
7277
assert config.host == "localhost" # from env
7378
assert config.log_level == LogLevel.WARNING # from env
7479
assert config.log_format == LogFormat.TEXT # from env
80+
assert config.provider_urls == DEFAULT_PROVIDER_URLS # no env override
7581

7682
# Config file should override defaults
7783
os.environ.clear() # Remove env vars
@@ -80,6 +86,34 @@ def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> N
8086
assert config.host == "localhost" # from file
8187
assert config.log_level == LogLevel.DEBUG # from file
8288
assert config.log_format == LogFormat.JSON # from file
89+
assert config.provider_urls == DEFAULT_PROVIDER_URLS # default values
90+
91+
92+
def test_provider_urls_from_config(tmp_path: Path) -> None:
93+
"""Test loading provider URLs from config file."""
94+
config_file = tmp_path / "config.yaml"
95+
custom_urls = {
96+
"vllm": "https://custom.vllm.server",
97+
"openai": "https://custom.openai.server",
98+
}
99+
with open(config_file, "w") as f:
100+
yaml.dump({"provider_urls": custom_urls}, f)
101+
102+
config = Config.from_file(config_file)
103+
assert config.provider_urls["vllm"] == custom_urls["vllm"]
104+
assert config.provider_urls["openai"] == custom_urls["openai"]
105+
assert config.provider_urls["anthropic"] == DEFAULT_PROVIDER_URLS["anthropic"]
106+
107+
108+
def test_provider_urls_from_env() -> None:
109+
"""Test loading provider URLs from environment variables."""
110+
os.environ["CODEGATE_PROVIDER_VLLM_URL"] = "https://custom.vllm.server"
111+
try:
112+
config = Config.from_env()
113+
assert config.provider_urls["vllm"] == "https://custom.vllm.server"
114+
assert config.provider_urls["openai"] == DEFAULT_PROVIDER_URLS["openai"]
115+
finally:
116+
del os.environ["CODEGATE_PROVIDER_VLLM_URL"]
83117

84118

85119
def test_invalid_log_level() -> None:

0 commit comments

Comments
 (0)