昇腾社区首页
中文
注册
开发者
下载

activation

  • 原始代码:
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    def get_activation_layer(act_type):
        if act_type == "gelu":
            return lambda: nn.GELU()
        elif act_type == "gelu_tanh":
            return lambda: nn.GELU(approximate="tanh")
        elif act_type == "relu":
            return nn.ReLU
        elif act_type == "silu":
            return nn.SiLU
        else:
            raise ValueError(f"Unknown activation type: {act_type}")
    act_func = get_activation_layer("silu")
    
  • 调用def get_activation_layer优化后的代码:
    1
    2
    from mindiesd import get_activation_layer
    act_func = get_activation_layer("silu")