1+ // Copyright (c) TensorStack. All rights reserved.
2+ // Licensed under the Apache 2.0 License.
3+
4+ using System ;
5+ using System . IO ;
6+ using System . Linq ;
7+ using System . Threading ;
8+ using System . Threading . Tasks ;
9+ using TensorStack . Common ;
10+ using TensorStack . Common . Pipeline ;
11+ using TensorStack . Common . Tensor ;
12+ using TensorStack . TextGeneration . Cache ;
13+ using TensorStack . TextGeneration . Common ;
14+ using TensorStack . TextGeneration . Processing ;
15+ using TensorStack . TextGeneration . Tokenizers ;
16+
17+ namespace TensorStack . TextGeneration . Pipelines . Qwen
18+ {
19+ public class QwenPipeline : DecoderPipeline < GenerateOptions > ,
20+ IPipeline < GenerateResult , GenerateOptions , GenerateProgress > ,
21+ IPipeline < GenerateResult [ ] , SearchOptions , GenerateProgress >
22+ {
23+ /// <summary>
24+ /// Initializes a new instance of the <see cref="QwenPipeline"/> class.
25+ /// </summary>
26+ /// <param name="tokenizerConfig">The tokenizer configuration.</param>
27+ /// <param name="decoderConfig">The decoder configuration.</param>
28+ public QwenPipeline ( QwenConfig configuration )
29+ : base ( configuration . Tokenizer , configuration . DecoderConfig )
30+ {
31+ Configuration = configuration ;
32+ }
33+
34+ public QwenConfig Configuration { get ; }
35+
36+
37+ /// <summary>
38+ /// Runs the GreedySearch inference
39+ /// </summary>
40+ /// <param name="options">The options.</param>
41+ /// <param name="cancellationToken">The cancellation token.</param>
42+ /// <returns></returns>
43+ public virtual async Task < GenerateResult > RunAsync ( GenerateOptions options , IProgress < GenerateProgress > progressCallback = null , CancellationToken cancellationToken = default )
44+ {
45+ await TokenizePromptAsync ( options ) ;
46+ var sequence = await GreedySearchAsync ( options , progressCallback , cancellationToken ) ;
47+ using ( sequence )
48+ {
49+ return new GenerateResult
50+ {
51+ Score = sequence . Score ,
52+ Result = Tokenizer . Decode ( sequence . Tokens ) ,
53+ Tokens = sequence . Tokens ,
54+ LastHiddenState = sequence . LastHiddenState
55+ } ;
56+ }
57+ }
58+
59+
60+ /// <summary>
61+ /// Runs the BeamSearch inference
62+ /// </summary>
63+ /// <param name="options">The options.</param>
64+ /// <param name="progressCallback">The progress callback.</param>
65+ /// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
66+ public async Task < GenerateResult [ ] > RunAsync ( SearchOptions options , IProgress < GenerateProgress > progressCallback = null , CancellationToken cancellationToken = default )
67+ {
68+ await TokenizePromptAsync ( options ) ;
69+
70+ var sequences = await BeamSearchAsync ( options , progressCallback , cancellationToken ) ;
71+ var results = new GenerateResult [ sequences . Length ] ;
72+ for ( int beam = 0 ; beam < sequences . Length ; beam ++ )
73+ {
74+ var sequence = sequences [ beam ] ;
75+ using ( sequence )
76+ {
77+ results [ beam ] = new GenerateResult
78+ {
79+ Beam = beam ,
80+ Score = sequence . Score ,
81+ PenaltyScore = sequence . PenaltyScore ,
82+ Result = Tokenizer . Decode ( sequence . Tokens ) ,
83+ Tokens = sequence . Tokens ,
84+ LastHiddenState = sequence . LastHiddenState
85+ } ;
86+ }
87+ }
88+ return results ;
89+ }
90+
91+
92+ /// <summary>
93+ /// Tokenize the prompt
94+ /// </summary>
95+ /// <param name="options">The options.</param>
96+ /// <returns>A Task representing the asynchronous operation.</returns>
97+ protected override async Task TokenizePromptAsync ( GenerateOptions options )
98+ {
99+ var tokenizerResult = await Tokenizer . EncodeAsync ( options . Prompt ) ;
100+ var inputIds = tokenizerResult . InputIds . Span . Pad ( Tokenizer . EOS , options . MinLength ) ;
101+ var mask = tokenizerResult . Mask . Span . Pad ( 0 , options . MinLength ) ;
102+ TokenizerOutput = new TokenizerResult ( inputIds , mask ) ;
103+ }
104+
105+
106+ /// <summary>
107+ /// Gets the token processors.
108+ /// </summary>
109+ /// <param name="options">The options.</param>
110+ /// <returns>ITokenProcessor[].</returns>
111+ protected override ITokenProcessor [ ] GetTokenProcessors ( GenerateOptions options )
112+ {
113+ return
114+ [
115+ new EOSTokenProcessor ( options . MinLength , Tokenizer . EOS ) ,
116+ new MaxLengthTokenProcessor ( options . MaxLength )
117+ ] ;
118+ }
119+
120+
121+ /// <summary>
122+ /// Initialize the Decoder cache
123+ /// </summary>
124+ /// <param name="options">The options.</param>
125+ /// <returns>A Task<Sequence> representing the asynchronous operation.</returns>
126+ protected override async Task < Sequence > InitializeAsync ( GenerateOptions options )
127+ {
128+ var modelMetadata = await Decoder . LoadAsync ( ) ;
129+ var kvCache = new KVCacheDecoder ( modelMetadata , DecoderConfig . NumHeads , DecoderConfig . NumLayers , DecoderConfig . HiddenSize , DecoderConfig . NumKVHeads , options . MaxLength ) ;
130+ var sequence = new Sequence ( kvCache , Tokenizer . BOS ) ;
131+ sequence . Initialize ( 0 ) ;
132+
133+ var position = TokenizerOutput . Length ;
134+ var inputIds = TokenizerOutput . InputIds ;
135+ var positionIds = GetPositionIds ( modelMetadata , 0 , position ) ;
136+ var attentionMask = new Tensor < long > ( [ 1 , position ] , 1 ) ;
137+ RunDecoderInternal ( modelMetadata , sequence , inputIds , positionIds , attentionMask , false ) ;
138+ return sequence ;
139+ }
140+
141+
142+ /// <summary>
143+ /// Run decoder model
144+ /// </summary>
145+ /// <param name="sequence">The sequence.</param>
146+ /// <returns>A Task<Tensor`1> representing the asynchronous operation.</returns>
147+ protected override async Task < Tensor < float > > RunDecoderAsync ( Sequence sequence )
148+ {
149+ var modelMetadata = await Decoder . LoadAsync ( ) ;
150+ var position = TokenizerOutput . Length + sequence . Tokens . Count ;
151+ var inputIds = new Tensor < long > ( [ 1 , 1 ] , sequence . Tokens [ ^ 1 ] ) ;
152+ var positionIds = GetPositionIds ( modelMetadata , position ) ;
153+ var attentionMask = new Tensor < long > ( [ 1 , position ] , 1 ) ;
154+ return RunDecoderInternal ( modelMetadata , sequence , inputIds , positionIds , attentionMask , true ) ;
155+ }
156+
157+
158+ /// <summary>
159+ /// Runs the decoder
160+ /// </summary>
161+ /// <param name="modelMetadata">The model metadata.</param>
162+ /// <param name="sequence">The sequence.</param>
163+ /// <param name="inputIds">The input ids.</param>
164+ /// <param name="positionIds">The position ids.</param>
165+ /// <param name="attentionMask">The attention mask.</param>
166+ /// <param name="useBranchCache">if set to <c>true</c> [use branch cache].</param>
167+ private Tensor < float > RunDecoderInternal ( ModelMetadata modelMetadata , Sequence sequence , Tensor < long > inputIds , Tensor < long > positionIds , Tensor < long > attentionMask , bool useBranchCache )
168+ {
169+ using ( var parameters = new ModelParameters ( modelMetadata ) )
170+ {
171+ // Inputs
172+ parameters . AddInput ( inputIds ) ;
173+ parameters . AddInput ( attentionMask ) ;
174+ if ( positionIds != null )
175+ parameters . AddInput ( positionIds ) ;
176+
177+ foreach ( var pastKeyValue in sequence . Cache )
178+ parameters . AddInput ( pastKeyValue , false ) ;
179+
180+ // Outputs
181+ foreach ( var output in modelMetadata . Outputs )
182+ parameters . AddOutput ( ) ;
183+
184+ // Result
185+ var modelResult = Decoder . RunInference ( parameters ) ;
186+ using ( var logitsResult = modelResult [ 0 ] )
187+ {
188+ var dimension = logitsResult . GetDimensions ( ) ;
189+ var logits = logitsResult . ToTensor ( dimension [ 1 ..] ) ;
190+ var presentKeyValues = modelResult . ToArray ( ) [ 1 ..] ;
191+ sequence . UpdateCache ( presentKeyValues , useBranchCache ) ;
192+ return logits ;
193+ }
194+ }
195+ }
196+
197+
198+ /// <summary>
199+ /// Creates the QwenPipeline
200+ /// </summary>
201+ /// <param name="provider">The provider.</param>
202+ /// <param name="modelPath">The model path.</param>
203+ /// <param name="model">The decoder model.</param>
204+ /// <returns>QwenPipeline.</returns>
205+ public static QwenPipeline Create ( ExecutionProvider provider , string modelPath , string model = "model.onnx" )
206+ {
207+ // Qwen-2.5 - https://huggingface.co/onnx-community/Qwen2.5-0.5B
208+ var numHeads = 14 ;
209+ var numLayers = 24 ;
210+ var hiddenSize = 896 ;
211+ var numKVHeads = 2 ;
212+ var vocabSize = 151936 ;
213+ var config = new QwenConfig
214+ {
215+ Tokenizer = new BPETokenizer ( new TokenizerConfig
216+ {
217+ BOS = 151643 ,
218+ EOS = 151643 ,
219+ Path = modelPath
220+ } ) ,
221+ DecoderConfig = new DecoderConfig
222+ {
223+ Path = Path . Combine ( modelPath , model ) ,
224+ VocabSize = vocabSize ,
225+ NumHeads = numHeads ,
226+ NumLayers = numLayers ,
227+ HiddenSize = hiddenSize ,
228+ NumKVHeads = numKVHeads
229+ }
230+ } ;
231+
232+ config . DecoderConfig . SetProvider ( provider ) ;
233+ return new QwenPipeline ( config ) ;
234+ }
235+
236+ }
237+ }
0 commit comments