Skip to content

Commit 26ab476

Browse files
authored
[Example] ggml: add nnrpc example for RPC usage (second-state#119)
* [Example] ggml: add nnrpc example for RPC usage Signed-off-by: dm4 <dm4@secondstate.io> * [CI] llama: add tests for nnrpc example Signed-off-by: dm4 <dm4@secondstate.io> --------- Signed-off-by: dm4 <dm4@secondstate.io>
1 parent 1be5d43 commit 26ab476

File tree

5 files changed

+253
-0
lines changed

5 files changed

+253
-0
lines changed

‎.github/workflows/llama.yml‎

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,19 @@ jobs:
145145
default \
146146
'hello world'
147147
148+
- name: RPC Example
149+
run: |
150+
test -f ~/.wasmedge/env && source ~/.wasmedge/env
151+
cd wasmedge-ggml/nnrpc
152+
curl -LO https://huggingface.co/second-state/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_M.gguf
153+
cargo build --target wasm32-wasi --release
154+
time wasmedge --dir .:. \
155+
--env n_gpu_layers="$NGL" \
156+
--nn-preload default:GGML:AUTO:llama-2-7b-chat.Q5_K_M.gguf \
157+
target/wasm32-wasi/release/wasmedge-ggml-nnrpc.wasm \
158+
default \
159+
$'[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you do not know the answer to a question, please do not share false information.\n<</SYS>>\nWhat is the capital of Japan?[/INST]'
160+
148161
- name: Build llama-stream
149162
run: |
150163
cd wasmedge-ggml/llama-stream

‎wasmedge-ggml/nnrpc/Cargo.toml‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "wasmedge-ggml-nnrpc"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
serde_json = "1.0"
8+
wasmedge-wasi-nn = "0.7.0"

‎wasmedge-ggml/nnrpc/README.md‎

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# RPC Example For WASI-NN with GGML Backend
2+
3+
> [!NOTE]
4+
> Please refer to the [wasmedge-ggml/README.md](../README.md) for the general introduction and the setup of the WASI-NN plugin with GGML backend. This document will focus on the specific example of the WASI-NN RPC usage.
5+
6+
## Parameters
7+
8+
> [!NOTE]
9+
> Please check the parameters section of [wasmedge-ggml/README.md](https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters) first.
10+
11+
For GPU offloading, please adjust the `n-gpu-layers` options to the number of layers that you want to offload to the GPU.
12+
13+
```rust
14+
options.insert("n-gpu-layers", Value::from(...));
15+
```
16+
17+
In llava inference, we recommend to use the `ctx-size` at least `2048` when using llava-v1.5 and at least `4096` when using llava-v1.6 for better results.
18+
19+
```rust
20+
options.insert("ctx-size", Value::from(4096));
21+
```
22+
23+
## Execute
24+
25+
26+
```console
27+
# Run the RPC server.
28+
$ wasi_nn_rpcserver --nn-rpc-uri unix://$PWD/nn_server.sock \
29+
--nn-preload default:GGML:AUTO:llama-2-7b-chat.Q5_K_M.gguf
30+
31+
# Run the wasmedge and inference though the RPC server.
32+
$ wasmedge \
33+
--nn-rpc-uri unix://$PWD/nn_server.sock \
34+
wasmedge-ggml-nnrpc.wasm default
35+
36+
USER:
37+
What's the capital of the United States?
38+
ASSISTANT:
39+
The capital of the United States is Washington, D.C. (District of Columbia).
40+
USER:
41+
How about France?
42+
ASSISTANT:
43+
The capital of France is Paris.
44+
```

‎wasmedge-ggml/nnrpc/src/main.rs‎

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
use serde_json::json;
2+
use serde_json::Value;
3+
use std::env;
4+
use std::io;
5+
use wasmedge_wasi_nn::{
6+
self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
7+
TensorType,
8+
};
9+
10+
fn read_input() -> String {
11+
loop {
12+
let mut answer = String::new();
13+
io::stdin()
14+
.read_line(&mut answer)
15+
.expect("Failed to read line");
16+
if !answer.is_empty() && answer != "\n" && answer != "\r\n" {
17+
return answer.trim().to_string();
18+
}
19+
}
20+
}
21+
22+
fn get_options_from_env() -> Value {
23+
let mut options = json!({});
24+
if let Ok(val) = env::var("enable_log") {
25+
options["enable-log"] = serde_json::from_str(val.as_str())
26+
.expect("invalid value for enable-log option (true/false)")
27+
} else {
28+
options["enable-log"] = serde_json::from_str("false").unwrap()
29+
}
30+
if let Ok(val) = env::var("n_gpu_layers") {
31+
options["n-gpu-layers"] =
32+
serde_json::from_str(val.as_str()).expect("invalid ngl value (unsigned integer")
33+
} else {
34+
options["n-gpu-layers"] = serde_json::from_str("0").unwrap()
35+
}
36+
options["ctx-size"] = serde_json::from_str("1024").unwrap();
37+
38+
options
39+
}
40+
41+
fn set_data_to_context(context: &mut GraphExecutionContext, data: Vec<u8>) -> Result<(), Error> {
42+
context.set_input(0, TensorType::U8, &[1], &data)
43+
}
44+
45+
#[allow(dead_code)]
46+
fn set_metadata_to_context(
47+
context: &mut GraphExecutionContext,
48+
data: Vec<u8>,
49+
) -> Result<(), Error> {
50+
context.set_input(1, TensorType::U8, &[1], &data)
51+
}
52+
53+
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
54+
// Preserve for 4096 tokens with average token length 6
55+
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
56+
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
57+
let mut output_size = context
58+
.get_output(index, &mut output_buffer)
59+
.expect("Failed to get output");
60+
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
61+
62+
return String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
63+
}
64+
65+
fn get_output_from_context(context: &GraphExecutionContext) -> String {
66+
get_data_from_context(context, 0)
67+
}
68+
69+
#[allow(dead_code)]
70+
fn get_metadata_from_context(context: &GraphExecutionContext) -> Value {
71+
serde_json::from_str(&get_data_from_context(context, 1)).expect("Failed to get metadata")
72+
}
73+
74+
fn main() {
75+
let args: Vec<String> = env::args().collect();
76+
let model_name: &str = &args[1];
77+
78+
// Set options for the graph. Check our README for more details:
79+
// https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters
80+
let options = get_options_from_env();
81+
82+
// Create graph and initialize context.
83+
let graph = GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO)
84+
.build_from_cache(model_name)
85+
.expect("Failed to build graph");
86+
let mut context = graph
87+
.init_execution_context()
88+
.expect("Failed to init context");
89+
90+
// We also support setting the options via input tensor with index 1.
91+
// Uncomment the line below to run the example, Check our README for more details.
92+
set_metadata_to_context(
93+
&mut context,
94+
serde_json::to_string(&options)
95+
.expect("Failed to serialize options")
96+
.as_bytes()
97+
.to_vec(),
98+
)
99+
.expect("Failed to set metadata");
100+
101+
// If there is a third argument, use it as the prompt and enter non-interactive mode.
102+
// This is mainly for the CI workflow.
103+
if args.len() >= 3 {
104+
// Set the prompt.
105+
let prompt = &args[2];
106+
println!("Prompt:\n{}", prompt);
107+
let tensor_data = prompt.as_bytes().to_vec();
108+
context
109+
.set_input(0, TensorType::U8, &[1], &tensor_data)
110+
.expect("Failed to set input");
111+
// Get the number of input tokens and llama.cpp versions.
112+
let input_metadata = get_metadata_from_context(&context);
113+
println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]);
114+
println!(
115+
"[INFO] llama_build_number: {}",
116+
input_metadata["llama_build_number"]
117+
);
118+
println!(
119+
"[INFO] Number of input tokens: {}",
120+
input_metadata["input_tokens"]
121+
);
122+
// Get the response.
123+
println!("Response:");
124+
context.compute().expect("Failed to compute");
125+
let output = get_output_from_context(&context);
126+
println!("{}", output.trim());
127+
// Retrieve the output metadata.
128+
let metadata = get_metadata_from_context(&context);
129+
println!(
130+
"[INFO] Number of input tokens: {}",
131+
metadata["input_tokens"]
132+
);
133+
println!(
134+
"[INFO] Number of output tokens: {}",
135+
metadata["output_tokens"]
136+
);
137+
std::process::exit(0);
138+
}
139+
140+
let mut saved_prompt = String::new();
141+
let system_prompt = String::from("You are a helpful, respectful and honest assistant. Always answer as short as possible, while being safe." );
142+
143+
loop {
144+
println!("USER:");
145+
let input = read_input();
146+
if saved_prompt.is_empty() {
147+
saved_prompt = format!(
148+
"[INST] <<SYS>> {} <</SYS>> {} [/INST]",
149+
system_prompt, input
150+
);
151+
} else {
152+
saved_prompt = format!("{} [INST] {} [/INST]", saved_prompt, input);
153+
}
154+
155+
// Set prompt to the input tensor.
156+
set_data_to_context(&mut context, saved_prompt.as_bytes().to_vec())
157+
.expect("Failed to set input");
158+
159+
// Execute the inference.
160+
let mut reset_prompt = false;
161+
match context.compute() {
162+
Ok(_) => (),
163+
Err(Error::BackendError(BackendError::ContextFull)) => {
164+
println!("\n[INFO] Context full, we'll reset the context and continue.");
165+
reset_prompt = true;
166+
}
167+
Err(Error::BackendError(BackendError::PromptTooLong)) => {
168+
println!("\n[INFO] Prompt too long, we'll reset the context and continue.");
169+
reset_prompt = true;
170+
}
171+
Err(err) => {
172+
println!("\n[ERROR] {}", err);
173+
}
174+
}
175+
176+
// Retrieve the output.
177+
let mut output = get_output_from_context(&context);
178+
println!("ASSISTANT:\n{}", output.trim());
179+
180+
// Update the saved prompt.
181+
if reset_prompt {
182+
saved_prompt.clear();
183+
} else {
184+
output = output.trim().to_string();
185+
saved_prompt = format!("{} {}", saved_prompt, output);
186+
}
187+
}
188+
}
2.15 MB
Binary file not shown.

0 commit comments

Comments
 (0)