From 1beb83d1b0cb623abda74fc54c56a514c6694fe2 Mon Sep 17 00:00:00 2001 From: Dominic Ricottone Date: Fri, 30 Jun 2023 23:33:13 -0500 Subject: [PATCH] v1.1 Update Dependencies are pinned to ensure that they continue to work. There have been breaking updates in some upstream projects, while others are beginning to bit rot. For now I am preferring the path of least resistance that will keep the module useful. The module now generates a single image only. This is not configurable. Renamed the -a/--artifacts CLI option to -d/--dalle. This is a breaking change. Dropped support for -v as a short-form of the CLI --version. This is a minor breaking change. Dropped all support for GPU-acceleration. Added the -v/--vqgan and -o/--output CLI options. The pre-trained VQGAN model should now be downloaded separately and provided on the --vqgan CLI option. This is a breaking change. The README demonstrates where pre-trained models can be downloaded from, how to run the module, and how to build a Docker image and container to run the module. --- .gitignore | 1 + Dockerfile | 17 +++++ Makefile | 45 ++++-------- README.md | 121 +++++++++++++++----------------- dalle_mini_terminal/__main__.py | 11 +-- dalle_mini_terminal/cli.py | 56 ++++++++++----- dalle_mini_terminal/cli.toml | 14 +++- dalle_mini_terminal/model.py | 69 +++++------------- 8 files changed, 164 insertions(+), 170 deletions(-) create mode 100644 Dockerfile diff --git a/.gitignore b/.gitignore index f84ec5e..f480f85 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ mega-1-fp16_v14_artifacts mini-1_v0_artifacts +vqgan_imagenet_f16_16384_artifacts *.jpg diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e8daf83 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +FROM python:3.11.4-slim-bookworm + +RUN apt-get update && apt-get install -y dumb-init git && apt-get clean +RUN pip install wheel +RUN pip install jax==0.3.25 jaxlib==0.3.25 orbax-checkpoint==0.1.1 git+https://github.com/patil-suraj/vqgan-jax.git dalle-mini + +COPY dalle_mini_terminal /app/dalle_mini_terminal +VOLUME /app/dalle-artifacts +VOLUME /app/vqgan-artifacts +VOLUME /app/output + +RUN python -c "exec('from huggingface_hub import hf_hub_download\nhf_hub_download(\"dalle-mini/dalle-mini\", filename=\"enwiki-words-frequency.txt\")')" + +WORKDIR /app +ENTRYPOINT ["dumb-init", "--", "python", "-m", "dalle_mini_terminal", "--output", "./output", "--"] +CMD ["cats", "playing", "chess"] + diff --git a/Makefile b/Makefile index 1730ba3..e391870 100644 --- a/Makefile +++ b/Makefile @@ -1,39 +1,22 @@ -PYTHON_BIN=python3 -PIP_BIN=$(PYTHON_BIN) -m pip -VENV_BIN=$(PYTHON_BIN) -m venv -PY_COMPILE_BIN=$(PYTHON_BIN) -m py_compile - -# see https://git.dominic-ricottone.com/~dricottone/gap -GAP_BIN=gap - clean: - rm -rf **/__pycache__ **/*.pyc - -uninstall: rm -rf .venv + rm -rf **/__pycache__ **/*.pyc -test: - $(PY_COMPILE_BIN) dalle_mini_terminal/*.py - -.venv: - $(VENV_BIN) .venv - -dalle_mini_terminal/cli.py: - $(GAP_BIN) dalle_mini_terminal/cli.toml -o dalle_mini_terminal/cli.py - -build: dalle_mini_terminal/cli.py +dalle_mini_terminal/cli.py: dalle_mini_terminal/cli.toml + gap dalle_mini_terminal/cli.toml -o dalle_mini_terminal/cli.py -install: .venv dalle_mini_terminal/cli.py - (source .venv/bin/activate; $(PIP_BIN) install jax) - (source .venv/bin/activate; $(PIP_BIN) install git+https://github.com/patil-suraj/vqgan-jax.git) - (source .venv/bin/activate; $(PIP_BIN) install dalle-mini) +.venv: dalle_mini_terminal/cli.py + python -m venv .venv + (. .venv/bin/activate; pip install --upgrade pip) + (. .venv/bin/activate; pip install wheel) + (. .venv/bin/activate; pip install jax==0.3.25 jaxlib==0.3.25 orbax-checkpoint==0.1.1 git+https://github.com/patil-suraj/vqgan-jax.git dalle-mini) -install-cuda: .venv dalle_mini_terminal/cli.py - (source .venv/bin/activate; $(PIP_BIN) install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html) - (source .venv/bin/activate; $(PIP_BIN) install git+https://github.com/patil-suraj/vqgan-jax.git) - (source .venv/bin/activate; $(PIP_BIN) install dalle-mini) +install: .venv run: - (source .venv/bin/activate; $(PYTHON_BIN) -m dalle_mini_terminal --artifacts ./mini-1_v0_artifacts -- cats playing chess) + (. .venv/bin/activate; python -m dalle_mini_terminal --dalle ./mini-1_v0_artifacts --vqgan ./vqgan_imagenet_f16_16384_artifacts -- cats playing chess) + +build: + sudo docker build -t dalle_mini_terminal . -.PHONY: clean uninstall test build install install-cuda run +.PHONY: clean install run build diff --git a/README.md b/README.md index 68c75c2..9e55a9a 100644 --- a/README.md +++ b/README.md @@ -1,86 +1,79 @@ # DALL-E Mini in a Terminal -Run the [DALL-E Mini model](https://github.com/borisdayma/dalle-mini) in a -terminal. - -I've taken the upstream project's inference pipeline notebook and reimplemented -as a normal Python module. -Currently does not perform the 'optional' CLIP scoring and sorting. -And obviously this module runs headlessly; -images are saved locally. - +Generate an image from the +[DALL-E Mini model](https://github.com/borisdayma/dalle-mini). +All without leaving the terminal. +A simplistic refactoring of the official project's inference pipeline notebook. ## Usage -The project can be setup (into a virtual environment for all dependencies) -by running `make install`. +Download the pretrained VQGAN and DALL-E mini models from +[here](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/e93a26e7707683d349bf5d5c41c5b0ef69b677a9). +and +[here](https://huggingface.co/dalle-mini/dalle-mini/tree/e0888f668d60b9009e1a00ef4b6c155ec7512610). -Download the latest DALL-E mini model artifacts to a local `artifacts` folder. -See [W&B for these downloads](https://wandb.ai/dalle-mini/dalle-mini/artifacts). -The **mega** model is tagged as `mega-1-fp16`, -while the **mini** model is tagged as `mini-1`. - -Try running with -`python -m dalle-mini-terminal -a path/to/artifacts -- avocado toast` -It will take a while though, even with the mini model. +Run `make install` to install into a virtualenv. ``` -$ time (source .venv/bin/activate; python -m dalle_mini_terminal --artifacts ./mini-1_v0_artifacts -- cats playing chess) - -[...] - -real 79m59.554s -user 85m35.281s -sys 0m17.885s +$ time (. .venv/bin/activate; python -m dalle_mini_terminal \ +> --dalle ./mini-1_v0_artifacts \ +> --vqgan ./vqgan_imagenet_f16_16384_artifacts \ +> -- your prompt should go here) +Generating images with prompt: cats playing chess +WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) + +real 6m0.490s +user 6m32.701s +sys 0m5.739s ``` +Some notes: + + The first time `dalle_mini_terminal` runs, a list of words will be + downloaded to `~/.cache/huggingface/hub/models--dalle-mini--dalle-mini`. + This is an unavoidable side-effect from importing `dalle_mini`. - -### CUDA - -Install the proprietary nvidia driver, -as well as the `cuda` and `cudnn` packages. -Likely also necessary to reboot, in order to load the kernel modules. - -On Arch Linux, cuda libraries and binaries install into `/opt`, -while the cudnn libraries install into `/usr/lib`. -Debian-based distributions often use `/usr/lib` for both cuda and cuddn. -The underlying Python modules assume that the entire toolchain lives in -`/usr/local/cuda-MAJOR.MINOR`. - -In other words, if using Arch Linux, it's also necessary to run: +Or run `make build` to build the container image. ``` -sudo ln -s /opt/cuda /usr/local/cuda-11.7 -sudo ln -s /usr/include/cudnn*.h /usr/local/cuda-11.7/include -sudo ln -s /usr/lib/libcudnn*.so /usr/local/cuda-11.7/lib64/ -sudo ln -s /usr/lib/libcudnn*.a /usr/local/cuda-11.7/lib64/ +$ time sudo docker run --rm --interactive --tty \ +> --mount type=bind,src="$(pwd)/mini-1_v0_artifacts",dst=/dalle-artifacts \ +> --mount type=bind,src="$(pwd)/vqgan_imagenet_f16_16384_artifacts",dst=/vqgan_artifacts \ +> --mount type=bind,src="$(pwd)/output",dst=/output \ +> dalle_mini_terminal \ +> your prompt should go here +The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`. +Moving 0 files to the new cache system +0it [00:00, ?it/s] +Generating images with prompt: cats playing chess +WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) + +real 5m52.257s +user 0m0.085s +sys 0m0.067s ``` -The project can then be setup with `make install-cuda`. -Running will eat a lot of VRAM though; -far more than my measly GTX 950 has to offer. +Some notes: + + The Dockerfile does not download the models. + There are terms and conditions associated with use of these models, + and by downloading files from the associate portals you will be accepting + them. + The *only* download that I have pre-built into the image is a list of words + that is unavoidably downloaded and cached as a side effect of importing + the `dalle_mini` package. + I won't be making any other exceptions. + + I can't seem to do anything about those warnings. + They shouldn't bother anyone much, since they will just go to a log unless + the container is running interactively (i.e. `--tty --interactive`). + + If for any reason you need to interact with the image beyond generating + images, be sure to override the entrypoint (i.e. `--entrypoint sh`). +### CUDA -## To-Do - - + Factor logic and functionality (e.g. `print_version`) out of `__main__.py` - into an `internals.py` file. - + Figure out how to automate downloading W&B artifacts in a Makefile. - + Experiment with re-writing the codebase under the assumption that there is - no GPU/TPU. - + e.g. is there a cost to `flax.jax_utils.replicate`? - It doesn't do anything for a CPU workload. - + or is there any benefit to parallelism (via `jax.pmap`) when there is just - one compute unit? - (i.e. `jax.device_count() == 1`) - + Figure out how to reflect flavors (in the BSD sense) in `pyproject.toml`, - so that this project can be `pipx` installable both with and without cuda. - + [Maybe not an option?](https://github.com/python-poetry/poetry/issues/2613) - + Figure out if `mypy` is an option with this dependency chain. - +This is more work than it's worth. +If it *just works* for you, congrats. +It doesn't work for me either, if that's any consolation. ## Licensing diff --git a/dalle_mini_terminal/__main__.py b/dalle_mini_terminal/__main__.py index e73edf7..ba7f418 100644 --- a/dalle_mini_terminal/__main__.py +++ b/dalle_mini_terminal/__main__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import sys +import pathlib from . import model from . import cli @@ -9,18 +10,20 @@ def main(): _config, _positionals = cli.main(sys.argv[1:]) if "version" in _config.keys(): - print("dalle_mini_terminal v1.0.0") + print("dalle_mini_terminal v1.1.0") sys.exit(0) elif "help" in _config.keys(): - print("dalle_mini_terminal --artifacts path/to/artifacts -- avocado chair") + print("dalle_mini_terminal --dalle dalle/artifacts/dir --vqgan vqgan/artifacts/dir -- avocado chair") sys.exit(0) - artifacts_dir = _config.get("artifacts", "./artifacts") + dalle_dir = _config.get("dalle", "./dalle-artifacts") + vqgan_dir = _config.get("vqgan", "./vqgan-artifacts") + output_dir = _config.get("output", ".") prompt = ' '.join(_positionals) print("Generating images with prompt:", prompt) - model.main(prompt, artifacts_dir) + model.main(prompt, dalle_dir, vqgan_dir, output_dir) sys.exit(0) if __name__ == "__main__": diff --git a/dalle_mini_terminal/cli.py b/dalle_mini_terminal/cli.py index d6a50df..887148b 100644 --- a/dalle_mini_terminal/cli.py +++ b/dalle_mini_terminal/cli.py @@ -5,7 +5,7 @@ import re def main(arguments): config=dict() positional=[] - pattern=re.compile(r"(?:-(?:a|h|x|v|V)|--(?:artifacts|help|version))(?:=.*)?$") + pattern=re.compile(r"(?:-(?:d|h|x|o|V|v)|--(?:dalle|help|output|version|vqgan))(?:=.*)?$") consuming,needing,wanting=None,0,0 attached_value=None def log(*values): pass @@ -37,14 +37,14 @@ def main(arguments): log(f'{option} has an attached value') option,attached_value=option.split('=',1) log(f'{option} is an option') - if option=="artifacts": + if option=="dalle": if attached_value is not None: - config["artifacts"]=attached_value + config["dalle"]=attached_value attached_value=None consuming,needing,wanting=None,0,0 else: - config["artifacts"]=None - consuming,needing,wanting="artifacts",1,1 + config["dalle"]=None + consuming,needing,wanting="dalle",1,1 elif option=="help": if attached_value is not None: message=( @@ -53,6 +53,14 @@ def main(arguments): ) raise ValueError(message) from None config["help"]=True + elif option=="output": + if attached_value is not None: + config["output"]=attached_value + attached_value=None + consuming,needing,wanting=None,0,0 + else: + config["output"]=None + consuming,needing,wanting="output",1,1 elif option=="version": if attached_value is not None: message=( @@ -61,14 +69,22 @@ def main(arguments): ) raise ValueError(message) from None config["version"]=True - elif option=="a": + elif option=="vqgan": + if attached_value is not None: + config["vqgan"]=attached_value + attached_value=None + consuming,needing,wanting=None,0,0 + else: + config["vqgan"]=None + consuming,needing,wanting="vqgan",1,1 + elif option=="d": if attached_value is not None: - config["artifacts"]=attached_value + config["dalle"]=attached_value attached_value=None consuming,needing,wanting=None,0,0 else: - config["artifacts"]=None - consuming,needing,wanting="artifacts",1,1 + config["dalle"]=None + consuming,needing,wanting="dalle",1,1 elif option=="h": if attached_value is not None: message=( @@ -85,14 +101,14 @@ def main(arguments): ) raise ValueError(message) from None config["help"]=True - elif option=="v": + elif option=="o": if attached_value is not None: - message=( - 'unexpected value while parsing "version"' - ' (expected 0 values)' - ) - raise ValueError(message) from None - config["version"]=True + config["output"]=attached_value + attached_value=None + consuming,needing,wanting=None,0,0 + else: + config["output"]=None + consuming,needing,wanting="output",1,1 elif option=="V": if attached_value is not None: message=( @@ -101,6 +117,14 @@ def main(arguments): ) raise ValueError(message) from None config["version"]=True + elif option=="v": + if attached_value is not None: + config["vqgan"]=attached_value + attached_value=None + consuming,needing,wanting=None,0,0 + else: + config["vqgan"]=None + consuming,needing,wanting="vqgan",1,1 else: positional.append(arguments.pop(0)) if needing>0: diff --git a/dalle_mini_terminal/cli.toml b/dalle_mini_terminal/cli.toml index 8c3d1fe..449f68f 100644 --- a/dalle_mini_terminal/cli.toml +++ b/dalle_mini_terminal/cli.toml @@ -1,12 +1,20 @@ -[artifacts] +[dalle] number = 1 -alternatives = ['a'] +alternatives = ['d'] [help] number = 0 alternatives = ['h', 'x'] +[output] +number = 1 +alternatives = ['o'] + [version] number = 0 -alternatives = ['v', 'V'] +alternatives = ['V'] + +[vqgan] +number = 1 +alternatives = ['v'] diff --git a/dalle_mini_terminal/model.py b/dalle_mini_terminal/model.py index f4ee36a..2f71961 100644 --- a/dalle_mini_terminal/model.py +++ b/dalle_mini_terminal/model.py @@ -1,22 +1,10 @@ #!/usr/bin/env python3 -# constants -VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384" -VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9" -N_PREDICTIONS = 8 -GEN_TOP_K = None -GEN_TOP_P = None -TEMPERATURE = None -COND_SCALE = 10.0 - - -# stdlib imports +import os import random from functools import partial import datetime - -# pypi imports import numpy as np import jax import jax.numpy as jnp @@ -25,38 +13,16 @@ from transformers import CLIPProcessor, FlaxCLIPModel from flax.jax_utils import replicate from flax.training.common_utils import shard_prng_key from PIL import Image - - -# repo imports from vqgan_jax.modeling_flax_vqgan import VQModel - -# functions -def load_dalle_mini(repo: str, version: str | None): - """Load DALL-E mini""" - return DalleBart.from_pretrained(repo, revision=version, dtype=jnp.float16, _do_init=False) - -def load_vqgan(repo: str, version: str | None): - """Load VQGAN""" - return VQModel.from_pretrained(repo, revision=version, _do_init=False) - -def load_processor(repo, version): - """Load DALL-E mini processor""" - return DalleBartProcessor.from_pretrained(repo, revision=version) - - -def main(prompt, artifacts_dir): +def main(prompt, dalle_dir, vqgan_dir, output_dir): prompts = [prompt] - # check how many devices are available - jax.local_device_count() + model, params = DalleBart.from_pretrained(dalle_dir, revision=None, dtype=jnp.float32, _do_init=False) + vqgan, vqgan_params = VQModel.from_pretrained(vqgan_dir, revision=None, _do_init=False) + processor = DalleBartProcessor.from_pretrained(dalle_dir, revision=None) - print("Loading DALL-E Mini model...") - model, params = load_dalle_mini(artifacts_dir, None) - print("Loading VQGAN model...") - vqgan, vqgan_params = load_vqgan(VQGAN_REPO, VQGAN_COMMIT_ID) - print("Loading BART encoder...") - processor = load_processor(artifacts_dir, None) + jax.local_device_count() params = replicate(params) vqgan_params = replicate(vqgan_params) @@ -75,18 +41,17 @@ def main(prompt, artifacts_dir): def p_decode(indices, params): return vqgan.decode_code(indices, params=params) - for _ in range(max(N_PREDICTIONS // jax.device_count(), 1)): - key, subkey = jax.random.split(key) - print("Generating image(s)...") - encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, GEN_TOP_K, GEN_TOP_P, TEMPERATURE, COND_SCALE) - encoded_images = encoded_images.sequences[..., 1:] + key, subkey = jax.random.split(key) + encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, None, None, None, 10.0) + encoded_images = encoded_images.sequences[..., 1:] + + decoded_images = p_decode(encoded_images, vqgan_params) + decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) - print("Decoding image(s)...") - decoded_images = p_decode(encoded_images, vqgan_params) - decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) + if output_dir != ".": + os.chdir(output_dir) - for decoded_img in decoded_images: - print("Saving image...") - img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8)) - img.save(datetime.datetime.now().strftime("%y%m%d_%H%M%S") + ".jpg", "JPEG") + for decoded_img in decoded_images: + img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8)) + img.save(datetime.datetime.now().strftime("%y%m%d_%H%M%S") + ".jpg", "JPEG") -- 2.45.2