昇腾社区首页
中文
注册

cache

cache算法可以基于相邻迭代采样步骤间、或相邻block间的激活相似性,复用模型局部特征,减少冗余计算,从而显著加速推理。

  1. 调用class CacheAgentclass CacheConfig接口。
    from mindiesd import CacheConfig, CacheAgent
  2. 初始化CacheConfig。
    • 若使用dit_block_cache
      config = CacheConfig(
                  method="dit_block_cache",
                  blocks_count=len(transformer.single_blocks), # 使能cache的block的个数
                  steps_count=args.infer_steps,                # 模型推理的总迭代步数
                  step_start=args.cache_start_steps,           # 开始进行cache的步数索引
                  step_interval=args.cache_interval,           # 强制重新计算的间隔步数
                  step_end=args.infer_steps-1,                 # 停止cache的步数索引
                  block_start=args.single_block_start,         # 每一步中,开始进行cache的block索引
                  block_end=args.single_block_end              # 每一步中,停止cache的block索引
              )
    • 若使用attention_cache、block_start和block_end,则可采用默认值
      config = CacheConfig(
                  method="attention_cache",
                  blocks_count=len(transformer.single_blocks), # 使能cache的block的个数
                  steps_count=args.infer_steps,                # 模型推理的总迭代步数
                  step_start=args.start_step,                  # 开始进行cache的步数索引
                  step_interval=args.attentioncache_interval,  # 强制重新计算的间隔步数
                  step_end=args.end_step                       # 停止cache的步数索引
              )
  3. 初始化CacheAgent并赋值给block。
    cache_agent = CacheAgent(config)
    • 若对dit block粒度进行cache
      hunyuan_video_sampler.pipeline.transformer.cache = cache_agent
    • 若对block里的attention部分进行cache
      for block in transformer.single_blocks:
          block.cache = cache_agent
  4. 使能cache进行推理。
    • 若使用dit_block_cache:
      x = self.cache.apply(block,
                           hidden_states=x,
                           vec=vec,
                           txt_len=txt_seq_len,
                           cu_seqlens_q=cu_seqlens_q,
                           cu_seqlens_kv=cu_seqlens_kv,
                           max_seqlen_q=max_seqlen_q,
                           max_seqlen_kv=max_seqlen_kv,
                           freqs_cis=freqs_cis)
    • 若使用attention_cache:
      attn = self.cache.apply(self.double_forward,
                              img=img, txt=txt,
                              img_mod1_shift=img_mod1_shift,
                              img_mod1_scale=img_mod1_scale,
                              txt_mod1_shift=txt_mod1_shift,
                              txt_mod1_scale=txt_mod1_scale,
                              freqs_cis=freqs_cis,
                              cu_seqlens_q=cu_seqlens_q,
                              cu_seqlens_kv=cu_seqlens_kv,
                              max_seqlen_q=max_seqlen_q,
                              max_seqlen_kv=max_seqlen_kv)