def get_torch_tensor(self, mode="max_shape", device="cpu")

函数功能

生成一个与Input中shape相等的元素值全为1的torch.tensor(若为动态shape可以指定输出shape等于max_shape或者min_shape)。

函数原型

def get_torch_tensor(self, mode="max_shape", device="cpu")

参数说明

参数名

参数类型

输入/输出

说明

mode

str

输入

对于动态shape场景,指定输出的tensor的shape对齐min_shape或max_shape;可选。

默认值:"max_shape"。

device

str

输入

指定生成的torch.tensor的device属性;可选。

默认值:"cpu"。

返回值说明

返回生成的数据全为1的torch.tensor。