Skip to content

我用你们提供的llama3 chinese版本推理速度超慢 #363

@fujingnan

Description

@fujingnan

torch: 2.1.2
flash-attn: 2.5.6
显卡A6000 48G
脚本:你们提供的quick_start.py:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device_map = "cuda:0" if torch.cuda.is_available() else "auto"
model = AutoModelForCausalLM.from_pretrained('./pretrain', device_map=device_map, torch_dtype=torch.float16,
                                             load_in_8bit=True, trust_remote_code=True, use_flash_attention_2=True)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained('./pretrain', use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
while True:
    text = input("请输入问题:")
    prompt = f'<s>Human: {text}\n</s><s>Assistant: '.replace('请输入问题:', '')
    input_ids = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).input_ids
    if torch.cuda.is_available():
        input_ids = input_ids.to('cuda')
    generate_input = {
        "input_ids": input_ids,
        "max_new_tokens": 512,
        "do_sample": True,
        "top_k": 50,
        "top_p": 0.95,
        "temperature": 0.3,
        "repetition_penalty": 1.3,
        "eos_token_id": tokenizer.eos_token_id,
        "bos_token_id": tokenizer.bos_token_id,
        "pad_token_id": tokenizer.pad_token_id
    }
    generate_ids = model.generate(**generate_input)
    text = tokenizer.decode(generate_ids[0])
    print(text)

简单问“请介绍一下中国”,大概需要1~2分钟左右返回结果,监控是用了显卡了的,为什么这么慢?尽管未使用vllm加速也不至于这么慢,是什么原因呢?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions