@@ -29,7 +29,7 @@ def init_vectorstore():
2929 pprint (dict (info ))
3030
3131
32- async def run_shell ():
32+ def get_law_chain ():
3333 llm = get_llm ()
3434 law_vs = get_vectorstore (config .LAW_VS_COLLECTION_NAME )
3535 web_vs = get_vectorstore (config .WEB_VS_COLLECTION_NAME )
@@ -47,6 +47,12 @@ async def run_shell():
4747 return_source_documents = True ,
4848 )
4949
50+ return chain
51+
52+
53+ async def run_shell ():
54+ chain = get_law_chain ()
55+
5056 while True :
5157 query = input ("\n 用户:" )
5258 if query .strip () == "stop" :
@@ -55,8 +61,8 @@ async def run_shell():
5561 callback = AsyncIteratorCallbackHandler ()
5662 task = asyncio .create_task (
5763 chain .ainvoke ({"query" : query }, config = {"callbacks" : [callback ]}))
58- async for t in callback .aiter ():
59- print (t , end = "" , flush = True )
64+ async for new_token in callback .aiter ():
65+ print (new_token , end = "" , flush = True )
6066
6167 print ("\n " )
6268 res = await task
@@ -65,6 +71,34 @@ async def run_shell():
6571 print (f"{ source_text (docs )} " )
6672
6773
74+ async def run_web ():
75+ import gradio as gr
76+
77+ chain = get_law_chain ()
78+
79+ async def chat (message , history ):
80+ callback = AsyncIteratorCallbackHandler ()
81+ task = asyncio .create_task (
82+ chain .ainvoke ({"query" : message }, config = {"callbacks" : [callback ]}))
83+
84+ response = ""
85+ async for new_token in callback .aiter ():
86+ response += new_token
87+ yield response
88+
89+ res = await task
90+ _ , docs = res ['result' ], res ['source_documents' ]
91+
92+ response += "\n " + source_text (docs )
93+ yield response
94+
95+ demo = gr .ChatInterface (
96+ fn = chat , examples = ["故意杀了一个人,会判几年?" , "杀人自首会减刑吗?" ], title = "法律AI小助手" )
97+
98+ demo .queue ()
99+ demo .launch (server_name = config .WEB_HOST , server_port = config .WEB_PORT )
100+
101+
68102if __name__ == '__main__' :
69103 import argparse
70104 parser = argparse .ArgumentParser (
@@ -85,6 +119,15 @@ async def run_shell():
85119 run shell
86120 ''' )
87121 )
122+ parser .add_argument (
123+ "-w" ,
124+ "--web" ,
125+ action = "store_true" ,
126+ help = ('''
127+ run web
128+ ''' )
129+ )
130+
88131
89132 if len (sys .argv ) <= 1 :
90133 parser .print_help ()
@@ -95,3 +138,5 @@ async def run_shell():
95138 init_vectorstore ()
96139 if args .shell :
97140 asyncio .run (run_shell ())
141+ if args .web :
142+ asyncio .run (run_web ())
0 commit comments