Skip to content

Commit decafab

Browse files
committed
CogVideoX, LTXVideo pipelines
1 parent d16c46a commit decafab

4 files changed

Lines changed: 413 additions & 1 deletion

File tree

TensorStack.Python/Common/PipelineOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public ImageTensor InputControlImage
4949
get { return InputControlImages.FirstOrDefault(); }
5050
set
5151
{
52-
if (InputImages.Count == 0)
52+
if (InputControlImages.Count == 0)
5353
{
5454
InputControlImages.Add(value);
5555
}
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import os
2+
import sys
3+
import torch
4+
import numpy as np
5+
from threading import Event
6+
from collections.abc import Buffer
7+
from typing import Coroutine, Dict, Sequence, List, Tuple, Optional, Union
8+
from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline, CogVideoXVideoToVideoPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler
9+
from tensorstack.utils import MemoryStdout, create_scheduler, getDataType, tensorFromInput, prepare_images, trim_memory
10+
sys.stderr = MemoryStdout()
11+
12+
# Globals
13+
_pipeline = None
14+
_processType = None
15+
_step_latent = None
16+
_generator = None
17+
_isMemoryOffload = False
18+
_cancel_event = Event()
19+
_pipelineMap = {
20+
"TextToVideo": CogVideoXPipeline,
21+
"ImageToVideo": CogVideoXImageToVideoPipeline,
22+
"VideoToVideo": CogVideoXVideoToVideoPipeline,
23+
}
24+
25+
def load(
26+
modelName: str,
27+
processType: str,
28+
controlNet: str = None,
29+
device: str = "cuda",
30+
deviceId: int = 0,
31+
dataType: str = "bfloat16",
32+
variant: str = None,
33+
cacheDir: str = None,
34+
secureToken: str = None,
35+
isModelOffloadEnabled: bool = False,
36+
isFullOffloadEnabled: bool= False,
37+
isVaeSlicingEnabled: bool= False,
38+
isVaeTilingEnabled: bool = False,
39+
loraAdapters: Optional[List[Tuple[str, str, str]]] = None
40+
) -> bool:
41+
global _pipeline, _generator, _processType, _isMemoryOffload
42+
43+
# Reset
44+
_reset()
45+
46+
# Load Pipeline
47+
dtype = getDataType(dataType)
48+
_processType = processType
49+
pipeline = _pipelineMap[_processType]
50+
_pipeline = pipeline.from_pretrained(
51+
modelName,
52+
torch_dtype=dtype,
53+
cache_dir = cacheDir,
54+
token = secureToken,
55+
variant=variant
56+
)
57+
58+
#Lora Adapters
59+
if loraAdapters is not None:
60+
for adapter_path, weight_name, adapter_name in loraAdapters:
61+
_pipeline.load_lora_weights(adapter_path, weight_name=weight_name, adapter_name=adapter_name)
62+
63+
# Device
64+
execution_device = torch.device(f"{device}:{deviceId}")
65+
if isFullOffloadEnabled:
66+
_isMemoryOffload = True
67+
_pipeline.enable_sequential_cpu_offload(device=execution_device)
68+
elif isModelOffloadEnabled:
69+
_isMemoryOffload = True
70+
_pipeline.enable_model_cpu_offload(device=execution_device)
71+
else:
72+
_pipeline.to(execution_device)
73+
74+
# Memory
75+
if isVaeSlicingEnabled:
76+
_pipeline.vae.enable_slicing()
77+
if isVaeTilingEnabled:
78+
_pipeline.vae.enable_tiling()
79+
_generator = torch.Generator(device=execution_device)
80+
return True
81+
82+
83+
84+
def unload() -> bool:
85+
global _pipeline
86+
_pipeline.remove_all_hooks()
87+
_pipeline.maybe_free_model_hooks()
88+
if hasattr(_pipeline,"tokenizer"):
89+
del _pipeline.tokenizer
90+
if hasattr(_pipeline,"text_encoder"):
91+
del _pipeline.text_encoder
92+
if hasattr(_pipeline,"transformer"):
93+
del _pipeline.transformer
94+
if hasattr(_pipeline,"vae"):
95+
del _pipeline.vae
96+
del _pipeline
97+
98+
# Cleanup
99+
trim_memory(_isMemoryOffload)
100+
return True
101+
102+
103+
def generateCancel() -> None:
104+
_cancel_event.set()
105+
106+
107+
def generate(
108+
prompt: str,
109+
negativePrompt: str,
110+
guidanceScale: float,
111+
guidanceScale2: float,
112+
steps: int,
113+
steps2: int,
114+
height: int,
115+
width: int,
116+
seed: int,
117+
scheduler: str,
118+
numFrames: int,
119+
shift: float,
120+
strength: float,
121+
controlScale: float,
122+
loraOptions: Optional[Dict[str, float]] = None,
123+
inputData: Optional[List[Tuple[Sequence[float],Sequence[int]]]] = None,
124+
controlNetData: Optional[List[Tuple[Sequence[float],Sequence[int]]]] = None,
125+
) -> Buffer:
126+
127+
# Reset
128+
_reset()
129+
130+
# scheduler
131+
if scheduler == "ddim":
132+
_pipeline.scheduler = CogVideoXDDIMScheduler.from_config(_pipeline.scheduler.config)
133+
if scheduler == "ddpm":
134+
_pipeline.scheduler = CogVideoXDPMScheduler.from_config(_pipeline.scheduler.config)
135+
136+
#Lora Adapters
137+
if loraOptions is not None:
138+
names = list(loraOptions.keys())
139+
weights = list(loraOptions.values())
140+
_pipeline.set_adapters(names, adapter_weights=weights)
141+
142+
# Input Images
143+
image = prepare_images(inputData)
144+
control_image = prepare_images(controlNetData)
145+
146+
# Pipeline Options
147+
options = {
148+
"prompt": prompt,
149+
"negative_prompt": negativePrompt,
150+
"height": height,
151+
"width": width,
152+
"generator": _generator.manual_seed(seed),
153+
"guidance_scale": guidanceScale,
154+
"num_inference_steps": steps,
155+
"num_frames": numFrames,
156+
"num_videos_per_prompt": 1,
157+
"output_type": "np",
158+
"callback_on_step_end": _progress_callback,
159+
"callback_on_step_end_tensor_inputs": ["latents"],
160+
}
161+
if _processType == "ImageToVideo":
162+
options.update({ "image": image, "use_dynamic_cfg": True })
163+
164+
# Run Pipeline
165+
output = _pipeline(**options)[0]
166+
167+
# (Frames, Channel, Height, Width)
168+
output = output.transpose(0, 1, 4, 2, 3).squeeze(axis=0).astype(np.float32)
169+
170+
# Cleanup
171+
trim_memory(_isMemoryOffload)
172+
return np.ascontiguousarray(output)
173+
174+
175+
176+
def getLogs() -> list[str]:
177+
return sys.stderr.get_log_history()
178+
179+
180+
def getStepLatent() -> Buffer:
181+
return _step_latent
182+
183+
184+
def _reset():
185+
_cancel_event.clear()
186+
187+
188+
def _log(message: str):
189+
sys.stderr.write(message)
190+
191+
192+
def _progress_callback(pipe, step: int, total_steps: int, info: Dict):
193+
global _step_latent
194+
if _cancel_event.is_set():
195+
pipe._interrupt = True
196+
raise Exception("Operation Canceled")
197+
198+
latents = info.get("latents")
199+
if latents is not None:
200+
_step_latent = np.ascontiguousarray(latents.float().cpu())
201+
202+
return info

0 commit comments

Comments
 (0)