1- use serde_json:: Value ;
21use serde_json:: json;
2+ use serde_json:: Value ;
33use std:: env;
44use std:: io;
55use wasi_nn:: { self , GraphExecutionContext } ;
@@ -67,8 +67,7 @@ fn get_output_from_context(context: &GraphExecutionContext) -> String {
6767
6868#[ allow( dead_code) ]
6969fn get_metadata_from_context ( context : & GraphExecutionContext ) -> Value {
70- serde_json:: from_str ( & get_data_from_context ( context, 1 ) )
71- . expect ( "Failed to get metadata" )
70+ serde_json:: from_str ( & get_data_from_context ( context, 1 ) ) . expect ( "Failed to get metadata" )
7271}
7372
7473fn main ( ) {
@@ -119,15 +118,18 @@ fn main() {
119118 let mut saved_prompt = String :: new ( ) ;
120119
121120 loop {
122- println ! ( "Question :" ) ;
121+ println ! ( "USER :" ) ;
123122 let input = read_input ( ) ;
124123 if saved_prompt. is_empty ( ) {
125124 saved_prompt = format ! (
126125 "<start_of_turn>user {} <end_of_turn><start_of_turn>model" ,
127126 input
128127 ) ;
129128 } else {
130- saved_prompt = format ! ( "{} <start_of_turn>user {} <end_of_turn><start_of_turn>model" , saved_prompt, input) ;
129+ saved_prompt = format ! (
130+ "{} <start_of_turn>user {} <end_of_turn><start_of_turn>model" ,
131+ saved_prompt, input
132+ ) ;
131133 }
132134
133135 // Set prompt to the input tensor.
@@ -148,6 +150,7 @@ fn main() {
148150
149151 // Execute the inference.
150152 let mut reset_prompt = false ;
153+ println ! ( "ASSISTANT:" ) ;
151154 match context. compute ( ) {
152155 Ok ( _) => ( ) ,
153156 Err ( wasi_nn:: Error :: BackendError ( wasi_nn:: BackendError :: ContextFull ) ) => {
@@ -165,7 +168,11 @@ fn main() {
165168
166169 // Retrieve the output.
167170 let mut output = get_output_from_context ( & context) ;
168- println ! ( "Answer:\n {}" , output. trim( ) ) ;
171+ if let Some ( true ) = options[ "stream-stdout" ] . as_bool ( ) {
172+ println ! ( "" ) ;
173+ } else {
174+ println ! ( "{}" , output. trim( ) ) ;
175+ }
169176
170177 // Update the saved prompt.
171178 if reset_prompt {
0 commit comments