昇腾社区首页
中文
注册
开发者
下载

JaggedTensorWithTimestamp

功能描述

该接口是一个扩展自JaggedTensor的类,用于表示带有时间戳信息的Jagged Tensor。该类在JaggedTensor的基础上增加了一个_timestamps属性,存储与values对应的时间戳信息。用于特征淘汰时计算时间。

函数原型

1
2
3
4
5
6
7
8
9
 class JaggedTensorWithTimestamp(ExtendedJaggedTensor):
    def __init__(
        self,
        values: torch.Tensor,
        weights: Optional[torch.Tensor] = None,
        lengths: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
        timestamps: Optional[torch.Tensor] = None,
    ) -> None: 

参数说明

参数名

类型

可选/必选

说明

values

torch.Tensor

必选

表示Jagged Tensor的值。

weights

torch.Tensor

可选

表示每个值的权重。默认为None。

lengths

torch.Tensor

可选

表示每个样本的长度。默认为None。

offsets

torch.Tensor

可选

表示每个样本的起始偏移量。默认为None。

timestamps

torch.Tensor

可选

表示与values对应的时间戳信息。默认为None。

使用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import torch
from torchrec_embcache.sparse.jagged_tensor_with_timestamp import
JaggedTensorWithTimestamp
 
# 创建一个JaggedTensorWithTimestamp实例
jt = JaggedTensorWithTimestamp(
    values=torch.tensor([1, 2, 3, 4, 5]),
    lengths=torch.tensor([2, 3]),
    timestamps=torch.tensor([100, 200, 300, 400, 500])
)
 
# 访问属性
print("Values:", jt.values)         # 输出: tensor([1, 2, 3, 4, 5])
print("Timestamps:", jt.timestamps) # 输出: tensor([100, 200, 300, 400, 500])
print("Lengths:", jt.lengths)       # 输出: tensor([2, 3])