Host侧接口,Power接口通用tiling函数,用于获取Power接口能完成计算所需最小的临时空间大小,此空间为预留空间,即需要保证预留有足够的物理空间,用于执行计算。
内部根据srcShape1、srcShape2判断输入类型,判断接口为Power(dstTensor, srcTensor1, srcTensor2)、Power(dstTensor, srcTensor1, scalarValue) 或 Power(dstTensor, scalarValue, srcTensor2),进而返回对应临时空间大小。
inline uint32_t GetPowerMinTmpSize(const ge::Shape srcShape1, const ge::Shape srcShape2, const uint32_t typeSize, const bool isReuseSource)
接口 |
输入/输出 |
功能 |
---|---|---|
srcShape1 |
输入 |
输入1的shape信息。 |
srcShape2 |
输入 |
输入2的shape信息。 |
typeSize |
输入 |
计算数据类型的单位字节数,比如half类型为2,float类型为4。 |
isReuseSource |
输入 |
是否允许修改源操作数。 |
返回Power接口能完成计算所需最小临时空间大小。
Atlas A2训练系列产品
完整的调用样例请参考更多样例。
// srcTensor1、srcTensor2输入shape信息均为512;算子输入的数据类型为float;不允许修改源操作数 std::vector<int64_t> shape_dims = { 512 }; auto powShape = ge::Shape(shape_dims); uint32_t tmp_size = GetPowerMinTmpSize(powShape, powShape, 4, false);
输出数据: 12288
// srcTensor1输入shape信息为128*128,scalarValue的shape为1;算子输入的数据类型为half;不允许修改源操作数 std::vector<int64> shape1_vec = {128,128}; std::vector<int64> shape2_vec = {1}; ge::Shape shape1(shape1_vec); ge::Shape shape2(shape2_vec); auto tmp_size = GetPowerMinTmpSize(shape1, shape2, 2, false);
输出数据: 28672
//scalarValue的shape为1,srcTensor2输入shape信息为128*128;算子输入的数据类型为float;不允许修改源操作数 std::vector<int64> shape1_vec = {1}; std::vector<int64> shape2_vec = {128,128}; ge::Shape shape1(shape1_vec); ge::Shape shape2(shape2_vec); auto tmp_size = GetPowerMinTmpSize(shape1, shape2, 4, false);
输出数据: 16384