From 8b0e3d063aef8a267e7d87f7630a2a367008e153 Mon Sep 17 00:00:00 2001 From: LarryTsai Date: Mon, 21 Oct 2024 10:51:54 +0800 Subject: [PATCH] Update text_encoder/config.json --- scheduler/scheduler_config.json | 2 +- text_encoder/config.json | 4 +- text_encoder/model-00001-of-00004.safetensors | 3 - text_encoder/model-00002-of-00004.safetensors | 3 - text_encoder/model-00003-of-00004.safetensors | 3 - text_encoder/model-00004-of-00004.safetensors | 3 - text_encoder/model.safetensors.index.json | 226 --- transformer/config.json | 2 +- ...n_pytorch_model-00001-of-00002.safetensors | 3 - ...n_pytorch_model-00002-of-00002.safetensors | 3 - ...usion_pytorch_model.safetensors.index.json | 694 ------- transformer/transformer_3d_allegro.py | 1776 ----------------- vae/config.json | 2 +- vae/vae_allegro.py | 978 --------- 14 files changed, 4 insertions(+), 3698 deletions(-) delete mode 100644 text_encoder/model-00001-of-00004.safetensors delete mode 100644 text_encoder/model-00002-of-00004.safetensors delete mode 100644 text_encoder/model-00003-of-00004.safetensors delete mode 100644 text_encoder/model-00004-of-00004.safetensors delete mode 100644 text_encoder/model.safetensors.index.json delete mode 100644 transformer/diffusion_pytorch_model-00001-of-00002.safetensors delete mode 100644 transformer/diffusion_pytorch_model-00002-of-00002.safetensors delete mode 100644 transformer/diffusion_pytorch_model.safetensors.index.json delete mode 100644 transformer/transformer_3d_allegro.py delete mode 100644 vae/vae_allegro.py diff --git a/scheduler/scheduler_config.json b/scheduler/scheduler_config.json index cb332de..da642e9 100644 --- a/scheduler/scheduler_config.json +++ b/scheduler/scheduler_config.json @@ -1,6 +1,6 @@ { "_class_name": "EulerAncestralDiscreteScheduler", - "_diffusers_version": "0.30.3", + "_diffusers_version": "0.28.0", "beta_end": 0.02, "beta_schedule": "linear", "beta_start": 0.0001, diff --git a/text_encoder/config.json b/text_encoder/config.json index 5f53503..c631139 100644 --- a/text_encoder/config.json +++ b/text_encoder/config.json @@ -1,9 +1,7 @@ { - "_name_or_path": "/cpfs/data/user/larrytsai/Projects/Yi-VG/allegro/text_encoder", "architectures": [ "T5EncoderModel" ], - "classifier_dropout": 0.0, "d_ff": 10240, "d_kv": 64, "d_model": 4096, @@ -26,7 +24,7 @@ "relative_attention_num_buckets": 32, "tie_word_embeddings": false, "torch_dtype": "float32", - "transformers_version": "4.40.1", + "transformers_version": "4.21.1", "use_cache": true, "vocab_size": 32128 } diff --git a/text_encoder/model-00001-of-00004.safetensors b/text_encoder/model-00001-of-00004.safetensors deleted file mode 100644 index 18c5aab..0000000 --- a/text_encoder/model-00001-of-00004.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7a68b2c8c080696a10109612a649bc69330991ecfea65930ccfdfbdb011f2686 -size 4989319680 diff --git a/text_encoder/model-00002-of-00004.safetensors b/text_encoder/model-00002-of-00004.safetensors deleted file mode 100644 index 99a5b98..0000000 --- a/text_encoder/model-00002-of-00004.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b8ed6556d7507e38af5b428c605fb2a6f2bdb7e80bd481308b865f7a40c551ca -size 4999830656 diff --git a/text_encoder/model-00003-of-00004.safetensors b/text_encoder/model-00003-of-00004.safetensors deleted file mode 100644 index 03e0700..0000000 --- a/text_encoder/model-00003-of-00004.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c831635f83041f83faf0024b39c6ecb21b45d70dd38a63ea5bac6c7c6e5e558c -size 4865612720 diff --git a/text_encoder/model-00004-of-00004.safetensors b/text_encoder/model-00004-of-00004.safetensors deleted file mode 100644 index 2c6513c..0000000 --- a/text_encoder/model-00004-of-00004.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:02a5f2d69205be92ad48fe5d712d38c2ff55627969116aeffc58bd75a28da468 -size 4194506688 diff --git a/text_encoder/model.safetensors.index.json b/text_encoder/model.safetensors.index.json deleted file mode 100644 index 0900b8e..0000000 --- a/text_encoder/model.safetensors.index.json +++ /dev/null @@ -1,226 +0,0 @@ -{ - "metadata": { - "total_size": 19049242624 - }, - "weight_map": { - "encoder.block.0.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors", - "encoder.block.0.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors", - "encoder.block.0.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors", - "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "model-00001-of-00004.safetensors", - "encoder.block.0.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors", - "encoder.block.0.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors", - "encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors", - "encoder.block.0.layer.1.DenseReluDense.wo.weight": "model-00001-of-00004.safetensors", - "encoder.block.0.layer.1.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors", - "encoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors", - "encoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors", - "encoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors", - "encoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors", - "encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors", - "encoder.block.1.layer.1.DenseReluDense.wo.weight": "model-00001-of-00004.safetensors", - "encoder.block.1.layer.1.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.10.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors", - "encoder.block.10.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors", - "encoder.block.10.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors", - "encoder.block.10.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors", - "encoder.block.10.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors", - "encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors", - "encoder.block.10.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors", - "encoder.block.10.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.11.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors", - "encoder.block.11.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors", - "encoder.block.11.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors", - "encoder.block.11.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors", - "encoder.block.11.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors", - "encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors", - "encoder.block.11.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors", - "encoder.block.11.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.12.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors", - "encoder.block.12.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors", - "encoder.block.12.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors", - "encoder.block.12.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors", - "encoder.block.12.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors", - "encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors", - "encoder.block.12.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors", - "encoder.block.12.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.13.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors", - "encoder.block.13.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors", - "encoder.block.13.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors", - "encoder.block.13.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors", - "encoder.block.13.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors", - "encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors", - "encoder.block.13.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors", - "encoder.block.13.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.14.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors", - "encoder.block.14.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors", - "encoder.block.14.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors", - "encoder.block.14.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors", - "encoder.block.14.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors", - "encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors", - "encoder.block.14.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors", - "encoder.block.14.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.15.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors", - "encoder.block.15.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors", - "encoder.block.15.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors", - "encoder.block.15.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors", - "encoder.block.15.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors", - "encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors", - "encoder.block.15.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors", - "encoder.block.15.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.16.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors", - "encoder.block.16.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors", - "encoder.block.16.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors", - "encoder.block.16.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors", - "encoder.block.16.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors", - "encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors", - "encoder.block.16.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors", - "encoder.block.16.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.17.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors", - "encoder.block.17.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors", - "encoder.block.17.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors", - "encoder.block.17.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors", - "encoder.block.17.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors", - "encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "model-00003-of-00004.safetensors", - "encoder.block.17.layer.1.DenseReluDense.wo.weight": "model-00003-of-00004.safetensors", - "encoder.block.17.layer.1.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.18.layer.0.SelfAttention.k.weight": "model-00003-of-00004.safetensors", - "encoder.block.18.layer.0.SelfAttention.o.weight": "model-00003-of-00004.safetensors", - "encoder.block.18.layer.0.SelfAttention.q.weight": "model-00003-of-00004.safetensors", - "encoder.block.18.layer.0.SelfAttention.v.weight": "model-00003-of-00004.safetensors", - "encoder.block.18.layer.0.layer_norm.weight": "model-00003-of-00004.safetensors", - "encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "model-00003-of-00004.safetensors", - "encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors", - "encoder.block.18.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors", - "encoder.block.18.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.19.layer.0.SelfAttention.k.weight": "model-00004-of-00004.safetensors", - "encoder.block.19.layer.0.SelfAttention.o.weight": "model-00004-of-00004.safetensors", - "encoder.block.19.layer.0.SelfAttention.q.weight": "model-00004-of-00004.safetensors", - "encoder.block.19.layer.0.SelfAttention.v.weight": "model-00004-of-00004.safetensors", - "encoder.block.19.layer.0.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "model-00004-of-00004.safetensors", - "encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors", - "encoder.block.19.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors", - "encoder.block.19.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors", - "encoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors", - "encoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors", - "encoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors", - "encoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors", - "encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors", - "encoder.block.2.layer.1.DenseReluDense.wo.weight": "model-00001-of-00004.safetensors", - "encoder.block.2.layer.1.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.20.layer.0.SelfAttention.k.weight": "model-00004-of-00004.safetensors", - "encoder.block.20.layer.0.SelfAttention.o.weight": "model-00004-of-00004.safetensors", - "encoder.block.20.layer.0.SelfAttention.q.weight": "model-00004-of-00004.safetensors", - "encoder.block.20.layer.0.SelfAttention.v.weight": "model-00004-of-00004.safetensors", - "encoder.block.20.layer.0.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "model-00004-of-00004.safetensors", - "encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors", - "encoder.block.20.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors", - "encoder.block.20.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.21.layer.0.SelfAttention.k.weight": "model-00004-of-00004.safetensors", - "encoder.block.21.layer.0.SelfAttention.o.weight": "model-00004-of-00004.safetensors", - "encoder.block.21.layer.0.SelfAttention.q.weight": "model-00004-of-00004.safetensors", - "encoder.block.21.layer.0.SelfAttention.v.weight": "model-00004-of-00004.safetensors", - "encoder.block.21.layer.0.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "model-00004-of-00004.safetensors", - "encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors", - "encoder.block.21.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors", - "encoder.block.21.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.22.layer.0.SelfAttention.k.weight": "model-00004-of-00004.safetensors", - "encoder.block.22.layer.0.SelfAttention.o.weight": "model-00004-of-00004.safetensors", - "encoder.block.22.layer.0.SelfAttention.q.weight": "model-00004-of-00004.safetensors", - "encoder.block.22.layer.0.SelfAttention.v.weight": "model-00004-of-00004.safetensors", - "encoder.block.22.layer.0.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "model-00004-of-00004.safetensors", - "encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors", - "encoder.block.22.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors", - "encoder.block.22.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.23.layer.0.SelfAttention.k.weight": "model-00004-of-00004.safetensors", - "encoder.block.23.layer.0.SelfAttention.o.weight": "model-00004-of-00004.safetensors", - "encoder.block.23.layer.0.SelfAttention.q.weight": "model-00004-of-00004.safetensors", - "encoder.block.23.layer.0.SelfAttention.v.weight": "model-00004-of-00004.safetensors", - "encoder.block.23.layer.0.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "model-00004-of-00004.safetensors", - "encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "model-00004-of-00004.safetensors", - "encoder.block.23.layer.1.DenseReluDense.wo.weight": "model-00004-of-00004.safetensors", - "encoder.block.23.layer.1.layer_norm.weight": "model-00004-of-00004.safetensors", - "encoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors", - "encoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors", - "encoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors", - "encoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors", - "encoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors", - "encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors", - "encoder.block.3.layer.1.DenseReluDense.wo.weight": "model-00001-of-00004.safetensors", - "encoder.block.3.layer.1.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors", - "encoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors", - "encoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors", - "encoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors", - "encoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors", - "encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors", - "encoder.block.4.layer.1.DenseReluDense.wo.weight": "model-00001-of-00004.safetensors", - "encoder.block.4.layer.1.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00004.safetensors", - "encoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00004.safetensors", - "encoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00004.safetensors", - "encoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00004.safetensors", - "encoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00004.safetensors", - "encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00004.safetensors", - "encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00004.safetensors", - "encoder.block.5.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors", - "encoder.block.5.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.6.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors", - "encoder.block.6.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors", - "encoder.block.6.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors", - "encoder.block.6.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors", - "encoder.block.6.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors", - "encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors", - "encoder.block.6.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors", - "encoder.block.6.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.7.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors", - "encoder.block.7.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors", - "encoder.block.7.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors", - "encoder.block.7.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors", - "encoder.block.7.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors", - "encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors", - "encoder.block.7.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors", - "encoder.block.7.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.8.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors", - "encoder.block.8.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors", - "encoder.block.8.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors", - "encoder.block.8.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors", - "encoder.block.8.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors", - "encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors", - "encoder.block.8.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors", - "encoder.block.8.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.9.layer.0.SelfAttention.k.weight": "model-00002-of-00004.safetensors", - "encoder.block.9.layer.0.SelfAttention.o.weight": "model-00002-of-00004.safetensors", - "encoder.block.9.layer.0.SelfAttention.q.weight": "model-00002-of-00004.safetensors", - "encoder.block.9.layer.0.SelfAttention.v.weight": "model-00002-of-00004.safetensors", - "encoder.block.9.layer.0.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00004.safetensors", - "encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00004.safetensors", - "encoder.block.9.layer.1.DenseReluDense.wo.weight": "model-00002-of-00004.safetensors", - "encoder.block.9.layer.1.layer_norm.weight": "model-00002-of-00004.safetensors", - "encoder.final_layer_norm.weight": "model-00004-of-00004.safetensors", - "shared.weight": "model-00001-of-00004.safetensors" - } -} diff --git a/transformer/config.json b/transformer/config.json index 488c632..e7d8fa6 100644 --- a/transformer/config.json +++ b/transformer/config.json @@ -1,6 +1,6 @@ { "_class_name": "AllegroTransformer3DModel", - "_diffusers_version": "0.30.3", + "_diffusers_version": "0.28.0", "activation_fn": "gelu-approximate", "attention_bias": true, "attention_head_dim": 96, diff --git a/transformer/diffusion_pytorch_model-00001-of-00002.safetensors b/transformer/diffusion_pytorch_model-00001-of-00002.safetensors deleted file mode 100644 index 16efbab..0000000 --- a/transformer/diffusion_pytorch_model-00001-of-00002.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:566c682d40b99cdf07b351a4ec57a01f5469dc6344dd1eca38939314d5f635bc -size 9985256872 diff --git a/transformer/diffusion_pytorch_model-00002-of-00002.safetensors b/transformer/diffusion_pytorch_model-00002-of-00002.safetensors deleted file mode 100644 index e43e5e1..0000000 --- a/transformer/diffusion_pytorch_model-00002-of-00002.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3b1cffec0f067b2fb0bc0b6f228cafddd87d04e23772c8d5a4c9f40f8c9719eb -size 1102452560 diff --git a/transformer/diffusion_pytorch_model.safetensors.index.json b/transformer/diffusion_pytorch_model.safetensors.index.json deleted file mode 100644 index 55a5afd..0000000 --- a/transformer/diffusion_pytorch_model.safetensors.index.json +++ /dev/null @@ -1,694 +0,0 @@ -{ - "metadata": { - "total_size": 11087631424 - }, - "weight_map": { - "adaln_single.emb.timestep_embedder.linear_1.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "adaln_single.emb.timestep_embedder.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "adaln_single.emb.timestep_embedder.linear_2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "adaln_single.emb.timestep_embedder.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "adaln_single.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "adaln_single.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "caption_projection.linear_1.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "caption_projection.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "caption_projection.linear_2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "caption_projection.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "pos_embed.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "pos_embed.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "proj_out.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "proj_out.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.0.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.1.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.10.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.11.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.12.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.13.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.14.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.15.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.16.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.17.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.18.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.19.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.2.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.20.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.21.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.22.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.23.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.24.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.25.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.26.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.27.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.28.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn2.to_out.0.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.29.attn2.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.29.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.29.attn2.to_v.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.29.attn2.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.29.ff.net.0.proj.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.29.ff.net.0.proj.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.29.ff.net.2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.29.ff.net.2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.29.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.3.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.30.attn1.to_k.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn1.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn1.to_out.0.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn1.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn1.to_q.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn1.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn1.to_v.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn1.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn2.to_k.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn2.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn2.to_out.0.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn2.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn2.to_q.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn2.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn2.to_v.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.attn2.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.ff.net.0.proj.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.ff.net.0.proj.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.ff.net.2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.ff.net.2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.30.scale_shift_table": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn1.to_k.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn1.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn1.to_out.0.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn1.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn1.to_q.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn1.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn1.to_v.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn1.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn2.to_k.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn2.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn2.to_out.0.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn2.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn2.to_q.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn2.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn2.to_v.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.attn2.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.ff.net.0.proj.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.ff.net.0.proj.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.ff.net.2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.ff.net.2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.31.scale_shift_table": "diffusion_pytorch_model-00002-of-00002.safetensors", - "transformer_blocks.4.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.4.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.5.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.6.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.7.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.8.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn1.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn1.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn1.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn1.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn1.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn1.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn1.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn1.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn2.to_k.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn2.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn2.to_out.0.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn2.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn2.to_q.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn2.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn2.to_v.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.attn2.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.ff.net.0.proj.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.ff.net.0.proj.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.ff.net.2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.ff.net.2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors", - "transformer_blocks.9.scale_shift_table": "diffusion_pytorch_model-00001-of-00002.safetensors" - } -} diff --git a/transformer/transformer_3d_allegro.py b/transformer/transformer_3d_allegro.py deleted file mode 100644 index efc358d..0000000 --- a/transformer/transformer_3d_allegro.py +++ /dev/null @@ -1,1776 +0,0 @@ -# Adapted from Open-Sora-Plan - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan -# -------------------------------------------------------- - - -import json -import os -from dataclasses import dataclass -from functools import partial -from importlib import import_module -from typing import Any, Callable, Dict, Optional, Tuple - -import numpy as np -import torch -import collections -import torch.nn.functional as F -from torch.nn.attention import SDPBackend, sdpa_kernel -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.activations import GEGLU, GELU, ApproximateGELU -from diffusers.models.attention_processor import ( - AttnAddedKVProcessor, - AttnAddedKVProcessor2_0, - AttnProcessor, - CustomDiffusionAttnProcessor, - CustomDiffusionAttnProcessor2_0, - CustomDiffusionXFormersAttnProcessor, - LoRAAttnAddedKVProcessor, - LoRAAttnProcessor, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - SlicedAttnAddedKVProcessor, - SlicedAttnProcessor, - SpatialNorm, - XFormersAttnAddedKVProcessor, - XFormersAttnProcessor, -) -from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available -from diffusers.utils.torch_utils import maybe_allow_in_graph -from einops import rearrange, repeat -from torch import nn -from diffusers.models.embeddings import PixArtAlphaTextProjection - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - -from diffusers.utils import logging - -logger = logging.get_logger(__name__) - - -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - -class CombinedTimestepSizeEmbeddings(nn.Module): - """ - For PixArt-Alpha. - - Reference: - https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 - """ - - def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): - super().__init__() - - self.outdim = size_emb_dim - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - - self.use_additional_conditions = use_additional_conditions - if use_additional_conditions: - self.use_additional_conditions = True - self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) - self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) - - def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): - if size.ndim == 1: - size = size[:, None] - - if size.shape[0] != batch_size: - size = size.repeat(batch_size // size.shape[0], 1) - if size.shape[0] != batch_size: - raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") - - current_batch_size, dims = size.shape[0], size.shape[1] - size = size.reshape(-1) - size_freq = self.additional_condition_proj(size).to(size.dtype) - - size_emb = embedder(size_freq) - size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) - return size_emb - - def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) - - if self.use_additional_conditions: - resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) - aspect_ratio = self.apply_condition( - aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder - ) - conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) - else: - conditioning = timesteps_emb - - return conditioning - - -class PositionGetter3D(object): - """ return positions of patches """ - - def __init__(self, ): - self.cache_positions = {} - - def __call__(self, b, t, h, w, device): - if not (b, t,h,w) in self.cache_positions: - x = torch.arange(w, device=device) - y = torch.arange(h, device=device) - z = torch.arange(t, device=device) - pos = torch.cartesian_prod(z, y, x) - - pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone() - poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous()) - max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max())) - - self.cache_positions[b, t, h, w] = (poses, max_poses) - pos = self.cache_positions[b, t, h, w] - - return pos - - -class RoPE3D(torch.nn.Module): - - def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)): - super().__init__() - self.base = freq - self.F0 = F0 - self.interpolation_scale_t = interpolation_scale_thw[0] - self.interpolation_scale_h = interpolation_scale_thw[1] - self.interpolation_scale_w = interpolation_scale_thw[2] - self.cache = {} - - def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1): - if (D, seq_len, device, dtype) not in self.cache: - inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale - freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) - freqs = torch.cat((freqs, freqs), dim=-1) - cos = freqs.cos() # (Seq, Dim) - sin = freqs.sin() - self.cache[D, seq_len, device, dtype] = (cos, sin) - return self.cache[D, seq_len, device, dtype] - - @staticmethod - def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - def apply_rope1d(self, tokens, pos1d, cos, sin): - assert pos1d.ndim == 2 - - # for (batch_size x ntokens x nheads x dim) - cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] - sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] - return (tokens * cos) + (self.rotate_half(tokens) * sin) - - def forward(self, tokens, positions): - """ - input: - * tokens: batch_size x nheads x ntokens x dim - * positions: batch_size x ntokens x 3 (t, y and x position of each token) - output: - * tokens after appplying RoPE3D (batch_size x nheads x ntokens x x dim) - """ - assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three" - D = tokens.size(3) // 3 - poses, max_poses = positions - assert len(poses) == 3 and poses[0].ndim == 2# Batch, Seq, 3 - cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t) - cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h) - cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) - # split features into three along the feature dimension, and apply rope1d on each half - t, y, x = tokens.chunk(3, dim=-1) - t = self.apply_rope1d(t, poses[0], cos_t, sin_t) - y = self.apply_rope1d(y, poses[1], cos_y, sin_y) - x = self.apply_rope1d(x, poses[2], cos_x, sin_x) - tokens = torch.cat((t, y, x), dim=-1) - return tokens - -class PatchEmbed2D(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - num_frames=1, - height=224, - width=224, - patch_size_t=1, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - interpolation_scale=(1, 1), - interpolation_scale_t=1, - use_abs_pos=False, - ): - super().__init__() - self.use_abs_pos = use_abs_pos - self.flatten = flatten - self.layer_norm = layer_norm - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias - ) - if layer_norm: - self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) - else: - self.norm = None - - self.patch_size_t = patch_size_t - self.patch_size = patch_size - - def forward(self, latent): - b, _, _, _, _ = latent.shape - video_latent = None - - latent = rearrange(latent, 'b c t h w -> (b t) c h w') - - latent = self.proj(latent) - if self.flatten: - latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C - if self.layer_norm: - latent = self.norm(latent) - - latent = rearrange(latent, '(b t) n c -> b (t n) c', b=b) - video_latent = latent - - return video_latent - - -@maybe_allow_in_graph -class Attention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (`int`): - The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): - The number of heads to use for multi-head attention. - dim_head (`int`, *optional*, defaults to 64): - The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - upcast_attention (`bool`, *optional*, defaults to False): - Set to `True` to upcast the attention computation to `float32`. - upcast_softmax (`bool`, *optional*, defaults to False): - Set to `True` to upcast the softmax computation to `float32`. - cross_attention_norm (`str`, *optional*, defaults to `None`): - The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. - cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups to use for the group norm in the cross attention. - added_kv_proj_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the added key and value projections. If `None`, no projection is used. - norm_num_groups (`int`, *optional*, defaults to `None`): - The number of groups to use for the group norm in the attention. - spatial_norm_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the spatial normalization. - out_bias (`bool`, *optional*, defaults to `True`): - Set to `True` to use a bias in the output linear layer. - scale_qk (`bool`, *optional*, defaults to `True`): - Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. - only_cross_attention (`bool`, *optional*, defaults to `False`): - Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if - `added_kv_proj_dim` is not `None`. - eps (`float`, *optional*, defaults to 1e-5): - An additional value added to the denominator in group normalization that is used for numerical stability. - rescale_output_factor (`float`, *optional*, defaults to 1.0): - A factor to rescale the output by dividing it with this value. - residual_connection (`bool`, *optional*, defaults to `False`): - Set to `True` to add the residual connection to the output. - _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): - Set to `True` if the attention block is loaded from a deprecated state dict. - processor (`AttnProcessor`, *optional*, defaults to `None`): - The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and - `AttnProcessor` otherwise. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: Optional[str] = None, - cross_attention_norm_num_groups: int = 32, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor"] = None, - attention_mode: str = "xformers", - use_rope: bool = False, - interpolation_scale_thw=None, - ): - super().__init__() - self.inner_dim = dim_head * heads - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.rescale_output_factor = rescale_output_factor - self.residual_connection = residual_connection - self.dropout = dropout - self.use_rope = use_rope - - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - self.heads = heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - self.only_cross_attention = only_cross_attention - - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) - else: - self.group_norm = None - - if spatial_norm_dim is not None: - self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) - else: - self.spatial_norm = None - - if cross_attention_norm is None: - self.norm_cross = None - elif cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(self.cross_attention_dim) - elif cross_attention_norm == "group_norm": - if self.added_kv_proj_dim is not None: - # The given `encoder_hidden_states` are initially of shape - # (batch_size, seq_len, added_kv_proj_dim) before being projected - # to (batch_size, seq_len, cross_attention_dim). The norm is applied - # before the projection, so we need to use `added_kv_proj_dim` as - # the number of channels for the group norm. - norm_cross_num_channels = added_kv_proj_dim - else: - norm_cross_num_channels = self.cross_attention_dim - - self.norm_cross = nn.GroupNorm( - num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True - ) - else: - raise ValueError( - f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" - ) - - linear_cls = nn.Linear - - - self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) - - if not self.only_cross_attention: - # only relevant for the `AddedKVProcessor` classes - self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) - else: - self.to_k = None - self.to_v = None - - if self.added_kv_proj_dim is not None: - self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) - self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) - - self.to_out = nn.ModuleList([]) - self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - if processor is None: - processor = ( - AttnProcessor2_0( - attention_mode, - use_rope, - interpolation_scale_thw=interpolation_scale_thw, - ) - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk - else AttnProcessor() - ) - self.set_processor(processor) - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ) -> None: - r""" - Set whether to use memory efficient attention from `xformers` or not. - - Args: - use_memory_efficient_attention_xformers (`bool`): - Whether to use memory efficient attention from `xformers` or not. - attention_op (`Callable`, *optional*): - The attention operation to use. Defaults to `None` which uses the default attention operation from - `xformers`. - """ - is_lora = hasattr(self, "processor") - is_custom_diffusion = hasattr(self, "processor") and isinstance( - self.processor, - (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), - ) - is_added_kv_processor = hasattr(self, "processor") and isinstance( - self.processor, - ( - AttnAddedKVProcessor, - AttnAddedKVProcessor2_0, - SlicedAttnAddedKVProcessor, - XFormersAttnAddedKVProcessor, - LoRAAttnAddedKVProcessor, - ), - ) - - if use_memory_efficient_attention_xformers: - if is_added_kv_processor and (is_lora or is_custom_diffusion): - raise NotImplementedError( - f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" - ) - if not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - - if is_lora: - # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers - # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? - processor = LoRAXFormersAttnProcessor( - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - rank=self.processor.rank, - attention_op=attention_op, - ) - processor.load_state_dict(self.processor.state_dict()) - processor.to(self.processor.to_q_lora.up.weight.device) - elif is_custom_diffusion: - processor = CustomDiffusionXFormersAttnProcessor( - train_kv=self.processor.train_kv, - train_q_out=self.processor.train_q_out, - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - attention_op=attention_op, - ) - processor.load_state_dict(self.processor.state_dict()) - if hasattr(self.processor, "to_k_custom_diffusion"): - processor.to(self.processor.to_k_custom_diffusion.weight.device) - elif is_added_kv_processor: - # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP - # which uses this type of cross attention ONLY because the attention mask of format - # [0, ..., -10.000, ..., 0, ...,] is not supported - # throw warning - logger.info( - "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." - ) - processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) - else: - processor = XFormersAttnProcessor(attention_op=attention_op) - else: - if is_lora: - attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor - ) - processor = attn_processor_class( - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - rank=self.processor.rank, - ) - processor.load_state_dict(self.processor.state_dict()) - processor.to(self.processor.to_q_lora.up.weight.device) - elif is_custom_diffusion: - attn_processor_class = ( - CustomDiffusionAttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") - else CustomDiffusionAttnProcessor - ) - processor = attn_processor_class( - train_kv=self.processor.train_kv, - train_q_out=self.processor.train_q_out, - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - ) - processor.load_state_dict(self.processor.state_dict()) - if hasattr(self.processor, "to_k_custom_diffusion"): - processor.to(self.processor.to_k_custom_diffusion.weight.device) - else: - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessor2_0() - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk - else AttnProcessor() - ) - - self.set_processor(processor) - - def set_attention_slice(self, slice_size: int) -> None: - r""" - Set the slice size for attention computation. - - Args: - slice_size (`int`): - The slice size for attention computation. - """ - if slice_size is not None and slice_size > self.sliceable_head_dim: - raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") - - if slice_size is not None and self.added_kv_proj_dim is not None: - processor = SlicedAttnAddedKVProcessor(slice_size) - elif slice_size is not None: - processor = SlicedAttnProcessor(slice_size) - elif self.added_kv_proj_dim is not None: - processor = AttnAddedKVProcessor() - else: - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() - ) - - self.set_processor(processor) - - def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: - r""" - Set the attention processor to use. - - Args: - processor (`AttnProcessor`): - The attention processor to use. - _remove_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to remove LoRA layers from the model. - """ - if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: - deprecate( - "set_processor to offload LoRA", - "0.26.0", - "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", - ) - # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete - # We need to remove all LoRA layers - # Don't forget to remove ALL `_remove_lora` from the codebase - for module in self.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) - - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def get_processor(self, return_deprecated_lora: bool = False): - r""" - Get the attention processor in use. - - Args: - return_deprecated_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to return the deprecated LoRA attention processor. - - Returns: - "AttentionProcessor": The attention processor in use. - """ - if not return_deprecated_lora: - return self.processor - - # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible - # serialization format for LoRA Attention Processors. It should be deleted once the integration - # with PEFT is completed. - is_lora_activated = { - name: module.lora_layer is not None - for name, module in self.named_modules() - if hasattr(module, "lora_layer") - } - - # 1. if no layer has a LoRA activated we can return the processor as usual - if not any(is_lora_activated.values()): - return self.processor - - # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` - is_lora_activated.pop("add_k_proj", None) - is_lora_activated.pop("add_v_proj", None) - # 2. else it is not posssible that only some layers have LoRA activated - if not all(is_lora_activated.values()): - raise ValueError( - f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" - ) - - # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor - non_lora_processor_cls_name = self.processor.__class__.__name__ - lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) - - hidden_size = self.inner_dim - - # now create a LoRA attention processor from the LoRA layers - if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: - kwargs = { - "cross_attention_dim": self.cross_attention_dim, - "rank": self.to_q.lora_layer.rank, - "network_alpha": self.to_q.lora_layer.network_alpha, - "q_rank": self.to_q.lora_layer.rank, - "q_hidden_size": self.to_q.lora_layer.out_features, - "k_rank": self.to_k.lora_layer.rank, - "k_hidden_size": self.to_k.lora_layer.out_features, - "v_rank": self.to_v.lora_layer.rank, - "v_hidden_size": self.to_v.lora_layer.out_features, - "out_rank": self.to_out[0].lora_layer.rank, - "out_hidden_size": self.to_out[0].lora_layer.out_features, - } - - if hasattr(self.processor, "attention_op"): - kwargs["attention_op"] = self.processor.attention_op - - lora_processor = lora_processor_cls(hidden_size, **kwargs) - lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) - lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) - lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) - lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) - elif lora_processor_cls == LoRAAttnAddedKVProcessor: - lora_processor = lora_processor_cls( - hidden_size, - cross_attention_dim=self.add_k_proj.weight.shape[0], - rank=self.to_q.lora_layer.rank, - network_alpha=self.to_q.lora_layer.network_alpha, - ) - lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) - lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) - lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) - lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) - - # only save if used - if self.add_k_proj.lora_layer is not None: - lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) - lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) - else: - lora_processor.add_k_proj_lora = None - lora_processor.add_v_proj_lora = None - else: - raise ValueError(f"{lora_processor_cls} does not exist.") - - return lora_processor - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - **cross_attention_kwargs, - ) -> torch.Tensor: - r""" - The forward method of the `Attention` class. - - Args: - hidden_states (`torch.Tensor`): - The hidden states of the query. - encoder_hidden_states (`torch.Tensor`, *optional*): - The hidden states of the encoder. - attention_mask (`torch.Tensor`, *optional*): - The attention mask to use. If `None`, no mask is applied. - **cross_attention_kwargs: - Additional keyword arguments to pass along to the cross attention. - - Returns: - `torch.Tensor`: The output of the attention layer. - """ - # The `Attention` class can call different attention processors / attention functions - # here we simply pass along all tensors to the selected processor class - # For standard processors that are defined here, `**cross_attention_kwargs` is empty - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: - r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` - is the number of heads initialized while constructing the `Attention` class. - - Args: - tensor (`torch.Tensor`): The tensor to reshape. - - Returns: - `torch.Tensor`: The reshaped tensor. - """ - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: - r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is - the number of heads initialized while constructing the `Attention` class. - - Args: - tensor (`torch.Tensor`): The tensor to reshape. - out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is - reshaped to `[batch_size * heads, seq_len, dim // heads]`. - - Returns: - `torch.Tensor`: The reshaped tensor. - """ - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3) - - if out_dim == 3: - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) - - return tensor - - def get_attention_scores( - self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None - ) -> torch.Tensor: - r""" - Compute the attention scores. - - Args: - query (`torch.Tensor`): The query tensor. - key (`torch.Tensor`): The key tensor. - attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. - - Returns: - `torch.Tensor`: The attention probabilities/scores. - """ - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - if attention_mask is None: - baddbmm_input = torch.empty( - query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device - ) - beta = 0 - else: - baddbmm_input = attention_mask - beta = 1 - - attention_scores = torch.baddbmm( - baddbmm_input, - query, - key.transpose(-1, -2), - beta=beta, - alpha=self.scale, - ) - del baddbmm_input - - if self.upcast_softmax: - attention_scores = attention_scores.float() - - attention_probs = attention_scores.softmax(dim=-1) - del attention_scores - - attention_probs = attention_probs.to(dtype) - - return attention_probs - - def prepare_attention_mask( - self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3, head_size = None, - ) -> torch.Tensor: - r""" - Prepare the attention mask for the attention computation. - - Args: - attention_mask (`torch.Tensor`): - The attention mask to prepare. - target_length (`int`): - The target length of the attention mask. This is the length of the attention mask after padding. - batch_size (`int`): - The batch size, which is used to repeat the attention mask. - out_dim (`int`, *optional*, defaults to `3`): - The output dimension of the attention mask. Can be either `3` or `4`. - - Returns: - `torch.Tensor`: The prepared attention mask. - """ - head_size = head_size if head_size is not None else self.heads - if attention_mask is None: - return attention_mask - - current_length: int = attention_mask.shape[-1] - if current_length != target_length: - if attention_mask.device.type == "mps": - # HACK: MPS: Does not support padding by greater than dimension of input tensor. - # Instead, we can manually construct the padding tensor. - padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) - padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) - attention_mask = torch.cat([attention_mask, padding], dim=2) - else: - # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: - # we want to instead pad by (0, remaining_length), where remaining_length is: - # remaining_length: int = target_length - current_length - # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) - elif out_dim == 4: - attention_mask = attention_mask.unsqueeze(1) - attention_mask = attention_mask.repeat_interleave(head_size, dim=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: - r""" - Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the - `Attention` class. - - Args: - encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. - - Returns: - `torch.Tensor`: The normalized encoder hidden states. - """ - assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - # Group norm norms along the channels dimension and expects - # input to be in the shape of (N, C, *). In this case, we want - # to norm along the hidden dimension, so we need to move - # (batch_size, sequence_length, hidden_size) -> - # (batch_size, hidden_size, sequence_length) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - else: - assert False - - return encoder_hidden_states - - def _init_compress(self): - self.sr.bias.data.zero_() - self.norm = nn.LayerNorm(self.inner_dim) - - -class AttnProcessor2_0(nn.Module): - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self, attention_mode="xformers", use_rope=False, interpolation_scale_thw=None): - super().__init__() - self.attention_mode = attention_mode - self.use_rope = use_rope - self.interpolation_scale_thw = interpolation_scale_thw - - if self.use_rope: - self._init_rope(interpolation_scale_thw) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def _init_rope(self, interpolation_scale_thw): - self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw) - self.position_getter = PositionGetter3D() - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, - frame: int = 8, - height: int = 16, - width: int = 16, - ) -> torch.FloatTensor: - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None and self.attention_mode == 'xformers': - attention_heads = attn.heads - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, head_size=attention_heads) - attention_mask = attention_mask.view(batch_size, attention_heads, -1, attention_mask.shape[-1]) - else: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - - - attn_heads = attn.heads - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn_heads - - query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - - - if self.use_rope: - # require the shape of (batch_size x nheads x ntokens x dim) - pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device) - query = self.rope(query, pos_thw) - key = self.rope(key, pos_thw) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - if self.attention_mode == 'flash': - # assert attention_mask is None, 'flash-attn do not support attention_mask' - with sdpa_kernel(SDPBackend.FLASH_ATTENTION): - hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False - ) - elif self.attention_mode == 'xformers': - with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - """ - - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - linear_cls = nn.Linear - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(linear_cls(inner_dim, dim_out)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - - -@maybe_allow_in_graph -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, *optional*): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. - final_dropout (`bool` *optional*, defaults to False): - Whether to apply a final dropout after the last feed-forward layer. - positional_embeddings (`str`, *optional*, defaults to `None`): - The type of positional embeddings to apply to. - num_positional_embeddings (`int`, *optional*, defaults to `None`): - The maximum number of positional embeddings to apply. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' - norm_eps: float = 1e-5, - final_dropout: bool = False, - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - sa_attention_mode: str = "flash", - ca_attention_mode: str = "xformers", - use_rope: bool = False, - interpolation_scale_thw: Tuple[int] = (1, 1, 1), - block_idx: Optional[int] = None, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - self.use_ada_layer_norm_single = norm_type == "ada_norm_single" - self.use_layer_norm = norm_type == "layer_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - if positional_embeddings and (num_positional_embeddings is None): - raise ValueError( - "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." - ) - - if positional_embeddings == "sinusoidal": - self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) - else: - self.pos_embed = None - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - attention_mode=sa_attention_mode, - use_rope=use_rope, - interpolation_scale_thw=interpolation_scale_thw, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - ) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - attention_mode=ca_attention_mode, # only xformers support attention_mask - use_rope=False, # do not position in cross attention - interpolation_scale_thw=interpolation_scale_thw, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - - if not self.use_ada_layer_norm_single: - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - ) - - # 5. Scale-shift for PixArt-Alpha. - if self.use_ada_layer_norm_single: - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - frame: int = None, - height: int = None, - width: int = None, - ) -> torch.FloatTensor: - # Notice that normalization is always applied before the real computation in the following blocks. - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - # 0. Self-Attention - batch_size = hidden_states.shape[0] - - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.use_layer_norm: - norm_hidden_states = self.norm1(hidden_states) - elif self.use_ada_layer_norm_single: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - norm_hidden_states = norm_hidden_states.squeeze(1) - else: - raise ValueError("Incorrect norm used") - - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - frame=frame, - height=height, - width=width, - **cross_attention_kwargs, - ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - elif self.use_ada_layer_norm_single: - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - # 1. Cross-Attention - if self.attn2 is not None: - - if self.use_ada_layer_norm: - norm_hidden_states = self.norm2(hidden_states, timestep) - elif self.use_ada_layer_norm_zero or self.use_layer_norm: - norm_hidden_states = self.norm2(hidden_states) - elif self.use_ada_layer_norm_single: - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states - else: - raise ValueError("Incorrect norm") - - if self.pos_embed is not None and self.use_ada_layer_norm_single is False: - norm_hidden_states = self.pos_embed(norm_hidden_states) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - - # 2. Feed-forward - if not self.use_ada_layer_norm_single: - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self.use_ada_layer_norm_single: - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - elif self.use_ada_layer_norm_single: - ff_output = gate_mlp * ff_output - - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - return hidden_states - - -class AdaLayerNormSingle(nn.Module): - r""" - Norm layer adaptive layer norm single (adaLN-single). - - As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - use_additional_conditions (`bool`): To use additional conditions for normalization or not. - """ - - def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): - super().__init__() - - self.emb = CombinedTimestepSizeEmbeddings( - embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions - ) - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - - def forward( - self, - timestep: torch.Tensor, - added_cond_kwargs: Dict[str, torch.Tensor] = None, - batch_size: int = None, - hidden_dtype: Optional[torch.dtype] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # No modulation happening here. - embedded_timestep = self.emb( - timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None - ) - return self.linear(self.silu(embedded_timestep)), embedded_timestep - - -@dataclass -class Transformer3DModelOutput(BaseOutput): - """ - The output of [`Transformer2DModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability - distributions for the unnoised latent pixels. - """ - - sample: torch.FloatTensor - - -class AllegroTransformer3DModel(ModelMixin, ConfigMixin): - _supports_gradient_checkpointing = True - - """ - A 2D Transformer model for image-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - num_vector_embeds (`int`, *optional*): - The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. - - During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - sample_size_t: Optional[int] = None, - patch_size: Optional[int] = None, - patch_size_t: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_type: str = "ada_norm", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - caption_channels: int = None, - interpolation_scale_h: float = None, - interpolation_scale_w: float = None, - interpolation_scale_t: float = None, - use_additional_conditions: Optional[bool] = None, - sa_attention_mode: str = "flash", - ca_attention_mode: str = 'xformers', - downsampler: str = None, - use_rope: bool = False, - model_max_length: int = 300, - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.interpolation_scale_t = interpolation_scale_t - self.interpolation_scale_h = interpolation_scale_h - self.interpolation_scale_w = interpolation_scale_w - self.downsampler = downsampler - self.caption_channels = caption_channels - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - self.inner_dim = inner_dim - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - self.use_rope = use_rope - self.model_max_length = model_max_length - self.num_layers = num_layers - self.config.hidden_size = inner_dim - - - # 1. Transformer3DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - assert in_channels is not None and patch_size is not None - - # 2. Initialize the right blocks. - # Initialize the output blocks and other projection blocks when necessary. - - assert self.config.sample_size_t is not None, "AllegroTransformer3DModel over patched input must provide sample_size_t" - assert self.config.sample_size is not None, "AllegroTransformer3DModel over patched input must provide sample_size" - #assert not (self.config.sample_size_t == 1 and self.config.patch_size_t == 2), "Image do not need patchfy in t-dim" - - self.num_frames = self.config.sample_size_t - self.config.sample_size = to_2tuple(self.config.sample_size) - self.height = self.config.sample_size[0] - self.width = self.config.sample_size[1] - self.patch_size_t = self.config.patch_size_t - self.patch_size = self.config.patch_size - interpolation_scale_t = ((self.config.sample_size_t - 1) // 16 + 1) if self.config.sample_size_t % 2 == 1 else self.config.sample_size_t / 16 - interpolation_scale_t = ( - self.config.interpolation_scale_t if self.config.interpolation_scale_t is not None else interpolation_scale_t - ) - interpolation_scale = ( - self.config.interpolation_scale_h if self.config.interpolation_scale_h is not None else self.config.sample_size[0] / 30, - self.config.interpolation_scale_w if self.config.interpolation_scale_w is not None else self.config.sample_size[1] / 40, - ) - self.pos_embed = PatchEmbed2D( - num_frames=self.config.sample_size_t, - height=self.config.sample_size[0], - width=self.config.sample_size[1], - patch_size_t=self.config.patch_size_t, - patch_size=self.config.patch_size, - in_channels=self.in_channels, - embed_dim=self.inner_dim, - interpolation_scale=interpolation_scale, - interpolation_scale_t=interpolation_scale_t, - use_abs_pos=not self.config.use_rope, - ) - interpolation_scale_thw = (interpolation_scale_t, *interpolation_scale) - - # 3. Define transformers blocks, spatial attention - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - double_self_attention=double_self_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - sa_attention_mode=sa_attention_mode, - ca_attention_mode=ca_attention_mode, - use_rope=use_rope, - interpolation_scale_thw=interpolation_scale_thw, - block_idx=d, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - - if norm_type != "ada_norm_single": - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) - elif norm_type == "ada_norm_single": - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) - - # 5. PixArt-Alpha blocks. - self.adaln_single = None - self.use_additional_conditions = False - if norm_type == "ada_norm_single": - # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024 - # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use - # additional conditions until we find better name - self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) - - self.caption_projection = None - if caption_channels is not None: - self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=inner_dim - ) - - self.gradient_checkpointing = False - - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - - - def forward( - self, - hidden_states: torch.Tensor, - timestep: Optional[torch.LongTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - added_cond_kwargs: Dict[str, torch.Tensor] = None, - class_labels: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ): - """ - The [`Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): - Input `hidden_states`. - encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - added_cond_kwargs ( `Dict[str, Any]`, *optional*): - A kwargs dictionary that if specified is passed along to the `AdaLayerNormSingle` - cross_attention_kwargs ( `Dict[str, Any]`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - attention_mask ( `torch.Tensor`, *optional*): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batch, sequence_length)` True = keep, False = discard. - * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - batch_size, c, frame, h, w = hidden_states.shape - - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. - # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. - # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None - if attention_mask is not None and attention_mask.ndim == 4: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - # b, frame+use_image_num, h, w -> a video with images - # b, 1, h, w -> only images - attention_mask = attention_mask.to(self.dtype) - attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w - - if attention_mask_vid.numel() > 0: - attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w - attention_mask_vid = F.max_pool3d(attention_mask_vid, kernel_size=(self.patch_size_t, self.patch_size, self.patch_size), - stride=(self.patch_size_t, self.patch_size, self.patch_size)) - attention_mask_vid = rearrange(attention_mask_vid, 'b 1 t h w -> (b 1) 1 (t h w)') - - attention_mask_vid = (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: - # b, 1+use_image_num, l -> a video with images - # b, 1, l -> only images - encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 - encoder_attention_mask_vid = rearrange(encoder_attention_mask, 'b 1 l -> (b 1) 1 l') if encoder_attention_mask.numel() > 0 else None - - # 1. Input - frame = frame // self.patch_size_t # patchfy - # print('frame', frame) - height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size - - added_cond_kwargs = {"resolution": None, "aspect_ratio": None} if added_cond_kwargs is None else added_cond_kwargs - hidden_states, encoder_hidden_states_vid, \ - timestep_vid, embedded_timestep_vid = self._operate_on_patched_inputs( - hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, - ) - - - for _, block in enumerate(self.transformer_blocks): - hidden_states = block( - hidden_states, - attention_mask_vid, - encoder_hidden_states_vid, - encoder_attention_mask_vid, - timestep_vid, - cross_attention_kwargs, - class_labels, - frame=frame, - height=height, - width=width, - ) - - # 3. Output - output = None - if hidden_states is not None: - output = self._get_output_for_patched_inputs( - hidden_states=hidden_states, - timestep=timestep_vid, - class_labels=class_labels, - embedded_timestep=embedded_timestep_vid, - num_frames=frame, - height=height, - width=width, - ) # b c t h w - - if not return_dict: - return (output,) - - return Transformer3DModelOutput(sample=output) - - def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size): - # batch_size = hidden_states.shape[0] - hidden_states_vid = self.pos_embed(hidden_states.to(self.dtype)) - timestep_vid = None - embedded_timestep_vid = None - encoder_hidden_states_vid = None - - if self.adaln_single is not None: - if self.use_additional_conditions and added_cond_kwargs is None: - raise ValueError( - "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." - ) - timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype - ) # b 6d, b d - - timestep_vid = timestep - embedded_timestep_vid = embedded_timestep - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1+use_image_num, l, d or b, 1, l, d - encoder_hidden_states_vid = rearrange(encoder_hidden_states[:, :1], 'b 1 l d -> (b 1) l d') - - return hidden_states_vid, encoder_hidden_states_vid, timestep_vid, embedded_timestep_vid - - def _get_output_for_patched_inputs( - self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None - ): - # import ipdb;ipdb.set_trace() - if self.config.norm_type != "ada_norm_single": - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=self.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - elif self.config.norm_type == "ada_norm_single": - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.squeeze(1) - - # unpatchify - if self.adaln_single is None: - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, num_frames, height, width, self.patch_size_t, self.patch_size, self.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, num_frames * self.patch_size_t, height * self.patch_size, width * self.patch_size) - ) - return output diff --git a/vae/config.json b/vae/config.json index 820c364..9b540d4 100644 --- a/vae/config.json +++ b/vae/config.json @@ -1,6 +1,6 @@ { "_class_name": "AllegroAutoencoderKL3D", - "_diffusers_version": "0.30.3", + "_diffusers_version": "0.28.0", "act_fn": "silu", "block_out_channels": [ 128, diff --git a/vae/vae_allegro.py b/vae/vae_allegro.py deleted file mode 100644 index a70baa7..0000000 --- a/vae/vae_allegro.py +++ /dev/null @@ -1,978 +0,0 @@ -import math -from dataclasses import dataclass -import os -from typing import Dict, Optional, Tuple, Union -from einops import rearrange - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.modeling_outputs import AutoencoderKLOutput -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution -from diffusers.models.attention_processor import Attention -from diffusers.models.resnet import ResnetBlock2D -from diffusers.models.upsampling import Upsample2D -from diffusers.models.downsampling import Downsample2D -from diffusers.models.attention_processor import SpatialNorm - - -class TemporalConvBlock(nn.Module): - """ - Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: - https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 - """ - - def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1): - super().__init__() - out_dim = out_dim or in_dim - self.in_dim = in_dim - self.out_dim = out_dim - spa_pad = int((spa_stride-1)*0.5) - temp_pad = 0 - self.temp_pad = temp_pad - - if down_sample: - self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), - nn.SiLU(), - nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad)) - ) - elif up_sample: - self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), - nn.SiLU(), - nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad)) - ) - else: - self.conv1 = nn.Sequential( - nn.GroupNorm(32, in_dim), - nn.SiLU(), - nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)) - ) - self.conv2 = nn.Sequential( - nn.GroupNorm(32, out_dim), - nn.SiLU(), - nn.Dropout(dropout), - nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), - ) - self.conv3 = nn.Sequential( - nn.GroupNorm(32, out_dim), - nn.SiLU(), - nn.Dropout(dropout), - nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), - ) - self.conv4 = nn.Sequential( - nn.GroupNorm(32, out_dim), - nn.SiLU(), - nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), - ) - - # zero out the last layer params,so the conv block is identity - nn.init.zeros_(self.conv4[-1].weight) - nn.init.zeros_(self.conv4[-1].bias) - - self.down_sample = down_sample - self.up_sample = up_sample - - - def forward(self, hidden_states): - identity = hidden_states - - if self.down_sample: - identity = identity[:,:,::2] - elif self.up_sample: - hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2) - hidden_states_new[:, :, 0::2] = hidden_states - hidden_states_new[:, :, 1::2] = hidden_states - identity = hidden_states_new - del hidden_states_new - - if self.down_sample or self.up_sample: - hidden_states = self.conv1(hidden_states) - else: - hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) - hidden_states = self.conv1(hidden_states) - - - if self.up_sample: - hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2) - - hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) - hidden_states = self.conv2(hidden_states) - hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) - hidden_states = self.conv3(hidden_states) - hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) - hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) - hidden_states = self.conv4(hidden_states) - - hidden_states = identity + hidden_states - - return hidden_states - - -class DownEncoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - add_temp_downsample=False, - downsample_padding=1, - ): - super().__init__() - resnets = [] - temp_convs = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - temp_convs.append( - TemporalConvBlock( - out_channels, - out_channels, - dropout=0.1, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.temp_convs = nn.ModuleList(temp_convs) - - if add_temp_downsample: - self.temp_convs_down = TemporalConvBlock( - out_channels, - out_channels, - dropout=0.1, - down_sample=True, - spa_stride=3 - ) - self.add_temp_downsample = add_temp_downsample - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def _set_partial_grad(self): - for temp_conv in self.temp_convs: - temp_conv.requires_grad_(True) - if self.downsamplers: - for down_layer in self.downsamplers: - down_layer.requires_grad_(True) - - def forward(self, hidden_states): - bz = hidden_states.shape[0] - - for resnet, temp_conv in zip(self.resnets, self.temp_convs): - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') - hidden_states = resnet(hidden_states, temb=None) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) - hidden_states = temp_conv(hidden_states) - if self.add_temp_downsample: - hidden_states = self.temp_convs_down(hidden_states) - - if self.downsamplers is not None: - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') - for upsampler in self.downsamplers: - hidden_states = upsampler(hidden_states) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) - return hidden_states - - -class UpDecoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - add_temp_upsample=False, - temb_channels=None, - ): - super().__init__() - self.add_upsample = add_upsample - - resnets = [] - temp_convs = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - temp_convs.append( - TemporalConvBlock( - out_channels, - out_channels, - dropout=0.1, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.temp_convs = nn.ModuleList(temp_convs) - - self.add_temp_upsample = add_temp_upsample - if add_temp_upsample: - self.temp_conv_up = TemporalConvBlock( - out_channels, - out_channels, - dropout=0.1, - up_sample=True, - spa_stride=3 - ) - - - if self.add_upsample: - # self.upsamplers = nn.ModuleList([PSUpsample2D(out_channels, use_conv=True, use_pixel_shuffle=True, out_channels=out_channels)]) - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - def _set_partial_grad(self): - for temp_conv in self.temp_convs: - temp_conv.requires_grad_(True) - if self.add_upsample: - self.upsamplers.requires_grad_(True) - - def forward(self, hidden_states): - bz = hidden_states.shape[0] - - for resnet, temp_conv in zip(self.resnets, self.temp_convs): - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') - hidden_states = resnet(hidden_states, temb=None) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) - hidden_states = temp_conv(hidden_states) - if self.add_temp_upsample: - hidden_states = self.temp_conv_up(hidden_states) - - if self.upsamplers is not None: - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) - return hidden_states - - -class UNetMidBlock3DConv(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - add_attention: bool = True, - attention_head_dim=1, - output_scale_factor=1.0, - ): - super().__init__() - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - self.add_attention = add_attention - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - temp_convs = [ - TemporalConvBlock( - in_channels, - in_channels, - dropout=0.1, - ) - ] - attentions = [] - - if attention_head_dim is None: - attention_head_dim = in_channels - - for _ in range(num_layers): - if self.add_attention: - attentions.append( - Attention( - in_channels, - heads=in_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, - spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - else: - attentions.append(None) - - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - temp_convs.append( - TemporalConvBlock( - in_channels, - in_channels, - dropout=0.1, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.temp_convs = nn.ModuleList(temp_convs) - self.attentions = nn.ModuleList(attentions) - - def _set_partial_grad(self): - for temp_conv in self.temp_convs: - temp_conv.requires_grad_(True) - - def forward( - self, - hidden_states, - ): - bz = hidden_states.shape[0] - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') - - hidden_states = self.resnets[0](hidden_states, temb=None) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) - hidden_states = self.temp_convs[0](hidden_states) - hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') - - for attn, resnet, temp_conv in zip( - self.attentions, self.resnets[1:], self.temp_convs[1:] - ): - hidden_states = attn(hidden_states) - hidden_states = resnet(hidden_states, temb=None) - hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) - hidden_states = temp_conv(hidden_states) - return hidden_states - - -class Encoder3D(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - num_blocks=4, - blocks_temp_li=[False, False, False, False], - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - double_z=True, - ): - super().__init__() - self.layers_per_block = layers_per_block - self.blocks_temp_li = blocks_temp_li - - self.conv_in = nn.Conv2d( - in_channels, - block_out_channels[0], - kernel_size=3, - stride=1, - padding=1, - ) - - self.temp_conv_in = nn.Conv3d( - block_out_channels[0], - block_out_channels[0], - (3,1,1), - padding = (1, 0, 0) - ) - - self.mid_block = None - self.down_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i in range(num_blocks): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = DownEncoderBlock3D( - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - add_temp_downsample=blocks_temp_li[i], - resnet_eps=1e-6, - downsample_padding=0, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock3DConv( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=None, - ) - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - - self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3,1,1), padding = (1, 0, 0)) - - self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) - - nn.init.zeros_(self.temp_conv_in.weight) - nn.init.zeros_(self.temp_conv_in.bias) - nn.init.zeros_(self.temp_conv_out.weight) - nn.init.zeros_(self.temp_conv_out.bias) - - self.gradient_checkpointing = False - - def forward(self, x): - ''' - x: [b, c, (tb f), h, w] - ''' - bz = x.shape[0] - sample = rearrange(x, 'b c n h w -> (b n) c h w') - sample = self.conv_in(sample) - - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) - temp_sample = sample - sample = self.temp_conv_in(sample) - sample = sample+temp_sample - # down - for b_id, down_block in enumerate(self.down_blocks): - sample = down_block(sample) - # middle - sample = self.mid_block(sample) - - # post-process - sample = rearrange(sample, 'b c n h w -> (b n) c h w') - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) - - temp_sample = sample - sample = self.temp_conv_out(sample) - sample = sample+temp_sample - sample = rearrange(sample, 'b c n h w -> (b n) c h w') - - sample = self.conv_out(sample) - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) - return sample - -class Decoder3D(nn.Module): - def __init__( - self, - in_channels=4, - out_channels=3, - num_blocks=4, - blocks_temp_li=[False, False, False, False], - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - norm_type="group", # group, spatial - ): - super().__init__() - self.layers_per_block = layers_per_block - self.blocks_temp_li = blocks_temp_li - - self.conv_in = nn.Conv2d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - ) - - self.temp_conv_in = nn.Conv3d( - block_out_channels[-1], - block_out_channels[-1], - (3,1,1), - padding = (1, 0, 0) - ) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - temb_channels = in_channels if norm_type == "spatial" else None - - # mid - self.mid_block = UNetMidBlock3DConv( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default" if norm_type == "group" else norm_type, - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=temb_channels, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(num_blocks): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = UpDecoderBlock3D( - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - add_upsample=not is_final_block, - add_temp_upsample=blocks_temp_li[i], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - temb_channels=temb_channels, - resnet_time_scale_shift=norm_type, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - - self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3,1,1), padding = (1, 0, 0)) - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - - nn.init.zeros_(self.temp_conv_in.weight) - nn.init.zeros_(self.temp_conv_in.bias) - nn.init.zeros_(self.temp_conv_out.weight) - nn.init.zeros_(self.temp_conv_out.bias) - - self.gradient_checkpointing = False - - def forward(self, z): - bz = z.shape[0] - sample = rearrange(z, 'b c n h w -> (b n) c h w') - sample = self.conv_in(sample) - - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) - temp_sample = sample - sample = self.temp_conv_in(sample) - sample = sample+temp_sample - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - # middle - sample = self.mid_block(sample) - sample = sample.to(upscale_dtype) - - # up - for b_id, up_block in enumerate(self.up_blocks): - sample = up_block(sample) - - # post-process - sample = rearrange(sample, 'b c n h w -> (b n) c h w') - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) - temp_sample = sample - sample = self.temp_conv_out(sample) - sample = sample+temp_sample - sample = rearrange(sample, 'b c n h w -> (b n) c h w') - - sample = self.conv_out(sample) - sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz) - return sample - - - -class AllegroAutoencoderKL3D(ModelMixin, ConfigMixin): - r""" - A VAE model with KL loss for encoding images into latents and decoding latent representations into images. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): - Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): - Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `256`): Spatial Tiling Size. - tile_overlap (`tuple`, *optional*, defaults to `(120, 80`): Spatial overlapping size while tiling (height, width) - chunk_len (`int`, *optional*, defaults to `24`): Temporal Tiling Size. - t_over (`int`, *optional*, defaults to `8`): Temporal overlapping size while tiling - scaling_factor (`float`, *optional*, defaults to 0.13235): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 - / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image - Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - force_upcast (`bool`, *optional*, default to `True`): - If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE - can be fine-tuned / trained to a lower range without loosing too much precision in which case - `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix - blocks_tempdown_li (`List`, *optional*, defaults to `[True, True, False, False]`): Each item indicates whether each TemporalBlock in the Encoder performs temporal downsampling. - blocks_tempup_li (`List`, *optional*, defaults to `[False, True, True, False]`): Each item indicates whether each TemporalBlock in the Decoder performs temporal upsampling. - load_mode (`str`, *optional*, defaults to `full`): Load mode for the model. Can be one of `full`, `encoder_only`, `decoder_only`. which corresponds to loading the full model state dicts, only the encoder state dicts, or only the decoder state dicts. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_num: int = 4, - up_block_num: int = 4, - block_out_channels: Tuple[int] = (128,256,512,512), - layers_per_block: int = 2, - act_fn: str = "silu", - latent_channels: int = 4, - norm_num_groups: int = 32, - sample_size: int = 320, - tile_overlap: tuple = (120, 80), - force_upcast: bool = True, - chunk_len: int = 24, - t_over: int = 8, - scale_factor: float = 0.13235, - blocks_tempdown_li=[True, True, False, False], - blocks_tempup_li=[False, True, True, False], - load_mode = 'full', - ): - super().__init__() - - self.blocks_tempdown_li = blocks_tempdown_li - self.blocks_tempup_li = blocks_tempup_li - # pass init params to Encoder - self.load_mode = load_mode - if load_mode in ['full', 'encoder_only']: - self.encoder = Encoder3D( - in_channels=in_channels, - out_channels=latent_channels, - num_blocks=down_block_num, - blocks_temp_li=blocks_tempdown_li, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - ) - self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) - - if load_mode in ['full', 'decoder_only']: - # pass init params to Decoder - self.decoder = Decoder3D( - in_channels=latent_channels, - out_channels=out_channels, - num_blocks=up_block_num, - blocks_temp_li=blocks_tempup_li, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - ) - self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) - - - # only relevant if vae tiling is enabled - sample_size = ( - sample_size[0] - if isinstance(sample_size, (list, tuple)) - else sample_size - ) - self.tile_overlap = tile_overlap - self.vae_scale_factor=[4, 8, 8] - self.scale_factor = scale_factor - self.sample_size = sample_size - self.chunk_len = chunk_len - self.t_over = t_over - - self.latent_chunk_len = self.chunk_len//4 - self.latent_t_over = self.t_over//4 - self.kernel = (self.chunk_len, self.sample_size, self.sample_size) #(24, 256, 256) - self.stride = (self.chunk_len - self.t_over, self.sample_size-self.tile_overlap[0], self.sample_size-self.tile_overlap[1]) # (16, 112, 192) - - - def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: - KERNEL = self.kernel - STRIDE = self.stride - LOCAL_BS = local_batch_size - OUT_C = 8 - - B, C, N, H, W = input_imgs.shape - - - out_n = math.floor((N - KERNEL[0]) / STRIDE[0]) + 1 - out_h = math.floor((H - KERNEL[1]) / STRIDE[1]) + 1 - out_w = math.floor((W - KERNEL[2]) / STRIDE[2]) + 1 - - ## cut video into overlapped small cubes and batch forward - num = 0 - - out_latent = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8), device=input_imgs.device, dtype=input_imgs.dtype) - vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype) - - for i in range(out_n): - for j in range(out_h): - for k in range(out_w): - n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] - h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] - w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] - video_cube = input_imgs[:, :, n_start:n_end, h_start:h_end, w_start:w_end] - vae_batch_input[num%LOCAL_BS] = video_cube - - if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1: - latent = self.encoder(vae_batch_input) - - if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1: - out_latent[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1] - else: - out_latent[num-LOCAL_BS+1:num+1] = latent - vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype) - num+=1 - - ## flatten the batched out latent to videos and supress the overlapped parts - B, C, N, H, W = input_imgs.shape - - out_video_cube = torch.zeros((B, OUT_C, N//4, H//8, W//8), device=input_imgs.device, dtype=input_imgs.dtype) - OUT_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8 - OUT_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8 - OVERLAP = OUT_KERNEL[0]-OUT_STRIDE[0], OUT_KERNEL[1]-OUT_STRIDE[1], OUT_KERNEL[2]-OUT_STRIDE[2] - - for i in range(out_n): - n_start, n_end = i * OUT_STRIDE[0], i * OUT_STRIDE[0] + OUT_KERNEL[0] - for j in range(out_h): - h_start, h_end = j * OUT_STRIDE[1], j * OUT_STRIDE[1] + OUT_KERNEL[1] - for k in range(out_w): - w_start, w_end = k * OUT_STRIDE[2], k * OUT_STRIDE[2] + OUT_KERNEL[2] - latent_mean_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), out_latent[i*out_h*out_w+j*out_w+k].unsqueeze(0)) - out_video_cube[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean_blend - - ## final conv - out_video_cube = rearrange(out_video_cube, 'b c n h w -> (b n) c h w') - out_video_cube = self.quant_conv(out_video_cube) - out_video_cube = rearrange(out_video_cube, '(b n) c h w -> b c n h w', b=B) - - posterior = DiagonalGaussianDistribution(out_video_cube) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - - def decode(self, input_latents: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[DecoderOutput, torch.Tensor]: - KERNEL = self.kernel - STRIDE = self.stride - - LOCAL_BS = local_batch_size - OUT_C = 3 - IN_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8 - IN_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8 - - B, C, N, H, W = input_latents.shape - - ## post quant conv (a mapping) - input_latents = rearrange(input_latents, 'b c n h w -> (b n) c h w') - input_latents = self.post_quant_conv(input_latents) - input_latents = rearrange(input_latents, '(b n) c h w -> b c n h w', b=B) - - ## out tensor shape - out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1 - out_h = math.floor((H - IN_KERNEL[1]) / IN_STRIDE[1]) + 1 - out_w = math.floor((W - IN_KERNEL[2]) / IN_STRIDE[2]) + 1 - - ## cut latent into overlapped small cubes and batch forward - num = 0 - decoded_cube = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) - vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) - for i in range(out_n): - for j in range(out_h): - for k in range(out_w): - n_start, n_end = i * IN_STRIDE[0], i * IN_STRIDE[0] + IN_KERNEL[0] - h_start, h_end = j * IN_STRIDE[1], j * IN_STRIDE[1] + IN_KERNEL[1] - w_start, w_end = k * IN_STRIDE[2], k * IN_STRIDE[2] + IN_KERNEL[2] - latent_cube = input_latents[:, :, n_start:n_end, h_start:h_end, w_start:w_end] - vae_batch_input[num%LOCAL_BS] = latent_cube - if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1: - - latent = self.decoder(vae_batch_input) - - if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1: - decoded_cube[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1] - else: - decoded_cube[num-LOCAL_BS+1:num+1] = latent - vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype) - num+=1 - B, C, N, H, W = input_latents.shape - - out_video = torch.zeros((B, OUT_C, N*4, H*8, W*8), device=input_latents.device, dtype=input_latents.dtype) - OVERLAP = KERNEL[0]-STRIDE[0], KERNEL[1]-STRIDE[1], KERNEL[2]-STRIDE[2] - for i in range(out_n): - n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0] - for j in range(out_h): - h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1] - for k in range(out_w): - w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2] - out_video_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), decoded_cube[i*out_h*out_w+j*out_w+k].unsqueeze(0)) - out_video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend - - out_video = rearrange(out_video, 'b c t h w -> b t c h w').contiguous() - - decoded = out_video - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) - - def forward( - self, - sample: torch.Tensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - encoder_local_batch_size: int = 2, - decoder_local_batch_size: int = 2, - ) -> Union[DecoderOutput, torch.Tensor]: - r""" - Args: - sample (`torch.Tensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - generator (`torch.Generator`, *optional*): - PyTorch random number generator. - encoder_local_batch_size (`int`, *optional*, defaults to 2): - Local batch size for the encoder's batch inference. - decoder_local_batch_size (`int`, *optional*, defaults to 2): - Local batch size for the decoder's batch inference. - """ - x = sample - posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - kwargs["torch_type"] = torch.float32 - return super().from_pretrained(pretrained_model_name_or_path, **kwargs) - - -def prepare_for_blend(n_param, h_param, w_param, x): - n, n_max, overlap_n = n_param - h, h_max, overlap_h = h_param - w, w_max, overlap_w = w_param - if overlap_n > 0: - if n > 0: # the head overlap part decays from 0 to 1 - x[:,:,0:overlap_n,:,:] = x[:,:,0:overlap_n,:,:] * (torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1) - if n < n_max-1: # the tail overlap part decays from 1 to 0 - x[:,:,-overlap_n:,:,:] = x[:,:,-overlap_n:,:,:] * (1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1) - if h > 0: - x[:,:,:,0:overlap_h,:] = x[:,:,:,0:overlap_h,:] * (torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1) - if h < h_max-1: - x[:,:,:,-overlap_h:,:] = x[:,:,:,-overlap_h:,:] * (1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1) - if w > 0: - x[:,:,:,:,0:overlap_w] = x[:,:,:,:,0:overlap_w] * (torch.arange(0, overlap_w).float().to(x.device) / overlap_w) - if w < w_max-1: - x[:,:,:,:,-overlap_w:] = x[:,:,:,:,-overlap_w:] * (1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w) - return x