(beta)torch_npu.npu_gru
须知:[object Object] 该接口计划废弃,可以使用
torch.gru接口进行替换。
产品支持情况
| 产品 | 是否支持 |
|---|---|
| [object Object]Atlas A3 训练系列产品[object Object] | √ |
| [object Object]Atlas A2 训练系列产品[object Object] | √ |
| [object Object]Atlas 训练系列产品[object Object] | √ |
| [object Object]Atlas 推理系列产品[object Object] | √ |
功能说明
计算DynamicGRUV2。
函数原型
[object Object]
参数说明
- input(
Tensor):数据类型支持float16;格式支持FRACTAL_NZ。 - hx(
Tensor):数据类型支持float16,float32;格式支持FRACTAL_NZ。 - weight_input(
Tensor):数据类型支持float16;格式支持FRACTAL_Z。 - weight_hidden(
Tensor):数据类型支持float16;格式支持FRACTAL_Z。 - bias_input(
Tensor):数据类型支持float16,float32;格式支持ND。 - bias_hidden(
Tensor):数据类型支持float16,float32;格式支持ND。 - seq_length(
Tensor):数据类型支持int32;格式支持ND。 - has_biases(
bool):默认值为True。 - num_layers(
int):层数。 - dropout(
float):丢弃概率。 - train(
bool):训练是否在op进行,默认值为True。 - bidirectional(
bool):默认值为True。 - batch_first(
bool):默认值为True。
返回值说明
- y(
Tensor):数据类型支持float16,float32;格式支持FRACTAL_NZ。 - output_h(
Tensor):数据类型支持float16,float32;格式支持FRACTAL_NZ。 - update(
Tensor):数据类型支持float16,float32;格式支持FRACTAL_NZ。 - reset(
Tensor):数据类型支持float16,float32;格式支持FRACTAL_NZ。 - new(
Tensor):数据类型支持float16,float32;格式支持FRACTAL_NZ。 - hidden_new(
Tensor):数据类型支持float16,float32;格式支持FRACTAL_NZ。
约束说明
接口暂不支持jit_compile=False,需要在该模式下使用时请将"DynamicGRUV2"添加至"NPU_FUZZY_COMPILE_BLACKLIST"选项内,具体操作可参考。