Skip to content

Commit b6c8ab2

Browse files
authored
Refactor the examples to the plugin system (second-state#22)
* Refactor the TFLite example * Use WasmEdge 0.13 and wasi-nn 0.4. * Add GitHub Actions CI Part of Issue second-state#21 Signed-off-by: Michael Yuan <michael@secondstate.io> * Add trigger to all branches Signed-off-by: Michael Yuan <michael@secondstate.io> --------- Signed-off-by: Michael Yuan <michael@secondstate.io>
1 parent a33bc37 commit b6c8ab2

File tree

9 files changed

+74
-99
lines changed

9 files changed

+74
-99
lines changed

‎.github/workflows/tflite.yml‎

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
name: Build and Test TFlite examples
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
logLevel:
7+
description: 'Log level'
8+
required: true
9+
default: 'info'
10+
push:
11+
branches: [ '*' ]
12+
pull_request:
13+
branches: [ '*' ]
14+
15+
jobs:
16+
build:
17+
18+
runs-on: ubuntu-20.04
19+
20+
steps:
21+
- uses: actions/checkout@v2
22+
23+
- name: Install apt-get packages
24+
run: |
25+
sudo ACCEPT_EULA=Y apt-get update
26+
sudo ACCEPT_EULA=Y apt-get upgrade
27+
sudo apt-get install wget git curl software-properties-common build-essential
28+
29+
- name: Install Rust target for wasm
30+
run: |
31+
rustup target add wasm32-wasi
32+
33+
- name: Install WasmEdge + WASI-NN + TFLite
34+
run: |
35+
VERSION=0.13.1
36+
TFVERSION=2.12.0
37+
curl -s -L -O --remote-name-all https://github.com/second-state/WasmEdge-tensorflow-deps/releases/download/TF-2.12.0-CC/WasmEdge-tensorflow-deps-TFLite-TF-$TFVERSION-CC-manylinux2014_x86_64.tar.gz
38+
tar -zxf WasmEdge-tensorflow-deps-TFLite-TF-$TFVERSION-CC-manylinux2014_x86_64.tar.gz
39+
rm -f WasmEdge-tensorflow-deps-TFLite-TF-$TFVERSION-CC-manylinux2014_x86_64.tar.gz
40+
sudo mv libtensorflowlite_c.so /usr/local/lib
41+
sudo mv libtensorflowlite_flex.so /usr/local/lib
42+
curl -sSf https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | sudo bash -s -- -v $VERSION --plugins wasi_nn-tensorflowlite -p /usr/local
43+
44+
- name: Example
45+
run: |
46+
cd tflite-birds_v1-image/rust
47+
cargo build --target wasm32-wasi --release
48+
cd ..
49+
wasmedge compile rust/target/wasm32-wasi/release/wasmedge-wasinn-example-tflite-bird-image.wasm wasmedge-wasinn-example-tflite-bird-image.wasm
50+
wasmedge --dir .:. wasmedge-wasinn-example-tflite-bird-image.wasm lite-model_aiy_vision_classifier_birds_V1_3.tflite bird.jpg
51+

‎tflite-birds_v1-image/README.md‎

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,42 +10,34 @@ This crate depends on the `wasi-nn` in the `Cargo.toml`:
1010

1111
```toml
1212
[dependencies]
13-
wasi-nn = "0.3.0"
13+
wasi-nn = "0.4.0"
1414
```
1515

1616
## Build
1717

1818
Compile the application to WebAssembly:
1919

2020
```bash
21-
cd rust/tflite-bird && cargo build --target=wasm32-wasi --release
21+
cd rust && cargo build --target=wasm32-wasi --release
2222
```
2323

24-
The output WASM file will be at [`rust/tflite-bird/target/wasm32-wasi/release/wasmedge-wasinn-example-tflite-bird-image.wasm`](wasmedge-wasinn-example-tflite-bird-image.wasm).
24+
The output WASM file will be at [`rust/target/wasm32-wasi/release/wasmedge-wasinn-example-tflite-bird-image.wasm`](wasmedge-wasinn-example-tflite-bird-image.wasm).
2525
To speed up the image processing, we can enable the AOT mode in WasmEdge with:
2626

2727
```bash
28-
wasmedgec rust/tflite-bird/target/wasm32-wasi/release/wasmedge-wasinn-example-tflite-bird-image.wasm wasmedge-wasinn-example-tflite-bird-image.wasm
28+
wasmedge compile rust/target/wasm32-wasi/release/wasmedge-wasinn-example-tflite-bird-image.wasm wasmedge-wasinn-example-tflite-bird-image.wasm
2929
```
3030

3131
## Run
3232

33-
### Download fixture
33+
### Test data
3434

3535
The testing image is located at `./bird.jpg`:
3636

3737
![Aix galericulata](bird.jpg)
3838

3939
The `tflite` model is located at `./lite-model_aiy_vision_classifier_birds_V1_3.tflite`
4040

41-
### Generate Image Tensor
42-
43-
If you want to generate the [raw](birdx224x224x3.rgb) tensor, you can run:
44-
45-
```shell
46-
cd rust/image-converter/ && cargo run ../../bird.jpg ../../birdx224x224x3.rgb
47-
```
48-
4941
### Execute
5042

5143
Users should [install the WasmEdge with WASI-NN TensorFlow-Lite backend plug-in](https://wasmedge.org/book/en/write_wasm/rust/wasinn.html#get-wasmedge-with-wasi-nn-plug-in-tensorflow-lite-backend).

‎tflite-birds_v1-image/birdx224x224x3.rgb‎

Lines changed: 0 additions & 11 deletions
This file was deleted.

‎tflite-birds_v1-image/rust/tflite-bird/Cargo.toml‎ renamed to ‎tflite-birds_v1-image/rust/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ publish = false
88

99
[dependencies]
1010
image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "pnm", "tga", "tiff", "webp", "bmp", "hdr", "dxt", "dds", "farbfeld"] }
11-
wasi-nn = { version = "0.3.0" }
11+
wasi-nn = "0.4.0"
1212

