当前位置:AIGC资讯 > AIGC > 正文

stable-diffusion-webui中stability的sdv1.5和sdxl模型结构config对比

sdv1.5 v1-inference.yaml

model:
  base_learning_rate: 1.0e-04
  target: ldm.models.diffusion.ddpm.LatentDiffusion
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    image_size: 64
    channels: 4
    cond_stage_trainable: false   # Note: different from the one we trained before
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False

    scheduler_config: # 10000 warmup steps
      target: ldm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [ 10000 ]
        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
        f_start: [ 1.e-6 ]
        f_max: [ 1. ]
        f_min: [ 1. ]

    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

modules/initialize.py

Thread(target=load_model).start()
load_model->shared.sd_model

modules/shared_items.py

Shared()->
sd_model()->modules.sd_models.model_data.get_sd_model()

sd_models.py

SdModelData:
get_sd_model()->load_model()

model_data = SdModelData()

sd_models.py load_model()

load_model(checkpoint_info,already_loaded_state_dict)->
state_dict = get_checkpoint_state_dict(checkpoint_info,..)
- torch.load()
checkpoint_config = sd_model_config.find_checkpoint_config(state_dict,checkpoint_info)
# state_dict 权重已经加载上来了,类似下面这种
'model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight': tensor([0.8882, 0.9307, 0.8149, 0.8799, 0.8374, 0.8779, 0.8208, 0.7705, 0.7871,
        0.6953, 0.8354, 0.8594, 0.7881, 0.8018, 0.8442, 0.7744, 0.7969, 0.7715,

sd_models_config.py 

find_checkpoint_config(state_dict,info)
guess_model_config_from_state_dict(state_dict,info.filename)
- config_default # 根据权重的关键key从开头的config中选出来符合要求的yaml

sd_model.py load_model()

sd_config = OmegaConf.load(checkpoint_config)

Creating model from config: /root/autodl-tmp/stable-diffusion-webui/configs/v1-inference.yaml

sd_model = instantiate_from_config(sd_config.model)

简单分析下ldm下的代码:

models是串起全流程的代码,比如DDPM,modules下的是具体的模块代码

repositories/stable-diffusion-stability-ai/ldm/util.py 

get_obj_from_str(config["target"])(**config.get("params", dict()))
module:ldm.models.diffusion.ddpm,cls:LatentDiffusion
importlib.import_module(module, package=None)->
<module 'ldm.models.diffusion.ddpm' from '/root/autodl-tmp/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py'>

sd_model.py load_model()

sd_model = instantiate_from_config(sd_config.model)
# sd_model = LatentDiffusion

repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py LatentDiffusion()

self.instantiate_first_stage()
- model = instantiate_from_config(config)
- self.first_stage_model = model.eval()

self.instantiate_cond_stage()
- model = instantiate_from_config(config)
- self.cond_stage_model = model.eval()

self.model = DiffusionWrapper(unet_config,..)
- self.diffusion_model = instantiate_from_config(diff_model_config)

sd_model.first_stage_model:

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0): Module(
        (block): ModuleList(
          (0-1): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (downsample): Downsample(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (1): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nin_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ResnetBlock(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (downsample): Downsample(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (2): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nin_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ResnetBlock(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (downsample): Downsample(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (3): Module(
        (block): ModuleList(
          (0-1): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
      )
    )
    (mid): Module(
      (block_1): ResnetBlock(
        (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attn_1): AttnBlock(
        (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
        (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      )
      (block_2): ResnetBlock(
        (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)
    (conv_out): Conv2d(512, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (decoder): Decoder(
    (conv_in): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (mid): Module(
      (block_1): ResnetBlock(
        (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attn_1): AttnBlock(
        (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
        (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      )
      (block_2): ResnetBlock(
        (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (up): ModuleList(
      (0): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nin_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
          )
          (1-2): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
      )
      (1): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nin_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (1-2): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (upsample): Upsample(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (2-3): 2 x Module(
        (block): ModuleList(
          (0-2): 3 x ResnetBlock(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (upsample): Upsample(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)
    (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (loss): Identity()
  (quant_conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
  (post_quant_conv): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))
)

sd_model.cond_stage_model:

FrozenCLIPEmbedder(
  (transformer): CLIPTextModel(
    (text_model): CLIPTextTransformer(
      (embeddings): CLIPTextEmbeddings(
        (token_embedding): Embedding(49408, 768)
        (position_embedding): Embedding(77, 768)
      )
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-11): 12 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=768, out_features=3072, bias=True)
              (fc2): Linear(in_features=3072, out_features=768, bias=True)
            )
            (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
  )
)

sd_model.model -> diffusionModel

FrozenCLIPEmbedder(
  (transformer): CLIPTextModel(
    (text_model): CLIPTextTransformer(
      (embeddings): CLIPTextEmbeddings(
        (token_embedding): Embedding(49408, 768)
        (position_embedding): Embedding(77, 768)
      )
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-11): 12 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=768, out_features=3072, bias=True)
              (fc2): Linear(in_features=3072, out_features=768, bias=True)
            )
            (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
  )
)

sd_model.py load_model_weights

load_model_weights(sd_model,checkpoint_info,state_dict,...)->
model.is_sdxl
model.is_sd1
model.is_sd2
model.load_state_dict(state_dict,strict=False)

vae = model.first_stage_model
model.first_stage_model = None
model.half()
model.first_stage_model = vae

sd_vae.load_vae(model,vae_file,vae_source)

sd_model.py load_model

send_model_to_device(sd_model)

sd_hijack.model_hijack.hijack(sd_model)

modules/sd_hijack.py

StableDiffusionModelHijack->hijack(,m)-> m=sd_model

type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings,self) # 49408,768 
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model,self) 

apply_weighted_forward(m)
self.apply_optimizations()
self.clip = m.cond_stage_model

self.layers = flatten(m)
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward

modules/sd_hijack_clip.py

FrozenCLIPEmbedderWithCustomWords()->
self.tokenizer = wrapped.tokenizer
vocab = self.tokenizer.get_vocab()

sd_model.py load_model

sd_model.eval()
model_data.set_sd_model(sd_model)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
script_callbacks.model_loaded_callback(sd_model)
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)

Model loaded in 3004.5s (calculate hash: 175.0s, load weights from disk: 0.2s, find config: 13.4s, create model: 0.4s, apply weights to model: 667.5s, apply half(): 298.5s, apply dtype to VAE: 15.6s, load VAE: 101.6s, load weights from state dict: 69.7s, move model to device: 21.8s, hijack: 1429.6s, load textual inversion embeddings: 114.8s, scripts callbacks: 53.8s, calculate empty prompt: 42.5s).

sdxl sd_xl_base.yaml

model:
  target: sgm.models.diffusion.DiffusionEngine
  params:
    scale_factor: 0.13025
    disable_first_stage_autocast: True

    denoiser_config:
      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
      params:
        num_idx: 1000

        weighting_config:
          target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
        scaling_config:
          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

    network_config:
      target: sgm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        adm_in_channels: 2816
        num_classes: sequential
        use_checkpoint: True
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [4, 2]
        num_res_blocks: 2
        channel_mult: [1, 2, 4]
        num_head_channels: 64
        use_spatial_transformer: True
        use_linear_in_transformer: True
        transformer_depth: [1, 2, 10]  # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
        context_dim: 2048
        spatial_transformer_attn_type: softmax-xformers
        legacy: False

    conditioner_config:
      target: sgm.modules.GeneralConditioner
      params:
        emb_models:
          # crossattn cond
          - is_trainable: False
            input_key: txt
            target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
            params:
              layer: hidden
              layer_idx: 11
          # crossattn and vector cond
          - is_trainable: False
            input_key: txt
            target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
            params:
              arch: ViT-bigG-14
              version: laion2b_s39b_b160k
              freeze: True
              layer: penultimate
              always_return_pooled: True
              legacy: False
          # vector cond
          - is_trainable: False
            input_key: original_size_as_tuple
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256  # multiplied by two
          # vector cond
          - is_trainable: False
            input_key: crop_coords_top_left
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256  # multiplied by two
          # vector cond
          - is_trainable: False
            input_key: target_size_as_tuple
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256  # multiplied by two

    first_stage_config:
      target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          attn_type: vanilla-xformers
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult: [1, 2, 4, 4]
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

sd_model.py  load_model_weights()

sd_model_xl.extend_sdxl(model)

sd_model_xl.py

model.model.conditioning_key = "crossattn"

discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()

sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding

generative-models中sgm代码结构和ldm一致,models下面是整体代码流程,modules下是具体的模块代码。

repositories/generative-models/sgm/moduels/diffusion.py

model = instantiate_from_config(network_config)
self.model = get_obj_from_str(model)->

self.denoiser = instantiate_from_config(denoiser_config)
self.conditioner = instantiate_from_config(conditioner_config)
self.first_stage_model = instantiate_from_config(first_stage_config).eval()

model.conditioner

GeneralConditioner(
  (embedders): ModuleList(
    (0): FrozenCLIPEmbedder(
      (transformer): CLIPTextModel(
        (text_model): CLIPTextTransformer(
          (embeddings): CLIPTextEmbeddings(
            (token_embedding): Embedding(49408, 768)
            (position_embedding): Embedding(77, 768)
          )
          (encoder): CLIPEncoder(
            (layers): ModuleList(
              (0-11): 12 x CLIPEncoderLayer(
                (self_attn): CLIPAttention(
                  (k_proj): Linear(in_features=768, out_features=768, bias=True)
                  (v_proj): Linear(in_features=768, out_features=768, bias=True)
                  (q_proj): Linear(in_features=768, out_features=768, bias=True)
                  (out_proj): Linear(in_features=768, out_features=768, bias=True)
                )
                (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
                (mlp): CLIPMLP(
                  (activation_fn): QuickGELUActivation()
                  (fc1): Linear(in_features=768, out_features=3072, bias=True)
                  (fc2): Linear(in_features=3072, out_features=768, bias=True)
                )
                (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              )
            )
          )
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (1): FrozenOpenCLIPEmbedder2(
      (model): CLIP(
        (transformer): Transformer(
          (resblocks): ModuleList(
            (0-31): 32 x ResidualAttentionBlock(
              (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (attn): MultiheadAttention(
                (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)
              )
              (ls_1): Identity()
              (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1280, out_features=5120, bias=True)
                (gelu): GELUHijack(approximate='none')
                (c_proj): Linear(in_features=5120, out_features=1280, bias=True)
              )
              (ls_2): Identity()
            )
          )
        )
        (token_embedding): Embedding(49408, 1280)
        (ln_final): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
    )
    (2-4): 3 x ConcatTimestepEmbedderND(
      (timestep): Timestep()
    )
  )
  (wrapped): Module()
)

model.first_stage_model:

GeneralConditioner(
  (embedders): ModuleList(
    (0): FrozenCLIPEmbedder(
      (transformer): CLIPTextModel(
        (text_model): CLIPTextTransformer(
          (embeddings): CLIPTextEmbeddings(
            (token_embedding): Embedding(49408, 768)
            (position_embedding): Embedding(77, 768)
          )
          (encoder): CLIPEncoder(
            (layers): ModuleList(
              (0-11): 12 x CLIPEncoderLayer(
                (self_attn): CLIPAttention(
                  (k_proj): Linear(in_features=768, out_features=768, bias=True)
                  (v_proj): Linear(in_features=768, out_features=768, bias=True)
                  (q_proj): Linear(in_features=768, out_features=768, bias=True)
                  (out_proj): Linear(in_features=768, out_features=768, bias=True)
                )
                (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
                (mlp): CLIPMLP(
                  (activation_fn): QuickGELUActivation()
                  (fc1): Linear(in_features=768, out_features=3072, bias=True)
                  (fc2): Linear(in_features=3072, out_features=768, bias=True)
                )
                (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              )
            )
          )
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (1): FrozenOpenCLIPEmbedder2(
      (model): CLIP(
        (transformer): Transformer(
          (resblocks): ModuleList(
            (0-31): 32 x ResidualAttentionBlock(
              (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (attn): MultiheadAttention(
                (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)
              )
              (ls_1): Identity()
              (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1280, out_features=5120, bias=True)
                (gelu): GELUHijack(approximate='none')
                (c_proj): Linear(in_features=5120, out_features=1280, bias=True)
              )
              (ls_2): Identity()
            )
          )
        )
        (token_embedding): Embedding(49408, 1280)
        (ln_final): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
    )
    (2-4): 3 x ConcatTimestepEmbedderND(
      (timestep): Timestep()
    )
  )
  (wrapped): Module()
)

model.model:

OpenAIWrapper(
  (diffusion_model): UNetModel(
    (time_embed): Sequential(
      (0): Linear(in_features=320, out_features=1280, bias=True)
      (1): SiLU()
      (2): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (label_emb): Sequential(
      (0): Sequential(
        (0): Linear(in_features=2816, out_features=1280, bias=True)
        (1): SiLU()
        (2): Linear(in_features=1280, out_features=1280, bias=True)
      )
    )
    (input_blocks): ModuleList(
      (0): TimestepEmbedSequential(
        (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1-2): 2 x TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=320, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Identity()
        )
      )
      (3): TimestepEmbedSequential(
        (0): Downsample(
          (op): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
      (4): TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=640, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): SpatialTransformer(
          (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=640, out_features=640, bias=True)
          (transformer_blocks): ModuleList(
            (0-1): 2 x BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=640, out_features=640, bias=False)
                (to_v): Linear(in_features=640, out_features=640, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): Sequential(
                  (0): GEGLU(
                    (proj): Linear(in_features=640, out_features=5120, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=2560, out_features=640, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=2048, out_features=640, bias=False)
                (to_v): Linear(in_features=2048, out_features=640, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=640, out_features=640, bias=True)
        )
      )
      (5): TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=640, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Identity()
        )
        (1): SpatialTransformer(
          (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=640, out_features=640, bias=True)
          (transformer_blocks): ModuleList(
            (0-1): 2 x BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=640, out_features=640, bias=False)
                (to_v): Linear(in_features=640, out_features=640, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): Sequential(
                  (0): GEGLU(
                    (proj): Linear(in_features=640, out_features=5120, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=2560, out_features=640, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=2048, out_features=640, bias=False)
                (to_v): Linear(in_features=2048, out_features=640, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=640, out_features=640, bias=True)
        )
      )
      (6): TimestepEmbedSequential(
        (0): Downsample(
          (op): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
      (7): TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Conv2d(640, 1280, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): SpatialTransformer(
          (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=1280, out_features=1280, bias=True)
          (transformer_blocks): ModuleList(
            (0-9): 10 x BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=1280, out_features=1280, bias=False)
                (to_v): Linear(in_features=1280, out_features=1280, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): Sequential(
                  (0): GEGLU(
                    (proj): Linear(in_features=1280, out_features=10240, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=5120, out_features=1280, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=2048, out_features=1280, bias=False)
                (to_v): Linear(in_features=2048, out_features=1280, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=1280, out_features=1280, bias=True)
        )
      )
      (8): TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Identity()
        )
        (1): SpatialTransformer(
          (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=1280, out_features=1280, bias=True)
          (transformer_blocks): ModuleList(
            (0-9): 10 x BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=1280, out_features=1280, bias=False)
                (to_v): Linear(in_features=1280, out_features=1280, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): Sequential(
                  (0): GEGLU(
                    (proj): Linear(in_features=1280, out_features=10240, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=5120, out_features=1280, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=2048, out_features=1280, bias=False)
                (to_v): Linear(in_features=2048, out_features=1280, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=1280, out_features=1280, bias=True)
        )
      )
    )
    (middle_block): TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0, inplace=False)
          (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (skip_connection): Identity()
      )
      (1): SpatialTransformer(
        (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
        (proj_in): Linear(in_features=1280, out_features=1280, bias=True)
        (transformer_blocks): ModuleList(
          (0-9): 10 x BasicTransformerBlock(
            (attn1): CrossAttention(
              (to_q): Linear(in_features=1280, out_features=1280, bias=False)
              (to_k): Linear(in_features=1280, out_features=1280, bias=False)
              (to_v): Linear(in_features=1280, out_features=1280, bias=False)
              (to_out): Sequential(
                (0): Linear(in_features=1280, out_features=1280, bias=True)
                (1): Dropout(p=0.0, inplace=False)
              )
            )
            (ff): FeedForward(
              (net): Sequential(
                (0): GEGLU(
                  (proj): Linear(in_features=1280, out_features=10240, bias=True)
                )
                (1): Dropout(p=0.0, inplace=False)
                (2): Linear(in_features=5120, out_features=1280, bias=True)
              )
            )
            (attn2): CrossAttention(
              (to_q): Linear(in_features=1280, out_features=1280, bias=False)
              (to_k): Linear(in_features=2048, out_features=1280, bias=False)
              (to_v): Linear(in_features=2048, out_features=1280, bias=False)
              (to_out): Sequential(
                (0): Linear(in_features=1280, out_features=1280, bias=True)
                (1): Dropout(p=0.0, inplace=False)
              )
            )
            (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          )
        )
        (proj_out): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (2): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0, inplace=False)
          (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (skip_connection): Identity()
      )
    )
    (output_blocks): ModuleList(
      (0-1): 2 x TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 2560, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): SpatialTransformer(
          (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=1280, out_features=1280, bias=True)
          (transformer_blocks): ModuleList(
            (0-9): 10 x BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=1280, out_features=1280, bias=False)
                (to_v): Linear(in_features=1280, out_features=1280, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): Sequential(
                  (0): GEGLU(
                    (proj): Linear(in_features=1280, out_features=10240, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=5120, out_features=1280, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=2048, out_features=1280, bias=False)
                (to_v): Linear(in_features=2048, out_features=1280, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=1280, out_features=1280, bias=True)
        )
      )
      (2): TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 1920, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Conv2d(1920, 1280, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): SpatialTransformer(
          (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=1280, out_features=1280, bias=True)
          (transformer_blocks): ModuleList(
            (0-9): 10 x BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=1280, out_features=1280, bias=False)
                (to_v): Linear(in_features=1280, out_features=1280, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): Sequential(
                  (0): GEGLU(
                    (proj): Linear(in_features=1280, out_features=10240, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=5120, out_features=1280, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=1280, out_features=1280, bias=False)
                (to_k): Linear(in_features=2048, out_features=1280, bias=False)
                (to_v): Linear(in_features=2048, out_features=1280, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=1280, out_features=1280, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (2): Upsample(
          (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (3): TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 1920, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=640, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Conv2d(1920, 640, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): SpatialTransformer(
          (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=640, out_features=640, bias=True)
          (transformer_blocks): ModuleList(
            (0-1): 2 x BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=640, out_features=640, bias=False)
                (to_v): Linear(in_features=640, out_features=640, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): Sequential(
                  (0): GEGLU(
                    (proj): Linear(in_features=640, out_features=5120, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=2560, out_features=640, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=2048, out_features=640, bias=False)
                (to_v): Linear(in_features=2048, out_features=640, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=640, out_features=640, bias=True)
        )
      )
      (4): TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=640, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): SpatialTransformer(
          (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=640, out_features=640, bias=True)
          (transformer_blocks): ModuleList(
            (0-1): 2 x BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=640, out_features=640, bias=False)
                (to_v): Linear(in_features=640, out_features=640, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): Sequential(
                  (0): GEGLU(
                    (proj): Linear(in_features=640, out_features=5120, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=2560, out_features=640, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=2048, out_features=640, bias=False)
                (to_v): Linear(in_features=2048, out_features=640, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=640, out_features=640, bias=True)
        )
      )
      (5): TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 960, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=640, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Conv2d(960, 640, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): SpatialTransformer(
          (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=640, out_features=640, bias=True)
          (transformer_blocks): ModuleList(
            (0-1): 2 x BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=640, out_features=640, bias=False)
                (to_v): Linear(in_features=640, out_features=640, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): Sequential(
                  (0): GEGLU(
                    (proj): Linear(in_features=640, out_features=5120, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=2560, out_features=640, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=640, out_features=640, bias=False)
                (to_k): Linear(in_features=2048, out_features=640, bias=False)
                (to_v): Linear(in_features=2048, out_features=640, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=640, out_features=640, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=640, out_features=640, bias=True)
        )
        (2): Upsample(
          (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (6): TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 960, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=320, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (7-8): 2 x TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (h_upd): Identity()
          (x_upd): Identity()
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=1280, out_features=320, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
    (out): Sequential(
      (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
)

更新时间 2024-03-26