Skip to content

Commit 19fffac

Browse files
authored
Merge pull request #14 from TensorStack-AI/FluxKlein
Flux2 Klein Pipeline
2 parents d9fc595 + 3d291e5 commit 19fffac

2 files changed

Lines changed: 392 additions & 0 deletions

File tree

Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
import sys
2+
import tensorstack.utils as Utils
3+
import tensorstack.data_objects as DataObjects
4+
import tensorstack.quantization as Quantization
5+
Utils.redirect_output()
6+
7+
import torch
8+
import numpy as np
9+
from threading import Event
10+
from collections.abc import Buffer
11+
from typing import Dict, Sequence, List, Tuple, Optional, Union, Any
12+
from transformers import Qwen3ForCausalLM
13+
from diffusers import (
14+
AutoencoderKLFlux2,
15+
Flux2Transformer2DModel,
16+
Flux2KleinPipeline
17+
)
18+
19+
# Globals
20+
_pipeline = None
21+
_processType = None
22+
_pipeline_config = None
23+
_quant_config_diffusers = None
24+
_quant_config_transformers = None
25+
_execution_device = None
26+
_device_map = None
27+
_control_net_path = None
28+
_control_net_cache = None
29+
_step_latent = None
30+
_generator = None
31+
_isMemoryOffload = False
32+
_prompt_cache_key = None
33+
_prompt_cache_value = None
34+
_progress_tracker: Utils.ModelDownloadProgress = None
35+
_cancel_event = Event()
36+
_pipelineMap = {
37+
"TextToImage": Flux2KleinPipeline,
38+
"ImageToImage": Flux2KleinPipeline,
39+
"ImageEdit": Flux2KleinPipeline,
40+
}
41+
42+
43+
#------------------------------------------------
44+
# Initialize Pipeline
45+
#------------------------------------------------
46+
def initialize(config: DataObjects.PipelineConfig):
47+
global _progress_tracker, _pipeline_config, _quant_config_diffusers, _quant_config_transformers, _device_map
48+
49+
_progress_tracker = Utils.ModelDownloadProgress(total_models=4 if config.control_net_path is not None else 3)
50+
_pipeline_config = Utils.get_pipeline_config(config.base_model_path, config.cache_directory, config.secure_token)
51+
_quant_config_diffusers, _quant_config_transformers = Quantization.get_quantize_model_config(config.data_type, config.quant_data_type, config.memory_mode)
52+
_device_map = Utils.get_device_map(config)
53+
return create_pipeline(config)
54+
55+
56+
#------------------------------------------------
57+
# Load Pipeline
58+
#------------------------------------------------
59+
def load(config_args: Dict[str, Any]) -> bool:
60+
global _pipeline, _generator, _processType, _execution_device, _isMemoryOffload
61+
62+
# Config
63+
config = DataObjects.PipelineConfig(**config_args)
64+
_processType = config.process_type
65+
66+
# Initialize Pipeline
67+
_pipeline = initialize(config)
68+
69+
# Load Lora
70+
Utils.load_lora_weights(_pipeline, config)
71+
72+
# Memory
73+
_execution_device = torch.device(f"{config.device}:{config.device_id}")
74+
_generator = torch.Generator(device=_execution_device)
75+
_isMemoryOffload = Utils.configure_pipeline_memory(_pipeline, _execution_device, config)
76+
Utils.trim_memory(_isMemoryOffload)
77+
return True
78+
79+
80+
#------------------------------------------------
81+
# Reload Pipeline - ProcessType, LoraAdapters and ControlNet are the only options that can be modified
82+
#------------------------------------------------
83+
def reload(config_args: Dict[str, Any]) -> bool:
84+
global _pipeline, _processType
85+
86+
# Config
87+
config = DataObjects.PipelineConfig(**config_args)
88+
_processType = config.process_type
89+
_progress_tracker.Reset(total_models=4 if config.control_net_path is not None else 3)
90+
91+
# Rebuild Pipeline
92+
_pipeline.unload_lora_weights()
93+
_pipeline = create_pipeline(config)
94+
95+
# Load Lora
96+
Utils.load_lora_weights(_pipeline, config)
97+
98+
# Memory
99+
Utils.configure_pipeline_memory(_pipeline, _execution_device, config)
100+
Utils.trim_memory(_isMemoryOffload)
101+
return True
102+
103+
104+
#------------------------------------------------
105+
# Cancel Generation
106+
#------------------------------------------------
107+
def generateCancel() -> None:
108+
_cancel_event.set()
109+
110+
111+
#------------------------------------------------
112+
# Unload Pipline
113+
#------------------------------------------------
114+
def unload() -> bool:
115+
global _pipeline, _prompt_cache_key, _prompt_cache_value
116+
_pipeline = None
117+
_prompt_cache_key = None
118+
_prompt_cache_value = None
119+
Utils.trim_memory(_isMemoryOffload)
120+
return True
121+
122+
123+
#------------------------------------------------
124+
# Get the log entires
125+
#------------------------------------------------
126+
def getLogs() -> list[str]:
127+
return Utils.get_output()
128+
129+
130+
#------------------------------------------------
131+
# Ge the last step latent
132+
#------------------------------------------------
133+
def getStepLatent() -> Buffer:
134+
return _step_latent
135+
136+
137+
#------------------------------------------------
138+
# Diffusers pipeline callback to capture step artifacts
139+
#------------------------------------------------
140+
def _progress_callback(pipe, step: int, total_steps: int, info: Dict):
141+
global _step_latent
142+
if _cancel_event.is_set():
143+
pipe._interrupt = True
144+
raise Exception("Operation Canceled")
145+
146+
latents = info.get("latents")
147+
if latents is not None:
148+
_step_latent = np.ascontiguousarray(latents.float().cpu())
149+
150+
return info
151+
152+
153+
#------------------------------------------------
154+
# Generate Image/Video
155+
#------------------------------------------------
156+
def generate(
157+
inference_args: Dict[str, Any],
158+
input_tensors: Optional[List[Tuple[Sequence[float],Sequence[int]]]] = None,
159+
control_tensors: Optional[List[Tuple[Sequence[float],Sequence[int]]]] = None,
160+
) -> Sequence[Buffer]:
161+
global _prompt_cache_key, _prompt_cache_value
162+
_cancel_event.clear()
163+
_pipeline._interrupt = False
164+
165+
# Options
166+
options = DataObjects.PipelineOptions(**inference_args)
167+
168+
#scheduler
169+
_pipeline.scheduler = Utils.create_scheduler(options.scheduler, options.scheduler_options, _pipeline.scheduler.config)
170+
171+
#Lora Adapters
172+
Utils.set_lora_weights(_pipeline, options)
173+
174+
# Input Images
175+
image = Utils.prepare_images(input_tensors)
176+
control_image = Utils.prepare_images(control_tensors)
177+
178+
# Prompt Cache
179+
prompt_cache_key = (options.prompt, options.negative_prompt, options.guidance_scale > 1.0)
180+
if _prompt_cache_key != prompt_cache_key:
181+
with torch.no_grad():
182+
neg_prompt_embeds = None
183+
prompt_embeds, text_ids = _pipeline.encode_prompt(
184+
prompt=options.prompt,
185+
num_images_per_prompt=1,
186+
device=_pipeline._execution_device,
187+
)
188+
if options.guidance_scale > 1.0:
189+
neg_prompt_embeds, neg_text_ids = _pipeline.encode_prompt(
190+
prompt=options.negative_prompt,
191+
num_images_per_prompt=1,
192+
device=_pipeline._execution_device,
193+
)
194+
195+
_prompt_cache_value = (prompt_embeds, neg_prompt_embeds)
196+
_prompt_cache_key = prompt_cache_key
197+
198+
# Pipeline Options
199+
(prompt_embeds, negative_prompt_embeds) = _prompt_cache_value
200+
pipeline_options = {
201+
"image": image,
202+
"prompt_embeds": prompt_embeds,
203+
"negative_prompt_embeds": negative_prompt_embeds,
204+
"height": options.height,
205+
"width": options.width,
206+
"generator": _generator.manual_seed(options.seed),
207+
"guidance_scale": options.guidance_scale,
208+
"num_inference_steps": options.steps,
209+
"output_type": "np",
210+
"callback_on_step_end": _progress_callback,
211+
"callback_on_step_end_tensor_inputs": ["latents"],
212+
}
213+
214+
# Run Pipeline
215+
output = _pipeline(**pipeline_options)[0]
216+
217+
# (Batch, Channel, Height, Width)
218+
output = output.transpose(0, 3, 1, 2).astype(np.float32)
219+
220+
# Cleanup
221+
Utils.trim_memory(_isMemoryOffload)
222+
return [ np.ascontiguousarray(output) ]
223+
224+
225+
#------------------------------------------------
226+
# Create a new pipeline
227+
#------------------------------------------------
228+
def create_pipeline(config: DataObjects.PipelineConfig):
229+
pipeline_kwargs = {
230+
"variant": config.variant,
231+
"token": config.secure_token,
232+
"cache_dir": config.cache_directory
233+
}
234+
235+
# Load Models
236+
text_encoder = load_text_encoder(config, pipeline_kwargs)
237+
transformer = load_transformer(config, pipeline_kwargs)
238+
vae = load_vae(config, pipeline_kwargs)
239+
# control_net = load_control_net(config, pipeline_kwargs)
240+
# if control_net is not None:
241+
# pipeline_kwargs.update({"controlnet": control_net})
242+
243+
# Build Pipeline
244+
_progress_tracker.Clear()
245+
pipeline = _pipelineMap[config.process_type]
246+
return pipeline.from_pretrained(
247+
config.base_model_path,
248+
text_encoder=text_encoder,
249+
transformer=transformer,
250+
vae=vae,
251+
torch_dtype=config.data_type,
252+
device_map=_device_map,
253+
local_files_only=True,
254+
**pipeline_kwargs
255+
)
256+
257+
258+
#------------------------------------------------
259+
# Load Qwen3ForCausalLM
260+
#------------------------------------------------
261+
def load_text_encoder(
262+
config: DataObjects.PipelineConfig,
263+
pipeline_kwargs: Dict[str, str]
264+
):
265+
266+
if _pipeline and _pipeline.text_encoder:
267+
print(f"[Reload] Loading cached TextEncoder")
268+
return _pipeline.text_encoder
269+
270+
_progress_tracker.Initialize(0, "text_encoder")
271+
checkpoint_config = config.checkpoint_config
272+
if checkpoint_config.text_encoder_checkpoint is not None:
273+
text_encoder = Qwen3ForCausalLM.from_single_file(
274+
checkpoint_config.text_encoder_checkpoint,
275+
config=_pipeline_config["text_encoder"],
276+
torch_dtype=config.data_type,
277+
use_safetensors=True,
278+
local_files_only=True
279+
)
280+
Quantization.quantize_model(config, text_encoder)
281+
return text_encoder
282+
283+
return Qwen3ForCausalLM.from_pretrained(
284+
config.base_model_path,
285+
subfolder="text_encoder",
286+
torch_dtype=config.data_type,
287+
quantization_config=_quant_config_transformers,
288+
use_safetensors=True,
289+
**pipeline_kwargs
290+
)
291+
292+
293+
#------------------------------------------------
294+
# Load Flux2Transformer2DModel
295+
#------------------------------------------------
296+
def load_transformer(
297+
config: DataObjects.PipelineConfig,
298+
pipeline_kwargs: Dict[str, str]
299+
):
300+
301+
if _pipeline and _pipeline.transformer:
302+
print(f"[Reload] Loading cached Transformer")
303+
return _pipeline.transformer
304+
305+
_progress_tracker.Initialize(1, "transformer")
306+
checkpoint_config = config.checkpoint_config
307+
if checkpoint_config.model_checkpoint is not None:
308+
transformer = Flux2Transformer2DModel.from_single_file(
309+
checkpoint_config.model_checkpoint,
310+
config=_pipeline_config["transformer"],
311+
torch_dtype=config.data_type,
312+
use_safetensors=True,
313+
local_files_only=True,
314+
quantization_config=Quantization.get_single_file_config(config)
315+
)
316+
Quantization.quantize_model(config, transformer)
317+
return transformer
318+
319+
return Flux2Transformer2DModel.from_pretrained(
320+
config.base_model_path,
321+
subfolder="transformer",
322+
torch_dtype=config.data_type,
323+
quantization_config=_quant_config_diffusers,
324+
use_safetensors=True,
325+
**pipeline_kwargs
326+
)
327+
328+
329+
#------------------------------------------------
330+
# Load AutoencoderKLFlux2
331+
#------------------------------------------------
332+
def load_vae(
333+
config: DataObjects.PipelineConfig,
334+
pipeline_kwargs: Dict[str, str]
335+
):
336+
337+
if _pipeline and _pipeline.vae:
338+
print(f"[Reload] Loading cached Vae")
339+
return _pipeline.vae
340+
341+
_progress_tracker.Initialize(2, "vae")
342+
checkpoint_config = config.checkpoint_config
343+
if checkpoint_config.vae_checkpoint is not None:
344+
return AutoencoderKLFlux2.from_single_file(
345+
checkpoint_config.vae_checkpoint,
346+
config=_pipeline_config["vae"],
347+
torch_dtype=config.data_type,
348+
use_safetensors=True,
349+
local_files_only=True
350+
)
351+
352+
return AutoencoderKLFlux2.from_pretrained(
353+
config.base_model_path,
354+
subfolder="vae",
355+
torch_dtype=config.data_type,
356+
use_safetensors=True,
357+
**pipeline_kwargs
358+
)
359+
360+
361+
# #------------------------------------------------
362+
# # Load ControlNetModel
363+
# #------------------------------------------------
364+
# def load_control_net(
365+
# config: DataObjects.PipelineConfig,
366+
# pipeline_kwargs: Dict[str, str]
367+
# ):
368+
# global _control_net_path, _control_net_cache
369+
370+
# if _control_net_cache and _control_net_path == config.control_net_path:
371+
# print(f"[Reload] Loading cached ControlNet")
372+
# return _control_net_cache
373+
374+
# if config.control_net_path is None:
375+
# _control_net_path = None
376+
# _control_net_cache = None
377+
# return None
378+
379+
# _control_net_path = config.control_net_path
380+
# _progress_tracker.Initialize(3, "control_net")
381+
# _control_net_cache = ControlNetModel.from_pretrained(
382+
# _control_net_path,
383+
# torch_dtype=config.data_type,
384+
# use_safetensors=True,
385+
# )
386+
# return _control_net_cache

0 commit comments

Comments
 (0)