Skip to content

Commit 56d497a

Browse files
authored
[Example] Give an example of llama tflite (second-state#165)
* add llama * add llama 2 * fix encoding * fix error * Update README.md * Update Cargo.toml
1 parent d75c738 commit 56d497a

File tree

7 files changed

+307
-4
lines changed

7 files changed

+307
-4
lines changed

‎openvino-mobilenet-image/rust/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ publish = false
88

99
[dependencies]
1010
image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "pnm", "tga", "tiff", "webp", "bmp", "hdr", "dxt", "dds", "farbfeld"] }
11-
wasi-nn = { version = "0.4.0" }
11+
wasi-nn = { version = "0.6.0" }
1212

1313
[workspace]

‎openvino-mobilenet-raw/rust/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ edition = "2021"
77
publish = false
88

99
[dependencies]
10-
wasi-nn = { version = "0.4.0" }
10+
wasi-nn = { version = "0.6.0" }
1111

1212
[workspace]

‎openvino-road-segmentation-adas/openvino-road-seg-adas/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ name = "openvino-road-seg-adas"
55
version = "0.2.0"
66

77
[dependencies]
8-
wasi-nn = "0.4.0"
8+
wasi-nn = "0.6.0"
99
image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "pnm", "tga", "tiff", "webp", "bmp", "hdr", "dxt", "dds", "farbfeld"] }

‎tflite-birds_v1-image/rust/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ publish = false
88

99
[dependencies]
1010
image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "pnm", "tga", "tiff", "webp", "bmp", "hdr", "dxt", "dds", "farbfeld"] }
11-
wasi-nn = "0.4.0"
11+
wasi-nn = "0.6.0"
1212

1313
[workspace]

