~dricottone/dalle-mini-terminal

1beb83d1b0cb623abda74fc54c56a514c6694fe2 — Dominic Ricottone 1 year, 3 months ago 3d1904e dev
v1.1 Update

Dependencies are pinned to ensure that they continue to work. There have
been breaking updates in some upstream projects, while others are
beginning to bit rot. For now I am preferring the path of least
resistance that will keep the module useful.

The module now generates a single image only. This is not configurable.

Renamed the -a/--artifacts CLI option to -d/--dalle. This is a breaking
change.

Dropped support for -v as a short-form of the CLI --version. This is a
minor breaking change.

Dropped all support for GPU-acceleration.

Added the -v/--vqgan and -o/--output CLI options. The pre-trained VQGAN
model should now be downloaded separately and provided on the --vqgan CLI
option. This is a breaking change.

The README demonstrates where pre-trained models can be downloaded from,
how to run the module, and how to build a Docker image and container to
run the module.
M .gitignore => .gitignore +1 -0
@@ 6,6 6,7 @@

mega-1-fp16_v14_artifacts
mini-1_v0_artifacts
vqgan_imagenet_f16_16384_artifacts

*.jpg


A Dockerfile => Dockerfile +17 -0
@@ 0,0 1,17 @@
FROM python:3.11.4-slim-bookworm

RUN apt-get update && apt-get install -y dumb-init git && apt-get clean
RUN pip install wheel
RUN pip install jax==0.3.25 jaxlib==0.3.25 orbax-checkpoint==0.1.1 git+https://github.com/patil-suraj/vqgan-jax.git dalle-mini

COPY dalle_mini_terminal /app/dalle_mini_terminal
VOLUME /app/dalle-artifacts
VOLUME /app/vqgan-artifacts
VOLUME /app/output

RUN python -c "exec('from huggingface_hub import hf_hub_download\nhf_hub_download(\"dalle-mini/dalle-mini\", filename=\"enwiki-words-frequency.txt\")')"

WORKDIR /app
ENTRYPOINT ["dumb-init", "--", "python", "-m", "dalle_mini_terminal", "--output", "./output", "--"]
CMD ["cats", "playing", "chess"]


M Makefile => Makefile +14 -31
@@ 1,39 1,22 @@
PYTHON_BIN=python3
PIP_BIN=$(PYTHON_BIN) -m pip
VENV_BIN=$(PYTHON_BIN) -m venv
PY_COMPILE_BIN=$(PYTHON_BIN) -m py_compile

# see https://git.dominic-ricottone.com/~dricottone/gap
GAP_BIN=gap

