1+ import os
2+ import sys
3+ import torch
4+ import numpy as np
5+ from threading import Event
6+ from collections .abc import Buffer
7+ from typing import Coroutine , Dict , Sequence , List , Tuple , Optional , Union
8+ from diffusers import CogVideoXPipeline , CogVideoXImageToVideoPipeline , CogVideoXVideoToVideoPipeline , CogVideoXDDIMScheduler , CogVideoXDPMScheduler
9+ from tensorstack .utils import MemoryStdout , create_scheduler , getDataType , tensorFromInput , prepare_images , trim_memory
10+ sys .stderr = MemoryStdout ()
11+
12+ # Globals
13+ _pipeline = None
14+ _processType = None
15+ _step_latent = None
16+ _generator = None
17+ _isMemoryOffload = False
18+ _cancel_event = Event ()
19+ _pipelineMap = {
20+ "TextToVideo" : CogVideoXPipeline ,
21+ "ImageToVideo" : CogVideoXImageToVideoPipeline ,
22+ "VideoToVideo" : CogVideoXVideoToVideoPipeline ,
23+ }
24+
25+ def load (
26+ modelName : str ,
27+ processType : str ,
28+ controlNet : str = None ,
29+ device : str = "cuda" ,
30+ deviceId : int = 0 ,
31+ dataType : str = "bfloat16" ,
32+ variant : str = None ,
33+ cacheDir : str = None ,
34+ secureToken : str = None ,
35+ isModelOffloadEnabled : bool = False ,
36+ isFullOffloadEnabled : bool = False ,
37+ isVaeSlicingEnabled : bool = False ,
38+ isVaeTilingEnabled : bool = False ,
39+ loraAdapters : Optional [List [Tuple [str , str , str ]]] = None
40+ ) -> bool :
41+ global _pipeline , _generator , _processType , _isMemoryOffload
42+
43+ # Reset
44+ _reset ()
45+
46+ # Load Pipeline
47+ dtype = getDataType (dataType )
48+ _processType = processType
49+ pipeline = _pipelineMap [_processType ]
50+ _pipeline = pipeline .from_pretrained (
51+ modelName ,
52+ torch_dtype = dtype ,
53+ cache_dir = cacheDir ,
54+ token = secureToken ,
55+ variant = variant
56+ )
57+
58+ #Lora Adapters
59+ if loraAdapters is not None :
60+ for adapter_path , weight_name , adapter_name in loraAdapters :
61+ _pipeline .load_lora_weights (adapter_path , weight_name = weight_name , adapter_name = adapter_name )
62+
63+ # Device
64+ execution_device = torch .device (f"{ device } :{ deviceId } " )
65+ if isFullOffloadEnabled :
66+ _isMemoryOffload = True
67+ _pipeline .enable_sequential_cpu_offload (device = execution_device )
68+ elif isModelOffloadEnabled :
69+ _isMemoryOffload = True
70+ _pipeline .enable_model_cpu_offload (device = execution_device )
71+ else :
72+ _pipeline .to (execution_device )
73+
74+ # Memory
75+ if isVaeSlicingEnabled :
76+ _pipeline .vae .enable_slicing ()
77+ if isVaeTilingEnabled :
78+ _pipeline .vae .enable_tiling ()
79+ _generator = torch .Generator (device = execution_device )
80+ return True
81+
82+
83+
84+ def unload () -> bool :
85+ global _pipeline
86+ _pipeline .remove_all_hooks ()
87+ _pipeline .maybe_free_model_hooks ()
88+ if hasattr (_pipeline ,"tokenizer" ):
89+ del _pipeline .tokenizer
90+ if hasattr (_pipeline ,"text_encoder" ):
91+ del _pipeline .text_encoder
92+ if hasattr (_pipeline ,"transformer" ):
93+ del _pipeline .transformer
94+ if hasattr (_pipeline ,"vae" ):
95+ del _pipeline .vae
96+ del _pipeline
97+
98+ # Cleanup
99+ trim_memory (_isMemoryOffload )
100+ return True
101+
102+
103+ def generateCancel () -> None :
104+ _cancel_event .set ()
105+
106+
107+ def generate (
108+ prompt : str ,
109+ negativePrompt : str ,
110+ guidanceScale : float ,
111+ guidanceScale2 : float ,
112+ steps : int ,
113+ steps2 : int ,
114+ height : int ,
115+ width : int ,
116+ seed : int ,
117+ scheduler : str ,
118+ numFrames : int ,
119+ shift : float ,
120+ strength : float ,
121+ controlScale : float ,
122+ loraOptions : Optional [Dict [str , float ]] = None ,
123+ inputData : Optional [List [Tuple [Sequence [float ],Sequence [int ]]]] = None ,
124+ controlNetData : Optional [List [Tuple [Sequence [float ],Sequence [int ]]]] = None ,
125+ ) -> Buffer :
126+
127+ # Reset
128+ _reset ()
129+
130+ # scheduler
131+ if scheduler == "ddim" :
132+ _pipeline .scheduler = CogVideoXDDIMScheduler .from_config (_pipeline .scheduler .config )
133+ if scheduler == "ddpm" :
134+ _pipeline .scheduler = CogVideoXDPMScheduler .from_config (_pipeline .scheduler .config )
135+
136+ #Lora Adapters
137+ if loraOptions is not None :
138+ names = list (loraOptions .keys ())
139+ weights = list (loraOptions .values ())
140+ _pipeline .set_adapters (names , adapter_weights = weights )
141+
142+ # Input Images
143+ image = prepare_images (inputData )
144+ control_image = prepare_images (controlNetData )
145+
146+ # Pipeline Options
147+ options = {
148+ "prompt" : prompt ,
149+ "negative_prompt" : negativePrompt ,
150+ "height" : height ,
151+ "width" : width ,
152+ "generator" : _generator .manual_seed (seed ),
153+ "guidance_scale" : guidanceScale ,
154+ "num_inference_steps" : steps ,
155+ "num_frames" : numFrames ,
156+ "num_videos_per_prompt" : 1 ,
157+ "output_type" : "np" ,
158+ "callback_on_step_end" : _progress_callback ,
159+ "callback_on_step_end_tensor_inputs" : ["latents" ],
160+ }
161+ if _processType == "ImageToVideo" :
162+ options .update ({ "image" : image , "use_dynamic_cfg" : True })
163+
164+ # Run Pipeline
165+ output = _pipeline (** options )[0 ]
166+
167+ # (Frames, Channel, Height, Width)
168+ output = output .transpose (0 , 1 , 4 , 2 , 3 ).squeeze (axis = 0 ).astype (np .float32 )
169+
170+ # Cleanup
171+ trim_memory (_isMemoryOffload )
172+ return np .ascontiguousarray (output )
173+
174+
175+
176+ def getLogs () -> list [str ]:
177+ return sys .stderr .get_log_history ()
178+
179+
180+ def getStepLatent () -> Buffer :
181+ return _step_latent
182+
183+
184+ def _reset ():
185+ _cancel_event .clear ()
186+
187+
188+ def _log (message : str ):
189+ sys .stderr .write (message )
190+
191+
192+ def _progress_callback (pipe , step : int , total_steps : int , info : Dict ):
193+ global _step_latent
194+ if _cancel_event .is_set ():
195+ pipe ._interrupt = True
196+ raise Exception ("Operation Canceled" )
197+
198+ latents = info .get ("latents" )
199+ if latents is not None :
200+ _step_latent = np .ascontiguousarray (latents .float ().cpu ())
201+
202+ return info
0 commit comments