Skip to content

Commit 613dcae

Browse files
authored
[Example] ggml: add phi-3-mini test (second-state#136)
Signed-off-by: dm4 <dm4@secondstate.io>
1 parent 7f67945 commit 613dcae

File tree

5 files changed

+130
-0
lines changed

5 files changed

+130
-0
lines changed

‎.github/workflows/llama.yml‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,18 @@ jobs:
259259
default \
260260
$'[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you do not know the answer to a question, please do not share false information.\n<</SYS>>\nWhat is the capital of Japan?[/INST]'
261261
262+
- name: Phi 3 Mini
263+
run: |
264+
test -f ~/.wasmedge/env && source ~/.wasmedge/env
265+
cd wasmedge-ggml/test/phi-3-mini
266+
curl -LO https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf
267+
cargo build --target wasm32-wasi --release
268+
time wasmedge --dir .:. \
269+
--nn-preload default:GGML:AUTO:Phi-3-mini-4k-instruct-q4.gguf \
270+
target/wasm32-wasi/release/wasmedge-ggml-phi-3-mini.wasm \
271+
default \
272+
$'<|user|>\nWhat is the capital of Japan?<|end|>\n<|assistant|>'
273+
262274
- name: Build llama-stream
263275
run: |
264276
cd wasmedge-ggml/llama-stream
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "wasmedge-ggml-phi-3-mini"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
serde_json = "1.0"
8+
wasmedge-wasi-nn = "0.7.1"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# `phi-3-mini`
2+
3+
Ensure that we can use the `phi-3-mini` model.
4+
5+
## Execute
6+
7+
```console
8+
$ curl -LO https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf
9+
$ wasmedge --dir .:. \
10+
--nn-preload default:GGML:AUTO:Phi-3-mini-4k-instruct-q4.gguf \
11+
wasmedge-ggml-phi-3-mini.wasm default \
12+
<prompt>
13+
```
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
use serde_json::json;
2+
use serde_json::Value;
3+
use std::env;
4+
use wasmedge_wasi_nn::{
5+
self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
6+
TensorType,
7+
};
8+
9+
fn get_options_from_env() -> Value {
10+
let mut options = json!({});
11+
if let Ok(val) = env::var("enable_log") {
12+
options["enable-log"] = serde_json::from_str(val.as_str())
13+
.expect("invalid value for enable-log option (true/false)")
14+
} else {
15+
options["enable-log"] = serde_json::from_str("false").unwrap()
16+
}
17+
if let Ok(val) = env::var("n_gpu_layers") {
18+
options["n-gpu-layers"] =
19+
serde_json::from_str(val.as_str()).expect("invalid ngl value (unsigned integer")
20+
} else {
21+
options["n-gpu-layers"] = serde_json::from_str("0").unwrap()
22+
}
23+
24+
options
25+
}
26+
27+
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
28+
// Preserve for 4096 tokens with average token length 6
29+
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
30+
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
31+
let mut output_size = context
32+
.get_output(index, &mut output_buffer)
33+
.expect("Failed to get output");
34+
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
35+
36+
return String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
37+
}
38+
39+
fn get_output_from_context(context: &GraphExecutionContext) -> String {
40+
get_data_from_context(context, 0)
41+
}
42+
43+
fn main() {
44+
let args: Vec<String> = env::args().collect();
45+
let model_name: &str = &args[1];
46+
47+
// Set options for the graph. Check our README for more details:
48+
// https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters
49+
let options = get_options_from_env();
50+
51+
// This is mainly for the CI workflow. Only support the prompt from the command line.
52+
if args.len() < 3 {
53+
println!("Usage: {} <model_name> <prompt>", args[0]);
54+
std::process::exit(1);
55+
}
56+
let prompt = &args[2];
57+
58+
// Create graph and initialize context.
59+
let graph = GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO)
60+
.config(serde_json::to_string(&options).expect("Failed to serialize options"))
61+
.build_from_cache(model_name)
62+
.expect("Failed to build graph");
63+
println!("Graph {} loaded.", graph);
64+
let mut context = graph
65+
.init_execution_context()
66+
.expect("Failed to init context");
67+
68+
// Set the prompt.
69+
println!("Prompt:\n{}", prompt);
70+
let tensor_data = prompt.as_bytes().to_vec();
71+
context
72+
.set_input(0, TensorType::U8, &[1], &tensor_data)
73+
.expect("Failed to set input");
74+
println!("Response:");
75+
76+
// Execute the inference.
77+
match context.compute() {
78+
Ok(_) => (),
79+
Err(Error::BackendError(BackendError::ContextFull)) => {
80+
println!("\n[INFO] Context full, we'll reset the context and continue.");
81+
}
82+
Err(Error::BackendError(BackendError::PromptTooLong)) => {
83+
println!("\n[INFO] Prompt too long, we'll reset the context and continue.");
84+
}
85+
Err(err) => {
86+
println!("\n[ERROR] {}", err);
87+
std::process::exit(1);
88+
}
89+
}
90+
91+
// Retrieve the output.
92+
let output = get_output_from_context(&context);
93+
println!("{}", output.trim());
94+
95+
// Unload.
96+
graph.unload().expect("Failed to unload graph");
97+
}
2.13 MB
Binary file not shown.

0 commit comments

Comments
 (0)