-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
356 lines (305 loc) · 11.6 KB
/
app.py
File metadata and controls
356 lines (305 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
"""
NL2Shell Demo — HuggingFace Spaces Gradio App
==============================================
Converts natural language descriptions into shell commands using
the AryaYT/nl2shell-0.8b model (Qwen3.5-0.8B fine-tuned with QLoRA).
Deploy: Push this file + requirements.txt to a Gradio HF Space.
ZeroGPU: Decorate the generate() fn with @spaces.GPU if hosting
on a ZeroGPU-enabled Space (PRO account required).
"""
import time
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MODEL_ID = "AryaYT/nl2shell-0.8b"
SYSTEM_PROMPT = (
"You are an expert shell programmer. Given a natural language request, "
"output ONLY the corresponding shell command. No explanations."
)
EXAMPLES = [
"list all files in the current directory sorted by size",
"kill the process using port 3000",
"find all Python files modified in the last 24 hours",
"show disk usage of the current directory",
"count lines of code in all TypeScript files",
"check SSL certificate expiry of example.com",
"compress the src directory into a tar.gz",
"show git log as one-line summaries for the last 20 commits",
"find and delete all node_modules directories recursively",
"show all running Docker containers",
"list all open network connections",
"show which process is using the most CPU",
"generate a random 32-character password",
"create a new git branch called feature-auth",
"watch a directory for file changes",
]
# ---------------------------------------------------------------------------
# ZeroGPU support (no-op when not in a ZeroGPU Space)
# ---------------------------------------------------------------------------
# If you are on a ZeroGPU Space (HF PRO account), uncomment the two lines
# below and add `spaces` to requirements.txt to get free H200 access:
#
# import spaces
# @spaces.GPU(duration=30)
#
# The @spaces.GPU decorator is a no-op in non-ZeroGPU environments, so it's
# safe to add proactively if you have a PRO account.
# ---------------------------------------------------------------------------
# Device detection
# ---------------------------------------------------------------------------
def get_device() -> str:
"""Return the best available device string."""
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
DEVICE = get_device()
IS_GPU = DEVICE in ("cuda", "mps")
# ---------------------------------------------------------------------------
# Model loading — happens once at startup
# ---------------------------------------------------------------------------
print(f"[NL2Shell] Loading model on {DEVICE}...")
_load_start = time.time()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
# On CPU (free HF Spaces tier): load in float32, no device_map needed.
# On GPU: use bfloat16 + device_map=auto for speed and memory efficiency.
if IS_GPU:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
trust_remote_code=True,
)
model = model.to(DEVICE)
model.eval()
_load_time = time.time() - _load_start
print(f"[NL2Shell] Model loaded in {_load_time:.1f}s on {DEVICE}")
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def build_prompt(user_request: str) -> str:
"""Build a ChatML-formatted prompt for NL2Shell."""
return (
f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{user_request.strip()}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
def extract_command(full_output: str) -> str:
"""Extract the assistant's shell command from a full ChatML output string."""
marker = "<|im_start|>assistant\n"
if marker in full_output:
cmd = full_output.split(marker)[-1]
# Strip the closing im_end token if present
cmd = cmd.split("<|im_end|>")[0]
return cmd.strip()
return full_output.strip()
def generate(
user_request: str, temperature: float = 0.1, max_new_tokens: int = 128
) -> tuple[str, str]:
"""
Run inference and return (command, metadata_markdown).
Returns a tuple so Gradio can display command and meta info separately.
"""
if not user_request or not user_request.strip():
return "", "Enter a description above and click Generate."
prompt = build_prompt(user_request)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
input_len = input_ids.shape[1]
t0 = time.time()
with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0.0,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.convert_tokens_to_ids("<|im_end|>"),
)
elapsed = time.time() - t0
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=False)
command = extract_command(full_output)
tokens_generated = output_ids.shape[1] - input_len
tok_per_sec = tokens_generated / elapsed if elapsed > 0 else 0
meta = (
f"**Device:** {DEVICE.upper()} | "
f"**Tokens generated:** {tokens_generated} | "
f"**Speed:** {tok_per_sec:.1f} tok/s | "
f"**Time:** {elapsed:.2f}s"
)
return command, meta
# ---------------------------------------------------------------------------
# UI helpers
# ---------------------------------------------------------------------------
def run_example(example_text: str) -> tuple[str, str, str]:
"""Used by the Examples component — populate input and run inference."""
cmd, meta = generate(example_text)
return example_text, cmd, meta
# ---------------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------------
CSS = """
#header {
text-align: center;
margin-bottom: 8px;
}
#header h1 {
font-size: 2rem;
font-weight: 700;
margin-bottom: 4px;
}
#header p {
color: #6b7280;
font-size: 0.95rem;
}
#command-output textarea {
font-family: 'JetBrains Mono', 'Fira Code', 'Consolas', monospace;
font-size: 1.05rem;
background: #0d1117;
color: #58d68d;
border: 1px solid #30363d;
border-radius: 6px;
}
#meta-row {
font-size: 0.8rem;
color: #6b7280;
}
.gr-button-primary {
background: #2ea44f !important;
border: none !important;
}
"""
DESCRIPTION_MD = """\
**[AryaYT/nl2shell-0.8b](https://huggingface.co/AryaYT/nl2shell-0.8b)** \
is a 0.8B-parameter model fine-tuned from Qwen3.5-0.8B with QLoRA on the \
[NL2Bash](https://huggingface.co/datasets/jiacheng-ye/nl2bash) benchmark \
(~10,000 pairs) plus 40 hand-crafted macOS command pairs. \
It outputs **only the shell command** — no markdown, no explanation.
"""
INFO_MD = """\
| Property | Value |
|---|---|
| Base model | Qwen/Qwen3.5-0.8B |
| Fine-tuning | QLoRA (4-bit, rank 16, alpha 32) |
| Training data | NL2Bash + 40 macOS synthetic pairs |
| Format | ChatML |
| Max output | 128 tokens |
| GGUF (edge) | q4_k_m ~400 MB, q8_0 ~650 MB |
| License | MIT |
**Run locally with Ollama:**
```bash
ollama run hf.co/AryaYT/nl2shell-0.8b
```
"""
HARDWARE_NOTE = (
"Running on **GPU** — fast inference."
if IS_GPU
else (
"Running on **CPU** (free HF Spaces tier). "
"Generation takes 5-20 seconds. "
"For instant results, duplicate this Space and choose a GPU tier, "
"or run the model locally via Ollama."
)
)
def build_ui() -> gr.Blocks:
with gr.Blocks(css=CSS, title="NL2Shell — Natural Language to Shell Commands") as demo:
# Header
with gr.Column(elem_id="header"):
gr.Markdown("# NL2Shell")
gr.Markdown(
"Convert natural language descriptions into shell commands. "
"No explanations — just the command."
)
gr.Markdown(DESCRIPTION_MD)
gr.Markdown(f"> {HARDWARE_NOTE}")
# Main input/output
with gr.Row():
with gr.Column(scale=3):
user_input = gr.Textbox(
label="Describe what you want to do",
placeholder="e.g. find all Python files modified in the last 24 hours",
lines=2,
max_lines=4,
)
with gr.Row():
submit_btn = gr.Button("Generate Command", variant="primary")
clear_btn = gr.Button("Clear")
# Advanced settings (collapsed by default)
with gr.Accordion("Advanced settings", open=False):
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.05,
label="Temperature",
info="Lower = more deterministic. 0.1 recommended for shell commands.",
)
max_tokens = gr.Slider(
minimum=16,
maximum=256,
value=128,
step=16,
label="Max new tokens",
)
with gr.Column(scale=2):
command_output = gr.Textbox(
label="Generated Shell Command",
elem_id="command-output",
lines=3,
interactive=False,
show_copy_button=True,
)
meta_output = gr.Markdown(
value="",
elem_id="meta-row",
)
# Examples
gr.Markdown("### Example Prompts")
gr.Examples(
examples=EXAMPLES,
inputs=user_input,
outputs=[user_input, command_output, meta_output],
fn=run_example,
cache_examples=False,
label="Click an example to run it",
)
# Model info
with gr.Accordion("Model Info & Local Usage", open=False):
gr.Markdown(INFO_MD)
# Warning footer
gr.Markdown(
"> **Note:** Always review shell commands before running them, "
"especially those involving `rm`, `kill`, `sudo`, or network operations. "
"This model is a research demo — treat output as a starting point, not ground truth."
)
# Wire up events
submit_btn.click(
fn=generate,
inputs=[user_input, temperature, max_tokens],
outputs=[command_output, meta_output],
)
user_input.submit(
fn=generate,
inputs=[user_input, temperature, max_tokens],
outputs=[command_output, meta_output],
)
clear_btn.click(
fn=lambda: ("", "", ""),
outputs=[user_input, command_output, meta_output],
)
return demo
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
app = build_ui()
app.launch()