-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Open
Description
torch: 2.1.2
flash-attn: 2.5.6
显卡A6000 48G
脚本:你们提供的quick_start.py:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device_map = "cuda:0" if torch.cuda.is_available() else "auto"
model = AutoModelForCausalLM.from_pretrained('./pretrain', device_map=device_map, torch_dtype=torch.float16,
load_in_8bit=True, trust_remote_code=True, use_flash_attention_2=True)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained('./pretrain', use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
while True:
text = input("请输入问题:")
prompt = f'<s>Human: {text}\n</s><s>Assistant: '.replace('请输入问题:', '')
input_ids = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).input_ids
if torch.cuda.is_available():
input_ids = input_ids.to('cuda')
generate_input = {
"input_ids": input_ids,
"max_new_tokens": 512,
"do_sample": True,
"top_k": 50,
"top_p": 0.95,
"temperature": 0.3,
"repetition_penalty": 1.3,
"eos_token_id": tokenizer.eos_token_id,
"bos_token_id": tokenizer.bos_token_id,
"pad_token_id": tokenizer.pad_token_id
}
generate_ids = model.generate(**generate_input)
text = tokenizer.decode(generate_ids[0])
print(text)
简单问“请介绍一下中国”,大概需要1~2分钟左右返回结果,监控是用了显卡了的,为什么这么慢?尽管未使用vllm加速也不至于这么慢,是什么原因呢?
Metadata
Metadata
Assignees
Labels
No labels