Skip to content

Commit 6c19ca8

Browse files
author
johnson
committed
2024-12-27_17:08
1 parent b48a418 commit 6c19ca8

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

‎examples/LightRAG_utils.py‎

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,42 @@
1616
from functools import wraps
1717
from lightrag.utils import xml_to_json
1818
from neo4j import GraphDatabase
19+
from lingua import Language, LanguageDetectorBuilder #pip install lingua-language-detector
1920
from firecrawl import FirecrawlApp #pip install firecrawl-py
2021
import fitz # PyMuPDF
2122
import tika
2223
from tika import parser as tikaParser
24+
tika_path = "/media/wac/backup/john/johnson/LightRAG/examples/tika-server.jar"
2325
TIKA_SERVER_JAR = "file:////media/wac/backup/john/johnson/LightRAG/examples/tika-server.jar"
24-
if not os.path.exists(TIKA_SERVER_JAR):
26+
if not os.path.exists(tika_path):
2527
TIKA_SERVER_JAR = "file:////Users/admin/git/tika/tika-server-standard-2.9.0-bin/tika-server.jar"
2628
os.environ['TIKA_SERVER_JAR'] = TIKA_SERVER_JAR
2729

30+
def detect_language(content):
31+
"""
32+
检测文本语言
33+
Args:
34+
content ():
35+
英语,法语,德语,西班牙,中文,日语,韩语
36+
Returns:
37+
"""
38+
language_pair = {
39+
Language.ENGLISH: "english",
40+
Language.FRENCH: "french",
41+
Language.GERMAN: "german",
42+
Language.SPANISH: "spanish",
43+
Language.CHINESE: "chinese",
44+
Language.JAPANESE: "japanese",
45+
Language.KOREAN: "korean",
46+
}
47+
languages = [Language.ENGLISH, Language.FRENCH, Language.GERMAN, Language.SPANISH, Language.CHINESE,Language.JAPANESE,Language.KOREAN]
48+
detector = LanguageDetectorBuilder.from_languages(*languages).build()
49+
language = detector.detect_language_of(content)
50+
if language not in language_pair:
51+
print(f"��入数据{content}被检测成未知的语言,请修改language_pair进行兼���: {language}")
52+
language_str = language_pair.get(language, "english")
53+
return language_str
54+
2855
class MyFirecrawl():
2956
def __init__(self, api_key="EXAMPLE", api_url="http://127.0.0.1:3002"):
3057
"""

‎lightrag/lightrag.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,8 @@ def query(self, query: str, param: QueryParam = QueryParam()):
343343
loop = always_get_an_event_loop()
344344
return loop.run_until_complete(self.aquery(query, param))
345345

346-
async def aquery(self, query: str, param: QueryParam = QueryParam()):
346+
async def aquery(self, query: str, param: QueryParam = QueryParam(), history:list = []):
347+
# history: 历史聊天对话
347348
if param.mode == "local":
348349
response = await local_query(
349350
query,
@@ -353,6 +354,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()):
353354
self.text_chunks,
354355
param,
355356
asdict(self),
357+
history
356358
)
357359
elif param.mode == "global":
358360
response = await global_query(
@@ -363,6 +365,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()):
363365
self.text_chunks,
364366
param,
365367
asdict(self),
368+
history
366369
)
367370
elif param.mode == "hybrid":
368371
response = await hybrid_query(
@@ -373,6 +376,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()):
373376
self.text_chunks,
374377
param,
375378
asdict(self),
379+
history
376380
)
377381
elif param.mode == "naive":
378382
response = await naive_query(
@@ -381,6 +385,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()):
381385
self.text_chunks,
382386
param,
383387
asdict(self),
388+
history
384389
)
385390
else:
386391
raise ValueError(f"Unknown mode {param.mode}")

‎lightrag/operate.py‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ async def local_query(
397397
text_chunks_db: BaseKVStorage[TextChunkSchema],
398398
query_param: QueryParam,
399399
global_config: dict,
400+
history: list[dict] = [],
400401
) -> str:
401402
context = None
402403
use_model_func = global_config["llm_model_func"]
@@ -446,6 +447,7 @@ async def local_query(
446447
response = await use_model_func(
447448
query,
448449
system_prompt=sys_prompt,
450+
history_messages=history
449451
)
450452
if len(response) > len(sys_prompt):
451453
response = (
@@ -670,6 +672,7 @@ async def global_query(
670672
text_chunks_db: BaseKVStorage[TextChunkSchema],
671673
query_param: QueryParam,
672674
global_config: dict,
675+
history: list[dict] = [],
673676
) -> str:
674677
context = None
675678
use_model_func = global_config["llm_model_func"]
@@ -723,6 +726,7 @@ async def global_query(
723726
response = await use_model_func(
724727
query,
725728
system_prompt=sys_prompt,
729+
history_messages=history
726730
)
727731
if len(response) > len(sys_prompt):
728732
response = (
@@ -916,6 +920,7 @@ async def hybrid_query(
916920
text_chunks_db: BaseKVStorage[TextChunkSchema],
917921
query_param: QueryParam,
918922
global_config: dict,
923+
history: list[dict] = [],
919924
) -> str:
920925
low_level_context = None
921926
high_level_context = None
@@ -984,6 +989,7 @@ async def hybrid_query(
984989
response = await use_model_func(
985990
query,
986991
system_prompt=sys_prompt,
992+
history_messages=history
987993
)
988994
if len(response) > len(sys_prompt):
989995
response = (
@@ -1070,6 +1076,7 @@ async def naive_query(
10701076
text_chunks_db: BaseKVStorage[TextChunkSchema],
10711077
query_param: QueryParam,
10721078
global_config: dict,
1079+
history: list[dict] = [],
10731080
):
10741081
use_model_func = global_config["llm_model_func"]
10751082
results = await chunks_vdb.query(query, top_k=query_param.top_k)
@@ -1094,6 +1101,7 @@ async def naive_query(
10941101
response = await use_model_func(
10951102
query,
10961103
system_prompt=sys_prompt,
1104+
history_messages=history
10971105
)
10981106

10991107
if len(response) > len(sys_prompt):

0 commit comments

Comments
 (0)