DALLE mini - Inference pipeline
This notebook illustrates DALL·E mini inference pipeline.
Just want to play? Use directly DALL·E mini app.
For more understanding of the model, refer to the report.
!pip install -q dalle-mini
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git
We load required models:
- DALL·E mini for text to encoded images
- VQGAN for decoding images
- CLIP for scoring predictions
# dalle-mega
DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_COMMIT_ID = None
# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line
# DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
import jax
import jax.numpy as jnp
# check how many devices are available
jax.local_device_count()
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel
# Load dalle-mini
model, params = DalleBart.from_pretrained(
DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
)
# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(
VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)
Model parameters are replicated on each device for faster inference.
from flax.jax_utils import replicate
params = replicate(params)
vqgan_params = replicate(vqgan_params)
Model functions are compiled and parallelized to take advantage of multiple devices.
from functools import partial
# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
return model.generate(
**tokenized_prompt,
prng_key=key,
params=params,
top_k=top_k,
top_p=top_p,
temperature=temperature,
condition_scale=condition_scale,
)
# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
return vqgan.decode_code(indices, params=params)
Keys are passed to the model on each device to generate unique inference per device.
import random
# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)
Our model requires processing prompts.
from dalle_mini import DalleBartProcessor
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
Let's define some text prompts.
prompts = ["sunset over a lake in the mountains", "the Eiffel tower landing on the moon"]
Note: we could use the same prompt multiple times for faster inference.
tokenized_prompts = processor(prompts)
Finally we replicate the prompts onto each device.
tokenized_prompt = replicate(tokenized_prompts)
n_predictions = 8
# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange
print(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
# get a new key
key, subkey = jax.random.split(key)
# generate images
encoded_images = p_generate(
tokenized_prompt,
shard_prng_key(subkey),
params,
gen_top_k,
gen_top_p,
temperature,
cond_scale,
)
# remove BOS
encoded_images = encoded_images.sequences[..., 1:]
# decode images
decoded_images = p_decode(encoded_images, vqgan_params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
for decoded_img in decoded_images:
img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
images.append(img)
display(img)
print()
CLIP_REPO = "openai/clip-vit-base-patch32"
CLIP_COMMIT_ID = None
# Load CLIP
clip, clip_params = FlaxCLIPModel.from_pretrained(
CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False
)
clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)
clip_params = replicate(clip_params)
# score images
@partial(jax.pmap, axis_name="batch")
def p_clip(inputs, params):
logits = clip(params=params, **inputs).logits_per_image
return logits
from flax.training.common_utils import shard
# get clip scores
clip_inputs = clip_processor(
text=prompts * jax.device_count(),
images=images,
return_tensors="np",
padding="max_length",
max_length=77,
truncation=True,
).data
logits = p_clip(shard(clip_inputs), clip_params)
# organize scores per prompt
p = len(prompts)
logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()
#logits = rearrange(logits, '1 b p -> p b')
logits.shape
Let's now display images ranked by CLIP score.
for i, prompt in enumerate(prompts):
print(f"Prompt: {prompt}\n")
for idx in logits[i].argsort()[::-1]:
display(images[idx*p+i])
print(f"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\n")
print()