昇腾社区首页
中文
注册

def __init__

函数功能

类初始化函数。

函数原型

def __init__(self, input_size: Tuple[int, int, int] = (None, None, None), in_channels: int = 4, caption_channels: int = 4096, enable_flash_attn: bool = True, enable_sequence_parallelism: bool = False, use_cache: bool = True, cache_interval: int = 2, cache_start: int = 3, num_cache_layer: int = 13, cache_start_steps: int = 5)

参数说明

参数名

输入/输出

类型

说明

input_size

输入

Tuple[int, int, int]

STDiT的输入latent_size大小,表示输入数据的尺寸,三元组(T, H, W)分别表示时间维度、高度和宽度。

in_channels

输入

int

每个像素的RGBA通道数,默认值为4。

可选范围为:[1, 4]。

caption_channels

输入

int

TextEncoder模型文本编码维度,默认值为4096。

可选范围为:[1, 4096]。

enable_flash_attn

输入

bool

控制是否使用Flash Attention技术来加速注意力计算,默认值为True。

enable_sequence_parallelism

输入

bool

控制是否在模型中使用序列并行化技术来加速计算,设置为True时需要使用多卡并行推理。默认值为False。

use_cache

输入

bool

是否启用缓存机制,默认值为True。

cache_interval

输入

int

缓存数据的间隔步数。不建议用户修改默认值,如需修改,需要保证不要超过迭代的最大步数。默认值为2。

cache_start

输入

int

开始缓存的Block层数。不建议用户修改默认值,如需修改,需要保证不要超过模型最大Block层数。默认值为3。

num_cache_layer

输入

int

缓存的Block层数。不建议用户修改默认值,如需修改,需要保证不要超过模型最大Block层数。默认值为13。

cache_start_steps

输入

int

开始缓存的迭代步数。不建议用户修改默认值,如需修改,需要保证不要超过迭代的最大步数。默认值为5。

返回值说明