From 6efb49b27e07f4eeacc1001a6591deae5433fa38 Mon Sep 17 00:00:00 2001 From: Dominic Ricottone Date: Thu, 16 Jun 2022 15:45:20 -0500 Subject: [PATCH] Initial commit --- .gitignore | 11 ++ LICENSE.txt | 201 ++++++++++++++++++++++++++++++++ Makefile | 63 ++++++++++ README.md | 83 +++++++++++++ dalle_mini_terminal/__init__.py | 0 dalle_mini_terminal/__main__.py | 28 +++++ dalle_mini_terminal/cli.py | 127 ++++++++++++++++++++ dalle_mini_terminal/cli.toml | 12 ++ dalle_mini_terminal/model.py | 92 +++++++++++++++ 9 files changed, 617 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE.txt create mode 100644 Makefile create mode 100644 README.md create mode 100644 dalle_mini_terminal/__init__.py create mode 100644 dalle_mini_terminal/__main__.py create mode 100644 dalle_mini_terminal/cli.py create mode 100644 dalle_mini_terminal/cli.toml create mode 100644 dalle_mini_terminal/model.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f84ec5e --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +.venv + +**/__pycache__ +**/__mypycache__ +**/*.pyc + +mega-1-fp16_v14_artifacts +mini-1_v0_artifacts + +*.jpg + diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..fc7419b --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 The DALL·E mini Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..cc0a650 --- /dev/null +++ b/Makefile @@ -0,0 +1,63 @@ +#PYTHON_BIN=python3 +PYTHON_BIN=python + +#PIP_BIN=$(PYTHON_BIN) -m pip +PIP_BIN=pip + +# NOTE: `pipx` not currently used +#PIPX_BIN=$(PYTHON_BIN) -m pipx +PIPX_BIN=pipx + +VENV_BIN=$(PYTHON_BIN) -m venv + +PY_COMPILE_BIN=$(PYTHON_BIN) -m py_compile + +# NOTE: `pyproject-build` not currently used +#PYPROJECT_BUILD_BIN=$(PYTHON_BIN) -m build +PYPROJECT_BUILD_BIN=pyproject-build + +# NOTE: `unittest` not currently used +#UNITTEST_BIN=$(PYTHON_BIN) -m unittest +UNITTEST_BIN=unittest --color + +# NOTE: `mypy` not currently used +#MYPY_BIN=$(PYTHON_BIN) -m mypy +MYPY_BIN=MYPY_CACHE_DIR=dalle_mini_terminal/__mypycache__ mypy + +# see https://git.dominic-ricottone.com/gap.git/about +#GAP_BIN=$(PYTHON_BIN) -m gap +GAP_BIN=gap + +.PHONY: clean test install install-cuda uninstall run + +clean: + rm -rf **/__pycache__ **/*.pyc + #rm -rf **/__mypycache__ + #rm -rf build + #rm -rf *.egg-info + +test: + $(PY_COMPILE_BIN) dalle_mini_terminal/*.py + #$(UNITTEST_BIN) --working-directory . tests --verbose + #$(MYPY_BIN) -p dalle_mini_terminal + +install: + $(VENV_BIN) .venv + (source .venv/bin/activate; $(PIP_BIN) install jax) + (source .venv/bin/activate; $(PIP_BIN) install git+https://github.com/patil-suraj/vqgan-jax.git) + (source .venv/bin/activate; $(PIP_BIN) install dalle-mini) + $(GAP_BIN) dalle_mini_terminal/cli.toml -o dalle_mini_terminal/cli.py + +install-cuda: + $(VENV_BIN) .venv + (source .venv/bin/activate; $(PIP_BIN) install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html) + (source .venv/bin/activate; $(PIP_BIN) install git+https://github.com/patil-suraj/vqgan-jax.git) + (source .venv/bin/activate; $(PIP_BIN) install dalle-mini) + $(GAP_BIN) dalle_mini_terminal/cli.toml -o dalle_mini_terminal/cli.py + +uninstall: + rm -rf .venv + +run: + (source .venv/bin/activate; $(PYTHON_BIN) -m dalle_mini_terminal --artifacts ./mini-1_v0_artifacts -- cats playing chess) + diff --git a/README.md b/README.md new file mode 100644 index 0000000..86cee7a --- /dev/null +++ b/README.md @@ -0,0 +1,83 @@ +# DALL-E Mini in the Terminal + +Run the [DALL-E Mini model](https://github.com/borisdayma/dalle-mini) in the terminal. + + + +## Usage + +Download the latest DALL-E mini model artifacts to a local `artifacts` folder. +See [W&B for these downloads](https://wandb.ai/dalle-mini/dalle-mini/artifacts). +The **mega** model is tagged as `mega-1-fp16`, +while the **mini** model is tagged as `mini-1`. + +Run `python -m dalle-mini-terminal -a path/to/artifacts -- avocado toast` +It will take a while though, even with the mini model. + +``` +$ time (source .venv/bin/activate; python -m dalle_mini_terminal --artifacts ./mini-1_v0_artifacts -- cats playing chess) + +[...] + +real 79m59.554s +user 85m35.281s +sys 0m17.885s +``` + + + +### CUDA + +Install `cuda` and `cudnn` packages. +Also, probably need to reboot to load the kernel modules. + +Arch Linux packages install the headers and binaries under `/opt`. +Debian-based distributions often use `/usr/lib`. +The underlying Python libraries assume that the cuda toolchain lives in +`/usr/local/cuda-MAJOR.MINOR`. + +So if using Arch Linux, you also need to run: + +``` +sudo ln -s /opt/cuda /usr/local/cuda-11.7 +sudo ln -s /usr/include/cudnn*.h /usr/local/cuda-11.7/include +sudo ln -s /usr/lib/libcudnn*.so /usr/local/cuda-11.7/lib64/ +sudo ln -s /usr/lib/libcudnn*.a /usr/local/cuda-11.7/lib64/ +``` + +It will eat a lot of VRAM though; +far more than my measly 950 has to offer. + + + +## Licensing + +This is all derivative of the iPython/Jupyter notebook hosted at +[https://github.com/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb]. +As such, I have reproduced the original license in this repository +(see LICENSE.txt). +The work is licensed under Apache 2. + +See a list of the model's authors +[here](https://github.com/borisdayma/dalle-mini#authors--contributors). + +Cite the model as: + +``` +@misc{Dayma_DALL·E_Mini_2021, + author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata}, + doi = {10.5281/zenodo.5146400}, + month = {7}, + title = {DALL·E Mini}, + url = {https://github.com/borisdayma/dalle-mini}, + year = {2021} +} +``` + +Images generated by the model are one of: + + 1. Public domain + 2. Property of the AI model + 3. Licensed as a derivative work of the model, + which itself is licensed under Apache 2 (see above) + diff --git a/dalle_mini_terminal/__init__.py b/dalle_mini_terminal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dalle_mini_terminal/__main__.py b/dalle_mini_terminal/__main__.py new file mode 100644 index 0000000..e73edf7 --- /dev/null +++ b/dalle_mini_terminal/__main__.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +import sys + +from . import model +from . import cli + +def main(): + _config, _positionals = cli.main(sys.argv[1:]) + + if "version" in _config.keys(): + print("dalle_mini_terminal v1.0.0") + sys.exit(0) + elif "help" in _config.keys(): + print("dalle_mini_terminal --artifacts path/to/artifacts -- avocado chair") + sys.exit(0) + + artifacts_dir = _config.get("artifacts", "./artifacts") + + prompt = ' '.join(_positionals) + print("Generating images with prompt:", prompt) + + model.main(prompt, artifacts_dir) + sys.exit(0) + +if __name__ == "__main__": + main() + diff --git a/dalle_mini_terminal/cli.py b/dalle_mini_terminal/cli.py new file mode 100644 index 0000000..d6a50df --- /dev/null +++ b/dalle_mini_terminal/cli.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 + +import re + +def main(arguments): + config=dict() + positional=[] + pattern=re.compile(r"(?:-(?:a|h|x|v|V)|--(?:artifacts|help|version))(?:=.*)?$") + consuming,needing,wanting=None,0,0 + attached_value=None + def log(*values): pass + + if "--debug-gap-behavior" in arguments: + def log(*values): print(*values) + while len(arguments) and arguments[0]!="--": + if arguments[0]=="--debug-gap-behavior": + arguments.pop(0) + continue + log(f'processing {arguments[0]}...') + if consuming is not None: + log(f'option {consuming} is consuming') + if config[consuming] is None: + config[consuming]=arguments.pop(0) + log(f'option {consuming} = {config[consuming]}') + else: + config[consuming].append(arguments.pop(0)) + log(f'option {consuming} = {config[consuming]}') + needing-=1 + wanting-=1 + if wanting==0: + log(f'option {consuming} is no longer consuming') + consuming,needing,wanting=None,0,0 + elif pattern.match(arguments[0]): + log(f'{arguments[0]} matched an option') + option = arguments.pop(0).lstrip('-') + if '=' in option: + log(f'{option} has an attached value') + option,attached_value=option.split('=',1) + log(f'{option} is an option') + if option=="artifacts": + if attached_value is not None: + config["artifacts"]=attached_value + attached_value=None + consuming,needing,wanting=None,0,0 + else: + config["artifacts"]=None + consuming,needing,wanting="artifacts",1,1 + elif option=="help": + if attached_value is not None: + message=( + 'unexpected value while parsing "help"' + ' (expected 0 values)' + ) + raise ValueError(message) from None + config["help"]=True + elif option=="version": + if attached_value is not None: + message=( + 'unexpected value while parsing "version"' + ' (expected 0 values)' + ) + raise ValueError(message) from None + config["version"]=True + elif option=="a": + if attached_value is not None: + config["artifacts"]=attached_value + attached_value=None + consuming,needing,wanting=None,0,0 + else: + config["artifacts"]=None + consuming,needing,wanting="artifacts",1,1 + elif option=="h": + if attached_value is not None: + message=( + 'unexpected value while parsing "help"' + ' (expected 0 values)' + ) + raise ValueError(message) from None + config["help"]=True + elif option=="x": + if attached_value is not None: + message=( + 'unexpected value while parsing "help"' + ' (expected 0 values)' + ) + raise ValueError(message) from None + config["help"]=True + elif option=="v": + if attached_value is not None: + message=( + 'unexpected value while parsing "version"' + ' (expected 0 values)' + ) + raise ValueError(message) from None + config["version"]=True + elif option=="V": + if attached_value is not None: + message=( + 'unexpected value while parsing "version"' + ' (expected 0 values)' + ) + raise ValueError(message) from None + config["version"]=True + else: + positional.append(arguments.pop(0)) + if needing>0: + message=( + f'unexpected end while parsing "{consuming}"' + f' (expected {needing} values)' + ) + raise ValueError(message) from None + for argument in arguments[1:]: + if argument=="--debug-gap-behavior": + continue + positional.append(argument) + return config,positional + +if __name__=="__main__": + import sys + cfg,pos = main(sys.argv[1:]) + cfg = {k:v for k,v in cfg.items() if v is not None} + if len(cfg): + print("Options:") + for k,v in cfg.items(): + print(f"{k:20} = {v}") + if len(pos): + print("Positional arguments:", ", ".join(pos)) diff --git a/dalle_mini_terminal/cli.toml b/dalle_mini_terminal/cli.toml new file mode 100644 index 0000000..8c3d1fe --- /dev/null +++ b/dalle_mini_terminal/cli.toml @@ -0,0 +1,12 @@ +[artifacts] +number = 1 +alternatives = ['a'] + +[help] +number = 0 +alternatives = ['h', 'x'] + +[version] +number = 0 +alternatives = ['v', 'V'] + diff --git a/dalle_mini_terminal/model.py b/dalle_mini_terminal/model.py new file mode 100644 index 0000000..f4ee36a --- /dev/null +++ b/dalle_mini_terminal/model.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +# constants +VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384" +VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9" +N_PREDICTIONS = 8 +GEN_TOP_K = None +GEN_TOP_P = None +TEMPERATURE = None +COND_SCALE = 10.0 + + +# stdlib imports +import random +from functools import partial +import datetime + + +# pypi imports +import numpy as np +import jax +import jax.numpy as jnp +from dalle_mini import DalleBart, DalleBartProcessor +from transformers import CLIPProcessor, FlaxCLIPModel +from flax.jax_utils import replicate +from flax.training.common_utils import shard_prng_key +from PIL import Image + + +# repo imports +from vqgan_jax.modeling_flax_vqgan import VQModel + + +# functions +def load_dalle_mini(repo: str, version: str | None): + """Load DALL-E mini""" + return DalleBart.from_pretrained(repo, revision=version, dtype=jnp.float16, _do_init=False) + +def load_vqgan(repo: str, version: str | None): + """Load VQGAN""" + return VQModel.from_pretrained(repo, revision=version, _do_init=False) + +def load_processor(repo, version): + """Load DALL-E mini processor""" + return DalleBartProcessor.from_pretrained(repo, revision=version) + + +def main(prompt, artifacts_dir): + prompts = [prompt] + + # check how many devices are available + jax.local_device_count() + + print("Loading DALL-E Mini model...") + model, params = load_dalle_mini(artifacts_dir, None) + print("Loading VQGAN model...") + vqgan, vqgan_params = load_vqgan(VQGAN_REPO, VQGAN_COMMIT_ID) + print("Loading BART encoder...") + processor = load_processor(artifacts_dir, None) + + params = replicate(params) + vqgan_params = replicate(vqgan_params) + + seed = random.randint(0, 2**32 - 1) + key = jax.random.PRNGKey(seed) + + tokenized_prompts = processor(prompts) + tokenized_prompt = replicate(tokenized_prompts) + + @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6)) + def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale): + return model.generate(**tokenized_prompt, prng_key=key, params=params, top_k=top_k, top_p=top_p, temperature=temperature, condition_scale=condition_scale) + + @partial(jax.pmap, axis_name="batch") + def p_decode(indices, params): + return vqgan.decode_code(indices, params=params) + + for _ in range(max(N_PREDICTIONS // jax.device_count(), 1)): + key, subkey = jax.random.split(key) + print("Generating image(s)...") + encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, GEN_TOP_K, GEN_TOP_P, TEMPERATURE, COND_SCALE) + encoded_images = encoded_images.sequences[..., 1:] + + print("Decoding image(s)...") + decoded_images = p_decode(encoded_images, vqgan_params) + decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) + + for decoded_img in decoded_images: + print("Saving image...") + img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8)) + img.save(datetime.datetime.now().strftime("%y%m%d_%H%M%S") + ".jpg", "JPEG") + -- 2.45.2