torch

API名称

是否支持

限制与说明

torch.is_tensor

  

torch.is_storage

  

torch.is_complex

支持判断,但当前硬件限制不支持复数。

torch.is_conj

  

torch.is_floating_point

  

torch.is_nonzero

  

torch.set_default_dtype

  

torch.get_default_dtype

  

torch.set_default_tensor_type

  

torch.numel

  

torch.set_printoptions

  

torch.set_flush_denormal

  

torch.tensor

  

torch.sparse_coo_tensor

  

torch.asarray

  

torch.as_tensor

  

torch.as_strided

  

torch.from_numpy

  

torch.frombuffer

  

torch.zeros

  

torch.zeros_like

  

torch.ones

  

torch.ones_like

  

torch.arange

  

torch.range

  

torch.linspace

  

torch.logspace

  

torch.eye

  

torch.empty

  

torch.empty_like

  

torch.empty_strided

  

torch.full

  

torch.full_like

  

torch.quantize_per_tensor

  

torch.quantize_per_channel

  

torch.dequantize(tensor) ->Tensor

  

torch.dequantize(tensors) ->sequence of Tensors

  

torch.complex

  

torch.polar

  

torch.heaviside

  

torch.adjoint

  

torch.argwhere

  

torch.cat

  

torch.concat

不支持float64,不支持8D输入。

torch.conj

  

torch.chunk

  

torch.dsplit

只支持float16,float32,float64。

torch.column_stack

  

torch.dstack

  

torch.gather

  

torch.hsplit

  

torch.hstack

  

torch.index_add

  

torch.index_select

  

torch.masked_select

  

torch.movedim

  

torch.moveaxis

  

torch.narrow

  

torch.nonzero

  

torch.permute

  

torch.reshape

  

torch.row_stack

  

torch.select

  

torch.scatter

  

torch.diagonal_scatter

  

torch.select_scatter

  

torch.slice_scatter

  

torch.scatter_add

  

torch.scatter_reduce

  

torch.split

  

torch.squeeze

  

torch.stack

  

torch.swapaxes

  

torch.swapdims

  

torch.t

  

torch.take

  

torch.take_along_dim

不支持float64与int。

torch.tensor_split

  

torch.tile

  

torch.transpose

  

torch.unbind

  

torch.unsqueeze

  

torch.vsplit

只支持float16,float32,float64。

torch.vstack

  

torch.where(condition, x, y) ->Tensor

  

torch.where(condition) ->tuple of LongTensor

  

torch.Generator

  

torch.Generator.get_state

  

torch.Generator.initial_seed

  

torch.Generator.manual_seed

  

torch.Generator.seed

  

torch.Generator.set_state

  

torch.seed

  

torch.manual_seed

  

torch.initial_seed

  

torch.get_rng_state

  

torch.set_rng_state

  

torch.bernoulli

  

torch.multinomial

  

torch.normal(mean, std, *, generator=None, out=None) ->Tensor

  

torch.normal(mean=0.0, std, *, out=None) ->Tensor

  

torch.normal(mean, std=1.0, *, out=None) ->Tensor

  

torch.normal(mean, std, size, *, out=None) ->Tensor

  

torch.poisson

可以在CPU实现。

torch.rand

  

torch.rand_like

不支持int64。

torch.randint

  

torch.randint_like

  

torch.randn

  

torch.randn_like

  

torch.randperm

  

torch.quasirandom.SobolEngine

  

torch.quasirandom.SobolEngine.draw

  

torch.quasirandom.SobolEngine.draw_base2

  

torch.quasirandom.SobolEngine.fast_forward

  

torch.quasirandom.SobolEngine.reset

  

torch.save

  

torch.load

  

torch.get_num_threads

  

torch.set_num_threads

  

torch.get_num_interop_threads

  

torch.set_num_interop_threads

  

torch.no_grad

  

torch.enable_grad

  

torch.set_grad_enabled

  

torch.is_grad_enabled

  

torch.inference_mode

  

torch.is_inference_mode_enabled

  

torch.abs

  

torch.absolute

  

torch.acos

不支持int32。

torch.arccos

不支持int64。

torch.acosh

不支持int64。

torch.arccosh

不支持int64。

torch.add

  

torch.addcdiv

  

torch.addcmul

  

torch.angle

  

torch.asin

  

torch.arcsin

不支持int64。

torch.asinh

  

torch.arcsinh

不支持int64。

torch.atan

  

torch.arctan

不支持int64。

torch.atanh

不支持int64。

torch.arctanh

  

torch.atan2

  

torch.arctan2

  

torch.bitwise_not

  

torch.bitwise_and

  

torch.bitwise_or

  

torch.bitwise_xor

  

