~dricottone/dalle-mini-terminal

ref: 1beb83d1b0cb623abda74fc54c56a514c6694fe2 dalle-mini-terminal/dalle_mini_terminal/model.py -rw-r--r-- 2.1 KiB
1beb83d1Dominic Ricottone v1.1 Update 1 year, 3 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python3

import os
import random
from functools import partial
import datetime

import numpy as np
import jax
import jax.numpy as jnp
from dalle_mini import DalleBart, DalleBartProcessor
from transformers import CLIPProcessor, FlaxCLIPModel
from flax.jax_utils import replicate
from flax.training.common_utils import shard_prng_key
from PIL import Image
from vqgan_jax.modeling_flax_vqgan import VQModel

def main(prompt, dalle_dir, vqgan_dir, output_dir):
    prompts = [prompt]

    model, params = DalleBart.from_pretrained(dalle_dir, revision=None, dtype=jnp.float32, _do_init=False)
    vqgan, vqgan_params = VQModel.from_pretrained(vqgan_dir, revision=None, _do_init=False)
    processor = DalleBartProcessor.from_pretrained(dalle_dir, revision=None)

    jax.local_device_count()

    params = replicate(params)
    vqgan_params = replicate(vqgan_params)

    seed = random.randint(0, 2**32 - 1)
    key = jax.random.PRNGKey(seed)

    tokenized_prompts = processor(prompts)
    tokenized_prompt = replicate(tokenized_prompts)

    @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
    def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
        return model.generate(**tokenized_prompt, prng_key=key, params=params, top_k=top_k, top_p=top_p, temperature=temperature, condition_scale=condition_scale)

    @partial(jax.pmap, axis_name="batch")
    def p_decode(indices, params):
        return vqgan.decode_code(indices, params=params)

    key, subkey = jax.random.split(key)
    encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, None, None, None, 10.0)
    encoded_images = encoded_images.sequences[..., 1:]

    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))

    if output_dir != ".":
        os.chdir(output_dir)

    for decoded_img in decoded_images:
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        img.save(datetime.datetime.now().strftime("%y%m%d_%H%M%S") + ".jpg", "JPEG")