KeyedJaggedTensorWithTimestamp
功能描述
该接口是一个扩展自KeyedJaggedTensor的类,用于表示带有时间戳信息的Keyed Jagged Tensor。该类在KeyedJaggedTensor的基础上增加了一个_timestamps属性,存储与values对应的时间戳信息。用于特征淘汰时计算时间。
函数原型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | class KeyedJaggedTensorWithTimestamp(KeyedExtendedJaggedTensor[JaggedTensorWithTimestamp]): def __init__( self, keys: List[str], values: torch.Tensor, timestamps: Optional[torch.Tensor] = None, weights: Optional[torch.Tensor] = None, lengths: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, stride_per_key: Optional[List[int]] = None, length_per_key: Optional[List[int]] = None, lengths_offset_per_key: Optional[List[int]] = None, offset_per_key: Optional[List[int]] = None, index_per_key: Optional[Dict[str, int]] = None, jt_dict: Optional[Dict[str, JaggedTensor]] = None, inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, extra: Optional[torch.Tensor] = None, ) -> None: @staticmethod def from_jt_dict( jt_dict: Dict[str, JaggedTensorWithTimestamp], ) -> "KeyedJaggedTensorWithTimestamp": |
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
|---|---|---|---|
keys |
List[str] |
必选 |
表示键的列表,用于标识不同的Jagged Tensor。 |
values |
torch.Tensor |
必选 |
表示Keyed Jagged Tensor的值。 |
timestamps |
torch.Tensor |
可选 |
表示与values对应的时间戳信息。默认为None。 |
weights |
torch.Tensor |
可选 |
表示每个值的权重。默认为None。 |
lengths |
torch.Tensor |
可选 |
表示每个样本的长度。默认为None。 |
offsets |
torch.Tensor |
可选 |
表示每个样本的起始偏移量。默认为None。 |
stride |
int |
可选 |
表示步长。默认为None。 |
stride_per_key_per_rank |
List[List[int]] |
可选 |
表示每个键在每个秩上的步长。默认为None。 |
stride_per_key |
List[int] |
可选 |
表示每个键的步长。默认为None。 |
length_per_key |
List[int] |
可选 |
表示每个键的长度。默认为None。 |
lengths_offset_per_key |
List[int] |
可选 |
表示每个键的长度偏移量。默认为None。 |
offset_per_key |
List[int] |
可选 |
表示每个键的偏移量。默认为None。 |
index_per_key |
Dict[str, int] |
可选 |
表示键到索引的映射。默认为None。 |
jt_dict |
Dict[str, JaggedTensor] |
可选 |
表示包含JaggedTensor实例的字典。默认为None。 |
inverse_indices |
Tuple[List[str], torch.Tensor] |
可选 |
表示反向索引。默认为None。 |
extra |
torch.Tensor |
可选 |
用于兼容重构后的基类,通常为None。 |
from_jt_dict方法说明:
- 功能:从一个包含 JaggedTensorWithTimestamp 实例的字典构造一个新的 KeyedJaggedTensorWithTimestamp 实例。
- 参数:jt_dict:一个字典,键为字符串,值为 JaggedTensorWithTimestamp 实例。
- 返回值:一个新的 KeyedJaggedTensorWithTimestamp 实例,合并了字典中所有 JaggedTensorWithTimestamp 实例的数据。
使用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | import torch from torchrec_embcache.sparse.jagged_tensor_with_timestamp import KeyedJaggedTensorWithTimestamp, JaggedTensorWithTimestamp # 创建两个JaggedTensorWithTimestamp实例 jt1 = JaggedTensorWithTimestamp( values=torch.tensor([1, 2, 3]), lengths=torch.tensor([3]), timestamps=torch.tensor([100, 200, 300]) ) jt2 = JaggedTensorWithTimestamp( values=torch.tensor([4, 5]), lengths=torch.tensor([2]), timestamps=torch.tensor([400, 500]) ) # 创建一个字典 jt_dict = { "tensor1": jt1, "tensor2": jt2 } # 使用from_jt_dict方法构造新的KeyedJaggedTensorWithTimestamp实例 new_kjt = KeyedJaggedTensorWithTimestamp.from_jt_dict(jt_dict) # 访问属性 print("Keys:", new_kjt.keys) # 输出: ['tensor1', 'tensor2'] print("Values:", new_kjt.values) # 输出: tensor([1, 2, 3, 4, 5]) print("Timestamps:", new_kjt.timestamps) # 输出: tensor([100, 200, 300, 400, 500]) print("Lengths:", new_kjt.lengths) # 输出: tensor([3, 2]) # 直接初始化KeyedJaggedTensorWithTimestamp实例 kjt = KeyedJaggedTensorWithTimestamp( keys=["tensorA", "tensorB"], values=torch.tensor([10, 20, 30, 40]), lengths=torch.tensor([2, 2]), timestamps=torch.tensor([1000, 2000, 3000, 4000]) ) # 访问属性 print("Keys:", kjt.keys) # 输出: ['tensorA', 'tensorB'] print("Values:", kjt.values) # 输出: tensor([10, 20, 30, 40]) print("Timestamps:", kjt.timestamps) # 输出: tensor([1000, 2000, 3000, 4000]) print("Lengths:", kjt.lengths) # 输出: tensor([2, 2]) |