forked from socketteer/loom
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpt.py
More file actions
85 lines (66 loc) · 1.96 KB
/
gpt.py
File metadata and controls
85 lines (66 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import time
import traceback
from pprint import pprint
from celery import Celery
import openai
from util.util import retry
#################################
# Janus
#################################
redis_url = os.environ.get("JANUS_REDIS", None)
app = Celery(
# 'janus',
broker=redis_url,
backend=redis_url,
)
# get_gpt_response(prompt, memory, retry=True) -> result, error
janus_task = "janus.my_celery.tasks.get_gpt_response"
def janus_generate(prompt, memory=""):
assert isinstance(prompt, str) and isinstance(memory, str)
celery_task = app.send_task(janus_task, args=[prompt, memory])
print("Sent to janus")
result, error = celery_task.get()
return result, error
#################################
# OpenAI
#################################
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
# pprint([d["id"] for d in openai.Engine.list()["data"]])
POSSIBLE_MODELS = [
'ada',
'babbage',
'content-filter-alpha-c4',
'content-filter-dev',
'curie',
'cursing-filter-v6',
'davinci',
'instruct-curie-beta',
'instruct-davinci-beta'
]
@retry(n_tries=3, delay=1, backoff=2, on_failure=lambda *args, **kwargs: "")
def api_generate(prompt, length=150, num_continuations=1, logprobs=10, temperature=0.8, top_p=1, stop=None, engine='davinci', **kwargs):
response = openai.Completion.create(
engine=engine,
prompt=prompt,
temperature=temperature,
max_tokens=length,
top_p=top_p,
logprobs=logprobs,
n=num_continuations,
stop=stop,
**kwargs
)
# for choice in response.choices:
# print(choice['logprobs'])
return response, None
def search(query, documents, engine="curie"):
return openai.Engine(engine).search(
documents=documents,
query=query
)
if __name__ == "__main__":
pass
print(janus_generate("test"))
# print(os.environ["OPENAI_API_KEY"])
# print(api_generate("test"))