开发者
资源

GRUQAT

产品支持情况

产品

是否支持

Atlas 350 加速卡

x

Atlas A3 训练系列产品/Atlas A3 推理系列产品

x

Atlas A2 训练系列产品/Atlas A2 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

Atlas 推理系列产品

x

Atlas 训练系列产品

x

功能说明

构造GRU的QAT算子。

函数原型

  • 直接构造接口:
    1
    qat = amct_pytorch.nn.module.quantization.gru.GRUQAT(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None, config=None)
    
  • 基于原生算子构造接口:
    1
    qat = amct_pytorch.nn.module.quantization.gru.GRUQAT.from_float(mod, config)
    

参数说明

表1 直接构造接口参数说明

参数名

输入/输出

说明

input_size

输入

含义:输入x中预期特征的数量。

数据类型:int

hidden_size

输入

含义:隐藏状态的特征数h。

数据类型:int

num_layers

输入

含义:循环层数,有几层LSTM算子,单算子中限制为1。

数据类型:int

默认值:1

bias

输入

含义:是否开启偏置项参与学习。

数据类型:bool

默认值:True

batch_first

输入

含义:是否将batch_size放到第一维,如果为True,则输入和输出张量将为(batch,seq,feature),否则为(seq, batch, feature)。注意只对x输入生效,不对h和c生效。(gru中为h)。

数据类型:bool

默认值:False

dropout

输入

含义:是否在除最后一层之外的每个LSTM层的输出上引入一个Dropout层。如果非零,则引入,单算子中限制为0。

数据类型:int或者float

默认值:0

bidirectional

输入

含义:是否为双向LSTM算子。如果为True,则成为双向LSTM,单算子中限制为False。

数据类型:bool

默认值:False

device

输入

含义:运行设备。

数据类型:string

默认值:None

dtype

输入

含义:torch数值类型。

数据类型:torch数据类型, 仅支持torch.float32

config

输入

含义:量化配置,配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明

DEFAULT_QAT_CONF = {
    "retrain_enable": True,
    "retrain_data_config": {
        "dst_type": "INT8", # INT8
        "batch_num": 1, # 大于0
        "fixed_min": True,
        "clip_max": 1.0,
        "clip_min": -1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8", # INT8
        "weights_retrain_algo": "arq_retrain", # arq_retrain/ulq_retrain
        "channel_wise": True
    }
}

数据类型:dict

默认值:None

表2 基于原生算子构造接口

参数名

输入/输出

说明

mod

输入

含义:待量化的原生GRU算子。

数据类型:torch.nn.Module

config

输入

含义:量化配置,配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明

config = {
    "retrain_enable":true,
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}

数据类型:dict

默认值:None

返回值说明

  • 直接构造:返回构造的QAT单算子实例。
  • 基于原生算子构造:torch.nn.Module转化后的QAT单算子。

调用示例

  • 直接构造:
    1
    2
    3
    4
    5
    from amct_pytorch.nn.module.quantization.gru import GRUQAT
    
    GRUQAT(input_size=3, hidden_size=3, num_layers=1, bias=True,
              batch_first=True, dropout=0, bidirectional=False,
              device=None, dtype=None, config=None)
    
  • 基于原生算子构造:
    1
    2
    3
    4
    5
    6
    7
    8
    import torch
    
    from amct_pytorch.nn.module.quantization.gru import GRUQAT
    
    gru_op = torch.nn.lstm(input_size=1, hidden_size=1, num_layers=1, bias=True,
                           batch_first=True, dropout=0, bidirectional=False,
                           device=None, dtype=None, config=None)
    GRUQAT.from_float(mod=lstm_op, config=None)