File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -693,13 +693,17 @@ async def bedrock_embedding(
693693
694694
695695async def hf_embedding (texts : list [str ], tokenizer , embed_model ) -> np .ndarray :
696+ device = next (embed_model .parameters ()).device
696697 input_ids = tokenizer (
697698 texts , return_tensors = "pt" , padding = True , truncation = True
698- ).input_ids
699+ ).input_ids . to ( device )
699700 with torch .no_grad ():
700701 outputs = embed_model (input_ids )
701702 embeddings = outputs .last_hidden_state .mean (dim = 1 )
702- return embeddings .detach ().numpy ()
703+ if embeddings .dtype == torch .bfloat16 :
704+ return embeddings .detach ().to (torch .float32 ).cpu ().numpy ()
705+ else :
706+ return embeddings .detach ().cpu ().numpy ()
703707
704708
705709async def ollama_embedding (texts : list [str ], embed_model , ** kwargs ) -> np .ndarray :
You can’t perform that action at this time.
0 commit comments