aclnnAlltoAllvGroupedMatMul
支持的产品型号
当前版本暂不支持该接口。
函数原型
每个算子分为两段式接口,必须先调用“aclnnAlltoAllvGroupedMatMulGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnAlltoAllvGroupedMatMul”接口执行计算。
aclnnStatus aclnnAlltoAllvGroupedMatMulGetWorkspaceSize(const aclTensor* gmmX, const aclTensor* gmmWeight, const aclTensor* sendCountsTensorOptional, const aclTensor* recvCountsTensorOptional, const aclTensor* mmXOptional, const aclTensor* mmWeightOptional, const char* group, int64_t epWorldSize, const aclIntArray* sendCounts, const aclIntArray* recvCounts, bool transGmmWeight, bool transMmWeight, bool permuteOutFlag, aclTensor* gmmY, aclTensor* mmYOptional, aclTensor* permuteOutOptional, uint64_t* workspaceSize, aclOpExecutor** executor)
aclnnStatus aclnnAlltoAllvGroupedMatMul(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
功能说明
算子功能:完成路由专家GrouedMatMul,AlltoAllv与共享专家Matmul计算融合,先通信后计算。
计算公式:
路由专家:
共享专家:
aclnnAlltoAllvGroupedMatMulGetWorkspaceSize
参数说明:
- gmmX(aclTensor*,计算输入):该输入进行AlltoAllv通信,通信后结果作为GrouedMatMul计算的左矩阵。数据类型支持FLOAT16、BFLOAT16,支持2维,shape为(BSK, H1),数据格式支持ND。
- gmmWeight(aclTensor*,计算输入):GrouedMatMul计算的右矩阵。数据类型与gmmX保持一致,支持3维,shape为(e, H1, N1),数据格式支持ND。
- sendCountsTensorOptional(aclTensor*,计算输入):可选输入,数据类型支持INT32、INT64,shape为(e*epWorldSize,),数据格式支持ND。当前版本暂不支持,传nullptr。
- recvCountsTensorOptional(aclTensor*,计算输入):可选输入,数据类型支持INT32、INT64,shape为(e*epWorldSize,),数据格式支持ND。当前版本暂不支持,传nullptr。
- mmXOptional(aclTensor*,计算输入):可选输入,共享专家Matmul计算中的左矩阵。当需要融合共享专家矩阵计算时,该参数必选,支持2维,shape为(BS, H2)。
- mmWeightOptional(aclTensor*,计算输入):可选输入,共享专家Matmul计算中的右矩阵。当需要融合共享专家矩阵计算时,该参数必选,支持2维,shape为(H2, N2)。
- group(char*,计算输入):专家并行的通信域名,字符串长度要求(0, 128)。
- epWorldSize(int64_t,计算输入):ep通信域size,取值支持8、16、32、64。
- sendCounts(aclIntArray*,计算输入):表示发送给其他卡的token数,数据类型支持INT64,取值大小为e * epWorldSize,最大为256。
- recvCounts(aclIntArray*,计算输入):表示接收其他卡的token数,数据类型支持INT64,取值大小为e * epWorldSize,最大为256。
- transGmmWeight(bool, 计算输入):GrouedMatMul的右矩阵是否需要转置,true表示需要转置,false表示不转置。
- transMmWeight(bool, 计算输入):共享专家Matmul的右矩阵是否需要转置,true表示需要转置,false表示不转置。
- permuteOutFlag(bool, 计算输入):permuteOutOptional是否需要输出,true表明需要输出,false表明不需要输出。
- gmmY(aclTensor*, 计算输出):表示最终的计算结果,数据类型与输入gmmX保持一致,支持2维,shape为(A, N1)。
- mmYOptional(aclTensor*, 计算输出):共享专家Matmul的输出,数据类型与mmXOptional保持一致,支持2维,shape为(BS, N2)。仅当传入mmXOptional与mmWeightOptional才输出。
- permuteOutOptional(aclTensor*, 计算输出):permute之后的输出,数据类型与gmmX保持一致。
- workspaceSize(uint64_t*, 出参):返回需要在Device侧申请的workspace大小。
- executor(aclOpExecutor**, 出参):返回op执行器,包含了算子计算流程。
返回值:
aclnnStatus: 返回状态码,具体参见aclnn返回码。
第一段接口完成入参校验,出现以下场景时报错: 161001(ACLNN_ERR_PARAM_NULLPTR):1. 传入参数要求是必选输入、输出或者必选属性,但实际传入了空指针。 161002(ACLNN_ERR_PARAM_INVALID):1. gmmX、gmmWeight、sendCountsTensorOptional、recvCountsTensorOptional、mmXOptional、mmWeightOptional、group、epWorldSize、sendCounts、recvCounts的数据类型、数据格式或者维度不在支持的范围内;
aclnnAlltoAllvGroupedMatMul
参数说明:
- workspace(void*,入参):在Device侧申请的workspace内存地址。
- workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnAlltoAllvGroupedMatMulGetWorkspaceSize获取。
- executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
- stream(aclrtStream,入参):指定执行任务的AscendCL stream流。
返回值:
返回aclnnStatus状态码,具体参见aclnn返回码。
约束说明
参数说明里shape使用的变量:
- BSK:
- H1:表示hidden size隐藏层大小,取值范围(0, 65536)。
- H2:表示hidden size隐藏层大小,取值范围(0, 65536)。
- e:表示单卡上专家个数,e<=32,e * epWorldSize最大支持256。
- N1:取值范围(0, 65536)。
- N2:取值范围(0, 65536)。
- BS:batch sequence size,取值范围(0, 52428800)。
- K:表示选取TopK个专家,K的范围[2, 8]。
- A:表示recvCounts参数累加之和。
单卡通信量要求在2MB到100MB范围内。
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
#include <thread>
#include <iostream>
#include <string>
#include <vector>
#include "acl/acl.h"
#include "hccl/hccl.h"
#include "aclnnop/aclnn_allto_allv_grouped_mat_mul.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
#define LOG_PRINT(message, ...) \
do { \
printf(message, ##__VA_ARGS__); \
} while (0)
int64_t GetShapeSize(const std::vector<int64_t> &shape)
{
int64_t shape_size = 1;
for (auto i : shape) {
shape_size *= i;
}
return shape_size;
}
template <typename T>
int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
aclDataType dataType, aclTensor **tensor)
{
auto size = GetShapeSize(shape) * sizeof(T);
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret: %d\n", ret); return ret);
ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMemcpy failed. ret: %d\n", ret); return ret);
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
*tensor = aclCreateTensor(shape.data(),
shape.size(),
dataType,
strides.data(),
0,
aclFormat::ACL_FORMAT_ND,
shape.data(),
shape.size(),
*deviceAddr);
return 0;
}
struct Args {
int rankId;
HcclComm hcclComm;
aclrtStream stream;
aclrtContext context;
};
// shape 基本信息
constexpr int64_t EP_WORLD_SIZE = 8;
constexpr int64_t BS = 4096;
constexpr int64_t K = 1;
constexpr int64_t H = 7168;
constexpr int64_t e = 4;
constexpr int64_t N1 = 4096;
constexpr int64_t N2 = 4096;
constexpr int64_t A = BS * K;
std::vector<int16_t> pPermuteData(A *H, 0);
std::vector<int16_t> pGmmyData(A *N1, 0);
std::vector<int16_t> pmmXData(BS *H, 0);
std::vector<int16_t> pmmWData(H *N2, 0);
std::vector<int16_t> pmmYData(BS *N2, 0);
int LaunchOneThreadAlltoAllvGmm(Args &args)
{
int ret = aclrtSetCurrentContext(args.context);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetCurrentContext failed. ret: %d\n", ret); return ret);
char hcomName[128] = {0};
ret = HcclGetCommName(args.hcclComm, hcomName);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetEpCommName failed. ret: %d\n", ret); return -1);
std::vector<int64_t> gmmXShape = {BS * K, H};
std::vector<int64_t> gmmWShape = {e, H, N1};
std::vector<int64_t> gmmYShape = {A, N1};
std::vector<int64_t> permuteShape = {A, H};
std::vector<int64_t> mmXShape = {BS, H};
std::vector<int64_t> mmWShape = {H, N2};
std::vector<int64_t> mmYShape = {BS, N2};
std::vector<int64_t> sendCountsShape = {EP_WORLD_SIZE * e};
std::vector<int64_t> recvCountsShape = {EP_WORLD_SIZE * e};
std::vector<int64_t> sendCountsList(EP_WORLD_SIZE * e, BS * K / (EP_WORLD_SIZE * e));
std::vector<int64_t> recvCountsList(EP_WORLD_SIZE * e, BS * K / (EP_WORLD_SIZE * e));
void *gmmXDeviceAddr = nullptr;
void *gmmWDeviceAddr = nullptr;
void *gmmYDeviceAddr = nullptr;
void *permuteDeviceAddr = nullptr;
void *mmXDeviceAddr = nullptr;
void *mmWDeviceAddr = nullptr;
void *mmYDeviceAddr = nullptr;
aclTensor *gmmX = nullptr;
aclTensor *gmmW = nullptr;
aclTensor *gmmY = nullptr;
aclTensor *mmX = nullptr;
aclTensor *mmW = nullptr;
aclTensor *mmY = nullptr;
aclTensor *permute = nullptr;
aclTensor *sendCountsTensor = nullptr;
aclTensor *recvCountsTensor = nullptr;
uint64_t workspaceSize = 0;
aclOpExecutor *executor = nullptr;
void *workspaceAddr = nullptr;
long long gmmXShapeSize = GetShapeSize(gmmXShape);
long long gmmWShapeSize = GetShapeSize(gmmWShape);
long long gmmYShapeSize = GetShapeSize(gmmYShape);
long long permuteShapeSize = GetShapeSize(permuteShape);
long long mmXShapeSize = GetShapeSize(mmXShape);
long long mmWShapeSize = GetShapeSize(mmWShape);
long long mmYShapeSize = GetShapeSize(mmYShape);
std::vector<uint16_t> gmmXHostData(gmmXShapeSize, (args.rankId + 1) * 1024); // BF16, FP16
std::vector<uint16_t> gmmWHostData(gmmWShapeSize, (args.rankId + 1) * 512);
std::vector<uint16_t> gmmYHostData(gmmYShapeSize, 65535);
std::vector<uint16_t> permuteHostData(permuteShapeSize, 65535);
std::vector<uint16_t> mmXHostData(mmXShapeSize, (args.rankId + 1) * 1024); // BF16, FP16
std::vector<uint16_t> mmWHostData(mmWShapeSize, (args.rankId + 1) * 512);
std::vector<uint16_t> mmYHostData(mmYShapeSize, 0);
ret = CreateAclTensor(gmmXHostData, gmmXShape, &gmmXDeviceAddr, aclDataType::ACL_FLOAT16, &gmmX);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gmmWHostData, gmmWShape, &gmmWDeviceAddr, aclDataType::ACL_FLOAT16, &gmmW);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gmmYHostData, gmmYShape, &gmmYDeviceAddr, aclDataType::ACL_FLOAT16, &gmmY);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(permuteHostData, permuteShape, &permuteDeviceAddr, aclDataType::ACL_FLOAT16, &permute);
CHECK_RET(ret == ACL_SUCCESS, return ret);
aclIntArray *sendCounts = aclCreateIntArray(sendCountsList.data(), sendCountsList.size());
aclIntArray *recvCounts = aclCreateIntArray(recvCountsList.data(), recvCountsList.size());
// 调用第一阶段接口
ret = aclnnAlltoAllvGroupedMatMulGetWorkspaceSize(gmmX,
gmmW,
sendCountsTensor,
recvCountsTensor,
mmX,
mmW,
hcomName,
EP_WORLD_SIZE,
sendCounts,
recvCounts,
false,
false,
true,
gmmY,
mmY,
permute,
&workspaceSize,
&executor);
CHECK_RET(
ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnAlltoAllvGroupedMatMulGetWorkspaceSize failed. ret = %d \n", ret);
return ret);
if (workspaceSize > 0) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret);
}
// 调用第二阶段接口
ret = aclnnAlltoAllvGroupedMatMul(workspaceAddr, workspaceSize, executor, args.stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnAlltoAllvGroupedMatMul failed. ret = %d \n", ret);
return ret);
// (固定写法)同步等待任务执行结束
ret = aclrtSynchronizeStreamWithTimeout(args.stream, 10000000);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret);
return ret);
// 释放device资源,需要根据具体API的接口定义修改
if (args.rankId == 0) {
size_t size = A * H * sizeof(int16_t);
aclrtMemcpy(pPermuteData.data(), size, permuteDeviceAddr, size, ACL_MEMCPY_DEVICE_TO_HOST);
}
if (args.rankId == 0) {
size_t size = A * N1 * sizeof(int16_t);
aclrtMemcpy(pGmmyData.data(), size, gmmYDeviceAddr, size, ACL_MEMCPY_DEVICE_TO_HOST);
}
if (gmmX != nullptr) {
aclDestroyTensor(gmmX);
}
if (gmmW != nullptr) {
aclDestroyTensor(gmmW);
}
if (gmmY != nullptr) {
aclDestroyTensor(gmmY);
}
if (mmX != nullptr) {
aclDestroyTensor(mmX);
}
if (mmW != nullptr) {
aclDestroyTensor(mmW);
}
if (mmY != nullptr) {
aclDestroyTensor(mmY);
}
if (gmmXDeviceAddr != nullptr) {
aclrtFree(gmmXDeviceAddr);
}
if (gmmWDeviceAddr != nullptr) {
aclrtFree(gmmWDeviceAddr);
}
if (gmmYDeviceAddr != nullptr) {
aclrtFree(gmmYDeviceAddr);
}
if (mmXDeviceAddr != nullptr) {
aclrtFree(mmXDeviceAddr);
}
if (mmWDeviceAddr != nullptr) {
aclrtFree(mmWDeviceAddr);
}
if (mmYDeviceAddr != nullptr) {
aclrtFree(mmYDeviceAddr);
}
if (workspaceSize > 0) {
aclrtFree(workspaceAddr);
}
ret = aclrtDestroyStream(args.stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtDestroyStream failed. ret = %d \n", ret); return ret);
ret = aclrtResetDevice(args.rankId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtResetDevice failed. ret = %d \n", ret); return ret);
return 0;
}
int main(int argc, char *argv[])
{
int ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclInit failed. ret = %d \n", ret); return ret);
aclrtStream stream[EP_WORLD_SIZE];
aclrtContext context[EP_WORLD_SIZE];
for (uint32_t rankId = 0; rankId < EP_WORLD_SIZE; rankId++) {
ret = aclrtSetDevice(rankId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed. ret = %d \n", ret); return ret);
ret = aclrtCreateContext(&context[rankId], rankId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateContext failed. ret = %d \n", ret); return ret);
ret = aclrtCreateStream(&stream[rankId]);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed. ret = %d \n", ret); return ret);
}
int32_t devices[EP_WORLD_SIZE];
for (int i = 0; i < EP_WORLD_SIZE; i++) {
devices[i] = i;
}
// 初始化集合通信域
HcclComm comms[EP_WORLD_SIZE];
ret = HcclCommInitAll(EP_WORLD_SIZE, devices, comms);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclCommInitAll failed. ret = %d \n", ret); return ret);
Args args[EP_WORLD_SIZE];
// 启动多线程
std::vector<std::unique_ptr<std::thread>> threads(EP_WORLD_SIZE);
for (uint32_t rankId = 0; rankId < EP_WORLD_SIZE; rankId++) {
args[rankId].rankId = rankId;
args[rankId].hcclComm = comms[rankId];
args[rankId].stream = stream[rankId];
args[rankId].context = context[rankId];
threads[rankId].reset(new std::thread(&LaunchOneThreadAlltoAllvGmm, std::ref(args[rankId])));
}
for (uint32_t rankId = 0; rankId < EP_WORLD_SIZE; rankId++) {
threads[rankId]->join();
}
for (uint32_t rankId = 0; rankId < EP_WORLD_SIZE; rankId++) {
aclrtDestroyStream(comms[rankId]);
auto hcclRet = HcclCommDestroy(comms[rankId]);
CHECK_RET(hcclRet == HCCL_SUCCESS, LOG_PRINT("[ERROR] HcclCommDestroy failed. ret = %d \n", ret); return -1);
}
aclFinalize();
return 0;
}