Skip to content

Commit 1300102

Browse files
committed
QwenPipeline
1 parent d042602 commit 1300102

2 files changed

Lines changed: 245 additions & 0 deletions

File tree

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
using TensorStack.TextGeneration.Common;
2+
3+
namespace TensorStack.TextGeneration.Pipelines.Qwen
4+
{
5+
public record QwenConfig : TransformerConfig
6+
{
7+
}
8+
}
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
4+
using System;
5+
using System.IO;
6+
using System.Linq;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
using TensorStack.Common;
10+
using TensorStack.Common.Pipeline;
11+
using TensorStack.Common.Tensor;
12+
using TensorStack.TextGeneration.Cache;
13+
using TensorStack.TextGeneration.Common;
14+
using TensorStack.TextGeneration.Processing;
15+
using TensorStack.TextGeneration.Tokenizers;
16+
17+
namespace TensorStack.TextGeneration.Pipelines.Qwen
18+
{
19+
public class QwenPipeline : DecoderPipeline<GenerateOptions>,
20+
IPipeline<GenerateResult, GenerateOptions, GenerateProgress>,
21+
IPipeline<GenerateResult[], SearchOptions, GenerateProgress>
22+
{
23+
/// <summary>
24+
/// Initializes a new instance of the <see cref="QwenPipeline"/> class.
25+
/// </summary>
26+
/// <param name="tokenizerConfig">The tokenizer configuration.</param>
27+
/// <param name="decoderConfig">The decoder configuration.</param>
28+
public QwenPipeline(QwenConfig configuration)
29+
: base(configuration.Tokenizer, configuration.DecoderConfig)
30+
{
31+
Configuration = configuration;
32+
}
33+
34+
public QwenConfig Configuration { get; }
35+
36+
37+
/// <summary>
38+
/// Runs the GreedySearch inference
39+
/// </summary>
40+
/// <param name="options">The options.</param>
41+
/// <param name="cancellationToken">The cancellation token.</param>
42+
/// <returns></returns>
43+
public virtual async Task<GenerateResult> RunAsync(GenerateOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
44+
{
45+
await TokenizePromptAsync(options);
46+
var sequence = await GreedySearchAsync(options, progressCallback, cancellationToken);
47+
using (sequence)
48+
{
49+
return new GenerateResult
50+
{
51+
Score = sequence.Score,
52+
Result = Tokenizer.Decode(sequence.Tokens),
53+
Tokens = sequence.Tokens,
54+
LastHiddenState = sequence.LastHiddenState
55+
};
56+
}
57+
}
58+
59+
60+
/// <summary>
61+
/// Runs the BeamSearch inference
62+
/// </summary>
63+
/// <param name="options">The options.</param>
64+
/// <param name="progressCallback">The progress callback.</param>
65+
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
66+
public async Task<GenerateResult[]> RunAsync(SearchOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
67+
{
68+
await TokenizePromptAsync(options);
69+
70+
var sequences = await BeamSearchAsync(options, progressCallback, cancellationToken);
71+
var results = new GenerateResult[sequences.Length];
72+
for (int beam = 0; beam < sequences.Length; beam++)
73+
{
74+
var sequence = sequences[beam];
75+
using (sequence)
76+
{
77+
results[beam] = new GenerateResult
78+
{
79+
Beam = beam,
80+
Score = sequence.Score,
81+
PenaltyScore = sequence.PenaltyScore,
82+
Result = Tokenizer.Decode(sequence.Tokens),
83+
Tokens = sequence.Tokens,
84+
LastHiddenState = sequence.LastHiddenState
85+
};
86+
}
87+
}
88+
return results;
89+
}
90+
91+
92+
/// <summary>
93+
/// Tokenize the prompt
94+
/// </summary>
95+
/// <param name="options">The options.</param>
96+
/// <returns>A Task representing the asynchronous operation.</returns>
97+
protected override async Task TokenizePromptAsync(GenerateOptions options)
98+
{
99+
var tokenizerResult = await Tokenizer.EncodeAsync(options.Prompt);
100+
var inputIds = tokenizerResult.InputIds.Span.Pad(Tokenizer.EOS, options.MinLength);
101+
var mask = tokenizerResult.Mask.Span.Pad(0, options.MinLength);
102+
TokenizerOutput = new TokenizerResult(inputIds, mask);
103+
}
104+
105+
106+
/// <summary>
107+
/// Gets the token processors.
108+
/// </summary>
109+
/// <param name="options">The options.</param>
110+
/// <returns>ITokenProcessor[].</returns>
111+
protected override ITokenProcessor[] GetTokenProcessors(GenerateOptions options)
112+
{
113+
return
114+
[
115+
new EOSTokenProcessor(options.MinLength, Tokenizer.EOS),
116+
new MaxLengthTokenProcessor(options.MaxLength)
117+
];
118+
}
119+
120+
121+
/// <summary>
122+
/// Initialize the Decoder cache
123+
/// </summary>
124+
/// <param name="options">The options.</param>
125+
/// <returns>A Task&lt;Sequence&gt; representing the asynchronous operation.</returns>
126+
protected override async Task<Sequence> InitializeAsync(GenerateOptions options)
127+
{
128+
var modelMetadata = await Decoder.LoadAsync();
129+
var kvCache = new KVCacheDecoder(modelMetadata, DecoderConfig.NumHeads, DecoderConfig.NumLayers, DecoderConfig.HiddenSize, DecoderConfig.NumKVHeads, options.MaxLength);
130+
var sequence = new Sequence(kvCache, Tokenizer.BOS);
131+
sequence.Initialize(0);
132+
133+
var position = TokenizerOutput.Length;
134+
var inputIds = TokenizerOutput.InputIds;
135+
var positionIds = GetPositionIds(modelMetadata, 0, position);
136+
var attentionMask = new Tensor<long>([1, position], 1);
137+
RunDecoderInternal(modelMetadata, sequence, inputIds, positionIds, attentionMask, false);
138+
return sequence;
139+
}
140+
141+
142+
/// <summary>
143+
/// Run decoder model
144+
/// </summary>
145+
/// <param name="sequence">The sequence.</param>
146+
/// <returns>A Task&lt;Tensor`1&gt; representing the asynchronous operation.</returns>
147+
protected override async Task<Tensor<float>> RunDecoderAsync(Sequence sequence)
148+
{
149+
var modelMetadata = await Decoder.LoadAsync();
150+
var position = TokenizerOutput.Length + sequence.Tokens.Count;
151+
var inputIds = new Tensor<long>([1, 1], sequence.Tokens[^1]);
152+
var positionIds = GetPositionIds(modelMetadata, position);
153+
var attentionMask = new Tensor<long>([1, position], 1);
154+
return RunDecoderInternal(modelMetadata, sequence, inputIds, positionIds, attentionMask, true);
155+
}
156+
157+
158+
/// <summary>
159+
/// Runs the decoder
160+
/// </summary>
161+
/// <param name="modelMetadata">The model metadata.</param>
162+
/// <param name="sequence">The sequence.</param>
163+
/// <param name="inputIds">The input ids.</param>
164+
/// <param name="positionIds">The position ids.</param>
165+
/// <param name="attentionMask">The attention mask.</param>
166+
/// <param name="useBranchCache">if set to <c>true</c> [use branch cache].</param>
167+
private Tensor<float> RunDecoderInternal(ModelMetadata modelMetadata, Sequence sequence, Tensor<long> inputIds, Tensor<long> positionIds, Tensor<long> attentionMask, bool useBranchCache)
168+
{
169+
using (var parameters = new ModelParameters(modelMetadata))
170+
{
171+
// Inputs
172+
parameters.AddInput(inputIds);
173+
parameters.AddInput(attentionMask);
174+
if (positionIds != null)
175+
parameters.AddInput(positionIds);
176+
177+
foreach (var pastKeyValue in sequence.Cache)
178+
parameters.AddInput(pastKeyValue, false);
179+
180+
// Outputs
181+
foreach (var output in modelMetadata.Outputs)
182+
parameters.AddOutput();
183+
184+
// Result
185+
var modelResult = Decoder.RunInference(parameters);
186+
using (var logitsResult = modelResult[0])
187+
{
188+
var dimension = logitsResult.GetDimensions();
189+
var logits = logitsResult.ToTensor(dimension[1..]);
190+
var presentKeyValues = modelResult.ToArray()[1..];
191+
sequence.UpdateCache(presentKeyValues, useBranchCache);
192+
return logits;
193+
}
194+
}
195+
}
196+
197+
198+
/// <summary>
199+
/// Creates the QwenPipeline
200+
/// </summary>
201+
/// <param name="provider">The provider.</param>
202+
/// <param name="modelPath">The model path.</param>
203+
/// <param name="model">The decoder model.</param>
204+
/// <returns>QwenPipeline.</returns>
205+
public static QwenPipeline Create(ExecutionProvider provider, string modelPath, string model = "model.onnx")
206+
{
207+
// Qwen-2.5 - https://huggingface.co/onnx-community/Qwen2.5-0.5B
208+
var numHeads = 14;
209+
var numLayers = 24;
210+
var hiddenSize = 896;
211+
var numKVHeads = 2;
212+
var vocabSize = 151936;
213+
var config = new QwenConfig
214+
{
215+
Tokenizer = new BPETokenizer(new TokenizerConfig
216+
{
217+
BOS = 151643,
218+
EOS = 151643,
219+
Path = modelPath
220+
}),
221+
DecoderConfig = new DecoderConfig
222+
{
223+
Path = Path.Combine(modelPath, model),
224+
VocabSize = vocabSize,
225+
NumHeads = numHeads,
226+
NumLayers = numLayers,
227+
HiddenSize = hiddenSize,
228+
NumKVHeads = numKVHeads
229+
}
230+
};
231+
232+
config.DecoderConfig.SetProvider(provider);
233+
return new QwenPipeline(config);
234+
}
235+
236+
}
237+
}

0 commit comments

Comments
 (0)