1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116 | #include "aclnn/aclnn_operation_base.h"
#include "utils/log.h"
AclnnBaseOperation::AclnnBaseOperation(const std::string &opName) : opName_(opName)
{}
AclnnBaseOperation::~AclnnBaseOperation()
{
aclExecutor_ = nullptr;
}
std::string AclnnBaseOperation::GetName() const
{
return opName_;
}
atb::Status AclnnBaseOperation::Setup(
const atb::VariantPack &variantPack, uint64_t &workspaceSize, atb::Context *context)
{
LOG_INFO(opName_ + " setup start");
// 调用子类,创建输入输出tensor,并存入VariantPack
int ret = CreateAclnnVariantPack(variantPack);
if (ret != 0) {
LOG_ERROR(opName_ + " call CreateAclnnVariantPack fail, error: " + std::to_string(ret));
return atb::ERROR_INVALID_PARAM;
}
// 调用子类,获取Executor和Workspace
ret = SetAclnnWorkspaceExecutor();
if (ret != 0) {
LOG_ERROR(
opName_ + " call CreateAclnnVaSetAclnnWorkspaceExecutorriantPack fail, error: " + std::to_string(ret));
return atb::ERROR_INVALID_PARAM;
}
// 返回计算出的workspaceSize
workspaceSize = workspaceSize_;
LOG_INFO(opName_ + " setup end");
return ret;
}
atb::Status AclnnBaseOperation::Execute(
const atb::VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize, atb::Context *context)
{
LOG_INFO(opName_ + " execute start");
if (!context) {
LOG_ERROR(opName_ + " execute fail, context param is null");
return atb::ERROR_INVALID_PARAM;
}
aclrtStream stream = context->GetExecuteStream();
if (!stream) {
LOG_ERROR(opName_ + " execute fail, execute stream in context is null");
return atb::ERROR_INVALID_PARAM;
}
// 更新数据传入的地址
int ret = UpdateAclnnVariantPack(variantPack);
if (ret != 0) {
LOG_ERROR(opName_ + " call UpdateAclnnVariantPack fail, error: " + std::to_string(ret));
return atb::ERROR_CANN_ERROR;
}
LOG_INFO("Input workspaceSize " + std::to_string(workspaceSize) + " localCache workspaceSize " +
std::to_string(workspaceSize_));
ret = ExecuteAclnnOp(workspace, stream); // 调用aclnn接口
if (ret != 0) {
LOG_ERROR(opName_ + " call ExecuteAclnnOp fail, error: " + std::to_string(ret));
return atb::ERROR_CANN_ERROR;
}
LOG_INFO(opName_ + " execute start");
return ret;
}
atb::Status AclnnBaseOperation::UpdateAclnnVariantPack(const atb::VariantPack &variantPack)
{
// 更新inTensor的device地址
for (size_t i = 0; i < aclInTensors_.size(); ++i) {
int ret = -1;
if (!aclInTensors_[i]->needUpdateTensorDataPtr) {
continue;
}
aclInTensors_[i]->atbTensor = variantPack.inTensors.at(i);
ret = aclSetInputTensorAddr(aclExecutor_,
aclInTensors_[i]->tensorIdx,
aclInTensors_[i]->tensor,
aclInTensors_[i]->atbTensor.deviceData);
if (ret != 0) {
LOG_ERROR(
"inTensor " + std::to_string(i) + " call UpdateAclTensorDataPtr fail, error: " + std::to_string(ret));
return atb::ERROR_CANN_ERROR;
}
}
// 更新outTensor的device地址
for (size_t i = 0; i < aclOutTensors_.size(); ++i) {
int ret = -1;
if (!aclOutTensors_[i]->needUpdateTensorDataPtr) {
continue;
}
aclOutTensors_[i]->atbTensor = variantPack.outTensors.at(i);
ret = aclSetOutputTensorAddr(aclExecutor_,
aclOutTensors_[i]->tensorIdx,
aclOutTensors_[i]->tensor,
aclOutTensors_[i]->atbTensor.deviceData);
if (ret != 0) {
LOG_ERROR(
"outTensor " + std::to_string(i) + " call UpdateAclTensorDataPtr fail, error: " + std::to_string(ret));
return atb::ERROR_CANN_ERROR;
}
}
return atb::NO_ERROR;
}
|