Skip to content

Commit b383b55

Browse files
committed
Model download endpoints
1 parent 2b18467 commit b383b55

15 files changed

Lines changed: 71 additions & 43 deletions

TensorStack.Python/Pipelines/ChromaPipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def initialize(config: DataObjects.PipelineConfig):
5858
#------------------------------------------------
5959
# Download Pipeline
6060
#------------------------------------------------
61-
def download(config: DataObjects.PipelineConfig):
62-
global _progress_tracker, _pipeline_config
63-
61+
def download(config_args: Dict[str, Any]):
62+
global _progress_tracker, _pipeline_config, _device_map
63+
64+
_device_map = "meta"
65+
config = DataObjects.PipelineConfig(**config_args)
6466
_progress_tracker = Utils.ModelDownloadProgress(total_models=get_model_count(config))
6567
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
6668
create_pipeline(config, True)

TensorStack.Python/Pipelines/CogVideoXPipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ def initialize(config: DataObjects.PipelineConfig):
6060
#------------------------------------------------
6161
# Download Pipeline
6262
#------------------------------------------------
63-
def download(config: DataObjects.PipelineConfig):
64-
global _progress_tracker, _pipeline_config
65-
63+
def download(config_args: Dict[str, Any]):
64+
global _progress_tracker, _pipeline_config, _device_map
65+
66+
_device_map = "meta"
67+
config = DataObjects.PipelineConfig(**config_args)
6668
_progress_tracker = Utils.ModelDownloadProgress(total_models=get_model_count(config))
6769
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
6870
create_pipeline(config, True)

TensorStack.Python/Pipelines/Flux2KleinPipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def initialize(config: DataObjects.PipelineConfig):
5858
#------------------------------------------------
5959
# Download Pipeline
6060
#------------------------------------------------
61-
def download(config: DataObjects.PipelineConfig):
62-
global _progress_tracker, _pipeline_config
63-
61+
def download(config_args: Dict[str, Any]):
62+
global _progress_tracker, _pipeline_config, _device_map
63+
64+
_device_map = "meta"
65+
config = DataObjects.PipelineConfig(**config_args)
6466
_progress_tracker = Utils.ModelDownloadProgress(total_models=get_model_count(config))
6567
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
6668
create_pipeline(config, True)

TensorStack.Python/Pipelines/Flux2Pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ def initialize(config: DataObjects.PipelineConfig):
5757
#------------------------------------------------
5858
# Download Pipeline
5959
#------------------------------------------------
60-
def download(config: DataObjects.PipelineConfig):
61-
global _progress_tracker, _pipeline_config
62-
60+
def download(config_args: Dict[str, Any]):
61+
global _progress_tracker, _pipeline_config, _device_map
62+
63+
_device_map = "meta"
64+
config = DataObjects.PipelineConfig(**config_args)
6365
_progress_tracker = Utils.ModelDownloadProgress(total_models=get_model_count(config))
6466
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
6567
create_pipeline(config, True)

TensorStack.Python/Pipelines/FluxPipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ def initialize(config: DataObjects.PipelineConfig):
6565
#------------------------------------------------
6666
# Download Pipeline
6767
#------------------------------------------------
68-
def download(config: DataObjects.PipelineConfig):
69-
global _progress_tracker, _pipeline_config
70-
68+
def download(config_args: Dict[str, Any]):
69+
global _progress_tracker, _pipeline_config, _device_map
70+
71+
_device_map = "meta"
72+
config = DataObjects.PipelineConfig(**config_args)
7173
_progress_tracker = Utils.ModelDownloadProgress(total_models=get_model_count(config))
7274
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
7375
create_pipeline(config, True)

TensorStack.Python/Pipelines/Kandinsky5Pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ def initialize(config: DataObjects.PipelineConfig):
6464
#------------------------------------------------
6565
# Download Pipeline
6666
#------------------------------------------------
67-
def download(config: DataObjects.PipelineConfig):
68-
global _progress_tracker, _pipeline_config
69-
67+
def download(config_args: Dict[str, Any]):
68+
global _progress_tracker, _pipeline_config, _device_map
69+
70+
_device_map = "meta"
71+
config = DataObjects.PipelineConfig(**config_args)
7072
_progress_tracker = Utils.ModelDownloadProgress(total_models=get_model_count(config))
7173
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
7274
create_pipeline(config, True)

TensorStack.Python/Pipelines/LTX2Pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ def initialize(config: DataObjects.PipelineConfig):
6262
#------------------------------------------------
6363
# Download Pipeline
6464
#------------------------------------------------
65-
def download(config: DataObjects.PipelineConfig):
66-
global _progress_tracker, _pipeline_config
67-
65+
def download(config_args: Dict[str, Any]):
66+
global _progress_tracker, _pipeline_config, _device_map
67+
68+
_device_map = "meta"
69+
config = DataObjects.PipelineConfig(**config_args)
6870
_progress_tracker = Utils.ModelDownloadProgress(total_models=get_model_count(config))
6971
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
7072
create_pipeline(config, True)

TensorStack.Python/Pipelines/LTXPipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def initialize(config: DataObjects.PipelineConfig):
5858
#------------------------------------------------
5959
# Download Pipeline
6060
#------------------------------------------------
61-
def download(config: DataObjects.PipelineConfig):
62-
global _progress_tracker, _pipeline_config
63-
61+
def download(config_args: Dict[str, Any]):
62+
global _progress_tracker, _pipeline_config, _device_map
63+
64+
_device_map = "meta"
65+
config = DataObjects.PipelineConfig(**config_args)
6466
_progress_tracker = Utils.ModelDownloadProgress(total_models=get_model_count(config))
6567
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
6668
create_pipeline(config, True)

TensorStack.Python/Pipelines/QwenImagePipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ def initialize(config: DataObjects.PipelineConfig):
6565
#------------------------------------------------
6666
# Download Pipeline
6767
#------------------------------------------------
68-
def download(config: DataObjects.PipelineConfig):
69-
global _progress_tracker, _pipeline_config
70-
68+
def download(config_args: Dict[str, Any]):
69+
global _progress_tracker, _pipeline_config, _device_map
70+
71+
_device_map = "meta"
72+
config = DataObjects.PipelineConfig(**config_args)
7173
_progress_tracker = Utils.ModelDownloadProgress(total_models=get_model_count(config))
7274
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
7375
create_pipeline(config, True)

TensorStack.Python/Pipelines/SkyReelsV2Pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ def initialize(config: DataObjects.PipelineConfig):
6060
#------------------------------------------------
6161
# Download Pipeline
6262
#------------------------------------------------
63-
def download(config: DataObjects.PipelineConfig):
64-
global _progress_tracker, _pipeline_config
65-
63+
def download(config_args: Dict[str, Any]):
64+
global _progress_tracker, _pipeline_config, _device_map
65+
66+
_device_map = "meta"
67+
config = DataObjects.PipelineConfig(**config_args)
6668
_progress_tracker = Utils.ModelDownloadProgress(total_models=get_model_count(config))
6769
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
6870
create_pipeline(config, True)

0 commit comments

Comments
 (0)