昇腾社区首页
中文
注册

JaggedTensor(TorchRec

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。

功能描述

持有稀疏id和特征长度的类,用于查表。例如,values为[id1, id2, id3, id4],length为[1, 2, 1]。表示为id2和id3查表后的Embedding应该被pooling。

函数原型

1
2
class JaggedTensor:
 def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

values

torch.Tensor[int64]

必选

稀疏表查表ID。取值范围:[0,2^31]。

weights

torch.Tensor

可选

仅支持默认值为None,不支持用户自定义。

lengths

torch.Tensor[int64]

必选

每一个样本中的特征序列的长度。取值范围:[1,10000]。

需保证lengths的总和与values的长度相等。

须知:

Rec SDK Torch目前不支持可变batchsize,一个训练任务中的所有JaggedTensor的lengths的长度必须一致。

offsets

torch.Tensor[int64]

可选

offsets是lengths累加的结果。

offsets的第一位为0,后续位数为lengths的累加。默认值为None。

使用示例

1
2
from torchrec import JaggedTensor
JaggedTensor(values=[1, 3, 4], lengths=[1, 1, 1], offsets=[0, 1, 2, 3])

参考资源

接口调用流程及示例,参见迁移与训练