昇腾社区首页
中文
注册

精度验证

  • 精度验证可以基于问答判断文本生成功能。
    prompt = ["Common sense questions and answers\n\nQuestion: What is the capital of France\nFactual answer:"]
    inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=SEQ_LEN_IN)
    with torch.no_grad():
        generate_ids = model.generate(inputs.input_ids.npu(), attention_mask=inputs.attention_mask.npu(), max_new_tokens=SEQ_LEN_OUT)
    res = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    for item in res:
        print(item)
  • 比较Pytorch的输出结果和加速库Model的输出结果。
    outputs = self.decoder_layer(self.inputs)  # 原始Pytorch输出结果
    outputs_acl = self.acl_decoder_operation.execute(self.inputs_acl, self.param)  # 加速库Model输出结果
    compare_result = torch.allclose(outputs_acl[0], outputs[0], atol=0.02, rtol=0.02)
  • 基于数据集进行端到端精度验证,如基于C-Eval数据集的精度验证,可参考“pytorch/examples/chatglm2_6b/evaluate_ceval.py”