~dricottone/dalle-mini-terminal

ref: 3d1904e546e4a28d3c5d3daacfe2bb874754a868 dalle-mini-terminal/dalle_mini_terminal/model.py -rw-r--r-- 3.0 KiB
3d1904e5Dominic Ricottone Refactored Makefile 1 year, 8 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3

# constants
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
N_PREDICTIONS = 8
GEN_TOP_K = None
GEN_TOP_P = None
TEMPERATURE = None
COND_SCALE = 10.0


# stdlib imports
import random
from functools import partial
import datetime


# pypi imports
import numpy as np
import jax
import jax.numpy as jnp
from dalle_mini import DalleBart, DalleBartProcessor
from transformers import CLIPProcessor, FlaxCLIPModel
from flax.jax_utils import replicate
from flax.training.common_utils import shard_prng_key
from PIL import Image


# repo imports
from vqgan_jax.modeling_flax_vqgan import VQModel


# functions
def load_dalle_mini(repo: str, version: str | None):
    """Load DALL-E mini"""
    return DalleBart.from_pretrained(repo, revision=version, dtype=jnp.float16, _do_init=False)

def load_vqgan(repo: str, version: str | None):
    """Load VQGAN"""
    return VQModel.from_pretrained(repo, revision=version, _do_init=False)

def load_processor(repo, version):
    """Load DALL-E mini processor"""
    return DalleBartProcessor.from_pretrained(repo, revision=version)


def main(prompt, artifacts_dir):
    prompts = [prompt]

    # check how many devices are available
    jax.local_device_count()

    print("Loading DALL-E Mini model...")
    model, params = load_dalle_mini(artifacts_dir, None)
    print("Loading VQGAN model...")
    vqgan, vqgan_params = load_vqgan(VQGAN_REPO, VQGAN_COMMIT_ID)
    print("Loading BART encoder...")
    processor = load_processor(artifacts_dir, None)

    params = replicate(params)
    vqgan_params = replicate(vqgan_params)

    seed = random.randint(0, 2**32 - 1)
    key = jax.random.PRNGKey(seed)

    tokenized_prompts = processor(prompts)
    tokenized_prompt = replicate(tokenized_prompts)

    @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
    def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
        return model.generate(**tokenized_prompt, prng_key=key, params=params, top_k=top_k, top_p=top_p, temperature=temperature, condition_scale=condition_scale)

    @partial(jax.pmap, axis_name="batch")
    def p_decode(indices, params):
        return vqgan.decode_code(indices, params=params)

    for _ in range(max(N_PREDICTIONS // jax.device_count(), 1)):
        key, subkey = jax.random.split(key)
        print("Generating image(s)...")
        encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, GEN_TOP_K, GEN_TOP_P, TEMPERATURE, COND_SCALE)
        encoded_images = encoded_images.sequences[..., 1:]

        print("Decoding image(s)...")
        decoded_images = p_decode(encoded_images, vqgan_params)
        decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))

        for decoded_img in decoded_images:
            print("Saving image...")
            img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
            img.save(datetime.datetime.now().strftime("%y%m%d_%H%M%S") + ".jpg", "JPEG")