Skip to content

Commit 4dffec2

Browse files
dm4hydai
authored andcommitted
[Example] ggml: use wasmedge-wasi-nn crate (second-state#108)
Signed-off-by: dm4 <dm4@secondstate.io>
1 parent 6a9a245 commit 4dffec2

21 files changed

+108
-115
lines changed

‎wasmedge-ggml/chatml/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ edition = "2021"
55

66
[dependencies]
77
serde_json = "1.0"
8-
wasi-nn = { git = "https://github.com/second-state/wasmedge-wasi-nn", branch = "ggml" }
8+
wasmedge-wasi-nn = "0.7.0"

‎wasmedge-ggml/chatml/src/main.rs‎

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ use serde_json::Value;
22
use std::collections::HashMap;
33
use std::env;
44
use std::io;
5-
use wasi_nn::{self, GraphExecutionContext};
5+
use wasmedge_wasi_nn::{
6+
self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
7+
TensorType,
8+
};
69

710
fn read_input() -> String {
811
loop {
@@ -16,19 +19,16 @@ fn read_input() -> String {
1619
}
1720
}
1821

19-
fn set_data_to_context(
20-
context: &mut GraphExecutionContext,
21-
data: Vec<u8>,
22-
) -> Result<(), wasi_nn::Error> {
23-
context.set_input(0, wasi_nn::TensorType::U8, &[1], &data)
22+
fn set_data_to_context(context: &mut GraphExecutionContext, data: Vec<u8>) -> Result<(), Error> {
23+
context.set_input(0, TensorType::U8, &[1], &data)
2424
}
2525

2626
#[allow(dead_code)]
2727
fn set_metadata_to_context(
2828
context: &mut GraphExecutionContext,
2929
data: Vec<u8>,
30-
) -> Result<(), wasi_nn::Error> {
31-
context.set_input(1, wasi_nn::TensorType::U8, &[1], &data)
30+
) -> Result<(), Error> {
31+
context.set_input(1, TensorType::U8, &[1], &data)
3232
}
3333

3434
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
@@ -63,11 +63,10 @@ fn main() {
6363
options.insert("ctx-size", Value::from(512));
6464

6565
// Create graph and initialize context.
66-
let graph =
67-
wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Ggml, wasi_nn::ExecutionTarget::AUTO)
68-
.config(serde_json::to_string(&options).expect("Failed to serialize options"))
69-
.build_from_cache(model_name)
70-
.expect("Failed to build graph");
66+
let graph = GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO)
67+
.config(serde_json::to_string(&options).expect("Failed to serialize options"))
68+
.build_from_cache(model_name)
69+
.expect("Failed to build graph");
7170
let mut context = graph
7271
.init_execution_context()
7372
.expect("Failed to init context");
@@ -118,11 +117,11 @@ fn main() {
118117
let mut reset_prompt = false;
119118
match context.compute() {
120119
Ok(_) => (),
121-
Err(wasi_nn::Error::BackendError(wasi_nn::BackendError::ContextFull)) => {
120+
Err(Error::BackendError(BackendError::ContextFull)) => {
122121
println!("\n[INFO] Context full, we'll reset the context and continue.");
123122
reset_prompt = true;
124123
}
125-
Err(wasi_nn::Error::BackendError(wasi_nn::BackendError::PromptTooLong)) => {
124+
Err(Error::BackendError(BackendError::PromptTooLong)) => {
126125
println!("\n[INFO] Prompt too long, we'll reset the context and continue.");
127126
reset_prompt = true;
128127
}
-361 Bytes
Binary file not shown.

‎wasmedge-ggml/embedding/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ edition = "2021"
55

66
[dependencies]
77
serde_json = "1.0"
8-
wasi-nn = { git = "https://github.com/second-state/wasmedge-wasi-nn", branch = "ggml" }
8+
wasmedge-wasi-nn = "0.7.0"

‎wasmedge-ggml/embedding/src/main.rs‎

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use serde_json::{json, Value};
22
use std::env;
33
use std::io::{self};
4-
use wasi_nn::{self, GraphExecutionContext};
4+
use wasmedge_wasi_nn::{
5+
self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
6+
TensorType,
7+
};
58

69
fn read_input() -> String {
710
loop {
@@ -33,19 +36,16 @@ fn get_options_from_env() -> Value {
3336
options
3437
}
3538

36-
fn set_data_to_context(
37-
context: &mut GraphExecutionContext,
38-
data: Vec<u8>,
39-
) -> Result<(), wasi_nn::Error> {
40-
context.set_input(0, wasi_nn::TensorType::U8, &[1], &data)
39+
fn set_data_to_context(context: &mut GraphExecutionContext, data: Vec<u8>) -> Result<(), Error> {
40+
context.set_input(0, TensorType::U8, &[1], &data)
4141
}
4242

4343
#[allow(dead_code)]
4444
fn set_metadata_to_context(
4545
context: &mut GraphExecutionContext,
4646
data: Vec<u8>,
47-
) -> Result<(), wasi_nn::Error> {
48-
context.set_input(1, wasi_nn::TensorType::U8, &[1], &data)
47+
) -> Result<(), Error> {
48+
context.set_input(1, TensorType::U8, &[1], &data)
4949
}
5050

5151
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
@@ -77,11 +77,10 @@ fn main() {
7777
options["embedding"] = serde_json::Value::Bool(true);
7878

7979
// Create graph and initialize context.
80-
let graph =
81-
wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Ggml, wasi_nn::ExecutionTarget::AUTO)
82-
.config(options.to_string())
83-
.build_from_cache(model_name)
84-
.expect("Create GraphBuilder Failed, please check the model name or options");
80+
let graph = GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO)
81+
.config(options.to_string())
82+
.build_from_cache(model_name)
83+
.expect("Create GraphBuilder Failed, please check the model name or options");
8584
let mut context = graph
8685
.init_execution_context()
8786
.expect("Init Context Failed, please check the model");
@@ -98,7 +97,7 @@ fn main() {
9897
println!("Prompt:\n{}", prompt);
9998
let tensor_data = prompt.as_bytes().to_vec();
10099
context
101-
.set_input(0, wasi_nn::TensorType::U8, &[1], &tensor_data)
100+
.set_input(0, TensorType::U8, &[1], &tensor_data)
102101
.unwrap();
103102
println!("Raw Embedding Output:");
104103
context.compute().unwrap();
@@ -139,10 +138,10 @@ fn main() {
139138

140139
match context.compute() {
141140
Ok(_) => (),
142-
Err(wasi_nn::Error::BackendError(wasi_nn::BackendError::ContextFull)) => {
141+
Err(Error::BackendError(BackendError::ContextFull)) => {
143142
println!("\n[INFO] Context full");
144143
}
145-
Err(wasi_nn::Error::BackendError(wasi_nn::BackendError::PromptTooLong)) => {
144+
Err(Error::BackendError(BackendError::PromptTooLong)) => {
146145
println!("\n[INFO] Prompt too long");
147146
}
148147
Err(err) => {
137 KB
Binary file not shown.

‎wasmedge-ggml/gemma/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ edition = "2021"
55

66
[dependencies]
77
serde_json = "1.0"
8-
wasi-nn = { git = "https://github.com/second-state/wasmedge-wasi-nn", branch = "ggml" }
8+
wasmedge-wasi-nn = "0.7.0"

‎wasmedge-ggml/gemma/src/main.rs‎

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ use serde_json::json;
22
use serde_json::Value;
33
use std::env;
44
use std::io;
5-
use wasi_nn::{self, GraphExecutionContext};
5+
use wasmedge_wasi_nn::{
6+
self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
7+
TensorType,
8+
};
69

710
fn read_input() -> String {
811
loop {
@@ -41,19 +44,16 @@ fn get_options_from_env() -> Value {
4144
options
4245
}
4346

44-
fn set_data_to_context(
45-
context: &mut GraphExecutionContext,
46-
data: Vec<u8>,
47-
) -> Result<(), wasi_nn::Error> {
48-
context.set_input(0, wasi_nn::TensorType::U8, &[1], &data)
47+
fn set_data_to_context(context: &mut GraphExecutionContext, data: Vec<u8>) -> Result<(), Error> {
48+
context.set_input(0, TensorType::U8, &[1], &data)
4949
}
5050

5151
#[allow(dead_code)]
5252
fn set_metadata_to_context(
5353
context: &mut GraphExecutionContext,
5454
data: Vec<u8>,
55-
) -> Result<(), wasi_nn::Error> {
56-
context.set_input(1, wasi_nn::TensorType::U8, &[1], &data)
55+
) -> Result<(), Error> {
56+
context.set_input(1, TensorType::U8, &[1], &data)
5757
}
5858

5959
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
@@ -86,11 +86,10 @@ fn main() {
8686
let options = get_options_from_env();
8787

8888
// Create graph and initialize context.
89-
let graph =
90-
wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Ggml, wasi_nn::ExecutionTarget::AUTO)
91-
.config(serde_json::to_string(&options).expect("Failed to serialize options"))
92-
.build_from_cache(model_name)
93-
.expect("Failed to build graph");
89+
let graph = GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO)
90+
.config(serde_json::to_string(&options).expect("Failed to serialize options"))
91+
.build_from_cache(model_name)
92+
.expect("Failed to build graph");
9493
let mut context = graph
9594
.init_execution_context()
9695
.expect("Failed to init context");
@@ -113,7 +112,7 @@ fn main() {
113112
println!("Prompt:\n{}", prompt);
114113
let tensor_data = prompt.as_bytes().to_vec();
115114
context
116-
.set_input(0, wasi_nn::TensorType::U8, &[1], &tensor_data)
115+
.set_input(0, TensorType::U8, &[1], &tensor_data)
117116
.expect("Failed to set input");
118117
println!("Response:");
119118
context.compute().expect("Failed to compute");
@@ -160,11 +159,11 @@ fn main() {
160159
println!("ASSISTANT:");
161160
match context.compute() {
162161
Ok(_) => (),
163-
Err(wasi_nn::Error::BackendError(wasi_nn::BackendError::ContextFull)) => {
162+
Err(Error::BackendError(BackendError::ContextFull)) => {
164163
println!("\n[INFO] Context full, we'll reset the context and continue.");
165164
reset_prompt = true;
166165
}
167-
Err(wasi_nn::Error::BackendError(wasi_nn::BackendError::PromptTooLong)) => {
166+
Err(Error::BackendError(BackendError::PromptTooLong)) => {
168167
println!("\n[INFO] Prompt too long, we'll reset the context and continue.");
169168
reset_prompt = true;
170169
}
-5.2 KB
Binary file not shown.

‎wasmedge-ggml/llama-stream/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ edition = "2021"
55

66
[dependencies]
77
serde_json = "1.0"
8-
wasi-nn = { git = "https://github.com/second-state/wasmedge-wasi-nn", branch = "ggml" }
8+
wasmedge-wasi-nn = "0.7.0"

0 commit comments

Comments
 (0)