昇腾社区首页
中文
注册

def forward

函数功能

生成视频帧。

函数原型

def forward(self, x: torch.Tensor, timestep: torch.Tensor, y: torch.Tensor, mask: torch.Tensor = None, x_mask: torch.Tensor = None, fps: torch.Tensor = None, height: torch.Tensor = None, width: torch.Tensor = None, t_idx: int = 0, **kwargs) -> torch.Tensor:

参数说明

参数名

输入/输出

类型

说明

x

输入

torch.Tensor

噪声张量,当前仅支持5维输入。

timestep

输入

torch.Tensor

时间步张量,当前仅支持1维输入。

y

输入

torch.Tensor

文本编码后的prompts输入,当前仅支持4维输入。

mask

输入

torch.Tensor

text文本掩码,当前仅支持2维输入。

可选输入,默认为None。

x_mask

输入

torch.Tensor

空间维度掩码,当前仅支持2维输入。

可选输入,默认为None。

fps

输入

torch.Tensor

每秒生成视频的帧数,要求1维输入(B)。当输入为None时,代表每秒生成视频的帧数为8。

可选输入,默认为None。

height

输入

torch.Tensor

生成视频的高,要求1维输入(B)。当输入为None时,代表生成视频的高为720。

可选输入,默认为None。

width

输入

torch.Tensor

生成视频的宽,要求1维输入(B)。当输入为None时,代表生成视频的宽为1280。

可选输入,默认为None。

t_idx

输入

int

当前迭代步索引,当配置use_cache为True时必须输入,用户需要保证输入的值小于最大迭代步数。

可选输入,默认值为0。

kwargs

输入

-

其他参数。

返回值说明

返回生成的视频帧。