@dataclass class BuildDataArgs: model_type_ = None model_provider_ = None train_valid_test_datasets_provider_ = None def set_build_data_args(model_type, model_provider, train_valid_test_datasets_provider): BuildDataArgs.model_type_ = model_type BuildDataArgs.model_provider_ = model_provider BuildDataArgs.train_valid_test_datasets_provider_ = train_valid_test_datasets_provider
以Megatron为例,获取Megatron框架中用以建立模型优化器及数据集的必要参数,客户需根据自身框架决定获取的必要参数。
def save_callback(step: int, save_info: list, train_args=None, ctx=None): model = train_args[TRAIN_PARAM][MODEL_INDEX] optimizer = train_args[TRAIN_PARAM][OPTIM_INDEX] opt_param_scheduler = train_args[TRAIN_PARAM][SCHEDULER_INDEX] global_args = get_args() # Update learning rate. if global_args.train_samples is None: global_args.consumed_train_samples = step * global_args.global_batch_size if train_args[TRAIN_PARAM][SCHEDULER_INDEX].num_steps != global_args.consumed_train_samples: train_args[TRAIN_PARAM][SCHEDULER_INDEX].step(global_args.global_batch_size) if hasattr(optimizer, 'optim_nums'): for _, info_dict in enumerate(save_info): optim_idx = info_dict.get("type", 0) rank_list = info_dict.get("ranks", None) save_rank = rank_list[0] optimizer.set_dump_args(optim_idx, save_rank, step, rank_list) else: rank_list = save_info[0].get("ranks", None) save_rank = rank_list[0] optimizer.set_dump_args(save_rank, step, rank_list) save_checkpoint(step, model, optimizer, opt_param_scheduler, global_args.num_floating_point_operations_so_far) def rename_callback(step: int, ctx=None): iteration = step rank = torch.distributed.get_rank() args = get_args() tmp_dir = 'iter_{:07d}_tmp'.format(iteration) fin_dir = 'iter_{:07d}'.format(iteration) src_path = os.path.join(args.save, tmp_dir) dst_path = os.path.join(args.save, fin_dir) try: src_check_ret, err_msg, src_abs_path = FileUtils.regular_file_path(src_path, args.save, False) dst_check_ret, err_msg, dst_abs_path = FileUtils.regular_file_path(dst_path, args.save, False) if (not src_check_ret) or (not dst_check_ret): raise FileNotFoundError(f"rank:{rank} rename path error. {err_msg}") os.rename(src_abs_path, dst_abs_path) except FileNotFoundError as e: print(f"{e}") return False # And update the latest iteration tracker_filename = get_checkpoint_tracker_filename(args.save) is_path_valid, err_msg, tracker_filename = FileUtils.regular_file_path(tracker_filename, "/", False) if not is_path_valid: print_rank_0(err_msg) raise Exception(" tracker_filename is not valid") with open(tracker_filename, 'w') as f: f.write(str(iteration)) return True def get_checkpoint_name(checkpoints_path, iteration, release=False, pipeline_parallel=None, tensor_rank=None, pipeline_rank=None, expert_parallel=None, expert_rank=None, return_base_dir=False, tmp=False): if release: directory = 'release' else: directory = 'iter_{:07d}'.format(iteration) if tmp: directory = directory + "_tmp" if return_base_dir: common_path = os.path.join(checkpoints_path, directory) return common_path # Use both the tensor and pipeline MP rank. if pipeline_parallel is None: pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) if tensor_rank is None: tensor_rank = mpu.get_tensor_model_parallel_rank() if pipeline_rank is None: pipeline_rank = mpu.get_pipeline_model_parallel_rank() if expert_parallel is None: expert_parallel = (mpu.get_expert_model_parallel_world_size() > 1) if expert_rank is None: expert_rank = mpu.get_expert_model_parallel_rank() # Use both the tensor and pipeline MP rank. If using the distributed # optimizer, then the optimizer's path must additionally include the # data parallel rank. if not pipeline_parallel: common_path = os.path.join(checkpoints_path, directory, f'mp_rank_{tensor_rank:02d}') else: common_path = os.path.join(checkpoints_path, directory, f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}') if expert_parallel: common_path = common_path + f'_{expert_rank:03d}' return os.path.join(common_path, "model_optim_rng.pt") def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far): """Save a model checkpoint.""" args = get_args() # Only rank zero of the data parallel writes to the disk. model = unwrap_model(model) ckpt_format = args.dist_ckpt_format if args.use_dist_ckpt else 'torch' print('rank {} is saving checkpoint at iteration {:7d} to {} in {} format' .format(args.rank, iteration, args.save, ckpt_format)) # Collect rng state across data parallel ranks. rng_state = get_rng_state(args.use_dist_ckpt) tmp = optimizer.get_error_dump() # Checkpoint name. checkpoint_name = get_checkpoint_name(args.save, iteration, return_base_dir=args.use_dist_ckpt, tmp=tmp) check_ret, err_msg, checkpoint_name = FileUtils.regular_file_path(checkpoint_name, '/', False) if not check_ret: raise Exception(f"rank {args.rank} get checkpoint name error, {err_msg}") # Save distributed optimizer's custom parameter state. if args.use_distributed_optimizer and not args.no_save_optim and optimizer is not None and not args.use_dist_ckpt: optim_checkpoint_name = \ get_distributed_optimizer_checkpoint_name(checkpoint_name) ensure_directory_exists(optim_checkpoint_name) optimizer.save_parameter_state(optim_checkpoint_name) if not args.use_distributed_optimizer and (args.fp16 or args.bf16): if not hasattr(optimizer.config, 'reuse_fp32_param'): optimizer._copy_main_params_to_model_params() cur_rank = torch.distributed.get_rank() save_flag = optimizer.need_write_file() # Collect args, model, RNG. if not torch.distributed.is_initialized() \ or save_flag \ or args.use_dist_ckpt: optim_sd_kwargs = {} if args.use_dist_ckpt and args.use_distributed_optimizer: optim_sd_kwargs['sharding_type'] = ('fully_sharded_bucket_space' if args.ckpt_fully_parallel_save else 'dp_zero_gather_scatter') print_rank_0(f'Storing distributed optimizer sharded state of type {optim_sd_kwargs["sharding_type"]}') state_dict = generate_state_dict(args, model, optimizer, opt_param_scheduler, rng_state, args.use_dist_ckpt, iteration, optim_sd_kwargs=optim_sd_kwargs) state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far # Save. ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name)
MindIO TTP功能沿用了Megatron框架中保存CheckPoint的主要逻辑,但是由于需要参与CheckPoint保存的worker不再是有序的,因此需要对CheckPoint文件的命名以及是否进行CheckPoint保存的判断逻辑进行修改,在临终保存之后文件名可能与Megatron框架中保存CheckPoint的文件名不同,由rename_callback实现文件改名以和框架中的命名规则一致。
def stop_callback(train_args=None, ctx=None):
# stop and clean device
device = get_current_device()
torch.npu.set_device(device)
torch_npu.npu.stop_device(device)
def clean_callback(is_uce_error=False, train_args=None, ctx=None):
"""
this function do:
1) get UCE check result from torch_npu
2) do some clear before rebuild (avoid OOM) when the check result is UCE_HIGH_LEVEL
3) HCCL resume and restart device
"""
device = get_device()
rank = torch.distributed.get_rank()
torch.npu.set_device(device)
ret = RET_OK
if is_uce_error:
check_memory_result = torch_npu.npu.check_uce_in_memory(device)
if check_memory_result == UCE_LOW_LEVEL: # no need rebuild
ret = RET_NO_REBUILD
elif check_memory_result == UCE_HIGH_LEVEL: # need rebuild
clear_memory(train_args[TRAIN_PARAM][MODEL_INDEX],
train_args[TRAIN_PARAM][OPTIM_INDEX],
train_args[TRAIN_PARAM][CONFIG_INDEX])
train_args[TRAIN_PARAM][MODEL_INDEX] = None
train_args[TRAIN_PARAM][OPTIM_INDEX] = None
train_args[TRAIN_PARAM][SCHEDULER_INDEX] = None
train_args[TRAIN_PARAM][CONFIG_INDEX] = None
ret = RET_OK
else: # exit
ret = RET_ERROR
if hasattr(torch.distributed, 'reinit_process_group'):
torch.distributed.reinit_process_group(group=None, rebuild_link=False)
torch.npu.restart_device(device)
return ret
def repair_callback(step: int, need_rebuild: bool, error_ranks: list, repair_info: dict,
train_args=None, ctx=None):
from mindio_ttp.framework_ttp import RepairType, OptimizerType
torch.npu.set_device(get_current_device())
optim_idxs = repair_info.get("type", OptimizerType.ATTENTION.value)
repair_type = repair_info.get("repair_type", None)
src_ranks = repair_info.get("src", [])
dest_ranks = repair_info.get("dst", [])
rank_list = repair_info.get("rank_list", [])
build_repair_group(rank_list)
if repair_type == RepairType.RT_SEND.value: # send
return send_rank_repair(src_ranks, dest_ranks, optim_idxs, step, rank_list, train_args[TRAIN_PARAM])
elif repair_type in [RepairType.RT_UCE_HIGHLEVEL.value,
RepairType.RT_UCE_LOWLEVEL.value,
RepairType.RT_RECV_REPAIR.value]:
return recv_rank_repair(src_ranks, dest_ranks, optim_idxs, step,
need_rebuild, rank_list, train_args[TRAIN_PARAM])
else:
return RET_ERROR
def rollback_callback(step: int, train_args=None, ctx=None):
args = get_args()
torch.npu.set_device(get_current_device())
global temp_memory_ckpt
rank = torch.distributed.get_rank()
# Update consumed train samples and learning rate.
if args.train_samples is None:
args.consumed_train_samples = step * args.global_batch_size
if train_args[TRAIN_PARAM][SCHEDULER_INDEX].num_steps != args.consumed_train_samples:
train_args[TRAIN_PARAM][SCHEDULER_INDEX].step(args.global_batch_size)
feature_rollback()
gather_model_params_from_optimizer(train_args[TRAIN_PARAM][OPTIM_INDEX], step)
rebuild_dataset(train_args)
set_memory_ckpt(None)
temp_memory_ckpt = None
destroy_repair_group()
return RET_OK
def feature_rollback():
# Other features re-initialize or roll back data after the fault is rectified.
if hasattr(args, "num_experts") and args.num_experts:
mpu._MOE_AUX_LOSSES_LOGGING_TRACKER = {} # 重新初始化全局Tensor
def send_rank_repair(src_ranks: list, dest_ranks: list, optim_idxs: list,
step: int, rank_list: list, train_args):
rank_list.sort()
rank_ = torch.distributed.get_rank()
repair_group = get_repair_group()
for idx, src_rank in enumerate(src_ranks):
dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx]
if src_rank != rank_:
return RET_ERROR
if src_rank == dest_rank:
return RET_ERROR
save_and_send_ckpt(src_rank, dest_rank, step, optim_idx, train_args)
for idx, _ in enumerate(src_ranks):
dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx]
train_args[OPTIM_INDEX].send_optim_param_state(dest_rank, repair_group, step, optim_idx)
# repair other global tensor
for idx, src_rank in enumerate(src_ranks):
torch.distributed.send(global_tensor, dst=dest_rank, group=repair_group)
return RET_OK
def save_and_send_ckpt(src_rank, dest_rank, step, optim_idx, train_args):
"""
Save memory checkpoint and send to dest rank.
"""
repair_group = get_repair_group()
state_dict = save_memory_ckpt(train_args[OPTIM_INDEX], train_args[SCHEDULER_INDEX],
step, optim_idx)
buffer = io.BytesIO()
torch.save(state_dict, buffer)
state_dict_bytes = buffer.getvalue()
state_dict_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(state_dict_bytes)).to('npu')
# Send tensor size first
size_tensor = torch.tensor([state_dict_tensor.numel()], dtype=torch.long).to('npu')
torch.distributed.send(size_tensor, dst=dest_rank, group=repair_group)
# Send the serialized state_dict tensor
torch.distributed.send(state_dict_tensor, dst=dest_rank, group=repair_group)
return RET_OK
def recv_rank_repair(src_ranks: list, dest_ranks: list, optim_idxs: list,
step: int, need_rebuild: bool, rank_list: list, train_args):
# low level and only rollback case
if src_ranks == dest_ranks:
return RET_OK
rank_ = torch.distributed.get_rank()
rank_list.sort()
if need_rebuild:
build_local_embedding_group()
model, optimizer, lr_scheduler, config = rebuild_model_and_optimizer(
BuildDataArgs.model_provider_, BuildDataArgs.model_type_)
train_args[MODEL_INDEX] = model
train_args[OPTIM_INDEX] = optimizer
train_args[SCHEDULER_INDEX] = lr_scheduler
train_args[CONFIG_INDEX] = config
for idx, src_rank in enumerate(src_ranks):
dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx]
if dest_rank != rank_:
return RET_ERROR
if src_rank != dest_rank:
recv_ckpt_from_peer(src_rank, dest_rank, step, rank_list)
# combine state_dict and once load,fix precision problem
state_dict = get_memory_ckpt()
load_memory_ckpt(train_args[MODEL_INDEX], train_args[OPTIM_INDEX], train_args[SCHEDULER_INDEX],
state_dict, None)
repair_group = get_repair_group()
for idx, src_rank in enumerate(src_ranks):
dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx]
train_args[OPTIM_INDEX].recv_and_load_optim_param_state(src_rank, repair_group, step, optim_idx)
# repair other global tensor
for idx, src_rank in enumerate(src_ranks):
recv_tensor = torch.empty(size, dtype=type, device="npu")
torch.distributed.recv(recv_tensor, src=src_rank, group=repair_group)
global_tesor.data.copy_(recv_tensor)
return RET_OK
def recv_ckpt_from_peer(src_rank, dest_rank, step, rank_list: list):
"""
receive memory checkpoint and repair train() param.
"""
repair_group = get_repair_group()
# Receive tensor size first
size_tensor = torch.tensor([0], dtype=torch.long, device='npu')
torch.distributed.recv(size_tensor, src=src_rank, group=repair_group)
size = size_tensor.item()
# Receive the serialized state_dict tensor
state_dict_tensor = torch.empty(size, dtype=torch.uint8, device='npu')
torch.distributed.recv(state_dict_tensor, src=src_rank, group=repair_group)
# Deserialize the state_dict
state_dict_bytes = state_dict_tensor.cpu().numpy().tobytes()
buffer = io.BytesIO(state_dict_bytes)
device_count = torch.npu.device_count()
if device_count == 0:
raise ValueError(f"torch.npu.device_count return 0!")
map_location = {
'npu:' + str(src_rank % device_count): 'npu:' + str(dest_rank % device_count)
}
loaded_state_dict = torch.load(buffer, map_location=map_location)
set_memory_ckpt(loaded_state_dict)
def rebuild_model_and_optimizer(model_provider, model_type):
args = get_args()
load = args.load
args.load = None
ori_embedding_group = mpu._EMBEDDING_GROUP
mpu._EMBEDDING_GROUP = get_local_embedding_group()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type)
mpu._EMBEDDING_GROUP = ori_embedding_group
args.load = load
config = get_model_config(model[0])
return model, optimizer, lr_scheduler, config
def rebuild_dataset(args):
# repair data iterator
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_data_iterator(args[TRAIN_PARAM][MODEL_INDEX])
args[TRAIN_PARAM][TRAIN_DATA_INDEX] = train_data_iterator
args[TRAIN_PARAM][VALID_DATA_INDEX] = valid_data_iterator
args[TEST_DATA_ITER][0] = test_data_iterator
return RET_OK
def build_data_iterator(model):
args = get_args()
if args.virtual_pipeline_model_parallel_size is not None:
train_data_iterator = []
valid_data_iterator = []
test_data_iterator = []
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
iterators = build_train_valid_test_data_iterators(
BuildDataArgs.train_valid_test_datasets_provider_)
train_data_iterator.append(iterators[0])
valid_data_iterator.append(iterators[1])
test_data_iterator.append(iterators[2])
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
BuildDataArgs.train_valid_test_datasets_provider_)
return train_data_iterator, valid_data_iterator, test_data_iterator
def update_arf_reboot_flag(new_state): # 对于MindIO ARF中增量启动的节点,选择修复过程中接收的优化器数据而不是加载故障前已保存的CheckPoint数据
global ARF_REBOOT_FLAG, LOAD, PRETRAINED_CHECKPOINT
args = get_args()
if new_state:
LOAD = args.load
PRETRAINED_CHECKPOINT = args.pretrained_checkpoint
args.load = None
args.pretrained_checkpoint = None
elif ARF_REBOOT_FLAG:
args.load = LOAD
args.pretrained_checkpoint = PRETRAINED_CHECKPOINT
ARF_REBOOT_FLAG = new_state
def arf_rebuild_process_group_callback(fault_ranks: list, train_args, ctx):
models, optimizer = train_args[TRAIN_PARAM][MODEL_INDEX], train_args[TRAIN_PARAM][OPTIM_INDEX]
args = get_args()
update_arf_reboot_flag(False)
# 1.1 destroy_all_process_group
torch.distributed.destroy_process_group()
# 1.2 init_all_process_group
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
timeout=timedelta(minutes=args.distributed_timeout_minutes),
)
#2.1 destroy mpu group
delete_mpu_group(mpu)
# 2.2 destroy global memory buffer
mpu.destroy_global_memory_buffer()
# 3.1 init_mpu
mpu.initialize_model_parallel(
context_parallel_size=args.context_parallel_size,
tensor_model_parallel_size=args.tensor_model_parallel_size,
expert_model_parallel_size=args.expert_model_parallel_size,
distributed_timeout_minutes=args.distributed_timeout_minutes,
pipeline_model_parallel_size=args.pipeline_model_parallel_size,
nccl_communicator_config_path=args.nccl_communicator_config_path,
pipeline_model_parallel_split_rank=args.pipeline_model_parallel_split_rank,
virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size,
)
update_model_and_optim_related_group(models, optimizer)
def delete_mpu_group(mpu_module):
group_space = [mpu_group for mpu_group in dir(mpu_module) if 'GROUP' in mpu_group]
for group_name in group_space:
setattr(mpu_module, group_name, None)
def update_model_and_optim_related_group(models, optimizer):
for model in models:
# dense
for buffer in model.buffers:
# fix ParamAndGradBuffer attributes
buffer.data_parallel_group = mpu._DATA_PARALLEL_GROUP_WITH_CP
buffer.data_parallel_world_size = torch.distributed.get_world_size(group=mpu._DATA_PARALLEL_GROUP_WITH_CP)
for bucket in buffer.buckets:
# fix Bucket attributes
bucket.data_parallel_group = mpu._DATA_PARALLEL_GROUP_WITH_CP
bucket.data_parallel_world_size = torch.distributed.get_world_size(
group=mpu._DATA_PARALLEL_GROUP_WITH_CP)
bucket.data_parallel_rank = torch.distributed.get_rank(mpu._DATA_PARALLEL_GROUP_WITH_CP)
# moe
for expert_buffer in model.expert_parallel_buffers:
# fix ParamAndGradBuffer attributes
expert_buffer.data_parallel_group = mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP
expert_buffer.data_parallel_world_size = \
torch.distributed.get_world_size(group=mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP)
for bucket in expert_buffer.buckets:
# fix Bucket attributes
bucket.data_parallel_group = mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP
bucket.data_parallel_world_size = torch.distributed.get_world_size(
group=mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP)
bucket.data_parallel_rank = torch.distributed.get_rank(mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP)
if hasattr(optimizer, 'optim_nums'): # 根据需求更新优化器中持有的通信组
optimizer.chained_optimizers[0].ori_dp_group = mpu._DATA_PARALLEL_GROUP_WITH_CP
optimizer.chained_optimizers[1].ori_dp_group = mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP
else:
optimizer.ori_dp_group = mpu._DATA_PARALLEL_GROUP_WITH_CP
if not get_args().use_distributed_optimizer: # 非分布式优化器不持有副本通信组
return
# fix optimizer attributes
if hasattr(optimizer, 'optim_nums'): # 分布式优化器会持有副本组用于副本优化器相关的通信
optimizer.chained_optimizers[0].data_parallel_group = ttp_get_dp_cp_replica_group()
optimizer.chained_optimizers[0].data_parallel_group_gloo = ttp_get_dp_cp_replica_group_gloo()
optimizer.chained_optimizers[0].ori_dp_list = torch.distributed.get_process_group_ranks(
mpu._DATA_PARALLEL_GROUP_WITH_CP)
optimizer.chained_optimizers[1].data_parallel_group = ttp_get_dp_ep_replica_group()
optimizer.chained_optimizers[1].data_parallel_group_gloo = ttp_get_dp_ep_replica_group_gloo()
optimizer.chained_optimizers[1].ori_dp_list = torch.distributed.get_process_group_ranks(
mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP)
else:
optimizer.data_parallel_group = ttp_get_dp_cp_replica_group()
optimizer.data_parallel_group_gloo = ttp_get_dp_cp_replica_group_gloo()
optimizer.ori_dp_list = torch.distributed.get_process_group_ranks(mpu._DATA_PARALLEL_GROUP_WITH_CP)
return
#(可选)若训练中涉及切流操作,重置训练中使用的流,确保调度正确
class CustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
torch.cuda.set_stream(torch.cuda.default_stream())
return input
@staticmethod
def backward(ctx, grad):
torch.cuda.set_stream(torch.cuda.default_stream())
return grad
def recover_set_stream():
input_tensor = torch.empty(1, dtype=torch.float32, device="npu", requires_grad=True)
grad_tensor = torch.empty(1, dtype=torch.float32, device="npu", requires_grad=True)
output_tensor = CustomFunction.apply(input_tensor)
output_tensor.backward(grad_tensor)
stop&clean callback为MindIO UCE和MindIO ARF必要的回调函数,stop_callback用于通过torch_npu.npu.stop_device接口令其他worker暂停训练,待修复完成后所有worker共同重新进入训练;clean_callback中,如果当前故障为UCE故障,则需要通过torch_npu.npu.check_uce_in_memory接口检测UCE故障类型,若为UCE_HIGH_LEVEL类型,则本次修复需要重建模型优化器以重新申请内存,刷新Tensor状态以完全修复UCE故障;之后通过torch.npu.restart_device接口重新启动device。
repair_callback用于实现互为副本的worker进行数据修复,故障worker重建模型优化器并接收并加载正常副本worker的CheckPoint数据,正常副本worker保存并发送CheckPoint数据。
在repair完成后,所有worker共同进入rollback_callback中,在rollback阶段,将优化器中的参数取出到模型,还需要将其他训练中用到的参数回滚至正确的位置,如学习率、数据集。
MindIO ARF功能依赖MindIO UCE修复中用到的数据传输功能,arf_rebuild_process_group_callback用于MindIO ARF功能在增量重启worker后重建通信组,并替换模型优化器中持有的通信组,确保在故障修复后的训练中使用重建的通信组。
此外,对于其他需要修复的全局Tensor,若其与训练迭代无依赖关系,则可在rollback阶段对其重新初始化,将其释放;若其与训练迭代有依赖关系,且与副本优化器映射关系一致时,则需要在repair阶段对其进行重建并通过点对点通信修复该Tensor的数据。