clean:
	rm -rf **/__pycache__ **/*.pyc

uninstall:
	rm -rf .venv
	rm -rf **/__pycache__ **/*.pyc

test:
	$(PY_COMPILE_BIN) dalle_mini_terminal/*.py

.venv:
	$(VENV_BIN) .venv

dalle_mini_terminal/cli.py:
	$(GAP_BIN) dalle_mini_terminal/cli.toml -o dalle_mini_terminal/cli.py

build: dalle_mini_terminal/cli.py
dalle_mini_terminal/cli.py: dalle_mini_terminal/cli.toml
	gap dalle_mini_terminal/cli.toml -o dalle_mini_terminal/cli.py

install: .venv dalle_mini_terminal/cli.py
	(source .venv/bin/activate; $(PIP_BIN) install jax)
	(source .venv/bin/activate; $(PIP_BIN) install git+https://github.com/patil-suraj/vqgan-jax.git)
	(source .venv/bin/activate; $(PIP_BIN) install dalle-mini)
.venv: dalle_mini_terminal/cli.py
	python -m venv .venv
	(. .venv/bin/activate; pip install --upgrade pip)
	(. .venv/bin/activate; pip install wheel)
	(. .venv/bin/activate; pip install jax==0.3.25 jaxlib==0.3.25 orbax-checkpoint==0.1.1 git+https://github.com/patil-suraj/vqgan-jax.git dalle-mini)

install-cuda: .venv dalle_mini_terminal/cli.py
	(source .venv/bin/activate; $(PIP_BIN) install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html)
	(source .venv/bin/activate; $(PIP_BIN) install git+https://github.com/patil-suraj/vqgan-jax.git)
	(source .venv/bin/activate; $(PIP_BIN) install dalle-mini)
install: .venv

run:
	(source .venv/bin/activate; $(PYTHON_BIN) -m dalle_mini_terminal --artifacts ./mini-1_v0_artifacts -- cats playing chess)
	(. .venv/bin/activate; python -m dalle_mini_terminal --dalle ./mini-1_v0_artifacts --vqgan ./vqgan_imagenet_f16_16384_artifacts -- cats playing chess)

build:
	sudo docker build -t dalle_mini_terminal .

.PHONY: clean uninstall test build install install-cuda run
.PHONY: clean install run build

M README.md => README.md +57 -64
@@ 1,86 1,79 @@
# DALL-E Mini in a Terminal

Run the [DALL-E Mini model](https://github.com/borisdayma/dalle-mini) in a
terminal.

I've taken the upstream project's inference pipeline notebook and reimplemented
as a normal Python module.
Currently does not perform the 'optional' CLIP scoring and sorting.
And obviously this module runs headlessly;
images are saved locally.

Generate an image from the
[DALL-E Mini model](https://github.com/borisdayma/dalle-mini).
All without leaving the terminal.
A simplistic refactoring of the official project's inference pipeline notebook.


## Usage

The project can be setup (into a virtual environment for all dependencies)
by running `make install`.
Download the pretrained VQGAN and DALL-E mini models from
[here](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/e93a26e7707683d349bf5d5c41c5b0ef69b677a9).
and
[here](https://huggingface.co/dalle-mini/dalle-mini/tree/e0888f668d60b9009e1a00ef4b6c155ec7512610).

Download the latest DALL-E mini model artifacts to a local `artifacts` folder.
See [W&B for these downloads](https://wandb.ai/dalle-mini/dalle-mini/artifacts).
The **mega** model is tagged as `mega-1-fp16`,
while the **mini** model is tagged as `mini-1`.

Try running with
`python -m dalle-mini-terminal -a path/to/artifacts -- avocado toast` 
It will take a while though, even with the mini model.
Run `make install` to install into a virtualenv.

```
$ time (source .venv/bin/activate; python -m dalle_mini_terminal --artifacts ./mini-1_v0_artifacts -- cats playing chess)

[...]

real    79m59.554s
user    85m35.281s
sys     0m17.885s
$ time (. .venv/bin/activate; python -m dalle_mini_terminal \
> --dalle ./mini-1_v0_artifacts \
> --vqgan ./vqgan_imagenet_f16_16384_artifacts \
> -- your prompt should go here)
Generating images with prompt: cats playing chess
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

real    6m0.490s
user    6m32.701s
sys     0m5.739s
```

Some notes:
 + The first time `dalle_mini_terminal` runs, a list of words will be
   downloaded to `~/.cache/huggingface/hub/models--dalle-mini--dalle-mini`.
   This is an unavoidable side-effect from importing `dalle_mini`.


### CUDA

Install the proprietary nvidia driver,
as well as the `cuda` and `cudnn` packages.
Likely also necessary to reboot, in order to load the kernel modules.

On Arch Linux, cuda libraries and binaries install into `/opt`,
while the cudnn libraries install into `/usr/lib`.
Debian-based distributions often use `/usr/lib` for both cuda and cuddn.
The underlying Python modules assume that the entire toolchain lives in
`/usr/local/cuda-MAJOR.MINOR`.

In other words, if using Arch Linux, it's also necessary to run:
Or run `make build` to build the container image.

```
sudo ln -s /opt/cuda /usr/local/cuda-11.7
sudo ln -s /usr/include/cudnn*.h /usr/local/cuda-11.7/include
sudo ln -s /usr/lib/libcudnn*.so /usr/local/cuda-11.7/lib64/
sudo ln -s /usr/lib/libcudnn*.a /usr/local/cuda-11.7/lib64/
$ time sudo docker run --rm --interactive --tty \
> --mount type=bind,src="$(pwd)/mini-1_v0_artifacts",dst=/dalle-artifacts \
> --mount type=bind,src="$(pwd)/vqgan_imagenet_f16_16384_artifacts",dst=/vqgan_artifacts \
> --mount type=bind,src="$(pwd)/output",dst=/output \
> dalle_mini_terminal \
> your prompt should go here
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
Moving 0 files to the new cache system
0it [00:00, ?it/s]
Generating images with prompt: cats playing chess
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

real    5m52.257s
user    0m0.085s
sys     0m0.067s
```

The project can then be setup with `make install-cuda`.
Running will eat a lot of VRAM though;
far more than my measly GTX 950 has to offer.
Some notes:
 + The Dockerfile does not download the models.
   There are terms and conditions associated with use of these models,
   and by downloading files from the associate portals you will be accepting
   them.
   The *only* download that I have pre-built into the image is a list of words
   that is unavoidably downloaded and cached as a side effect of importing
   the `dalle_mini` package.
   I won't be making any other exceptions.
 + I can't seem to do anything about those warnings.
   They shouldn't bother anyone much, since they will just go to a log unless
   the container is running interactively (i.e. `--tty --interactive`).
 + If for any reason you need to interact with the image beyond generating
   images, be sure to override the entrypoint (i.e. `--entrypoint sh`).


### CUDA

## To-Do

 + Factor logic and functionality (e.g. `print_version`) out of `__main__.py`
   into an `internals.py` file.
 + Figure out how to automate downloading W&B artifacts in a Makefile.
 + Experiment with re-writing the codebase under the assumption that there is
   no GPU/TPU.
   + e.g. is there a cost to `flax.jax_utils.replicate`?
     It doesn't do anything for a CPU workload.
   + or is there any benefit to parallelism (via `jax.pmap`) when there is just
     one compute unit?
     (i.e. `jax.device_count() == 1`)
 + Figure out how to reflect flavors (in the BSD sense) in `pyproject.toml`,
   so that this project can be `pipx` installable both with and without cuda.
   + [Maybe not an option?](https://github.com/python-poetry/poetry/issues/2613)
 + Figure out if `mypy` is an option with this dependency chain.

This is more work than it's worth.
If it *just works* for you, congrats.
It doesn't work for me either, if that's any consolation.


## Licensing

M dalle_mini_terminal/__main__.py => dalle_mini_terminal/__main__.py +7 -4
@@ 1,6 1,7 @@
#!/usr/bin/env python3

import sys
import pathlib

from . import model
from . import cli


@@ 9,18 10,20 @@ def main():
    _config, _positionals = cli.main(sys.argv[1:])

    if "version" in _config.keys():
        print("dalle_mini_terminal v1.0.0")
        print("dalle_mini_terminal v1.1.0")
        sys.exit(0)
    elif "help" in _config.keys():
        print("dalle_mini_terminal --artifacts path/to/artifacts -- avocado chair")
        print("dalle_mini_terminal --dalle dalle/artifacts/dir --vqgan vqgan/artifacts/dir -- avocado chair")
        sys.exit(0)

    artifacts_dir = _config.get("artifacts", "./artifacts")
    dalle_dir = _config.get("dalle", "./dalle-artifacts")
    vqgan_dir = _config.get("vqgan", "./vqgan-artifacts")
    output_dir = _config.get("output", ".")

    prompt = ' '.join(_positionals)
    print("Generating images with prompt:", prompt)

    model.main(prompt, artifacts_dir)
    model.main(prompt, dalle_dir, vqgan_dir, output_dir)
    sys.exit(0)

if __name__ == "__main__":

M dalle_mini_terminal/cli.py => dalle_mini_terminal/cli.py +40 -16
@@ 5,7 5,7 @@ import re
def main(arguments):
	config=dict()
	positional=[]
	pattern=re.compile(r"(?:-(?:a|h|x|v|V)|--(?:artifacts|help|version))(?:=.*)?$")
	pattern=re.compile(r"(?:-(?:d|h|x|o|V|v)|--(?:dalle|help|output|version|vqgan))(?:=.*)?$")
	consuming,needing,wanting=None,0,0
	attached_value=None
	def log(*values): pass


@@ 37,14 37,14 @@ def main(arguments):
				log(f'{option} has an attached value')
				option,attached_value=option.split('=',1)
			log(f'{option} is an option')
			if option=="artifacts":
			if option=="dalle":
				if attached_value is not None:
					config["artifacts"]=attached_value
					config["dalle"]=attached_value
					attached_value=None
					consuming,needing,wanting=None,0,0
				else:
					config["artifacts"]=None
					consuming,needing,wanting="artifacts",1,1
					config["dalle"]=None
					consuming,needing,wanting="dalle",1,1
			elif option=="help":
				if attached_value is not None:
					message=(


@@ 53,6 53,14 @@ def main(arguments):
					)
					raise ValueError(message) from None
				config["help"]=True
			elif option=="output":
				if attached_value is not None:
					config["output"]=attached_value
					attached_value=None
					consuming,needing,wanting=None,0,0
				else:
					config["output"]=None
					consuming,needing,wanting="output",1,1
			elif option=="version":
				if attached_value is not None:
					message=(


@@ 61,14 69,22 @@ def main(arguments):
					)
					raise ValueError(message) from None
				config["version"]=True
			elif option=="a":
			elif option=="vqgan":
				if attached_value is not None:
					config["vqgan"]=attached_value
					attached_value=None
					consuming,needing,wanting=None,0,0
				else:
					config["vqgan"]=None
					consuming,needing,wanting="vqgan",1,1
			elif option=="d":
				if attached_value is not None:
					config["artifacts"]=attached_value
					config["dalle"]=attached_value
					attached_value=None
					consuming,needing,wanting=None,0,0
				else:
					config["artifacts"]=None
					consuming,needing,wanting="artifacts",1,1
					config["dalle"]=None
					consuming,needing,wanting="dalle",1,1
			elif option=="h":
				if attached_value is not None:
					message=(


@@ 85,14 101,14 @@ def main(arguments):
					)
					raise ValueError(message) from None
				config["help"]=True
			elif option=="v":
			elif option=="o":
				if attached_value is not None:
					message=(
						'unexpected value while parsing "version"'
						' (expected 0 values)'
					)
					raise ValueError(message) from None
				config["version"]=True
					config["output"]=attached_value
					attached_value=None
					consuming,needing,wanting=None,0,0
				else:
					config["output"]=None
					consuming,needing,wanting="output",1,1
			elif option=="V":
				if attached_value is not None:
					message=(


@@ 101,6 117,14 @@ def main(arguments):
					)
					raise ValueError(message) from None
				config["version"]=True
			elif option=="v":
				if attached_value is not None:
					config["vqgan"]=attached_value
					attached_value=None
					consuming,needing,wanting=None,0,0
				else:
					config["vqgan"]=None
					consuming,needing,wanting="vqgan",1,1
		else:
			positional.append(arguments.pop(0))
	if needing>0:

M dalle_mini_terminal/cli.toml => dalle_mini_terminal/cli.toml +11 -3
@@ 1,12 1,20 @@
[artifacts]
[dalle]
number = 1
alternatives = ['a']
alternatives = ['d']

[help]
number = 0
alternatives = ['h', 'x']

[output]
number = 1
alternatives = ['o']

[version]
number = 0
alternatives = ['v', 'V']
alternatives = ['V']

[vqgan]
number = 1
alternatives = ['v']


M dalle_mini_terminal/model.py => dalle_mini_terminal/model.py +17 -52
@@ 1,22 1,10 @@
#!/usr/bin/env python3

# constants
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
N_PREDICTIONS = 8
GEN_TOP_K = None
GEN_TOP_P = None
TEMPERATURE = None
COND_SCALE = 10.0


# stdlib imports
import os
import random
from functools import partial
import datetime


# pypi imports
import numpy as np
import jax
import jax.numpy as jnp


@@ 25,38 13,16 @@ from transformers import CLIPProcessor, FlaxCLIPModel
from flax.jax_utils import replicate
from flax.training.common_utils import shard_prng_key
from PIL import Image


# repo imports
from vqgan_jax.modeling_flax_vqgan import VQModel


# functions
def load_dalle_mini(repo: str, version: str | None):
    """Load DALL-E mini"""
    return DalleBart.from_pretrained(repo, revision=version, dtype=jnp.float16, _do_init=False)

def load_vqgan(repo: str, version: str | None):
    """Load VQGAN"""
    return VQModel.from_pretrained(repo, revision=version, _do_init=False)

def load_processor(repo, version):
    """Load DALL-E mini processor"""
    return DalleBartProcessor.from_pretrained(repo, revision=version)


def main(prompt, artifacts_dir):
def main(prompt, dalle_dir, vqgan_dir, output_dir):
    prompts = [prompt]

    # check how many devices are available
    jax.local_device_count()
    model, params = DalleBart.from_pretrained(dalle_dir, revision=None, dtype=jnp.float32, _do_init=False)
    vqgan, vqgan_params = VQModel.from_pretrained(vqgan_dir, revision=None, _do_init=False)
    processor = DalleBartProcessor.from_pretrained(dalle_dir, revision=None)

    print("Loading DALL-E Mini model...")
    model, params = load_dalle_mini(artifacts_dir, None)
    print("Loading VQGAN model...")
    vqgan, vqgan_params = load_vqgan(VQGAN_REPO, VQGAN_COMMIT_ID)
    print("Loading BART encoder...")
    processor = load_processor(artifacts_dir, None)
    jax.local_device_count()

    params = replicate(params)
    vqgan_params = replicate(vqgan_params)


@@ 75,18 41,17 @@ def main(prompt, artifacts_dir):
    def p_decode(indices, params):
        return vqgan.decode_code(indices, params=params)

    for _ in range(max(N_PREDICTIONS // jax.device_count(), 1)):
        key, subkey = jax.random.split(key)
        print("Generating image(s)...")
        encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, GEN_TOP_K, GEN_TOP_P, TEMPERATURE, COND_SCALE)
        encoded_images = encoded_images.sequences[..., 1:]
    key, subkey = jax.random.split(key)
    encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, None, None, None, 10.0)
    encoded_images = encoded_images.sequences[..., 1:]

    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))

        print("Decoding image(s)...")
        decoded_images = p_decode(encoded_images, vqgan_params)
        decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    if output_dir != ".":
        os.chdir(output_dir)

        for decoded_img in decoded_images:
            print("Saving image...")
            img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
            img.save(datetime.datetime.now().strftime("%y%m%d_%H%M%S") + ".jpg", "JPEG")
    for decoded_img in decoded_images:
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        img.save(datetime.datetime.now().strftime("%y%m%d_%H%M%S") + ".jpg", "JPEG")