Upload folder using huggingface_hub

This commit is contained in:
YuxuanCai 2024-10-20 11:46:03 +00:00 committed by system
parent 3f99bb766a
commit 547b60eab3
No known key found for this signature in database
GPG Key ID: 6A528E38E0733467
17 changed files with 5145 additions and 0 deletions

27
model_index.json Normal file

@ -0,0 +1,27 @@
{
"_class_name": [
"pipeline_allegro",
"AllegroPipeline"
],
"_diffusers_version": "0.28.0",
"scheduler": [
"diffusers",
"EulerAncestralDiscreteScheduler"
],
"text_encoder": [
"transformers",
"T5EncoderModel"
],
"tokenizer": [
"transformers",
"T5Tokenizer"
],
"transformer": [
"transformer_3d_allegro",
"AllegroTransformer3DModel"
],
"vae": [
"vae_allegro",
"AllegroAutoencoderKL3D"
]
}

832
pipeline_allegro.py Normal file

@ -0,0 +1,832 @@
# Adapted from Open-Sora-Plan
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
# --------------------------------------------------------
import html
import inspect
import math
import re
import urllib.parse as ul
from typing import Callable, List, Optional, Tuple, Union
from einops import rearrange
import ftfy
import torch
from dataclasses import dataclass
import tqdm
from bs4 import BeautifulSoup
from diffusers import DiffusionPipeline, ModelMixin
from diffusers.schedulers import EulerAncestralDiscreteScheduler
from diffusers.utils import (
BACKENDS_MAPPING,
is_bs4_available,
is_ftfy_available,
logging,
replace_example_docstring,
BaseOutput
)
from diffusers.utils.torch_utils import randn_tensor
from transformers import T5EncoderModel, T5Tokenizer
logger = logging.get_logger(__name__)
# from transformer_3d_allegro import AllegroTransformer3DModel
# from vae_allegro import AllegroAutoencoderKL3D
@dataclass
class AllegroPipelineOutput(BaseOutput):
r"""
Output class for Allegro pipelines.
Args:
video (`torch.Tensor`):
Torch tensor with shape `(batch_size, num_frames, channels, height, width)`.
"""
video: torch.Tensor
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> # You can replace the your_path_to_model with your own path.
>>> pipe = AllegroPipeline.from_pretrained(your_path_to_model, torch_dtype=torch.float16, trust_remote_code=True)
>>> prompt = "A small cactus with a happy face in the Sahara desert."
>>> image = pipe(prompt).video[0]
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class AllegroPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Allegro.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AllegroAutoEncoderKL3D`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. PixArt-Alpha uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`AllegroTransformer3DModel`]):
A text conditioned `AllegroTransformer3DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
bad_punct_regex = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
self,
tokenizer: Optional[T5Tokenizer] = None,
text_encoder: Optional[T5EncoderModel] = None,
vae: Optional[ModelMixin] = None,
transformer: Optional[ModelMixin] = None,
scheduler: Optional[EulerAncestralDiscreteScheduler] = None,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.float16,
):
super().__init__()
# # init
# if tokenizer is None:
# tokenizer = T5Tokenizer.from_pretrained(tokenizer)
# if text_encoder is None:
# text_encoder = T5EncoderModel.from_pretrained(text_encoder, torch_dtype=torch.float16)
# if vae is None:
# vae = AllegroAutoencoderKL3D.from_pretrained(vae).to(dtype=torch.float32)
# if transformer is None:
# transformer = AllegroTransformer3DModel.from_pretrained(transformer, torch_dtype=dtype)
# if scheduler is None:
# scheduler = EulerAncestralDiscreteScheduler()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 120,
**kwargs,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
PixArt-Alpha, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
string.
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
"""
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper.
max_length = max_sequence_length
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0]
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=negative_prompt_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
num_frames,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if num_frames <= 0:
raise ValueError(f"`num_frames` have to be positive but is {num_frames}.")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
text = [text]
def process(text: str):
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
text = text.lower().strip()
return text
return [process(t) for t in text]
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",
# noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",
# noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
# caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
# noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = ftfy.fix_text(caption)
caption = html.unescape(html.unescape(caption))
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
shape = (
batch_size,
num_channels_latents,
(math.ceil((int(num_frames) - 1) / self.vae.vae_scale_factor[0]) + 1)
if int(num_frames) % 2 == 1
else math.ceil(int(num_frames) / self.vae.vae_scale_factor[0]),
math.ceil(int(height) / self.vae.vae_scale_factor[1]),
math.ceil(int(width) / self.vae.vae_scale_factor[2]),
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 100,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
max_sequence_length: int = 512,
verbose: bool = True,
) -> Union[AllegroPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
num_frames: (`int`, *optional*, defaults to 88):
The number controls the generated video frames.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
# 1. Check inputs. Raise error if not correct
num_frames = num_frames or self.transformer.config.sample_size_t * self.vae.vae_scale_factor[0]
height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1]
width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2]
self.check_inputs(
prompt,
num_frames,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
progress_wrap = tqdm.tqdm if verbose else (lambda x: x)
for i, t in progress_wrap(list(enumerate(timesteps))):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
if prompt_attention_mask.ndim == 2:
prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l
# prepare attention_mask.
# b c t h w -> b t h w
attention_mask = torch.ones_like(latent_model_input)[:, 0]
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
attention_mask=attention_mask,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latents":
video = self.decode_latents(latents)
video = video[:, :num_frames, :height, :width]
else:
video = latents
return AllegroPipelineOutput(video=video)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return AllegroPipelineOutput(video=video)
def decode_latents(self, latents):
video = self.vae.decode(latents.to(self.vae.dtype) / self.vae.scale_factor).sample
# b t c h w -> b t h w c
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
return video

@ -0,0 +1,13 @@
{
"_class_name": "EulerAncestralDiscreteScheduler",
"_diffusers_version": "0.28.0",
"beta_end": 0.02,
"beta_schedule": "linear",
"beta_start": 0.0001,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"rescale_betas_zero_snr": false,
"steps_offset": 0,
"timestep_spacing": "linspace",
"trained_betas": null
}

30
text_encoder/config.json Normal file

@ -0,0 +1,30 @@
{
"architectures": [
"T5EncoderModel"
],
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"torch_dtype": "float32",
"transformers_version": "4.21.1",
"use_cache": true,
"vocab_size": 32128
}

BIN
text_encoder/pytorch_model-00001-of-00002.bin (Stored with Git LFS) Normal file

Binary file not shown.

BIN
text_encoder/pytorch_model-00002-of-00002.bin (Stored with Git LFS) Normal file

Binary file not shown.

@ -0,0 +1,227 @@
{
"metadata": {
"total_size": 19575627776
},
"weight_map": {
"encoder.block.0.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.0.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.0.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.0.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.0.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.0.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.0.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.1.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.1.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.1.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.1.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.1.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.1.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.1.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.10.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.10.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.10.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.10.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.10.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.10.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.10.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.11.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.11.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.11.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.11.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.11.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.11.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.11.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.12.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.12.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.12.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.12.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.12.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.12.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.12.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.13.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.13.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.13.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.13.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.13.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.13.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.13.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.14.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.14.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.14.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.14.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.14.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.14.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.14.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.15.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.15.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.15.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.15.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.15.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.15.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.15.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.16.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.16.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.16.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.16.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.16.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.16.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.16.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.17.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.17.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.17.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.17.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.17.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.17.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.17.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.18.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.18.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.18.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.18.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.18.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.18.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.18.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.19.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.19.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.19.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.19.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.19.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.19.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.19.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.2.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.2.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.2.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.2.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.2.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.2.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.2.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.20.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.20.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.20.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.20.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.20.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.20.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.20.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.21.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.21.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.21.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.21.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.21.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.21.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.21.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.22.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.22.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.22.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.22.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.22.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.22.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.22.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.23.layer.0.SelfAttention.k.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.23.layer.0.SelfAttention.o.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.23.layer.0.SelfAttention.q.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.23.layer.0.SelfAttention.v.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.23.layer.0.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.23.layer.1.DenseReluDense.wo.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.23.layer.1.layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"encoder.block.3.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.3.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.3.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.3.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.3.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.3.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.3.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.4.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.4.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.4.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.4.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.4.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.4.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.4.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.5.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.5.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.5.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.5.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.5.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.5.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.5.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.6.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.6.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.6.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.6.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.6.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.6.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.6.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.7.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.7.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.7.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.7.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.7.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.7.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.7.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.8.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.8.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.8.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.8.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.8.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.8.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.8.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.9.layer.0.SelfAttention.k.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.9.layer.0.SelfAttention.o.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.9.layer.0.SelfAttention.q.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.9.layer.0.SelfAttention.v.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.9.layer.0.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.9.layer.1.DenseReluDense.wo.weight": "pytorch_model-00001-of-00002.bin",
"encoder.block.9.layer.1.layer_norm.weight": "pytorch_model-00001-of-00002.bin",
"encoder.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
"encoder.final_layer_norm.weight": "pytorch_model-00002-of-00002.bin",
"shared.weight": "pytorch_model-00001-of-00002.bin"
}
}

102
tokenizer/added_tokens.json Normal file

@ -0,0 +1,102 @@
{
"<extra_id_0>": 32099,
"<extra_id_10>": 32089,
"<extra_id_11>": 32088,
"<extra_id_12>": 32087,
"<extra_id_13>": 32086,
"<extra_id_14>": 32085,
"<extra_id_15>": 32084,
"<extra_id_16>": 32083,
"<extra_id_17>": 32082,
"<extra_id_18>": 32081,
"<extra_id_19>": 32080,
"<extra_id_1>": 32098,
"<extra_id_20>": 32079,
"<extra_id_21>": 32078,
"<extra_id_22>": 32077,
"<extra_id_23>": 32076,
"<extra_id_24>": 32075,
"<extra_id_25>": 32074,
"<extra_id_26>": 32073,
"<extra_id_27>": 32072,
"<extra_id_28>": 32071,
"<extra_id_29>": 32070,
"<extra_id_2>": 32097,
"<extra_id_30>": 32069,
"<extra_id_31>": 32068,
"<extra_id_32>": 32067,
"<extra_id_33>": 32066,
"<extra_id_34>": 32065,
"<extra_id_35>": 32064,
"<extra_id_36>": 32063,
"<extra_id_37>": 32062,
"<extra_id_38>": 32061,
"<extra_id_39>": 32060,
"<extra_id_3>": 32096,
"<extra_id_40>": 32059,
"<extra_id_41>": 32058,
"<extra_id_42>": 32057,
"<extra_id_43>": 32056,
"<extra_id_44>": 32055,
"<extra_id_45>": 32054,
"<extra_id_46>": 32053,
"<extra_id_47>": 32052,
"<extra_id_48>": 32051,
"<extra_id_49>": 32050,
"<extra_id_4>": 32095,
"<extra_id_50>": 32049,
"<extra_id_51>": 32048,
"<extra_id_52>": 32047,
"<extra_id_53>": 32046,
"<extra_id_54>": 32045,
"<extra_id_55>": 32044,
"<extra_id_56>": 32043,
"<extra_id_57>": 32042,
"<extra_id_58>": 32041,
"<extra_id_59>": 32040,
"<extra_id_5>": 32094,
"<extra_id_60>": 32039,
"<extra_id_61>": 32038,
"<extra_id_62>": 32037,
"<extra_id_63>": 32036,
"<extra_id_64>": 32035,
"<extra_id_65>": 32034,
"<extra_id_66>": 32033,
"<extra_id_67>": 32032,
"<extra_id_68>": 32031,
"<extra_id_69>": 32030,
"<extra_id_6>": 32093,
"<extra_id_70>": 32029,
"<extra_id_71>": 32028,
"<extra_id_72>": 32027,
"<extra_id_73>": 32026,
"<extra_id_74>": 32025,
"<extra_id_75>": 32024,
"<extra_id_76>": 32023,
"<extra_id_77>": 32022,
"<extra_id_78>": 32021,
"<extra_id_79>": 32020,
"<extra_id_7>": 32092,
"<extra_id_80>": 32019,
"<extra_id_81>": 32018,
"<extra_id_82>": 32017,
"<extra_id_83>": 32016,
"<extra_id_84>": 32015,
"<extra_id_85>": 32014,
"<extra_id_86>": 32013,
"<extra_id_87>": 32012,
"<extra_id_88>": 32011,
"<extra_id_89>": 32010,
"<extra_id_8>": 32091,
"<extra_id_90>": 32009,
"<extra_id_91>": 32008,
"<extra_id_92>": 32007,
"<extra_id_93>": 32006,
"<extra_id_94>": 32005,
"<extra_id_95>": 32004,
"<extra_id_96>": 32003,
"<extra_id_97>": 32002,
"<extra_id_98>": 32001,
"<extra_id_99>": 32000,
"<extra_id_9>": 32090
}

@ -0,0 +1,125 @@
{
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

BIN
tokenizer/spiece.model (Stored with Git LFS) Normal file

Binary file not shown.

@ -0,0 +1,940 @@
{
"add_prefix_space": true,
"added_tokens_decoder": {
"0": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32000": {
"content": "<extra_id_99>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32001": {
"content": "<extra_id_98>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32002": {
"content": "<extra_id_97>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32003": {
"content": "<extra_id_96>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32004": {
"content": "<extra_id_95>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32005": {
"content": "<extra_id_94>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32006": {
"content": "<extra_id_93>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32007": {
"content": "<extra_id_92>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32008": {
"content": "<extra_id_91>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32009": {
"content": "<extra_id_90>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32010": {
"content": "<extra_id_89>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32011": {
"content": "<extra_id_88>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32012": {
"content": "<extra_id_87>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32013": {
"content": "<extra_id_86>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32014": {
"content": "<extra_id_85>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32015": {
"content": "<extra_id_84>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32016": {
"content": "<extra_id_83>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32017": {
"content": "<extra_id_82>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32018": {
"content": "<extra_id_81>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32019": {
"content": "<extra_id_80>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32020": {
"content": "<extra_id_79>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32021": {
"content": "<extra_id_78>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32022": {
"content": "<extra_id_77>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32023": {
"content": "<extra_id_76>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32024": {
"content": "<extra_id_75>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32025": {
"content": "<extra_id_74>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32026": {
"content": "<extra_id_73>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32027": {
"content": "<extra_id_72>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32028": {
"content": "<extra_id_71>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32029": {
"content": "<extra_id_70>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32030": {
"content": "<extra_id_69>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32031": {
"content": "<extra_id_68>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32032": {
"content": "<extra_id_67>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32033": {
"content": "<extra_id_66>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32034": {
"content": "<extra_id_65>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32035": {
"content": "<extra_id_64>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32036": {
"content": "<extra_id_63>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32037": {
"content": "<extra_id_62>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32038": {
"content": "<extra_id_61>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32039": {
"content": "<extra_id_60>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32040": {
"content": "<extra_id_59>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32041": {
"content": "<extra_id_58>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32042": {
"content": "<extra_id_57>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32043": {
"content": "<extra_id_56>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32044": {
"content": "<extra_id_55>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32045": {
"content": "<extra_id_54>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32046": {
"content": "<extra_id_53>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32047": {
"content": "<extra_id_52>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32048": {
"content": "<extra_id_51>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32049": {
"content": "<extra_id_50>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32050": {
"content": "<extra_id_49>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32051": {
"content": "<extra_id_48>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32052": {
"content": "<extra_id_47>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32053": {
"content": "<extra_id_46>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32054": {
"content": "<extra_id_45>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32055": {
"content": "<extra_id_44>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32056": {
"content": "<extra_id_43>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32057": {
"content": "<extra_id_42>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32058": {
"content": "<extra_id_41>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32059": {
"content": "<extra_id_40>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32060": {
"content": "<extra_id_39>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32061": {
"content": "<extra_id_38>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32062": {
"content": "<extra_id_37>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32063": {
"content": "<extra_id_36>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32064": {
"content": "<extra_id_35>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32065": {
"content": "<extra_id_34>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32066": {
"content": "<extra_id_33>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32067": {
"content": "<extra_id_32>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32068": {
"content": "<extra_id_31>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32069": {
"content": "<extra_id_30>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32070": {
"content": "<extra_id_29>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32071": {
"content": "<extra_id_28>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32072": {
"content": "<extra_id_27>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32073": {
"content": "<extra_id_26>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32074": {
"content": "<extra_id_25>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32075": {
"content": "<extra_id_24>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32076": {
"content": "<extra_id_23>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32077": {
"content": "<extra_id_22>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32078": {
"content": "<extra_id_21>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32079": {
"content": "<extra_id_20>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32080": {
"content": "<extra_id_19>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32081": {
"content": "<extra_id_18>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32082": {
"content": "<extra_id_17>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32083": {
"content": "<extra_id_16>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32084": {
"content": "<extra_id_15>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32085": {
"content": "<extra_id_14>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32086": {
"content": "<extra_id_13>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32087": {
"content": "<extra_id_12>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32088": {
"content": "<extra_id_11>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32089": {
"content": "<extra_id_10>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32090": {
"content": "<extra_id_9>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32091": {
"content": "<extra_id_8>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32092": {
"content": "<extra_id_7>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32093": {
"content": "<extra_id_6>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32094": {
"content": "<extra_id_5>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32095": {
"content": "<extra_id_4>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32096": {
"content": "<extra_id_3>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32097": {
"content": "<extra_id_2>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32098": {
"content": "<extra_id_1>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32099": {
"content": "<extra_id_0>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"clean_up_tokenization_spaces": true,
"eos_token": "</s>",
"extra_ids": 100,
"legacy": true,
"model_max_length": 512,
"pad_token": "<pad>",
"sp_model_kwargs": {},
"tokenizer_class": "T5Tokenizer",
"unk_token": "<unk>"
}

39
transformer/config.json Normal file

@ -0,0 +1,39 @@
{
"_class_name": "AllegroTransformer3DModel",
"_diffusers_version": "0.28.0",
"_name_or_path": "/cpfs/data/user/yanghuan/expr/rsora/RSoraT2V_L32AH24AD96_122_20240918_88x720x1280_fps15_t5/checkpoint-38000/model",
"activation_fn": "gelu-approximate",
"attention_bias": true,
"attention_head_dim": 96,
"ca_attention_mode": "xformers",
"caption_channels": 4096,
"cross_attention_dim": 2304,
"double_self_attention": false,
"downsampler": null,
"dropout": 0.0,
"in_channels": 4,
"interpolation_scale_h": 2.0,
"interpolation_scale_t": 2.2,
"interpolation_scale_w": 2.0,
"model_max_length": 300,
"norm_elementwise_affine": false,
"norm_eps": 1e-06,
"norm_type": "ada_norm_single",
"num_attention_heads": 24,
"num_embeds_ada_norm": 1000,
"num_layers": 32,
"only_cross_attention": false,
"out_channels": 4,
"patch_size": 2,
"patch_size_t": 1,
"sa_attention_mode": "flash",
"sample_size": [
90,
160
],
"sample_size_t": 22,
"upcast_attention": false,
"use_additional_conditions": null,
"use_linear_projection": false,
"use_rope": true
}

BIN
transformer/diffusion_pytorch_model.safetensors (Stored with Git LFS) Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

41
vae/config.json Normal file

@ -0,0 +1,41 @@
{
"_class_name": "AllegroAutoencoderKL3D",
"_diffusers_version": "0.28.0",
"_name_or_path": "/cpfs/data/user/larrytsai/Projects/Yi-VG/allegro_pipeline/vae",
"act_fn": "silu",
"block_out_channels": [
128,
256,
512,
512
],
"blocks_tempdown_li": [
true,
true,
false,
false
],
"blocks_tempup_li": [
false,
true,
true,
false
],
"chunk_len": 24,
"down_block_num": 4,
"force_upcast": true,
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"load_mode": "full",
"norm_num_groups": 32,
"out_channels": 3,
"sample_size": 320,
"scale_factor": 0.13,
"t_over": 8,
"tile_overlap": [
120,
80
],
"up_block_num": 4
}

BIN
vae/diffusion_pytorch_model.safetensors (Stored with Git LFS) Normal file

Binary file not shown.

978
vae/vae_allegro.py Normal file

@ -0,0 +1,978 @@
import math
from dataclasses import dataclass
import os
from typing import Dict, Optional, Tuple, Union
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.attention_processor import Attention
from diffusers.models.resnet import ResnetBlock2D
from diffusers.models.upsampling import Upsample2D
from diffusers.models.downsampling import Downsample2D
from diffusers.models.attention_processor import SpatialNorm
class TemporalConvBlock(nn.Module):
"""
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
"""
def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
self.out_dim = out_dim
spa_pad = int((spa_stride-1)*0.5)
temp_pad = 0
self.temp_pad = temp_pad
if down_sample:
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad))
)
elif up_sample:
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad))
)
else:
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad))
)
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
)
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
)
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)),
)
# zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
self.down_sample = down_sample
self.up_sample = up_sample
def forward(self, hidden_states):
identity = hidden_states
if self.down_sample:
identity = identity[:,:,::2]
elif self.up_sample:
hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2)
hidden_states_new[:, :, 0::2] = hidden_states
hidden_states_new[:, :, 1::2] = hidden_states
identity = hidden_states_new
del hidden_states_new
if self.down_sample or self.up_sample:
hidden_states = self.conv1(hidden_states)
else:
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
hidden_states = self.conv1(hidden_states)
if self.up_sample:
hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2)
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
hidden_states = self.conv2(hidden_states)
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
hidden_states = self.conv3(hidden_states)
hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2)
hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2)
hidden_states = self.conv4(hidden_states)
hidden_states = identity + hidden_states
return hidden_states
class DownEncoderBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
add_temp_downsample=False,
downsample_padding=1,
):
super().__init__()
resnets = []
temp_convs = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
temp_convs.append(
TemporalConvBlock(
out_channels,
out_channels,
dropout=0.1,
)
)
self.resnets = nn.ModuleList(resnets)
self.temp_convs = nn.ModuleList(temp_convs)
if add_temp_downsample:
self.temp_convs_down = TemporalConvBlock(
out_channels,
out_channels,
dropout=0.1,
down_sample=True,
spa_stride=3
)
self.add_temp_downsample = add_temp_downsample
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
def _set_partial_grad(self):
for temp_conv in self.temp_convs:
temp_conv.requires_grad_(True)
if self.downsamplers:
for down_layer in self.downsamplers:
down_layer.requires_grad_(True)
def forward(self, hidden_states):
bz = hidden_states.shape[0]
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
hidden_states = resnet(hidden_states, temb=None)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
hidden_states = temp_conv(hidden_states)
if self.add_temp_downsample:
hidden_states = self.temp_convs_down(hidden_states)
if self.downsamplers is not None:
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
for upsampler in self.downsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
return hidden_states
class UpDecoderBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
add_temp_upsample=False,
temb_channels=None,
):
super().__init__()
self.add_upsample = add_upsample
resnets = []
temp_convs = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
temp_convs.append(
TemporalConvBlock(
out_channels,
out_channels,
dropout=0.1,
)
)
self.resnets = nn.ModuleList(resnets)
self.temp_convs = nn.ModuleList(temp_convs)
self.add_temp_upsample = add_temp_upsample
if add_temp_upsample:
self.temp_conv_up = TemporalConvBlock(
out_channels,
out_channels,
dropout=0.1,
up_sample=True,
spa_stride=3
)
if self.add_upsample:
# self.upsamplers = nn.ModuleList([PSUpsample2D(out_channels, use_conv=True, use_pixel_shuffle=True, out_channels=out_channels)])
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
def _set_partial_grad(self):
for temp_conv in self.temp_convs:
temp_conv.requires_grad_(True)
if self.add_upsample:
self.upsamplers.requires_grad_(True)
def forward(self, hidden_states):
bz = hidden_states.shape[0]
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
hidden_states = resnet(hidden_states, temb=None)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
hidden_states = temp_conv(hidden_states)
if self.add_temp_upsample:
hidden_states = self.temp_conv_up(hidden_states)
if self.upsamplers is not None:
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
return hidden_states
class UNetMidBlock3DConv(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim=1,
output_scale_factor=1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
temp_convs = [
TemporalConvBlock(
in_channels,
in_channels,
dropout=0.1,
)
]
attentions = []
if attention_head_dim is None:
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
temp_convs.append(
TemporalConvBlock(
in_channels,
in_channels,
dropout=0.1,
)
)
self.resnets = nn.ModuleList(resnets)
self.temp_convs = nn.ModuleList(temp_convs)
self.attentions = nn.ModuleList(attentions)
def _set_partial_grad(self):
for temp_conv in self.temp_convs:
temp_conv.requires_grad_(True)
def forward(
self,
hidden_states,
):
bz = hidden_states.shape[0]
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
hidden_states = self.resnets[0](hidden_states, temb=None)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
hidden_states = self.temp_convs[0](hidden_states)
hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w')
for attn, resnet, temp_conv in zip(
self.attentions, self.resnets[1:], self.temp_convs[1:]
):
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb=None)
hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz)
hidden_states = temp_conv(hidden_states)
return hidden_states
class Encoder3D(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
num_blocks=4,
blocks_temp_li=[False, False, False, False],
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.blocks_temp_li = blocks_temp_li
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)
self.temp_conv_in = nn.Conv3d(
block_out_channels[0],
block_out_channels[0],
(3,1,1),
padding = (1, 0, 0)
)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i in range(num_blocks):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = DownEncoderBlock3D(
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
add_temp_downsample=blocks_temp_li[i],
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock3DConv(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3,1,1), padding = (1, 0, 0))
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
nn.init.zeros_(self.temp_conv_in.weight)
nn.init.zeros_(self.temp_conv_in.bias)
nn.init.zeros_(self.temp_conv_out.weight)
nn.init.zeros_(self.temp_conv_out.bias)
self.gradient_checkpointing = False
def forward(self, x):
'''
x: [b, c, (tb f), h, w]
'''
bz = x.shape[0]
sample = rearrange(x, 'b c n h w -> (b n) c h w')
sample = self.conv_in(sample)
sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
temp_sample = sample
sample = self.temp_conv_in(sample)
sample = sample+temp_sample
# down
for b_id, down_block in enumerate(self.down_blocks):
sample = down_block(sample)
# middle
sample = self.mid_block(sample)
# post-process
sample = rearrange(sample, 'b c n h w -> (b n) c h w')
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
temp_sample = sample
sample = self.temp_conv_out(sample)
sample = sample+temp_sample
sample = rearrange(sample, 'b c n h w -> (b n) c h w')
sample = self.conv_out(sample)
sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
return sample
class Decoder3D(nn.Module):
def __init__(
self,
in_channels=4,
out_channels=3,
num_blocks=4,
blocks_temp_li=[False, False, False, False],
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group", # group, spatial
):
super().__init__()
self.layers_per_block = layers_per_block
self.blocks_temp_li = blocks_temp_li
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.temp_conv_in = nn.Conv3d(
block_out_channels[-1],
block_out_channels[-1],
(3,1,1),
padding = (1, 0, 0)
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock3DConv(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(num_blocks):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpDecoderBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block,
add_temp_upsample=blocks_temp_li[i],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3,1,1), padding = (1, 0, 0))
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
nn.init.zeros_(self.temp_conv_in.weight)
nn.init.zeros_(self.temp_conv_in.bias)
nn.init.zeros_(self.temp_conv_out.weight)
nn.init.zeros_(self.temp_conv_out.bias)
self.gradient_checkpointing = False
def forward(self, z):
bz = z.shape[0]
sample = rearrange(z, 'b c n h w -> (b n) c h w')
sample = self.conv_in(sample)
sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
temp_sample = sample
sample = self.temp_conv_in(sample)
sample = sample+temp_sample
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
# middle
sample = self.mid_block(sample)
sample = sample.to(upscale_dtype)
# up
for b_id, up_block in enumerate(self.up_blocks):
sample = up_block(sample)
# post-process
sample = rearrange(sample, 'b c n h w -> (b n) c h w')
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
temp_sample = sample
sample = self.temp_conv_out(sample)
sample = sample+temp_sample
sample = rearrange(sample, 'b c n h w -> (b n) c h w')
sample = self.conv_out(sample)
sample = rearrange(sample, '(b n) c h w -> b c n h w', b=bz)
return sample
class AllegroAutoencoderKL3D(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `256`): Spatial Tiling Size.
tile_overlap (`tuple`, *optional*, defaults to `(120, 80`): Spatial overlapping size while tiling (height, width)
chunk_len (`int`, *optional*, defaults to `24`): Temporal Tiling Size.
t_over (`int`, *optional*, defaults to `8`): Temporal overlapping size while tiling
scaling_factor (`float`, *optional*, defaults to 0.13235):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
blocks_tempdown_li (`List`, *optional*, defaults to `[True, True, False, False]`): Each item indicates whether each TemporalBlock in the Encoder performs temporal downsampling.
blocks_tempup_li (`List`, *optional*, defaults to `[False, True, True, False]`): Each item indicates whether each TemporalBlock in the Decoder performs temporal upsampling.
load_mode (`str`, *optional*, defaults to `full`): Load mode for the model. Can be one of `full`, `encoder_only`, `decoder_only`. which corresponds to loading the full model state dicts, only the encoder state dicts, or only the decoder state dicts.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_num: int = 4,
up_block_num: int = 4,
block_out_channels: Tuple[int] = (128,256,512,512),
layers_per_block: int = 2,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 320,
tile_overlap: tuple = (120, 80),
force_upcast: bool = True,
chunk_len: int = 24,
t_over: int = 8,
scale_factor: float = 0.13235,
blocks_tempdown_li=[True, True, False, False],
blocks_tempup_li=[False, True, True, False],
load_mode = 'full',
):
super().__init__()
self.blocks_tempdown_li = blocks_tempdown_li
self.blocks_tempup_li = blocks_tempup_li
# pass init params to Encoder
self.load_mode = load_mode
if load_mode in ['full', 'encoder_only']:
self.encoder = Encoder3D(
in_channels=in_channels,
out_channels=latent_channels,
num_blocks=down_block_num,
blocks_temp_li=blocks_tempdown_li,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
if load_mode in ['full', 'decoder_only']:
# pass init params to Decoder
self.decoder = Decoder3D(
in_channels=latent_channels,
out_channels=out_channels,
num_blocks=up_block_num,
blocks_temp_li=blocks_tempup_li,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
# only relevant if vae tiling is enabled
sample_size = (
sample_size[0]
if isinstance(sample_size, (list, tuple))
else sample_size
)
self.tile_overlap = tile_overlap
self.vae_scale_factor=[4, 8, 8]
self.scale_factor = scale_factor
self.sample_size = sample_size
self.chunk_len = chunk_len
self.t_over = t_over
self.latent_chunk_len = self.chunk_len//4
self.latent_t_over = self.t_over//4
self.kernel = (self.chunk_len, self.sample_size, self.sample_size) #(24, 256, 256)
self.stride = (self.chunk_len - self.t_over, self.sample_size-self.tile_overlap[0], self.sample_size-self.tile_overlap[1]) # (16, 112, 192)
def encode(self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
KERNEL = self.kernel
STRIDE = self.stride
LOCAL_BS = local_batch_size
OUT_C = 8
B, C, N, H, W = input_imgs.shape
out_n = math.floor((N - KERNEL[0]) / STRIDE[0]) + 1
out_h = math.floor((H - KERNEL[1]) / STRIDE[1]) + 1
out_w = math.floor((W - KERNEL[2]) / STRIDE[2]) + 1
## cut video into overlapped small cubes and batch forward
num = 0
out_latent = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8), device=input_imgs.device, dtype=input_imgs.dtype)
vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype)
for i in range(out_n):
for j in range(out_h):
for k in range(out_w):
n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0]
h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1]
w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2]
video_cube = input_imgs[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
vae_batch_input[num%LOCAL_BS] = video_cube
if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1:
latent = self.encoder(vae_batch_input)
if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1:
out_latent[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1]
else:
out_latent[num-LOCAL_BS+1:num+1] = latent
vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_imgs.device, dtype=input_imgs.dtype)
num+=1
## flatten the batched out latent to videos and supress the overlapped parts
B, C, N, H, W = input_imgs.shape
out_video_cube = torch.zeros((B, OUT_C, N//4, H//8, W//8), device=input_imgs.device, dtype=input_imgs.dtype)
OUT_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8
OUT_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8
OVERLAP = OUT_KERNEL[0]-OUT_STRIDE[0], OUT_KERNEL[1]-OUT_STRIDE[1], OUT_KERNEL[2]-OUT_STRIDE[2]
for i in range(out_n):
n_start, n_end = i * OUT_STRIDE[0], i * OUT_STRIDE[0] + OUT_KERNEL[0]
for j in range(out_h):
h_start, h_end = j * OUT_STRIDE[1], j * OUT_STRIDE[1] + OUT_KERNEL[1]
for k in range(out_w):
w_start, w_end = k * OUT_STRIDE[2], k * OUT_STRIDE[2] + OUT_KERNEL[2]
latent_mean_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), out_latent[i*out_h*out_w+j*out_w+k].unsqueeze(0))
out_video_cube[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean_blend
## final conv
out_video_cube = rearrange(out_video_cube, 'b c n h w -> (b n) c h w')
out_video_cube = self.quant_conv(out_video_cube)
out_video_cube = rearrange(out_video_cube, '(b n) c h w -> b c n h w', b=B)
posterior = DiagonalGaussianDistribution(out_video_cube)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, input_latents: torch.Tensor, return_dict: bool = True, local_batch_size=1) -> Union[DecoderOutput, torch.Tensor]:
KERNEL = self.kernel
STRIDE = self.stride
LOCAL_BS = local_batch_size
OUT_C = 3
IN_KERNEL = KERNEL[0]//4, KERNEL[1]//8, KERNEL[2]//8
IN_STRIDE = STRIDE[0]//4, STRIDE[1]//8, STRIDE[2]//8
B, C, N, H, W = input_latents.shape
## post quant conv (a mapping)
input_latents = rearrange(input_latents, 'b c n h w -> (b n) c h w')
input_latents = self.post_quant_conv(input_latents)
input_latents = rearrange(input_latents, '(b n) c h w -> b c n h w', b=B)
## out tensor shape
out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1
out_h = math.floor((H - IN_KERNEL[1]) / IN_STRIDE[1]) + 1
out_w = math.floor((W - IN_KERNEL[2]) / IN_STRIDE[2]) + 1
## cut latent into overlapped small cubes and batch forward
num = 0
decoded_cube = torch.zeros((out_n*out_h*out_w, OUT_C, KERNEL[0], KERNEL[1], KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype)
vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype)
for i in range(out_n):
for j in range(out_h):
for k in range(out_w):
n_start, n_end = i * IN_STRIDE[0], i * IN_STRIDE[0] + IN_KERNEL[0]
h_start, h_end = j * IN_STRIDE[1], j * IN_STRIDE[1] + IN_KERNEL[1]
w_start, w_end = k * IN_STRIDE[2], k * IN_STRIDE[2] + IN_KERNEL[2]
latent_cube = input_latents[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
vae_batch_input[num%LOCAL_BS] = latent_cube
if num%LOCAL_BS == LOCAL_BS-1 or num == out_n*out_h*out_w-1:
latent = self.decoder(vae_batch_input)
if num == out_n*out_h*out_w-1 and num%LOCAL_BS != LOCAL_BS-1:
decoded_cube[num-num%LOCAL_BS:] = latent[:num%LOCAL_BS+1]
else:
decoded_cube[num-LOCAL_BS+1:num+1] = latent
vae_batch_input = torch.zeros((LOCAL_BS, C, IN_KERNEL[0], IN_KERNEL[1], IN_KERNEL[2]), device=input_latents.device, dtype=input_latents.dtype)
num+=1
B, C, N, H, W = input_latents.shape
out_video = torch.zeros((B, OUT_C, N*4, H*8, W*8), device=input_latents.device, dtype=input_latents.dtype)
OVERLAP = KERNEL[0]-STRIDE[0], KERNEL[1]-STRIDE[1], KERNEL[2]-STRIDE[2]
for i in range(out_n):
n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0]
for j in range(out_h):
h_start, h_end = j * STRIDE[1], j * STRIDE[1] + KERNEL[1]
for k in range(out_w):
w_start, w_end = k * STRIDE[2], k * STRIDE[2] + KERNEL[2]
out_video_blend = prepare_for_blend((i, out_n, OVERLAP[0]), (j, out_h, OVERLAP[1]), (k, out_w, OVERLAP[2]), decoded_cube[i*out_h*out_w+j*out_w+k].unsqueeze(0))
out_video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend
out_video = rearrange(out_video, 'b c t h w -> b t c h w').contiguous()
decoded = out_video
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
encoder_local_batch_size: int = 2,
decoder_local_batch_size: int = 2,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*):
PyTorch random number generator.
encoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the encoder's batch inference.
decoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the decoder's batch inference.
"""
x = sample
posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
kwargs["torch_type"] = torch.float32
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
def prepare_for_blend(n_param, h_param, w_param, x):
n, n_max, overlap_n = n_param
h, h_max, overlap_h = h_param
w, w_max, overlap_w = w_param
if overlap_n > 0:
if n > 0: # the head overlap part decays from 0 to 1
x[:,:,0:overlap_n,:,:] = x[:,:,0:overlap_n,:,:] * (torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1)
if n < n_max-1: # the tail overlap part decays from 1 to 0
x[:,:,-overlap_n:,:,:] = x[:,:,-overlap_n:,:,:] * (1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n).reshape(overlap_n,1,1)
if h > 0:
x[:,:,:,0:overlap_h,:] = x[:,:,:,0:overlap_h,:] * (torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1)
if h < h_max-1:
x[:,:,:,-overlap_h:,:] = x[:,:,:,-overlap_h:,:] * (1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h).reshape(overlap_h,1)
if w > 0:
x[:,:,:,:,0:overlap_w] = x[:,:,:,:,0:overlap_w] * (torch.arange(0, overlap_w).float().to(x.device) / overlap_w)
if w < w_max-1:
x[:,:,:,:,-overlap_w:] = x[:,:,:,:,-overlap_w:] * (1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w)
return x