M .gitignore => .gitignore +1 -0
@@ 6,6 6,7 @@
mega-1-fp16_v14_artifacts
mini-1_v0_artifacts
+vqgan_imagenet_f16_16384_artifacts
*.jpg
A Dockerfile => Dockerfile +17 -0
@@ 0,0 1,17 @@
+FROM python:3.11.4-slim-bookworm
+
+RUN apt-get update && apt-get install -y dumb-init git && apt-get clean
+RUN pip install wheel
+RUN pip install jax==0.3.25 jaxlib==0.3.25 orbax-checkpoint==0.1.1 git+https://github.com/patil-suraj/vqgan-jax.git dalle-mini
+
+COPY dalle_mini_terminal /app/dalle_mini_terminal
+VOLUME /app/dalle-artifacts
+VOLUME /app/vqgan-artifacts
+VOLUME /app/output
+
+RUN python -c "exec('from huggingface_hub import hf_hub_download\nhf_hub_download(\"dalle-mini/dalle-mini\", filename=\"enwiki-words-frequency.txt\")')"
+
+WORKDIR /app
+ENTRYPOINT ["dumb-init", "--", "python", "-m", "dalle_mini_terminal", "--output", "./output", "--"]
+CMD ["cats", "playing", "chess"]
+
M Makefile => Makefile +14 -31
@@ 1,39 1,22 @@
-PYTHON_BIN=python3
-PIP_BIN=$(PYTHON_BIN) -m pip
-VENV_BIN=$(PYTHON_BIN) -m venv
-PY_COMPILE_BIN=$(PYTHON_BIN) -m py_compile
-
-# see https://git.dominic-ricottone.com/~dricottone/gap
-GAP_BIN=gap
-
clean:
- rm -rf **/__pycache__ **/*.pyc
-
-uninstall:
rm -rf .venv
+ rm -rf **/__pycache__ **/*.pyc
-test:
- $(PY_COMPILE_BIN) dalle_mini_terminal/*.py
-
-.venv:
- $(VENV_BIN) .venv
-
-dalle_mini_terminal/cli.py:
- $(GAP_BIN) dalle_mini_terminal/cli.toml -o dalle_mini_terminal/cli.py
-
-build: dalle_mini_terminal/cli.py
+dalle_mini_terminal/cli.py: dalle_mini_terminal/cli.toml
+ gap dalle_mini_terminal/cli.toml -o dalle_mini_terminal/cli.py
-install: .venv dalle_mini_terminal/cli.py
- (source .venv/bin/activate; $(PIP_BIN) install jax)
- (source .venv/bin/activate; $(PIP_BIN) install git+https://github.com/patil-suraj/vqgan-jax.git)
- (source .venv/bin/activate; $(PIP_BIN) install dalle-mini)
+.venv: dalle_mini_terminal/cli.py
+ python -m venv .venv
+ (. .venv/bin/activate; pip install --upgrade pip)
+ (. .venv/bin/activate; pip install wheel)
+ (. .venv/bin/activate; pip install jax==0.3.25 jaxlib==0.3.25 orbax-checkpoint==0.1.1 git+https://github.com/patil-suraj/vqgan-jax.git dalle-mini)
-install-cuda: .venv dalle_mini_terminal/cli.py
- (source .venv/bin/activate; $(PIP_BIN) install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html)
- (source .venv/bin/activate; $(PIP_BIN) install git+https://github.com/patil-suraj/vqgan-jax.git)
- (source .venv/bin/activate; $(PIP_BIN) install dalle-mini)
+install: .venv
run:
- (source .venv/bin/activate; $(PYTHON_BIN) -m dalle_mini_terminal --artifacts ./mini-1_v0_artifacts -- cats playing chess)
+ (. .venv/bin/activate; python -m dalle_mini_terminal --dalle ./mini-1_v0_artifacts --vqgan ./vqgan_imagenet_f16_16384_artifacts -- cats playing chess)
+
+build:
+ sudo docker build -t dalle_mini_terminal .
-.PHONY: clean uninstall test build install install-cuda run
+.PHONY: clean install run build
M README.md => README.md +57 -64
@@ 1,86 1,79 @@
# DALL-E Mini in a Terminal
-Run the [DALL-E Mini model](https://github.com/borisdayma/dalle-mini) in a
-terminal.
-
-I've taken the upstream project's inference pipeline notebook and reimplemented
-as a normal Python module.
-Currently does not perform the 'optional' CLIP scoring and sorting.
-And obviously this module runs headlessly;
-images are saved locally.
-
+Generate an image from the
+[DALL-E Mini model](https://github.com/borisdayma/dalle-mini).
+All without leaving the terminal.
+A simplistic refactoring of the official project's inference pipeline notebook.
## Usage
-The project can be setup (into a virtual environment for all dependencies)
-by running `make install`.
+Download the pretrained VQGAN and DALL-E mini models from
+[here](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/e93a26e7707683d349bf5d5c41c5b0ef69b677a9).
+and
+[here](https://huggingface.co/dalle-mini/dalle-mini/tree/e0888f668d60b9009e1a00ef4b6c155ec7512610).
-Download the latest DALL-E mini model artifacts to a local `artifacts` folder.
-See [W&B for these downloads](https://wandb.ai/dalle-mini/dalle-mini/artifacts).
-The **mega** model is tagged as `mega-1-fp16`,
-while the **mini** model is tagged as `mini-1`.
-
-Try running with
-`python -m dalle-mini-terminal -a path/to/artifacts -- avocado toast`
-It will take a while though, even with the mini model.
+Run `make install` to install into a virtualenv.
```
-$ time (source .venv/bin/activate; python -m dalle_mini_terminal --artifacts ./mini-1_v0_artifacts -- cats playing chess)
-
-[...]
-
-real 79m59.554s
-user 85m35.281s
-sys 0m17.885s
+$ time (. .venv/bin/activate; python -m dalle_mini_terminal \
+> --dalle ./mini-1_v0_artifacts \
+> --vqgan ./vqgan_imagenet_f16_16384_artifacts \
+> -- your prompt should go here)
+Generating images with prompt: cats playing chess
+WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
+
+real 6m0.490s
+user 6m32.701s
+sys 0m5.739s
```
+Some notes:
+ + The first time `dalle_mini_terminal` runs, a list of words will be
+ downloaded to `~/.cache/huggingface/hub/models--dalle-mini--dalle-mini`.
+ This is an unavoidable side-effect from importing `dalle_mini`.
-
-### CUDA
-
-Install the proprietary nvidia driver,
-as well as the `cuda` and `cudnn` packages.
-Likely also necessary to reboot, in order to load the kernel modules.
-
-On Arch Linux, cuda libraries and binaries install into `/opt`,
-while the cudnn libraries install into `/usr/lib`.
-Debian-based distributions often use `/usr/lib` for both cuda and cuddn.
-The underlying Python modules assume that the entire toolchain lives in
-`/usr/local/cuda-MAJOR.MINOR`.
-
-In other words, if using Arch Linux, it's also necessary to run:
+Or run `make build` to build the container image.
```
-sudo ln -s /opt/cuda /usr/local/cuda-11.7
-sudo ln -s /usr/include/cudnn*.h /usr/local/cuda-11.7/include
-sudo ln -s /usr/lib/libcudnn*.so /usr/local/cuda-11.7/lib64/
-sudo ln -s /usr/lib/libcudnn*.a /usr/local/cuda-11.7/lib64/
+$ time sudo docker run --rm --interactive --tty \
+> --mount type=bind,src="$(pwd)/mini-1_v0_artifacts",dst=/dalle-artifacts \
+> --mount type=bind,src="$(pwd)/vqgan_imagenet_f16_16384_artifacts",dst=/vqgan_artifacts \
+> --mount type=bind,src="$(pwd)/output",dst=/output \
+> dalle_mini_terminal \
+> your prompt should go here
+The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
+Moving 0 files to the new cache system
+0it [00:00, ?it/s]
+Generating images with prompt: cats playing chess
+WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
+
+real 5m52.257s
+user 0m0.085s
+sys 0m0.067s
```
-The project can then be setup with `make install-cuda`.
-Running will eat a lot of VRAM though;
-far more than my measly GTX 950 has to offer.
+Some notes:
+ + The Dockerfile does not download the models.
+ There are terms and conditions associated with use of these models,
+ and by downloading files from the associate portals you will be accepting
+ them.
+ The *only* download that I have pre-built into the image is a list of words
+ that is unavoidably downloaded and cached as a side effect of importing
+ the `dalle_mini` package.
+ I won't be making any other exceptions.
+ + I can't seem to do anything about those warnings.
+ They shouldn't bother anyone much, since they will just go to a log unless
+ the container is running interactively (i.e. `--tty --interactive`).
+ + If for any reason you need to interact with the image beyond generating
+ images, be sure to override the entrypoint (i.e. `--entrypoint sh`).
+### CUDA
-## To-Do
-
- + Factor logic and functionality (e.g. `print_version`) out of `__main__.py`
- into an `internals.py` file.
- + Figure out how to automate downloading W&B artifacts in a Makefile.
- + Experiment with re-writing the codebase under the assumption that there is
- no GPU/TPU.
- + e.g. is there a cost to `flax.jax_utils.replicate`?
- It doesn't do anything for a CPU workload.
- + or is there any benefit to parallelism (via `jax.pmap`) when there is just
- one compute unit?
- (i.e. `jax.device_count() == 1`)
- + Figure out how to reflect flavors (in the BSD sense) in `pyproject.toml`,
- so that this project can be `pipx` installable both with and without cuda.
- + [Maybe not an option?](https://github.com/python-poetry/poetry/issues/2613)
- + Figure out if `mypy` is an option with this dependency chain.
-
+This is more work than it's worth.
+If it *just works* for you, congrats.
+It doesn't work for me either, if that's any consolation.
## Licensing
M dalle_mini_terminal/__main__.py => dalle_mini_terminal/__main__.py +7 -4
@@ 1,6 1,7 @@
#!/usr/bin/env python3
import sys
+import pathlib
from . import model
from . import cli
@@ 9,18 10,20 @@ def main():
_config, _positionals = cli.main(sys.argv[1:])
if "version" in _config.keys():
- print("dalle_mini_terminal v1.0.0")
+ print("dalle_mini_terminal v1.1.0")
sys.exit(0)
elif "help" in _config.keys():
- print("dalle_mini_terminal --artifacts path/to/artifacts -- avocado chair")
+ print("dalle_mini_terminal --dalle dalle/artifacts/dir --vqgan vqgan/artifacts/dir -- avocado chair")
sys.exit(0)
- artifacts_dir = _config.get("artifacts", "./artifacts")
+ dalle_dir = _config.get("dalle", "./dalle-artifacts")
+ vqgan_dir = _config.get("vqgan", "./vqgan-artifacts")
+ output_dir = _config.get("output", ".")
prompt = ' '.join(_positionals)
print("Generating images with prompt:", prompt)
- model.main(prompt, artifacts_dir)
+ model.main(prompt, dalle_dir, vqgan_dir, output_dir)
sys.exit(0)
if __name__ == "__main__":
M dalle_mini_terminal/cli.py => dalle_mini_terminal/cli.py +40 -16
@@ 5,7 5,7 @@ import re
def main(arguments):
config=dict()
positional=[]
- pattern=re.compile(r"(?:-(?:a|h|x|v|V)|--(?:artifacts|help|version))(?:=.*)?$")
+ pattern=re.compile(r"(?:-(?:d|h|x|o|V|v)|--(?:dalle|help|output|version|vqgan))(?:=.*)?$")
consuming,needing,wanting=None,0,0
attached_value=None
def log(*values): pass
@@ 37,14 37,14 @@ def main(arguments):
log(f'{option} has an attached value')
option,attached_value=option.split('=',1)
log(f'{option} is an option')
- if option=="artifacts":
+ if option=="dalle":
if attached_value is not None:
- config["artifacts"]=attached_value
+ config["dalle"]=attached_value
attached_value=None
consuming,needing,wanting=None,0,0
else:
- config["artifacts"]=None
- consuming,needing,wanting="artifacts",1,1
+ config["dalle"]=None
+ consuming,needing,wanting="dalle",1,1
elif option=="help":
if attached_value is not None:
message=(
@@ 53,6 53,14 @@ def main(arguments):
)
raise ValueError(message) from None
config["help"]=True
+ elif option=="output":
+ if attached_value is not None:
+ config["output"]=attached_value
+ attached_value=None
+ consuming,needing,wanting=None,0,0
+ else:
+ config["output"]=None
+ consuming,needing,wanting="output",1,1
elif option=="version":
if attached_value is not None:
message=(
@@ 61,14 69,22 @@ def main(arguments):
)
raise ValueError(message) from None
config["version"]=True
- elif option=="a":
+ elif option=="vqgan":
+ if attached_value is not None:
+ config["vqgan"]=attached_value
+ attached_value=None
+ consuming,needing,wanting=None,0,0
+ else:
+ config["vqgan"]=None
+ consuming,needing,wanting="vqgan",1,1
+ elif option=="d":
if attached_value is not None:
- config["artifacts"]=attached_value
+ config["dalle"]=attached_value
attached_value=None
consuming,needing,wanting=None,0,0
else:
- config["artifacts"]=None
- consuming,needing,wanting="artifacts",1,1
+ config["dalle"]=None
+ consuming,needing,wanting="dalle",1,1
elif option=="h":
if attached_value is not None:
message=(
@@ 85,14 101,14 @@ def main(arguments):
)
raise ValueError(message) from None
config["help"]=True
- elif option=="v":
+ elif option=="o":
if attached_value is not None:
- message=(
- 'unexpected value while parsing "version"'
- ' (expected 0 values)'
- )
- raise ValueError(message) from None
- config["version"]=True
+ config["output"]=attached_value
+ attached_value=None
+ consuming,needing,wanting=None,0,0
+ else:
+ config["output"]=None
+ consuming,needing,wanting="output",1,1
elif option=="V":
if attached_value is not None:
message=(
@@ 101,6 117,14 @@ def main(arguments):
)
raise ValueError(message) from None
config["version"]=True
+ elif option=="v":
+ if attached_value is not None:
+ config["vqgan"]=attached_value
+ attached_value=None
+ consuming,needing,wanting=None,0,0
+ else:
+ config["vqgan"]=None
+ consuming,needing,wanting="vqgan",1,1
else:
positional.append(arguments.pop(0))
if needing>0:
M dalle_mini_terminal/cli.toml => dalle_mini_terminal/cli.toml +11 -3
@@ 1,12 1,20 @@
-[artifacts]
+[dalle]
number = 1
-alternatives = ['a']
+alternatives = ['d']
[help]
number = 0
alternatives = ['h', 'x']
+[output]
+number = 1
+alternatives = ['o']
+
[version]
number = 0
-alternatives = ['v', 'V']
+alternatives = ['V']
+
+[vqgan]
+number = 1
+alternatives = ['v']
M dalle_mini_terminal/model.py => dalle_mini_terminal/model.py +17 -52
@@ 1,22 1,10 @@
#!/usr/bin/env python3
-# constants
-VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
-VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
-N_PREDICTIONS = 8
-GEN_TOP_K = None
-GEN_TOP_P = None
-TEMPERATURE = None
-COND_SCALE = 10.0
-
-
-# stdlib imports
+import os
import random
from functools import partial
import datetime
-
-# pypi imports
import numpy as np
import jax
import jax.numpy as jnp
@@ 25,38 13,16 @@ from transformers import CLIPProcessor, FlaxCLIPModel
from flax.jax_utils import replicate
from flax.training.common_utils import shard_prng_key
from PIL import Image
-
-
-# repo imports
from vqgan_jax.modeling_flax_vqgan import VQModel
-
-# functions
-def load_dalle_mini(repo: str, version: str | None):
- """Load DALL-E mini"""
- return DalleBart.from_pretrained(repo, revision=version, dtype=jnp.float16, _do_init=False)
-
-def load_vqgan(repo: str, version: str | None):
- """Load VQGAN"""
- return VQModel.from_pretrained(repo, revision=version, _do_init=False)
-
-def load_processor(repo, version):
- """Load DALL-E mini processor"""
- return DalleBartProcessor.from_pretrained(repo, revision=version)
-
-
-def main(prompt, artifacts_dir):
+def main(prompt, dalle_dir, vqgan_dir, output_dir):
prompts = [prompt]
- # check how many devices are available
- jax.local_device_count()
+ model, params = DalleBart.from_pretrained(dalle_dir, revision=None, dtype=jnp.float32, _do_init=False)
+ vqgan, vqgan_params = VQModel.from_pretrained(vqgan_dir, revision=None, _do_init=False)
+ processor = DalleBartProcessor.from_pretrained(dalle_dir, revision=None)
- print("Loading DALL-E Mini model...")
- model, params = load_dalle_mini(artifacts_dir, None)
- print("Loading VQGAN model...")
- vqgan, vqgan_params = load_vqgan(VQGAN_REPO, VQGAN_COMMIT_ID)
- print("Loading BART encoder...")
- processor = load_processor(artifacts_dir, None)
+ jax.local_device_count()
params = replicate(params)
vqgan_params = replicate(vqgan_params)
@@ 75,18 41,17 @@ def main(prompt, artifacts_dir):
def p_decode(indices, params):
return vqgan.decode_code(indices, params=params)
- for _ in range(max(N_PREDICTIONS // jax.device_count(), 1)):
- key, subkey = jax.random.split(key)
- print("Generating image(s)...")
- encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, GEN_TOP_K, GEN_TOP_P, TEMPERATURE, COND_SCALE)
- encoded_images = encoded_images.sequences[..., 1:]
+ key, subkey = jax.random.split(key)
+ encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, None, None, None, 10.0)
+ encoded_images = encoded_images.sequences[..., 1:]
+
+ decoded_images = p_decode(encoded_images, vqgan_params)
+ decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
- print("Decoding image(s)...")
- decoded_images = p_decode(encoded_images, vqgan_params)
- decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
+ if output_dir != ".":
+ os.chdir(output_dir)
- for decoded_img in decoded_images:
- print("Saving image...")
- img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
- img.save(datetime.datetime.now().strftime("%y%m%d_%H%M%S") + ".jpg", "JPEG")
+ for decoded_img in decoded_images:
+ img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
+ img.save(datetime.datetime.now().strftime("%y%m%d_%H%M%S") + ".jpg", "JPEG")