‎wasmedge-tf-llama/README.md‎

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Llama Example For WasmEdge-Tensorflow plug-in
2+
3+
This package is a high-level Rust bindings for [WasmEdge-TensorFlow plug-in](https://wasmedge.org/docs/develop/rust/tensorflow) example of Mobilenet.
4+
5+
## Dependencies
6+
7+
This crate depends on the `wasmedge_tensorflow_interface` in the `Cargo.toml`:
8+
9+
```toml
10+
[dependencies]
11+
wasmedge_tensorflow_interface = "0.3.0"
12+
wasi-nn = "0.1.0" # Ensure you use the latest version
13+
thiserror = "1.0"
14+
bytemuck = "1.13.1"
15+
log = "0.4.19"
16+
env_logger = "0.10.0"
17+
anyhow = "1.0.79"
18+
```
19+
20+
## Build
21+
22+
Compile the application to WebAssembly:
23+
24+
```bash
25+
cd rust && cargo build --target=wasm32-wasip1 --release
26+
```
27+
28+
The output WASM file will be at [`rust/target/wasm32-wasip1/release/wasmedge-tf-example-llama.wasm`](wasmedge-tf-example-llama.wasm).
29+
To speed up the image processing, we can enable the AOT mode in WasmEdge with:
30+
31+
```bash
32+
wasmedge compile rust/target/wasm32-wasi/release/wasmedge-tf-example-llama.wasm wasmedge-tf-example-llama_aot.wasm
33+
```
34+
35+
## Run
36+
37+
The frozen `tflite` model should be translated through `ai_edge_torch` and HuggingFace.
38+
39+
### Execute
40+
41+
Users should [install the WasmEdge with WasmEdge-TensorFlow and WasmEdge-Image plug-ins](https://wasmedge.org/docs/start/install#wasmedge-tensorflow-plug-in).
42+
43+
```bash
44+
curl -sSf https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- --plugins wasmedge_tensorflow wasmedge_image
45+
```
46+
47+
Execute the WASM with the `wasmedge` with Tensorflow Lite supporting:
48+
49+
```bash
50+
wasmedge --dir .:. wasmedge-tf-example-llama.wasm ./llama_1b_q8_ekv1280.tflite
51+
```
52+
53+
You will get the output:
54+
55+
```console
56+
Input the Chatbot:
57+
Hello world
58+
Hello world!
59+
```

‎wasmedge-tf-llama/rust/Cargo.toml‎

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[package]
2+
name = "wasmedge-tf-example-llama"
3+
version = "0.1.0"
4+
authors = ["victoryang00"]
5+
readme = "README.md"
6+
edition = "2021"
7+
publish = false
8+
9+
[dependencies]
10+
wasmedge_tensorflow_interface = "0.3.0"
11+
wasi-nn = "0.1.0" # Ensure you use the latest version
12+
thiserror = "1.0"
13+
bytemuck = "1.13.1"
14+
log = "0.4.19"
15+
env_logger = "0.10.0"
16+
anyhow = "1.0.79"
17+
[workspace]

‎wasmedge-tf-llama/rust/src/main.rs‎

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
#![feature(str_as_str)]
2+
use bytemuck::{cast_slice, cast_slice_mut};
3+
use std::collections::HashMap;
4+
use std::env;
5+
use std::fs::File;
6+
use std::io::Read;
7+
use std::io::{self, BufRead, Write};
8+
use std::process;
9+
use thiserror::Error;
10+
use wasi_nn;
11+
12+
// Define a custom error type to handle multiple error sources
13+
#[derive(Error, Debug)]
14+
pub enum ChatbotError {
15+
#[error("IO Error: {0}")]
16+
Io(#[from] std::io::Error),
17+
18+
#[error("WASI NN Error: {0}")]
19+
WasiNn(String),
20+
21+
#[error("Other Error: {0}")]
22+
Other(String),
23+
}
24+
25+
impl From<wasi_nn::Error> for ChatbotError {
26+
fn from(error: wasi_nn::Error) -> Self {
27+
ChatbotError::WasiNn(error.to_string())
28+
}
29+
}
30+
31+
type Result<T> = std::result::Result<T, ChatbotError>;
32+
33+
// Tokenizer Struct
34+
struct Tokenizer {
35+
vocab: HashMap<String, i32>,
36+
vocab_reverse: HashMap<i32, String>,
37+
next_id: i32,
38+
}
39+
40+
impl Tokenizer {
41+
fn new(initial_vocab: Vec<(&str, i32)>) -> Self {
42+
let mut vocab = HashMap::new();
43+
let mut vocab_reverse = HashMap::new();
44+
let mut next_id = 1;
45+
46+
for (word, id) in initial_vocab {
47+
vocab.insert(word.to_string(), id);
48+
vocab_reverse.insert(id, word.to_string());
49+
if id >= next_id {
50+
next_id = id + 1;
51+
}
52+
}
53+
54+
// Add special tokens
55+
vocab.insert("<UNK>".to_string(), 0);
56+
vocab_reverse.insert(0, "<UNK>".to_string());
57+
vocab.insert("<PAD>".to_string(), -1);
58+
vocab_reverse.insert(-1, "<PAD>".to_string());
59+
60+
Tokenizer {
61+
vocab,
62+
vocab_reverse,
63+
next_id,
64+
}
65+
}
66+
67+
fn tokenize(&mut self, input: &str) -> Vec<i32> {
68+
input
69+
.split_whitespace()
70+
.map(|word| {
71+
self.vocab.get(word).cloned().unwrap_or_else(|| {
72+
let id = self.next_id;
73+
self.vocab.insert(word.to_string(), id);
74+
self.vocab_reverse.insert(id, word.to_string());
75+
self.next_id += 1;
76+
id
77+
})
78+
})
79+
.collect()
80+
}
81+
82+
fn tokenize_with_fixed_length(&mut self, input: &str, max_length: usize) -> Vec<i32> {
83+
let mut tokens = self.tokenize(input);
84+
85+
if tokens.len() > max_length {
86+
tokens.truncate(max_length);
87+
} else if tokens.len() < max_length {
88+
tokens.extend(vec![-1; max_length - tokens.len()]); // Assuming -1 is the <PAD> token
89+
}
90+
91+
tokens
92+
}
93+
94+
fn detokenize(&self, tokens: &[i32]) -> String {
95+
tokens
96+
.iter()
97+
.map(|&token| self.vocab_reverse.get(&token).map_or("<UNK>", |v| v).as_str())
98+
.collect::<Vec<&str>>()
99+
.join(" ")
100+
}
101+
}
102+
103+
// Function to load the TFLite model
104+
fn load_model(model_path: &str) -> Result<wasi_nn::Graph> {
105+
let mut file = File::open(model_path)?;
106+
let mut buffer = Vec::new();
107+
file.read_to_end(&mut buffer)?;
108+
let model_segments = &[&buffer[..]];
109+
Ok(unsafe { wasi_nn::load(model_segments, 4, wasi_nn::EXECUTION_TARGET_CPU)? })
110+
}
111+
112+
// Function to initialize the execution context
113+
fn init_context(graph: wasi_nn::Graph) -> Result<wasi_nn::GraphExecutionContext> {
114+
Ok(unsafe { wasi_nn::init_execution_context(graph)? })
115+
}
116+
117+
fn main() -> Result<()> {
118+
// Parse command-line arguments
119+
let args: Vec<String> = env::args().collect();
120+
if args.len() < 2 {
121+
eprintln!("Usage: {} <model_file>", args[0]);
122+
process::exit(1);
123+
}
124+
let model_file = &args[1];
125+
126+
// Load the model
127+
let graph = load_model(model_file)?;
128+
129+
// Initialize execution context
130+
let ctx = init_context(graph)?;
131+
132+
let mut stdout = io::stdout();
133+
let stdin = io::stdin();
134+
135+
// Initialize KV cache data
136+
println!("Chatbot is ready! Type your messages below:");
137+
138+
for line in stdin.lock().lines() {
139+
let user_input = line?;
140+
if user_input.trim().is_empty() {
141+
continue;
142+
}
143+
if user_input.to_lowercase() == "exit" {
144+
break;
145+
}
146+
// Initialize tokenizer
147+
let initial_vocab = vec![
148+
("hello", 1),
149+
("world", 2),
150+
("this", 3),
151+
("is", 4),
152+
("a", 5),
153+
("test", 6),
154+
("<PAD>", -1),
155+
];
156+
157+
let mut tokenizer = Tokenizer::new(initial_vocab);
158+
// let user_input = "hello world this is a test with more words";
159+
160+
// Tokenize with fixed length
161+
let max_length = 655360;
162+
let tokens = tokenizer.tokenize_with_fixed_length(&user_input, max_length);
163+
let tokens_dims = &[1u32, max_length as u32];
164+
let tokens_tensor = wasi_nn::Tensor {
165+
dimensions: tokens_dims,
166+
r#type: wasi_nn::TENSOR_TYPE_I32,
167+
data: cast_slice(&tokens),
168+
};
169+
170+
// Create input_pos tensor
171+
let input_pos: Vec<i32> = (0..max_length as i32).collect();
172+
let input_pos_dims = &[1u32, max_length as u32];
173+
let input_pos_tensor = wasi_nn::Tensor {
174+
dimensions: input_pos_dims,
175+
r#type: wasi_nn::TENSOR_TYPE_I32,
176+
data: cast_slice(&input_pos),
177+
};
178+
179+
// Create kv tensor (ensure kv_data has the correct size)
180+
let kv_data = vec![0.0_f32; max_length]; // Example initialization
181+
let kv_dims = &[32u32, 2u32, 1u32, 16u32, 10u32, 64u32];
182+
let kv_tensor = wasi_nn::Tensor {
183+
dimensions: kv_dims,
184+
r#type: wasi_nn::TENSOR_TYPE_F32,
185+
data: cast_slice(&kv_data),
186+
};
187+
188+
// Set inputs
189+
unsafe {
190+
wasi_nn::set_input(ctx, 0, tokens_tensor)?;
191+
wasi_nn::set_input(ctx, 1, input_pos_tensor)?;
192+
wasi_nn::set_input(ctx, 2, kv_tensor)?;
193+
}
194+
195+
// Run inference
196+
run_inference(&ctx)?;
197+
// Get output
198+
let output = get_model_output(&ctx, 0, 655360)?;
199+
200+
// Detokenize output
201+
let response = tokenizer.detokenize(&output.as_slice());
202+
203+
// Display response
204+
writeln!(stdout, "Bot: {}", response)?;
205+
}
206+
207+
println!("Chatbot session ended.");
208+
Ok(())
209+
}
210+
211+
// Function to run inference
212+
fn run_inference(ctx: &wasi_nn::GraphExecutionContext) -> Result<()> {
213+
unsafe { Ok(wasi_nn::compute(*ctx)?) }
214+
}
215+
216+
// Function to get model output
217+
fn get_model_output(
218+
ctx: &wasi_nn::GraphExecutionContext,
219+
index: u32,
220+
size: usize,
221+
) -> Result<Vec<i32>> {
222+
let mut buffer = vec![0i32; size];
223+
let buffer_ptr = cast_slice_mut(&mut buffer).as_mut_ptr();
224+
let byte_len = (size * std::mem::size_of::<i32>()) as u32;
225+
unsafe { wasi_nn::get_output(*ctx, index, buffer_ptr, byte_len)? };
226+
Ok(buffer)
227+
}

0 commit comments

Comments
 (0)