Skip to content

Commit ed96fcd

Browse files
authored
[Example] ggml: add test for set input twice (second-state#121)
* [Example] ggml: show metadata on CI for llama, llava examples Signed-off-by: dm4 <dm4@secondstate.io> * [Example] ggml: add test for set input twice Signed-off-by: dm4 <dm4@secondstate.io> --------- Signed-off-by: dm4 <dm4@secondstate.io>
1 parent 26ab476 commit ed96fcd

File tree

9 files changed

+193
-24
lines changed

9 files changed

+193
-24
lines changed

‎.github/workflows/llama.yml‎

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,19 @@ jobs:
158158
default \
159159
$'[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you do not know the answer to a question, please do not share false information.\n<</SYS>>\nWhat is the capital of Japan?[/INST]'
160160
161+
- name: Set Input Twice
162+
run: |
163+
test -f ~/.wasmedge/env && source ~/.wasmedge/env
164+
cd wasmedge-ggml/test/set-input-twice
165+
curl -LO https://huggingface.co/second-state/Gemma-2b-it-GGUF/resolve/main/gemma-2b-it-Q5_K_M.gguf
166+
cargo build --target wasm32-wasi --release
167+
time wasmedge --dir .:. \
168+
--env n_gpu_layers="$NGL" \
169+
--nn-preload default:GGML:AUTO:gemma-2b-it-Q5_K_M.gguf \
170+
target/wasm32-wasi/release/wasmedge-ggml-set-input-twice.wasm \
171+
default \
172+
'<start_of_turn>user Where is the capital of Japan? <end_of_turn><start_of_turn>model'
173+
161174
- name: Build llama-stream
162175
run: |
163176
cd wasmedge-ggml/llama-stream

‎wasmedge-ggml/llama/src/main.rs‎

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ fn get_output_from_context(context: &GraphExecutionContext) -> String {
6666
get_data_from_context(context, 0)
6767
}
6868

69-
#[allow(dead_code)]
7069
fn get_metadata_from_context(context: &GraphExecutionContext) -> Value {
7170
serde_json::from_str(&get_data_from_context(context, 1)).expect("Failed to get metadata")
7271
}
@@ -103,15 +102,41 @@ fn main() {
103102
// This is mainly for the CI workflow.
104103
if args.len() >= 3 {
105104
let prompt = &args[2];
105+
// Set the prompt.
106106
println!("Prompt:\n{}", prompt);
107107
let tensor_data = prompt.as_bytes().to_vec();
108108
context
109109
.set_input(0, TensorType::U8, &[1], &tensor_data)
110110
.expect("Failed to set input");
111111
println!("Response:");
112+
113+
// Get the number of input tokens and llama.cpp versions.
114+
let input_metadata = get_metadata_from_context(&context);
115+
println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]);
116+
println!(
117+
"[INFO] llama_build_number: {}",
118+
input_metadata["llama_build_number"]
119+
);
120+
println!(
121+
"[INFO] Number of input tokens: {}",
122+
input_metadata["input_tokens"]
123+
);
124+
125+
// Get the output.
112126
context.compute().expect("Failed to compute");
113127
let output = get_output_from_context(&context);
114128
println!("{}", output.trim());
129+
130+
// Retrieve the output metadata.
131+
let metadata = get_metadata_from_context(&context);
132+
println!(
133+
"[INFO] Number of input tokens: {}",
134+
metadata["input_tokens"]
135+
);
136+
println!(
137+
"[INFO] Number of output tokens: {}",
138+
metadata["output_tokens"]
139+
);
115140
std::process::exit(0);
116141
}
117142

@@ -134,18 +159,6 @@ fn main() {
134159
set_data_to_context(&mut context, saved_prompt.as_bytes().to_vec())
135160
.expect("Failed to set input");
136161

137-
// Get the number of input tokens and llama.cpp versions.
138-
// let input_metadata = get_metadata_from_context(&context);
139-
// println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]);
140-
// println!(
141-
// "[INFO] llama_build_number: {}",
142-
// input_metadata["llama_build_number"]
143-
// );
144-
// println!(
145-
// "[INFO] Number of input tokens: {}",
146-
// input_metadata["input_tokens"]
147-
// );
148-
149162
// Execute the inference.
150163
let mut reset_prompt = false;
151164
match context.compute() {
@@ -174,16 +187,5 @@ fn main() {
174187
output = output.trim().to_string();
175188
saved_prompt = format!("{} {}", saved_prompt, output);
176189
}
177-
178-
// Retrieve the output metadata.
179-
// let metadata = get_metadata_from_context(&context);
180-
// println!(
181-
// "[INFO] Number of input tokens: {}",
182-
// metadata["input_tokens"]
183-
// );
184-
// println!(
185-
// "[INFO] Number of output tokens: {}",
186-
// metadata["output_tokens"]
187-
// );
188190
}
189191
}
16.7 KB
Binary file not shown.

‎wasmedge-ggml/llava/src/main.rs‎

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ fn get_output_from_context(context: &GraphExecutionContext) -> String {
7575
get_data_from_context(context, 0)
7676
}
7777

