Skip to content

Commit b854ab4

Browse files
authored
Merge pull request HKUDS#49 from JGalego/feat/bedrock-support
feat: Amazon Bedrock support ⛰️
2 parents e7a7ff6 + 75a91d9 commit b854ab4

File tree

4 files changed

+204
-0
lines changed

4 files changed

+204
-0
lines changed

‎.gitignore‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
__pycache__
2+
*.egg-info
3+
dickens/
4+
book.txt

‎examples/lightrag_bedrock_demo.py‎

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
LightRAG meets Amazon Bedrock ⛰️
3+
"""
4+
5+
import os
6+
import logging
7+
8+
from lightrag import LightRAG, QueryParam
9+
from lightrag.llm import bedrock_complete, bedrock_embedding
10+
from lightrag.utils import EmbeddingFunc
11+
12+
logging.getLogger("aiobotocore").setLevel(logging.WARNING)
13+
14+
WORKING_DIR = "./dickens"
15+
if not os.path.exists(WORKING_DIR):
16+
os.mkdir(WORKING_DIR)
17+
18+
rag = LightRAG(
19+
working_dir=WORKING_DIR,
20+
llm_model_func=bedrock_complete,
21+
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
22+
embedding_func=EmbeddingFunc(
23+
embedding_dim=1024,
24+
max_token_size=8192,
25+
func=bedrock_embedding
26+
)
27+
)
28+
29+
with open("./book.txt", 'r', encoding='utf-8') as f:
30+
rag.insert(f.read())
31+
32+
for mode in ["naive", "local", "global", "hybrid"]:
33+
print("\n+-" + "-" * len(mode) + "-+")
34+
print(f"| {mode.capitalize()} |")
35+
print("+-" + "-" * len(mode) + "-+\n")
36+
print(
37+
rag.query(
38+
"What are the top themes in this story?",
39+
param=QueryParam(mode=mode)
40+
)
41+
)

‎lightrag/llm.py‎

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
import os
2+
import copy
3+
import json
4+
import botocore
5+
import aioboto3
6+
import botocore.errorfactory
27
import numpy as np
38
import ollama
49
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
@@ -48,6 +53,81 @@ async def openai_complete_if_cache(
4853
)
4954
return response.choices[0].message.content
5055

56+
57+
class BedrockError(Exception):
58+
"""Generic error for issues related to Amazon Bedrock"""
59+
60+
61+
@retry(
62+
stop=stop_after_attempt(5),
63+
wait=wait_exponential(multiplier=1, max=60),
64+
retry=retry_if_exception_type((BedrockError)),
65+
)
66+
async def bedrock_complete_if_cache(
67+
model, prompt, system_prompt=None, history_messages=[],
68+
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
69+
) -> str:
70+
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
71+
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
72+
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
73+
74+
# Fix message history format
75+
messages = []
76+
for history_message in history_messages:
77+
message = copy.copy(history_message)
78+
message['content'] = [{'text': message['content']}]
79+
messages.append(message)
80+
81+
# Add user prompt
82+
messages.append({'role': "user", 'content': [{'text': prompt}]})
83+
84+
# Initialize Converse API arguments
85+
args = {
86+
'modelId': model,
87+
'messages': messages
88+
}
89+
90+
# Define system prompt
91+
if system_prompt:
92+
args['system'] = [{'text': system_prompt}]
93+
94+
# Map and set up inference parameters
95+
inference_params_map = {
96+
'max_tokens': "maxTokens",
97+
'top_p': "topP",
98+
'stop_sequences': "stopSequences"
99+
}
100+
if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))):
101+
args['inferenceConfig'] = {}
102+
for param in inference_params:
103+
args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param)
104+
105+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
106+
if hashing_kv is not None:
107+
args_hash = compute_args_hash(model, messages)
108+
if_cache_return = await hashing_kv.get_by_id(args_hash)
109+
if if_cache_return is not None:
110+
return if_cache_return["return"]
111+
112+
# Call model via Converse API
113+
session = aioboto3.Session()
114+
async with session.client("bedrock-runtime") as bedrock_async_client:
115+
116+
try:
117+
response = await bedrock_async_client.converse(**args, **kwargs)
118+
except Exception as e:
119+
raise BedrockError(e)
120+
121+
if hashing_kv is not None:
122+
await hashing_kv.upsert({
123+
args_hash: {
124+
'return': response['output']['message']['content'][0]['text'],
125+
'model': model
126+
}
127+
})
128+
129+
return response['output']['message']['content'][0]['text']
130+
51131
async def hf_model_if_cache(
52132
model, prompt, system_prompt=None, history_messages=[], **kwargs
53133
) -> str:
@@ -145,6 +225,19 @@ async def gpt_4o_mini_complete(
145225
**kwargs,
146226
)
147227

228+
229+
async def bedrock_complete(
230+
prompt, system_prompt=None, history_messages=[], **kwargs
231+
) -> str:
232+
return await bedrock_complete_if_cache(
233+
"anthropic.claude-3-haiku-20240307-v1:0",
234+
prompt,
235+
system_prompt=system_prompt,
236+
history_messages=history_messages,
237+
**kwargs,
238+
)
239+
240+
148241
async def hf_model_complete(
149242
prompt, system_prompt=None, history_messages=[], **kwargs
150243
) -> str:
@@ -186,6 +279,71 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal
186279
return np.array([dp.embedding for dp in response.data])
187280

188281

282+
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
283+
# @retry(
284+
# stop=stop_after_attempt(3),
285+
# wait=wait_exponential(multiplier=1, min=4, max=10),
286+
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
287+
# )
288+
async def bedrock_embedding(
289+
texts: list[str], model: str = "amazon.titan-embed-text-v2:0",
290+
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) -> np.ndarray:
291+
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
292+
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
293+
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
294+
295+
session = aioboto3.Session()
296+
async with session.client("bedrock-runtime") as bedrock_async_client:
297+
298+
if (model_provider := model.split(".")[0]) == "amazon":
299+
embed_texts = []
300+
for text in texts:
301+
if "v2" in model:
302+
body = json.dumps({
303+
'inputText': text,
304+
# 'dimensions': embedding_dim,
305+
'embeddingTypes': ["float"]
306+
})
307+
elif "v1" in model:
308+
body = json.dumps({
309+
'inputText': text
310+
})
311+
else:
312+
raise ValueError(f"Model {model} is not supported!")
313+
314+
response = await bedrock_async_client.invoke_model(
315+
modelId=model,
316+
body=body,
317+
accept="application/json",
318+
contentType="application/json"
319+
)
320+
321+
response_body = await response.get('body').json()
322+
323+
embed_texts.append(response_body['embedding'])
324+
elif model_provider == "cohere":
325+
body = json.dumps({
326+
'texts': texts,
327+
'input_type': "search_document",
328+
'truncate': "NONE"
329+
})
330+
331+
response = await bedrock_async_client.invoke_model(
332+
model=model,
333+
body=body,
334+
accept="application/json",
335+
contentType="application/json"
336+
)
337+
338+
response_body = json.loads(response.get('body').read())
339+
340+
embed_texts = response_body['embeddings']
341+
else:
342+
raise ValueError(f"Model provider '{model_provider}' is not supported!")
343+
344+
return np.array(embed_texts)
345+
346+
189347
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
190348
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
191349
with torch.no_grad():

‎requirements.txt‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
aioboto3
12
openai
23
tiktoken
34
networkx

0 commit comments

Comments
 (0)