torch.bitwise_left_shift

  

torch.bitwise_right_shift

  

torch.ceil

  

torch.clamp

  

torch.clip

  

torch.conj_physical

  

torch.copysign

  

torch.cos

  

torch.cosh

  

torch.deg2rad

  

torch.div

  

torch.divide

  

torch.digamma

可以在CPU实现。

torch.erf

  

torch.erfc

  

torch.erfinv

  

torch.exp

  

torch.exp2

  

torch.expm1

  

torch.fake_quantize_per_channel_affine

  

torch.fake_quantize_per_tensor_affine

  

torch.fix

  

torch.float_power

不支持float64。

torch.floor

  

torch.floor_divide

  

torch.fmod

  

torch.frac

  

torch.frexp

  

torch.gradient

不支持8D输入。

torch.imag

  

torch.ldexp

在int64带out场景下,out场景也必须是int64类型。

torch.lerp

  

torch.lgamma

可以走CPU实现。

torch.log

  

torch.log10

  

torch.log1p

  

torch.log2

  

torch.logaddexp

  

torch.logaddexp2

  

torch.logical_and

  

torch.logical_not

  

torch.logical_or

  

torch.logical_xor

  

torch.logit

可以走CPU实现。

torch.hypot

  

torch.i0

  

torch.igamma

  

torch.igammac

  

torch.mul

  

torch.multiply

  

torch.mvlgamma

可以走CPU实现。

torch.nan_to_num

  

torch.neg

  

torch.negative

  

torch.nextafter

  

torch.polygamma

  

torch.positive

  

torch.pow(input, exponent, *, out=None) ->Tensor

  

torch.pow(self, exponent, *, out=None) ->Tensor

  

torch.quantized_batch_norm

  

torch.quantized_max_pool1d

  

torch.quantized_max_pool2d

  

torch.rad2deg

不支持int64。

torch.real

  

torch.reciprocal

  

torch.remainder

  

torch.round

  

torch.rsqrt

  

torch.sigmoid

  

torch.sign

  

torch.sgn

  

torch.signbit

  

torch.sin

  

torch.sinc

  

torch.sinh

  

torch.sqrt

  

torch.square

  

torch.sub

  

torch.subtract

  

torch.tan

  

torch.tanh

  

torch.true_divide

  

torch.trunc

  

torch.xlogy

  

torch.argmax(input) ->LongTensor

  

torch.argmax(input, dim, keepdim=False) ->LongTensor

  

torch.argmin

  

torch.amax

  

torch.amin

  

torch.aminmax

  

torch.all(input) ->Tensor

  

torch.all(input, dim, keepdim=False, *, out=None) ->Tensor

  

torch.any(input) ->Tensor

  

torch.any(input, dim, keepdim=False, *, out=None) ->Tensor

  

torch.max(input) ->Tensor

  

torch.max(input, dim, keepdim=False, *, out=None)

  

torch.max(input, other, *, out=None) ->Tensor

  

torch.min(input) ->Tensor

  

torch.min(input, dim, keepdim=False, *, out=None)

  

torch.min(input, other, *, out=None) ->Tensor

  

torch.dist

  

torch.logsumexp

  

torch.mean(input, *, dtype=None) ->Tensor

  

torch.mean(input, dim, keepdim=False, *, dtype=None, out=None) ->Tensor

  

torch.nanmean

  

torch.median(input) ->Tensor

  

torch.median(input, dim=- 1, keepdim=False, *, out=None)

  

torch.nanmedian(input) ->Tensor

可以走CPU实现。

torch.nanmedian(input, dim=- 1, keepdim=False, *, out=None)

可以走CPU实现。

torch.mode

可以走CPU实现。

torch.norm

  

torch.nansum(input, *, dtype=None) ->Tensor

可以走CPU实现。

torch.nansum(input, dim, keepdim=False, *, dtype=None) ->Tensor

可以走CPU实现。

torch.prod(input, *, dtype=None) ->Tensor

  

torch.prod(input, dim, keepdim=False, *, dtype=None) ->Tensor

  

torch.quantile

  

torch.nanquantile

  

torch.std(input, dim, unbiased, keepdim=False, *, out=None) ->Tensor

如果输入tensor元素值相同,会产生精度误差。

torch.std(input, unbiased) ->Tensor

如果输入tensor元素值相同,会产生精度误差。

torch.std_mean(input, dim, unbiased, keepdim=False, *, out=None)

只支持float16,float32。

torch.std_mean(input, unbiased)

只支持float16,float32。

torch.sum(input, *, dtype=None) ->Tensor

  

torch.sum(input, dim, keepdim=False, *, dtype=None) ->Tensor

  

torch.unique

  