78+
fn get_metadata_from_context(context: &GraphExecutionContext) -> Value {
79+
serde_json::from_str(&get_data_from_context(context, 1)).expect("Failed to get metadata")
80+
}
81+
7882
fn main() {
7983
let args: Vec<String> = env::args().collect();
8084
let model_name: &str = &args[1];
@@ -98,15 +102,41 @@ fn main() {
98102
// This is mainly for the CI workflow.
99103
if args.len() >= 3 {
100104
let prompt = &args[2];
105+
// Set the prompt.
101106
println!("Prompt:\n{}", prompt);
102107
let tensor_data = prompt.as_bytes().to_vec();
103108
context
104109
.set_input(0, TensorType::U8, &[1], &tensor_data)
105110
.expect("Failed to set input");
106111
println!("Response:");
112+
113+
// Get the number of input tokens and llama.cpp versions.
114+
let input_metadata = get_metadata_from_context(&context);
115+
println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]);
116+
println!(
117+
"[INFO] llama_build_number: {}",
118+
input_metadata["llama_build_number"]
119+
);
120+
println!(
121+
"[INFO] Number of input tokens: {}",
122+
input_metadata["input_tokens"]
123+
);
124+
125+
// Get the output.
107126
context.compute().expect("Failed to compute");
108127
let output = get_output_from_context(&context);
109128
println!("{}", output.trim());
129+
130+
// Retrieve the output metadata.
131+
let metadata = get_metadata_from_context(&context);
132+
println!(
133+
"[INFO] Number of input tokens: {}",
134+
metadata["input_tokens"]
135+
);
136+
println!(
137+
"[INFO] Number of output tokens: {}",
138+
metadata["output_tokens"]
139+
);
110140
std::process::exit(0);
111141
}
112142

16.9 KB
Binary file not shown.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "wasmedge-ggml-set-input-twice"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
serde_json = "1.0"
8+
wasmedge-wasi-nn = "0.7.0"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# `set-input-twice`
2+
3+
Ensure that we get the same result from executing `set_input` twice.
4+
5+
## Execute
6+
7+
```console
8+
$ curl -LO https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_M.gguf
9+
$ wasmedge --dir .:. \
10+
--nn-preload default:GGML:AUTO:llama-2-7b-chat.Q5_K_M.gguf \
11+
wasmedge-ggml-set-input-twice.wasm default '<PROMPT>'
12+
```
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use serde_json::json;
2+
use serde_json::Value;
3+
use std::env;
4+
use wasmedge_wasi_nn::{
5+
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, TensorType,
6+
};
7+
8+
fn get_options_from_env() -> Value {
9+
let mut options = json!({});
10+
if let Ok(val) = env::var("enable_log") {
11+
options["enable-log"] = serde_json::from_str(val.as_str())
12+
.expect("invalid value for enable-log option (true/false)")
13+
} else {
14+
options["enable-log"] = serde_json::from_str("false").unwrap()
15+
}
16+
if let Ok(val) = env::var("n_gpu_layers") {
17+
options["n-gpu-layers"] =
18+
serde_json::from_str(val.as_str()).expect("invalid ngl value (unsigned integer")
19+
} else {
20+
options["n-gpu-layers"] = serde_json::from_str("0").unwrap()
21+
}
22+
options["ctx-size"] = serde_json::from_str("1024").unwrap();
23+
24+
options
25+
}
26+
27+
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
28+
// Preserve for 4096 tokens with average token length 6
29+
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
30+
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
31+
let mut output_size = context
32+
.get_output(index, &mut output_buffer)
33+
.expect("Failed to get output");
34+
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
35+
36+
return String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
37+
}
38+
39+
fn get_metadata_from_context(context: &GraphExecutionContext) -> Value {
40+
serde_json::from_str(&get_data_from_context(context, 1)).expect("Failed to get metadata")
41+
}
42+
43+
fn main() {
44+
let args: Vec<String> = env::args().collect();
45+
let model_name: &str = &args[1];
46+
47+
// Set options for the graph. Check our README for more details:
48+
// https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters
49+
let options = get_options_from_env();
50+
51+
// Create graph and initialize context.
52+
let graph = GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO)
53+
.config(serde_json::to_string(&options).expect("Failed to serialize options"))
54+
.build_from_cache(model_name)
55+
.expect("Failed to build graph");
56+
let mut context = graph
57+
.init_execution_context()
58+
.expect("Failed to init context");
59+
60+
// If there is a third argument, use it as the prompt and enter non-interactive mode.
61+
// This is mainly for the CI workflow.
62+
if args.len() < 3 {
63+
println!("Usage: {} <model_name> <prompt>", args[0]);
64+
} else {
65+
let prompt = &args[2];
66+
67+
// Set the prompt.
68+
println!("Prompt:\n{}", prompt);
69+
let tensor_data = prompt.as_bytes().to_vec();
70+
context
71+
.set_input(0, TensorType::U8, &[1], &tensor_data)
72+
.expect("Failed to set input");
73+
println!("Response:");
74+
75+
// Get the number of input tokens and llama.cpp versions.
76+
let input_metadata = get_metadata_from_context(&context);
77+
println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]);
78+
println!(
79+
"[INFO] llama_build_number: {}",
80+
input_metadata["llama_build_number"]
81+
);
82+
println!(
83+
"[INFO] Number of input tokens: {}",
84+
input_metadata["input_tokens"]
85+
);
86+
87+
// Set the prompt, twice.
88+
context
89+
.set_input(0, TensorType::U8, &[1], &tensor_data)
90+
.expect("Failed to set input");
91+
92+
// Get the number of input tokens and llama.cpp versions.
93+
let input_metadata_after = get_metadata_from_context(&context);
94+
println!(
95+
"[INFO] Number of input tokens: {}",
96+
input_metadata_after["input_tokens"]
97+
);
98+
99+
// Check it the numbers of input_tokens are the same
100+
if input_metadata["input_tokens"] != input_metadata_after["input_tokens"] {
101+
panic!("The number of input tokens is different after setting the input twice.");
102+
}
103+
}
104+
}
Binary file not shown.

0 commit comments

Comments
 (0)