JaggedTensor(TorchRec)

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。
功能描述
持有稀疏id和特征长度的类,用于查表。例如,values为[id1, id2, id3, id4],length为[1, 2, 1]。表示为id2和id3查表后的Embedding应该被pooling。
函数原型
1 2 | class JaggedTensor: def __init__(**kwargs): |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
values |
torch.Tensor[int64] |
必选 |
稀疏表查表ID。取值范围:[0,2^31]。 |
weights |
torch.Tensor |
可选 |
仅支持默认值为None,不支持用户自定义。 |
lengths |
torch.Tensor[int64] |
必选 |
每一个样本中的特征序列的长度。取值范围:[1,10000]。 需保证lengths的总和与values的长度相等。 须知:
Rec SDK Torch目前不支持可变batchsize,一个训练任务中的所有JaggedTensor的lengths的长度必须一致。 |
offsets |
torch.Tensor[int64] |
可选 |
offsets是lengths累加的结果。 offsets的第一位为0,后续位数为lengths的累加。默认值为None。 |
使用示例
1 2 | from torchrec import JaggedTensor JaggedTensor(values=[1, 3, 4], lengths=[1, 1, 1], offsets=[0, 1, 2, 3]) |
参考资源
接口调用流程及示例,参见迁移与训练。
父主题: 数据接口