1+ import sys
2+ import tensorstack .utils as Utils
3+ import tensorstack .data_objects as DataObjects
4+ import tensorstack .quantization as Quantization
5+ Utils .redirect_output ()
6+
7+ import torch
8+ import numpy as np
9+ from threading import Event
10+ from collections .abc import Buffer
11+ from typing import Dict , Sequence , List , Tuple , Optional , Union , Any
12+ from transformers import Qwen3ForCausalLM
13+ from diffusers import (
14+ AutoencoderKLFlux2 ,
15+ Flux2Transformer2DModel ,
16+ Flux2KleinPipeline
17+ )
18+
19+ # Globals
20+ _pipeline = None
21+ _processType = None
22+ _pipeline_config = None
23+ _quant_config_diffusers = None
24+ _quant_config_transformers = None
25+ _execution_device = None
26+ _device_map = None
27+ _control_net_path = None
28+ _control_net_cache = None
29+ _step_latent = None
30+ _generator = None
31+ _isMemoryOffload = False
32+ _prompt_cache_key = None
33+ _prompt_cache_value = None
34+ _progress_tracker : Utils .ModelDownloadProgress = None
35+ _cancel_event = Event ()
36+ _pipelineMap = {
37+ "TextToImage" : Flux2KleinPipeline ,
38+ "ImageToImage" : Flux2KleinPipeline ,
39+ "ImageEdit" : Flux2KleinPipeline ,
40+ }
41+
42+
43+ #------------------------------------------------
44+ # Initialize Pipeline
45+ #------------------------------------------------
46+ def initialize (config : DataObjects .PipelineConfig ):
47+ global _progress_tracker , _pipeline_config , _quant_config_diffusers , _quant_config_transformers , _device_map
48+
49+ _progress_tracker = Utils .ModelDownloadProgress (total_models = 4 if config .control_net_path is not None else 3 )
50+ _pipeline_config = Utils .get_pipeline_config (config .base_model_path , config .cache_directory , config .secure_token )
51+ _quant_config_diffusers , _quant_config_transformers = Quantization .get_quantize_model_config (config .data_type , config .quant_data_type , config .memory_mode )
52+ _device_map = Utils .get_device_map (config )
53+ return create_pipeline (config )
54+
55+
56+ #------------------------------------------------
57+ # Load Pipeline
58+ #------------------------------------------------
59+ def load (config_args : Dict [str , Any ]) -> bool :
60+ global _pipeline , _generator , _processType , _execution_device , _isMemoryOffload
61+
62+ # Config
63+ config = DataObjects .PipelineConfig (** config_args )
64+ _processType = config .process_type
65+
66+ # Initialize Pipeline
67+ _pipeline = initialize (config )
68+
69+ # Load Lora
70+ Utils .load_lora_weights (_pipeline , config )
71+
72+ # Memory
73+ _execution_device = torch .device (f"{ config .device } :{ config .device_id } " )
74+ _generator = torch .Generator (device = _execution_device )
75+ _isMemoryOffload = Utils .configure_pipeline_memory (_pipeline , _execution_device , config )
76+ Utils .trim_memory (_isMemoryOffload )
77+ return True
78+
79+
80+ #------------------------------------------------
81+ # Reload Pipeline - ProcessType, LoraAdapters and ControlNet are the only options that can be modified
82+ #------------------------------------------------
83+ def reload (config_args : Dict [str , Any ]) -> bool :
84+ global _pipeline , _processType
85+
86+ # Config
87+ config = DataObjects .PipelineConfig (** config_args )
88+ _processType = config .process_type
89+ _progress_tracker .Reset (total_models = 4 if config .control_net_path is not None else 3 )
90+
91+ # Rebuild Pipeline
92+ _pipeline .unload_lora_weights ()
93+ _pipeline = create_pipeline (config )
94+
95+ # Load Lora
96+ Utils .load_lora_weights (_pipeline , config )
97+
98+ # Memory
99+ Utils .configure_pipeline_memory (_pipeline , _execution_device , config )
100+ Utils .trim_memory (_isMemoryOffload )
101+ return True
102+
103+
104+ #------------------------------------------------
105+ # Cancel Generation
106+ #------------------------------------------------
107+ def generateCancel () -> None :
108+ _cancel_event .set ()
109+
110+
111+ #------------------------------------------------
112+ # Unload Pipline
113+ #------------------------------------------------
114+ def unload () -> bool :
115+ global _pipeline , _prompt_cache_key , _prompt_cache_value
116+ _pipeline = None
117+ _prompt_cache_key = None
118+ _prompt_cache_value = None
119+ Utils .trim_memory (_isMemoryOffload )
120+ return True
121+
122+
123+ #------------------------------------------------
124+ # Get the log entires
125+ #------------------------------------------------
126+ def getLogs () -> list [str ]:
127+ return Utils .get_output ()
128+
129+
130+ #------------------------------------------------
131+ # Ge the last step latent
132+ #------------------------------------------------
133+ def getStepLatent () -> Buffer :
134+ return _step_latent
135+
136+
137+ #------------------------------------------------
138+ # Diffusers pipeline callback to capture step artifacts
139+ #------------------------------------------------
140+ def _progress_callback (pipe , step : int , total_steps : int , info : Dict ):
141+ global _step_latent
142+ if _cancel_event .is_set ():
143+ pipe ._interrupt = True
144+ raise Exception ("Operation Canceled" )
145+
146+ latents = info .get ("latents" )
147+ if latents is not None :
148+ _step_latent = np .ascontiguousarray (latents .float ().cpu ())
149+
150+ return info
151+
152+
153+ #------------------------------------------------
154+ # Generate Image/Video
155+ #------------------------------------------------
156+ def generate (
157+ inference_args : Dict [str , Any ],
158+ input_tensors : Optional [List [Tuple [Sequence [float ],Sequence [int ]]]] = None ,
159+ control_tensors : Optional [List [Tuple [Sequence [float ],Sequence [int ]]]] = None ,
160+ ) -> Sequence [Buffer ]:
161+ global _prompt_cache_key , _prompt_cache_value
162+ _cancel_event .clear ()
163+ _pipeline ._interrupt = False
164+
165+ # Options
166+ options = DataObjects .PipelineOptions (** inference_args )
167+
168+ #scheduler
169+ _pipeline .scheduler = Utils .create_scheduler (options .scheduler , options .scheduler_options , _pipeline .scheduler .config )
170+
171+ #Lora Adapters
172+ Utils .set_lora_weights (_pipeline , options )
173+
174+ # Input Images
175+ image = Utils .prepare_images (input_tensors )
176+ control_image = Utils .prepare_images (control_tensors )
177+
178+ # Prompt Cache
179+ prompt_cache_key = (options .prompt , options .negative_prompt , options .guidance_scale > 1.0 )
180+ if _prompt_cache_key != prompt_cache_key :
181+ with torch .no_grad ():
182+ neg_prompt_embeds = None
183+ prompt_embeds , text_ids = _pipeline .encode_prompt (
184+ prompt = options .prompt ,
185+ num_images_per_prompt = 1 ,
186+ device = _pipeline ._execution_device ,
187+ )
188+ if options .guidance_scale > 1.0 :
189+ neg_prompt_embeds , neg_text_ids = _pipeline .encode_prompt (
190+ prompt = options .negative_prompt ,
191+ num_images_per_prompt = 1 ,
192+ device = _pipeline ._execution_device ,
193+ )
194+
195+ _prompt_cache_value = (prompt_embeds , neg_prompt_embeds )
196+ _prompt_cache_key = prompt_cache_key
197+
198+ # Pipeline Options
199+ (prompt_embeds , negative_prompt_embeds ) = _prompt_cache_value
200+ pipeline_options = {
201+ "image" : image ,
202+ "prompt_embeds" : prompt_embeds ,
203+ "negative_prompt_embeds" : negative_prompt_embeds ,
204+ "height" : options .height ,
205+ "width" : options .width ,
206+ "generator" : _generator .manual_seed (options .seed ),
207+ "guidance_scale" : options .guidance_scale ,
208+ "num_inference_steps" : options .steps ,
209+ "output_type" : "np" ,
210+ "callback_on_step_end" : _progress_callback ,
211+ "callback_on_step_end_tensor_inputs" : ["latents" ],
212+ }
213+
214+ # Run Pipeline
215+ output = _pipeline (** pipeline_options )[0 ]
216+
217+ # (Batch, Channel, Height, Width)
218+ output = output .transpose (0 , 3 , 1 , 2 ).astype (np .float32 )
219+
220+ # Cleanup
221+ Utils .trim_memory (_isMemoryOffload )
222+ return [ np .ascontiguousarray (output ) ]
223+
224+
225+ #------------------------------------------------
226+ # Create a new pipeline
227+ #------------------------------------------------
228+ def create_pipeline (config : DataObjects .PipelineConfig ):
229+ pipeline_kwargs = {
230+ "variant" : config .variant ,
231+ "token" : config .secure_token ,
232+ "cache_dir" : config .cache_directory
233+ }
234+
235+ # Load Models
236+ text_encoder = load_text_encoder (config , pipeline_kwargs )
237+ transformer = load_transformer (config , pipeline_kwargs )
238+ vae = load_vae (config , pipeline_kwargs )
239+ # control_net = load_control_net(config, pipeline_kwargs)
240+ # if control_net is not None:
241+ # pipeline_kwargs.update({"controlnet": control_net})
242+
243+ # Build Pipeline
244+ _progress_tracker .Clear ()
245+ pipeline = _pipelineMap [config .process_type ]
246+ return pipeline .from_pretrained (
247+ config .base_model_path ,
248+ text_encoder = text_encoder ,
249+ transformer = transformer ,
250+ vae = vae ,
251+ torch_dtype = config .data_type ,
252+ device_map = _device_map ,
253+ local_files_only = True ,
254+ ** pipeline_kwargs
255+ )
256+
257+
258+ #------------------------------------------------
259+ # Load Qwen3ForCausalLM
260+ #------------------------------------------------
261+ def load_text_encoder (
262+ config : DataObjects .PipelineConfig ,
263+ pipeline_kwargs : Dict [str , str ]
264+ ):
265+
266+ if _pipeline and _pipeline .text_encoder :
267+ print (f"[Reload] Loading cached TextEncoder" )
268+ return _pipeline .text_encoder
269+
270+ _progress_tracker .Initialize (0 , "text_encoder" )
271+ checkpoint_config = config .checkpoint_config
272+ if checkpoint_config .text_encoder_checkpoint is not None :
273+ text_encoder = Qwen3ForCausalLM .from_single_file (
274+ checkpoint_config .text_encoder_checkpoint ,
275+ config = _pipeline_config ["text_encoder" ],
276+ torch_dtype = config .data_type ,
277+ use_safetensors = True ,
278+ local_files_only = True
279+ )
280+ Quantization .quantize_model (config , text_encoder )
281+ return text_encoder
282+
283+ return Qwen3ForCausalLM .from_pretrained (
284+ config .base_model_path ,
285+ subfolder = "text_encoder" ,
286+ torch_dtype = config .data_type ,
287+ quantization_config = _quant_config_transformers ,
288+ use_safetensors = True ,
289+ ** pipeline_kwargs
290+ )
291+
292+
293+ #------------------------------------------------
294+ # Load Flux2Transformer2DModel
295+ #------------------------------------------------
296+ def load_transformer (
297+ config : DataObjects .PipelineConfig ,
298+ pipeline_kwargs : Dict [str , str ]
299+ ):
300+
301+ if _pipeline and _pipeline .transformer :
302+ print (f"[Reload] Loading cached Transformer" )
303+ return _pipeline .transformer
304+
305+ _progress_tracker .Initialize (1 , "transformer" )
306+ checkpoint_config = config .checkpoint_config
307+ if checkpoint_config .model_checkpoint is not None :
308+ transformer = Flux2Transformer2DModel .from_single_file (
309+ checkpoint_config .model_checkpoint ,
310+ config = _pipeline_config ["transformer" ],
311+ torch_dtype = config .data_type ,
312+ use_safetensors = True ,
313+ local_files_only = True ,
314+ quantization_config = Quantization .get_single_file_config (config )
315+ )
316+ Quantization .quantize_model (config , transformer )
317+ return transformer
318+
319+ return Flux2Transformer2DModel .from_pretrained (
320+ config .base_model_path ,
321+ subfolder = "transformer" ,
322+ torch_dtype = config .data_type ,
323+ quantization_config = _quant_config_diffusers ,
324+ use_safetensors = True ,
325+ ** pipeline_kwargs
326+ )
327+
328+
329+ #------------------------------------------------
330+ # Load AutoencoderKLFlux2
331+ #------------------------------------------------
332+ def load_vae (
333+ config : DataObjects .PipelineConfig ,
334+ pipeline_kwargs : Dict [str , str ]
335+ ):
336+
337+ if _pipeline and _pipeline .vae :
338+ print (f"[Reload] Loading cached Vae" )
339+ return _pipeline .vae
340+
341+ _progress_tracker .Initialize (2 , "vae" )
342+ checkpoint_config = config .checkpoint_config
343+ if checkpoint_config .vae_checkpoint is not None :
344+ return AutoencoderKLFlux2 .from_single_file (
345+ checkpoint_config .vae_checkpoint ,
346+ config = _pipeline_config ["vae" ],
347+ torch_dtype = config .data_type ,
348+ use_safetensors = True ,
349+ local_files_only = True
350+ )
351+
352+ return AutoencoderKLFlux2 .from_pretrained (
353+ config .base_model_path ,
354+ subfolder = "vae" ,
355+ torch_dtype = config .data_type ,
356+ use_safetensors = True ,
357+ ** pipeline_kwargs
358+ )
359+
360+
361+ # #------------------------------------------------
362+ # # Load ControlNetModel
363+ # #------------------------------------------------
364+ # def load_control_net(
365+ # config: DataObjects.PipelineConfig,
366+ # pipeline_kwargs: Dict[str, str]
367+ # ):
368+ # global _control_net_path, _control_net_cache
369+
370+ # if _control_net_cache and _control_net_path == config.control_net_path:
371+ # print(f"[Reload] Loading cached ControlNet")
372+ # return _control_net_cache
373+
374+ # if config.control_net_path is None:
375+ # _control_net_path = None
376+ # _control_net_cache = None
377+ # return None
378+
379+ # _control_net_path = config.control_net_path
380+ # _progress_tracker.Initialize(3, "control_net")
381+ # _control_net_cache = ControlNetModel.from_pretrained(
382+ # _control_net_path,
383+ # torch_dtype=config.data_type,
384+ # use_safetensors=True,
385+ # )
386+ # return _control_net_cache
0 commit comments