昇腾社区首页
中文
注册
开发者
下载

KeyedJaggedTensorWithTimestamp

功能描述

该接口是一个扩展自KeyedJaggedTensor的类,用于表示带有时间戳信息的Keyed Jagged Tensor。该类在KeyedJaggedTensor的基础上增加了一个_timestamps属性,存储与values对应的时间戳信息。用于特征淘汰时计算时间。

函数原型

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class KeyedJaggedTensorWithTimestamp(KeyedExtendedJaggedTensor[JaggedTensorWithTimestamp]):
    def __init__(
        self,
        keys: List[str],
        values: torch.Tensor,
        timestamps: Optional[torch.Tensor] = None,
        weights: Optional[torch.Tensor] = None,
        lengths: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
        stride: Optional[int] = None,
        stride_per_key_per_rank: Optional[List[List[int]]] = None,
        stride_per_key: Optional[List[int]] = None,
        length_per_key: Optional[List[int]] = None,
        lengths_offset_per_key: Optional[List[int]] = None,
        offset_per_key: Optional[List[int]] = None,
        index_per_key: Optional[Dict[str, int]] = None,
        jt_dict: Optional[Dict[str, JaggedTensor]] = None,
        inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
        extra: Optional[torch.Tensor] = None,
    ) -> None:
 
    @staticmethod
    def from_jt_dict(
        jt_dict: Dict[str, JaggedTensorWithTimestamp],
    ) -> "KeyedJaggedTensorWithTimestamp":

参数说明

参数名

类型

可选/必选

说明

keys

List[str]

必选

表示键的列表,用于标识不同的Jagged Tensor。

values

torch.Tensor

必选

表示Keyed Jagged Tensor的值。

timestamps

torch.Tensor

可选

表示与values对应的时间戳信息。默认为None。

weights

torch.Tensor

可选

表示每个值的权重。默认为None。

lengths

torch.Tensor

可选

表示每个样本的长度。默认为None。

offsets

torch.Tensor

可选

表示每个样本的起始偏移量。默认为None。

stride

int

可选

表示步长。默认为None。

stride_per_key_per_rank

List[List[int]]

可选

表示每个键在每个秩上的步长。默认为None。

stride_per_key

List[int]

可选

表示每个键的步长。默认为None。

length_per_key

List[int]

可选

表示每个键的长度。默认为None。

lengths_offset_per_key

List[int]

可选

表示每个键的长度偏移量。默认为None。

offset_per_key

List[int]

可选

表示每个键的偏移量。默认为None。

index_per_key

Dict[str, int]

可选

表示键到索引的映射。默认为None。

jt_dict

Dict[str, JaggedTensor]

可选

表示包含JaggedTensor实例的字典。默认为None。

inverse_indices

Tuple[List[str], torch.Tensor]

可选

表示反向索引。默认为None。

extra

torch.Tensor

可选

用于兼容重构后的基类,通常为None。

from_jt_dict方法说明:

  • 功能:从一个包含 JaggedTensorWithTimestamp 实例的字典构造一个新的 KeyedJaggedTensorWithTimestamp 实例。
  • 参数:jt_dict:一个字典,键为字符串,值为 JaggedTensorWithTimestamp 实例。
  • 返回值:一个新的 KeyedJaggedTensorWithTimestamp 实例,合并了字典中所有 JaggedTensorWithTimestamp 实例的数据。

使用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from torchrec_embcache.sparse.jagged_tensor_with_timestamp import KeyedJaggedTensorWithTimestamp, JaggedTensorWithTimestamp
 
# 创建两个JaggedTensorWithTimestamp实例
jt1 = JaggedTensorWithTimestamp(
    values=torch.tensor([1, 2, 3]),
    lengths=torch.tensor([3]),
    timestamps=torch.tensor([100, 200, 300])
)
 
jt2 = JaggedTensorWithTimestamp(
    values=torch.tensor([4, 5]),
    lengths=torch.tensor([2]),
    timestamps=torch.tensor([400, 500])
)
 
# 创建一个字典
jt_dict = {
    "tensor1": jt1,
    "tensor2": jt2
}
 
# 使用from_jt_dict方法构造新的KeyedJaggedTensorWithTimestamp实例
new_kjt = KeyedJaggedTensorWithTimestamp.from_jt_dict(jt_dict)
 
# 访问属性
print("Keys:", new_kjt.keys)         # 输出: ['tensor1', 'tensor2']
print("Values:", new_kjt.values)     # 输出: tensor([1, 2, 3, 4, 5])
print("Timestamps:", new_kjt.timestamps) # 输出: tensor([100, 200, 300, 400, 500])
print("Lengths:", new_kjt.lengths)   # 输出: tensor([3, 2])
 
# 直接初始化KeyedJaggedTensorWithTimestamp实例
kjt = KeyedJaggedTensorWithTimestamp(
    keys=["tensorA", "tensorB"],
    values=torch.tensor([10, 20, 30, 40]),
    lengths=torch.tensor([2, 2]),
    timestamps=torch.tensor([1000, 2000, 3000, 4000])
)
 
# 访问属性
print("Keys:", kjt.keys)         # 输出: ['tensorA', 'tensorB']
print("Values:", kjt.values)     # 输出: tensor([10, 20, 30, 40])
print("Timestamps:", kjt.timestamps) # 输出: tensor([1000, 2000, 3000, 4000])
print("Lengths:", kjt.lengths)   # 输出: tensor([2, 2])