1313
[workspace]

‎tflite-birds_v1-image/rust/image-converter/Cargo.toml‎

Lines changed: 0 additions & 9 deletions
This file was deleted.

‎tflite-birds_v1-image/rust/image-converter/src/main.rs‎

Lines changed: 0 additions & 28 deletions
This file was deleted.

‎tflite-birds_v1-image/rust/tflite-bird/src/main.rs‎ renamed to ‎tflite-birds_v1-image/rust/src/main.rs‎

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,38 @@
11
use image::io::Reader;
22
use image::DynamicImage;
3-
use std::convert::TryInto;
43
use std::env;
54
use std::fs;
6-
use wasi_nn;
5+
use std::error::Error;
6+
use wasi_nn::{GraphBuilder, GraphEncoding, ExecutionTarget, TensorType};
77
mod imagenet_classes;
88

9-
pub fn main() {
9+
pub fn main() -> Result<(), Box<dyn Error>> {
1010
let args: Vec<String> = env::args().collect();
1111
let model_bin_name: &str = &args[1];
1212
let image_name: &str = &args[2];
1313

14-
let weights = fs::read(model_bin_name).unwrap();
14+
let weights = fs::read(model_bin_name)?;
1515
println!("Read graph weights, size in bytes: {}", weights.len());
1616

17-
let graph = unsafe {
18-
wasi_nn::load(
19-
&[&weights],
20-
wasi_nn::GRAPH_ENCODING_TENSORFLOWLITE,
21-
wasi_nn::EXECUTION_TARGET_CPU,
22-
)
23-
.unwrap()
24-
};
17+
let graph = GraphBuilder::new(GraphEncoding::TensorflowLite, ExecutionTarget::CPU).build_from_bytes(&[&weights])?;
18+
let mut ctx = graph.init_execution_context()?;
2519
println!("Loaded graph into wasi-nn with ID: {}", graph);
2620

27-
let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() };
28-
println!("Created wasi-nn execution context with ID: {}", context);
29-
3021
// Load a tensor that precisely matches the graph input tensor (see
3122
let tensor_data = image_to_tensor(image_name.to_string(), 224, 224);
3223
println!("Read input tensor, size in bytes: {}", tensor_data.len());
33-
let tensor = wasi_nn::Tensor {
34-
dimensions: &[1, 224, 224, 3],
35-
type_: wasi_nn::TENSOR_TYPE_U8,
36-
data: &tensor_data,
37-
};
38-
unsafe {
39-
wasi_nn::set_input(context, 0, tensor).unwrap();
40-
}
24+
25+
// Pass tensor data into the TFLite runtime
26+
ctx.set_input(0, TensorType::U8, &[1, 224, 224, 3], &tensor_data)?;
27+
4128
// Execute the inference.
42-
unsafe {
43-
wasi_nn::compute(context).unwrap();
44-
}
45-
println!("Executed graph inference");
29+
ctx.compute()?;
30+
4631
// Retrieve the output.
47-
let mut output_buffer = vec![0u8; 965];
48-
unsafe {
49-
wasi_nn::get_output(
50-
context,
51-
0,
52-
&mut output_buffer[..] as *mut [u8] as *mut u8,
53-
output_buffer.len().try_into().unwrap(),
54-
)
55-
.unwrap();
56-
}
32+
let mut output_buffer = vec![0u8; imagenet_classes::AIY_BIRDS_V1.len()];
33+
_ = ctx.get_output(0, &mut output_buffer)?;
5734

35+
// Sort the result with the highest probability result first
5836
let results = sort_results(&output_buffer);
5937
for i in 0..5 {
6038
println!(
@@ -65,6 +43,8 @@ pub fn main() {
6543
imagenet_classes::AIY_BIRDS_V1[results[i].0]
6644
);
6745
}
46+
47+
Ok(())
6848
}
6949

7050
// Sort the buffer of probabilities. The graph places the match probability for each class at the
Binary file not shown.

0 commit comments

Comments
 (0)