Skip to content

Commit 9cf37f2

Browse files
dm4hydai
authored andcommitted
[Example] ggml: add multimodel example with CI
Signed-off-by: dm4 <dm4@secondstate.io>
1 parent c573af5 commit 9cf37f2

File tree

5 files changed

+276
-0
lines changed

5 files changed

+276
-0
lines changed

‎.github/workflows/llama.yml‎

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,24 @@ jobs:
114114
default \
115115
'def print_hello_world():'
116116
117+
- name: Multiple Models Example
118+
run: |
119+
test -f ~/.wasmedge/env && source ~/.wasmedge/env
120+
cd wasmedge-ggml/multimodel
121+
curl -LO https://huggingface.co/second-state/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_M.gguf
122+
curl -LO https://huggingface.co/cmp-nct/llava-1.6-gguf/resolve/main/vicuna-7b-q5_k.gguf
123+
curl -LO https://huggingface.co/cmp-nct/llava-1.6-gguf/resolve/main/mmproj-vicuna7b-f16.gguf
124+
curl -LO https://llava-vl.github.io/static/images/monalisa.jpg
125+
cargo build --target wasm32-wasi --release
126+
time wasmedge --dir .:. \
127+
--env n_gpu_layers="$NGL" \
128+
--env image=monalisa.jpg \
129+
--env mmproj=mmproj-vicuna7b-f16.gguf \
130+
--nn-preload llama2:GGML:AUTO:llama-2-7b-chat.Q5_K_M.gguf \
131+
--nn-preload llava:GGML:AUTO:vicuna-7b-q5_k.gguf \
132+
target/wasm32-wasi/release/wasmedge-ggml-multimodel.wasm \
133+
'describe this picture please'
134+
117135
- name: Build llama-stream
118136
run: |
119137
cd wasmedge-ggml/llama-stream
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "wasmedge-ggml-multimodel"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
serde_json = "1.0"
8+
wasmedge-wasi-nn = "0.7.0"

