昇腾社区首页
中文
注册

activation

  • 原始代码:
    def get_activation_layer(act_type):
        if act_type == "gelu":
            return lambda: nn.GELU()
        elif act_type == "gelu_tanh":
            return lambda: nn.GELU(approximate="tanh")
        elif act_type == "relu":
            return nn.ReLU
        elif act_type == "silu":
            return nn.SiLU
        else:
            raise ValueError(f"Unknown activation type: {act_type}")
    act_func = get_activation_layer("silu")
  • 调用def get_activation_layer优化后的代码:
    from mindiesd import get_activation_layer
    act_func = get_activation_layer("silu")