1+ import os
2+ import asyncio
3+ from lightrag import LightRAG , QueryParam
4+ from lightrag .utils import EmbeddingFunc
5+ import numpy as np
6+ from dotenv import load_dotenv
7+ import aiohttp
8+ import logging
9+
10+ logging .basicConfig (level = logging .INFO )
11+
12+ load_dotenv ()
13+
14+ AZURE_OPENAI_API_VERSION = os .getenv ("AZURE_OPENAI_API_VERSION" )
15+ AZURE_OPENAI_DEPLOYMENT = os .getenv ("AZURE_OPENAI_DEPLOYMENT" )
16+ AZURE_OPENAI_API_KEY = os .getenv ("AZURE_OPENAI_API_KEY" )
17+ AZURE_OPENAI_ENDPOINT = os .getenv ("AZURE_OPENAI_ENDPOINT" )
18+
19+ AZURE_EMBEDDING_DEPLOYMENT = os .getenv ("AZURE_EMBEDDING_DEPLOYMENT" )
20+ AZURE_EMBEDDING_API_VERSION = os .getenv ("AZURE_EMBEDDING_API_VERSION" )
21+
22+ WORKING_DIR = "./dickens"
23+
24+ if os .path .exists (WORKING_DIR ):
25+ import shutil
26+
27+ shutil .rmtree (WORKING_DIR )
28+
29+ os .mkdir (WORKING_DIR )
30+
31+
32+ async def llm_model_func (
33+ prompt , system_prompt = None , history_messages = [], ** kwargs
34+ ) -> str :
35+ headers = {
36+ "Content-Type" : "application/json" ,
37+ "api-key" : AZURE_OPENAI_API_KEY ,
38+ }
39+ endpoint = f"{ AZURE_OPENAI_ENDPOINT } openai/deployments/{ AZURE_OPENAI_DEPLOYMENT } /chat/completions?api-version={ AZURE_OPENAI_API_VERSION } "
40+
41+ messages = []
42+ if system_prompt :
43+ messages .append ({"role" : "system" , "content" : system_prompt })
44+ if history_messages :
45+ messages .extend (history_messages )
46+ messages .append ({"role" : "user" , "content" : prompt })
47+
48+ payload = {
49+ "messages" : messages ,
50+ "temperature" : kwargs .get ("temperature" , 0 ),
51+ "top_p" : kwargs .get ("top_p" , 1 ),
52+ "n" : kwargs .get ("n" , 1 ),
53+ }
54+
55+ async with aiohttp .ClientSession () as session :
56+ async with session .post (endpoint , headers = headers , json = payload ) as response :
57+ if response .status != 200 :
58+ raise ValueError (
59+ f"Request failed with status { response .status } : { await response .text ()} "
60+ )
61+ result = await response .json ()
62+ return result ["choices" ][0 ]["message" ]["content" ]
63+
64+
65+ async def embedding_func (texts : list [str ]) -> np .ndarray :
66+ headers = {
67+ "Content-Type" : "application/json" ,
68+ "api-key" : AZURE_OPENAI_API_KEY ,
69+ }
70+ endpoint = f"{ AZURE_OPENAI_ENDPOINT } openai/deployments/{ AZURE_EMBEDDING_DEPLOYMENT } /embeddings?api-version={ AZURE_EMBEDDING_API_VERSION } "
71+
72+ payload = {"input" : texts }
73+
74+ async with aiohttp .ClientSession () as session :
75+ async with session .post (endpoint , headers = headers , json = payload ) as response :
76+ if response .status != 200 :
77+ raise ValueError (
78+ f"Request failed with status { response .status } : { await response .text ()} "
79+ )
80+ result = await response .json ()
81+ embeddings = [item ["embedding" ] for item in result ["data" ]]
82+ return np .array (embeddings )
83+
84+
85+ async def test_funcs ():
86+ result = await llm_model_func ("How are you?" )
87+ print ("Resposta do llm_model_func: " , result )
88+
89+ result = await embedding_func (["How are you?" ])
90+ print ("Resultado do embedding_func: " , result .shape )
91+ print ("Dimensão da embedding: " , result .shape [1 ])
92+
93+
94+ asyncio .run (test_funcs ())
95+
96+ embedding_dimension = 3072
97+
98+ rag = LightRAG (
99+ working_dir = WORKING_DIR ,
100+ llm_model_func = llm_model_func ,
101+ embedding_func = EmbeddingFunc (
102+ embedding_dim = embedding_dimension ,
103+ max_token_size = 8192 ,
104+ func = embedding_func ,
105+ ),
106+ )
107+
108+ book1 = open ("./book_1.txt" , encoding = "utf-8" )
109+ book2 = open ("./book_2.txt" , encoding = "utf-8" )
110+
111+ rag .insert ([book1 .read (), book2 .read ()])
112+
113+ query_text = "What are the main themes?"
114+
115+ print ("Result (Naive):" )
116+ print (rag .query (query_text , param = QueryParam (mode = "naive" )))
117+
118+ print ("\n Result (Local):" )
119+ print (rag .query (query_text , param = QueryParam (mode = "local" )))
120+
121+ print ("\n Result (Global):" )
122+ print (rag .query (query_text , param = QueryParam (mode = "global" )))
123+
124+ print ("\n Result (Hybrid):" )
125+ print (rag .query (query_text , param = QueryParam (mode = "hybrid" )))
0 commit comments