(beta)torch_npu.npu_dropout_with_add_softmax
接口原型
torch_npu.npu_dropout_with_add_softmax(Tensor self, Tensor x1, Scalar alpha, float prob, int dim) -> (Tensor, Tensor, Tensor)
功能描述
实现axpy_v2、softmax_v2、drop_out_domask_v3功能。即:
y=x1+ self *alpha
Softmax(xi)= exp(xi)/∑jexp(xj)
output = 根据mask舍弃x中的元素,留下来的元素乘(1/prob)
参数说明
- Tensor self:4维张量,shape为(N, C, H, W)。
- Tensor x1:4维张量,shape为(N, C, H, W)。
约束说明
- self和x1的shape相同;
- H和W是[128, 256, 384, 512]其中之一;
- (N * C)%32结果为0;
- dim为-1。
调用示例
self = torch.rand(16, 16, 128, 128).npu()
tensor([[[[7.2556e-02, 3.0909e-01, 7.9734e-01, ..., 6.1179e-01,
6.2624e-03, 8.5186e-01],
[8.9196e-02, 3.3319e-01, 4.0780e-01, ..., 1.9144e-01,
2.2701e-01, 6.4018e-01],
[4.7275e-01, 7.4895e-01, 4.6215e-01, ..., 9.3753e-01,
6.6048e-02, 8.1877e-02],
...,
[7.9366e-01, 5.1516e-01, 5.6594e-01, ..., 1.6457e-01,
1.0640e-01, 3.4322e-03],
[1.5743e-02, 1.2893e-01, 5.8990e-01, ..., 4.1721e-01,
8.7816e-02, 6.8886e-01],
[4.2980e-01, 5.5447e-01, 3.1894e-01, ..., 9.2638e-01,
9.9324e-01, 4.6225e-01]],
[[6.2426e-01, 4.5948e-01, 1.0837e-01, ..., 8.9386e-01,
3.6932e-01, 1.2406e-01],
[9.1823e-01, 6.2311e-01, 5.1474e-01, ..., 2.1042e-01,
6.5943e-01, 3.1797e-01],
[5.2891e-01, 2.0183e-01, 2.1452e-01, ..., 9.1638e-01,
6.4109e-01, 9.4484e-01],
...,
[3.7783e-02, 1.3218e-01, 3.1192e-01, ..., 2.4931e-01,
4.8809e-01, 9.6085e-01],
[3.3197e-01, 9.1186e-02, 2.4839e-01, ..., 2.1156e-03,
6.4952e-01, 8.5996e-01],
[1.7941e-01, 5.1532e-01, 7.8133e-01, ..., 3.5526e-01,
5.3576e-01, 6.0538e-01]],
[[2.6743e-01, 7.4942e-01, 1.9146e-01, ..., 4.9179e-01,
6.3319e-01, 9.9269e-01],
[1.5163e-01, 3.7388e-01, 8.0604e-02, ..., 8.1193e-01,
1.7922e-01, 8.6578e-01],
[8.2558e-01, 9.5139e-01, 2.1313e-01, ..., 2.1722e-01,
2.8402e-01, 8.8888e-01],
...,
[1.8222e-01, 2.7645e-01, 6.7305e-01, ..., 6.8003e-01,
4.0917e-01, 7.6655e-01],
[3.1234e-01, 7.8519e-01, 8.8509e-01, ..., 7.2574e-01,
9.6134e-01, 2.2267e-01],
[4.9233e-01, 8.8407e-01, 7.4390e-01, ..., 5.2253e-02,
5.5150e-02, 4.4108e-02]],
...,
[[4.3370e-01, 2.1176e-01, 4.7512e-01, ..., 5.7611e-01,
3.2619e-01, 1.1523e-01],
[6.1469e-01, 7.4528e-01, 7.9559e-02, ..., 9.7112e-01,
1.8391e-01, 8.9883e-01],
[8.6677e-02, 3.5051e-02, 1.6875e-01, ..., 3.9833e-01,
6.7967e-01, 4.7062e-01],
...,
[7.1648e-01, 1.8378e-01, 5.3054e-01, ..., 8.4282e-01,
9.1972e-01, 7.0031e-01],
[5.9876e-01, 6.7868e-01, 6.4128e-01, ..., 4.9516e-02,
7.2571e-01, 5.8792e-01],
[7.6723e-01, 6.9527e-01, 9.3573e-01, ..., 6.3490e-02,
6.6129e-01, 2.4517e-01]],
[[5.0158e-01, 8.2565e-01, 7.5532e-01, ..., 6.9342e-01,
3.3244e-01, 5.3913e-01],
[2.3347e-01, 9.7822e-02, 1.5009e-01, ..., 5.5090e-01,
9.1813e-01, 7.9857e-01],
[7.2416e-02, 5.9086e-01, 1.2243e-01, ..., 7.8511e-01,
2.4803e-01, 5.3717e-01],
...,
[7.4899e-01, 1.5467e-02, 4.9711e-01, ..., 2.2938e-02,
1.6099e-01, 3.1928e-01],
[3.9111e-01, 1.2422e-01, 6.1795e-02, ..., 8.4212e-01,
6.1346e-01, 1.0957e-01],
[3.6311e-02, 8.9652e-01, 7.7428e-01, ..., 9.2212e-01,
4.9290e-01, 4.5609e-01]],
[[2.2052e-01, 4.4260e-01, 8.8627e-01, ..., 9.2381e-01,
7.7046e-01, 9.2057e-01],
[5.5775e-01, 8.8951e-01, 7.9238e-01, ..., 3.9209e-01,
9.6636e-01, 8.1876e-01],
[3.4709e-01, 7.8678e-01, 1.4396e-01, ..., 7.9073e-01,
3.9021e-01, 8.5285e-01],
...,
[1.4238e-01, 9.8432e-01, 2.7802e-01, ..., 5.1720e-01,
1.6290e-01, 8.2036e-01],
[2.0184e-01, 1.0635e-01, 1.9612e-01, ..., 9.7101e-01,
9.6679e-01, 7.0811e-01],
[5.8240e-01, 3.1642e-01, 9.6549e-01, ..., 5.1130e-02,
5.6725e-01, 3.5238e-01]]]], device='npu:0')
x1 = torch.rand(16, 16, 128, 128).npu()
tensor([[[[2.4353e-01, 8.5665e-01, 5.3571e-01, ..., 5.9101e-01,
4.0872e-01, 6.3873e-01],
[1.4489e-01, 8.7982e-01, 3.3114e-01, ..., 2.5155e-01,
8.4987e-01, 8.7096e-01],
[6.5837e-02, 2.2677e-02, 7.2063e-01, ..., 2.3542e-01,
9.3041e-01, 8.9596e-01],
...,
[5.1450e-01, 7.9412e-01, 8.9288e-01, ..., 3.3639e-01,
5.6086e-01, 4.8770e-02],
[4.7557e-01, 1.4793e-01, 4.9800e-01, ..., 3.9479e-01,
5.6052e-01, 9.8271e-01],
[7.4438e-01, 7.5646e-01, 2.7942e-02, ..., 3.0381e-01,
4.3703e-01, 1.4037e-02]],
[[4.0232e-01, 9.4407e-01, 6.4969e-01, ..., 3.4524e-01,
8.2647e-01, 5.4792e-01],
[1.1801e-01, 1.8281e-01, 6.1723e-01, ..., 1.9393e-01,
4.5877e-01, 8.9990e-01],
[2.6244e-01, 6.9614e-01, 3.6008e-01, ..., 5.0258e-01,
8.1919e-01, 4.6943e-01],
...,
[7.4710e-01, 5.8911e-01, 1.5292e-01, ..., 6.6590e-01,
4.0754e-01, 3.6944e-01],
[9.0501e-01, 2.7943e-01, 3.7068e-01, ..., 1.5053e-01,
7.3413e-01, 7.9626e-01],
[9.5200e-01, 7.8327e-01, 3.4033e-01, ..., 8.0892e-01,
4.0480e-01, 3.8717e-01]],
[[7.5938e-01, 2.9089e-01, 5.9916e-01, ..., 6.2526e-01,
3.9670e-01, 3.3548e-01],
[7.0733e-01, 8.1400e-01, 4.9259e-01, ..., 1.6607e-02,
6.5331e-01, 7.3150e-02],
[5.2770e-01, 7.8141e-01, 4.1904e-01, ..., 3.8917e-01,
4.1405e-01, 9.9596e-01],
...,
[4.8669e-01, 9.9948e-01, 1.2023e-01, ..., 7.0420e-01,
2.8522e-01, 6.6192e-01],
[4.9718e-01, 7.5792e-01, 6.6748e-01, ..., 9.7302e-01,
3.3443e-01, 3.6536e-01],
[7.7033e-01, 6.0550e-01, 8.2024e-01, ..., 2.9711e-01,
1.9410e-01, 6.6304e-01]],
...,
[[1.0284e-01, 6.5712e-01, 6.0831e-01, ..., 6.2622e-01,
2.0355e-01, 9.4250e-01],
[4.9053e-01, 2.0148e-01, 2.4974e-01, ..., 9.2521e-01,
1.9919e-01, 4.4700e-01],
[7.6515e-01, 8.7755e-01, 1.3500e-01, ..., 8.2136e-01,
2.0848e-01, 5.6432e-01],
...,
[3.3618e-01, 1.8585e-01, 5.3475e-01, ..., 4.9333e-01,
9.1018e-01, 9.5052e-01],
[2.1400e-01, 1.7407e-01, 5.8925e-01, ..., 7.5722e-01,
2.9850e-01, 3.9298e-01],
[6.3625e-01, 1.7168e-01, 2.9183e-01, ..., 9.9674e-01,
2.1718e-01, 5.2626e-01]],
[[1.8651e-01, 2.5385e-01, 2.0384e-01, ..., 3.4462e-01,
8.4150e-01, 4.7431e-01],
[2.4992e-01, 1.1788e-01, 1.9730e-01, ..., 4.3722e-02,
7.8943e-01, 9.9097e-01],
[1.4493e-02, 6.4856e-01, 8.3344e-01, ..., 8.6623e-01,
1.5456e-01, 7.8423e-01],
...,
[6.1458e-01, 4.4260e-01, 7.4133e-01, ..., 2.5126e-01,
2.7251e-01, 6.9784e-01],
[2.2419e-01, 3.4159e-01, 2.3232e-01, ..., 8.2850e-01,
8.2644e-02, 4.8390e-01],
[1.0171e-01, 8.7662e-01, 2.0457e-01, ..., 7.6868e-01,
7.6592e-01, 3.1254e-01]],
[[1.8866e-01, 1.5755e-01, 3.1025e-02, ..., 6.5044e-01,
7.8293e-01, 9.8030e-01],
[3.7703e-01, 5.3198e-01, 1.8633e-01, ..., 4.7398e-01,
8.3618e-01, 8.7283e-01],
[5.7119e-01, 4.3620e-01, 8.2536e-01, ..., 2.5390e-01,
5.6144e-01, 4.4044e-01],
...,
[1.3243e-01, 6.2002e-02, 7.5278e-01, ..., 7.5907e-01,
4.2472e-01, 1.7624e-01],
[4.7985e-01, 7.9769e-01, 8.1433e-01, ..., 7.3780e-01,
2.2877e-02, 4.8816e-01],
[4.5100e-01, 9.9698e-02, 7.0776e-01, ..., 9.8046e-01,
2.2372e-01, 8.6304e-01]]]], device='npu:0')
_, _, out = torch_npu.npu_dropout_with_add_softmax(self, x1, 2, 0.9, -1)
out
tensor([[[[0.0000, 0.0639, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0632, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0794, ..., 0.0000, 0.0000, 0.1571],
[0.0000, 0.0000, 0.0000, ..., 0.1270, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.1030, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.2134, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0342, 0.0000, 0.0633, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.1578, 0.1334, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.2316, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0237, 0.0000, ..., 0.0000, 0.2128, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.1421, 0.0000, 0.0000, ..., 0.0499, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0218, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.1461, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[0.0000, 0.1130, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.1976, ..., 0.0000, 0.0000, 0.0000]]]],
device='npu:0')
父主题: torch_npu