This repository was archived by the owner on Apr 15, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmodels.py
More file actions
82 lines (69 loc) · 2.9 KB
/
models.py
File metadata and controls
82 lines (69 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Copyright 2024 Tecnotree, Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
----
Module to interact with the Sensa model registry
"""
import os
from typing import Optional, Dict, Any, Union
from cortex.serviceconnector import _Client
from cortex.utils import generate_token
try:
import mlflow
except ImportError:
mlflow = None
def check_installed():
"""
Checks if the model SDK extra is installed
:return:
"""
if mlflow is None:
raise NotImplementedError(
'Models SDK extra not installed, please run `pip install cortex-python[models_dev]` to install')
class ModelClient(_Client):
"""
Client for model registry, this class requires the `models_sdk` extras to be installed
"""
def _setup_model_client(self, verify_ssl_cert=True, ttl="2h"):
# Generate a JWT, this call stores the JWT in `_serviceconnector.jwt` ( meh )
token = generate_token(self._serviceconnector._config, verify_ssl_cert=verify_ssl_cert, validity=ttl) # pylint: disable=protected-access
mlflow.set_tracking_uri(self._serviceconnector.url)
os.environ['MLFLOW_TRACKING_URI'] = self._serviceconnector.url
os.environ['MLFLOW_TRACKING_TOKEN'] = token
# detect cortex client setting to avoid invalid SSL cert errors
# os.environ['MLFLOW_TRACKING_TOKEN']='true'
# os.environ['MLFLOW_TRACKING_CLIENT_CERT_PATH']=
# Need api to fetch serverside userid..
# os.environ['MLFLOW_TRACKING_USERNAME']=_Client.???
def login(self, ttl: Optional[Union[str, int]] = '2h'):
"""
Configure connection settings for model registry.
:param ttl: Time to live, DEFAULT: 2h
"""
check_installed()
self._setup_model_client(ttl=ttl)
print("Configuring connection for model registry")
def create_experiment(self,
name: str,
tags: Optional[Dict[str, Any]] = None,
) -> str:
"""
Create an MLFlow experiment with default tags
:param name: experiment name, must be unique
:param tags: optional experiment tags
"""
check_installed()
if tags is None:
tags = {}
# default to client project if project tag isn't specified
if tags.get('project') is None:
tags['project'] = self._project()
return mlflow.create_experiment(name, tags=tags)