|
| 1 | +#![feature(str_as_str)] |
| 2 | +use bytemuck::{cast_slice, cast_slice_mut}; |
| 3 | +use std::collections::HashMap; |
| 4 | +use std::env; |
| 5 | +use std::fs::File; |
| 6 | +use std::io::Read; |
| 7 | +use std::io::{self, BufRead, Write}; |
| 8 | +use std::process; |
| 9 | +use thiserror::Error; |
| 10 | +use wasi_nn; |
| 11 | + |
| 12 | +// Define a custom error type to handle multiple error sources |
| 13 | +#[derive(Error, Debug)] |
| 14 | +pub enum ChatbotError { |
| 15 | + #[error("IO Error: {0}")] |
| 16 | + Io(#[from] std::io::Error), |
| 17 | + |
| 18 | + #[error("WASI NN Error: {0}")] |
| 19 | + WasiNn(String), |
| 20 | + |
| 21 | + #[error("Other Error: {0}")] |
| 22 | + Other(String), |
| 23 | +} |
| 24 | + |
| 25 | +impl From<wasi_nn::Error> for ChatbotError { |
| 26 | + fn from(error: wasi_nn::Error) -> Self { |
| 27 | + ChatbotError::WasiNn(error.to_string()) |
| 28 | + } |
| 29 | +} |
| 30 | + |
| 31 | +type Result<T> = std::result::Result<T, ChatbotError>; |
| 32 | + |
| 33 | +// Tokenizer Struct |
| 34 | +struct Tokenizer { |
| 35 | + vocab: HashMap<String, i32>, |
| 36 | + vocab_reverse: HashMap<i32, String>, |
| 37 | + next_id: i32, |
| 38 | +} |
| 39 | + |
| 40 | +impl Tokenizer { |
| 41 | + fn new(initial_vocab: Vec<(&str, i32)>) -> Self { |
| 42 | + let mut vocab = HashMap::new(); |
| 43 | + let mut vocab_reverse = HashMap::new(); |
| 44 | + let mut next_id = 1; |
| 45 | + |
| 46 | + for (word, id) in initial_vocab { |
| 47 | + vocab.insert(word.to_string(), id); |
| 48 | + vocab_reverse.insert(id, word.to_string()); |
| 49 | + if id >= next_id { |
| 50 | + next_id = id + 1; |
| 51 | + } |
| 52 | + } |
| 53 | + |
| 54 | + // Add special tokens |
| 55 | + vocab.insert("<UNK>".to_string(), 0); |
| 56 | + vocab_reverse.insert(0, "<UNK>".to_string()); |
| 57 | + vocab.insert("<PAD>".to_string(), -1); |
| 58 | + vocab_reverse.insert(-1, "<PAD>".to_string()); |
| 59 | + |
| 60 | + Tokenizer { |
| 61 | + vocab, |
| 62 | + vocab_reverse, |
| 63 | + next_id, |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + fn tokenize(&mut self, input: &str) -> Vec<i32> { |
| 68 | + input |
| 69 | + .split_whitespace() |
| 70 | + .map(|word| { |
| 71 | + self.vocab.get(word).cloned().unwrap_or_else(|| { |
| 72 | + let id = self.next_id; |
| 73 | + self.vocab.insert(word.to_string(), id); |
| 74 | + self.vocab_reverse.insert(id, word.to_string()); |
| 75 | + self.next_id += 1; |
| 76 | + id |
| 77 | + }) |
| 78 | + }) |
| 79 | + .collect() |
| 80 | + } |
| 81 | + |
| 82 | + fn tokenize_with_fixed_length(&mut self, input: &str, max_length: usize) -> Vec<i32> { |
| 83 | + let mut tokens = self.tokenize(input); |
| 84 | + |
| 85 | + if tokens.len() > max_length { |
| 86 | + tokens.truncate(max_length); |
| 87 | + } else if tokens.len() < max_length { |
| 88 | + tokens.extend(vec![-1; max_length - tokens.len()]); // Assuming -1 is the <PAD> token |
| 89 | + } |
| 90 | + |
| 91 | + tokens |
| 92 | + } |
| 93 | + |
| 94 | + fn detokenize(&self, tokens: &[i32]) -> String { |
| 95 | + tokens |
| 96 | + .iter() |
| 97 | + .map(|&token| self.vocab_reverse.get(&token).map_or("<UNK>", |v| v).as_str()) |
| 98 | + .collect::<Vec<&str>>() |
| 99 | + .join(" ") |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +// Function to load the TFLite model |
| 104 | +fn load_model(model_path: &str) -> Result<wasi_nn::Graph> { |
| 105 | + let mut file = File::open(model_path)?; |
| 106 | + let mut buffer = Vec::new(); |
| 107 | + file.read_to_end(&mut buffer)?; |
| 108 | + let model_segments = &[&buffer[..]]; |
| 109 | + Ok(unsafe { wasi_nn::load(model_segments, 4, wasi_nn::EXECUTION_TARGET_CPU)? }) |
| 110 | +} |
| 111 | + |
| 112 | +// Function to initialize the execution context |
| 113 | +fn init_context(graph: wasi_nn::Graph) -> Result<wasi_nn::GraphExecutionContext> { |
| 114 | + Ok(unsafe { wasi_nn::init_execution_context(graph)? }) |
| 115 | +} |
| 116 | + |
| 117 | +fn main() -> Result<()> { |
| 118 | + // Parse command-line arguments |
| 119 | + let args: Vec<String> = env::args().collect(); |
| 120 | + if args.len() < 2 { |
| 121 | + eprintln!("Usage: {} <model_file>", args[0]); |
| 122 | + process::exit(1); |
| 123 | + } |
| 124 | + let model_file = &args[1]; |
| 125 | + |
| 126 | + // Load the model |
| 127 | + let graph = load_model(model_file)?; |
| 128 | + |
| 129 | + // Initialize execution context |
| 130 | + let ctx = init_context(graph)?; |
| 131 | + |
| 132 | + let mut stdout = io::stdout(); |
| 133 | + let stdin = io::stdin(); |
| 134 | + |
| 135 | + // Initialize KV cache data |
| 136 | + println!("Chatbot is ready! Type your messages below:"); |
| 137 | + |
| 138 | + for line in stdin.lock().lines() { |
| 139 | + let user_input = line?; |
| 140 | + if user_input.trim().is_empty() { |
| 141 | + continue; |
| 142 | + } |
| 143 | + if user_input.to_lowercase() == "exit" { |
| 144 | + break; |
| 145 | + } |
| 146 | + // Initialize tokenizer |
| 147 | + let initial_vocab = vec![ |
| 148 | + ("hello", 1), |
| 149 | + ("world", 2), |
| 150 | + ("this", 3), |
| 151 | + ("is", 4), |
| 152 | + ("a", 5), |
| 153 | + ("test", 6), |
| 154 | + ("<PAD>", -1), |
| 155 | + ]; |
| 156 | + |
| 157 | + let mut tokenizer = Tokenizer::new(initial_vocab); |
| 158 | + // let user_input = "hello world this is a test with more words"; |
| 159 | + |
| 160 | + // Tokenize with fixed length |
| 161 | + let max_length = 655360; |
| 162 | + let tokens = tokenizer.tokenize_with_fixed_length(&user_input, max_length); |
| 163 | + let tokens_dims = &[1u32, max_length as u32]; |
| 164 | + let tokens_tensor = wasi_nn::Tensor { |
| 165 | + dimensions: tokens_dims, |
| 166 | + r#type: wasi_nn::TENSOR_TYPE_I32, |
| 167 | + data: cast_slice(&tokens), |
| 168 | + }; |
| 169 | + |
| 170 | + // Create input_pos tensor |
| 171 | + let input_pos: Vec<i32> = (0..max_length as i32).collect(); |
| 172 | + let input_pos_dims = &[1u32, max_length as u32]; |
| 173 | + let input_pos_tensor = wasi_nn::Tensor { |
| 174 | + dimensions: input_pos_dims, |
| 175 | + r#type: wasi_nn::TENSOR_TYPE_I32, |
| 176 | + data: cast_slice(&input_pos), |
| 177 | + }; |
| 178 | + |
| 179 | + // Create kv tensor (ensure kv_data has the correct size) |
| 180 | + let kv_data = vec![0.0_f32; max_length]; // Example initialization |
| 181 | + let kv_dims = &[32u32, 2u32, 1u32, 16u32, 10u32, 64u32]; |
| 182 | + let kv_tensor = wasi_nn::Tensor { |
| 183 | + dimensions: kv_dims, |
| 184 | + r#type: wasi_nn::TENSOR_TYPE_F32, |
| 185 | + data: cast_slice(&kv_data), |
| 186 | + }; |
| 187 | + |
| 188 | + // Set inputs |
| 189 | + unsafe { |
| 190 | + wasi_nn::set_input(ctx, 0, tokens_tensor)?; |
| 191 | + wasi_nn::set_input(ctx, 1, input_pos_tensor)?; |
| 192 | + wasi_nn::set_input(ctx, 2, kv_tensor)?; |
| 193 | + } |
| 194 | + |
| 195 | + // Run inference |
| 196 | + run_inference(&ctx)?; |
| 197 | + // Get output |
| 198 | + let output = get_model_output(&ctx, 0, 655360)?; |
| 199 | + |
| 200 | + // Detokenize output |
| 201 | + let response = tokenizer.detokenize(&output.as_slice()); |
| 202 | + |
| 203 | + // Display response |
| 204 | + writeln!(stdout, "Bot: {}", response)?; |
| 205 | + } |
| 206 | + |
| 207 | + println!("Chatbot session ended."); |
| 208 | + Ok(()) |
| 209 | +} |
| 210 | + |
| 211 | +// Function to run inference |
| 212 | +fn run_inference(ctx: &wasi_nn::GraphExecutionContext) -> Result<()> { |
| 213 | + unsafe { Ok(wasi_nn::compute(*ctx)?) } |
| 214 | +} |
| 215 | + |
| 216 | +// Function to get model output |
| 217 | +fn get_model_output( |
| 218 | + ctx: &wasi_nn::GraphExecutionContext, |
| 219 | + index: u32, |
| 220 | + size: usize, |
| 221 | +) -> Result<Vec<i32>> { |
| 222 | + let mut buffer = vec![0i32; size]; |
| 223 | + let buffer_ptr = cast_slice_mut(&mut buffer).as_mut_ptr(); |
| 224 | + let byte_len = (size * std::mem::size_of::<i32>()) as u32; |
| 225 | + unsafe { wasi_nn::get_output(*ctx, index, buffer_ptr, byte_len)? }; |
| 226 | + Ok(buffer) |
| 227 | +} |
0 commit comments