def forward
函数功能
生成视频帧。
函数原型
def forward(self, x: torch.Tensor, timestep: torch.Tensor, y: torch.Tensor, mask: torch.Tensor = None, x_mask: torch.Tensor = None, fps: torch.Tensor = None, height: torch.Tensor = None, width: torch.Tensor = None, t_idx: int = 0, **kwargs) -> torch.Tensor:
参数说明
参数名 |
输入/输出 |
类型 |
说明 |
---|---|---|---|
x |
输入 |
torch.Tensor |
噪声张量,当前仅支持5维输入。 |
timestep |
输入 |
torch.Tensor |
时间步张量,当前仅支持1维输入。 |
y |
输入 |
torch.Tensor |
文本编码后的prompts输入,当前仅支持4维输入。 |
mask |
输入 |
torch.Tensor |
text文本掩码,当前仅支持2维输入。 可选输入,默认为None。 |
x_mask |
输入 |
torch.Tensor |
空间维度掩码,当前仅支持2维输入。 可选输入,默认为None。 |
fps |
输入 |
torch.Tensor |
每秒生成视频的帧数,要求1维输入(B)。当输入为None时,代表每秒生成视频的帧数为8。 可选输入,默认为None。 |
height |
输入 |
torch.Tensor |
生成视频的高,要求1维输入(B)。当输入为None时,代表生成视频的高为720。 可选输入,默认为None。 |
width |
输入 |
torch.Tensor |
生成视频的宽,要求1维输入(B)。当输入为None时,代表生成视频的宽为1280。 可选输入,默认为None。 |
t_idx |
输入 |
int |
当前迭代步索引,当配置use_cache为True时必须输入,用户需要保证输入的值小于最大迭代步数。 可选输入,默认值为0。 |
kwargs |
输入 |
- |
其他参数。 |
返回值说明
返回生成的视频帧。