JaggedTensorWithTimestamp
功能描述
该接口是一个扩展自JaggedTensor的类,用于表示带有时间戳信息的Jagged Tensor。该类在JaggedTensor的基础上增加了一个_timestamps属性,存储与values对应的时间戳信息。用于特征淘汰时计算时间。
函数原型
1 2 3 4 5 6 7 8 9 | class JaggedTensorWithTimestamp(ExtendedJaggedTensor): def __init__( self, values: torch.Tensor, weights: Optional[torch.Tensor] = None, lengths: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, timestamps: Optional[torch.Tensor] = None, ) -> None: |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
|---|---|---|---|
values |
torch.Tensor |
必选 |
表示Jagged Tensor的值。 |
weights |
torch.Tensor |
可选 |
表示每个值的权重。默认为None。 |
lengths |
torch.Tensor |
可选 |
表示每个样本的长度。默认为None。 |
offsets |
torch.Tensor |
可选 |
表示每个样本的起始偏移量。默认为None。 |
timestamps |
torch.Tensor |
可选 |
表示与values对应的时间戳信息。默认为None。 |
使用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | import torch from torchrec_embcache.sparse.jagged_tensor_with_timestamp import JaggedTensorWithTimestamp # 创建一个JaggedTensorWithTimestamp实例 jt = JaggedTensorWithTimestamp( values=torch.tensor([1, 2, 3, 4, 5]), lengths=torch.tensor([2, 3]), timestamps=torch.tensor([100, 200, 300, 400, 500]) ) # 访问属性 print("Values:", jt.values) # 输出: tensor([1, 2, 3, 4, 5]) print("Timestamps:", jt.timestamps) # 输出: tensor([100, 200, 300, 400, 500]) print("Lengths:", jt.lengths) # 输出: tensor([2, 3]) |
父主题: 准入淘汰管理接口