Tensor并行的适配工作主要分为以下几个部分,切分逻辑需要结合大模型加速框架并行方案介绍理解:
此处只展示TransformerBlock的权重切分方式,主要是对SelfAttention层和FFN层的矩阵乘中的矩阵乘权重进行切分,切分方案为先“按列切分”再“按行切分”,具体可参考 “pytorch/examples/llama13b_parallel/cut_model_util.py”:
def cut_weights(model, world_size, cut_row_keys=['q_proj','k_proj','v_proj','gate_proj','up_proj'], cut_col_keys=['o_proj','down_proj']): state_dict_list=[{} for i in range(world_size)] for key, tensor in model.state_dict().items(): key_short=key.split('.')[-2] if key_short in cut_row_keys: # 按列切分 cut_tensor_list = torch.chunk(tensor,world_size,dim=0) elif key_short in cut_col_keys: # 按行切分 cut_tensor_list = torch.chunk(tensor,world_size,dim=1) else: cut_tensor_list=[tensor]*world_size for i in range(world_size): state_dict_list[i][key]=cut_tensor_list[i] return state_dict_list
if self.world_size >=2: torch.distributed.all_reduce(attn_output, op=torch.distributed.ReduceOp.SUM)
atb::Status LlamaLayerEncoderParallelOperation(const LlamaLayerEncoderParallelParam ¶m, atb::Operation **operation) { ... // 前置初始化图和构建Nodes,和非并行模型构图一致 atb::infer::LinearParallelParam mlpLinearParallelParam; // 并行LinearParalle创建 mlpLinearParallelParam.transWeight = false; // 和并行参数相关 mlpLinearParallelParam.rank = param.rank; mlpLinearParallelParam.rankSize = param.rankSize; mlpLinearParallelParam.rankRoot = 0; mlpLinearParallelParam.bias = "None"; mlpLinearParallelParam.parallelType = "RowParallel"; mlpLinearParallelParam.backend = "hccl"; // Opeation构建和连接和原始一致 CreateOperation(mlpLinearParallelParam, &mlpLinearParallelNode.operation); mlpLinearParallelNode.inTensorIds = {INTERMIDATE_MLPOUT, IN_MLPDOWNWEIGHT}; mlpLinearParallelNode.outTensorIds = {INTERMIDATE_MLPLINEARPARALLELOUT}; ... // 构建其他Node以及inferShapeFunc atb::CreateOperation(opGraph, operation); // 构建Layer return atb::NO_ERROR; }
双芯片涉及到多进程推理,可以参考推理模型库“pytorch/examples/llama13b_parallel/cut_model_and_run_llama.sh”脚本,核心代码如下:
export LOCAL_RANK=$RANK_ID export WORLD_SIZE=$WORLD_SIZE bind=${map["$RANK_ID"]} echo "Device ID: $RANK_ID, bind to NUMA node: $bind" numactl --cpunodebind=$bind --membind $bind \ python3 run_llama_half_parallel_loadPartModel.py --load_path $output_dir & # load_path为切分后权重/配置文件所在的路径
def setup_model_parallel(): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "12345" # 初始化HCCL,同时得到对应的local_rank/world_size值 local_rank = int(os.getenv("LOCAL_RANK", '0')) world_size = int(os.getenv("WORLD_SIZE", '0')) torch_npu.npu.set_device(local_rank) torch.distributed.init_process_group( backend='hccl', world_size=world_size, rank=local_rank) # seed must be the same in all processes torch.manual_seed(1) return local_rank, world_size