昇腾社区首页
中文
注册

Linear

  • 原始代码:
    qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    qkv = qkv(x)
  • 调用class Linear优化后的代码:
    from mindiesd import Linear
    qkv = Linear(dim, dim * 3, bias=qkv_bias, op_type="matmulv3")  # 初始化时新增op_type参数,值可选:{"matmulv3", "batchmatmulv3", "matmulv2", "batchmatmulv2"}。默认使用"matmulv2"
    qkv = qkv(x)