昇腾社区首页
中文
注册

ParameterConstraints(TorchRec

此接口为TorchRec开源接口,非Rec SDK Torch对外接口。此章节介绍使用Rec SDK Torch时调用的TorchRec接口支持的参数范围。

功能描述

指定分表计划的查询范围。

函数原型

1
2
class ParameterConstraints:
    def __init__(**kwargs):

参数说明

参数名

类型

可选/必选

说明

sharding_type

List[str]

必选

分表的类型。

取值范围:

"row_wise":按照行号进行分表。

compute_kernels

List[str]

必选

计算的kernel类型。

取值范围:

"fused":采用合表的方式查询。

min_partition

List[int]

可选

仅支持默认值为None,不支持用户自定义。

pooling_factors

List[float]

可选

仅支持默认值为None,不支持用户自定义。

num_poolings

List[float]

可选

仅支持默认值为None,不支持用户自定义。

batch_sizes

List[int]

可选

仅支持默认值为None,不支持用户自定义。

is_weighted

bool

可选

仅支持默认值为False,不支持用户自定义。

cache_params

torchrec.distributed.types.CacheParams

可选

仅支持默认值为None,不支持用户自定义。

enforce_hbm

bool

可选

仅支持默认值为None,不支持用户自定义。

stochastic_rounding

bool

可选

仅支持默认值为None,不支持用户自定义。

bounds_check_mode

enum.IntEnum

可选

仅支持默认值为None,不支持用户自定义。

feature_names

List[str]

可选

仅支持默认值为None,不支持用户自定义。

output_dtype

Enum

可选

仅支持默认值为None,不支持用户自定义。

device_group

str

可选

仅支持默认值为None,不支持用户自定义。

key_value_params

torchrec.distributed.types.KeyValueParams

可选

仅支持默认值为None,不支持用户自定义。

使用示例

1
2
3
4
from torchrec.distributed.planner import ParameterConstraints
constraints = {
   "table0": ParameterConstraints(sharding_types=["row_wise"], compute_kernels=["fused"])
}