用于获取FasterGelu 接口能完成计算所需最大或者最小的临时空间大小,此空间为预留空间,即需要保证预留有足够的物理空间,用于执行计算。
inline uint32_t GetGeluMaxTmpSize(const ge::Shape srcShape, const uint32_t typeSize);
inline uint32_t GetGeluMinTmpSize(const ge::Shape srcShape, const uint32_t typeSize);
inline void GetGeluMaxMinTmpSize(const ge::Shape srcShape, const uint32_t typeSize, uint32_t& maxValue, uint32_t& minValue);
接口 |
输入/输出 |
功能 |
---|---|---|
srcShape |
输入 |
输入的shape信息。 |
typeSize |
输入 |
算子输入的数据类型大小,单位为字节。比如算子输入的数据类型为half,此处应传入2,即sizeof(half)。 |
maxValue |
输出 |
输出FasterGelu接口所需的tiling信息(最大临时空间大小)。 |
minValue |
输出 |
输出FasterGelu接口所需的tiling信息(最小临时空间大小)。 |
GetGeluMaxTmpSize返回FasterGelu接口能完成计算所需最大临时空间大小。
GetGeluMinTmpSize返回FasterGelu接口能完成计算所需最小临时空间大小。
GetGeluMaxMinTmpSize返回FasterGelu接口能完成计算所需最大和最小临时空间大小。
Atlas A2训练系列产品
// 输入shape信息为1024;算子输入的数据类型为half; std::vector<int64> shape_vec = {1024}; ge::Shape srcShape(shape_vec); uint32_t typeSize = sizeof(half); uint32_t maxValue = 0; uint32_t minValue = 0; GetGeluMaxMinTmpSize(srcShape, typeSize, maxValue, minValue);
输出数据: maxValue: 2048 minValue: 512