Skip to content

Commit 87aad35

Browse files
authored
Merge pull request HKUDS#266 from davidleon/fix_hf_embedding_device
fix hf embedding to support loading to different device
2 parents 3d37888 + 38e1956 commit 87aad35

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

‎lightrag/llm.py‎

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,13 +693,17 @@ async def bedrock_embedding(
693693

694694

695695
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
696+
device = next(embed_model.parameters()).device
696697
input_ids = tokenizer(
697698
texts, return_tensors="pt", padding=True, truncation=True
698-
).input_ids
699+
).input_ids.to(device)
699700
with torch.no_grad():
700701
outputs = embed_model(input_ids)
701702
embeddings = outputs.last_hidden_state.mean(dim=1)
702-
return embeddings.detach().numpy()
703+
if embeddings.dtype == torch.bfloat16:
704+
return embeddings.detach().to(torch.float32).cpu().numpy()
705+
else:
706+
return embeddings.detach().cpu().numpy()
703707

704708

705709
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:

0 commit comments

Comments
 (0)