forked from oobabooga/textgen
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllamacpp_model.py
More file actions
91 lines (74 loc) · 2.94 KB
/
llamacpp_model.py
File metadata and controls
91 lines (74 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
'''
Based on
https://github.com/abetlen/llama-cpp-python
Documentation:
https://abetlen.github.io/llama-cpp-python/
'''
import re
from llama_cpp import Llama, LlamaCache
from modules import shared
from modules.callbacks import Iteratorize
from modules.logging_colors import logger
class LlamaCppModel:
def __init__(self):
self.initialized = False
def __del__(self):
self.model.__del__()
@classmethod
def from_pretrained(self, path):
result = self()
cache_capacity = 0
if shared.args.cache_capacity is not None:
if 'GiB' in shared.args.cache_capacity:
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000
elif 'MiB' in shared.args.cache_capacity:
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000
else:
cache_capacity = int(shared.args.cache_capacity)
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
params = {
'model_path': str(path),
'n_ctx': shared.args.n_ctx,
'seed': int(shared.args.llama_cpp_seed),
'n_threads': shared.args.threads or None,
'n_batch': shared.args.n_batch,
'use_mmap': not shared.args.no_mmap,
'use_mlock': shared.args.mlock,
'n_gpu_layers': shared.args.n_gpu_layers
}
self.model = Llama(**params)
if cache_capacity > 0:
self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
# This is ugly, but the model and the tokenizer are the same object in this library.
return result, result
def encode(self, string):
if type(string) is str:
string = string.encode()
return self.model.tokenize(string)
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
context = context if type(context) is str else context.decode()
completion_chunks = self.model.create_completion(
prompt=context,
max_tokens=token_count,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repetition_penalty,
mirostat_mode=int(mirostat_mode),
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
stream=True
)
output = ""
for completion_chunk in completion_chunks:
text = completion_chunk['choices'][0]['text']
output += text
if callback:
callback(text)
return output
def generate_with_streaming(self, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = ''
for token in generator:
reply += token
yield reply