‎wasmedge-ggml/multimodel/README.md‎

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
2+
# Multiple Models Example For WASI-NN with GGML Backend
3+
4+
> [!NOTE]
5+
> Please refer to the [wasmedge-ggml/README.md](../README.md) for the general introduction and the setup of the WASI-NN plugin with GGML backend. This document will focus on the specific example for the chaining the results between multiple models.
6+
7+
In this example, we will try asking the `Llava` model a question with a image, and then pass the answer to the `Llama2` model for further response. This example will demonstrate how to use WasmEdge WASI-NN plugin to link two or more models together.
8+
9+
## Get the Model
10+
11+
This example uses the `Llama2` model and `Llava` model. You can download the models from the following links:
12+
13+
```bash
14+
curl -LO https://huggingface.co/second-state/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_M.gguf
15+
curl -LO https://huggingface.co/cmp-nct/llava-1.6-gguf/resolve/main/vicuna-7b-q5_k.gguf
16+
curl -LO https://huggingface.co/cmp-nct/llava-1.6-gguf/resolve/main/mmproj-vicuna7b-f16.gguf
17+
```
18+
19+
## Parameters
20+
21+
> [!NOTE]
22+
> Please check the parameters section of [wasmedge-ggml/README.md](https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters) first.
23+
24+
Download the image for the Llava model:
25+
26+
```bash
27+
curl -LO https://llava-vl.github.io/static/images/monalisa.jpg
28+
```
29+
30+
## Execute
31+
32+
Execute the WASM with the `wasmedge` using the named model feature to preload the two large models:
33+
34+
```console
35+
$ wasmedge --dir .:. \
36+
--env image=monalisa.jpg \
37+
--env mmproj=mmproj-vicuna7b-f16.gguf \
38+
--nn-preload llama2:GGML:AUTO:llama-2-7b-chat.Q5_K_M.gguf \
39+
--nn-preload llava:GGML:AUTO:vicuna-7b-q5_k.gguf \
40+
wasmedge-ggml-multimodel.wasm default
41+
42+
USER:
43+
describe this picture please
44+
ASSISTANT (llava):
45+
The image you've provided appears to be a painting of the Mona Lisa, one of Leonardo da Vinci's most famous works. It is a portrait of a woman with a serene and enigmatic expression, looking directly at the viewer. Her hair is styled in an updo, and she wears a dark dress that drapes elegantly around her shoulders. The background features a landscape with rolling hills and a river, which adds depth to the composition. The painting is renowned for its subtle changes in expression and the enigmatic smile on the subject's face, which has intrigued viewers for centuries.
46+
ASSISTANT (llama2):
47+
The image provided is a painting of the Mona Lisa, one of Leonardo da Vinci's most famous works, depicting a woman with a serene and enigmatic expression, styled in an updo with a dark dress draped elegantly around her shoulders, set against a landscape background with rolling hills and a river.
48+
```
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
use serde_json::json;
2+
use serde_json::Value;
3+
use std::env;
4+
use std::io;
5+
use wasmedge_wasi_nn::{
6+
self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
7+
TensorType,
8+
};
9+
10+
fn read_input() -> String {
11+
loop {
12+
let mut answer = String::new();
13+
io::stdin()
14+
.read_line(&mut answer)
15+
.expect("Failed to read line");
16+
if !answer.is_empty() && answer != "\n" && answer != "\r\n" {
17+
return answer.trim().to_string();
18+
}
19+
}
20+
}
21+
22+
fn get_options_from_env() -> Value {
23+
let mut options = json!({});
24+
25+
// Required parameters for llava
26+
if let Ok(val) = env::var("mmproj") {
27+
options["mmproj"] = Value::from(val.as_str());
28+
} else {
29+
eprintln!("Failed to get mmproj model.");
30+
std::process::exit(1);
31+
}
32+
if let Ok(val) = env::var("image") {
33+
options["image"] = Value::from(val.as_str());
34+
} else {
35+
eprintln!("Failed to get the target image.");
36+
std::process::exit(1);
37+
}
38+
39+
// Optional parameters
40+
if let Ok(val) = env::var("enable_log") {
41+
options["enable-log"] = serde_json::from_str(val.as_str())
42+
.expect("invalid value for enable-log option (true/false)")
43+
} else {
44+
options["enable-log"] = serde_json::from_str("false").unwrap()
45+
}
46+
if let Ok(val) = env::var("n_gpu_layers") {
47+
options["n-gpu-layers"] =
48+
serde_json::from_str(val.as_str()).expect("invalid ngl value (unsigned integer")
49+
} else {
50+
options["n-gpu-layers"] = serde_json::from_str("0").unwrap()
51+
}
52+
53+
options
54+
}
55+
56+
fn set_data_to_context(context: &mut GraphExecutionContext, data: Vec<u8>) -> Result<(), Error> {
57+
context.set_input(0, TensorType::U8, &[1], &data)
58+
}
59+
60+
#[allow(dead_code)]
61+
fn set_metadata_to_context(
62+
context: &mut GraphExecutionContext,
63+
data: Vec<u8>,
64+
) -> Result<(), Error> {
65+
context.set_input(1, TensorType::U8, &[1], &data)
66+
}
67+
68+
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
69+
// Preserve for 4096 tokens with average token length 6
70+
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
71+
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
72+
let mut output_size = context
73+
.get_output(index, &mut output_buffer)
74+
.expect("Failed to get output");
75+
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
76+
77+
return String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
78+
}
79+
80+
fn get_output_from_context(context: &GraphExecutionContext) -> String {
81+
get_data_from_context(context, 0)
82+
}
83+
84+
#[allow(dead_code)]
85+
fn get_metadata_from_context(context: &GraphExecutionContext) -> Value {
86+
serde_json::from_str(&get_data_from_context(context, 1)).expect("Failed to get metadata")
87+
}
88+
89+
fn main() {
90+
let args: Vec<String> = env::args().collect();
91+
92+
// Set options for the graph. Check our README for more details:
93+
// https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters
94+
let mut options = get_options_from_env();
95+
// We set the temperature to 0.1 for more consistent results.
96+
options["temp"] = Value::from(0.1);
97+
// Set the context size to 4096 tokens for the llava 1.6 model.
98+
options["ctx-size"] = Value::from(4096);
99+
100+
// Create the llava model.
101+
let mut graphs = Vec::new();
102+
graphs.push(
103+
GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO)
104+
.config(serde_json::to_string(&options).expect("Failed to serialize options"))
105+
.build_from_cache("llava")
106+
.expect("Failed to build graph"),
107+
);
108+
109+
// Remove unnecessary options for the llama2 model.
110+
options
111+
.as_object_mut()
112+
.expect("Failed to get jsons object")
113+
.remove("mmproj");
114+
options
115+
.as_object_mut()
116+
.expect("Failed to get json object")
117+
.remove("image");
118+
// Create the llama2 model.
119+
graphs.push(
120+
GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO)
121+
.config(serde_json::to_string(&options).expect("Failed to serialize options"))
122+
.build_from_cache("llama2")
123+
.expect("Failed to build graph"),
124+
);
125+
126+
// Initilize the execution contexts.
127+
let mut contexts = Vec::new();
128+
contexts.push(
129+
graphs[0]
130+
.init_execution_context()
131+
.expect("Failed to init context"),
132+
);
133+
contexts.push(
134+
graphs[1]
135+
.init_execution_context()
136+
.expect("Failed to init context"),
137+
);
138+
139+
let system_prompt = String::from("You are a helpful, respectful and honest assistant.");
140+
let mut input = String::from("");
141+
142+
// If the user provides a prompt, use it.
143+
println!("USER:");
144+
if args.len() >= 2 {
145+
input += &args[1];
146+
println!("{}", input);
147+
} else {
148+
input = read_input();
149+
}
150+
151+
// Llava inference.
152+
let image_placeholder = "<image>";
153+
let mut saved_prompt = format!(
154+
"{}\nUSER:{}\n{}\nASSISTANT:",
155+
system_prompt, image_placeholder, input
156+
);
157+
set_data_to_context(&mut contexts[0], saved_prompt.as_bytes().to_vec())
158+
.expect("Failed to set input");
159+
match contexts[0].compute() {
160+
Ok(_) => (),
161+
Err(Error::BackendError(BackendError::ContextFull)) => {
162+
println!("\n[INFO] Context full, we'll reset the context and continue.");
163+
}
164+
Err(Error::BackendError(BackendError::PromptTooLong)) => {
165+
println!("\n[INFO] Prompt too long, we'll reset the context and continue.");
166+
}
167+
Err(err) => {
168+
println!("\n[ERROR] {}", err);
169+
}
170+
}
171+
172+
// Retrieve the llava output.
173+
let mut output = get_output_from_context(&contexts[0]);
174+
println!("ASSISTANT (llava):\n{}", output.trim());
175+
176+
// Llama2 inference.
177+
let llama2_prompt = "Summarize the following text in 1 sentence:";
178+
saved_prompt = format!(
179+
"[INST] <<SYS>> {} <</SYS>> {} {} [/INST]",
180+
system_prompt,
181+
llama2_prompt,
182+
output.trim()
183+
);
184+
set_data_to_context(&mut contexts[1], saved_prompt.as_bytes().to_vec())
185+
.expect("Failed to set input");
186+
match contexts[1].compute() {
187+
Ok(_) => (),
188+
Err(Error::BackendError(BackendError::ContextFull)) => {
189+
println!("\n[INFO] Context full, we'll reset the context and continue.");
190+
}
191+
Err(Error::BackendError(BackendError::PromptTooLong)) => {
192+
println!("\n[INFO] Prompt too long, we'll reset the context and continue.");
193+
}
194+
Err(err) => {
195+
println!("\n[ERROR] {}", err);
196+
}
197+
}
198+
199+
// Retrieve the llama2 output.
200+
output = get_output_from_context(&contexts[1]);
201+
println!("ASSISTANT (llama2):\n{}", output.trim());
202+
}
2.16 MB
Binary file not shown.

0 commit comments

Comments
 (0)