定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | struct LinearParallelParam { enum ParallelType : int { UNDEFINED = -1, LINEAR_ALL_REDUCE = 0, LINEAR_REDUCE_SCATTER = 1, ALL_GATHER_LINEAR = 2, PURE_LINEAR = 3, ALL_GATHER_LINEAR_REDUCE_SCATTER = 4, ALLTOALLVC_ALL_GATHER_GMM = 5, GMM_REDUCE_SCATTER_ALLTOALLVC = 6, MAX = 7, }; enum QuantType : int { QUANT_TYPE_UNDEFINED = -1, QUANT_TYPE_UNQUANT = -1, QUANT_TYPE_PER_TENSOR = 0, QUANT_TYPE_PER_CHANNEL = 1, QUANT_TYPE_PER_GROUP = 2, QUANT_TYPE_PER_TOKEN = 3, QUANT_TYPE_MAX = 4, }; struct MoeInfo { int16_t localExpertNums = 1; int8_t epSize = 1; int8_t tpSize = 1; }; bool transWeight = true; int rank = 0; int rankSize = 0; int rankRoot = 0; bool hasResidual = false; std::string backend = "hccl"; HcclComm hcclComm = nullptr; CommMode commMode = COMM_MULTI_PROCESS; std::string rankTableFile; ParallelType type = LINEAR_ALL_REDUCE; bool keepIntermediate = false; QuantType quantType = QUANT_TYPE_UNDEFINED; int32_t quantGroupSize = 0; aclDataType outDataType = ACL_DT_UNDEFINED; std::string commDomain; struct TwoDimTPInfo { uint16_t agDim = 0; uint16_t rsDim = 0; uint8_t innerDimIsAg = 1; uint8_t rsv[3] = {0}; }; TwoDimTPInfo twoDimTPInfo; MoeInfo moeInfo; uint8_t rsv[52] = {0}; }; |