torch.unique_consecutive

传参时必须使用关键字,否则精度不达标。return_inverse=return_inverse,return_counts=return_counts,dim=dim。

torch.var(input, dim, unbiased, keepdim=False, *, out=None) ->Tensor

  

torch.var(input, unbiased) ->Tensor

  

torch.var_mean(input, dim, unbiased, keepdim=False, *, out=None)

  

torch.var_mean(input, unbiased)

  

torch.count_nonzero

  

torch.allclose

  

torch.argsort

  

torch.eq

  

torch.equal

  

torch.ge

  

torch.greater_equal

  

torch.gt

  

torch.greater

  

torch.isclose

  

torch.isfinite

  

torch.isin

  

torch.isinf

  

torch.isposinf

  

torch.isneginf

  

torch.isnan

  

torch.isreal

  

torch.kthvalue

  

torch.le

  

torch.less_equal

  

torch.lt

  

torch.less

  

torch.maximum

  

torch.minimum

  

torch.fmax

  

torch.fmin

  

torch.ne

  

torch.not_equal

  

torch.sort

  

torch.topk

  

torch.msort

  

torch.stft

  

torch.istft

  

torch.bartlett_window

  

torch.blackman_window

  

torch.hamming_window

  

torch.hann_window

  

torch.kaiser_window

  

torch.atleast_1d

  

torch.atleast_2d

  

torch.atleast_3d

  

torch.bincount

  

torch.block_diag

  

torch.broadcast_tensors

  

torch.broadcast_to

  

torch.broadcast_shapes

  

torch.bucketize

可以走CPU实现。

torch.cartesian_prod

  

torch.cdist

  

torch.clone

  

torch.combinations(input, r=2, with_replacement=False) ->seq

r不能大于8。

torch.corrcoef

不支持float64。

torch.cov

不支持bool。

torch.cross

  

torch.cummax

  

torch.cummin

  

torch.cumprod

  

torch.cumsum

  

torch.diag

仅支持diagonal=0场景。

torch.diag_embed

不支持复数。

torch.diagflat

  

torch.diagonal

  

torch.diff

  

torch.einsum

  

torch.flatten

  

torch.flip

  

torch.fliplr

  

torch.flipud

  

torch.kron

不支持5维度及以上输入。

torch.rot90

  

torch.gcd

  

torch.histc

  

torch.histogram

  

torch.histogramdd

  

torch.meshgrid

  

torch.lcm

  

torch.logcumsumexp

  

torch.ravel

  

torch.renorm

  

torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) ->Tensor

  

torch.repeat_interleave(repeats, *, output_size=None) ->Tensor

  

torch.roll

  

torch.searchsorted

  

torch.tensordot

  

torch.trace

  

torch.tril

  

torch.tril_indices

  

torch.triu

  

torch.triu_indices

  

torch.vander

  

torch.view_as_real

  

torch.view_as_complex

  

torch.resolve_conj

  

torch.resolve_neg

  

torch.addbmm

  

torch.addmm

  

torch.addmv

  

torch.addr

  

torch.baddbmm

  

torch.bmm

  

torch.chain_matmul

  

torch.cholesky

  

torch.cholesky_inverse

  

torch.cholesky_solve

  

torch.dot

  

torch.eig

  

torch.geqrf

  

torch.ger

  

torch.inner

不支持int64,out场景没适配。

torch.inverse

  

torch.det

  

torch.logdet

  

torch.slogdet

  

torch.lstsq

  

torch.lu

  

torch.lu_solve

  

torch.lu_unpack

  

torch.matmul

  

torch.matrix_power

  

torch.matrix_rank

不支持symmetric=True,out场景没适配。

torch.matrix_exp

  

torch.mm

  

torch.mv

  

torch.orgqr

  

torch.ormqr

  

torch.outer

  

torch.pinverse

  

torch.qr

  

torch.solve

  

torch.svd

  

torch.svd_lowrank

可以走CPU实现。

torch.pca_lowrank

可以走CPU实现。

torch.symeig

  

torch.lobpcg

  

torch.trapz

  

torch.trapezoid

  

torch.cumulative_trapezoid

只支持float16,float32。

torch.triangular_solve

  

torch.vdot

可以走CPU实现。

torch.compiled_with_cxx11_abi

  

torch.result_type

  

torch.can_cast

  

torch.promote_types

  

torch.use_deterministic_algorithms

  

torch.are_deterministic_algorithms_enabled

  

torch.is_deterministic_algorithms_warn_only_enabled

  

torch.set_deterministic_debug_mode

  

torch.get_deterministic_debug_mode

  

torch.set_warn_always

  

torch.is_warn_always_enabled

  

torch._assert