commit 50ecd13a884abf3f92af4cbae87893abfc9f81d4 Author: MrTornado24 Date: Wed Dec 13 00:17:53 2023 +0800 chores: rebase commits diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..afc0794 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,12 @@ +root = true + +[*.py] +charset = utf-8 +trim_trailing_whitespace = true +end_of_line = lf +insert_final_newline = true +indent_style = space +indent_size = 4 + +[*.md] +trim_trailing_whitespace = false diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7bb5a2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,195 @@ +# Created by https://www.toptal.com/developers/gitignore/api/python +# Edit at https://www.toptal.com/developers/gitignore?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +# End of https://www.toptal.com/developers/gitignore/api/python + +.vscode/ +.threestudio_cache/ +outputs/ +outputs-gradio/ + +# pretrained model weights +*.ckpt +*.pt +*.pth + +# wandb +wandb/ + +load/tets/256_tets.npz + +# dataset +dataset/ +load/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..3bf22ab --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,34 @@ +default_language_version: + python: python3 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: check-ast + - id: check-merge-conflict + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + language_version: python3 + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + exclude: README.md + args: ["--profile", "black"] + + # temporarily disable static type checking + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: v1.2.0 + # hooks: + # - id: mypy + # args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..bbd9352 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 deepseek-ai + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/LICENSE-CODE b/LICENSE-CODE new file mode 100644 index 0000000..d84f527 --- /dev/null +++ b/LICENSE-CODE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 DeepSeek + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/LICENSE-MODEL b/LICENSE-MODEL new file mode 100644 index 0000000..3faa7a8 --- /dev/null +++ b/LICENSE-MODEL @@ -0,0 +1,91 @@ +DEEPSEEK LICENSE AGREEMENT + +Version 1.0, 23 October 2023 + +Copyright (c) 2023 DeepSeek + +Section I: PREAMBLE + +Large generative models are being widely adopted and used, and have the potential to transform the way individuals conceive and benefit from AI or ML technologies. + +Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. + +In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for content generation. + +Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this agreement aims to strike a balance between both in order to enable responsible open-science in the field of AI. + +This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. + +NOW THEREFORE, You and DeepSeek agree as follows: + +1. Definitions +"License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. +"Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. +"Output" means the results of operating a Model as embodied in informational content resulting therefrom. +"Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. +"Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. +"Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. +"Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. +"DeepSeek" (or "we") means Beijing DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd., Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. and/or any of their affiliates. +"You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, etc. +"Third Parties" means individuals or legal entities that are not under common control with DeepSeek or You. + +Section II: INTELLECTUAL PROPERTY RIGHTS + +Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. + +3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by DeepSeek that are necessarily infringed by its contribution(s). If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or works shall terminate as of the date such litigation is asserted or filed. + + +Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION + +4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: +a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. +b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; +c. You must cause any modified files to carry prominent notices stating that You changed the files; +d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. +e. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. – for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. + +5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). + +6. The Output You Generate. Except as set forth herein, DeepSeek claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. + +Section IV: OTHER PROVISIONS + +7. Updates and Runtime Restrictions. To the maximum extent permitted by law, DeepSeek reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License. + +8. Trademarks and related. Nothing in this License permits You to make use of DeepSeek’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by DeepSeek. + +9. Personal information, IP rights and related. This Model may contain personal information and works with IP rights. You commit to complying with applicable laws and regulations in the handling of personal information and the use of such works. Please note that DeepSeek's license granted to you to use the Model does not imply that you have obtained a legitimate basis for processing the related information or works. As an independent personal information processor and IP rights user, you need to ensure full compliance with relevant legal and regulatory requirements when handling personal information and works with IP rights that may be contained in the Model, and are willing to assume solely any risks and consequences that may arise from that. + +10. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, DeepSeek provides the Model and the Complementary Material on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. + +11. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall DeepSeek be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if DeepSeek has been advised of the possibility of such damages. + +12. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of DeepSeek, and only if You agree to indemnify, defend, and hold DeepSeek harmless for any liability incurred by, or claims asserted against, DeepSeek by reason of your accepting any such warranty or additional liability. + +13. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. + +14. Governing Law and Jurisdiction. This agreement will be governed and construed under PRC laws without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this agreement. The courts located in the domicile of Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. shall have exclusive jurisdiction of any dispute arising out of this agreement. + +END OF TERMS AND CONDITIONS + +Attachment A + +Use Restrictions + +You agree not to use the Model or Derivatives of the Model: + +- In any way that violates any applicable national or international law or regulation or infringes upon the lawful rights and interests of any third party; +- For military use in any way; +- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +- To generate or disseminate verifiably false information and/or content with the purpose of harming others; +- To generate or disseminate inappropriate content subject to applicable regulatory requirements; +- To generate or disseminate personal identifiable information without due authorization or for unreasonable use; +- To defame, disparage or otherwise harass others; +- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; +- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; +- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories. diff --git a/README.md b/README.md new file mode 100644 index 0000000..585c8ef --- /dev/null +++ b/README.md @@ -0,0 +1,175 @@ +# DreamCraft3D + +[**Paper**](https://arxiv.org/abs/2310.16818) | [**Project Page**](https://mrtornado24.github.io/DreamCraft3D/) | [**Youtube video**](https://www.youtube.com/watch?v=0FazXENkQms) + +Official implementation of DreamCraft3D: Hierarchical 3D Generation with Bootstrapped Diffusion Prior + +[Jingxiang Sun](https://mrtornado24.github.io/), [Bo Zhang](https://bo-zhang.me/), [Ruizhi Shao](https://dsaurus.github.io/saurus/), [Lizhen Wang](https://lizhenwangt.github.io/), [Wen Liu](https://github.com/StevenLiuWen), [Zhenda Xie](https://zdaxie.github.io/), [Yebin Liu](https://liuyebin.com/) + + +Abstract: *We present DreamCraft3D, a hierarchical 3D content generation method that produces high-fidelity and coherent 3D objects. We tackle the problem by leveraging a 2D reference image to guide the stages of geometry sculpting and texture boosting. A central focus of this work is to address the consistency issue that existing +works encounter. To sculpt geometries that render coherently, we perform score +distillation sampling via a view-dependent diffusion model. This 3D prior, alongside several training strategies, prioritizes the geometry consistency but compromises the texture fidelity. We further propose **Bootstrapped Score Distillation** to +specifically boost the texture. We train a personalized diffusion model, Dreambooth, on the augmented renderings of the scene, imbuing it with 3D knowledge +of the scene being optimized. The score distillation from this 3D-aware diffusion prior provides view-consistent guidance for the scene. Notably, through an +alternating optimization of the diffusion prior and 3D scene representation, we +achieve mutually reinforcing improvements: the optimized 3D scene aids in training the scene-specific diffusion model, which offers increasingly view-consistent +guidance for 3D optimization. The optimization is thus bootstrapped and leads +to substantial texture boosting. With tailored 3D priors throughout the hierarchical generation, DreamCraft3D generates coherent 3D objects with photorealistic +renderings, advancing the state-of-the-art in 3D content generation.* + +

+ +

+ + +## Method Overview +

+ +

+ + + + +## Installation +### Install threestudio + +**This part is the same as original threestudio. Skip it if you already have installed the environment.** + +See [installation.md](docs/installation.md) for additional information, including installation via Docker. + +- You must have an NVIDIA graphics card with at least 20GB VRAM and have [CUDA](https://developer.nvidia.com/cuda-downloads) installed. +- Install `Python >= 3.8`. +- (Optional, Recommended) Create a virtual environment: + +```sh +python3 -m virtualenv venv +. venv/bin/activate + +# Newer pip versions, e.g. pip-23.x, can be much faster than old versions, e.g. pip-20.x. +# For instance, it caches the wheels of git packages to avoid unnecessarily rebuilding them later. +python3 -m pip install --upgrade pip +``` + +- Install `PyTorch >= 1.12`. We have tested on `torch1.12.1+cu113` and `torch2.0.0+cu118`, but other versions should also work fine. + +```sh +# torch1.12.1+cu113 +pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 +# or torch2.0.0+cu118 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 +``` + +- (Optional, Recommended) Install ninja to speed up the compilation of CUDA extensions: + +```sh +pip install ninja +``` + +- Install dependencies: + +```sh +pip install -r requirements.txt +``` +## Quickstart +Our model is trained in multiple stages. You can run it by +```sh +prompt="a brightly colored mushroom growing on a log" +image_path="load/images/mushroom_log_rgba.png" + +# --------- Stage 1 (NeRF & NeuS) --------- # +python launch.py --config configs/dreamcraft3d-coarse-nerf.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" + +ckpt=outputs/dreamcraft3d-coarse-nerf/$prompt@LAST/ckpts/last.ckpt +python launch.py --config configs/dreamcraft3d-coarse-neus.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.weights="$ckpt" + +# --------- Stage 2 (Geometry Refinement) --------- # +ckpt=outputs/dreamcraft3d-coarse-neus/$prompt@LAST/ckpts/last.ckpt +python launch.py --config configs/dreamcraft3d-geometry.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.geometry_convert_from="$ckpt" + + +# --------- Stage 3 (Texture Refinement) --------- # +ckpt=outputs/dreamcraft3d-geometry/$prompt@LAST/ckpts/last.ckpt +python launch.py --config configs/dreamcraft3d-texture.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.geometry_convert_from="$ckpt" +``` + +
+[Optional] If the "Janus problem" arises in Stage 1, consider training a custom Text2Image model. + +First, generate multi-view images from a single reference image by Zero123++. + +```sh +python threestudio/scripts/img_to_mv.py --image_path 'load/mushroom.png' --save_path '.cache/temp' --prompt 'a photo of mushroom' --superres +``` +Train a personalized DeepFloyd model by DreamBooth Lora. Please check if the generated mv images above are reasonable. + +```sh +export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0" +export INSTANCE_DIR=".cache/temp" +export OUTPUT_DIR=".cache/if_dreambooth_mushroom" + +accelerate launch threestudio/scripts/train_dreambooth_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a sks mushroom" \ + --resolution=64 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --scale_lr \ + --max_train_steps=1200 \ + --checkpointing_steps=600 \ + --pre_compute_text_embeddings \ + --tokenizer_max_length=77 \ + --text_encoder_use_attention_mask +``` + +The personalized DeepFloyd model lora is save at `.cache/if_dreambooth_mushroom`. Now you can replace the guidance the training scripts by + +```sh +# --------- Stage 1 (NeRF & NeuS) --------- # +python launch.py --config configs/dreamcraft3d-coarse-nerf.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.guidance.lora_weights_path=".cache/if_dreambooth_mushroom" +``` +
+ +## Tips +- **Memory Usage**. We run the default configs on 40G A100 GPUs. For reducing memory usage, you can reduce the rendering resolution of NeuS by ```data.height=128 data.width=128 data.random_camera.height=128 data.random_camera.width=128```. You can also reduce resolution for other stages in the same way. + + +## Todo + +- [x] Release the reorganized code. +- [ ] Clean the original dreambooth training code. +- [ ] Provide some running results and checkpoints. + +## Credits +This code is built on the amazing open-source [threestudio-project](https://github.com/threestudio-project/threestudio). + +## Related links + +- [DreamFusion](https://dreamfusion3d.github.io/) +- [Magic3D](https://research.nvidia.com/labs/dir/magic3d/) +- [Make-it-3D](https://make-it-3d.github.io/) +- [Magic123](https://guochengqian.github.io/project/magic123/) +- [ProlificDreamer](https://ml.cs.tsinghua.edu.cn/prolificdreamer/) +- [DreamBooth](https://dreambooth.github.io/) + +## BibTeX + +```bibtex +@article{sun2023dreamcraft3d, + title={Dreamcraft3d: Hierarchical 3d generation with bootstrapped diffusion prior}, + author={Sun, Jingxiang and Zhang, Bo and Shao, Ruizhi and Wang, Lizhen and Liu, Wen and Xie, Zhenda and Liu, Yebin}, + journal={arXiv preprint arXiv:2310.16818}, + year={2023} +} +``` diff --git a/assets/diagram-1.png b/assets/diagram-1.png new file mode 100644 index 0000000..a4ca178 Binary files /dev/null and b/assets/diagram-1.png differ diff --git a/assets/logo.png b/assets/logo.png new file mode 100644 index 0000000..d4f5598 Binary files /dev/null and b/assets/logo.png differ diff --git a/assets/repo_demo_0.mp4 b/assets/repo_demo_0.mp4 new file mode 100644 index 0000000..92b2d75 Binary files /dev/null and b/assets/repo_demo_0.mp4 differ diff --git a/assets/repo_demo_01.mp4 b/assets/repo_demo_01.mp4 new file mode 100644 index 0000000..8334b69 Binary files /dev/null and b/assets/repo_demo_01.mp4 differ diff --git a/assets/repo_demo_02.mp4 b/assets/repo_demo_02.mp4 new file mode 100644 index 0000000..8cc94b2 Binary files /dev/null and b/assets/repo_demo_02.mp4 differ diff --git a/assets/repo_static_v2.png b/assets/repo_static_v2.png new file mode 100644 index 0000000..30d8d30 Binary files /dev/null and b/assets/repo_static_v2.png differ diff --git a/assets/result_mushroom.mp4 b/assets/result_mushroom.mp4 new file mode 100644 index 0000000..1d323ba Binary files /dev/null and b/assets/result_mushroom.mp4 differ diff --git a/configs/dreamcraft3d-coarse-nerf.yaml b/configs/dreamcraft3d-coarse-nerf.yaml new file mode 100644 index 0000000..5ff4c8e --- /dev/null +++ b/configs/dreamcraft3d-coarse-nerf.yaml @@ -0,0 +1,159 @@ +name: "dreamcraft3d-coarse-nerf" +tag: "${rmspace:${system.prompt_processor.prompt},_}" +exp_root_dir: "outputs" +seed: 0 + +data_type: "single-image-datamodule" +data: + image_path: ./load/images/hamburger_rgba.png + height: [128, 384] + width: [128, 384] + resolution_milestones: [3000] + default_elevation_deg: 0.0 + default_azimuth_deg: 0.0 + default_camera_distance: 3.8 + default_fovy_deg: 20.0 + requires_depth: true + requires_normal: ${cmaxgt0:${system.loss.lambda_normal}} + random_camera: + height: [128, 384] + width: [128, 384] + batch_size: [1, 1] + resolution_milestones: [3000] + eval_height: 512 + eval_width: 512 + eval_batch_size: 1 + elevation_range: [-10, 45] + azimuth_range: [-180, 180] + camera_distance_range: [3.8, 3.8] + fovy_range: [20.0, 20.0] # Zero123 has fixed fovy + progressive_until: 200 + camera_perturb: 0.0 + center_perturb: 0.0 + up_perturb: 0.0 + eval_elevation_deg: ${data.default_elevation_deg} + eval_camera_distance: ${data.default_camera_distance} + eval_fovy_deg: ${data.default_fovy_deg} + batch_uniform_azimuth: false + n_val_views: 40 + n_test_views: 120 + +system_type: "dreamcraft3d-system" +system: + stage: coarse + geometry_type: "implicit-volume" + geometry: + radius: 2.0 + normal_type: "finite_difference" + + # the density initialization proposed in the DreamFusion paper + # does not work very well + # density_bias: "blob_dreamfusion" + # density_activation: exp + # density_blob_scale: 5. + # density_blob_std: 0.2 + + # use Magic3D density initialization instead + density_bias: "blob_magic3d" + density_activation: softplus + density_blob_scale: 10. + density_blob_std: 0.5 + + # coarse to fine hash grid encoding + # to ensure smooth analytic normals + pos_encoding_config: + otype: ProgressiveBandHashGrid + n_levels: 16 + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 16 + per_level_scale: 1.447269237440378 # max resolution 4096 + start_level: 8 # resolution ~200 + start_step: 2000 + update_steps: 500 + + material_type: "no-material" + material: + requires_normal: true + + background_type: "solid-color-background" + + renderer_type: "nerf-volume-renderer" + renderer: + radius: ${system.geometry.radius} + num_samples_per_ray: 512 + return_normal_perturb: true + return_comp_normal: ${cmaxgt0:${system.loss.lambda_normal_smooth}} + + prompt_processor_type: "deep-floyd-prompt-processor" + prompt_processor: + pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" + prompt: ??? + use_perp_neg: true + + guidance_type: "deep-floyd-guidance" + guidance: + pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" + guidance_scale: 20 + min_step_percent: [0, 0.7, 0.2, 200] + max_step_percent: [0, 0.85, 0.5, 200] + + guidance_3d_type: "stable-zero123-guidance" + guidance_3d: + pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt" + pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml" + cond_image_path: ${data.image_path} + cond_elevation_deg: ${data.default_elevation_deg} + cond_azimuth_deg: ${data.default_azimuth_deg} + cond_camera_distance: ${data.default_camera_distance} + guidance_scale: 5.0 + min_step_percent: [0, 0.7, 0.2, 200] # (start_iter, start_val, end_val, end_iter) + max_step_percent: [0, 0.85, 0.5, 200] + + freq: + n_ref: 2 + ref_only_steps: 0 + ref_or_guidance: "alternate" + no_diff_steps: 0 + guidance_eval: 0 + + loggers: + wandb: + enable: false + project: "threestudio" + + loss: + lambda_sd: 0.1 + lambda_3d_sd: 0.1 + lambda_rgb: 1000.0 + lambda_mask: 100.0 + lambda_mask_binary: 0.0 + lambda_depth: 0.0 + lambda_depth_rel: 0.05 + lambda_normal: 0.0 + lambda_normal_smooth: 1.0 + lambda_3d_normal_smooth: [2000, 5., 1., 2001] + lambda_orient: [2000, 1., 10., 2001] + lambda_sparsity: [2000, 0.1, 10., 2001] + lambda_opaque: [2000, 0.1, 10., 2001] + lambda_clip: 0.0 + + optimizer: + name: Adam + args: + lr: 0.01 + betas: [0.9, 0.99] + eps: 1.e-8 + +trainer: + max_steps: 5000 + log_every_n_steps: 1 + num_sanity_val_steps: 0 + val_check_interval: 200 + enable_progress_bar: true + precision: 16-mixed + +checkpoint: + save_last: true + save_top_k: -1 + every_n_train_steps: ${trainer.max_steps} \ No newline at end of file diff --git a/configs/dreamcraft3d-coarse-neus.yaml b/configs/dreamcraft3d-coarse-neus.yaml new file mode 100644 index 0000000..d2e6136 --- /dev/null +++ b/configs/dreamcraft3d-coarse-neus.yaml @@ -0,0 +1,155 @@ +name: "dreamcraft3d-coarse-neus" +tag: "${rmspace:${system.prompt_processor.prompt},_}" +exp_root_dir: "outputs" +seed: 0 + +data_type: "single-image-datamodule" +data: + image_path: ./load/images/hamburger_rgba.png + height: 256 + width: 256 + default_elevation_deg: 0.0 + default_azimuth_deg: 0.0 + default_camera_distance: 3.8 + default_fovy_deg: 20.0 + requires_depth: true + requires_normal: ${cmaxgt0:${system.loss.lambda_normal}} + random_camera: + height: 256 + width: 256 + batch_size: 1 + eval_height: 512 + eval_width: 512 + eval_batch_size: 1 + elevation_range: [-10, 45] + azimuth_range: [-180, 180] + camera_distance_range: [3.8, 3.8] + fovy_range: [20.0, 20.0] # Zero123 has fixed fovy + progressive_until: 0 + camera_perturb: 0.0 + center_perturb: 0.0 + up_perturb: 0.0 + eval_elevation_deg: ${data.default_elevation_deg} + eval_camera_distance: ${data.default_camera_distance} + eval_fovy_deg: ${data.default_fovy_deg} + batch_uniform_azimuth: false + n_val_views: 40 + n_test_views: 120 + +system_type: "dreamcraft3d-system" +system: + stage: coarse + geometry_type: "implicit-sdf" + geometry: + radius: 2.0 + normal_type: "finite_difference" + + sdf_bias: sphere + sdf_bias_params: 0.5 + + # coarse to fine hash grid encoding + pos_encoding_config: + otype: HashGrid + n_levels: 16 + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 16 + per_level_scale: 1.447269237440378 # max resolution 4096 + start_level: 8 # resolution ~200 + start_step: 2000 + update_steps: 500 + + material_type: "no-material" + material: + requires_normal: true + + background_type: "solid-color-background" + + renderer_type: "neus-volume-renderer" + renderer: + radius: ${system.geometry.radius} + num_samples_per_ray: 512 + cos_anneal_end_steps: ${trainer.max_steps} + eval_chunk_size: 8192 + + prompt_processor_type: "deep-floyd-prompt-processor" + prompt_processor: + pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" + prompt: ??? + use_perp_neg: true + + guidance_type: "deep-floyd-guidance" + guidance: + pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" + guidance_scale: 20 + min_step_percent: 0.2 + max_step_percent: 0.5 + + guidance_3d_type: "stable-zero123-guidance" + guidance_3d: + pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt" + pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml" + cond_image_path: ${data.image_path} + cond_elevation_deg: ${data.default_elevation_deg} + cond_azimuth_deg: ${data.default_azimuth_deg} + cond_camera_distance: ${data.default_camera_distance} + guidance_scale: 5.0 + min_step_percent: 0.2 + max_step_percent: 0.5 + + freq: + n_ref: 2 + ref_only_steps: 0 + ref_or_guidance: "alternate" + no_diff_steps: 0 + guidance_eval: 0 + + loggers: + wandb: + enable: false + project: "threestudio" + + loss: + lambda_sd: 0.1 + lambda_3d_sd: 0.1 + lambda_rgb: 1000.0 + lambda_mask: 100.0 + lambda_mask_binary: 0.0 + lambda_depth: 0.0 + lambda_depth_rel: 0.05 + lambda_normal: 0.0 + lambda_normal_smooth: 0.0 + lambda_3d_normal_smooth: 0.0 + lambda_orient: 10.0 + lambda_sparsity: 0.1 + lambda_opaque: 0.1 + lambda_clip: 0.0 + lambda_eikonal: 0.0 + + optimizer: + name: Adam + args: + betas: [0.9, 0.99] + eps: 1.e-15 + params: + geometry.encoding: + lr: 0.01 + geometry.sdf_network: + lr: 0.001 + geometry.feature_network: + lr: 0.001 + renderer: + lr: 0.001 + +trainer: + max_steps: 5000 + log_every_n_steps: 1 + num_sanity_val_steps: 0 + val_check_interval: 200 + enable_progress_bar: true + precision: 16-mixed + +checkpoint: + save_last: true + save_top_k: -1 + every_n_train_steps: ${trainer.max_steps} \ No newline at end of file diff --git a/configs/dreamcraft3d-geometry.yaml b/configs/dreamcraft3d-geometry.yaml new file mode 100644 index 0000000..2ceb7ae --- /dev/null +++ b/configs/dreamcraft3d-geometry.yaml @@ -0,0 +1,133 @@ +name: "dreamcraft3d-geometry" +tag: "${rmspace:${system.prompt_processor.prompt},_}" +exp_root_dir: "outputs" +seed: 0 + +data_type: "single-image-datamodule" +data: + image_path: ./load/images/hamburger_rgba.png + height: 1024 + width: 1024 + default_elevation_deg: 0.0 + default_azimuth_deg: 0.0 + default_camera_distance: 3.8 + default_fovy_deg: 20.0 + requires_depth: ${cmaxgt0orcmaxgt0:${system.loss.lambda_depth},${system.loss.lambda_depth_rel}} + requires_normal: ${cmaxgt0:${system.loss.lambda_normal}} + use_mixed_camera_config: false + random_camera: + height: 1024 + width: 1024 + batch_size: 1 + eval_height: 1024 + eval_width: 1024 + eval_batch_size: 1 + elevation_range: [-10, 45] + azimuth_range: [-180, 180] + camera_distance_range: [3.8, 3.8] + fovy_range: [20.0, 20.0] # Zero123 has fixed fovy + progressive_until: 0 + camera_perturb: 0.0 + center_perturb: 0.0 + up_perturb: 0.0 + eval_elevation_deg: ${data.default_elevation_deg} + eval_camera_distance: ${data.default_camera_distance} + eval_fovy_deg: ${data.default_fovy_deg} + batch_uniform_azimuth: false + n_val_views: 40 + n_test_views: 120 + +system_type: "dreamcraft3d-system" +system: + stage: geometry + use_mixed_camera_config: ${data.use_mixed_camera_config} + geometry_convert_from: ??? + geometry_convert_inherit_texture: true + geometry_type: "tetrahedra-sdf-grid" + geometry: + radius: 2.0 # consistent with coarse + isosurface_resolution: 128 + isosurface_deformable_grid: true + + material_type: "no-material" + material: + n_output_dims: 3 + + background_type: "solid-color-background" + + renderer_type: "nvdiff-rasterizer" + renderer: + context_type: cuda + + prompt_processor_type: "deep-floyd-prompt-processor" + prompt_processor: + pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" + prompt: ??? + use_perp_neg: true + + guidance_type: "deep-floyd-guidance" + guidance: + pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" + guidance_scale: 20 + min_step_percent: 0.02 + max_step_percent: 0.5 + + guidance_3d_type: "stable-zero123-guidance" + guidance_3d: + pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt" + pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml" + cond_image_path: ${data.image_path} + cond_elevation_deg: ${data.default_elevation_deg} + cond_azimuth_deg: ${data.default_azimuth_deg} + cond_camera_distance: ${data.default_camera_distance} + guidance_scale: 5.0 + min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter) + max_step_percent: 0.5 + + freq: + n_ref: 2 + ref_only_steps: 0 + ref_or_guidance: "accumulate" + no_diff_steps: 0 + guidance_eval: 0 + n_rgb: 4 + + loggers: + wandb: + enable: false + project: "threestudio" + + loss: + lambda_sd: 0.1 + lambda_3d_sd: 0.1 + lambda_rgb: 1000.0 + lambda_mask: 100.0 + lambda_mask_binary: 0.0 + lambda_depth: 0.0 + lambda_depth_rel: 0.0 + lambda_normal: 0.0 + lambda_normal_smooth: 0. + lambda_3d_normal_smooth: 0. + lambda_normal_consistency: [1000,10.0,1,2000] + lambda_laplacian_smoothness: 0.0 + + optimizer: + name: Adam + args: + lr: 0.005 + betas: [0.9, 0.99] + eps: 1.e-15 + +trainer: + max_steps: 5000 + log_every_n_steps: 1 + num_sanity_val_steps: 0 + val_check_interval: 200 + enable_progress_bar: true + precision: 32 + strategy: "ddp_find_unused_parameters_true" + +checkpoint: + save_last: true + save_top_k: -1 + every_n_train_steps: ${trainer.max_steps} \ No newline at end of file diff --git a/configs/dreamcraft3d-texture.yaml b/configs/dreamcraft3d-texture.yaml new file mode 100644 index 0000000..7868f01 --- /dev/null +++ b/configs/dreamcraft3d-texture.yaml @@ -0,0 +1,166 @@ +name: "dreamcraft3d-texture" +tag: "${rmspace:${system.prompt_processor.prompt},_}" +exp_root_dir: "outputs" +seed: 0 + +data_type: "single-image-datamodule" +data: + image_path: ./load/images/hamburger_rgba.png + height: 1024 + width: 1024 + default_elevation_deg: 0.0 + default_azimuth_deg: 0.0 + default_camera_distance: 3.8 + default_fovy_deg: 20.0 + requires_depth: false + requires_normal: false + use_mixed_camera_config: false + random_camera: + height: 1024 + width: 1024 + batch_size: 1 + eval_height: 1024 + eval_width: 1024 + eval_batch_size: 1 + elevation_range: [-10, 45] + azimuth_range: [-180, 180] + camera_distance_range: [3.8, 3.8] + fovy_range: [20.0, 20.0] # Zero123 has fixed fovy + progressive_until: 0 + camera_perturb: 0.0 + center_perturb: 0.0 + up_perturb: 0.0 + eval_elevation_deg: ${data.default_elevation_deg} + eval_camera_distance: ${data.default_camera_distance} + eval_fovy_deg: ${data.default_fovy_deg} + batch_uniform_azimuth: false + n_val_views: 40 + n_test_views: 120 + +system_type: "dreamcraft3d-system" +system: + stage: texture + use_mixed_camera_config: ${data.use_mixed_camera_config} + geometry_convert_from: ??? + geometry_convert_inherit_texture: true + geometry_type: "tetrahedra-sdf-grid" + geometry: + radius: 2.0 # consistent with coarse + isosurface_resolution: 128 + isosurface_deformable_grid: true + isosurface_remove_outliers: true + pos_encoding_config: + otype: HashGrid + n_levels: 16 + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 16 + per_level_scale: 1.447269237440378 # max resolution 4096 + fix_geometry: true + + material_type: "no-material" + material: + n_output_dims: 3 + + background_type: "solid-color-background" + + renderer_type: "nvdiff-rasterizer" + renderer: + context_type: cuda + + prompt_processor_type: "stable-diffusion-prompt-processor" + prompt_processor: + pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" + prompt: ??? + front_threshold: 30. + back_threshold: 30. + + guidance_type: "stable-diffusion-bsd-guidance" + guidance: + pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" + pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1-base" + # pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1" + guidance_scale: 2.0 + min_step_percent: 0.05 + max_step_percent: [0, 0.5, 0.2, 5000] + only_pretrain_step: 1000 + + # guidance_3d_type: "stable-zero123-guidance" + # guidance_3d: + # pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt" + # pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml" + # cond_image_path: ${data.image_path} + # cond_elevation_deg: ${data.default_elevation_deg} + # cond_azimuth_deg: ${data.default_azimuth_deg} + # cond_camera_distance: ${data.default_camera_distance} + # guidance_scale: 5.0 + # min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter) + # max_step_percent: 0.5 + + # control_guidance_type: "stable-diffusion-controlnet-reg-guidance" + # control_guidance: + # min_step_percent: 0.1 + # max_step_percent: 0.5 + # control_prompt_processor_type: "stable-diffusion-prompt-processor" + # control_prompt_processor: + # pretrained_model_name_or_path: "SG161222/Realistic_Vision_V2.0" + # prompt: ${system.prompt_processor.prompt} + # front_threshold: 30. + # back_threshold: 30. + + freq: + n_ref: 2 + ref_only_steps: 0 + ref_or_guidance: "alternate" + no_diff_steps: -1 + guidance_eval: 0 + + loggers: + wandb: + enable: false + project: "threestudio" + + loss: + lambda_sd: 0.01 + lambda_lora: 0.1 + lambda_pretrain: 0.1 + lambda_3d_sd: 0.0 + lambda_rgb: 1000. + lambda_mask: 100. + lambda_mask_binary: 0.0 + lambda_depth: 0.0 + lambda_depth_rel: 0.0 + lambda_normal: 0.0 + lambda_normal_smooth: 0.0 + lambda_3d_normal_smooth: 0.0 + lambda_z_variance: 0.0 + lambda_reg: 0.0 + + optimizer: + name: AdamW + args: + betas: [0.9, 0.99] + eps: 1.e-4 + params: + geometry.encoding: + lr: 0.01 + geometry.feature_network: + lr: 0.001 + guidance.train_unet: + lr: 0.00001 + guidance.train_unet_lora: + lr: 0.00001 + +trainer: + max_steps: 5000 + log_every_n_steps: 1 + num_sanity_val_steps: 0 + val_check_interval: 200 + enable_progress_bar: true + precision: 32 + strategy: "ddp_find_unused_parameters_true" + +checkpoint: + save_last: true + save_top_k: -1 + every_n_train_steps: ${trainer.max_steps} \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..c4a8940 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,60 @@ +# Reference: +# https://github.com/cvpaperchallenge/Ascender +# https://github.com/nerfstudio-project/nerfstudio + +FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 + +ARG USER_NAME=dreamer +ARG GROUP_NAME=dreamers +ARG UID=1000 +ARG GID=1000 + +# Set compute capability for nerfacc and tiny-cuda-nn +# See https://developer.nvidia.com/cuda-gpus and limit number to speed-up build +ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX" +ENV TCNN_CUDA_ARCHITECTURES=90;89;86;80;75;70;61;60 +# Speed-up build for RTX 30xx +# ENV TORCH_CUDA_ARCH_LIST="8.6" +# ENV TCNN_CUDA_ARCHITECTURES=86 +# Speed-up build for RTX 40xx +# ENV TORCH_CUDA_ARCH_LIST="8.9" +# ENV TCNN_CUDA_ARCHITECTURES=89 + +ENV CUDA_HOME=/usr/local/cuda +ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH} +ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} +ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH} + +# apt install by root user +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libegl1-mesa-dev \ + libgl1-mesa-dev \ + libgles2-mesa-dev \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrender1 \ + python-is-python3 \ + python3.10-dev \ + python3-pip \ + wget \ + && rm -rf /var/lib/apt/lists/* + +# Change user to non-root user +RUN groupadd -g ${GID} ${GROUP_NAME} \ + && useradd -ms /bin/sh -u ${UID} -g ${GID} ${USER_NAME} +USER ${USER_NAME} + +RUN pip install --upgrade pip setuptools ninja +RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +# Install nerfacc and tiny-cuda-nn before installing requirements.txt +# because these two installations are time consuming and error prone +RUN pip install git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2 +RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn.git#subdirectory=bindings/torch + +COPY requirements.txt /tmp +RUN cd /tmp && pip install -r requirements.txt +WORKDIR /home/${USER_NAME}/threestudio diff --git a/docker/compose.yaml b/docker/compose.yaml new file mode 100644 index 0000000..b15559a --- /dev/null +++ b/docker/compose.yaml @@ -0,0 +1,23 @@ +services: + threestudio: + build: + context: ../ + dockerfile: docker/Dockerfile + args: + # you can set environment variables, otherwise default values will be used + USER_NAME: ${HOST_USER_NAME:-dreamer} # export HOST_USER_NAME=$USER + GROUP_NAME: ${HOST_GROUP_NAME:-dreamers} + UID: ${HOST_UID:-1000} # export HOST_UID=$(id -u) + GID: ${HOST_GID:-1000} # export HOST_GID=$(id -g) + shm_size: '4gb' + environment: + NVIDIA_DISABLE_REQUIRE: 1 # avoid wrong `nvidia-container-cli: requirement error` + tty: true + volumes: + - ../:/home/${HOST_USER_NAME:-dreamer}/threestudio + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: [gpu] diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 0000000..177221c --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,59 @@ +# Installation + +## Prerequisite + +- NVIDIA GPU with at least 6GB VRAM. The more memory you have, the more methods and higher resolutions you can try. +- [NVIDIA Driver](https://www.nvidia.com/Download/index.aspx) whose version is higher than the [Minimum Required Driver Version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html) of CUDA Toolkit you want to use. + +## Install CUDA Toolkit + +You can skip this step if you have installed sufficiently new version or you use Docker. + +Install [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive). + +- Example for Ubuntu 22.04: + - Run [command for CUDA 11.8 Ubuntu 22.04](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=22.04&target_type=deb_local) +- Example for Ubuntu on WSL2: + - `sudo apt-key del 7fa2af80` + - Run [command for CUDA 11.8 WSL-Ubuntu](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=WSL-Ubuntu&target_version=2.0&target_type=deb_local) + +## Git Clone + +```bash +git clone https://github.com/threestudio-project/threestudio.git +cd threestudio/ +``` + +## Install threestudio via Docker + +1. [Install Docker Engine](https://docs.docker.com/engine/install/). + This document assumes you [install Docker Engine on Ubuntu](https://docs.docker.com/engine/install/ubuntu/). +2. [Create `docker` group](https://docs.docker.com/engine/install/linux-postinstall/). + Otherwise, you need to type `sudo docker` instead of `docker`. +3. [Install NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html#setting-up-nvidia-container-toolkit). +4. If you use WSL2, [enable systemd](https://learn.microsoft.com/en-us/windows/wsl/wsl-config#systemd-support). +5. Edit [Dockerfile](../docker/Dockerfile) for your GPU to speed-up build. + The default Dockerfile takes into account many types of GPUs. +6. Run Docker via `docker compose`. + +```bash +cd docker/ +docker compose build # build Docker image +docker compose up -d # create and start a container in background +docker compose exec threestudio bash # run bash in the container + +# Enjoy threestudio! + +exit # or Ctrl+D +docker compose stop # stop the container +docker compose start # start the container +docker compose down # stop and remove the container +``` + +Note: The current Dockerfile will cause errors when using the OpenGL-based rasterizer of nvdiffrast. +You can use the CUDA-based rasterizer by adding commands or editing configs. + +- `system.renderer.context_type=cuda` for training +- `system.exporter.context_type=cuda` for exporting meshes + +[This comment by the nvdiffrast author](https://github.com/NVlabs/nvdiffrast/issues/94#issuecomment-1288566038) could be a guide to resolve this limitation. diff --git a/extern/MVDream b/extern/MVDream new file mode 160000 index 0000000..853c51b --- /dev/null +++ b/extern/MVDream @@ -0,0 +1 @@ +Subproject commit 853c51b5575e179b25d3aef3d9dbdff950e922ee diff --git a/extern/One-2-3-45 b/extern/One-2-3-45 new file mode 160000 index 0000000..ea88568 --- /dev/null +++ b/extern/One-2-3-45 @@ -0,0 +1 @@ +Subproject commit ea885683ee1a5ad93ba369057dc3d71b7a5ae061 diff --git a/extern/__init__.py b/extern/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/extern/ldm_zero123/extras.py b/extern/ldm_zero123/extras.py new file mode 100755 index 0000000..1646d46 --- /dev/null +++ b/extern/ldm_zero123/extras.py @@ -0,0 +1,78 @@ +import logging +from contextlib import contextmanager +from pathlib import Path + +import torch +from omegaconf import OmegaConf + +from extern.ldm_zero123.util import instantiate_from_config + + +@contextmanager +def all_logging_disabled(highest_level=logging.CRITICAL): + """ + A context manager that will prevent any logging messages + triggered during the body from being processed. + + :param highest_level: the maximum logging level in use. + This would only need to be changed if a custom level greater than CRITICAL + is defined. + + https://gist.github.com/simon-weber/7853144 + """ + # two kind-of hacks here: + # * can't get the highest logging level in effect => delegate to the user + # * can't get the current module-level override => use an undocumented + # (but non-private!) interface + + previous_level = logging.root.manager.disable + + logging.disable(highest_level) + + try: + yield + finally: + logging.disable(previous_level) + + +def load_training_dir(train_dir, device, epoch="last"): + """Load a checkpoint and config from training directory""" + train_dir = Path(train_dir) + ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) + assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" + config = list(train_dir.rglob(f"*-project.yaml")) + assert len(ckpt) > 0, f"didn't find any config in {train_dir}" + if len(config) > 1: + print(f"found {len(config)} matching config files") + config = sorted(config)[-1] + print(f"selecting {config}") + else: + config = config[0] + + config = OmegaConf.load(config) + return load_model_from_config(config, ckpt[0], device) + + +def load_model_from_config(config, ckpt, device="cpu", verbose=False): + """Loads a model from config and a ckpt + if config is a path will use omegaconf to load + """ + if isinstance(config, (str, Path)): + config = OmegaConf.load(config) + + with all_logging_disabled(): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + model.to(device) + model.eval() + model.cond_stage_model.device = device + return model diff --git a/extern/ldm_zero123/guidance.py b/extern/ldm_zero123/guidance.py new file mode 100755 index 0000000..a52a755 --- /dev/null +++ b/extern/ldm_zero123/guidance.py @@ -0,0 +1,110 @@ +import abc +from typing import List, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import torch +from IPython.display import clear_output +from scipy import interpolate + + +class GuideModel(torch.nn.Module, abc.ABC): + def __init__(self) -> None: + super().__init__() + + @abc.abstractmethod + def preprocess(self, x_img): + pass + + @abc.abstractmethod + def compute_loss(self, inp): + pass + + +class Guider(torch.nn.Module): + def __init__(self, sampler, guide_model, scale=1.0, verbose=False): + """Apply classifier guidance + + Specify a guidance scale as either a scalar + Or a schedule as a list of tuples t = 0->1 and scale, e.g. + [(0, 10), (0.5, 20), (1, 50)] + """ + super().__init__() + self.sampler = sampler + self.index = 0 + self.show = verbose + self.guide_model = guide_model + self.history = [] + + if isinstance(scale, (Tuple, List)): + times = np.array([x[0] for x in scale]) + values = np.array([x[1] for x in scale]) + self.scale_schedule = {"times": times, "values": values} + else: + self.scale_schedule = float(scale) + + self.ddim_timesteps = sampler.ddim_timesteps + self.ddpm_num_timesteps = sampler.ddpm_num_timesteps + + def get_scales(self): + if isinstance(self.scale_schedule, float): + return len(self.ddim_timesteps) * [self.scale_schedule] + + interpolater = interpolate.interp1d( + self.scale_schedule["times"], self.scale_schedule["values"] + ) + fractional_steps = np.array(self.ddim_timesteps) / self.ddpm_num_timesteps + return interpolater(fractional_steps) + + def modify_score(self, model, e_t, x, t, c): + # TODO look up index by t + scale = self.get_scales()[self.index] + + if scale == 0: + return e_t + + sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) + x_img = model.first_stage_model.decode((1 / 0.18215) * pred_x0) + + inp = self.guide_model.preprocess(x_img) + loss = self.guide_model.compute_loss(inp) + grads = torch.autograd.grad(loss.sum(), x_in)[0] + correction = grads * scale + + if self.show: + clear_output(wait=True) + print( + loss.item(), + scale, + correction.abs().max().item(), + e_t.abs().max().item(), + ) + self.history.append( + [ + loss.item(), + scale, + correction.min().item(), + correction.max().item(), + ] + ) + plt.imshow( + (inp[0].detach().permute(1, 2, 0).clamp(-1, 1).cpu() + 1) / 2 + ) + plt.axis("off") + plt.show() + plt.imshow(correction[0][0].detach().cpu()) + plt.axis("off") + plt.show() + + e_t_mod = e_t - sqrt_1ma * correction + if self.show: + fig, axs = plt.subplots(1, 3) + axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) + axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) + axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) + plt.show() + self.index += 1 + return e_t_mod diff --git a/extern/ldm_zero123/lr_scheduler.py b/extern/ldm_zero123/lr_scheduler.py new file mode 100755 index 0000000..b2f4d38 --- /dev/null +++ b/extern/ldm_zero123/lr_scheduler.py @@ -0,0 +1,135 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + + def __init__( + self, + warm_up_steps, + lr_min, + lr_max, + lr_start, + max_decay_steps, + verbosity_interval=0, + ): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0.0 + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = ( + self.lr_max - self.lr_start + ) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / ( + self.lr_max_decay_steps - self.lr_warm_up_steps + ) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi) + ) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + + def __init__( + self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 + ): + assert ( + len(warm_up_steps) + == len(f_min) + == len(f_max) + == len(f_start) + == len(cycle_lengths) + ) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0.0 + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ + cycle + ] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / ( + self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] + ) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi) + ) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ + cycle + ] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( + self.cycle_lengths[cycle] - n + ) / (self.cycle_lengths[cycle]) + self.last_f = f + return f diff --git a/extern/ldm_zero123/models/autoencoder.py b/extern/ldm_zero123/models/autoencoder.py new file mode 100755 index 0000000..a6c16b3 --- /dev/null +++ b/extern/ldm_zero123/models/autoencoder.py @@ -0,0 +1,551 @@ +from contextlib import contextmanager + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from extern.ldm_zero123.modules.diffusionmodules.model import Decoder, Encoder +from extern.ldm_zero123.modules.distributions.distributions import ( + DiagonalGaussianDistribution, +) +from extern.ldm_zero123.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__( + self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer( + n_embed, + embed_dim, + beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape, + ) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print( + f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}." + ) + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_, _, ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice( + np.arange(lower_size, upper_size + 16, 16) + ) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss( + qloss, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + predicted_indices=ind, + ) + + self.log_dict( + log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True + ) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss( + qloss, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log_dict( + log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True + ) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss( + qloss, + x, + xrec, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + suffix, + predicted_indices=ind, + ) + + discloss, log_dict_disc = self.loss( + qloss, + x, + xrec, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + suffix, + predicted_indices=ind, + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log( + f"val{suffix}/rec_loss", + rec_loss, + prog_bar=True, + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + self.log( + f"val{suffix}/aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + if version.parse(pl.__version__) >= version.parse("1.4.0"): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor * self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quantize.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr_g, + betas=(0.5, 0.9), + ) + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9) + ) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + "scheduler": LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + }, + { + "scheduler": LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + + self.log( + "discloss", + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val", + ) + + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val", + ) + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, + betas=(0.5, 0.9), + ) + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) + ) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/extern/ldm_zero123/models/diffusion/__init__.py b/extern/ldm_zero123/models/diffusion/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/extern/ldm_zero123/models/diffusion/classifier.py b/extern/ldm_zero123/models/diffusion/classifier.py new file mode 100755 index 0000000..40467e9 --- /dev/null +++ b/extern/ldm_zero123/models/diffusion/classifier.py @@ -0,0 +1,319 @@ +import os +from copy import deepcopy +from glob import glob + +import pytorch_lightning as pl +import torch +from einops import rearrange +from natsort import natsorted +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR + +from extern.ldm_zero123.modules.diffusionmodules.openaimodel import ( + EncoderUNetModel, + UNetModel, +) +from extern.ldm_zero123.util import ( + default, + instantiate_from_config, + ismap, + log_txt_as_img, +) + +__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + def __init__( + self, + diffusion_path, + num_classes, + ckpt_path=None, + pool="attention", + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.0e-2, + log_steps=10, + monitor="val/loss", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted( + glob(os.path.join(diffusion_path, "configs", "*-project.yaml")) + )[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = ( + label_key + if not hasattr(self.diffusion_model, "cond_stage_key") + else self.diffusion_model.cond_stage_key + ) + + assert ( + self.label_key is not None + ), "label_key neither in diffusion model nor in model.params" + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = ( + self.diffusion_config.params.unet_config.params.out_channels + ) + model_config.out_channels = self.num_classes + if self.label_key == "class_label": + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print( + "#####################################################################" + ) + print(f'load from ckpt "{ckpt_path}"') + print( + "#####################################################################" + ) + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = ( + self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + ) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample( + x_start=x, + t=t, + noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod, + ) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, "b h w c -> b c h w") + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, "Needs to provide label key" + + targets = batch[k].to(self.device) + + if self.label_key == "segmentation": + targets = rearrange(targets, "b h w c -> b c h w") + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest") + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to("cpu") + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = "train" if self.training else "val" + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict( + log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True + ) + self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log( + "global_step", self.global_step, logger=False, on_epoch=False, prog_bar=True + ) + lr = self.optimizers().param_groups[0]["lr"] + self.log("lr_abs", lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input( + batch, k=self.diffusion_model.first_stage_key + ) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint( + 0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device + ).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction="none") + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = { + t: {"acc@1": [], "acc@5": []} + for t in range( + 0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t + ) + } + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]["acc@1"].append( + self.compute_top_k(logits, targets, k=1, reduction="mean") + ) + self.noisy_acc[t]["acc@5"].append( + self.compute_top_k(logits, targets, k=5, reduction="mean") + ) + + return loss + + def configure_optimizers(self): + optimizer = AdamW( + self.model.parameters(), + lr=self.learning_rate, + weight_decay=self.weight_decay, + ) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + "scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } + ] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log["inputs"] = x + + y = self.get_conditioning(batch) + + if self.label_key == "class_label": + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log["labels"] = y + + if ismap(y): + log["labels"] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f"inputs@t{current_time}"] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, "b h w c -> b c h w") + + log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/extern/ldm_zero123/models/diffusion/ddim.py b/extern/ldm_zero123/models/diffusion/ddim.py new file mode 100755 index 0000000..a609ee8 --- /dev/null +++ b/extern/ldm_zero123/models/diffusion/ddim.py @@ -0,0 +1,488 @@ +"""SAMPLING ONLY.""" + +from functools import partial + +import numpy as np +import torch +from tqdm import tqdm + +from extern.ldm_zero123.models.diffusion.sampling_util import ( + norm_thresholding, + renorm_thresholding, + spatial_norm_thresholding, +) +from extern.ldm_zero123.modules.diffusionmodules.util import ( + extract_into_tensor, + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def to(self, device): + """Same as to in torch module + Don't really underestand why this isn't a module in the first place""" + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + new_v = getattr(self, k).to(device) + setattr(self, k, new_v) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + t_start=-1, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + timesteps = timesteps[:t_start] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + img, pred_x0 = outs + if callback: + img = callback(i, img, pred_x0) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + + print(t, sqrt_one_minus_at, a_t) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode( + self, + x0, + c, + t_enc, + use_original_steps=False, + return_intermediates=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + num_reference_steps = ( + self.ddpm_num_timesteps + if use_original_steps + else self.ddim_timesteps.shape[0] + ) + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc="Encoding Image"): + t = torch.full( + (x0.shape[0],), i, device=self.model.device, dtype=torch.long + ) + if unconditional_guidance_scale == 1.0: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model( + torch.cat((x_next, x_next)), + torch.cat((t, t)), + torch.cat((unconditional_conditioning, c)), + ), + 2, + ) + noise_pred = e_t_uncond + unconditional_guidance_scale * ( + noise_pred - e_t_uncond + ) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = ( + alphas_next[i].sqrt() + * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) + * noise_pred + ) + x_next = xt_weighted + weighted_noise_pred + if ( + return_intermediates + and i % (num_steps // return_intermediates) == 0 + and i < num_steps - 1 + ): + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + + out = {"x_encoded": x_next, "intermediate_steps": inter_steps} + if return_intermediates: + out.update({"intermediates": intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + ): + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return x_dec diff --git a/extern/ldm_zero123/models/diffusion/ddpm.py b/extern/ldm_zero123/models/diffusion/ddpm.py new file mode 100755 index 0000000..590b692 --- /dev/null +++ b/extern/ldm_zero123/models/diffusion/ddpm.py @@ -0,0 +1,2689 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import itertools +from contextlib import contextmanager, nullcontext +from functools import partial + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +from einops import rearrange, repeat +from omegaconf import ListConfig +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from torch.optim.lr_scheduler import LambdaLR +from torchvision.utils import make_grid +from tqdm import tqdm + +from extern.ldm_zero123.models.autoencoder import ( + AutoencoderKL, + IdentityFirstStage, + VQModelInterface, +) +from extern.ldm_zero123.models.diffusion.ddim import DDIMSampler +from extern.ldm_zero123.modules.attention import CrossAttention +from extern.ldm_zero123.modules.diffusionmodules.util import ( + extract_into_tensor, + make_beta_schedule, + noise_like, +) +from extern.ldm_zero123.modules.distributions.distributions import ( + DiagonalGaussianDistribution, + normal_kl, +) +from extern.ldm_zero123.modules.ema import LitEma +from extern.ldm_zero123.util import ( + count_params, + default, + exists, + instantiate_from_config, + isimage, + ismap, + log_txt_as_img, + mean_flat, +) + +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, + make_it_fit=False, + ucg_training=None, + ): + super().__init__() + assert parameterization in [ + "eps", + "x0", + ], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print( + f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" + ) + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + if ckpt_path is not None: + self.init_from_ckpt( + ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet + ) + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + + if self.make_it_fit: + n_params = len( + [ + name + for name, _ in itertools.chain( + self.named_parameters(), self.named_buffers() + ) + ] + ) + for name, param in tqdm( + itertools.chain(self.named_parameters(), self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params, + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[ + i % old_shape[0], j % old_shape[1] + ] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc="Sampling t", + total=self.num_timesteps, + ): + img = self.p_sample( + img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised, + ) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates, + ) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == "l1": + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == "l2": + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction="none") + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError( + f"Paramterization {self.parameterization} not yet supported" + ) + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = "train" if self.training else "val" + + loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f"{log_prefix}/loss": loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, "b h w c -> b c h w") + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1 - p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict( + loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) + + self.log( + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]["lr"] + self.log( + "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict( + loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True + ) + self.log_dict( + loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True + ) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample( + batch_size=N, return_intermediates=True + ) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__( + self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + unet_trainable=True, + *args, + **kwargs, + ): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs["timesteps"] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = "concat" if concat_mode else "crossattn" + if cond_stage_config == "__is_unconditional__": + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.unet_trainable = unet_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer("scale_factor", torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + + # construct linear projection layer for concatenating image CLIP embedding and RT + self.cc_projection = nn.Linear(772, 768) + nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768]) + nn.init.zeros_(list(self.cc_projection.parameters())[1]) + self.cc_projection.requires_grad_(True) + + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if ( + self.scale_by_std + and self.current_epoch == 0 + and self.global_step == 0 + and batch_idx == 0 + and not self.restarted_from_ckpt + ): + assert ( + self.scale_factor == 1.0 + ), "rather not use custom rescaling and std-rescaling simultaneously" + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer("scale_factor", 1.0 / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != "__is_first_stage__" + assert config != "__is_unconditional__" + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list( + self, samples, desc="", force_no_decoder_quantization=False + ): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append( + self.decode_first_stage( + zd.to(self.device), force_not_quantize=force_no_decoder_quantization + ) + ) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, "encode") and callable( + self.cond_stage_model.encode + ): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min( + torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1 + )[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip( + weighting, + self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], + ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip( + L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"], + ) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold( + self, x, kernel_size, stride, uf=1, df=1 + ): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting( + kernel_size[0], kernel_size[1], Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict( + kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf), + ) + fold = torch.nn.Fold( + output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2 + ) + + weighting = self.get_weighting( + kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view( + 1, 1, h * uf, w * uf + ) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx) + ) + + elif df > 1 and uf == 1: + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict( + kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df), + ) + fold = torch.nn.Fold( + output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2 + ) + + weighting = self.get_weighting( + kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view( + 1, 1, h // df, w // df + ) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx) + ) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input( + self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + uncond=0.05, + ): + x = super().get_input(batch, k) + T = batch["T"].to(memory_format=torch.contiguous_format).float() + + if bs is not None: + x = x[:bs] + T = T[:bs].to(self.device) + + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + cond_key = cond_key or self.cond_stage_key + xc = super().get_input(batch, cond_key).to(self.device) + if bs is not None: + xc = xc[:bs] + cond = {} + + # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. + random = torch.rand(x.size(0), device=x.device) + prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1") + input_mask = 1 - rearrange( + (random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1" + ) + null_prompt = self.get_learned_conditioning([""]) + + # z.shape: [8, 4, 64, 64]; c.shape: [8, 1, 768] + # print('=========== xc shape ===========', xc.shape) + with torch.enable_grad(): + clip_emb = self.get_learned_conditioning(xc).detach() + null_prompt = self.get_learned_conditioning([""]).detach() + cond["c_crossattn"] = [ + self.cc_projection( + torch.cat( + [ + torch.where(prompt_mask, null_prompt, clip_emb), + T[:, None, :], + ], + dim=-1, + ) + ) + ] + cond["c_concat"] = [ + input_mask * self.encode_first_stage((xc.to(self.device))).mode().detach() + ] + out = [z, cond] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + # @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, "b h w c -> b c h w").contiguous() + + z = 1.0 / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold( + z, ks, stride, uf=uf + ) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [ + self.first_stage_model.decode( + z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize, + ) + for i in range(z.shape[-1]) + ] + else: + output_list = [ + self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode( + z, force_not_quantize=predict_cids or force_not_quantize + ) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode( + z, force_not_quantize=predict_cids or force_not_quantize + ) + else: + return self.first_stage_model.decode(z) + + # @torch.no_grad() # wasted two hours to find this bug... why no grad here! + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params["original_image_size"] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold( + x, ks, stride, df=df + ) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) + + output_list = [ + self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + if self.model.conditioning_key is not None: + assert c is not None + # if self.cond_stage_trainable: + # c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = ( + "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" + ) + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold( + x_noisy, ks, stride + ) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if ( + self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"] + and self.model.conditioning_key + ): # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert len(c) == 1 # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view( + (c.shape[0], -1, ks[0], ks[1], c.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == "coordinates_bbox": + assert ( + "original_image_size" in self.split_input_params + ), "BoudingBoxRescaling is missing original_image_size" + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params["original_image_size"] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [ + ( + rescale_latent + * stride[0] + * (patch_nr % n_patches_per_row) + / full_img_w, + rescale_latent + * stride[1] + * (patch_nr // n_patches_per_row) + / full_img_h, + ) + for patch_nr in range(z.shape[-1]) + ] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [ + ( + x_tl, + y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h, + ) + for x_tl, y_tl in tl_patch_coordinates + ] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [ + torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to( + self.device + ) + for bbox in patch_limits + ] # list of length l with tensors of shape (1, 2) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), "cond must be dict to be fed into model" + cut_cond = cond["c_crossattn"][0][..., :-2].to(self.device) + + adapted_cond = torch.stack( + [torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd] + ) + adapted_cond = rearrange(adapted_cond, "l b n -> (l b) n") + adapted_cond = self.get_learned_conditioning(adapted_cond) + adapted_cond = rearrange( + adapted_cond, "(l b) n d -> l b n d", l=z.shape[-1] + ) + + cond_list = [{"c_crossattn": [e]} for e in adapted_cond] + + else: + cond_list = [ + cond for i in range(z.shape[-1]) + ] # Todo make this more efficient + + # apply model by loop over crops + output_list = [ + self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1]) + ] + assert not isinstance( + output_list[0], tuple + ) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = "train" if self.training else "val" + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f"{prefix}/loss_gamma": loss.mean()}) + loss_dict.update({"logvar": self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f"{prefix}/loss_vlb": loss_vlb}) + loss += self.original_elbo_weight * loss_vlb + loss_dict.update({f"{prefix}/loss": loss}) + + return loss, loss_dict + + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score( + self, model_out, x, t, c, **corrector_kwargs + ) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance + ).exp() * noise, logits.argmax(dim=1) + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm( + reversed(range(0, timesteps)), + desc="Progressive Generation", + total=timesteps, + ) + if verbose + else reversed(range(0, timesteps)) + ) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + ) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + ) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample( + ddim_steps, batch_size, shape, cond, verbose=False, **kwargs + ) + + else: + samples, intermediates = self.sample( + cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs + ) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning( + self, batch_size, null_label=None, image_size=512 + ): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) + cond = {} + cond["c_crossattn"] = [c] + cond["c_concat"] = [ + torch.zeros([batch_size, 4, image_size // 8, image_size // 8]).to( + self.device + ) + ] + return cond + + @torch.no_grad() + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, + ) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch[self.cond_stage_key], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif self.cond_stage_key == "class_label": + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch["human_label"], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if ( + quantize_denoised + and not isinstance(self.first_stage_model, AutoencoderKL) + and not isinstance(self.first_stage_model, IdentityFirstStage) + ): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + quantize_denoised=True, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning( + N, unconditional_guidance_label, image_size=x.shape[-1] + ) + # uc = torch.zeros_like(c) + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[ + f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" + ] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask, + ) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1.0 - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask, + ) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising( + c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N, + ) + prog_row = self._get_denoise_row_from_list( + progressives, desc="Progressive Generation" + ) + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = [] + if self.unet_trainable == "attn": + print("Training only unet attention layers") + for n, m in self.model.named_modules(): + if isinstance(m, CrossAttention) and n.endswith("attn2"): + params.extend(m.parameters()) + if self.unet_trainable == "conv_in": + print("Training only unet input conv layers") + params = list(self.model.diffusion_model.input_blocks[0][0].parameters()) + elif self.unet_trainable is True or self.unet_trainable == "all": + print("Training the full unet") + params = list(self.model.parameters()) + else: + raise ValueError( + f"Unrecognised setting for unet_trainable: {self.unet_trainable}" + ) + + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print("Diffusion model optimizing logvar") + params.append(self.logvar) + + if self.cc_projection is not None: + params = params + list(self.cc_projection.parameters()) + print("========== optimizing for cc projection weight ==========") + + opt = torch.optim.AdamW( + [ + {"params": self.model.parameters(), "lr": lr}, + {"params": self.cc_projection.parameters(), "lr": 10.0 * lr}, + ], + lr=lr, + ) + if self.use_scheduler: + assert "target" in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } + ] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [ + None, + "concat", + "crossattn", + "hybrid", + "adm", + "hybrid-adm", + ] + + def forward( + self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None + ): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == "concat": + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == "crossattn": + # c_crossattn dimension: torch.Size([8, 1, 768]) 1 + # cc dimension: torch.Size([8, 1, 768] + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == "hybrid": + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == "hybrid-adm": + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == "adm": + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, "b h w c -> b c h w") + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + # import pudb; pu.db + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input( + batch, self.first_stage_key, bs=N, log_mode=True + ) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[ + f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}" + ] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch[self.cond_stage_key], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif self.cond_stage_key == "class_label": + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch["human_label"], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning( + N, unconditional_guidance_label + ) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[ + f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" + ] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising( + c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N, + ) + prog_row = self._get_denoise_row_from_list( + progressives, desc="Progressive Generation" + ) + log["progressive_row"] = prog_row + + return log + + +class LatentInpaintDiffusion(LatentDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + + def __init__( + self, + finetune_keys=( + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ), + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, + **kwargs, + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): + assert exists(ckpt_path), "can only finetune from a given checkpoint" + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" + ) + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), "did not find matching parameter to modify" + new_entry[:, : self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = ( + rearrange(batch[ck], "b h w c -> b c h w") + .to(memory_format=torch.contiguous_format) + .float() + ) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input( + batch, self.first_stage_key, bs=N, return_first_stage_outputs=True + ) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch[self.cond_stage_key], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif self.cond_stage_key == "class_label": + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch["human_label"], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage( + c_cat[:, self.c_concat_log_start : self.c_concat_log_end] + ) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning( + N, unconditional_guidance_label + ) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[ + f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" + ] = x_samples_cfg + + log["masked_image"] = ( + rearrange(batch["masked_image"], "b h w c -> b c h w") + .to(memory_format=torch.contiguous_format) + .float() + ) + return log + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert ( + cond_stage_key == "coordinates_bbox" + ), 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = "train" if self.training else "validation" + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs["bbox_image"] = cond_img + return logs + + +class SimpleUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.low_scale_key = low_scale_key + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, "b h w c -> b c h w") + x_low = x_low.to(memory_format=torch.contiguous_format).float() + + encoder_posterior = self.encode_first_stage(x_low) + zx = self.get_first_stage_encoding(encoder_posterior).detach() + all_conds = {"c_concat": [zx], "c_crossattn": [c]} + + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + return z, all_conds, x, xrec, xc, x_low + return z, all_conds + + @torch.no_grad() + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low = self.get_input( + batch, self.first_stage_key, bs=N, log_mode=True + ) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch[self.cond_stage_key], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif self.cond_stage_key == "class_label": + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch["human_label"], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning( + N, unconditional_guidance_label + ) + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[ + f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" + ] = x_samples_cfg + return log + + +class MultiCatFrameDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.low_scale_key = low_scale_key + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + n = 2 + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + cat_conds = batch[self.low_scale_key][:bs] + cats = [] + for i in range(n): + x_low = cat_conds[:, :, :, 3 * i : 3 * (i + 1)] + x_low = rearrange(x_low, "b h w c -> b c h w") + x_low = x_low.to(memory_format=torch.contiguous_format).float() + encoder_posterior = self.encode_first_stage(x_low) + zx = self.get_first_stage_encoding(encoder_posterior).detach() + cats.append(zx) + + all_conds = {"c_concat": [torch.cat(cats, dim=1)], "c_crossattn": [c]} + + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + return z, all_conds, x, xrec, xc, x_low + return z, all_conds + + @torch.no_grad() + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low = self.get_input( + batch, self.first_stage_key, bs=N, log_mode=True + ) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch[self.cond_stage_key], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif self.cond_stage_key == "class_label": + xc = log_txt_as_img( + (x.shape[2], x.shape[3]), + batch["human_label"], + size=x.shape[2] // 25, + ) + log["conditioning"] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning( + N, unconditional_guidance_label + ) + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[ + f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" + ] = x_samples_cfg + return log diff --git a/extern/ldm_zero123/models/diffusion/plms.py b/extern/ldm_zero123/models/diffusion/plms.py new file mode 100755 index 0000000..4e61886 --- /dev/null +++ b/extern/ldm_zero123/models/diffusion/plms.py @@ -0,0 +1,383 @@ +"""SAMPLING ONLY.""" + +from functools import partial + +import numpy as np +import torch +from tqdm import tqdm + +from extern.ldm_zero123.models.diffusion.sampling_util import norm_thresholding +from extern.ldm_zero123.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + if ddim_eta != 0: + raise ValueError("ddim_eta must be 0 for PLMS") + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f"Data shape for PLMS sampling is {size}") + + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + list(reversed(range(0, timesteps))) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full( + (b,), + time_range[min(i + 1, len(time_range) - 1)], + device=device, + dtype=torch.long, + ) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + dynamic_threshold=dynamic_threshold, + ) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + dynamic_threshold=None, + ): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if ( + unconditional_conditioning is None + or unconditional_guidance_scale == 1.0 + ): + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = ( + 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] + ) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/extern/ldm_zero123/models/diffusion/sampling_util.py b/extern/ldm_zero123/models/diffusion/sampling_util.py new file mode 100755 index 0000000..1d0df15 --- /dev/null +++ b/extern/ldm_zero123/models/diffusion/sampling_util.py @@ -0,0 +1,51 @@ +import numpy as np +import torch + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +def renorm_thresholding(x0, value): + # renorm + pred_max = x0.max() + pred_min = x0.min() + pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 + pred_x0 = 2 * pred_x0 - 1.0 # -1 ... 1 + + s = torch.quantile(rearrange(pred_x0, "b ... -> b (...)").abs(), value, dim=-1) + s.clamp_(min=1.0) + s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) + + # clip by threshold + # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max + + # temporary hack: numpy on cpu + pred_x0 = ( + np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) + / s.cpu().numpy() + ) + pred_x0 = torch.tensor(pred_x0).to(self.model.device) + + # re.renorm + pred_x0 = (pred_x0 + 1.0) / 2.0 # 0 ... 1 + pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range + return pred_x0 + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) diff --git a/extern/ldm_zero123/modules/attention.py b/extern/ldm_zero123/modules/attention.py new file mode 100755 index 0000000..c69ea71 --- /dev/null +++ b/extern/ldm_zero123/modules/attention.py @@ -0,0 +1,364 @@ +import math +from inspect import isfunction + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import einsum, nn + +from extern.ldm_zero123.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.down = nn.Linear(in_features, rank, bias=False) + self.up = nn.Linear(rank, out_features, bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + self.lora = False + self.query_dim = query_dim + self.inner_dim = inner_dim + self.context_dim = context_dim + + def setup_lora(self, rank=4, network_alpha=None): + self.lora = True + self.rank = rank + self.to_q_lora = LoRALinearLayer(self.query_dim, self.inner_dim, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(self.context_dim, self.inner_dim, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(self.context_dim, self.inner_dim, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(self.inner_dim, self.query_dim, rank, network_alpha) + self.lora_layers = nn.ModuleList() + self.lora_layers.append(self.to_q_lora) + self.lora_layers.append(self.to_k_lora) + self.lora_layers.append(self.to_v_lora) + self.lora_layers.append(self.to_out_lora) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if self.lora: + q += self.to_q_lora(x) + k += self.to_k_lora(context) + v += self.to_v_lora(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + # return self.to_out(out) + + # linear proj + o = self.to_out[0](out) + if self.lora: + o += self.to_out_lora(out) + # dropout + out = self.to_out[1](o) + + return out + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + ): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + # return checkpoint( + # self._forward, (x, context), self.parameters(), self.checkpoint + # ) + return self._forward(x, context) + + def _forward(self, x, context=None): + x = ( + self.attn1( + self.norm1(x), context=context if self.disable_self_attn else None + ) + + x + ) + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + ): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + disable_self_attn=disable_self_attn, + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + x = self.proj_out(x) + return x + x_in diff --git a/extern/ldm_zero123/modules/attention_ori.py b/extern/ldm_zero123/modules/attention_ori.py new file mode 100755 index 0000000..e2a1c9e --- /dev/null +++ b/extern/ldm_zero123/modules/attention_ori.py @@ -0,0 +1,301 @@ +import math +from inspect import isfunction + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import einsum, nn + +from extern.ldm_zero123.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + ): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = ( + self.attn1( + self.norm1(x), context=context if self.disable_self_attn else None + ) + + x + ) + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + ): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + disable_self_attn=disable_self_attn, + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + x = self.proj_out(x) + return x + x_in diff --git a/extern/ldm_zero123/modules/diffusionmodules/__init__.py b/extern/ldm_zero123/modules/diffusionmodules/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/extern/ldm_zero123/modules/diffusionmodules/model.py b/extern/ldm_zero123/modules/diffusionmodules/model.py new file mode 100755 index 0000000..5aaefb4 --- /dev/null +++ b/extern/ldm_zero123/modules/diffusionmodules/model.py @@ -0,0 +1,1009 @@ +# pytorch_diffusion + derived encoder decoder +import math + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange + +from extern.ldm_zero123.modules.attention import LinearAttention +from extern.ldm_zero123.util import instantiate_from_config + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x + + +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock( + in_channels=ch_in, out_channels=m * n_channels, dropout=dropout + ) + ) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, "b c h w -> b (h w) c") + return z diff --git a/extern/ldm_zero123/modules/diffusionmodules/openaimodel.py b/extern/ldm_zero123/modules/diffusionmodules/openaimodel.py new file mode 100755 index 0000000..8c7ec90 --- /dev/null +++ b/extern/ldm_zero123/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1062 @@ +import math +from abc import abstractmethod +from functools import partial +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from extern.ldm_zero123.modules.attention import SpatialTransformer +from extern.ldm_zero123.modules.diffusionmodules.util import ( + avg_pool_nd, + checkpoint, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) +from extern.ldm_zero123.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + # return checkpoint( + # self._forward, (x, emb), self.parameters(), self.use_checkpoint + # ) + return self._forward(x, emb) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + # return checkpoint( + # self._forward, (x,), self.parameters(), True + # ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # # return pt_checkpoint(self._forward, x) # pytorch + return self._forward(x) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + ): + super().__init__() + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + # self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or nr < num_attention_blocks[level] + ): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or i < num_attention_blocks[level] + ): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/extern/ldm_zero123/modules/diffusionmodules/util.py b/extern/ldm_zero123/modules/diffusionmodules/util.py new file mode 100755 index 0000000..250aa3e --- /dev/null +++ b/extern/ldm_zero123/modules/diffusionmodules/util.py @@ -0,0 +1,296 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import math +import os + +import numpy as np +import torch +import torch.nn as nn +from einops import repeat + +from extern.ldm_zero123.util import instantiate_from_config + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/extern/ldm_zero123/modules/distributions/__init__.py b/extern/ldm_zero123/modules/distributions/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/extern/ldm_zero123/modules/distributions/distributions.py b/extern/ldm_zero123/modules/distributions/distributions.py new file mode 100755 index 0000000..016be35 --- /dev/null +++ b/extern/ldm_zero123/modules/distributions/distributions.py @@ -0,0 +1,102 @@ +import numpy as np +import torch + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/extern/ldm_zero123/modules/ema.py b/extern/ldm_zero123/modules/ema.py new file mode 100755 index 0000000..880ca3d --- /dev/null +++ b/extern/ldm_zero123/modules/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/extern/ldm_zero123/modules/encoders/__init__.py b/extern/ldm_zero123/modules/encoders/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/extern/ldm_zero123/modules/encoders/modules.py b/extern/ldm_zero123/modules/encoders/modules.py new file mode 100755 index 0000000..d4f99fe --- /dev/null +++ b/extern/ldm_zero123/modules/encoders/modules.py @@ -0,0 +1,712 @@ +from functools import partial + +import clip +import kornia +import numpy as np +import torch +import torch.nn as nn + +from extern.ldm_zero123.modules.x_transformer import ( # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + Encoder, + TransformerWrapper, +) +from extern.ldm_zero123.util import default + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + def encode(self, x): + return x + + +class FaceClipEncoder(AbstractEncoder): + def __init__(self, augment=True, retreival_key=None): + super().__init__() + self.encoder = FrozenCLIPImageEmbedder() + self.augment = augment + self.retreival_key = retreival_key + + def forward(self, img): + encodings = [] + with torch.no_grad(): + x_offset = 125 + if self.retreival_key: + # Assumes retrieved image are packed into the second half of channels + face = img[:, 3:, 190:440, x_offset : (512 - x_offset)] + other = img[:, :3, ...].clone() + else: + face = img[:, :, 190:440, x_offset : (512 - x_offset)] + other = img.clone() + + if self.augment: + face = K.RandomHorizontalFlip()(face) + + other[:, :, 190:440, x_offset : (512 - x_offset)] *= 0 + encodings = [ + self.encoder.encode(face), + self.encoder.encode(other), + ] + + return torch.cat(encodings, dim=1) + + def encode(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros( + (1, 2, 768), device=self.encoder.model.visual.conv1.weight.device + ) + + return self(img) + + +class FaceIdClipEncoder(AbstractEncoder): + def __init__(self): + super().__init__() + self.encoder = FrozenCLIPImageEmbedder() + for p in self.encoder.parameters(): + p.requires_grad = False + self.id = FrozenFaceEncoder( + "/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True + ) + + def forward(self, img): + encodings = [] + with torch.no_grad(): + face = kornia.geometry.resize( + img, (256, 256), interpolation="bilinear", align_corners=True + ) + + other = img.clone() + other[:, :, 184:452, 122:396] *= 0 + encodings = [ + self.id.encode(face), + self.encoder.encode(other), + ] + + return torch.cat(encodings, dim=1) + + def encode(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros( + (1, 2, 768), device=self.encoder.model.visual.conv1.weight.device + ) + + return self(img) + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key="class"): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper( + num_tokens=vocab_size, + max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + ) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + + def __init__( + self, + n_embed, + n_layer, + vocab_size=30522, + max_seq_len=77, + device="cuda", + use_tokenizer=True, + embedding_dropout=0.0, + ): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper( + num_tokens=vocab_size, + max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout, + ) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text) # .to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, version="google/t5-v1_1-large", device="cuda", max_length=77 + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +import kornia.augmentation as K + +from extern.ldm_zero123.thirdp.psp.id_loss import IDFeatures + + +class FrozenFaceEncoder(AbstractEncoder): + def __init__(self, model_path, augment=False): + super().__init__() + self.loss_fn = IDFeatures(model_path) + # face encoder is frozen + for p in self.loss_fn.parameters(): + p.requires_grad = False + # Mapper is trainable + self.mapper = torch.nn.Linear(512, 768) + p = 0.25 + if augment: + self.augment = K.AugmentationSequential( + K.RandomHorizontalFlip(p=0.5), + K.RandomEqualize(p=p), + # K.RandomPlanckianJitter(p=p), + # K.RandomPlasmaBrightness(p=p), + # K.RandomPlasmaContrast(p=p), + # K.ColorJiggle(0.02, 0.2, 0.2, p=p), + ) + else: + self.augment = False + + def forward(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros((1, 1, 768), device=self.mapper.weight.device) + + if self.augment is not None: + # Transforms require 0-1 + img = self.augment((img + 1) / 2) + img = 2 * img - 1 + + feat = self.loss_fn(img, crop=True) + feat = self.mapper(feat.unsqueeze(1)) + return feat + + def encode(self, img): + return self(img) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + def __init__( + self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77 + ): # clip-vit-base-patch32 + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +import torch.nn.functional as F +from transformers import CLIPVisionModel + + +class ClipImageProjector(AbstractEncoder): + """ + Uses the CLIP image encoder. + """ + + def __init__( + self, version="openai/clip-vit-large-patch14", max_length=77 + ): # clip-vit-base-patch32 + super().__init__() + self.model = CLIPVisionModel.from_pretrained(version) + self.model.train() + self.max_length = max_length # TODO: typical value? + self.antialias = True + self.mapper = torch.nn.Linear(1024, 768) + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + null_cond = self.get_null_cond(version, max_length) + self.register_buffer("null_cond", null_cond) + + @torch.no_grad() + def get_null_cond(self, version, max_length): + device = self.mean.device + embedder = FrozenCLIPEmbedder( + version=version, device=device, max_length=max_length + ) + null_cond = embedder([""]) + return null_cond + + def preprocess(self, x): + # Expects inputs in the range -1, 1 + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + if isinstance(x, list): + return self.null_cond + # x is assumed to be in range [-1,1] + x = self.preprocess(x) + outputs = self.model(pixel_values=x) + last_hidden_state = outputs.last_hidden_state + last_hidden_state = self.mapper(last_hidden_state) + return F.pad( + last_hidden_state, + [0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0], + ) + + def encode(self, im): + return self(im) + + +class ProjectedFrozenCLIPEmbedder(AbstractEncoder): + def __init__( + self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77 + ): # clip-vit-base-patch32 + super().__init__() + self.embedder = FrozenCLIPEmbedder( + version=version, device=device, max_length=max_length + ) + self.projection = torch.nn.Linear(768, 768) + + def forward(self, text): + z = self.embedder(text) + return self.projection(z) + + def encode(self, text): + return self(text) + + +class FrozenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the CLIP image encoder. + Not actually frozen... If you want that set cond_stage_trainable=False in cfg + """ + + def __init__( + self, + model="ViT-L/14", + jit=False, + device="cpu", + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit, download_root=None) + # We don't use the text part so delete it + del self.model.transformer + self.antialias = antialias + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + + def preprocess(self, x): + # Expects inputs in the range -1, 1 + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + if isinstance(x, list): + # [""] denotes condition dropout for ucg + device = self.model.visual.conv1.weight.device + return torch.zeros(1, 768, device=device) + return self.model.encode_image(self.preprocess(x)).float() + + def encode(self, im): + return self(im).unsqueeze(1) + + +import random + +from torchvision import transforms + + +class FrozenCLIPImageMutliEmbedder(AbstractEncoder): + """ + Uses the CLIP image encoder. + Not actually frozen... If you want that set cond_stage_trainable=False in cfg + """ + + def __init__( + self, + model="ViT-L/14", + jit=False, + device="cpu", + antialias=True, + max_crops=5, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + # We don't use the text part so delete it + del self.model.transformer + self.antialias = antialias + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + self.max_crops = max_crops + + def preprocess(self, x): + # Expects inputs in the range -1, 1 + randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1, 1)) + max_crops = self.max_crops + patches = [] + crops = [randcrop(x) for _ in range(max_crops)] + patches.extend(crops) + x = torch.cat(patches, dim=0) + x = (x + 1.0) / 2.0 + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + if isinstance(x, list): + # [""] denotes condition dropout for ucg + device = self.model.visual.conv1.weight.device + return torch.zeros(1, self.max_crops, 768, device=device) + batch_tokens = [] + for im in x: + patches = self.preprocess(im.unsqueeze(0)) + tokens = self.model.encode_image(patches).float() + for t in tokens: + if random.random() < 0.1: + t *= 0 + batch_tokens.append(tokens.unsqueeze(0)) + + return torch.cat(batch_tokens, dim=0) + + def encode(self, im): + return self(im) + + +class SpatialRescaler(nn.Module): + def __init__( + self, + n_stages=1, + method="bilinear", + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False, + ): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in [ + "nearest", + "linear", + "bilinear", + "trilinear", + "bicubic", + "area", + ] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print( + f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." + ) + self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias) + + def forward(self, x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +from extern.ldm_zero123.modules.diffusionmodules.util import ( + extract_into_tensor, + make_beta_schedule, + noise_like, +) +from extern.ldm_zero123.util import instantiate_from_config + + +class LowScaleEncoder(nn.Module): + def __init__( + self, + model_config, + linear_start, + linear_end, + timesteps=1000, + max_noise_level=250, + output_size=64, + scale_factor=1.0, + ): + super().__init__() + self.max_noise_level = max_noise_level + self.model = instantiate_from_config(model_config) + self.augmentation_schedule = self.register_schedule( + timesteps=timesteps, linear_start=linear_start, linear_end=linear_end + ) + self.out_size = output_size + self.scale_factor = scale_factor + + def register_schedule( + self, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def forward(self, x): + z = self.model.encode(x).sample() + z = z * self.scale_factor + noise_level = torch.randint( + 0, self.max_noise_level, (x.shape[0],), device=x.device + ).long() + z = self.q_sample(z, noise_level) + if self.out_size is not None: + z = torch.nn.functional.interpolate( + z, size=self.out_size, mode="nearest" + ) # TODO: experiment with mode + # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) + return z, noise_level + + def decode(self, z): + z = z / self.scale_factor + return self.model.decode(z) + + +if __name__ == "__main__": + from extern.ldm_zero123.util import count_params + + sentences = [ + "a hedgehog drinking a whiskey", + "der mond ist aufgegangen", + "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'", + ] + model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda() + count_params(model, True) + z = model(sentences) + print(z.shape) + + model = FrozenCLIPEmbedder().cuda() + count_params(model, True) + z = model(sentences) + print(z.shape) + + print("done.") diff --git a/extern/ldm_zero123/modules/evaluate/adm_evaluator.py b/extern/ldm_zero123/modules/evaluate/adm_evaluator.py new file mode 100755 index 0000000..6b70eda --- /dev/null +++ b/extern/ldm_zero123/modules/evaluate/adm_evaluator.py @@ -0,0 +1,703 @@ +import argparse +import io +import os +import random +import warnings +import zipfile +from abc import ABC, abstractmethod +from contextlib import contextmanager +from functools import partial +from multiprocessing import cpu_count +from multiprocessing.pool import ThreadPool +from typing import Iterable, Optional, Tuple + +import numpy as np +import requests +import tensorflow.compat.v1 as tf +import yaml +from scipy import linalg +from tqdm.auto import tqdm + +INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" +INCEPTION_V3_PATH = "classify_image_graph_def.pb" + +FID_POOL_NAME = "pool_3:0" +FID_SPATIAL_NAME = "mixed_6/conv:0" + +REQUIREMENTS = ( + f"This script has the following requirements: \n" + "tensorflow-gpu>=2.0" + "\n" + "scipy" + "\n" + "requests" + "\n" + "tqdm" +) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--ref_batch", help="path to reference batch npz file") + parser.add_argument("--sample_batch", help="path to sample batch npz file") + args = parser.parse_args() + + config = tf.ConfigProto( + allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph + ) + config.gpu_options.allow_growth = True + evaluator = Evaluator(tf.Session(config=config)) + + print("warming up TensorFlow...") + # This will cause TF to print a bunch of verbose stuff now rather + # than after the next print(), to help prevent confusion. + evaluator.warmup() + + print("computing reference batch activations...") + ref_acts = evaluator.read_activations(args.ref_batch) + print("computing/reading reference batch statistics...") + ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) + + print("computing sample batch activations...") + sample_acts = evaluator.read_activations(args.sample_batch) + print("computing/reading sample batch statistics...") + sample_stats, sample_stats_spatial = evaluator.read_statistics( + args.sample_batch, sample_acts + ) + + print("Computing evaluations...") + is_ = evaluator.compute_inception_score(sample_acts[0]) + print("Inception Score:", is_) + fid = sample_stats.frechet_distance(ref_stats) + print("FID:", fid) + sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial) + print("sFID:", sfid) + prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) + print("Precision:", prec) + print("Recall:", recall) + + savepath = "/".join(args.sample_batch.split("/")[:-1]) + results_file = os.path.join(savepath, "evaluation_metrics.yaml") + print(f'Saving evaluation results to "{results_file}"') + + results = { + "IS": is_, + "FID": fid, + "sFID": sfid, + "Precision:": prec, + "Recall": recall, + } + + with open(results_file, "w") as f: + yaml.dump(results, f, default_flow_style=False) + + +class InvalidFIDException(Exception): + pass + + +class FIDStatistics: + def __init__(self, mu: np.ndarray, sigma: np.ndarray): + self.mu = mu + self.sigma = sigma + + def frechet_distance(self, other, eps=1e-6): + """ + Compute the Frechet distance between two sets of statistics. + """ + # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 + mu1, sigma1 = self.mu, self.sigma + mu2, sigma2 = other.mu, other.sigma + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert ( + mu1.shape == mu2.shape + ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" + assert ( + sigma1.shape == sigma2.shape + ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; adding %s to diagonal of cov estimates" + % eps + ) + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +class Evaluator: + def __init__( + self, + session, + batch_size=64, + softmax_batch_size=512, + ): + self.sess = session + self.batch_size = batch_size + self.softmax_batch_size = softmax_batch_size + self.manifold_estimator = ManifoldEstimator(session) + with self.sess.graph.as_default(): + self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) + self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) + self.pool_features, self.spatial_features = _create_feature_graph( + self.image_input + ) + self.softmax = _create_softmax_graph(self.softmax_input) + + def warmup(self): + self.compute_activations(np.zeros([1, 8, 64, 64, 3])) + + def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]: + with open_npz_array(npz_path, "arr_0") as reader: + return self.compute_activations(reader.read_batches(self.batch_size)) + + def compute_activations( + self, batches: Iterable[np.ndarray], silent=False + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute image features for downstream evals. + + :param batches: a iterator over NHWC numpy arrays in [0, 255]. + :return: a tuple of numpy arrays of shape [N x X], where X is a feature + dimension. The tuple is (pool_3, spatial). + """ + preds = [] + spatial_preds = [] + it = batches if silent else tqdm(batches) + for batch in it: + batch = batch.astype(np.float32) + pred, spatial_pred = self.sess.run( + [self.pool_features, self.spatial_features], {self.image_input: batch} + ) + preds.append(pred.reshape([pred.shape[0], -1])) + spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) + return ( + np.concatenate(preds, axis=0), + np.concatenate(spatial_preds, axis=0), + ) + + def read_statistics( + self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray] + ) -> Tuple[FIDStatistics, FIDStatistics]: + obj = np.load(npz_path) + if "mu" in list(obj.keys()): + return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( + obj["mu_s"], obj["sigma_s"] + ) + return tuple(self.compute_statistics(x) for x in activations) + + def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return FIDStatistics(mu, sigma) + + def compute_inception_score( + self, activations: np.ndarray, split_size: int = 5000 + ) -> float: + softmax_out = [] + for i in range(0, len(activations), self.softmax_batch_size): + acts = activations[i : i + self.softmax_batch_size] + softmax_out.append( + self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}) + ) + preds = np.concatenate(softmax_out, axis=0) + # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 + scores = [] + for i in range(0, len(preds), split_size): + part = preds[i : i + split_size] + kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)) + + def compute_prec_recall( + self, activations_ref: np.ndarray, activations_sample: np.ndarray + ) -> Tuple[float, float]: + radii_1 = self.manifold_estimator.manifold_radii(activations_ref) + radii_2 = self.manifold_estimator.manifold_radii(activations_sample) + pr = self.manifold_estimator.evaluate_pr( + activations_ref, radii_1, activations_sample, radii_2 + ) + return (float(pr[0][0]), float(pr[1][0])) + + +class ManifoldEstimator: + """ + A helper for comparing manifolds of feature vectors. + + Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 + """ + + def __init__( + self, + session, + row_batch_size=10000, + col_batch_size=10000, + nhood_sizes=(3,), + clamp_to_percentile=None, + eps=1e-5, + ): + """ + Estimate the manifold of given feature vectors. + + :param session: the TensorFlow session. + :param row_batch_size: row batch size to compute pairwise distances + (parameter to trade-off between memory usage and performance). + :param col_batch_size: column batch size to compute pairwise distances. + :param nhood_sizes: number of neighbors used to estimate the manifold. + :param clamp_to_percentile: prune hyperspheres that have radius larger than + the given percentile. + :param eps: small number for numerical stability. + """ + self.distance_block = DistanceBlock(session) + self.row_batch_size = row_batch_size + self.col_batch_size = col_batch_size + self.nhood_sizes = nhood_sizes + self.num_nhoods = len(nhood_sizes) + self.clamp_to_percentile = clamp_to_percentile + self.eps = eps + + def warmup(self): + feats, radii = ( + np.zeros([1, 2048], dtype=np.float32), + np.zeros([1, 1], dtype=np.float32), + ) + self.evaluate_pr(feats, radii, feats, radii) + + def manifold_radii(self, features: np.ndarray) -> np.ndarray: + num_images = len(features) + + # Estimate manifold of features by calculating distances to k-NN of each sample. + radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) + distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) + seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) + + for begin1 in range(0, num_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_images) + row_batch = features[begin1:end1] + + for begin2 in range(0, num_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_images) + col_batch = features[begin2:end2] + + # Compute distances between batches. + distance_batch[ + 0 : end1 - begin1, begin2:end2 + ] = self.distance_block.pairwise_distances(row_batch, col_batch) + + # Find the k-nearest neighbor from the current batch. + radii[begin1:end1, :] = np.concatenate( + [ + x[:, self.nhood_sizes] + for x in _numpy_partition( + distance_batch[0 : end1 - begin1, :], seq, axis=1 + ) + ], + axis=0, + ) + + if self.clamp_to_percentile is not None: + max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) + radii[radii > max_distances] = 0 + return radii + + def evaluate( + self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray + ): + """ + Evaluate if new feature vectors are at the manifold. + """ + num_eval_images = eval_features.shape[0] + num_ref_images = radii.shape[0] + distance_batch = np.zeros( + [self.row_batch_size, num_ref_images], dtype=np.float32 + ) + batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) + max_realism_score = np.zeros([num_eval_images], dtype=np.float32) + nearest_indices = np.zeros([num_eval_images], dtype=np.int32) + + for begin1 in range(0, num_eval_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_eval_images) + feature_batch = eval_features[begin1:end1] + + for begin2 in range(0, num_ref_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_ref_images) + ref_batch = features[begin2:end2] + + distance_batch[ + 0 : end1 - begin1, begin2:end2 + ] = self.distance_block.pairwise_distances(feature_batch, ref_batch) + + # From the minibatch of new feature vectors, determine if they are in the estimated manifold. + # If a feature vector is inside a hypersphere of some reference sample, then + # the new sample lies at the estimated manifold. + # The radii of the hyperspheres are determined from distances of neighborhood size k. + samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii + batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype( + np.int32 + ) + + max_realism_score[begin1:end1] = np.max( + radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 + ) + nearest_indices[begin1:end1] = np.argmin( + distance_batch[0 : end1 - begin1, :], axis=1 + ) + + return { + "fraction": float(np.mean(batch_predictions)), + "batch_predictions": batch_predictions, + "max_realisim_score": max_realism_score, + "nearest_indices": nearest_indices, + } + + def evaluate_pr( + self, + features_1: np.ndarray, + radii_1: np.ndarray, + features_2: np.ndarray, + radii_2: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Evaluate precision and recall efficiently. + + :param features_1: [N1 x D] feature vectors for reference batch. + :param radii_1: [N1 x K1] radii for reference vectors. + :param features_2: [N2 x D] feature vectors for the other batch. + :param radii_2: [N x K2] radii for other vectors. + :return: a tuple of arrays for (precision, recall): + - precision: an np.ndarray of length K1 + - recall: an np.ndarray of length K2 + """ + features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool) + features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool) + for begin_1 in range(0, len(features_1), self.row_batch_size): + end_1 = begin_1 + self.row_batch_size + batch_1 = features_1[begin_1:end_1] + for begin_2 in range(0, len(features_2), self.col_batch_size): + end_2 = begin_2 + self.col_batch_size + batch_2 = features_2[begin_2:end_2] + batch_1_in, batch_2_in = self.distance_block.less_thans( + batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] + ) + features_1_status[begin_1:end_1] |= batch_1_in + features_2_status[begin_2:end_2] |= batch_2_in + return ( + np.mean(features_2_status.astype(np.float64), axis=0), + np.mean(features_1_status.astype(np.float64), axis=0), + ) + + +class DistanceBlock: + """ + Calculate pairwise distances between vectors. + + Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 + """ + + def __init__(self, session): + self.session = session + + # Initialize TF graph to calculate pairwise distances. + with session.graph.as_default(): + self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) + self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) + distance_block_16 = _batch_pairwise_distances( + tf.cast(self._features_batch1, tf.float16), + tf.cast(self._features_batch2, tf.float16), + ) + self.distance_block = tf.cond( + tf.reduce_all(tf.math.is_finite(distance_block_16)), + lambda: tf.cast(distance_block_16, tf.float32), + lambda: _batch_pairwise_distances( + self._features_batch1, self._features_batch2 + ), + ) + + # Extra logic for less thans. + self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) + self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) + dist32 = tf.cast(self.distance_block, tf.float32)[..., None] + self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) + self._batch_2_in = tf.math.reduce_any( + dist32 <= self._radii1[:, None], axis=0 + ) + + def pairwise_distances(self, U, V): + """ + Evaluate pairwise distances between two batches of feature vectors. + """ + return self.session.run( + self.distance_block, + feed_dict={self._features_batch1: U, self._features_batch2: V}, + ) + + def less_thans(self, batch_1, radii_1, batch_2, radii_2): + return self.session.run( + [self._batch_1_in, self._batch_2_in], + feed_dict={ + self._features_batch1: batch_1, + self._features_batch2: batch_2, + self._radii1: radii_1, + self._radii2: radii_2, + }, + ) + + +def _batch_pairwise_distances(U, V): + """ + Compute pairwise distances between two batches of feature vectors. + """ + with tf.variable_scope("pairwise_dist_block"): + # Squared norms of each row in U and V. + norm_u = tf.reduce_sum(tf.square(U), 1) + norm_v = tf.reduce_sum(tf.square(V), 1) + + # norm_u as a column and norm_v as a row vectors. + norm_u = tf.reshape(norm_u, [-1, 1]) + norm_v = tf.reshape(norm_v, [1, -1]) + + # Pairwise squared Euclidean distances. + D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) + + return D + + +class NpzArrayReader(ABC): + @abstractmethod + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + pass + + @abstractmethod + def remaining(self) -> int: + pass + + def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: + def gen_fn(): + while True: + batch = self.read_batch(batch_size) + if batch is None: + break + yield batch + + rem = self.remaining() + num_batches = rem // batch_size + int(rem % batch_size != 0) + return BatchIterator(gen_fn, num_batches) + + +class BatchIterator: + def __init__(self, gen_fn, length): + self.gen_fn = gen_fn + self.length = length + + def __len__(self): + return self.length + + def __iter__(self): + return self.gen_fn() + + +class StreamingNpzArrayReader(NpzArrayReader): + def __init__(self, arr_f, shape, dtype): + self.arr_f = arr_f + self.shape = shape + self.dtype = dtype + self.idx = 0 + + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + if self.idx >= self.shape[0]: + return None + + bs = min(batch_size, self.shape[0] - self.idx) + self.idx += bs + + if self.dtype.itemsize == 0: + return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) + + read_count = bs * np.prod(self.shape[1:]) + read_size = int(read_count * self.dtype.itemsize) + data = _read_bytes(self.arr_f, read_size, "array data") + return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) + + def remaining(self) -> int: + return max(0, self.shape[0] - self.idx) + + +class MemoryNpzArrayReader(NpzArrayReader): + def __init__(self, arr): + self.arr = arr + self.idx = 0 + + @classmethod + def load(cls, path: str, arr_name: str): + with open(path, "rb") as f: + arr = np.load(f)[arr_name] + return cls(arr) + + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + if self.idx >= self.arr.shape[0]: + return None + + res = self.arr[self.idx : self.idx + batch_size] + self.idx += batch_size + return res + + def remaining(self) -> int: + return max(0, self.arr.shape[0] - self.idx) + + +@contextmanager +def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: + with _open_npy_file(path, arr_name) as arr_f: + version = np.lib.format.read_magic(arr_f) + if version == (1, 0): + header = np.lib.format.read_array_header_1_0(arr_f) + elif version == (2, 0): + header = np.lib.format.read_array_header_2_0(arr_f) + else: + yield MemoryNpzArrayReader.load(path, arr_name) + return + shape, fortran, dtype = header + if fortran or dtype.hasobject: + yield MemoryNpzArrayReader.load(path, arr_name) + else: + yield StreamingNpzArrayReader(arr_f, shape, dtype) + + +def _read_bytes(fp, size, error_template="ran out of data"): + """ + Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 + + Read from file-like object until size bytes are read. + Raises ValueError if not EOF is encountered before size bytes are read. + Non-blocking objects only supported if they derive from io objects. + Required as e.g. ZipExtFile in python 2.6 can return less data than + requested. + """ + data = bytes() + while True: + # io files (default in python3) return None or raise on + # would-block, python2 file will truncate, probably nothing can be + # done about that. note that regular files can't be non-blocking + try: + r = fp.read(size - len(data)) + data += r + if len(r) == 0 or len(data) == size: + break + except io.BlockingIOError: + pass + if len(data) != size: + msg = "EOF: reading %s, expected %d bytes got %d" + raise ValueError(msg % (error_template, size, len(data))) + else: + return data + + +@contextmanager +def _open_npy_file(path: str, arr_name: str): + with open(path, "rb") as f: + with zipfile.ZipFile(f, "r") as zip_f: + if f"{arr_name}.npy" not in zip_f.namelist(): + raise ValueError(f"missing {arr_name} in npz file") + with zip_f.open(f"{arr_name}.npy", "r") as arr_f: + yield arr_f + + +def _download_inception_model(): + if os.path.exists(INCEPTION_V3_PATH): + return + print("downloading InceptionV3 model...") + with requests.get(INCEPTION_V3_URL, stream=True) as r: + r.raise_for_status() + tmp_path = INCEPTION_V3_PATH + ".tmp" + with open(tmp_path, "wb") as f: + for chunk in tqdm(r.iter_content(chunk_size=8192)): + f.write(chunk) + os.rename(tmp_path, INCEPTION_V3_PATH) + + +def _create_feature_graph(input_batch): + _download_inception_model() + prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" + with open(INCEPTION_V3_PATH, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + pool3, spatial = tf.import_graph_def( + graph_def, + input_map={f"ExpandDims:0": input_batch}, + return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], + name=prefix, + ) + _update_shapes(pool3) + spatial = spatial[..., :7] + return pool3, spatial + + +def _create_softmax_graph(input_batch): + _download_inception_model() + prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" + with open(INCEPTION_V3_PATH, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + (matmul,) = tf.import_graph_def( + graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix + ) + w = matmul.inputs[1] + logits = tf.matmul(input_batch, w) + return tf.nn.softmax(logits) + + +def _update_shapes(pool3): + # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 + ops = pool3.graph.get_operations() + for op in ops: + for o in op.outputs: + shape = o.get_shape() + if shape._dims is not None: # pylint: disable=protected-access + # shape = [s.value for s in shape] TF 1.x + shape = [s for s in shape] # TF 2.x + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.__dict__["_shape_val"] = tf.TensorShape(new_shape) + return pool3 + + +def _numpy_partition(arr, kth, **kwargs): + num_workers = min(cpu_count(), len(arr)) + chunk_size = len(arr) // num_workers + extra = len(arr) % num_workers + + start_idx = 0 + batches = [] + for i in range(num_workers): + size = chunk_size + (1 if i < extra else 0) + batches.append(arr[start_idx : start_idx + size]) + start_idx += size + + with ThreadPool(num_workers) as pool: + return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) + + +if __name__ == "__main__": + print(REQUIREMENTS) + main() diff --git a/extern/ldm_zero123/modules/evaluate/evaluate_perceptualsim.py b/extern/ldm_zero123/modules/evaluate/evaluate_perceptualsim.py new file mode 100755 index 0000000..023f8c3 --- /dev/null +++ b/extern/ldm_zero123/modules/evaluate/evaluate_perceptualsim.py @@ -0,0 +1,606 @@ +import argparse +import glob +import os +from collections import namedtuple + +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image +from torchvision import models +from tqdm import tqdm + +from extern.ldm_zero123.modules.evaluate.ssim import ssim + +transform = transforms.Compose([transforms.ToTensor()]) + + +def normalize_tensor(in_feat, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1)).view( + in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3] + ) + return in_feat / (norm_factor.expand_as(in_feat) + eps) + + +def cos_sim(in0, in1): + in0_norm = normalize_tensor(in0) + in1_norm = normalize_tensor(in1) + N = in0.size()[0] + X = in0.size()[2] + Y = in0.size()[3] + + return torch.mean( + torch.mean(torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2).view( + N, 1, 1, Y + ), + dim=3, + ).view(N) + + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = models.squeezenet1_1(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple( + "SqueezeOutputs", + ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"], + ) + out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple( + "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"] + ) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", + ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"], + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if num == 18: + self.net = models.resnet18(pretrained=pretrained) + elif num == 34: + self.net = models.resnet34(pretrained=pretrained) + elif num == 50: + self.net = models.resnet50(pretrained=pretrained) + elif num == 101: + self.net = models.resnet101(pretrained=pretrained) + elif num == 152: + self.net = models.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out + + +# Off-the-shelf deep network +class PNet(torch.nn.Module): + """Pre-trained network with all channels equally weighted by default""" + + def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True): + super(PNet, self).__init__() + + self.use_gpu = use_gpu + + self.pnet_type = pnet_type + self.pnet_rand = pnet_rand + + self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1) + self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1) + + if self.pnet_type in ["vgg", "vgg16"]: + self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False) + elif self.pnet_type == "alex": + self.net = alexnet(pretrained=not self.pnet_rand, requires_grad=False) + elif self.pnet_type[:-2] == "resnet": + self.net = resnet( + pretrained=not self.pnet_rand, + requires_grad=False, + num=int(self.pnet_type[-2:]), + ) + elif self.pnet_type == "squeeze": + self.net = squeezenet(pretrained=not self.pnet_rand, requires_grad=False) + + self.L = self.net.N_slices + + if use_gpu: + self.net.cuda() + self.shift = self.shift.cuda() + self.scale = self.scale.cuda() + + def forward(self, in0, in1, retPerLayer=False): + in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0) + in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0) + + outs0 = self.net.forward(in0_sc) + outs1 = self.net.forward(in1_sc) + + if retPerLayer: + all_scores = [] + for kk, out0 in enumerate(outs0): + cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk]) + if kk == 0: + val = 1.0 * cur_score + else: + val = val + cur_score + if retPerLayer: + all_scores += [cur_score] + + if retPerLayer: + return (val, all_scores) + else: + return val + + +# The SSIM metric +def ssim_metric(img1, img2, mask=None): + return ssim(img1, img2, mask=mask, size_average=False) + + +# The PSNR metric +def psnr(img1, img2, mask=None, reshape=False): + b = img1.size(0) + if not (mask is None): + b = img1.size(0) + mse_err = (img1 - img2).pow(2) * mask + if reshape: + mse_err = mse_err.reshape(b, -1).sum(dim=1) / ( + 3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1) + ) + else: + mse_err = mse_err.view(b, -1).sum(dim=1) / ( + 3 * mask.view(b, -1).sum(dim=1).clamp(min=1) + ) + else: + if reshape: + mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1) + else: + mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1) + + psnr = 10 * (1 / mse_err).log10() + return psnr + + +# The perceptual similarity metric +def perceptual_sim(img1, img2, vgg16): + # First extract features + dist = vgg16(img1 * 2 - 1, img2 * 2 - 1) + + return dist + + +def load_img(img_name, size=None): + try: + img = Image.open(img_name) + + if type(size) == int: + img = img.resize((size, size)) + elif size is not None: + img = img.resize((size[1], size[0])) + + img = transform(img).cuda() + img = img.unsqueeze(0) + except Exception as e: + print("Failed at loading %s " % img_name) + print(e) + img = torch.zeros(1, 3, 256, 256).cuda() + raise + return img + + +def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other): + # Load VGG16 for feature similarity + vgg16 = PNet().to("cuda") + vgg16.eval() + vgg16.cuda() + + values_percsim = [] + values_ssim = [] + values_psnr = [] + folders = os.listdir(folder) + for i, f in tqdm(enumerate(sorted(folders))): + pred_imgs = glob.glob(folder + f + "/" + pred_img) + tgt_imgs = glob.glob(folder + f + "/" + tgt_img) + assert len(tgt_imgs) == 1 + + perc_sim = 10000 + ssim_sim = -10 + psnr_sim = -10 + for p_img in pred_imgs: + t_img = load_img(tgt_imgs[0]) + p_img = load_img(p_img, size=t_img.shape[2:]) + t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item() + perc_sim = min(perc_sim, t_perc_sim) + + ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item()) + psnr_sim = max(psnr_sim, psnr(p_img, t_img).item()) + + values_percsim += [perc_sim] + values_ssim += [ssim_sim] + values_psnr += [psnr_sim] + + if take_every_other: + n_valuespercsim = [] + n_valuesssim = [] + n_valuespsnr = [] + for i in range(0, len(values_percsim) // 2): + n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])] + n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])] + n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])] + + values_percsim = n_valuespercsim + values_ssim = n_valuesssim + values_psnr = n_valuespsnr + + avg_percsim = np.mean(np.array(values_percsim)) + std_percsim = np.std(np.array(values_percsim)) + + avg_psnr = np.mean(np.array(values_psnr)) + std_psnr = np.std(np.array(values_psnr)) + + avg_ssim = np.mean(np.array(values_ssim)) + std_ssim = np.std(np.array(values_ssim)) + + return { + "Perceptual similarity": (avg_percsim, std_percsim), + "PSNR": (avg_psnr, std_psnr), + "SSIM": (avg_ssim, std_ssim), + } + + +def compute_perceptual_similarity_from_list( + pred_imgs_list, tgt_imgs_list, take_every_other, simple_format=True +): + # Load VGG16 for feature similarity + vgg16 = PNet().to("cuda") + vgg16.eval() + vgg16.cuda() + + values_percsim = [] + values_ssim = [] + values_psnr = [] + equal_count = 0 + ambig_count = 0 + for i, tgt_img in enumerate(tqdm(tgt_imgs_list)): + pred_imgs = pred_imgs_list[i] + tgt_imgs = [tgt_img] + assert len(tgt_imgs) == 1 + + if type(pred_imgs) != list: + pred_imgs = [pred_imgs] + + perc_sim = 10000 + ssim_sim = -10 + psnr_sim = -10 + assert len(pred_imgs) > 0 + for p_img in pred_imgs: + t_img = load_img(tgt_imgs[0]) + p_img = load_img(p_img, size=t_img.shape[2:]) + t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item() + perc_sim = min(perc_sim, t_perc_sim) + + ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item()) + psnr_sim = max(psnr_sim, psnr(p_img, t_img).item()) + + values_percsim += [perc_sim] + values_ssim += [ssim_sim] + if psnr_sim != np.float("inf"): + values_psnr += [psnr_sim] + else: + if torch.allclose(p_img, t_img): + equal_count += 1 + print("{} equal src and wrp images.".format(equal_count)) + else: + ambig_count += 1 + print("{} ambiguous src and wrp images.".format(ambig_count)) + + if take_every_other: + n_valuespercsim = [] + n_valuesssim = [] + n_valuespsnr = [] + for i in range(0, len(values_percsim) // 2): + n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])] + n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])] + n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])] + + values_percsim = n_valuespercsim + values_ssim = n_valuesssim + values_psnr = n_valuespsnr + + avg_percsim = np.mean(np.array(values_percsim)) + std_percsim = np.std(np.array(values_percsim)) + + avg_psnr = np.mean(np.array(values_psnr)) + std_psnr = np.std(np.array(values_psnr)) + + avg_ssim = np.mean(np.array(values_ssim)) + std_ssim = np.std(np.array(values_ssim)) + + if simple_format: + # just to make yaml formatting readable + return { + "Perceptual similarity": [float(avg_percsim), float(std_percsim)], + "PSNR": [float(avg_psnr), float(std_psnr)], + "SSIM": [float(avg_ssim), float(std_ssim)], + } + else: + return { + "Perceptual similarity": (avg_percsim, std_percsim), + "PSNR": (avg_psnr, std_psnr), + "SSIM": (avg_ssim, std_ssim), + } + + +def compute_perceptual_similarity_from_list_topk( + pred_imgs_list, tgt_imgs_list, take_every_other, resize=False +): + # Load VGG16 for feature similarity + vgg16 = PNet().to("cuda") + vgg16.eval() + vgg16.cuda() + + values_percsim = [] + values_ssim = [] + values_psnr = [] + individual_percsim = [] + individual_ssim = [] + individual_psnr = [] + for i, tgt_img in enumerate(tqdm(tgt_imgs_list)): + pred_imgs = pred_imgs_list[i] + tgt_imgs = [tgt_img] + assert len(tgt_imgs) == 1 + + if type(pred_imgs) != list: + assert False + pred_imgs = [pred_imgs] + + perc_sim = 10000 + ssim_sim = -10 + psnr_sim = -10 + sample_percsim = list() + sample_ssim = list() + sample_psnr = list() + for p_img in pred_imgs: + if resize: + t_img = load_img(tgt_imgs[0], size=(256, 256)) + else: + t_img = load_img(tgt_imgs[0]) + p_img = load_img(p_img, size=t_img.shape[2:]) + + t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item() + sample_percsim.append(t_perc_sim) + perc_sim = min(perc_sim, t_perc_sim) + + t_ssim = ssim_metric(p_img, t_img).item() + sample_ssim.append(t_ssim) + ssim_sim = max(ssim_sim, t_ssim) + + t_psnr = psnr(p_img, t_img).item() + sample_psnr.append(t_psnr) + psnr_sim = max(psnr_sim, t_psnr) + + values_percsim += [perc_sim] + values_ssim += [ssim_sim] + values_psnr += [psnr_sim] + individual_percsim.append(sample_percsim) + individual_ssim.append(sample_ssim) + individual_psnr.append(sample_psnr) + + if take_every_other: + assert False, "Do this later, after specifying topk to get proper results" + n_valuespercsim = [] + n_valuesssim = [] + n_valuespsnr = [] + for i in range(0, len(values_percsim) // 2): + n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])] + n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])] + n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])] + + values_percsim = n_valuespercsim + values_ssim = n_valuesssim + values_psnr = n_valuespsnr + + avg_percsim = np.mean(np.array(values_percsim)) + std_percsim = np.std(np.array(values_percsim)) + + avg_psnr = np.mean(np.array(values_psnr)) + std_psnr = np.std(np.array(values_psnr)) + + avg_ssim = np.mean(np.array(values_ssim)) + std_ssim = np.std(np.array(values_ssim)) + + individual_percsim = np.array(individual_percsim) + individual_psnr = np.array(individual_psnr) + individual_ssim = np.array(individual_ssim) + + return { + "avg_of_best": { + "Perceptual similarity": [float(avg_percsim), float(std_percsim)], + "PSNR": [float(avg_psnr), float(std_psnr)], + "SSIM": [float(avg_ssim), float(std_ssim)], + }, + "individual": { + "PSIM": individual_percsim, + "PSNR": individual_psnr, + "SSIM": individual_ssim, + }, + } + + +if __name__ == "__main__": + args = argparse.ArgumentParser() + args.add_argument("--folder", type=str, default="") + args.add_argument("--pred_image", type=str, default="") + args.add_argument("--target_image", type=str, default="") + args.add_argument("--take_every_other", action="store_true", default=False) + args.add_argument("--output_file", type=str, default="") + + opts = args.parse_args() + + folder = opts.folder + pred_img = opts.pred_image + tgt_img = opts.target_image + + results = compute_perceptual_similarity( + folder, pred_img, tgt_img, opts.take_every_other + ) + + f = open(opts.output_file, "w") + for key in results: + print("%s for %s: \n" % (key, opts.folder)) + print("\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1])) + + f.write("%s for %s: \n" % (key, opts.folder)) + f.write("\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1])) + + f.close() diff --git a/extern/ldm_zero123/modules/evaluate/frechet_video_distance.py b/extern/ldm_zero123/modules/evaluate/frechet_video_distance.py new file mode 100755 index 0000000..61688d0 --- /dev/null +++ b/extern/ldm_zero123/modules/evaluate/frechet_video_distance.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python2, python3 +"""Minimal Reference implementation for the Frechet Video Distance (FVD). + +FVD is a metric for the quality of video generation models. It is inspired by +the FID (Frechet Inception Distance) used for images, but uses a different +embedding to be better suitable for videos. +""" + +from __future__ import absolute_import, division, print_function + +import six +import tensorflow.compat.v1 as tf +import tensorflow_gan as tfgan +import tensorflow_hub as hub + + +def preprocess(videos, target_resolution): + """Runs some preprocessing on the videos for I3D model. + + Args: + videos: [batch_size, num_frames, height, width, depth] The videos to be + preprocessed. We don't care about the specific dtype of the videos, it can + be anything that tf.image.resize_bilinear accepts. Values are expected to + be in the range 0-255. + target_resolution: (width, height): target video resolution + + Returns: + videos: [batch_size, num_frames, height, width, depth] + """ + videos_shape = list(videos.shape) + all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) + resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) + target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] + output_videos = tf.reshape(resized_videos, target_shape) + scaled_videos = 2.0 * tf.cast(output_videos, tf.float32) / 255.0 - 1 + return scaled_videos + + +def _is_in_graph(tensor_name): + """Checks whether a given tensor does exists in the graph.""" + try: + tf.get_default_graph().get_tensor_by_name(tensor_name) + except KeyError: + return False + return True + + +def create_id3_embedding(videos, warmup=False, batch_size=16): + """Embeds the given videos using the Inflated 3D Convolution ne twork. + + Downloads the graph of the I3D from tf.hub and adds it to the graph on the + first call. + + Args: + videos: [batch_size, num_frames, height=224, width=224, depth=3]. + Expected range is [-1, 1]. + + Returns: + embedding: [batch_size, embedding_size]. embedding_size depends + on the model used. + + Raises: + ValueError: when a provided embedding_layer is not supported. + """ + + # batch_size = 16 + module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" + + # Making sure that we import the graph separately for + # each different input video tensor. + module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(videos.name).replace( + ":", "_" + ) + + assert_ops = [ + tf.Assert( + tf.reduce_max(videos) <= 1.001, ["max value in frame is > 1", videos] + ), + tf.Assert( + tf.reduce_min(videos) >= -1.001, ["min value in frame is < -1", videos] + ), + tf.assert_equal( + tf.shape(videos)[0], + batch_size, + ["invalid frame batch size: ", tf.shape(videos)], + summarize=6, + ), + ] + with tf.control_dependencies(assert_ops): + videos = tf.identity(videos) + + module_scope = "%s_apply_default/" % module_name + + # To check whether the module has already been loaded into the graph, we look + # for a given tensor name. If this tensor name exists, we assume the function + # has been called before and the graph was imported. Otherwise we import it. + # Note: in theory, the tensor could exist, but have wrong shapes. + # This will happen if create_id3_embedding is called with a frames_placehoder + # of wrong size/batch size, because even though that will throw a tf.Assert + # on graph-execution time, it will insert the tensor (with wrong shape) into + # the graph. This is why we need the following assert. + if warmup: + video_batch_size = int(videos.shape[0]) + assert video_batch_size in [ + batch_size, + -1, + None, + ], f"Invalid batch size {video_batch_size}" + tensor_name = module_scope + "RGB/inception_i3d/Mean:0" + if not _is_in_graph(tensor_name): + i3d_model = hub.Module(module_spec, name=module_name) + i3d_model(videos) + + # gets the kinetics-i3d-400-logits layer + tensor_name = module_scope + "RGB/inception_i3d/Mean:0" + tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) + return tensor + + +def calculate_fvd(real_activations, generated_activations): + """Returns a list of ops that compute metrics as funcs of activations. + + Args: + real_activations: [num_samples, embedding_size] + generated_activations: [num_samples, embedding_size] + + Returns: + A scalar that contains the requested FVD. + """ + return tfgan.eval.frechet_classifier_distance_from_activations( + real_activations, generated_activations + ) diff --git a/extern/ldm_zero123/modules/evaluate/ssim.py b/extern/ldm_zero123/modules/evaluate/ssim.py new file mode 100755 index 0000000..b640df0 --- /dev/null +++ b/extern/ldm_zero123/modules/evaluate/ssim.py @@ -0,0 +1,118 @@ +# MIT Licence + +# Methods to predict the SSIM, taken from +# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py + +from math import exp + +import torch +import torch.nn.functional as F +from torch.autograd import Variable + + +def gaussian(window_size, sigma): + gauss = torch.Tensor( + [ + exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) + for x in range(window_size) + ] + ) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable( + _2D_window.expand(channel, 1, window_size, window_size).contiguous() + ) + return window + + +def _ssim(img1, img2, window, window_size, channel, mask=None, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = ( + F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + ) + sigma2_sq = ( + F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + ) + sigma12 = ( + F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) + - mu1_mu2 + ) + + C1 = (0.01) ** 2 + C2 = (0.03) ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ) + + if not (mask is None): + b = mask.size(0) + ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask + ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(dim=1).clamp( + min=1 + ) + return ssim_map + + import pdb + + pdb.set_trace + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2, mask=None): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim( + img1, + img2, + window, + self.window_size, + channel, + mask, + self.size_average, + ) + + +def ssim(img1, img2, window_size=11, mask=None, size_average=True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, mask, size_average) diff --git a/extern/ldm_zero123/modules/evaluate/torch_frechet_video_distance.py b/extern/ldm_zero123/modules/evaluate/torch_frechet_video_distance.py new file mode 100755 index 0000000..c4cd40f --- /dev/null +++ b/extern/ldm_zero123/modules/evaluate/torch_frechet_video_distance.py @@ -0,0 +1,331 @@ +# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks! +import glob +import hashlib +import html +import io +import multiprocessing as mp +import os +import re +import urllib +import urllib.request +from typing import Any, Callable, Dict, List, Tuple, Union + +import numpy as np +import requests +import scipy.linalg +import torch +from torchvision.io import read_video +from tqdm import tqdm + +torch.set_grad_enabled(False) +from einops import rearrange +from nitro.util import isvideo + + +def compute_frechet_distance(mu_sample, sigma_sample, mu_ref, sigma_ref) -> float: + print("Calculate frechet distance...") + m = np.square(mu_sample - mu_ref).sum() + s, _ = scipy.linalg.sqrtm( + np.dot(sigma_sample, sigma_ref), disp=False + ) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2)) + + return float(fid) + + +def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + mu = feats.mean(axis=0) # [d] + sigma = np.cov(feats, rowvar=False) # [d, d] + + return mu, sigma + + +def open_url( + url: str, + num_attempts: int = 10, + verbose: bool = True, + return_filename: bool = False, +) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match("^[a-z]+://", url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith("file://"): + filename = urllib.parse.urlparse(url).path + if re.match(r"^/[a-zA-Z]:", filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [ + html.unescape(link) + for link in content_str.split('"') + if "export=download" in link + ] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError( + "Google Drive download quota exceeded -- please try again later" + ) + + match = re.search( + r'filename="([^"]*)"', + res.headers.get("Content-Disposition", ""), + ) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) + + +def load_video(ip): + vid, *_ = read_video(ip) + vid = rearrange(vid, "t h w c -> t c h w").to(torch.uint8) + return vid + + +def get_data_from_str(input_str, nprc=None): + assert os.path.isdir( + input_str + ), f'Specified input folder "{input_str}" is not a directory' + vid_filelist = glob.glob(os.path.join(input_str, "*.mp4")) + print(f"Found {len(vid_filelist)} videos in dir {input_str}") + + if nprc is None: + try: + nprc = mp.cpu_count() + except NotImplementedError: + print( + "WARNING: cpu_count() not avlailable, using only 1 cpu for video loading" + ) + nprc = 1 + + pool = mp.Pool(processes=nprc) + + vids = [] + for v in tqdm( + pool.imap_unordered(load_video, vid_filelist), + total=len(vid_filelist), + desc="Loading videos...", + ): + vids.append(v) + + vids = torch.stack(vids, dim=0).float() + + return vids + + +def get_stats(stats): + assert os.path.isfile(stats) and stats.endswith( + ".npz" + ), f"no stats found under {stats}" + + print(f"Using precomputed statistics under {stats}") + stats = np.load(stats) + stats = {key: stats[key] for key in stats.files} + + return stats + + +@torch.no_grad() +def compute_fvd( + ref_input, sample_input, bs=32, ref_stats=None, sample_stats=None, nprc_load=None +): + calc_stats = ref_stats is None or sample_stats is None + + if calc_stats: + only_ref = sample_stats is not None + only_sample = ref_stats is not None + + if isinstance(ref_input, str) and not only_sample: + ref_input = get_data_from_str(ref_input, nprc_load) + + if isinstance(sample_input, str) and not only_ref: + sample_input = get_data_from_str(sample_input, nprc_load) + + stats = compute_statistics( + sample_input, + ref_input, + device="cuda" if torch.cuda.is_available() else "cpu", + bs=bs, + only_ref=only_ref, + only_sample=only_sample, + ) + + if only_ref: + stats.update(get_stats(sample_stats)) + elif only_sample: + stats.update(get_stats(ref_stats)) + + else: + stats = get_stats(sample_stats) + stats.update(get_stats(ref_stats)) + + fvd = compute_frechet_distance(**stats) + + return { + "FVD": fvd, + } + + +@torch.no_grad() +def compute_statistics( + videos_fake, + videos_real, + device: str = "cuda", + bs=32, + only_ref=False, + only_sample=False, +) -> Dict: + detector_url = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1" + detector_kwargs = dict( + rescale=True, resize=True, return_features=True + ) # Return raw features before the softmax layer. + + with open_url(detector_url, verbose=False) as f: + detector = torch.jit.load(f).eval().to(device) + + assert not ( + only_sample and only_ref + ), "only_ref and only_sample arguments are mutually exclusive" + + ref_embed, sample_embed = [], [] + + info = f"Computing I3D activations for FVD score with batch size {bs}" + + if only_ref: + if not isvideo(videos_real): + # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] + videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() + print(videos_real.shape) + + if videos_real.shape[0] % bs == 0: + n_secs = videos_real.shape[0] // bs + else: + n_secs = videos_real.shape[0] // bs + 1 + + videos_real = torch.tensor_split(videos_real, n_secs, dim=0) + + for ref_v in tqdm(videos_real, total=len(videos_real), desc=info): + feats_ref = ( + detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() + ) + ref_embed.append(feats_ref) + + elif only_sample: + if not isvideo(videos_fake): + # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] + videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() + print(videos_fake.shape) + + if videos_fake.shape[0] % bs == 0: + n_secs = videos_fake.shape[0] // bs + else: + n_secs = videos_fake.shape[0] // bs + 1 + + videos_real = torch.tensor_split(videos_real, n_secs, dim=0) + + for sample_v in tqdm(videos_fake, total=len(videos_real), desc=info): + feats_sample = ( + detector(sample_v.to(device).contiguous(), **detector_kwargs) + .cpu() + .numpy() + ) + sample_embed.append(feats_sample) + + else: + if not isvideo(videos_real): + # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] + videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() + + if not isvideo(videos_fake): + videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() + + if videos_fake.shape[0] % bs == 0: + n_secs = videos_fake.shape[0] // bs + else: + n_secs = videos_fake.shape[0] // bs + 1 + + videos_real = torch.tensor_split(videos_real, n_secs, dim=0) + videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0) + + for ref_v, sample_v in tqdm( + zip(videos_real, videos_fake), total=len(videos_fake), desc=info + ): + # print(ref_v.shape) + # ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) + # sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) + + feats_sample = ( + detector(sample_v.to(device).contiguous(), **detector_kwargs) + .cpu() + .numpy() + ) + feats_ref = ( + detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() + ) + sample_embed.append(feats_sample) + ref_embed.append(feats_ref) + + out = dict() + if len(sample_embed) > 0: + sample_embed = np.concatenate(sample_embed, axis=0) + mu_sample, sigma_sample = compute_stats(sample_embed) + out.update({"mu_sample": mu_sample, "sigma_sample": sigma_sample}) + + if len(ref_embed) > 0: + ref_embed = np.concatenate(ref_embed, axis=0) + mu_ref, sigma_ref = compute_stats(ref_embed) + out.update({"mu_ref": mu_ref, "sigma_ref": sigma_ref}) + + return out diff --git a/extern/ldm_zero123/modules/image_degradation/__init__.py b/extern/ldm_zero123/modules/image_degradation/__init__.py new file mode 100755 index 0000000..1143c44 --- /dev/null +++ b/extern/ldm_zero123/modules/image_degradation/__init__.py @@ -0,0 +1,6 @@ +from extern.ldm_zero123.modules.image_degradation.bsrgan import ( + degradation_bsrgan_variant as degradation_fn_bsr, +) +from extern.ldm_zero123.modules.image_degradation.bsrgan_light import ( + degradation_bsrgan_variant as degradation_fn_bsr_light, +) diff --git a/extern/ldm_zero123/modules/image_degradation/bsrgan.py b/extern/ldm_zero123/modules/image_degradation/bsrgan.py new file mode 100755 index 0000000..3b2e534 --- /dev/null +++ b/extern/ldm_zero123/modules/image_degradation/bsrgan.py @@ -0,0 +1,809 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import random +from functools import partial + +import albumentations +import cv2 +import numpy as np +import scipy +import scipy.stats as ss +import torch +from scipy import ndimage +from scipy.interpolate import interp2d +from scipy.linalg import orth + +import extern.ldm_zero123.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + """ + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + """ + w, h = img.shape[:2] + im = np.copy(img) + return im[: w - w % sf, : h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot( + np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), + np.array([1.0, 0.0]), + ) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + """ + x: image, NxcxHxW + k: kernel, Nx1xhxw + """ + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate") + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel( + k_size=np.array([15, 15]), + scale_factor=np.array([4, 4]), + min_var=0.6, + max_var=10.0, + noise_level=0, +): + """ " + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + """ + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + """ + if filter_type == "gaussian": + return fspecial_gaussian(*args, **kwargs) + if filter_type == "laplacian": + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + """ + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + """ + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + """blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + """ + x = ndimage.filters.convolve( + x, np.expand_dims(k, axis=2), mode="wrap" + ) # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + """bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + """ + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") + return x + + +def classical_degradation(x, k, sf=3): + """blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + """ + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype("float32") + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian( + ksize=2 * random.randint(2, 11) + 3, + theta=random.random() * np.pi, + l1=l1, + l2=l2, + ) + else: + k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror") + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize( + img, + (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( + np.float32 + ) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal( + 0, noise_level / 255.0, (*img.shape[:2], 1) + ).astype(np.float32) + else: # add noise + L = noise_level2 / 255.0 + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] + ).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype( + np.float32 + ) + elif rnum < 0.4: + img += img * np.random.normal( + 0, noise_level / 255.0, (*img.shape[:2], 1) + ).astype(np.float32) + else: + L = noise_level2 / 255.0 + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] + ).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 + noise_gray = ( + np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + ) + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode( + ".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor] + ) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[ + rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, : + ] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f"img size ({h1}X{w1}) is too small!") + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize( + img, + (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = ( + shuffle_order[idx2], + shuffle_order[idx1], + ) + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + else: + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve( + img, np.expand_dims(k_shifted, axis=2), mode="mirror" + ) + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize( + img, + (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3]), + ) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = ( + shuffle_order[idx2], + shuffle_order[idx1], + ) + + for i in shuffle_order: + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize( + image, + (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + else: + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve( + image, np.expand_dims(k_shifted, axis=2), mode="mirror" + ) + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize( + image, + (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3]), + ) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus( + img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None +): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f"img size ({h1}X{w1}) is too small!") + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print("check the shuffle!") + + # resize to desired size + img = cv2.resize( + img, + (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == "__main__": + print("hey") + img = util.imread_uint("utils/test.png", 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize( + max_size=h, interpolation=cv2.INTER_CUBIC + )(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize( + util.single2uint(img_lq), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0, + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0, + ) + img_concat = np.concatenate( + [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1 + ) + util.imsave(img_concat, str(i) + ".png") diff --git a/extern/ldm_zero123/modules/image_degradation/bsrgan_light.py b/extern/ldm_zero123/modules/image_degradation/bsrgan_light.py new file mode 100755 index 0000000..84318a7 --- /dev/null +++ b/extern/ldm_zero123/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,720 @@ +# -*- coding: utf-8 -*- +import random +from functools import partial + +import albumentations +import cv2 +import numpy as np +import scipy +import scipy.stats as ss +import torch +from scipy import ndimage +from scipy.interpolate import interp2d +from scipy.linalg import orth + +import extern.ldm_zero123.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + """ + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + """ + w, h = img.shape[:2] + im = np.copy(img) + return im[: w - w % sf, : h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot( + np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), + np.array([1.0, 0.0]), + ) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + """ + x: image, NxcxHxW + k: kernel, Nx1xhxw + """ + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate") + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel( + k_size=np.array([15, 15]), + scale_factor=np.array([4, 4]), + min_var=0.6, + max_var=10.0, + noise_level=0, +): + """ " + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + """ + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + """ + if filter_type == "gaussian": + return fspecial_gaussian(*args, **kwargs) + if filter_type == "laplacian": + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + """ + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + """ + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + """blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + """ + x = ndimage.convolve( + x, np.expand_dims(k, axis=2), mode="wrap" + ) # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + """bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + """ + x = bicubic_degradation(x, sf=sf) + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") + return x + + +def classical_degradation(x, k, sf=3): + """blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + """ + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype("float32") + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2 / 4 + wd = wd / 4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian( + ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2 + ) + else: + k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode="mirror") + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize( + img, + (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( + np.float32 + ) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal( + 0, noise_level / 255.0, (*img.shape[:2], 1) + ).astype(np.float32) + else: # add noise + L = noise_level2 / 255.0 + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] + ).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype( + np.float32 + ) + elif rnum < 0.4: + img += img * np.random.normal( + 0, noise_level / 255.0, (*img.shape[:2], 1) + ).astype(np.float32) + else: + L = noise_level2 / 255.0 + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] + ).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 + noise_gray = ( + np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + ) + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode( + ".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor] + ) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[ + rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, : + ] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f"img size ({h1}X{w1}) is too small!") + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize( + img, + (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = ( + shuffle_order[idx2], + shuffle_order[idx1], + ) + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + else: + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.convolve( + img, np.expand_dims(k_shifted, axis=2), mode="mirror" + ) + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize( + img, + (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3]), + ) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = ( + shuffle_order[idx2], + shuffle_order[idx1], + ) + + for i in shuffle_order: + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize( + image, + (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) + else: + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.convolve( + image, np.expand_dims(k_shifted, axis=2), mode="mirror" + ) + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize( + image, + (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3]), + ) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + +if __name__ == "__main__": + print("hey") + img = util.imread_uint("utils/test.png", 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize( + max_size=h, interpolation=cv2.INTER_CUBIC + )(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize( + util.single2uint(img_lq), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0, + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0, + ) + img_concat = np.concatenate( + [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1 + ) + util.imsave(img_concat, str(i) + ".png") diff --git a/extern/ldm_zero123/modules/image_degradation/utils/test.png b/extern/ldm_zero123/modules/image_degradation/utils/test.png new file mode 100755 index 0000000..4249b43 Binary files /dev/null and b/extern/ldm_zero123/modules/image_degradation/utils/test.png differ diff --git a/extern/ldm_zero123/modules/image_degradation/utils_image.py b/extern/ldm_zero123/modules/image_degradation/utils_image.py new file mode 100755 index 0000000..c933af5 --- /dev/null +++ b/extern/ldm_zero123/modules/image_degradation/utils_image.py @@ -0,0 +1,988 @@ +import math +import os +import random +from datetime import datetime + +import cv2 +import numpy as np +import torch +from torchvision.utils import make_grid + +# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + + +""" +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +""" + + +IMG_EXTENSIONS = [ + ".jpg", + ".JPG", + ".jpeg", + ".JPEG", + ".png", + ".PNG", + ".ppm", + ".PPM", + ".bmp", + ".BMP", + ".tif", +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime("%y%m%d-%H%M%S") + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray") + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap="rainbow", figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection="3d") + + w, h = Z.shape[:2] + xx = np.arange(0, w, 1) + yy = np.arange(0, h, 1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X, Y, Z, cmap=cmap) + # ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +""" +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +""" + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), "{:s} is not a valid directory".format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, "{:s} has no valid image file".format(path) + return images + + +""" +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +""" + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int)) + w1.append(w - p_size) + h1.append(h - p_size) + # print(w1) + # print(h1) + for i in w1: + for j in h1: + patches.append(img[i : i + p_size, j : j + p_size, :]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join( + os.path.dirname(img_path), img_name + str("_s{:04d}".format(i)) + ".png" + ) + cv2.imwrite(new_path, img) + + +def split_imageset( + original_dataroot, + taget_dataroot, + n_channels=3, + p_size=800, + p_overlap=96, + p_max=1000, +): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path))) + # if original_dataroot == taget_dataroot: + # del img_path + + +""" +# -------------------------------------------- +# makedir +# -------------------------------------------- +""" + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + "_archived_" + get_timestamp() + print("Path already exists. Rename it to [{:s}]".format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +""" +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +""" + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255.0 + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +""" +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +""" + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + return np.float32(img / 255.0) + + +def single2uint(img): + return np.uint8((img.clip(0, 1) * 255.0).round()) + + +def uint162single(img): + return np.float32(img / 65535.0) + + +def single2uint16(img): + return np.uint16((img.clip(0, 1) * 65535.0).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return ( + torch.from_numpy(np.ascontiguousarray(img)) + .permute(2, 0, 1) + .float() + .div(255.0) + .unsqueeze(0) + ) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return ( + torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0) + ) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img * 255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return ( + torch.from_numpy(np.ascontiguousarray(img)) + .permute(2, 0, 1) + .float() + .unsqueeze(0) + ) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return ( + torch.from_numpy(np.ascontiguousarray(img)) + .permute(2, 0, 1, 3) + .float() + .unsqueeze(0) + ) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + """ + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + """ + tensor = ( + tensor.squeeze().float().cpu().clamp_(*min_max) + ) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + "Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format( + n_dim + ) + ) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +""" +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +""" + + +def augment_img(img, mode=0): + """Kai Zhang (github: https://github.com/cszn)""" + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + """Kai Zhang (github: https://github.com/cszn)""" + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + """Kai Zhang (github: https://github.com/cszn)""" + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +""" +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +""" + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[: H - H_r, : W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[: H - H_r, : W - W_r, :] + else: + raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border : h - border, border : w - border] + return img + + +""" +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +""" + + +def rgb2ycbcr(img, only_y=True): + """same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + """ + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255.0 + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul( + img, + [ + [65.481, -37.797, 112.0], + [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214], + ], + ) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255.0 + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + """same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + """ + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255.0 + # convert + rlt = np.matmul( + img, + [ + [0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0], + ], + ) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255.0 + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + """bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + """ + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255.0 + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul( + img, + [ + [24.966, 112.0, -18.214], + [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0], + ], + ) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255.0 + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == "gray": # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == "y": # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == "RGB": # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +""" +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +""" + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + # img1 = img1.squeeze() + # img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError("Input images must have the same dimensions.") + h, w = img1.shape[:2] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2) ** 2) + if mse == 0: + return float("inf") + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + """calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + """ + # img1 = img1.squeeze() + # img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError("Input images must have the same dimensions.") + h, w = img1.shape[:2] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:, :, i], img2[:, :, i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError("Wrong input image dimensions.") + + +def ssim(img1, img2): + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ) + return ssim_map.mean() + + +""" +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +""" + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( + -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 + ) * (((absx > 1) * (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices( + in_length, out_length, scale, kernel, kernel_width, antialiasing +): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace( + 0, P - 1, P + ).view(1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = "cubic" + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = ( + img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + ) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = "cubic" + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = ( + img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + ) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == "__main__": + print("---") +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) diff --git a/extern/ldm_zero123/modules/losses/__init__.py b/extern/ldm_zero123/modules/losses/__init__.py new file mode 100755 index 0000000..15c99be --- /dev/null +++ b/extern/ldm_zero123/modules/losses/__init__.py @@ -0,0 +1 @@ +from extern.ldm_zero123.modules.losses.contperceptual import LPIPSWithDiscriminator diff --git a/extern/ldm_zero123/modules/losses/contperceptual.py b/extern/ldm_zero123/modules/losses/contperceptual.py new file mode 100755 index 0000000..422f749 --- /dev/null +++ b/extern/ldm_zero123/modules/losses/contperceptual.py @@ -0,0 +1,153 @@ +import torch +import torch.nn as nn +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__( + self, + disc_start, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_loss="hinge", + ): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad( + nll_loss, self.last_layer[0], retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, self.last_layer[0], retain_graph=True + )[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + inputs, + reconstructions, + posteriors, + optimizer_idx, + global_step, + last_layer=None, + cond=None, + split="train", + weights=None, + ): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss( + inputs.contiguous(), reconstructions.contiguous() + ) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous(), cond), dim=1) + ) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + loss = ( + weighted_nll_loss + + self.kl_weight * kl_loss + + d_weight * disc_factor * g_loss + ) + + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1) + ) + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous().detach(), cond), dim=1) + ) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } + return d_loss, log diff --git a/extern/ldm_zero123/modules/losses/vqperceptual.py b/extern/ldm_zero123/modules/losses/vqperceptual.py new file mode 100755 index 0000000..feb5885 --- /dev/null +++ b/extern/ldm_zero123/modules/losses/vqperceptual.py @@ -0,0 +1,218 @@ +import torch +import torch.nn.functional as F +from einops import repeat +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss +from torch import nn + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3]) + loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def adopt_weight(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + + +def l1(x, y): + return torch.abs(x - y) + + +def l2(x, y): + return torch.pow((x - y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__( + self, + disc_start, + codebook_weight=1.0, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_ndf=64, + disc_loss="hinge", + n_classes=None, + perceptual_loss="lpips", + pixel_loss="l1", + ): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf, + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad( + nll_loss, self.last_layer[0], retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, self.last_layer[0], retain_graph=True + )[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + codebook_loss, + inputs, + reconstructions, + optimizer_idx, + global_step, + last_layer=None, + cond=None, + split="train", + predicted_indices=None, + ): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.0]).to(inputs.device) + # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss( + inputs.contiguous(), reconstructions.contiguous() + ) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous(), cond), dim=1) + ) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + loss = ( + nll_loss + + d_weight * disc_factor * g_loss + + self.codebook_weight * codebook_loss.mean() + ) + + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity( + predicted_indices, self.n_classes + ) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1) + ) + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous().detach(), cond), dim=1) + ) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } + return d_loss, log diff --git a/extern/ldm_zero123/modules/x_transformer.py b/extern/ldm_zero123/modules/x_transformer.py new file mode 100755 index 0000000..ab8fab2 --- /dev/null +++ b/extern/ldm_zero123/modules/x_transformer.py @@ -0,0 +1,705 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +from collections import namedtuple +from functools import partial +from inspect import isfunction + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from torch import einsum, nn + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) + +LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = ( + torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + + offset + ) + sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + + return inner + + +def not_equals(val): + def inner(x): + return x != val + + return inner + + +def equals(val): + def inner(x): + return x == val + + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key( + partial(string_begins_with, prefix), d + ) + kwargs_without_prefix = dict( + map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())) + ) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d") + ) + + return gated_output.reshape_as(x) + + +# feedforward + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0.0, + on_attn=False, + ): + super().__init__() + if use_entmax15: + raise NotImplementedError( + "Check out entmax activation instead of softmax activation!" + ) + self.scale = dim_head**-0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + # self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) + if on_attn + else nn.Linear(inner_dim, dim) + ) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None, + ): + b, n, _, h, talking_heads, device = ( + *x.shape, + self.heads, + self.talking_heads, + x.device, + ) + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default( + k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool() + ) + q_mask = rearrange(q_mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b j -> b () () j") + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map( + lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v) + ) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum( + "b h i j, h k -> b k i j", dots, self.pre_softmax_proj + ).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum( + "b h i j, h k -> b k i j", attn, self.post_softmax_proj + ).contiguous() + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs, + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) + attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs) + + dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = ( + FixedPositionalEmbedding(dim) if position_infused_attn else None + ) + self.rotary_pos_emb = always(None) + + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), "number of relative position buckets must be less than the relative position max distance" + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ("a", "c", "f") + elif cross_attend and only_cross: + default_block = ("c", "f") + else: + default_block = ("a", "f") + + if macaron: + default_block = ("f",) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, "par ratio out of range" + default_block = tuple(filter(not_equals("f"), default_block)) + par_attn = par_depth // par_ratio + depth_cut = ( + par_depth * 2 // 3 + ) # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert ( + len(default_block) <= par_width + ), "default block is too large for par_ratio" + par_block = default_block + ("f",) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ("f",) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert ( + sandwich_coef > 0 and sandwich_coef <= depth + ), "sandwich coefficient should be less than the depth" + layer_types = ( + ("a",) * sandwich_coef + + default_block * (depth - sandwich_coef) + + ("f",) * sandwich_coef + ) + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals("a"), layer_types))) + + for layer_type in self.layer_types: + if layer_type == "a": + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == "c": + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == "f": + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f"invalid layer type {layer_type}") + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False, + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate( + zip(self.layer_types, self.layers) + ): + is_last = ind == (len(self.layers) - 1) + + if layer_type == "a": + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == "a": + out, inter = block( + x, + mask=mask, + sinusoidal_emb=self.pia_pos_emb, + rel_pos=self.rel_pos, + prev_attn=prev_attn, + mem=layer_mem, + ) + elif layer_type == "c": + out, inter = block( + x, + context=context, + mask=mask, + context_mask=context_mask, + prev_attn=prev_cross_attn, + ) + elif layer_type == "f": + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ("a", "c"): + intermediates.append(inter) + + if layer_type == "a" and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == "c" and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert "causal" not in kwargs, "cannot set causality on encoder" + super().__init__(causal=False, **kwargs) + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0.0, + emb_dropout=0.0, + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True, + ): + super().__init__() + assert isinstance( + attn_layers, AttentionLayers + ), "attention layers must be one of Encoder or Decoder" + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = ( + nn.Linear(dim, num_tokens) + if not tie_embedding + else lambda t: t @ self.token_emb.weight.t() + ) + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, "num_memory_tokens"): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs, + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, "n d -> b n d", b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers( + x, mask=mask, mems=mems, return_hiddens=True, **kwargs + ) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = ( + list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) + if exists(mems) + else hiddens + ) + new_mems = list( + map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems) + ) + return out, new_mems + + if return_attn: + attn_maps = list( + map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates) + ) + return out, attn_maps + + return out diff --git a/extern/ldm_zero123/thirdp/psp/helpers.py b/extern/ldm_zero123/thirdp/psp/helpers.py new file mode 100755 index 0000000..fc0de4d --- /dev/null +++ b/extern/ldm_zero123/thirdp/psp/helpers.py @@ -0,0 +1,144 @@ +# https://github.com/eladrich/pixel2style2pixel + +from collections import namedtuple + +import torch +from torch.nn import ( + AdaptiveAvgPool2d, + BatchNorm2d, + Conv2d, + MaxPool2d, + Module, + PReLU, + ReLU, + Sequential, + Sigmoid, +) + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])): + """A named tuple describing a ResNet block.""" + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [ + Bottleneck(depth, depth, 1) for i in range(num_units - 1) + ] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3), + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3), + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3), + ] + else: + raise ValueError( + "Invalid number of layers: {}. Must be one of [50, 100, 152]".format( + num_layers + ) + ) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d( + channels, channels // reduction, kernel_size=1, padding=0, bias=False + ) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d( + channels // reduction, channels, kernel_size=1, padding=0, bias=False + ) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth), + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth), + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16), + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut diff --git a/extern/ldm_zero123/thirdp/psp/id_loss.py b/extern/ldm_zero123/thirdp/psp/id_loss.py new file mode 100755 index 0000000..76ab9eb --- /dev/null +++ b/extern/ldm_zero123/thirdp/psp/id_loss.py @@ -0,0 +1,26 @@ +# https://github.com/eladrich/pixel2style2pixel +import torch +from torch import nn + +from extern.ldm_zero123.thirdp.psp.model_irse import Backbone + + +class IDFeatures(nn.Module): + def __init__(self, model_path): + super(IDFeatures, self).__init__() + print("Loading ResNet ArcFace") + self.facenet = Backbone( + input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se" + ) + self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + + def forward(self, x, crop=False): + # Not sure of the image range here + if crop: + x = torch.nn.functional.interpolate(x, (256, 256), mode="area") + x = x[:, :, 35:223, 32:220] + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats diff --git a/extern/ldm_zero123/thirdp/psp/model_irse.py b/extern/ldm_zero123/thirdp/psp/model_irse.py new file mode 100755 index 0000000..50174ab --- /dev/null +++ b/extern/ldm_zero123/thirdp/psp/model_irse.py @@ -0,0 +1,118 @@ +# https://github.com/eladrich/pixel2style2pixel + +from torch.nn import ( + BatchNorm1d, + BatchNorm2d, + Conv2d, + Dropout, + Linear, + Module, + PReLU, + Sequential, +) + +from extern.ldm_zero123.thirdp.psp.helpers import ( + Flatten, + bottleneck_IR, + bottleneck_IR_SE, + get_blocks, + l2_norm, +) + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode="ir", drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ["ir", "ir_se"], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == "ir": + unit_module = bottleneck_IR + elif mode == "ir_se": + unit_module = bottleneck_IR_SE + self.input_layer = Sequential( + Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64) + ) + if input_size == 112: + self.output_layer = Sequential( + BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine), + ) + else: + self.output_layer = Sequential( + BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine), + ) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module( + bottleneck.in_channel, bottleneck.depth, bottleneck.stride + ) + ) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode="ir", drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone( + input_size, num_layers=100, mode="ir", drop_ratio=0.4, affine=False + ) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone( + input_size, num_layers=152, mode="ir", drop_ratio=0.4, affine=False + ) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone( + input_size, num_layers=50, mode="ir_se", drop_ratio=0.4, affine=False + ) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone( + input_size, num_layers=100, mode="ir_se", drop_ratio=0.4, affine=False + ) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone( + input_size, num_layers=152, mode="ir_se", drop_ratio=0.4, affine=False + ) + return model diff --git a/extern/ldm_zero123/util.py b/extern/ldm_zero123/util.py new file mode 100755 index 0000000..4664798 --- /dev/null +++ b/extern/ldm_zero123/util.py @@ -0,0 +1,249 @@ +import importlib +import os +import time +from inspect import isfunction + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import PIL +import torch +import torchvision +from PIL import Image, ImageDraw, ImageFont +from torch import optim + + +def pil_rectangle_crop(im): + width, height = im.size # Get dimensions + + if width <= height: + left = 0 + right = width + top = (height - width) / 2 + bottom = (height + width) / 2 + else: + top = 0 + bottom = height + left = (width - height) / 2 + bottom = (width + height) / 2 + + # Crop the center of the image + im = im.crop((left, top, right, bottom)) + return im + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join( + xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) + ) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__( + self, + params, + lr=1.0e-3, + betas=(0.9, 0.999), + eps=1.0e-8, # TODO: check hyperparameters before using + weight_decay=1.0e-2, + amsgrad=False, + ema_decay=0.9999, # ema decay to match previous code + ema_power=1.0, + param_names=(), + ): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ema_decay=ema_decay, + ema_power=ema_power, + param_names=param_names, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group["amsgrad"] + beta1, beta2 = group["betas"] + ema_decay = group["ema_decay"] + ema_power = group["ema_power"] + + for p in group["params"]: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of parameter values + state["param_exp_avg"] = p.detach().float().clone() + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + ema_params_with_grad.append(state["param_exp_avg"]) + + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + # update the steps for each param group update + state["step"] += 1 + # record the step after step update + state_steps.append(state["step"]) + + optim._functional.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + cur_ema_decay = min(ema_decay, 1 - state["step"] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_( + param.float(), alpha=1 - cur_ema_decay + ) + + return loss diff --git a/extern/zero123.py b/extern/zero123.py new file mode 100644 index 0000000..2ee4343 --- /dev/null +++ b/extern/zero123.py @@ -0,0 +1,666 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import torch +import torchvision.transforms.functional as TF +from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.modeling_utils import ModelMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import deprecate, is_accelerate_available, logging +from diffusers.utils.torch_utils import randn_tensor +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CLIPCameraProjection(ModelMixin, ConfigMixin): + """ + A Projection layer for CLIP embedding and camera embedding. + + Parameters: + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed` + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + + additional_embeddings`. + """ + + @register_to_config + def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4): + super().__init__() + self.embedding_dim = embedding_dim + self.additional_embeddings = additional_embeddings + + self.input_dim = self.embedding_dim + self.additional_embeddings + self.output_dim = self.embedding_dim + + self.proj = torch.nn.Linear(self.input_dim, self.output_dim) + + def forward( + self, + embedding: torch.FloatTensor, + ): + """ + The [`PriorTransformer`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`): + The currently input embeddings. + + Returns: + The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`). + """ + proj_embedding = self.proj(embedding) + return proj_embedding + + +class Zero123Pipeline(DiffusionPipeline): + r""" + Pipeline to generate variations from an input image using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + # TODO: feature_extractor is required to encode images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + clip_camera_projection: CLIPCameraProjection, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + clip_camera_projection=clip_camera_projection, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [ + self.unet, + self.image_encoder, + self.vae, + self.safety_checker, + ]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_image( + self, + image, + elevation, + azimuth, + distance, + device, + num_images_per_prompt, + do_classifier_free_guidance, + clip_image_embeddings=None, + image_camera_embeddings=None, + ): + dtype = next(self.image_encoder.parameters()).dtype + + if image_camera_embeddings is None: + if image is None: + assert clip_image_embeddings is not None + image_embeddings = clip_image_embeddings.to(device=device, dtype=dtype) + else: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor( + images=image, return_tensors="pt" + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + bs_embed, seq_len, _ = image_embeddings.shape + + if isinstance(elevation, float): + elevation = torch.as_tensor( + [elevation] * bs_embed, dtype=dtype, device=device + ) + if isinstance(azimuth, float): + azimuth = torch.as_tensor( + [azimuth] * bs_embed, dtype=dtype, device=device + ) + if isinstance(distance, float): + distance = torch.as_tensor( + [distance] * bs_embed, dtype=dtype, device=device + ) + + camera_embeddings = torch.stack( + [ + torch.deg2rad(elevation), + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + distance, + ], + dim=-1, + )[:, None, :] + + image_embeddings = torch.cat([image_embeddings, camera_embeddings], dim=-1) + + # project (image, camera) embeddings to the same dimension as clip embeddings + image_embeddings = self.clip_camera_projection(image_embeddings) + else: + image_embeddings = image_camera_embeddings.to(device=device, dtype=dtype) + bs_embed, seq_len, _ = image_embeddings.shape + + # duplicate image embeddings for each generation per prompt, using mps friendly method + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + # TODO: check image size or adjust image size to (height, width) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_latent_model_input( + self, + latents: torch.FloatTensor, + image: Optional[ + Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor] + ], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + image_latents: Optional[torch.FloatTensor] = None, + ): + if isinstance(image, PIL.Image.Image): + image_pt = TF.to_tensor(image).unsqueeze(0).to(latents) + elif isinstance(image, list): + image_pt = torch.stack([TF.to_tensor(img) for img in image], dim=0).to( + latents + ) + elif isinstance(image, torch.Tensor): + image_pt = image + else: + image_pt = None + + if image_pt is None: + assert image_latents is not None + image_pt = image_latents.repeat_interleave(num_images_per_prompt, dim=0) + else: + image_pt = image_pt * 2.0 - 1.0 # scale to [-1, 1] + # FIXME: encoded latents should be multiplied with self.vae.config.scaling_factor + # but zero123 was not trained this way + image_pt = self.vae.encode(image_pt).latent_dist.mode() + image_pt = image_pt.repeat_interleave(num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + latent_model_input = torch.cat( + [ + torch.cat([latents, latents], dim=0), + torch.cat([torch.zeros_like(image_pt), image_pt], dim=0), + ], + dim=1, + ) + else: + latent_model_input = torch.cat([latents, image_pt], dim=1) + + return latent_model_input + + @torch.no_grad() + def __call__( + self, + image: Optional[ + Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor] + ] = None, + elevation: Optional[Union[float, torch.FloatTensor]] = None, + azimuth: Optional[Union[float, torch.FloatTensor]] = None, + distance: Optional[Union[float, torch.FloatTensor]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 3.0, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + clip_image_embeddings: Optional[torch.FloatTensor] = None, + image_camera_embeddings: Optional[torch.FloatTensor] = None, + image_latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image or images to guide the image generation. If you provide a tensor, it needs to comply with the + configuration of + [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) + `CLIPImageProcessor` + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + # TODO: check input elevation, azimuth, and distance + # TODO: check image, clip_image_embeddings, image_latents + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + else: + assert image_latents is not None + assert ( + clip_image_embeddings is not None or image_camera_embeddings is not None + ) + batch_size = image_latents.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + if isinstance(image, PIL.Image.Image) or isinstance(image, list): + pil_image = image + elif isinstance(image, torch.Tensor): + pil_image = [TF.to_pil_image(image[i]) for i in range(image.shape[0])] + else: + pil_image = None + image_embeddings = self._encode_image( + pil_image, + elevation, + azimuth, + distance, + device, + num_images_per_prompt, + do_classifier_free_guidance, + clip_image_embeddings, + image_camera_embeddings, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + # num_channels_latents = self.unet.config.in_channels + num_channels_latents = 4 # FIXME: hard-coded + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = self._get_latent_model_input( + latents, + image, + num_images_per_prompt, + do_classifier_free_guidance, + image_latents, + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=image_embeddings, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, image_embeddings.dtype + ) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000..656b3a4 --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,450 @@ +import argparse +import glob +import os +import re +import signal +import subprocess +import tempfile +import time +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + +import gradio as gr +import numpy as np +import psutil +import trimesh + + +def tail(f, window=20): + # Returns the last `window` lines of file `f`. + if window == 0: + return [] + + BUFSIZ = 1024 + f.seek(0, 2) + remaining_bytes = f.tell() + size = window + 1 + block = -1 + data = [] + + while size > 0 and remaining_bytes > 0: + if remaining_bytes - BUFSIZ > 0: + # Seek back one whole BUFSIZ + f.seek(block * BUFSIZ, 2) + # read BUFFER + bunch = f.read(BUFSIZ) + else: + # file too small, start from beginning + f.seek(0, 0) + # only read what was not read + bunch = f.read(remaining_bytes) + + bunch = bunch.decode("utf-8") + data.insert(0, bunch) + size -= bunch.count("\n") + remaining_bytes -= BUFSIZ + block -= 1 + + return "\n".join("".join(data).splitlines()[-window:]) + + +@dataclass +class ExperimentStatus: + pid: Optional[int] = None + progress: str = "" + log: str = "" + output_image: Optional[str] = None + output_video: Optional[str] = None + output_mesh: Optional[str] = None + + def tolist(self): + return [ + self.pid, + self.progress, + self.log, + self.output_image, + self.output_video, + self.output_mesh, + ] + + +EXP_ROOT_DIR = "outputs-gradio" +DEFAULT_PROMPT = "a delicious hamburger" +model_config = [ + ("DreamFusion (DeepFloyd-IF)", "configs/gradio/dreamfusion-if.yaml"), + ("DreamFusion (Stable Diffusion)", "configs/gradio/dreamfusion-sd.yaml"), + ("TextMesh (DeepFloyd-IF)", "configs/gradio/textmesh-if.yaml"), + ("Fantasia3D (Stable Diffusion, Geometry Only)", "configs/gradio/fantasia3d.yaml"), + ("SJC (Stable Diffusion)", "configs/gradio/sjc.yaml"), + ("Latent-NeRF (Stable Diffusion)", "configs/gradio/latentnerf.yaml"), +] +model_choices = [m[0] for m in model_config] +model_name_to_config = {m[0]: m[1] for m in model_config} + + +def load_model_config(model_name): + return open(model_name_to_config[model_name]).read() + + +def load_model_config_attrs(model_name): + config_str = load_model_config(model_name) + from threestudio.utils.config import load_config + + cfg = load_config( + config_str, + cli_args=[ + "name=dummy", + "tag=dummy", + "use_timestamp=false", + f"exp_root_dir={EXP_ROOT_DIR}", + "system.prompt_processor.prompt=placeholder", + ], + from_string=True, + ) + return { + "source": config_str, + "guidance_scale": cfg.system.guidance.guidance_scale, + "max_steps": cfg.trainer.max_steps, + } + + +def on_model_selector_change(model_name): + cfg = load_model_config_attrs(model_name) + return [cfg["source"], cfg["guidance_scale"]] + + +def get_current_status(process, trial_dir, alive_path): + status = ExperimentStatus() + + status.pid = process.pid + + # write the current timestamp to the alive file + # the watcher will know the last active time of this process from this timestamp + if os.path.exists(os.path.dirname(alive_path)): + alive_fp = open(alive_path, "w") + alive_fp.seek(0) + alive_fp.write(str(time.time())) + alive_fp.flush() + + log_path = os.path.join(trial_dir, "logs") + progress_path = os.path.join(trial_dir, "progress") + save_path = os.path.join(trial_dir, "save") + + # read current progress from the progress file + # the progress file is created by GradioCallback + if os.path.exists(progress_path): + status.progress = open(progress_path).read() + else: + status.progress = "Setting up everything ..." + + # read the last 10 lines of the log file + if os.path.exists(log_path): + status.log = tail(open(log_path, "rb"), window=10) + else: + status.log = "" + + # get the validation image and testing video if they exist + if os.path.exists(save_path): + images = glob.glob(os.path.join(save_path, "*.png")) + steps = [ + int(re.match(r"it(\d+)-0\.png", os.path.basename(f)).group(1)) + for f in images + ] + images = sorted(list(zip(images, steps)), key=lambda x: x[1]) + if len(images) > 0: + status.output_image = images[-1][0] + + videos = glob.glob(os.path.join(save_path, "*.mp4")) + steps = [ + int(re.match(r"it(\d+)-test\.mp4", os.path.basename(f)).group(1)) + for f in videos + ] + videos = sorted(list(zip(videos, steps)), key=lambda x: x[1]) + if len(videos) > 0: + status.output_video = videos[-1][0] + + export_dirs = glob.glob(os.path.join(save_path, "*export")) + steps = [ + int(re.match(r"it(\d+)-export", os.path.basename(f)).group(1)) + for f in export_dirs + ] + export_dirs = sorted(list(zip(export_dirs, steps)), key=lambda x: x[1]) + if len(export_dirs) > 0: + obj = glob.glob(os.path.join(export_dirs[-1][0], "*.obj")) + if len(obj) > 0: + # FIXME + # seems the gr.Model3D cannot load our manually saved obj file + # here we load the obj and save it to a temporary file using trimesh + mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False) + trimesh.load(obj[0]).export(mesh_path.name) + status.output_mesh = mesh_path.name + + return status + + +def run( + model_name: str, + config: str, + prompt: str, + guidance_scale: float, + seed: int, + max_steps: int, +): + # update status every 1 second + status_update_interval = 1 + + # save the config to a temporary file + config_file = tempfile.NamedTemporaryFile() + + with open(config_file.name, "w") as f: + f.write(config) + + # manually assign the output directory, name and tag so that we know the trial directory + name = os.path.basename(model_name_to_config[model_name]).split(".")[0] + tag = datetime.now().strftime("@%Y%m%d-%H%M%S") + trial_dir = os.path.join(EXP_ROOT_DIR, name, tag) + alive_path = os.path.join(trial_dir, "alive") + + # spawn the training process + process = subprocess.Popen( + f"python launch.py --config {config_file.name} --train --gpu 0 --gradio trainer.enable_progress_bar=false".split() + + [ + f'name="{name}"', + f'tag="{tag}"', + f"exp_root_dir={EXP_ROOT_DIR}", + "use_timestamp=false", + f'system.prompt_processor.prompt="{prompt}"', + f"system.guidance.guidance_scale={guidance_scale}", + f"seed={seed}", + f"trainer.max_steps={max_steps}", + ] + ) + + # spawn the watcher process + watch_process = subprocess.Popen( + "python gradio_app.py watch".split() + + ["--pid", f"{process.pid}", "--trial-dir", f"{trial_dir}"] + ) + + # update status (progress, log, image, video) every status_update_interval senconds + # button status: Run -> Stop + while process.poll() is None: + time.sleep(status_update_interval) + yield get_current_status(process, trial_dir, alive_path).tolist() + [ + gr.update(visible=False), + gr.update(value="Stop", variant="stop", visible=True), + ] + + # wait for the processes to finish + process.wait() + watch_process.wait() + + # update status one last time + # button status: Stop / Reset -> Run + status = get_current_status(process, trial_dir, alive_path) + status.progress = "Finished." + yield status.tolist() + [ + gr.update(value="Run", variant="primary", visible=True), + gr.update(visible=False), + ] + + +def stop_run(pid): + # kill the process + print(f"Trying to kill process {pid} ...") + try: + os.kill(pid, signal.SIGKILL) + except: + print(f"Exception when killing process {pid}.") + # button status: Stop -> Reset + return [ + gr.update(value="Reset", variant="secondary", visible=True), + gr.update(visible=False), + ] + + +def launch(port, listen=False): + with gr.Blocks(title="threestudio - Web Demo") as demo: + with gr.Row(): + pid = gr.State() + with gr.Column(scale=1): + header = gr.Markdown( + """ + # threestudio + + - Select a model from the dropdown menu. + - Input a text prompt. + - Hit Run! + """ + ) + + # model selection dropdown + model_selector = gr.Dropdown( + value=model_choices[0], + choices=model_choices, + label="Select a model", + ) + + # prompt input + prompt_input = gr.Textbox(value=DEFAULT_PROMPT, label="Input prompt") + + # guidance scale slider + guidance_scale_input = gr.Slider( + minimum=0.0, + maximum=100.0, + value=load_model_config_attrs(model_selector.value)[ + "guidance_scale" + ], + step=0.5, + label="Guidance scale", + ) + + # seed slider + seed_input = gr.Slider( + minimum=0, maximum=2147483647, value=0, step=1, label="Seed" + ) + + max_steps_input = gr.Slider( + minimum=1, + maximum=5000, + value=5000, + step=1, + label="Number of training steps", + ) + + # full config viewer + with gr.Accordion("See full configurations", open=False): + config_editor = gr.Code( + value=load_model_config(model_selector.value), + language="yaml", + interactive=False, + ) + + # load config on model selection change + model_selector.change( + fn=on_model_selector_change, + inputs=model_selector, + outputs=[config_editor, guidance_scale_input], + ) + + run_btn = gr.Button(value="Run", variant="primary") + stop_btn = gr.Button(value="Stop", variant="stop", visible=False) + + # generation status + status = gr.Textbox( + value="Hit the Run button to start.", + label="Status", + lines=1, + max_lines=1, + ) + + with gr.Column(scale=1): + with gr.Accordion("See terminal logs", open=False): + # logs + logs = gr.Textbox(label="Logs", lines=10) + + # validation image display + output_image = gr.Image(value=None, label="Image") + + # testing video display + output_video = gr.Video(value=None, label="Video") + + # export mesh display + output_mesh = gr.Model3D(value=None, label="3D Mesh") + + run_event = run_btn.click( + fn=run, + inputs=[ + model_selector, + config_editor, + prompt_input, + guidance_scale_input, + seed_input, + max_steps_input, + ], + outputs=[ + pid, + status, + logs, + output_image, + output_video, + output_mesh, + run_btn, + stop_btn, + ], + ) + stop_btn.click( + fn=stop_run, inputs=[pid], outputs=[run_btn, stop_btn], cancels=[run_event] + ) + + launch_args = {"server_port": port} + if listen: + launch_args["server_name"] = "0.0.0.0" + demo.queue().launch(**launch_args) + + +def watch( + pid: int, trial_dir: str, alive_timeout: int, wait_timeout: int, check_interval: int +) -> None: + print(f"Spawn watcher for process {pid}") + + def timeout_handler(signum, frame): + exit(1) + + alive_path = os.path.join(trial_dir, "alive") + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(wait_timeout) + + def loop_find_progress_file(): + while True: + if not os.path.exists(alive_path): + time.sleep(check_interval) + else: + signal.alarm(0) + return + + def loop_check_alive(): + while True: + if not psutil.pid_exists(pid): + print(f"Process {pid} not exists, watcher exits.") + exit(0) + alive_timestamp = float(open(alive_path).read()) + if time.time() - alive_timestamp > alive_timeout: + print(f"Alive timeout for process {pid}, killed.") + try: + os.kill(pid, signal.SIGKILL) + except: + print(f"Exception when killing process {pid}.") + exit(0) + time.sleep(check_interval) + + # loop until alive file is found, or alive_timeout is reached + loop_find_progress_file() + # kill the process if it is not accessed for alive_timeout seconds + loop_check_alive() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("operation", type=str, choices=["launch", "watch"]) + args, extra = parser.parse_known_args() + if args.operation == "launch": + parser.add_argument("--listen", action="store_true") + parser.add_argument("--port", type=int, default=7860) + args = parser.parse_args() + launch(args.port, listen=args.listen) + if args.operation == "watch": + parser.add_argument("--pid", type=int) + parser.add_argument("--trial-dir", type=str) + parser.add_argument("--alive-timeout", type=int, default=10) + parser.add_argument("--wait-timeout", type=int, default=10) + parser.add_argument("--check-interval", type=int, default=1) + args = parser.parse_args() + watch( + args.pid, + args.trial_dir, + alive_timeout=args.alive_timeout, + wait_timeout=args.wait_timeout, + check_interval=args.check_interval, + ) diff --git a/launch.py b/launch.py new file mode 100644 index 0000000..fa25450 --- /dev/null +++ b/launch.py @@ -0,0 +1,252 @@ +import argparse +import contextlib +import importlib +import logging +import os +import sys + + +class ColoredFilter(logging.Filter): + """ + A logging filter to add color to certain log levels. + """ + + RESET = "\033[0m" + RED = "\033[31m" + GREEN = "\033[32m" + YELLOW = "\033[33m" + BLUE = "\033[34m" + MAGENTA = "\033[35m" + CYAN = "\033[36m" + + COLORS = { + "WARNING": YELLOW, + "INFO": GREEN, + "DEBUG": BLUE, + "CRITICAL": MAGENTA, + "ERROR": RED, + } + + RESET = "\x1b[0m" + + def __init__(self): + super().__init__() + + def filter(self, record): + if record.levelname in self.COLORS: + color_start = self.COLORS[record.levelname] + record.levelname = f"{color_start}[{record.levelname}]" + record.msg = f"{record.msg}{self.RESET}" + return True + + +def main(args, extras) -> None: + # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) + env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] + selected_gpus = [0] + + # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified. + # As far as Pytorch Lightning is concerned, we always use all available GPUs + # (possibly filtered by CUDA_VISIBLE_DEVICES). + devices = -1 + if len(env_gpus) > 0: + # CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script. + n_gpus = len(env_gpus) + else: + selected_gpus = list(args.gpu.split(",")) + n_gpus = len(selected_gpus) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu + + import pytorch_lightning as pl + import torch + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint + from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger + from pytorch_lightning.utilities.rank_zero import rank_zero_only + + if args.typecheck: + from jaxtyping import install_import_hook + + install_import_hook("threestudio", "typeguard.typechecked") + + import threestudio + from threestudio.systems.base import BaseSystem + from threestudio.utils.callbacks import ( + CodeSnapshotCallback, + ConfigSnapshotCallback, + CustomProgressBar, + ProgressCallback, + ) + from threestudio.utils.config import ExperimentConfig, load_config + from threestudio.utils.misc import get_rank + from threestudio.utils.typing import Optional + + logger = logging.getLogger("pytorch_lightning") + if args.verbose: + logger.setLevel(logging.DEBUG) + + for handler in logger.handlers: + if handler.stream == sys.stderr: # type: ignore + if not args.gradio: + handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) + handler.addFilter(ColoredFilter()) + else: + handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) + + # parse YAML config to OmegaConf + cfg: ExperimentConfig + cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) + + if len(cfg.custom_import) > 0: + print(cfg.custom_import) + for extension in cfg.custom_import: + importlib.import_module(extension) + # set a different seed for each device + pl.seed_everything(cfg.seed + get_rank(), workers=True) + + dm = threestudio.find(cfg.data_type)(cfg.data) + + # Auto check resume files during training + if args.train and cfg.resume is None: + import glob + resume_file_list = glob.glob(f"{cfg.trial_dir}/ckpts/*") + if len(resume_file_list) != 0: + print(sorted(resume_file_list)) + cfg.resume = sorted(resume_file_list)[-1] + print(f"Find resume file: {cfg.resume}") + + system: BaseSystem = threestudio.find(cfg.system_type)( + cfg.system, resumed=cfg.resume is not None + ) + system.set_save_dir(os.path.join(cfg.trial_dir, "save")) + + if args.gradio: + fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) + fh.setLevel(logging.INFO) + if args.verbose: + fh.setLevel(logging.DEBUG) + fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) + logger.addHandler(fh) + + callbacks = [] + if args.train: + callbacks += [ + ModelCheckpoint( + dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint + ), + LearningRateMonitor(logging_interval="step"), + CodeSnapshotCallback( + os.path.join(cfg.trial_dir, "code"), use_version=False + ), + ConfigSnapshotCallback( + args.config, + cfg, + os.path.join(cfg.trial_dir, "configs"), + use_version=False, + ), + ] + if args.gradio: + callbacks += [ + ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) + ] + else: + callbacks += [CustomProgressBar(refresh_rate=1)] + + def write_to_text(file, lines): + with open(file, "w") as f: + for line in lines: + f.write(line + "\n") + + loggers = [] + if args.train: + # make tensorboard logging dir to suppress warning + rank_zero_only( + lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) + )() + loggers += [ + TensorBoardLogger(cfg.trial_dir, name="tb_logs"), + CSVLogger(cfg.trial_dir, name="csv_logs"), + ] + system.get_loggers() + rank_zero_only( + lambda: write_to_text( + os.path.join(cfg.trial_dir, "cmd.txt"), + ["python " + " ".join(sys.argv), str(args)], + ) + )() + + trainer = Trainer( + callbacks=callbacks, + logger=loggers, + inference_mode=False, + accelerator="gpu", + devices=devices, + **cfg.trainer, + ) + + def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): + if ckpt_path is None: + return + ckpt = torch.load(ckpt_path, map_location="cpu") + system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) + + if args.train: + trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) + trainer.test(system, datamodule=dm) + if args.gradio: + # also export assets if in gradio mode + trainer.predict(system, datamodule=dm) + elif args.validate: + # manually set epoch and global_step as they cannot be automatically resumed + set_system_status(system, cfg.resume) + trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) + elif args.test: + # manually set epoch and global_step as they cannot be automatically resumed + set_system_status(system, cfg.resume) + trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) + elif args.export: + set_system_status(system, cfg.resume) + trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True, help="path to config file") + parser.add_argument( + "--gpu", + default="0", + help="GPU(s) to be used. 0 means use the 1st available GPU. " + "1,2 means use the 2nd and 3rd available GPU. " + "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " + "this argument is ignored and all available GPUs are always used.", + ) + + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--train", action="store_true") + group.add_argument("--validate", action="store_true") + group.add_argument("--test", action="store_true") + group.add_argument("--export", action="store_true") + + parser.add_argument( + "--gradio", action="store_true", help="if true, run in gradio mode" + ) + + parser.add_argument( + "--verbose", action="store_true", help="if true, set logging level to DEBUG" + ) + + parser.add_argument( + "--typecheck", + action="store_true", + help="whether to enable dynamic type checking", + ) + + args, extras = parser.parse_known_args() + + if args.gradio: + # FIXME: no effect, stdout is not captured + with contextlib.redirect_stdout(sys.stderr): + main(args, extras) + else: + main(args, extras) \ No newline at end of file diff --git a/metric_utils.py b/metric_utils.py new file mode 100644 index 0000000..7872b74 --- /dev/null +++ b/metric_utils.py @@ -0,0 +1,567 @@ +# * evaluate use laion/CLIP-ViT-H-14-laion2B-s32B-b79K +# best open source clip so far: laion/CLIP-ViT-bigG-14-laion2B-39B-b160k +# code adapted from NeuralLift-360 + +import torch +import torch.nn as nn +import os +import torchvision.transforms as T +import torchvision.transforms.functional as TF +import matplotlib.pyplot as plt +# import clip +from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor +from torchvision import transforms +import numpy as np +import torch.nn.functional as F +from tqdm import tqdm +import cv2 +from PIL import Image +# import torchvision.transforms as transforms +import glob +from skimage.metrics import peak_signal_noise_ratio as compare_psnr +import lpips +from os.path import join as osp +import argparse +import pandas as pd +import contextual_loss as cl + +criterion = cl.ContextualLoss(use_vgg=True, vgg_layer='relu5_4') + +class CLIP(nn.Module): + + def __init__(self, + device, + clip_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', + size=224): #'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'): + super().__init__() + self.size = size + self.device = f"cuda:{device}" + + clip_name = clip_name + + self.feature_extractor = CLIPFeatureExtractor.from_pretrained( + clip_name) + self.clip_model = CLIPModel.from_pretrained(clip_name).to(self.device) + self.tokenizer = CLIPTokenizer.from_pretrained( + 'openai/clip-vit-base-patch32') + + self.normalize = transforms.Normalize( + mean=self.feature_extractor.image_mean, + std=self.feature_extractor.image_std) + + self.resize = transforms.Resize(224) + self.to_tensor = transforms.ToTensor() + + # image augmentation + self.aug = T.Compose([ + T.Resize((224, 224)), + T.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + # * recommend to use this function for evaluation + @torch.no_grad() + def score_gt(self, ref_img_path, novel_views): + # assert len(novel_views) == 100 + clip_scores = [] + for novel in novel_views: + clip_scores.append(self.score_from_path(ref_img_path, [novel])) + return np.mean(clip_scores) + + # * recommend to use this function for evaluation + # def score_gt(self, ref_paths, novel_paths): + # clip_scores = [] + # for img1_path, img2_path in zip(ref_paths, novel_paths): + # clip_scores.append(self.score_from_path(img1_path, img2_path)) + + # return np.mean(clip_scores) + + def similarity(self, image1_features: torch.Tensor, + image2_features: torch.Tensor) -> float: + with torch.no_grad(), torch.cuda.amp.autocast(): + y = image1_features.T.view(image1_features.T.shape[1], + image1_features.T.shape[0]) + similarity = torch.matmul(y, image2_features.T) + # print(similarity) + return similarity[0][0].item() + + def get_img_embeds(self, img): + if img.shape[0] == 4: + img = img[:3, :, :] + + img = self.aug(img).to(self.device) + img = img.unsqueeze(0) # b,c,h,w + + # plt.imshow(img.cpu().squeeze(0).permute(1, 2, 0).numpy()) + # plt.show() + # print(img) + + image_z = self.clip_model.get_image_features(img) + image_z = image_z / image_z.norm(dim=-1, + keepdim=True) # normalize features + return image_z + + def score_from_feature(self, img1, img2): + img1_feature, img2_feature = self.get_img_embeds( + img1), self.get_img_embeds(img2) + # for debug + return self.similarity(img1_feature, img2_feature) + + def read_img_list(self, img_list): + size = self.size + images = [] + # white_background = np.ones((size, size, 3), dtype=np.uint8) * 255 + + for img_path in img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + # print(img_path) + if img.shape[2] == 4: # Handle BGRA images + alpha = img[:, :, 3] # Extract alpha channel + img = cv2.cvtColor(img,cv2.COLOR_BGRA2RGB) # Convert BGRA to BGR + img[np.where(alpha == 0)] = [ + 255, 255, 255 + ] # Set transparent pixels to white + else: # Handle other image formats like JPG and PNG + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) + + # plt.imshow(img) + # plt.show() + + images.append(img) + + images = np.stack(images, axis=0) + # images[np.where(images == 0)] = 255 # Set black pixels to white + # images = np.where(images == 0, white_background, images) # Set transparent pixels to white + # images = images.astype(np.float32) + + return images + + def score_from_path(self, img1_path, img2_path): + img1, img2 = self.read_img_list(img1_path), self.read_img_list(img2_path) + img1 = np.squeeze(img1) + img2 = np.squeeze(img2) + # plt.imshow(img1) + # plt.show() + # plt.imshow(img2) + # plt.show() + + img1, img2 = self.to_tensor(img1), self.to_tensor(img2) + # print("img1 to tensor ",img1) + return self.score_from_feature(img1, img2) + + +def numpy_to_torch(images): + images = images * 2.0 - 1.0 + images = torch.from_numpy(images.transpose((0, 3, 1, 2))).float() + return images.cuda() + + +class LPIPSMeter: + + def __init__(self, + net='alex', + device=None, + size=224): # or we can use 'alex', 'vgg' as network + self.size = size + self.net = net + self.results = [] + self.device = device if device is not None else torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.fn = lpips.LPIPS(net=net).eval().to(self.device) + + def measure(self): + return np.mean(self.results) + + def report(self): + return f'LPIPS ({self.net}) = {self.measure():.6f}' + + def read_img_list(self, img_list): + size = self.size + images = [] + white_background = np.ones((size, size, 3), dtype=np.uint8) * 255 + + for img_path in img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + + if img.shape[2] == 4: # Handle BGRA images + alpha = img[:, :, 3] # Extract alpha channel + img = cv2.cvtColor(img, + cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR + + img = cv2.cvtColor(img, + cv2.COLOR_BGR2RGB) # Convert BGR to RGB + img[np.where(alpha == 0)] = [ + 255, 255, 255 + ] # Set transparent pixels to white + else: # Handle other image formats like JPG and PNG + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) + images.append(img) + + images = np.stack(images, axis=0) + # images[np.where(images == 0)] = 255 # Set black pixels to white + # images = np.where(images == 0, white_background, images) # Set transparent pixels to white + images = images.astype(np.float32) / 255.0 + + return images + + # * recommend to use this function for evaluation + @torch.no_grad() + def score_gt(self, ref_paths, novel_paths): + self.results = [] + for path0, path1 in zip(ref_paths, novel_paths): + # Load images + # img0 = lpips.im2tensor(lpips.load_image(path0)).cuda() # RGB image from [-1,1] + # img1 = lpips.im2tensor(lpips.load_image(path1)).cuda() + img0, img1 = self.read_img_list([path0]), self.read_img_list( + [path1]) + img0, img1 = numpy_to_torch(img0), numpy_to_torch(img1) + # print(img0.shape,img1.shape) + img0 = F.interpolate(img0, + size=(self.size, self.size), + mode='area') + img1 = F.interpolate(img1, + size=(self.size, self.size), + mode='area') + + # for debug vis + # plt.imshow(img0.cpu().squeeze(0).permute(1, 2, 0).numpy()) + # plt.show() + # plt.imshow(img1.cpu().squeeze(0).permute(1, 2, 0).numpy()) + # plt.show() + # equivalent to cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA + + # print(img0.shape,img1.shape) + + self.results.append(self.fn.forward(img0, img1).cpu().numpy()) + + return self.measure() + +class CXMeter: + + def __init__(self, + net='vgg', + device=None, + size=512): # or we can use 'alex', 'vgg' as network + self.size = size + self.net = net + self.results = [] + self.device = device if device is not None else torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.fn = lpips.LPIPS(net=net).eval().to(self.device) + + def measure(self): + return np.mean(self.results) + + def report(self): + return f'LPIPS ({self.net}) = {self.measure():.6f}' + + def read_img_list(self, img_list): + size = self.size + images = [] + white_background = np.ones((size, size, 3), dtype=np.uint8) * 255 + + for img_path in img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + + if img.shape[2] == 4: # Handle BGRA images + alpha = img[:, :, 3] # Extract alpha channel + img = cv2.cvtColor(img, + cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR + + img = cv2.cvtColor(img, + cv2.COLOR_BGR2RGB) # Convert BGR to RGB + img[np.where(alpha == 0)] = [ + 255, 255, 255 + ] # Set transparent pixels to white + else: # Handle other image formats like JPG and PNG + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) + images.append(img) + + images = np.stack(images, axis=0) + # images[np.where(images == 0)] = 255 # Set black pixels to white + # images = np.where(images == 0, white_background, images) # Set transparent pixels to white + images = images.astype(np.float32) / 255.0 + + return images + + # * recommend to use this function for evaluation + @torch.no_grad() + def score_gt(self, ref_paths, novel_paths): + self.results = [] + path0 = ref_paths[0] + print('calculating CX loss') + for path1 in tqdm(novel_paths): + # Load images + img0, img1 = self.read_img_list([path0]), self.read_img_list( + [path1]) + img0, img1 = numpy_to_torch(img0), numpy_to_torch(img1) + img0, img1 = img0 * 0.5 + 0.5, img1 * 0.5 + 0.5 + img0 = F.interpolate(img0, + size=(self.size, self.size), + mode='area') + img1 = F.interpolate(img1, + size=(self.size, self.size), + mode='area') + loss = criterion(img0.cpu(), img1.cpu()) + self.results.append(loss.cpu().numpy()) + + return self.measure() + +class PSNRMeter: + + def __init__(self, size=800): + self.results = [] + self.size = size + + def read_img_list(self, img_list): + size = self.size + images = [] + white_background = np.ones((size, size, 3), dtype=np.uint8) * 255 + for img_path in img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + + if img.shape[2] == 4: # Handle BGRA images + alpha = img[:, :, 3] # Extract alpha channel + img = cv2.cvtColor(img, + cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR + + img = cv2.cvtColor(img, + cv2.COLOR_BGR2RGB) # Convert BGR to RGB + img[np.where(alpha == 0)] = [ + 255, 255, 255 + ] # Set transparent pixels to white + else: # Handle other image formats like JPG and PNG + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) + images.append(img) + + images = np.stack(images, axis=0) + # images[np.where(images == 0)] = 255 # Set black pixels to white + # images = np.where(images == 0, white_background, images) # Set transparent pixels to white + images = images.astype(np.float32) / 255.0 + # print(images.shape) + return images + + def update(self, preds, truths): + # print(preds.shape) + + psnr_values = [] + # For each pair of images in the batches + for img1, img2 in zip(preds, truths): + # Compute the PSNR and add it to the list + # print(img1.shape,img2.shape) + + # for debug + # plt.imshow(img1) + # plt.show() + # plt.imshow(img2) + # plt.show() + + psnr = compare_psnr( + img1, img2, + data_range=1.0) # assuming your images are scaled to [0,1] + # print(f"temp psnr {psnr}") + psnr_values.append(psnr) + + # Convert the list of PSNR values to a numpy array + self.results = psnr_values + + def measure(self): + return np.mean(self.results) + + def report(self): + return f'PSNR = {self.measure():.6f}' + + # * recommend to use this function for evaluation + def score_gt(self, ref_paths, novel_paths): + self.results = [] + # [B, N, 3] or [B, H, W, 3], range[0, 1] + preds = self.read_img_list(ref_paths) + print('novel_paths', novel_paths) + truths = self.read_img_list(novel_paths) + self.update(preds, truths) + return self.measure() + +# all_inputs = 'data' +# nerf_dataset = os.listdir(osp(all_inputs, 'nerf4')) +# realfusion_dataset = os.listdir(osp(all_inputs, 'realfusion15')) +# meta_examples = { +# 'nerf4': nerf_dataset, +# 'realfusion15': realfusion_dataset, +# } +# all_datasets = meta_examples.keys() + +# organization 1 +def deprecated_score_from_method_for_dataset(my_scorer, + method, + dataset, + input, + output, + score_type='clip', + ): # psnr, lpips + # print("\n\n\n") + # print(f"______{method}___{dataset}___{score_type}_________") + scores = {} + final_res = 0 + examples = meta_examples[dataset] + for i in range(len(examples)): + + # compare entire folder for clip + if score_type == 'clip': + novel_view = osp(pred_path, examples[i], 'colors') + # compare first image for other metrics + else: + if method == '3d_fuse': method = '3d_fuse_0' + novel_view = list( + glob.glob( + osp(pred_path, examples[i], 'colors', + 'step_0000*')))[0] + + score_i = my_scorer.score_gt( + [], [novel_view]) + scores[examples[i]] = score_i + final_res += score_i + # print(scores, " Avg : ", final_res / len(examples)) + # print("``````````````````````") + return scores + +# results organization 2 +def score_from_method_for_dataset(my_scorer, + input_path, + pred_path, + score_type='clip', + rgb_name='lambertian', + result_folder='results/images', + first_str='*0000*' + ): # psnr, lpips + scores = {} + final_res = 0 + examples = os.listdir(input_path) + for i in range(len(examples)): + # ref path + ref_path = osp(input_path, examples[i], 'rgba.png') + # compare entire folder for clip + print(pred_path,'*'+examples[i]+'*', result_folder, f'*{rgb_name}*') + exit(0) + if score_type == 'clip': + novel_view = glob.glob(osp(pred_path,'*'+examples[i]+'*', result_folder, f'*{rgb_name}*')) + print(f'[INOF] {score_type} loss for example {examples[i]} between 1 GT and {len(novel_view)} predictions') + # compare first image for other metrics + else: + novel_view = glob.glob(osp(pred_path, '*'+examples[i]+'*/', result_folder, f'{first_str}{rgb_name}*')) + print(f'[INOF] {score_type} loss for example {examples[i]} between {ref_path} and {novel_view}') + # breakpoint() + score_i = my_scorer.score_gt([ref_path], novel_view) + scores[examples[i]] = score_i + final_res += score_i + avg_score = final_res / len(examples) + scores['average'] = avg_score + return scores + + +# results organization 2 +def score_from_my_method_for_dataset(my_scorer, + input_path, dataset, + score_type='clip' + ): # psnr, lpips + scores = {} + final_res = 0 + input_path = osp(input_path, dataset) + ref_path = glob.glob(osp(input_path, "*_rgba.png")) + novel_view = [osp(input_path, '%d.png' % i) for i in range(120)] + # print(ref_path) + # print(novel_view) + for i in tqdm(range(120)): + if os.path.exists(osp(input_path, '%d_color.png' % i)): + continue + img = cv2.imread(novel_view[i]) + H = img.shape[0] + img = img[:, :H] + cv2.imwrite(osp(input_path, '%d_color.png' % i), img) + if score_type == 'clip' or score_type == 'cx': + novel_view = [osp(input_path, '%d_color.png' % i) for i in range(120)] + else: + novel_view = [osp(input_path, '%d_color.png' % i) for i in range(1)] + print(novel_view) + scores['%s_average' % dataset] = my_scorer.score_gt(ref_path, novel_view) + return scores + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Script to accept three string arguments") + parser.add_argument("--input_path", + default=None, + help="Specify the input path") + parser.add_argument("--pred_pattern", + default="out/magic123*", + help="Specify the pattern of predition paths") + parser.add_argument("--results_folder", + default="results/images", + help="where are the results under each pred_path") + parser.add_argument("--rgb_name", + default="lambertian", + help="the postfix of the image") + parser.add_argument("--first_str", + default="*0000*", + help="the str to indicate the first view") + parser.add_argument("--datasets", + default=None, + nargs='*', + help="Specify the output path") + parser.add_argument("--device", + type=int, + default=0, + help="Specify the GPU device to be used") + parser.add_argument("--save_dir", type=str, default='all_metrics/results') + args = parser.parse_args() + + clip_scorer = CLIP(args.device) + lpips_scorer = LPIPSMeter() + psnr_scorer = PSNRMeter() + CX_scorer = CXMeter() + # criterion = criterion.to(args.device) + + os.makedirs(args.save_dir, exist_ok=True) + + for dataset in os.listdir(args.input_path): + print(dataset) + results_dict = {} + results_dict['clip'] = score_from_my_method_for_dataset( + clip_scorer, args.input_path, dataset, 'clip') + + results_dict['psnr'] = score_from_my_method_for_dataset( + psnr_scorer, args.input_path, dataset, 'psnr') + + results_dict['lpips'] = score_from_my_method_for_dataset( + lpips_scorer, args.input_path, dataset, 'lpips') + + results_dict['CX'] = score_from_my_method_for_dataset( + CX_scorer, args.input_path, dataset, 'cx') + + df = pd.DataFrame(results_dict) + print(df) + df.to_csv(f"{args.save_dir}/result.csv") + + + # for dataset in args.datasets: + # input_path = osp(args.input_path, dataset) + + # # assume the pred_path is organized as: pred_path/methods/dataset + # pred_pattern = osp(args.pred_pattern, dataset) + # pred_paths = glob.glob(pred_pattern) + # print(f"[INFO] Following the pattern {pred_pattern}, find {len(pred_paths)} pred_paths: \n", pred_paths) + # if len(pred_paths) == 0: + # raise IOError + # for pred_path in pred_paths: + # if not os.path.exists(pred_path): + # print(f'[WARN] prediction does not exit for {pred_path}') + # else: + # print(f'[INFO] evaluate {pred_path}') + \ No newline at end of file diff --git a/preprocess_image.py b/preprocess_image.py new file mode 100644 index 0000000..a55ee79 --- /dev/null +++ b/preprocess_image.py @@ -0,0 +1,237 @@ +import os +import sys +import cv2 +import argparse +import numpy as np +import matplotlib.pyplot as plt + +import glob +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from PIL import Image + +class BackgroundRemoval(): + def __init__(self, device='cuda'): + + from carvekit.api.high import HiInterface + self.interface = HiInterface( + object_type="object", # Can be "object" or "hairs-like". + batch_size_seg=5, + batch_size_matting=1, + device=device, + seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net + matting_mask_size=2048, + trimap_prob_threshold=231, + trimap_dilation=30, + trimap_erosion_iters=5, + fp16=True, + ) + + @torch.no_grad() + def __call__(self, image): + # image: [H, W, 3] array in [0, 255]. + image = Image.fromarray(image) + + image = self.interface([image])[0] + image = np.array(image) + + return image + +class BLIP2(): + def __init__(self, device='cuda'): + self.device = device + from transformers import AutoProcessor, Blip2ForConditionalGeneration + self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") + self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device) + + @torch.no_grad() + def __call__(self, image): + image = Image.fromarray(image) + inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16) + + generated_ids = self.model.generate(**inputs, max_new_tokens=20) + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + + return generated_text + + +class DPT(): + def __init__(self, task='depth', device='cuda'): + + self.task = task + self.device = device + + from threestudio.utils.dpt import DPTDepthModel + + if task == 'depth': + path = 'load/omnidata/omnidata_dpt_depth_v2.ckpt' + self.model = DPTDepthModel(backbone='vitb_rn50_384') + self.aug = transforms.Compose([ + transforms.Resize((384, 384)), + transforms.ToTensor(), + transforms.Normalize(mean=0.5, std=0.5) + ]) + + else: # normal + path = 'load/omnidata/omnidata_dpt_normal_v2.ckpt' + self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) + self.aug = transforms.Compose([ + transforms.Resize((384, 384)), + transforms.ToTensor() + ]) + + # load model + checkpoint = torch.load(path, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = {} + for k, v in checkpoint['state_dict'].items(): + state_dict[k[6:]] = v + else: + state_dict = checkpoint + self.model.load_state_dict(state_dict) + self.model.eval().to(device) + + + @torch.no_grad() + def __call__(self, image): + # image: np.ndarray, uint8, [H, W, 3] + H, W = image.shape[:2] + image = Image.fromarray(image) + + image = self.aug(image).unsqueeze(0).to(self.device) + + if self.task == 'depth': + depth = self.model(image).clamp(0, 1) + depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False) + depth = depth.squeeze(1).cpu().numpy() + return depth + else: + normal = self.model(image).clamp(0, 1) + normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False) + normal = normal.cpu().numpy() + return normal + +def preprocess_single_image(img_path, args): + out_dir = os.path.dirname(img_path) + out_rgba = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_rgba.png') + out_depth = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_depth.png') + out_normal = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_normal.png') + out_caption = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_caption.txt') + + # load image + print(f'[INFO] loading image {img_path}...') + + # check the exisiting files + if os.path.isfile(out_rgba) and os.path.isfile(out_depth) and os.path.isfile(out_normal): + print(f"{img_path} has already been here!") + return + print(img_path) + image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + carved_image = None + # debug + if image.shape[-1] == 4: + if args.do_rm_bg_force: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) + else: + carved_image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) + + else: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + if args.do_seg: + if carved_image is None: + # carve background + print(f'[INFO] background removal...') + carved_image = BackgroundRemoval()(image) # [H, W, 4] + mask = carved_image[..., -1] > 0 + + # predict depth + print(f'[INFO] depth estimation...') + dpt_depth_model = DPT(task='depth') + depth = dpt_depth_model(image)[0] + depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9) + depth[~mask] = 0 + depth = (depth * 255).astype(np.uint8) + del dpt_depth_model + + # predict normal + print(f'[INFO] normal estimation...') + dpt_normal_model = DPT(task='normal') + normal = dpt_normal_model(image)[0] + normal = (normal * 255).astype(np.uint8).transpose(1, 2, 0) + normal[~mask] = 0 + del dpt_normal_model + + opt.recenter=False + # recenter + if opt.recenter: + print(f'[INFO] recenter...') + final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8) + final_depth = np.zeros((opt.size, opt.size), dtype=np.uint8) + final_normal = np.zeros((opt.size, opt.size, 3), dtype=np.uint8) + + coords = np.nonzero(mask) + x_min, x_max = coords[0].min(), coords[0].max() + y_min, y_max = coords[1].min(), coords[1].max() + h = x_max - x_min + w = y_max - y_min + desired_size = int(opt.size * (1 - opt.border_ratio)) + scale = desired_size / max(h, w) + h2 = int(h * scale) + w2 = int(w * scale) + x2_min = (opt.size - h2) // 2 + x2_max = x2_min + h2 + y2_min = (opt.size - w2) // 2 + y2_max = y2_min + w2 + final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) + final_depth[x2_min:x2_max, y2_min:y2_max] = cv2.resize(depth[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) + final_normal[x2_min:x2_max, y2_min:y2_max] = cv2.resize(normal[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) + + else: + final_rgba = carved_image + final_depth = depth + final_normal = normal + + # write output + cv2.imwrite(out_rgba, cv2.cvtColor(final_rgba, cv2.COLOR_RGBA2BGRA)) + cv2.imwrite(out_depth, final_depth) + cv2.imwrite(out_normal, final_normal) + + if opt.do_caption: + # predict caption (it's too slow... use your brain instead) + print(f'[INFO] captioning...') + blip2 = BLIP2() + caption = blip2(image) + with open(out_caption, 'w') as f: + f.write(caption) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)") + parser.add_argument('--size', default=1024, type=int, help="output resolution") + parser.add_argument('--border_ratio', default=0.1, type=float, help="output border ratio") + parser.add_argument('--recenter', type=bool, default=False, help="recenter, potentially not helpful for multiview zero123") + parser.add_argument('--dont_recenter', dest='recenter', action='store_false') + parser.add_argument('--do_caption', type=bool, default=False, help="do text captioning") + parser.add_argument('--do_seg', type=bool, default=True) + parser.add_argument('--do_rm_bg_force', type=bool, default=False) + + opt = parser.parse_args() + + if os.path.isdir(opt.path): + img_list = sorted(os.path.join(root, fname) for root, _dirs, files in os.walk(opt.path) for fname in files) + img_list = [img for img in img_list if not img.endswith("rgba.png") and not img.endswith("depth.png") and not img.endswith("normal.png")] + img_list = [img for img in img_list if img.endswith(".png")] + for img in img_list: + # try: + preprocess_single_image(img, opt) + # except: + # with open("preprocess_images_invalid.txt", "a") as f: + # print(img, file=f) + else: # single image file + preprocess_single_image(opt.path, opt) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6966066 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,35 @@ +lightning==2.0.0 +omegaconf==2.3.0 +jaxtyping +typeguard +diffusers<=0.23.0 +transformers +accelerate +opencv-python +tensorboard +matplotlib +imageio>=2.28.0 +imageio[ffmpeg] +libigl +xatlas +trimesh[easy] +networkx +pysdf +PyMCubes +wandb +gradio + +# deepfloyd +xformers +bitsandbytes +sentencepiece +safetensors +huggingface_hub + +# for zero123 +einops +kornia +taming-transformers-rom1504 + +#controlnet +controlnet_aux diff --git a/threestudio/__init__.py b/threestudio/__init__.py new file mode 100644 index 0000000..2c83608 --- /dev/null +++ b/threestudio/__init__.py @@ -0,0 +1,36 @@ +__modules__ = {} + + +def register(name): + def decorator(cls): + __modules__[name] = cls + return cls + + return decorator + + +def find(name): + return __modules__[name] + + +### grammar sugar for logging utilities ### +import logging + +logger = logging.getLogger("pytorch_lightning") + +from pytorch_lightning.utilities.rank_zero import ( + rank_zero_debug, + rank_zero_info, + rank_zero_only, +) + +debug = rank_zero_debug +info = rank_zero_info + + +@rank_zero_only +def warn(*args, **kwargs): + logger.warn(*args, **kwargs) + + +from . import data, models, systems diff --git a/threestudio/data/__init__.py b/threestudio/data/__init__.py new file mode 100644 index 0000000..51832a0 --- /dev/null +++ b/threestudio/data/__init__.py @@ -0,0 +1 @@ +from . import image, uncond \ No newline at end of file diff --git a/threestudio/data/image.py b/threestudio/data/image.py new file mode 100644 index 0000000..c0e70fc --- /dev/null +++ b/threestudio/data/image.py @@ -0,0 +1,351 @@ +import bisect +import math +import os +from dataclasses import dataclass, field + +import cv2 +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset, IterableDataset + +import threestudio +from threestudio import register +from threestudio.data.uncond import ( + RandomCameraDataModuleConfig, + RandomCameraDataset, + RandomCameraIterableDataset, +) +from threestudio.utils.base import Updateable +from threestudio.utils.config import parse_structured +from threestudio.utils.misc import get_rank +from threestudio.utils.ops import ( + get_mvp_matrix, + get_projection_matrix, + get_ray_directions, + get_rays, +) +from threestudio.utils.typing import * + + +@dataclass +class SingleImageDataModuleConfig: + # height and width should be Union[int, List[int]] + # but OmegaConf does not support Union of containers + height: Any = 96 + width: Any = 96 + resolution_milestones: List[int] = field(default_factory=lambda: []) + default_elevation_deg: float = 0.0 + default_azimuth_deg: float = -180.0 + default_camera_distance: float = 1.2 + default_fovy_deg: float = 60.0 + image_path: str = "" + use_random_camera: bool = True + random_camera: dict = field(default_factory=dict) + rays_noise_scale: float = 2e-3 + batch_size: int = 1 + requires_depth: bool = False + requires_normal: bool = False + rays_d_normalize: bool = True + use_mixed_camera_config: bool = False + + +class SingleImageDataBase: + def setup(self, cfg, split): + self.split = split + self.rank = get_rank() + self.cfg: SingleImageDataModuleConfig = cfg + + if self.cfg.use_random_camera: + random_camera_cfg = parse_structured( + RandomCameraDataModuleConfig, self.cfg.get("random_camera", {}) + ) + # FIXME: + if self.cfg.use_mixed_camera_config: + if self.rank % 2 == 0: + random_camera_cfg.camera_distance_range=[self.cfg.default_camera_distance, self.cfg.default_camera_distance] + random_camera_cfg.fovy_range=[self.cfg.default_fovy_deg, self.cfg.default_fovy_deg] + self.fixed_camera_intrinsic = True + else: + self.fixed_camera_intrinsic = False + if split == "train": + self.random_pose_generator = RandomCameraIterableDataset( + random_camera_cfg + ) + else: + self.random_pose_generator = RandomCameraDataset( + random_camera_cfg, split + ) + + elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg]) + azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg]) + camera_distance = torch.FloatTensor([self.cfg.default_camera_distance]) + + elevation = elevation_deg * math.pi / 180 + azimuth = azimuth_deg * math.pi / 180 + camera_position: Float[Tensor, "1 3"] = torch.stack( + [ + camera_distance * torch.cos(elevation) * torch.cos(azimuth), + camera_distance * torch.cos(elevation) * torch.sin(azimuth), + camera_distance * torch.sin(elevation), + ], + dim=-1, + ) + + center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position) + up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None] + + light_position: Float[Tensor, "1 3"] = camera_position + lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1) + right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1) + up = F.normalize(torch.cross(right, lookat), dim=-1) + self.c2w: Float[Tensor, "1 3 4"] = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]], + dim=-1, + ) + self.c2w4x4: Float[Tensor, "B 4 4"] = torch.cat( + [self.c2w, torch.zeros_like(self.c2w[:, :1])], dim=1 + ) + self.c2w4x4[:, 3, 3] = 1.0 + + self.camera_position = camera_position + self.light_position = light_position + self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg + self.camera_distance = camera_distance + self.fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg])) + + self.heights: List[int] = ( + [self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height + ) + self.widths: List[int] = ( + [self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width + ) + assert len(self.heights) == len(self.widths) + self.resolution_milestones: List[int] + if len(self.heights) == 1 and len(self.widths) == 1: + if len(self.cfg.resolution_milestones) > 0: + threestudio.warn( + "Ignoring resolution_milestones since height and width are not changing" + ) + self.resolution_milestones = [-1] + else: + assert len(self.heights) == len(self.cfg.resolution_milestones) + 1 + self.resolution_milestones = [-1] + self.cfg.resolution_milestones + + self.directions_unit_focals = [ + get_ray_directions(H=height, W=width, focal=1.0) + for (height, width) in zip(self.heights, self.widths) + ] + self.focal_lengths = [ + 0.5 * height / torch.tan(0.5 * self.fovy) for height in self.heights + ] + + self.height: int = self.heights[0] + self.width: int = self.widths[0] + self.directions_unit_focal = self.directions_unit_focals[0] + self.focal_length = self.focal_lengths[0] + self.set_rays() + self.load_images() + self.prev_height = self.height + + def set_rays(self): + # get directions by dividing directions_unit_focal by focal length + directions: Float[Tensor, "1 H W 3"] = self.directions_unit_focal[None] + directions[:, :, :, :2] = directions[:, :, :, :2] / self.focal_length + + rays_o, rays_d = get_rays( + directions, + self.c2w, + keepdim=True, + noise_scale=self.cfg.rays_noise_scale, + normalize=self.cfg.rays_d_normalize, + ) + + proj_mtx: Float[Tensor, "4 4"] = get_projection_matrix( + self.fovy, self.width / self.height, 0.01, 100.0 + ) # FIXME: hard-coded near and far + mvp_mtx: Float[Tensor, "4 4"] = get_mvp_matrix(self.c2w, proj_mtx) + + self.rays_o, self.rays_d = rays_o, rays_d + self.mvp_mtx = mvp_mtx + + def load_images(self): + # load image + assert os.path.exists( + self.cfg.image_path + ), f"Could not find image {self.cfg.image_path}!" + rgba = cv2.cvtColor( + cv2.imread(self.cfg.image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA + ) + rgba = ( + cv2.resize( + rgba, (self.width, self.height), interpolation=cv2.INTER_AREA + ).astype(np.float32) + / 255.0 + ) + rgb = rgba[..., :3] + self.rgb: Float[Tensor, "1 H W 3"] = ( + torch.from_numpy(rgb).unsqueeze(0).contiguous().to(self.rank) + ) + self.mask: Float[Tensor, "1 H W 1"] = ( + torch.from_numpy(rgba[..., 3:] > 0.5).unsqueeze(0).to(self.rank) + ) + print( + f"[INFO] single image dataset: load image {self.cfg.image_path} {self.rgb.shape}" + ) + + # load depth + if self.cfg.requires_depth: + depth_path = self.cfg.image_path.replace("_rgba.png", "_depth.png") + assert os.path.exists(depth_path) + depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) + depth = cv2.resize( + depth, (self.width, self.height), interpolation=cv2.INTER_AREA + ) + self.depth: Float[Tensor, "1 H W 1"] = ( + torch.from_numpy(depth.astype(np.float32) / 255.0) + .unsqueeze(0) + .to(self.rank) + ) + print( + f"[INFO] single image dataset: load depth {depth_path} {self.depth.shape}" + ) + else: + self.depth = None + + # load normal + if self.cfg.requires_normal: + normal_path = self.cfg.image_path.replace("_rgba.png", "_normal.png") + assert os.path.exists(normal_path) + normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED) + normal = cv2.resize( + normal, (self.width, self.height), interpolation=cv2.INTER_AREA + ) + self.normal: Float[Tensor, "1 H W 3"] = ( + torch.from_numpy(normal.astype(np.float32) / 255.0) + .unsqueeze(0) + .to(self.rank) + ) + print( + f"[INFO] single image dataset: load normal {normal_path} {self.normal.shape}" + ) + else: + self.normal = None + + def get_all_images(self): + return self.rgb + + def update_step_(self, epoch: int, global_step: int, on_load_weights: bool = False): + size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1 + self.height = self.heights[size_ind] + if self.height == self.prev_height: + return + + self.prev_height = self.height + self.width = self.widths[size_ind] + self.directions_unit_focal = self.directions_unit_focals[size_ind] + self.focal_length = self.focal_lengths[size_ind] + threestudio.debug(f"Training height: {self.height}, width: {self.width}") + self.set_rays() + self.load_images() + + +class SingleImageIterableDataset(IterableDataset, SingleImageDataBase, Updateable): + def __init__(self, cfg: Any, split: str) -> None: + super().__init__() + self.setup(cfg, split) + + def collate(self, batch) -> Dict[str, Any]: + batch = { + "rays_o": self.rays_o, + "rays_d": self.rays_d, + "mvp_mtx": self.mvp_mtx, + "camera_positions": self.camera_position, + "light_positions": self.light_position, + "elevation": self.elevation_deg, + "azimuth": self.azimuth_deg, + "camera_distances": self.camera_distance, + "rgb": self.rgb, + "ref_depth": self.depth, + "ref_normal": self.normal, + "mask": self.mask, + "height": self.cfg.height, + "width": self.cfg.width, + "c2w": self.c2w, + "c2w4x4": self.c2w4x4, + } + if self.cfg.use_random_camera: + batch["random_camera"] = self.random_pose_generator.collate(None) + + return batch + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + self.update_step_(epoch, global_step, on_load_weights) + self.random_pose_generator.update_step(epoch, global_step, on_load_weights) + + def __iter__(self): + while True: + yield {} + + +class SingleImageDataset(Dataset, SingleImageDataBase): + def __init__(self, cfg: Any, split: str) -> None: + super().__init__() + self.setup(cfg, split) + + def __len__(self): + return len(self.random_pose_generator) + + def __getitem__(self, index): + batch = self.random_pose_generator[index] + batch.update( + { + "height": self.random_pose_generator.cfg.eval_height, + "width": self.random_pose_generator.cfg.eval_width, + "mvp_mtx_ref": self.mvp_mtx[0], + "c2w_ref": self.c2w4x4, + } + ) + return batch + + +@register("single-image-datamodule") +class SingleImageDataModule(pl.LightningDataModule): + cfg: SingleImageDataModuleConfig + + def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: + super().__init__() + self.cfg = parse_structured(SingleImageDataModuleConfig, cfg) + + def setup(self, stage=None) -> None: + if stage in [None, "fit"]: + self.train_dataset = SingleImageIterableDataset(self.cfg, "train") + if stage in [None, "fit", "validate"]: + self.val_dataset = SingleImageDataset(self.cfg, "val") + if stage in [None, "test", "predict"]: + self.test_dataset = SingleImageDataset(self.cfg, "test") + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: + return DataLoader( + dataset, num_workers=0, batch_size=batch_size, collate_fn=collate_fn + ) + + def train_dataloader(self) -> DataLoader: + return self.general_loader( + self.train_dataset, + batch_size=self.cfg.batch_size, + collate_fn=self.train_dataset.collate, + ) + + def val_dataloader(self) -> DataLoader: + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self) -> DataLoader: + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self) -> DataLoader: + return self.general_loader(self.test_dataset, batch_size=1) \ No newline at end of file diff --git a/threestudio/data/images.py b/threestudio/data/images.py new file mode 100644 index 0000000..c0e70fc --- /dev/null +++ b/threestudio/data/images.py @@ -0,0 +1,351 @@ +import bisect +import math +import os +from dataclasses import dataclass, field + +import cv2 +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset, IterableDataset + +import threestudio +from threestudio import register +from threestudio.data.uncond import ( + RandomCameraDataModuleConfig, + RandomCameraDataset, + RandomCameraIterableDataset, +) +from threestudio.utils.base import Updateable +from threestudio.utils.config import parse_structured +from threestudio.utils.misc import get_rank +from threestudio.utils.ops import ( + get_mvp_matrix, + get_projection_matrix, + get_ray_directions, + get_rays, +) +from threestudio.utils.typing import * + + +@dataclass +class SingleImageDataModuleConfig: + # height and width should be Union[int, List[int]] + # but OmegaConf does not support Union of containers + height: Any = 96 + width: Any = 96 + resolution_milestones: List[int] = field(default_factory=lambda: []) + default_elevation_deg: float = 0.0 + default_azimuth_deg: float = -180.0 + default_camera_distance: float = 1.2 + default_fovy_deg: float = 60.0 + image_path: str = "" + use_random_camera: bool = True + random_camera: dict = field(default_factory=dict) + rays_noise_scale: float = 2e-3 + batch_size: int = 1 + requires_depth: bool = False + requires_normal: bool = False + rays_d_normalize: bool = True + use_mixed_camera_config: bool = False + + +class SingleImageDataBase: + def setup(self, cfg, split): + self.split = split + self.rank = get_rank() + self.cfg: SingleImageDataModuleConfig = cfg + + if self.cfg.use_random_camera: + random_camera_cfg = parse_structured( + RandomCameraDataModuleConfig, self.cfg.get("random_camera", {}) + ) + # FIXME: + if self.cfg.use_mixed_camera_config: + if self.rank % 2 == 0: + random_camera_cfg.camera_distance_range=[self.cfg.default_camera_distance, self.cfg.default_camera_distance] + random_camera_cfg.fovy_range=[self.cfg.default_fovy_deg, self.cfg.default_fovy_deg] + self.fixed_camera_intrinsic = True + else: + self.fixed_camera_intrinsic = False + if split == "train": + self.random_pose_generator = RandomCameraIterableDataset( + random_camera_cfg + ) + else: + self.random_pose_generator = RandomCameraDataset( + random_camera_cfg, split + ) + + elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg]) + azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg]) + camera_distance = torch.FloatTensor([self.cfg.default_camera_distance]) + + elevation = elevation_deg * math.pi / 180 + azimuth = azimuth_deg * math.pi / 180 + camera_position: Float[Tensor, "1 3"] = torch.stack( + [ + camera_distance * torch.cos(elevation) * torch.cos(azimuth), + camera_distance * torch.cos(elevation) * torch.sin(azimuth), + camera_distance * torch.sin(elevation), + ], + dim=-1, + ) + + center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position) + up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None] + + light_position: Float[Tensor, "1 3"] = camera_position + lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1) + right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1) + up = F.normalize(torch.cross(right, lookat), dim=-1) + self.c2w: Float[Tensor, "1 3 4"] = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]], + dim=-1, + ) + self.c2w4x4: Float[Tensor, "B 4 4"] = torch.cat( + [self.c2w, torch.zeros_like(self.c2w[:, :1])], dim=1 + ) + self.c2w4x4[:, 3, 3] = 1.0 + + self.camera_position = camera_position + self.light_position = light_position + self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg + self.camera_distance = camera_distance + self.fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg])) + + self.heights: List[int] = ( + [self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height + ) + self.widths: List[int] = ( + [self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width + ) + assert len(self.heights) == len(self.widths) + self.resolution_milestones: List[int] + if len(self.heights) == 1 and len(self.widths) == 1: + if len(self.cfg.resolution_milestones) > 0: + threestudio.warn( + "Ignoring resolution_milestones since height and width are not changing" + ) + self.resolution_milestones = [-1] + else: + assert len(self.heights) == len(self.cfg.resolution_milestones) + 1 + self.resolution_milestones = [-1] + self.cfg.resolution_milestones + + self.directions_unit_focals = [ + get_ray_directions(H=height, W=width, focal=1.0) + for (height, width) in zip(self.heights, self.widths) + ] + self.focal_lengths = [ + 0.5 * height / torch.tan(0.5 * self.fovy) for height in self.heights + ] + + self.height: int = self.heights[0] + self.width: int = self.widths[0] + self.directions_unit_focal = self.directions_unit_focals[0] + self.focal_length = self.focal_lengths[0] + self.set_rays() + self.load_images() + self.prev_height = self.height + + def set_rays(self): + # get directions by dividing directions_unit_focal by focal length + directions: Float[Tensor, "1 H W 3"] = self.directions_unit_focal[None] + directions[:, :, :, :2] = directions[:, :, :, :2] / self.focal_length + + rays_o, rays_d = get_rays( + directions, + self.c2w, + keepdim=True, + noise_scale=self.cfg.rays_noise_scale, + normalize=self.cfg.rays_d_normalize, + ) + + proj_mtx: Float[Tensor, "4 4"] = get_projection_matrix( + self.fovy, self.width / self.height, 0.01, 100.0 + ) # FIXME: hard-coded near and far + mvp_mtx: Float[Tensor, "4 4"] = get_mvp_matrix(self.c2w, proj_mtx) + + self.rays_o, self.rays_d = rays_o, rays_d + self.mvp_mtx = mvp_mtx + + def load_images(self): + # load image + assert os.path.exists( + self.cfg.image_path + ), f"Could not find image {self.cfg.image_path}!" + rgba = cv2.cvtColor( + cv2.imread(self.cfg.image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA + ) + rgba = ( + cv2.resize( + rgba, (self.width, self.height), interpolation=cv2.INTER_AREA + ).astype(np.float32) + / 255.0 + ) + rgb = rgba[..., :3] + self.rgb: Float[Tensor, "1 H W 3"] = ( + torch.from_numpy(rgb).unsqueeze(0).contiguous().to(self.rank) + ) + self.mask: Float[Tensor, "1 H W 1"] = ( + torch.from_numpy(rgba[..., 3:] > 0.5).unsqueeze(0).to(self.rank) + ) + print( + f"[INFO] single image dataset: load image {self.cfg.image_path} {self.rgb.shape}" + ) + + # load depth + if self.cfg.requires_depth: + depth_path = self.cfg.image_path.replace("_rgba.png", "_depth.png") + assert os.path.exists(depth_path) + depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) + depth = cv2.resize( + depth, (self.width, self.height), interpolation=cv2.INTER_AREA + ) + self.depth: Float[Tensor, "1 H W 1"] = ( + torch.from_numpy(depth.astype(np.float32) / 255.0) + .unsqueeze(0) + .to(self.rank) + ) + print( + f"[INFO] single image dataset: load depth {depth_path} {self.depth.shape}" + ) + else: + self.depth = None + + # load normal + if self.cfg.requires_normal: + normal_path = self.cfg.image_path.replace("_rgba.png", "_normal.png") + assert os.path.exists(normal_path) + normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED) + normal = cv2.resize( + normal, (self.width, self.height), interpolation=cv2.INTER_AREA + ) + self.normal: Float[Tensor, "1 H W 3"] = ( + torch.from_numpy(normal.astype(np.float32) / 255.0) + .unsqueeze(0) + .to(self.rank) + ) + print( + f"[INFO] single image dataset: load normal {normal_path} {self.normal.shape}" + ) + else: + self.normal = None + + def get_all_images(self): + return self.rgb + + def update_step_(self, epoch: int, global_step: int, on_load_weights: bool = False): + size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1 + self.height = self.heights[size_ind] + if self.height == self.prev_height: + return + + self.prev_height = self.height + self.width = self.widths[size_ind] + self.directions_unit_focal = self.directions_unit_focals[size_ind] + self.focal_length = self.focal_lengths[size_ind] + threestudio.debug(f"Training height: {self.height}, width: {self.width}") + self.set_rays() + self.load_images() + + +class SingleImageIterableDataset(IterableDataset, SingleImageDataBase, Updateable): + def __init__(self, cfg: Any, split: str) -> None: + super().__init__() + self.setup(cfg, split) + + def collate(self, batch) -> Dict[str, Any]: + batch = { + "rays_o": self.rays_o, + "rays_d": self.rays_d, + "mvp_mtx": self.mvp_mtx, + "camera_positions": self.camera_position, + "light_positions": self.light_position, + "elevation": self.elevation_deg, + "azimuth": self.azimuth_deg, + "camera_distances": self.camera_distance, + "rgb": self.rgb, + "ref_depth": self.depth, + "ref_normal": self.normal, + "mask": self.mask, + "height": self.cfg.height, + "width": self.cfg.width, + "c2w": self.c2w, + "c2w4x4": self.c2w4x4, + } + if self.cfg.use_random_camera: + batch["random_camera"] = self.random_pose_generator.collate(None) + + return batch + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + self.update_step_(epoch, global_step, on_load_weights) + self.random_pose_generator.update_step(epoch, global_step, on_load_weights) + + def __iter__(self): + while True: + yield {} + + +class SingleImageDataset(Dataset, SingleImageDataBase): + def __init__(self, cfg: Any, split: str) -> None: + super().__init__() + self.setup(cfg, split) + + def __len__(self): + return len(self.random_pose_generator) + + def __getitem__(self, index): + batch = self.random_pose_generator[index] + batch.update( + { + "height": self.random_pose_generator.cfg.eval_height, + "width": self.random_pose_generator.cfg.eval_width, + "mvp_mtx_ref": self.mvp_mtx[0], + "c2w_ref": self.c2w4x4, + } + ) + return batch + + +@register("single-image-datamodule") +class SingleImageDataModule(pl.LightningDataModule): + cfg: SingleImageDataModuleConfig + + def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: + super().__init__() + self.cfg = parse_structured(SingleImageDataModuleConfig, cfg) + + def setup(self, stage=None) -> None: + if stage in [None, "fit"]: + self.train_dataset = SingleImageIterableDataset(self.cfg, "train") + if stage in [None, "fit", "validate"]: + self.val_dataset = SingleImageDataset(self.cfg, "val") + if stage in [None, "test", "predict"]: + self.test_dataset = SingleImageDataset(self.cfg, "test") + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: + return DataLoader( + dataset, num_workers=0, batch_size=batch_size, collate_fn=collate_fn + ) + + def train_dataloader(self) -> DataLoader: + return self.general_loader( + self.train_dataset, + batch_size=self.cfg.batch_size, + collate_fn=self.train_dataset.collate, + ) + + def val_dataloader(self) -> DataLoader: + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self) -> DataLoader: + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self) -> DataLoader: + return self.general_loader(self.test_dataset, batch_size=1) \ No newline at end of file diff --git a/threestudio/data/uncond.py b/threestudio/data/uncond.py new file mode 100644 index 0000000..bb9be13 --- /dev/null +++ b/threestudio/data/uncond.py @@ -0,0 +1,518 @@ +import bisect +import math +import random +from dataclasses import dataclass, field + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset, IterableDataset + +import threestudio +from threestudio import register +from threestudio.utils.base import Updateable +from threestudio.utils.config import parse_structured +from threestudio.utils.misc import get_device +from threestudio.utils.ops import ( + get_full_projection_matrix, + get_mvp_matrix, + get_projection_matrix, + get_ray_directions, + get_rays, +) +from threestudio.utils.typing import * + + +@dataclass +class RandomCameraDataModuleConfig: + # height, width, and batch_size should be Union[int, List[int]] + # but OmegaConf does not support Union of containers + height: Any = 64 + width: Any = 64 + batch_size: Any = 1 + resolution_milestones: List[int] = field(default_factory=lambda: []) + eval_height: int = 512 + eval_width: int = 512 + eval_batch_size: int = 1 + n_val_views: int = 1 + n_test_views: int = 120 + elevation_range: Tuple[float, float] = (-10, 90) + azimuth_range: Tuple[float, float] = (-180, 180) + camera_distance_range: Tuple[float, float] = (1, 1.5) + fovy_range: Tuple[float, float] = ( + 40, + 70, + ) # in degrees, in vertical direction (along height) + camera_perturb: float = 0.1 + center_perturb: float = 0.2 + up_perturb: float = 0.02 + light_position_perturb: float = 1.0 + light_distance_range: Tuple[float, float] = (0.8, 1.5) + eval_elevation_deg: float = 15.0 + eval_camera_distance: float = 1.5 + eval_fovy_deg: float = 70.0 + light_sample_strategy: str = "dreamfusion" + batch_uniform_azimuth: bool = True + progressive_until: int = 0 # progressive ranges for elevation, azimuth, r, fovy + + rays_d_normalize: bool = True + + +class RandomCameraIterableDataset(IterableDataset, Updateable): + def __init__(self, cfg: Any) -> None: + super().__init__() + self.cfg: RandomCameraDataModuleConfig = cfg + self.heights: List[int] = ( + [self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height + ) + self.widths: List[int] = ( + [self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width + ) + self.batch_sizes: List[int] = ( + [self.cfg.batch_size] + if isinstance(self.cfg.batch_size, int) + else self.cfg.batch_size + ) + assert len(self.heights) == len(self.widths) == len(self.batch_sizes) + self.resolution_milestones: List[int] + if ( + len(self.heights) == 1 + and len(self.widths) == 1 + and len(self.batch_sizes) == 1 + ): + if len(self.cfg.resolution_milestones) > 0: + threestudio.warn( + "Ignoring resolution_milestones since height and width are not changing" + ) + self.resolution_milestones = [-1] + else: + assert len(self.heights) == len(self.cfg.resolution_milestones) + 1 + self.resolution_milestones = [-1] + self.cfg.resolution_milestones + + self.directions_unit_focals = [ + get_ray_directions(H=height, W=width, focal=1.0) + for (height, width) in zip(self.heights, self.widths) + ] + self.height: int = self.heights[0] + self.width: int = self.widths[0] + self.batch_size: int = self.batch_sizes[0] + self.directions_unit_focal = self.directions_unit_focals[0] + self.elevation_range = self.cfg.elevation_range + self.azimuth_range = self.cfg.azimuth_range + self.camera_distance_range = self.cfg.camera_distance_range + self.fovy_range = self.cfg.fovy_range + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1 + self.height = self.heights[size_ind] + self.width = self.widths[size_ind] + self.batch_size = self.batch_sizes[size_ind] + self.directions_unit_focal = self.directions_unit_focals[size_ind] + threestudio.debug( + f"Training height: {self.height}, width: {self.width}, batch_size: {self.batch_size}" + ) + # progressive view + self.progressive_view(global_step) + + def __iter__(self): + while True: + yield {} + + def progressive_view(self, global_step): + r = min(1.0, global_step / (self.cfg.progressive_until + 1)) + self.elevation_range = [ + (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[0], + (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[1], + ] + self.azimuth_range = [ + (1 - r) * 0.0 + r * self.cfg.azimuth_range[0], + (1 - r) * 0.0 + r * self.cfg.azimuth_range[1], + ] + # self.camera_distance_range = [ + # (1 - r) * self.cfg.eval_camera_distance + # + r * self.cfg.camera_distance_range[0], + # (1 - r) * self.cfg.eval_camera_distance + # + r * self.cfg.camera_distance_range[1], + # ] + # self.fovy_range = [ + # (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[0], + # (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[1], + # ] + + def collate(self, batch) -> Dict[str, Any]: + # sample elevation angles + elevation_deg: Float[Tensor, "B"] + elevation: Float[Tensor, "B"] + if random.random() < 0.5: + # sample elevation angles uniformly with a probability 0.5 (biased towards poles) + elevation_deg = ( + torch.rand(self.batch_size) + * (self.elevation_range[1] - self.elevation_range[0]) + + self.elevation_range[0] + ) + elevation = elevation_deg * math.pi / 180 + else: + # otherwise sample uniformly on sphere + elevation_range_percent = [ + self.elevation_range[0] / 180.0 * math.pi, + self.elevation_range[1] / 180.0 * math.pi, + ] + # inverse transform sampling + elevation = torch.asin( + ( + torch.rand(self.batch_size) + * ( + math.sin(elevation_range_percent[1]) + - math.sin(elevation_range_percent[0]) + ) + + math.sin(elevation_range_percent[0]) + ) + ) + elevation_deg = elevation / math.pi * 180.0 + + # sample azimuth angles from a uniform distribution bounded by azimuth_range + azimuth_deg: Float[Tensor, "B"] + if self.cfg.batch_uniform_azimuth: + # ensures sampled azimuth angles in a batch cover the whole range + azimuth_deg = ( + torch.rand(self.batch_size) + torch.arange(self.batch_size) + ) / self.batch_size * ( + self.azimuth_range[1] - self.azimuth_range[0] + ) + self.azimuth_range[ + 0 + ] + else: + # simple random sampling + azimuth_deg = ( + torch.rand(self.batch_size) + * (self.azimuth_range[1] - self.azimuth_range[0]) + + self.azimuth_range[0] + ) + azimuth = azimuth_deg * math.pi / 180 + + # sample distances from a uniform distribution bounded by distance_range + camera_distances: Float[Tensor, "B"] = ( + torch.rand(self.batch_size) + * (self.camera_distance_range[1] - self.camera_distance_range[0]) + + self.camera_distance_range[0] + ) + + # convert spherical coordinates to cartesian coordinates + # right hand coordinate system, x back, y right, z up + # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) + camera_positions: Float[Tensor, "B 3"] = torch.stack( + [ + camera_distances * torch.cos(elevation) * torch.cos(azimuth), + camera_distances * torch.cos(elevation) * torch.sin(azimuth), + camera_distances * torch.sin(elevation), + ], + dim=-1, + ) + + # default scene center at origin + center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) + # default camera up direction as +z + up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ + None, : + ].repeat(self.batch_size, 1) + + # sample camera perturbations from a uniform distribution [-camera_perturb, camera_perturb] + camera_perturb: Float[Tensor, "B 3"] = ( + torch.rand(self.batch_size, 3) * 2 * self.cfg.camera_perturb + - self.cfg.camera_perturb + ) + camera_positions = camera_positions + camera_perturb + # sample center perturbations from a normal distribution with mean 0 and std center_perturb + center_perturb: Float[Tensor, "B 3"] = ( + torch.randn(self.batch_size, 3) * self.cfg.center_perturb + ) + center = center + center_perturb + # sample up perturbations from a normal distribution with mean 0 and std up_perturb + up_perturb: Float[Tensor, "B 3"] = ( + torch.randn(self.batch_size, 3) * self.cfg.up_perturb + ) + up = up + up_perturb + + # sample fovs from a uniform distribution bounded by fov_range + fovy_deg: Float[Tensor, "B"] = ( + torch.rand(self.batch_size) * (self.fovy_range[1] - self.fovy_range[0]) + + self.fovy_range[0] + ) + fovy = fovy_deg * math.pi / 180 + + # sample light distance from a uniform distribution bounded by light_distance_range + light_distances: Float[Tensor, "B"] = ( + torch.rand(self.batch_size) + * (self.cfg.light_distance_range[1] - self.cfg.light_distance_range[0]) + + self.cfg.light_distance_range[0] + ) + + if self.cfg.light_sample_strategy == "dreamfusion": + # sample light direction from a normal distribution with mean camera_position and std light_position_perturb + light_direction: Float[Tensor, "B 3"] = F.normalize( + camera_positions + + torch.randn(self.batch_size, 3) * self.cfg.light_position_perturb, + dim=-1, + ) + # get light position by scaling light direction by light distance + light_positions: Float[Tensor, "B 3"] = ( + light_direction * light_distances[:, None] + ) + elif self.cfg.light_sample_strategy == "magic3d": + # sample light direction within restricted angle range (pi/3) + local_z = F.normalize(camera_positions, dim=-1) + local_x = F.normalize( + torch.stack( + [local_z[:, 1], -local_z[:, 0], torch.zeros_like(local_z[:, 0])], + dim=-1, + ), + dim=-1, + ) + local_y = F.normalize(torch.cross(local_z, local_x, dim=-1), dim=-1) + rot = torch.stack([local_x, local_y, local_z], dim=-1) + light_azimuth = ( + torch.rand(self.batch_size) * math.pi * 2 - math.pi + ) # [-pi, pi] + light_elevation = ( + torch.rand(self.batch_size) * math.pi / 3 + math.pi / 6 + ) # [pi/6, pi/2] + light_positions_local = torch.stack( + [ + light_distances + * torch.cos(light_elevation) + * torch.cos(light_azimuth), + light_distances + * torch.cos(light_elevation) + * torch.sin(light_azimuth), + light_distances * torch.sin(light_elevation), + ], + dim=-1, + ) + light_positions = (rot @ light_positions_local[:, :, None])[:, :, 0] + else: + raise ValueError( + f"Unknown light sample strategy: {self.cfg.light_sample_strategy}" + ) + + lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) + right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) + up = F.normalize(torch.cross(right, lookat), dim=-1) + c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], + dim=-1, + ) + c2w: Float[Tensor, "B 4 4"] = torch.cat( + [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 + ) + c2w[:, 3, 3] = 1.0 + + # get directions by dividing directions_unit_focal by focal length + focal_length: Float[Tensor, "B"] = 0.5 * self.height / torch.tan(0.5 * fovy) + directions: Float[Tensor, "B H W 3"] = self.directions_unit_focal[ + None, :, :, : + ].repeat(self.batch_size, 1, 1, 1) + directions[:, :, :, :2] = ( + directions[:, :, :, :2] / focal_length[:, None, None, None] + ) + + # Importance note: the returned rays_d MUST be normalized! + rays_o, rays_d = get_rays( + directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize + ) + + self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix( + fovy, self.width / self.height, 0.1, 1000.0 + ) # FIXME: hard-coded near and far + mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx) + self.fovy = fovy + + return { + "rays_o": rays_o, + "rays_d": rays_d, + "mvp_mtx": mvp_mtx, + "camera_positions": camera_positions, + "c2w": c2w, + "light_positions": light_positions, + "elevation": elevation_deg, + "azimuth": azimuth_deg, + "camera_distances": camera_distances, + "height": self.height, + "width": self.width, + "fovy": self.fovy, + "proj_mtx": self.proj_mtx, + } + + +class RandomCameraDataset(Dataset): + def __init__(self, cfg: Any, split: str) -> None: + super().__init__() + self.cfg: RandomCameraDataModuleConfig = cfg + self.split = split + + if split == "val": + self.n_views = self.cfg.n_val_views + else: + self.n_views = self.cfg.n_test_views + + azimuth_deg: Float[Tensor, "B"] + if self.split == "val": + # make sure the first and last view are not the same + azimuth_deg = torch.linspace(0, 360.0, self.n_views + 1)[: self.n_views] + else: + azimuth_deg = torch.linspace(0, 360.0, self.n_views) + elevation_deg: Float[Tensor, "B"] = torch.full_like( + azimuth_deg, self.cfg.eval_elevation_deg + ) + camera_distances: Float[Tensor, "B"] = torch.full_like( + elevation_deg, self.cfg.eval_camera_distance + ) + + elevation = elevation_deg * math.pi / 180 + azimuth = azimuth_deg * math.pi / 180 + + # convert spherical coordinates to cartesian coordinates + # right hand coordinate system, x back, y right, z up + # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) + camera_positions: Float[Tensor, "B 3"] = torch.stack( + [ + camera_distances * torch.cos(elevation) * torch.cos(azimuth), + camera_distances * torch.cos(elevation) * torch.sin(azimuth), + camera_distances * torch.sin(elevation), + ], + dim=-1, + ) + + # default scene center at origin + center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) + # default camera up direction as +z + up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ + None, : + ].repeat(self.cfg.eval_batch_size, 1) + + fovy_deg: Float[Tensor, "B"] = torch.full_like( + elevation_deg, self.cfg.eval_fovy_deg + ) + fovy = fovy_deg * math.pi / 180 + light_positions: Float[Tensor, "B 3"] = camera_positions + + lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) + right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) + up = F.normalize(torch.cross(right, lookat), dim=-1) + c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], + dim=-1, + ) + c2w: Float[Tensor, "B 4 4"] = torch.cat( + [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 + ) + c2w[:, 3, 3] = 1.0 + + # get directions by dividing directions_unit_focal by focal length + focal_length: Float[Tensor, "B"] = ( + 0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy) + ) + directions_unit_focal = get_ray_directions( + H=self.cfg.eval_height, W=self.cfg.eval_width, focal=1.0 + ) + directions: Float[Tensor, "B H W 3"] = directions_unit_focal[ + None, :, :, : + ].repeat(self.n_views, 1, 1, 1) + directions[:, :, :, :2] = ( + directions[:, :, :, :2] / focal_length[:, None, None, None] + ) + + rays_o, rays_d = get_rays( + directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize + ) + self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix( + fovy, self.cfg.eval_width / self.cfg.eval_height, 0.1, 1000.0 + ) # FIXME: hard-coded near and far + mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx) + + self.rays_o, self.rays_d = rays_o, rays_d + self.mvp_mtx = mvp_mtx + self.c2w = c2w + self.camera_positions = camera_positions + self.light_positions = light_positions + self.elevation, self.azimuth = elevation, azimuth + self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg + self.camera_distances = camera_distances + self.fovy = fovy + + def __len__(self): + return self.n_views + + def __getitem__(self, index): + return { + "index": index, + "rays_o": self.rays_o[index], + "rays_d": self.rays_d[index], + "mvp_mtx": self.mvp_mtx[index], + "c2w": self.c2w[index], + "camera_positions": self.camera_positions[index], + "light_positions": self.light_positions[index], + "elevation": self.elevation_deg[index], + "azimuth": self.azimuth_deg[index], + "camera_distances": self.camera_distances[index], + "height": self.cfg.eval_height, + "width": self.cfg.eval_width, + "fovy": self.fovy[index], + "proj_mtx": self.proj_mtx[index], + } + + def collate(self, batch): + batch = torch.utils.data.default_collate(batch) + batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width}) + return batch + + +@register("random-camera-datamodule") +class RandomCameraDataModule(pl.LightningDataModule): + cfg: RandomCameraDataModuleConfig + + def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: + super().__init__() + self.cfg = parse_structured(RandomCameraDataModuleConfig, cfg) + + def setup(self, stage=None) -> None: + if stage in [None, "fit"]: + self.train_dataset = RandomCameraIterableDataset(self.cfg) + if stage in [None, "fit", "validate"]: + self.val_dataset = RandomCameraDataset(self.cfg, "val") + if stage in [None, "test", "predict"]: + self.test_dataset = RandomCameraDataset(self.cfg, "test") + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: + return DataLoader( + dataset, + # very important to disable multi-processing if you want to change self attributes at runtime! + # (for example setting self.width and self.height in update_step) + num_workers=0, # type: ignore + batch_size=batch_size, + collate_fn=collate_fn, + ) + + def train_dataloader(self) -> DataLoader: + return self.general_loader( + self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate + ) + + def val_dataloader(self) -> DataLoader: + return self.general_loader( + self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate + ) + # return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate) + + def test_dataloader(self) -> DataLoader: + return self.general_loader( + self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate + ) + + def predict_dataloader(self) -> DataLoader: + return self.general_loader( + self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate + ) \ No newline at end of file diff --git a/threestudio/models/__init__.py b/threestudio/models/__init__.py new file mode 100644 index 0000000..9738918 --- /dev/null +++ b/threestudio/models/__init__.py @@ -0,0 +1,9 @@ +from . import ( + background, + exporters, + geometry, + guidance, + materials, + prompt_processors, + renderers, +) diff --git a/threestudio/models/background/__init__.py b/threestudio/models/background/__init__.py new file mode 100644 index 0000000..c637e6b --- /dev/null +++ b/threestudio/models/background/__init__.py @@ -0,0 +1,6 @@ +from . import ( + base, + neural_environment_map_background, + solid_color_background, + textured_background, +) diff --git a/threestudio/models/background/base.py b/threestudio/models/background/base.py new file mode 100644 index 0000000..2d7351c --- /dev/null +++ b/threestudio/models/background/base.py @@ -0,0 +1,24 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.utils.base import BaseModule +from threestudio.utils.typing import * + + +class BaseBackground(BaseModule): + @dataclass + class Config(BaseModule.Config): + pass + + cfg: Config + + def configure(self): + pass + + def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: + raise NotImplementedError \ No newline at end of file diff --git a/threestudio/models/background/neural_environment_map_background.py b/threestudio/models/background/neural_environment_map_background.py new file mode 100644 index 0000000..e6f1ed9 --- /dev/null +++ b/threestudio/models/background/neural_environment_map_background.py @@ -0,0 +1,71 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.models.networks import get_encoding, get_mlp +from threestudio.utils.ops import get_activation +from threestudio.utils.typing import * + + +@threestudio.register("neural-environment-map-background") +class NeuralEnvironmentMapBackground(BaseBackground): + @dataclass + class Config(BaseBackground.Config): + n_output_dims: int = 3 + color_activation: str = "sigmoid" + dir_encoding_config: dict = field( + default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} + ) + mlp_network_config: dict = field( + default_factory=lambda: { + "otype": "VanillaMLP", + "activation": "ReLU", + "n_neurons": 16, + "n_hidden_layers": 2, + } + ) + random_aug: bool = False + random_aug_prob: float = 0.5 + eval_color: Optional[Tuple[float, float, float]] = None + + # multi-view diffusion + share_aug_bg: bool = False + + cfg: Config + + def configure(self) -> None: + self.encoding = get_encoding(3, self.cfg.dir_encoding_config) + self.network = get_mlp( + self.encoding.n_output_dims, + self.cfg.n_output_dims, + self.cfg.mlp_network_config, + ) + + def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: + if not self.training and self.cfg.eval_color is not None: + return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to( + dirs + ) * torch.as_tensor(self.cfg.eval_color).to(dirs) + # viewdirs must be normalized before passing to this function + dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1) + dirs_embd = self.encoding(dirs.view(-1, 3)) + color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims) + color = get_activation(self.cfg.color_activation)(color) + if ( + self.training + and self.cfg.random_aug + and random.random() < self.cfg.random_aug_prob + ): + # use random background color with probability random_aug_prob + n_color = 1 if self.cfg.share_aug_bg else dirs.shape[0] + color = color * 0 + ( # prevent checking for unused parameters in DDP + torch.rand(n_color, 1, 1, self.cfg.n_output_dims) + .to(dirs) + .expand(*dirs.shape[:-1], -1) + ) + return color \ No newline at end of file diff --git a/threestudio/models/background/solid_color_background.py b/threestudio/models/background/solid_color_background.py new file mode 100644 index 0000000..13d1a2a --- /dev/null +++ b/threestudio/models/background/solid_color_background.py @@ -0,0 +1,51 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.utils.typing import * + + +@threestudio.register("solid-color-background") +class SolidColorBackground(BaseBackground): + @dataclass + class Config(BaseBackground.Config): + n_output_dims: int = 3 + color: Tuple = (1.0, 1.0, 1.0) + learned: bool = False + random_aug: bool = False + random_aug_prob: float = 0.5 + + cfg: Config + + def configure(self) -> None: + self.env_color: Float[Tensor, "Nc"] + if self.cfg.learned: + self.env_color = nn.Parameter( + torch.as_tensor(self.cfg.color, dtype=torch.float32) + ) + else: + self.register_buffer( + "env_color", torch.as_tensor(self.cfg.color, dtype=torch.float32) + ) + + def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: + color = torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to( + dirs + ) * self.env_color.to(dirs) + if ( + self.training + and self.cfg.random_aug + and random.random() < self.cfg.random_aug_prob + ): + # use random background color with probability random_aug_prob + color = color * 0 + ( # prevent checking for unused parameters in DDP + torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims) + .to(dirs) + .expand(*dirs.shape[:-1], -1) + ) + return color \ No newline at end of file diff --git a/threestudio/models/background/textured_background.py b/threestudio/models/background/textured_background.py new file mode 100644 index 0000000..d2962b0 --- /dev/null +++ b/threestudio/models/background/textured_background.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.utils.ops import get_activation +from threestudio.utils.typing import * + + +@threestudio.register("textured-background") +class TexturedBackground(BaseBackground): + @dataclass + class Config(BaseBackground.Config): + n_output_dims: int = 3 + height: int = 64 + width: int = 64 + color_activation: str = "sigmoid" + + cfg: Config + + def configure(self) -> None: + self.texture = nn.Parameter( + torch.randn((1, self.cfg.n_output_dims, self.cfg.height, self.cfg.width)) + ) + + def spherical_xyz_to_uv(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 2"]: + x, y, z = dirs[..., 0], dirs[..., 1], dirs[..., 2] + xy = (x**2 + y**2) ** 0.5 + u = torch.atan2(xy, z) / torch.pi + v = torch.atan2(y, x) / (torch.pi * 2) + 0.5 + uv = torch.stack([u, v], -1) + return uv + + def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B Nc"]: + dirs_shape = dirs.shape[:-1] + uv = self.spherical_xyz_to_uv(dirs.reshape(-1, dirs.shape[-1])) + uv = 2 * uv - 1 # rescale to [-1, 1] for grid_sample + uv = uv.reshape(1, -1, 1, 2) + color = ( + F.grid_sample( + self.texture, + uv, + mode="bilinear", + padding_mode="reflection", + align_corners=False, + ) + .reshape(self.cfg.n_output_dims, -1) + .T.reshape(*dirs_shape, self.cfg.n_output_dims) + ) + color = get_activation(self.cfg.color_activation)(color) + return color \ No newline at end of file diff --git a/threestudio/models/estimators.py b/threestudio/models/estimators.py new file mode 100644 index 0000000..38d8ae0 --- /dev/null +++ b/threestudio/models/estimators.py @@ -0,0 +1,118 @@ +from typing import Callable, List, Optional, Tuple + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import torch +from nerfacc.data_specs import RayIntervals +from nerfacc.estimators.base import AbstractEstimator +from nerfacc.pdf import importance_sampling, searchsorted +from nerfacc.volrend import render_transmittance_from_density +from torch import Tensor + + +class ImportanceEstimator(AbstractEstimator): + def __init__( + self, + ) -> None: + super().__init__() + + @torch.no_grad() + def sampling( + self, + prop_sigma_fns: List[Callable], + prop_samples: List[int], + num_samples: int, + # rendering options + n_rays: int, + near_plane: float, + far_plane: float, + sampling_type: Literal["uniform", "lindisp"] = "uniform", + # training options + stratified: bool = False, + requires_grad: bool = False, + ) -> Tuple[Tensor, Tensor]: + """Sampling with CDFs from proposal networks. + + Args: + prop_sigma_fns: Proposal network evaluate functions. It should be a list + of functions that take in samples {t_starts (n_rays, n_samples), + t_ends (n_rays, n_samples)} and returns the post-activation densities + (n_rays, n_samples). + prop_samples: Number of samples to draw from each proposal network. Should + be the same length as `prop_sigma_fns`. + num_samples: Number of samples to draw in the end. + n_rays: Number of rays. + near_plane: Near plane. + far_plane: Far plane. + sampling_type: Sampling type. Either "uniform" or "lindisp". Default to + "lindisp". + stratified: Whether to use stratified sampling. Default to `False`. + + Returns: + A tuple of {Tensor, Tensor}: + + - **t_starts**: The starts of the samples. Shape (n_rays, num_samples). + - **t_ends**: The ends of the samples. Shape (n_rays, num_samples). + + """ + assert len(prop_sigma_fns) == len(prop_samples), ( + "The number of proposal networks and the number of samples " + "should be the same." + ) + cdfs = torch.cat( + [ + torch.zeros((n_rays, 1), device=self.device), + torch.ones((n_rays, 1), device=self.device), + ], + dim=-1, + ) + intervals = RayIntervals(vals=cdfs) + + for level_fn, level_samples in zip(prop_sigma_fns, prop_samples): + intervals, _ = importance_sampling( + intervals, cdfs, level_samples, stratified + ) + t_vals = _transform_stot( + sampling_type, intervals.vals, near_plane, far_plane + ) + t_starts = t_vals[..., :-1] + t_ends = t_vals[..., 1:] + + with torch.set_grad_enabled(requires_grad): + sigmas = level_fn(t_starts, t_ends) + assert sigmas.shape == t_starts.shape + trans, _ = render_transmittance_from_density(t_starts, t_ends, sigmas) + cdfs = 1.0 - torch.cat([trans, torch.zeros_like(trans[:, :1])], dim=-1) + + intervals, _ = importance_sampling(intervals, cdfs, num_samples, stratified) + t_vals_fine = _transform_stot( + sampling_type, intervals.vals, near_plane, far_plane + ) + + t_vals = torch.cat([t_vals, t_vals_fine], dim=-1) + t_vals, _ = torch.sort(t_vals, dim=-1) + + t_starts_ = t_vals[..., :-1] + t_ends_ = t_vals[..., 1:] + + return t_starts_, t_ends_ + + +def _transform_stot( + transform_type: Literal["uniform", "lindisp"], + s_vals: torch.Tensor, + t_min: torch.Tensor, + t_max: torch.Tensor, +) -> torch.Tensor: + if transform_type == "uniform": + _contract_fn, _icontract_fn = lambda x: x, lambda x: x + elif transform_type == "lindisp": + _contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x + else: + raise ValueError(f"Unknown transform_type: {transform_type}") + s_min, s_max = _contract_fn(t_min), _contract_fn(t_max) + icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min) + return icontract_fn(s_vals) \ No newline at end of file diff --git a/threestudio/models/exporters/__init__.py b/threestudio/models/exporters/__init__.py new file mode 100644 index 0000000..add385e --- /dev/null +++ b/threestudio/models/exporters/__init__.py @@ -0,0 +1 @@ +from . import base, mesh_exporter diff --git a/threestudio/models/exporters/base.py b/threestudio/models/exporters/base.py new file mode 100644 index 0000000..4efe41f --- /dev/null +++ b/threestudio/models/exporters/base.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.models.geometry.base import BaseImplicitGeometry +from threestudio.models.materials.base import BaseMaterial +from threestudio.utils.base import BaseObject +from threestudio.utils.typing import * + + +@dataclass +class ExporterOutput: + save_name: str + save_type: str + params: Dict[str, Any] + + +class Exporter(BaseObject): + @dataclass + class Config(BaseObject.Config): + save_video: bool = False + + cfg: Config + + def configure( + self, + geometry: BaseImplicitGeometry, + material: BaseMaterial, + background: BaseBackground, + ) -> None: + @dataclass + class SubModules: + geometry: BaseImplicitGeometry + material: BaseMaterial + background: BaseBackground + + self.sub_modules = SubModules(geometry, material, background) + + @property + def geometry(self) -> BaseImplicitGeometry: + return self.sub_modules.geometry + + @property + def material(self) -> BaseMaterial: + return self.sub_modules.material + + @property + def background(self) -> BaseBackground: + return self.sub_modules.background + + def __call__(self, *args, **kwargs) -> List[ExporterOutput]: + raise NotImplementedError + + +@threestudio.register("dummy-exporter") +class DummyExporter(Exporter): + def __call__(self, *args, **kwargs) -> List[ExporterOutput]: + # DummyExporter does not export anything + return [] \ No newline at end of file diff --git a/threestudio/models/exporters/mesh_exporter.py b/threestudio/models/exporters/mesh_exporter.py new file mode 100644 index 0000000..ed108b2 --- /dev/null +++ b/threestudio/models/exporters/mesh_exporter.py @@ -0,0 +1,175 @@ +from dataclasses import dataclass, field + +import cv2 +import numpy as np +import torch + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.models.exporters.base import Exporter, ExporterOutput +from threestudio.models.geometry.base import BaseImplicitGeometry +from threestudio.models.materials.base import BaseMaterial +from threestudio.models.mesh import Mesh +from threestudio.utils.rasterize import NVDiffRasterizerContext +from threestudio.utils.typing import * + + +@threestudio.register("mesh-exporter") +class MeshExporter(Exporter): + @dataclass + class Config(Exporter.Config): + fmt: str = "obj-mtl" # in ['obj-mtl', 'obj'], TODO: fbx + save_name: str = "model" + save_normal: bool = False + save_uv: bool = True + save_texture: bool = True + texture_size: int = 1024 + texture_format: str = "jpg" + xatlas_chart_options: dict = field(default_factory=dict) + xatlas_pack_options: dict = field(default_factory=dict) + context_type: str = "gl" + + cfg: Config + + def configure( + self, + geometry: BaseImplicitGeometry, + material: BaseMaterial, + background: BaseBackground, + ) -> None: + super().configure(geometry, material, background) + self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device) + + def __call__(self) -> List[ExporterOutput]: + mesh: Mesh = self.geometry.isosurface() + + if self.cfg.fmt == "obj-mtl": + return self.export_obj_with_mtl(mesh) + elif self.cfg.fmt == "obj": + return self.export_obj(mesh) + else: + raise ValueError(f"Unsupported mesh export format: {self.cfg.fmt}") + + def export_obj_with_mtl(self, mesh: Mesh) -> List[ExporterOutput]: + params = { + "mesh": mesh, + "save_mat": True, + "save_normal": self.cfg.save_normal, + "save_uv": self.cfg.save_uv, + "save_vertex_color": False, + "map_Kd": None, # Base Color + "map_Ks": None, # Specular + "map_Bump": None, # Normal + # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering + "map_Pm": None, # Metallic + "map_Pr": None, # Roughness + "map_format": self.cfg.texture_format, + } + + if self.cfg.save_uv: + mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options) + + if self.cfg.save_texture: + threestudio.info("Exporting textures ...") + assert self.cfg.save_uv, "save_uv must be True when save_texture is True" + # clip space transform + uv_clip = mesh.v_tex * 2.0 - 1.0 + # pad to four component coordinate + uv_clip4 = torch.cat( + ( + uv_clip, + torch.zeros_like(uv_clip[..., 0:1]), + torch.ones_like(uv_clip[..., 0:1]), + ), + dim=-1, + ) + # rasterize + rast, _ = self.ctx.rasterize_one( + uv_clip4, mesh.t_tex_idx, (self.cfg.texture_size, self.cfg.texture_size) + ) + + hole_mask = ~(rast[:, :, 3] > 0) + + def uv_padding(image): + uv_padding_size = self.cfg.xatlas_pack_options.get("padding", 2) + inpaint_image = ( + cv2.inpaint( + (image.detach().cpu().numpy() * 255).astype(np.uint8), + (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8), + uv_padding_size, + cv2.INPAINT_TELEA, + ) + / 255.0 + ) + return torch.from_numpy(inpaint_image).to(image) + + # Interpolate world space position + gb_pos, _ = self.ctx.interpolate_one( + mesh.v_pos, rast[None, ...], mesh.t_pos_idx + ) + gb_pos = gb_pos[0] + + # Sample out textures from MLP + geo_out = self.geometry.export(points=gb_pos) + mat_out = self.material.export(points=gb_pos, **geo_out) + + threestudio.info( + "Perform UV padding on texture maps to avoid seams, may take a while ..." + ) + + if "albedo" in mat_out: + params["map_Kd"] = uv_padding(mat_out["albedo"]) + else: + threestudio.warn( + "save_texture is True but no albedo texture found, using default white texture" + ) + if "metallic" in mat_out: + params["map_Pm"] = uv_padding(mat_out["metallic"]) + if "roughness" in mat_out: + params["map_Pr"] = uv_padding(mat_out["roughness"]) + if "bump" in mat_out: + params["map_Bump"] = uv_padding(mat_out["bump"]) + # TODO: map_Ks + return [ + ExporterOutput( + save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params + ) + ] + + def export_obj(self, mesh: Mesh) -> List[ExporterOutput]: + params = { + "mesh": mesh, + "save_mat": False, + "save_normal": self.cfg.save_normal, + "save_uv": self.cfg.save_uv, + "save_vertex_color": False, + "map_Kd": None, # Base Color + "map_Ks": None, # Specular + "map_Bump": None, # Normal + # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering + "map_Pm": None, # Metallic + "map_Pr": None, # Roughness + "map_format": self.cfg.texture_format, + } + + if self.cfg.save_uv: + mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options) + + if self.cfg.save_texture: + threestudio.info("Exporting textures ...") + geo_out = self.geometry.export(points=mesh.v_pos) + mat_out = self.material.export(points=mesh.v_pos, **geo_out) + + if "albedo" in mat_out: + mesh.set_vertex_color(mat_out["albedo"]) + params["save_vertex_color"] = True + else: + threestudio.warn( + "save_texture is True but no albedo texture found, not saving vertex color" + ) + + return [ + ExporterOutput( + save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params + ) + ] \ No newline at end of file diff --git a/threestudio/models/geometry/__init__.py b/threestudio/models/geometry/__init__.py new file mode 100644 index 0000000..499cf4f --- /dev/null +++ b/threestudio/models/geometry/__init__.py @@ -0,0 +1,8 @@ +from . import ( + base, + custom_mesh, + implicit_sdf, + implicit_volume, + tetrahedra_sdf_grid, + volume_grid, +) \ No newline at end of file diff --git a/threestudio/models/geometry/base.py b/threestudio/models/geometry/base.py new file mode 100644 index 0000000..1dd8783 --- /dev/null +++ b/threestudio/models/geometry/base.py @@ -0,0 +1,209 @@ +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.isosurface import ( + IsosurfaceHelper, + MarchingCubeCPUHelper, + MarchingTetrahedraHelper, +) +from threestudio.models.mesh import Mesh +from threestudio.utils.base import BaseModule +from threestudio.utils.ops import chunk_batch, scale_tensor +from threestudio.utils.typing import * + + +def contract_to_unisphere( + x: Float[Tensor, "... 3"], bbox: Float[Tensor, "2 3"], unbounded: bool = False +) -> Float[Tensor, "... 3"]: + if unbounded: + x = scale_tensor(x, bbox, (0, 1)) + x = x * 2 - 1 # aabb is at [-1, 1] + mag = x.norm(dim=-1, keepdim=True) + mask = mag.squeeze(-1) > 1 + x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask]) + x = x / 4 + 0.5 # [-inf, inf] is at [0, 1] + else: + x = scale_tensor(x, bbox, (0, 1)) + return x + + +class BaseGeometry(BaseModule): + @dataclass + class Config(BaseModule.Config): + pass + + cfg: Config + + @staticmethod + def create_from( + other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs + ) -> "BaseGeometry": + raise TypeError( + f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}" + ) + + def export(self, *args, **kwargs) -> Dict[str, Any]: + return {} + + +class BaseImplicitGeometry(BaseGeometry): + @dataclass + class Config(BaseGeometry.Config): + radius: float = 1.0 + isosurface: bool = True + isosurface_method: str = "mt" + isosurface_resolution: int = 128 + isosurface_threshold: Union[float, str] = 0.0 + isosurface_chunk: int = 0 + isosurface_coarse_to_fine: bool = True + isosurface_deformable_grid: bool = False + isosurface_remove_outliers: bool = True + isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01 + + cfg: Config + + def configure(self) -> None: + self.bbox: Float[Tensor, "2 3"] + self.register_buffer( + "bbox", + torch.as_tensor( + [ + [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], + [self.cfg.radius, self.cfg.radius, self.cfg.radius], + ], + dtype=torch.float32, + ), + ) + self.isosurface_helper: Optional[IsosurfaceHelper] = None + self.unbounded: bool = False + + def _initilize_isosurface_helper(self): + if self.cfg.isosurface and self.isosurface_helper is None: + if self.cfg.isosurface_method == "mc-cpu": + self.isosurface_helper = MarchingCubeCPUHelper( + self.cfg.isosurface_resolution + ).to(self.device) + elif self.cfg.isosurface_method == "mt": + self.isosurface_helper = MarchingTetrahedraHelper( + self.cfg.isosurface_resolution, + f"load/tets/{self.cfg.isosurface_resolution}_tets.npz", + ).to(self.device) + else: + raise AttributeError( + "Unknown isosurface method {self.cfg.isosurface_method}" + ) + + def forward( + self, points: Float[Tensor, "*N Di"], output_normal: bool = False + ) -> Dict[str, Float[Tensor, "..."]]: + raise NotImplementedError + + def forward_field( + self, points: Float[Tensor, "*N Di"] + ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: + # return the value of the implicit field, could be density / signed distance + # also return a deformation field if the grid vertices can be optimized + raise NotImplementedError + + def forward_level( + self, field: Float[Tensor, "*N 1"], threshold: float + ) -> Float[Tensor, "*N 1"]: + # return the value of the implicit field, where the zero level set represents the surface + raise NotImplementedError + + def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh: + def batch_func(x): + # scale to bbox as the input vertices are in [0, 1] + field, deformation = self.forward_field( + scale_tensor( + x.to(bbox.device), self.isosurface_helper.points_range, bbox + ), + ) + field = field.to( + x.device + ) # move to the same device as the input (could be CPU) + if deformation is not None: + deformation = deformation.to(x.device) + return field, deformation + + assert self.isosurface_helper is not None + + field, deformation = chunk_batch( + batch_func, + self.cfg.isosurface_chunk, + self.isosurface_helper.grid_vertices, + ) + + threshold: float + + if isinstance(self.cfg.isosurface_threshold, float): + threshold = self.cfg.isosurface_threshold + elif self.cfg.isosurface_threshold == "auto": + eps = 1.0e-5 + threshold = field[field > eps].mean().item() + threestudio.info( + f"Automatically determined isosurface threshold: {threshold}" + ) + else: + raise TypeError( + f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}" + ) + + level = self.forward_level(field, threshold) + mesh: Mesh = self.isosurface_helper(level, deformation=deformation) + mesh.v_pos = scale_tensor( + mesh.v_pos, self.isosurface_helper.points_range, bbox + ) # scale to bbox as the grid vertices are in [0, 1] + mesh.add_extra("bbox", bbox) + + if self.cfg.isosurface_remove_outliers: + # remove outliers components with small number of faces + # only enabled when the mesh is not differentiable + mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold) + + return mesh + + def isosurface(self) -> Mesh: + if not self.cfg.isosurface: + raise NotImplementedError( + "Isosurface is not enabled in the current configuration" + ) + self._initilize_isosurface_helper() + if self.cfg.isosurface_coarse_to_fine: + threestudio.debug("First run isosurface to get a tight bounding box ...") + with torch.no_grad(): + mesh_coarse = self._isosurface(self.bbox) + vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0) + vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0]) + vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1]) + threestudio.debug("Run isosurface again with the tight bounding box ...") + mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True) + else: + mesh = self._isosurface(self.bbox) + return mesh + + +class BaseExplicitGeometry(BaseGeometry): + @dataclass + class Config(BaseGeometry.Config): + radius: float = 1.0 + + cfg: Config + + def configure(self) -> None: + self.bbox: Float[Tensor, "2 3"] + self.register_buffer( + "bbox", + torch.as_tensor( + [ + [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], + [self.cfg.radius, self.cfg.radius, self.cfg.radius], + ], + dtype=torch.float32, + ), + ) \ No newline at end of file diff --git a/threestudio/models/geometry/custom_mesh.py b/threestudio/models/geometry/custom_mesh.py new file mode 100644 index 0000000..6284584 --- /dev/null +++ b/threestudio/models/geometry/custom_mesh.py @@ -0,0 +1,178 @@ +import os +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.geometry.base import ( + BaseExplicitGeometry, + BaseGeometry, + contract_to_unisphere, +) +from threestudio.models.mesh import Mesh +from threestudio.models.networks import get_encoding, get_mlp +from threestudio.utils.ops import scale_tensor +from threestudio.utils.typing import * + + +@threestudio.register("custom-mesh") +class CustomMesh(BaseExplicitGeometry): + @dataclass + class Config(BaseExplicitGeometry.Config): + n_input_dims: int = 3 + n_feature_dims: int = 3 + pos_encoding_config: dict = field( + default_factory=lambda: { + "otype": "HashGrid", + "n_levels": 16, + "n_features_per_level": 2, + "log2_hashmap_size": 19, + "base_resolution": 16, + "per_level_scale": 1.447269237440378, + } + ) + mlp_network_config: dict = field( + default_factory=lambda: { + "otype": "VanillaMLP", + "activation": "ReLU", + "output_activation": "none", + "n_neurons": 64, + "n_hidden_layers": 1, + } + ) + shape_init: str = "" + shape_init_params: Optional[Any] = None + shape_init_mesh_up: str = "+z" + shape_init_mesh_front: str = "+x" + + cfg: Config + + def configure(self) -> None: + super().configure() + + self.encoding = get_encoding( + self.cfg.n_input_dims, self.cfg.pos_encoding_config + ) + self.feature_network = get_mlp( + self.encoding.n_output_dims, + self.cfg.n_feature_dims, + self.cfg.mlp_network_config, + ) + + # Initialize custom mesh + if self.cfg.shape_init.startswith("mesh:"): + assert isinstance(self.cfg.shape_init_params, float) + mesh_path = self.cfg.shape_init[5:] + if not os.path.exists(mesh_path): + raise ValueError(f"Mesh file {mesh_path} does not exist.") + + import trimesh + + scene = trimesh.load(mesh_path) + if isinstance(scene, trimesh.Trimesh): + mesh = scene + elif isinstance(scene, trimesh.scene.Scene): + mesh = trimesh.Trimesh() + for obj in scene.geometry.values(): + mesh = trimesh.util.concatenate([mesh, obj]) + else: + raise ValueError(f"Unknown mesh type at {mesh_path}.") + + # move to center + centroid = mesh.vertices.mean(0) + mesh.vertices = mesh.vertices - centroid + + # align to up-z and front-x + dirs = ["+x", "+y", "+z", "-x", "-y", "-z"] + dir2vec = { + "+x": np.array([1, 0, 0]), + "+y": np.array([0, 1, 0]), + "+z": np.array([0, 0, 1]), + "-x": np.array([-1, 0, 0]), + "-y": np.array([0, -1, 0]), + "-z": np.array([0, 0, -1]), + } + if ( + self.cfg.shape_init_mesh_up not in dirs + or self.cfg.shape_init_mesh_front not in dirs + ): + raise ValueError( + f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}." + ) + if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]: + raise ValueError( + "shape_init_mesh_up and shape_init_mesh_front must be orthogonal." + ) + z_, x_ = ( + dir2vec[self.cfg.shape_init_mesh_up], + dir2vec[self.cfg.shape_init_mesh_front], + ) + y_ = np.cross(z_, x_) + std2mesh = np.stack([x_, y_, z_], axis=0).T + mesh2std = np.linalg.inv(std2mesh) + + # scaling + scale = np.abs(mesh.vertices).max() + mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params + mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T + + v_pos = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device) + t_pos_idx = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device) + self.mesh = Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx) + self.register_buffer( + "v_buffer", + v_pos, + ) + self.register_buffer( + "t_buffer", + t_pos_idx, + ) + + else: + raise ValueError( + f"Unknown shape initialization type: {self.cfg.shape_init}" + ) + print(self.mesh.v_pos.device) + + def isosurface(self) -> Mesh: + if hasattr(self, "mesh"): + return self.mesh + elif hasattr(self, "v_buffer"): + self.mesh = Mesh(v_pos=self.v_buffer, t_pos_idx=self.t_buffer) + return self.mesh + else: + raise ValueError(f"custom mesh is not initialized") + + def forward( + self, points: Float[Tensor, "*N Di"], output_normal: bool = False + ) -> Dict[str, Float[Tensor, "..."]]: + assert ( + output_normal == False + ), f"Normal output is not supported for {self.__class__.__name__}" + points_unscaled = points # points in the original scale + points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1) + enc = self.encoding(points.view(-1, self.cfg.n_input_dims)) + features = self.feature_network(enc).view( + *points.shape[:-1], self.cfg.n_feature_dims + ) + return {"features": features} + + def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: + out: Dict[str, Any] = {} + if self.cfg.n_feature_dims == 0: + return out + points_unscaled = points + points = contract_to_unisphere(points_unscaled, self.bbox) + enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims)) + features = self.feature_network(enc).view( + *points.shape[:-1], self.cfg.n_feature_dims + ) + out.update( + { + "features": features, + } + ) + return out \ No newline at end of file diff --git a/threestudio/models/geometry/implicit_sdf.py b/threestudio/models/geometry/implicit_sdf.py new file mode 100644 index 0000000..8ece576 --- /dev/null +++ b/threestudio/models/geometry/implicit_sdf.py @@ -0,0 +1,413 @@ +import os +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere +from threestudio.models.mesh import Mesh +from threestudio.models.networks import get_encoding, get_mlp +from threestudio.utils.misc import broadcast, get_rank +from threestudio.utils.typing import * + + +@threestudio.register("implicit-sdf") +class ImplicitSDF(BaseImplicitGeometry): + @dataclass + class Config(BaseImplicitGeometry.Config): + n_input_dims: int = 3 + n_feature_dims: int = 3 + pos_encoding_config: dict = field( + default_factory=lambda: { + "otype": "HashGrid", + "n_levels": 16, + "n_features_per_level": 2, + "log2_hashmap_size": 19, + "base_resolution": 16, + "per_level_scale": 1.447269237440378, + } + ) + mlp_network_config: dict = field( + default_factory=lambda: { + "otype": "VanillaMLP", + "activation": "ReLU", + "output_activation": "none", + "n_neurons": 64, + "n_hidden_layers": 1, + } + ) + normal_type: Optional[ + str + ] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian'] + finite_difference_normal_eps: Union[ + float, str + ] = 0.01 # in [float, "progressive"] + shape_init: Optional[str] = None + shape_init_params: Optional[Any] = None + shape_init_mesh_up: str = "+z" + shape_init_mesh_front: str = "+x" + force_shape_init: bool = False + sdf_bias: Union[float, str] = 0.0 + sdf_bias_params: Optional[Any] = None + + # no need to removal outlier for SDF + isosurface_remove_outliers: bool = False + + cfg: Config + + def configure(self) -> None: + super().configure() + self.encoding = get_encoding( + self.cfg.n_input_dims, self.cfg.pos_encoding_config + ) + self.sdf_network = get_mlp( + self.encoding.n_output_dims, 1, self.cfg.mlp_network_config + ) + + if self.cfg.n_feature_dims > 0: + self.feature_network = get_mlp( + self.encoding.n_output_dims, + self.cfg.n_feature_dims, + self.cfg.mlp_network_config, + ) + + if self.cfg.normal_type == "pred": + self.normal_network = get_mlp( + self.encoding.n_output_dims, 3, self.cfg.mlp_network_config + ) + if self.cfg.isosurface_deformable_grid: + assert ( + self.cfg.isosurface_method == "mt" + ), "isosurface_deformable_grid only works with mt" + self.deformation_network = get_mlp( + self.encoding.n_output_dims, 3, self.cfg.mlp_network_config + ) + + self.finite_difference_normal_eps: Optional[float] = None + + def initialize_shape(self) -> None: + if self.cfg.shape_init is None and not self.cfg.force_shape_init: + return + + # do not initialize shape if weights are provided + if self.cfg.weights is not None and not self.cfg.force_shape_init: + return + + if self.cfg.sdf_bias != 0.0: + threestudio.warn( + "shape_init and sdf_bias are both specified, which may lead to unexpected results." + ) + + get_gt_sdf: Callable[[Float[Tensor, "N 3"]], Float[Tensor, "N 1"]] + assert isinstance(self.cfg.shape_init, str) + if self.cfg.shape_init == "ellipsoid": + assert ( + isinstance(self.cfg.shape_init_params, Sized) + and len(self.cfg.shape_init_params) == 3 + ) + size = torch.as_tensor(self.cfg.shape_init_params).to(self.device) + + def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: + return ((points_rand / size) ** 2).sum( + dim=-1, keepdim=True + ).sqrt() - 1.0 # pseudo signed distance of an ellipsoid + + get_gt_sdf = func + elif self.cfg.shape_init == "sphere": + assert isinstance(self.cfg.shape_init_params, float) + radius = self.cfg.shape_init_params + + def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: + return (points_rand**2).sum(dim=-1, keepdim=True).sqrt() - radius + + get_gt_sdf = func + elif self.cfg.shape_init.startswith("mesh:"): + assert isinstance(self.cfg.shape_init_params, float) + mesh_path = self.cfg.shape_init[5:] + if not os.path.exists(mesh_path): + raise ValueError(f"Mesh file {mesh_path} does not exist.") + + import trimesh + + scene = trimesh.load(mesh_path) + if isinstance(scene, trimesh.Trimesh): + mesh = scene + elif isinstance(scene, trimesh.scene.Scene): + mesh = trimesh.Trimesh() + for obj in scene.geometry.values(): + mesh = trimesh.util.concatenate([mesh, obj]) + else: + raise ValueError(f"Unknown mesh type at {mesh_path}.") + + # move to center + centroid = mesh.vertices.mean(0) + mesh.vertices = mesh.vertices - centroid + + # align to up-z and front-x + dirs = ["+x", "+y", "+z", "-x", "-y", "-z"] + dir2vec = { + "+x": np.array([1, 0, 0]), + "+y": np.array([0, 1, 0]), + "+z": np.array([0, 0, 1]), + "-x": np.array([-1, 0, 0]), + "-y": np.array([0, -1, 0]), + "-z": np.array([0, 0, -1]), + } + if ( + self.cfg.shape_init_mesh_up not in dirs + or self.cfg.shape_init_mesh_front not in dirs + ): + raise ValueError( + f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}." + ) + if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]: + raise ValueError( + "shape_init_mesh_up and shape_init_mesh_front must be orthogonal." + ) + z_, x_ = ( + dir2vec[self.cfg.shape_init_mesh_up], + dir2vec[self.cfg.shape_init_mesh_front], + ) + y_ = np.cross(z_, x_) + std2mesh = np.stack([x_, y_, z_], axis=0).T + mesh2std = np.linalg.inv(std2mesh) + + # scaling + scale = np.abs(mesh.vertices).max() + mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params + mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T + + from pysdf import SDF + + sdf = SDF(mesh.vertices, mesh.faces) + + def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: + # add a negative signed here + # as in pysdf the inside of the shape has positive signed distance + return torch.from_numpy(-sdf(points_rand.cpu().numpy())).to( + points_rand + )[..., None] + + get_gt_sdf = func + + else: + raise ValueError( + f"Unknown shape initialization type: {self.cfg.shape_init}" + ) + + # Initialize SDF to a given shape when no weights are provided or force_shape_init is True + optim = torch.optim.Adam(self.parameters(), lr=1e-3) + from tqdm import tqdm + + for _ in tqdm( + range(1000), + desc=f"Initializing SDF to a(n) {self.cfg.shape_init}:", + disable=get_rank() != 0, + ): + points_rand = ( + torch.rand((10000, 3), dtype=torch.float32).to(self.device) * 2.0 - 1.0 + ) + sdf_gt = get_gt_sdf(points_rand) + sdf_pred = self.forward_sdf(points_rand) + loss = F.mse_loss(sdf_pred, sdf_gt) + optim.zero_grad() + loss.backward() + optim.step() + + # explicit broadcast to ensure param consistency across ranks + for param in self.parameters(): + broadcast(param, src=0) + + def get_shifted_sdf( + self, points: Float[Tensor, "*N Di"], sdf: Float[Tensor, "*N 1"] + ) -> Float[Tensor, "*N 1"]: + sdf_bias: Union[float, Float[Tensor, "*N 1"]] + if self.cfg.sdf_bias == "ellipsoid": + assert ( + isinstance(self.cfg.sdf_bias_params, Sized) + and len(self.cfg.sdf_bias_params) == 3 + ) + size = torch.as_tensor(self.cfg.sdf_bias_params).to(points) + sdf_bias = ((points / size) ** 2).sum( + dim=-1, keepdim=True + ).sqrt() - 1.0 # pseudo signed distance of an ellipsoid + elif self.cfg.sdf_bias == "sphere": + assert isinstance(self.cfg.sdf_bias_params, float) + radius = self.cfg.sdf_bias_params + sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius + elif isinstance(self.cfg.sdf_bias, float): + sdf_bias = self.cfg.sdf_bias + else: + raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}") + return sdf + sdf_bias + + def forward( + self, points: Float[Tensor, "*N Di"], output_normal: bool = False + ) -> Dict[str, Float[Tensor, "..."]]: + grad_enabled = torch.is_grad_enabled() + + if output_normal and self.cfg.normal_type == "analytic": + torch.set_grad_enabled(True) + points.requires_grad_(True) + + points_unscaled = points # points in the original scale + points = contract_to_unisphere( + points, self.bbox, self.unbounded + ) # points normalized to (0, 1) + + enc = self.encoding(points.view(-1, self.cfg.n_input_dims)) + sdf = self.sdf_network(enc).view(*points.shape[:-1], 1) + sdf = self.get_shifted_sdf(points_unscaled, sdf) + output = {"sdf": sdf} + + if self.cfg.n_feature_dims > 0: + features = self.feature_network(enc).view( + *points.shape[:-1], self.cfg.n_feature_dims + ) + output.update({"features": features}) + + if output_normal: + if ( + self.cfg.normal_type == "finite_difference" + or self.cfg.normal_type == "finite_difference_laplacian" + ): + assert self.finite_difference_normal_eps is not None + eps: float = self.finite_difference_normal_eps + if self.cfg.normal_type == "finite_difference_laplacian": + offsets: Float[Tensor, "6 3"] = torch.as_tensor( + [ + [eps, 0.0, 0.0], + [-eps, 0.0, 0.0], + [0.0, eps, 0.0], + [0.0, -eps, 0.0], + [0.0, 0.0, eps], + [0.0, 0.0, -eps], + ] + ).to(points_unscaled) + points_offset: Float[Tensor, "... 6 3"] = ( + points_unscaled[..., None, :] + offsets + ).clamp(-self.cfg.radius, self.cfg.radius) + sdf_offset: Float[Tensor, "... 6 1"] = self.forward_sdf( + points_offset + ) + sdf_grad = ( + 0.5 + * (sdf_offset[..., 0::2, 0] - sdf_offset[..., 1::2, 0]) + / eps + ) + else: + offsets: Float[Tensor, "3 3"] = torch.as_tensor( + [[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]] + ).to(points_unscaled) + points_offset: Float[Tensor, "... 3 3"] = ( + points_unscaled[..., None, :] + offsets + ).clamp(-self.cfg.radius, self.cfg.radius) + sdf_offset: Float[Tensor, "... 3 1"] = self.forward_sdf( + points_offset + ) + sdf_grad = (sdf_offset[..., 0::1, 0] - sdf) / eps + normal = F.normalize(sdf_grad, dim=-1) + elif self.cfg.normal_type == "pred": + normal = self.normal_network(enc).view(*points.shape[:-1], 3) + normal = F.normalize(normal, dim=-1) + sdf_grad = normal + elif self.cfg.normal_type == "analytic": + sdf_grad = -torch.autograd.grad( + sdf, + points_unscaled, + grad_outputs=torch.ones_like(sdf), + create_graph=True, + )[0] + normal = F.normalize(sdf_grad, dim=-1) + if not grad_enabled: + sdf_grad = sdf_grad.detach() + normal = normal.detach() + else: + raise AttributeError(f"Unknown normal type {self.cfg.normal_type}") + output.update( + {"normal": normal, "shading_normal": normal, "sdf_grad": sdf_grad} + ) + return output + + def forward_sdf(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]: + points_unscaled = points + points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) + + sdf = self.sdf_network( + self.encoding(points.reshape(-1, self.cfg.n_input_dims)) + ).reshape(*points.shape[:-1], 1) + sdf = self.get_shifted_sdf(points_unscaled, sdf) + return sdf + + def forward_field( + self, points: Float[Tensor, "*N Di"] + ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: + points_unscaled = points + points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) + enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims)) + sdf = self.sdf_network(enc).reshape(*points.shape[:-1], 1) + sdf = self.get_shifted_sdf(points_unscaled, sdf) + deformation: Optional[Float[Tensor, "*N 3"]] = None + if self.cfg.isosurface_deformable_grid: + deformation = self.deformation_network(enc).reshape(*points.shape[:-1], 3) + return sdf, deformation + + def forward_level( + self, field: Float[Tensor, "*N 1"], threshold: float + ) -> Float[Tensor, "*N 1"]: + return field - threshold + + def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: + out: Dict[str, Any] = {} + if self.cfg.n_feature_dims == 0: + return out + points_unscaled = points + points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) + enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims)) + features = self.feature_network(enc).view( + *points.shape[:-1], self.cfg.n_feature_dims + ) + out.update( + { + "features": features, + } + ) + return out + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + if ( + self.cfg.normal_type == "finite_difference" + or self.cfg.normal_type == "finite_difference_laplacian" + ): + if isinstance(self.cfg.finite_difference_normal_eps, float): + self.finite_difference_normal_eps = ( + self.cfg.finite_difference_normal_eps + ) + elif self.cfg.finite_difference_normal_eps == "progressive": + # progressive finite difference eps from Neuralangelo + # https://arxiv.org/abs/2306.03092 + hg_conf: Any = self.cfg.pos_encoding_config + assert ( + hg_conf.otype == "ProgressiveBandHashGrid" + ), "finite_difference_normal_eps=progressive only works with ProgressiveBandHashGrid" + current_level = min( + hg_conf.start_level + + max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps, + hg_conf.n_levels, + ) + grid_res = hg_conf.base_resolution * hg_conf.per_level_scale ** ( + current_level - 1 + ) + grid_size = 2 * self.cfg.radius / grid_res + if grid_size != self.finite_difference_normal_eps: + threestudio.info( + f"Update finite_difference_normal_eps to {grid_size}" + ) + self.finite_difference_normal_eps = grid_size + else: + raise ValueError( + f"Unknown finite_difference_normal_eps={self.cfg.finite_difference_normal_eps}" + ) \ No newline at end of file diff --git a/threestudio/models/geometry/implicit_volume.py b/threestudio/models/geometry/implicit_volume.py new file mode 100644 index 0000000..fff50a3 --- /dev/null +++ b/threestudio/models/geometry/implicit_volume.py @@ -0,0 +1,325 @@ +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.geometry.base import ( + BaseGeometry, + BaseImplicitGeometry, + contract_to_unisphere, +) +from threestudio.models.networks import get_encoding, get_mlp +from threestudio.utils.ops import get_activation +from threestudio.utils.typing import * + + +@threestudio.register("implicit-volume") +class ImplicitVolume(BaseImplicitGeometry): + @dataclass + class Config(BaseImplicitGeometry.Config): + n_input_dims: int = 3 + n_feature_dims: int = 3 + density_activation: Optional[str] = "softplus" + density_bias: Union[float, str] = "blob_magic3d" + density_blob_scale: float = 10.0 + density_blob_std: float = 0.5 + pos_encoding_config: dict = field( + default_factory=lambda: { + "otype": "HashGrid", + "n_levels": 16, + "n_features_per_level": 2, + "log2_hashmap_size": 19, + "base_resolution": 16, + "per_level_scale": 1.447269237440378, + } + ) + mlp_network_config: dict = field( + default_factory=lambda: { + "otype": "VanillaMLP", + "activation": "ReLU", + "output_activation": "none", + "n_neurons": 64, + "n_hidden_layers": 1, + } + ) + normal_type: Optional[ + str + ] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian'] + finite_difference_normal_eps: Union[ + float, str + ] = 0.01 # in [float, "progressive"] + + # automatically determine the threshold + isosurface_threshold: Union[float, str] = 25.0 + + # 4D Gaussian Annealing + anneal_density_blob_std_config: Optional[dict] = None + + cfg: Config + + def configure(self) -> None: + super().configure() + self.encoding = get_encoding( + self.cfg.n_input_dims, self.cfg.pos_encoding_config + ) + self.density_network = get_mlp( + self.encoding.n_output_dims, 1, self.cfg.mlp_network_config + ) + if self.cfg.n_feature_dims > 0: + self.feature_network = get_mlp( + self.encoding.n_output_dims, + self.cfg.n_feature_dims, + self.cfg.mlp_network_config, + ) + if self.cfg.normal_type == "pred": + self.normal_network = get_mlp( + self.encoding.n_output_dims, 3, self.cfg.mlp_network_config + ) + + self.finite_difference_normal_eps: Optional[float] = None + + def get_activated_density( + self, points: Float[Tensor, "*N Di"], density: Float[Tensor, "*N 1"] + ) -> Tuple[Float[Tensor, "*N 1"], Float[Tensor, "*N 1"]]: + density_bias: Union[float, Float[Tensor, "*N 1"]] + if self.cfg.density_bias == "blob_dreamfusion": + # pre-activation density bias + density_bias = ( + self.cfg.density_blob_scale + * torch.exp( + -0.5 * (points**2).sum(dim=-1) / self.cfg.density_blob_std**2 + )[..., None] + ) + elif self.cfg.density_bias == "blob_magic3d": + # pre-activation density bias + density_bias = ( + self.cfg.density_blob_scale + * ( + 1 + - torch.sqrt((points**2).sum(dim=-1)) / self.cfg.density_blob_std + )[..., None] + ) + elif isinstance(self.cfg.density_bias, float): + density_bias = self.cfg.density_bias + else: + raise ValueError(f"Unknown density bias {self.cfg.density_bias}") + raw_density: Float[Tensor, "*N 1"] = density + density_bias + density = get_activation(self.cfg.density_activation)(raw_density) + return raw_density, density + + def forward( + self, points: Float[Tensor, "*N Di"], output_normal: bool = False + ) -> Dict[str, Float[Tensor, "..."]]: + grad_enabled = torch.is_grad_enabled() + + if output_normal and self.cfg.normal_type == "analytic": + torch.set_grad_enabled(True) + points.requires_grad_(True) + + points_unscaled = points # points in the original scale + points = contract_to_unisphere( + points, self.bbox, self.unbounded + ) # points normalized to (0, 1) + + enc = self.encoding(points.view(-1, self.cfg.n_input_dims)) + density = self.density_network(enc).view(*points.shape[:-1], 1) + raw_density, density = self.get_activated_density(points_unscaled, density) + + output = { + "density": density, + } + + if self.cfg.n_feature_dims > 0: + features = self.feature_network(enc).view( + *points.shape[:-1], self.cfg.n_feature_dims + ) + output.update({"features": features}) + + if output_normal: + if ( + self.cfg.normal_type == "finite_difference" + or self.cfg.normal_type == "finite_difference_laplacian" + ): + # TODO: use raw density + assert self.finite_difference_normal_eps is not None + eps: float = self.finite_difference_normal_eps + if self.cfg.normal_type == "finite_difference_laplacian": + offsets: Float[Tensor, "6 3"] = torch.as_tensor( + [ + [eps, 0.0, 0.0], + [-eps, 0.0, 0.0], + [0.0, eps, 0.0], + [0.0, -eps, 0.0], + [0.0, 0.0, eps], + [0.0, 0.0, -eps], + ] + ).to(points_unscaled) + points_offset: Float[Tensor, "... 6 3"] = ( + points_unscaled[..., None, :] + offsets + ).clamp(-self.cfg.radius, self.cfg.radius) + density_offset: Float[Tensor, "... 6 1"] = self.forward_density( + points_offset + ) + normal = ( + -0.5 + * (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0]) + / eps + ) + else: + offsets: Float[Tensor, "3 3"] = torch.as_tensor( + [[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]] + ).to(points_unscaled) + points_offset: Float[Tensor, "... 3 3"] = ( + points_unscaled[..., None, :] + offsets + ).clamp(-self.cfg.radius, self.cfg.radius) + density_offset: Float[Tensor, "... 3 1"] = self.forward_density( + points_offset + ) + normal = -(density_offset[..., 0::1, 0] - density) / eps + normal = F.normalize(normal, dim=-1) + elif self.cfg.normal_type == "pred": + normal = self.normal_network(enc).view(*points.shape[:-1], 3) + normal = F.normalize(normal, dim=-1) + elif self.cfg.normal_type == "analytic": + normal = -torch.autograd.grad( + density, + points_unscaled, + grad_outputs=torch.ones_like(density), + create_graph=True, + )[0] + normal = F.normalize(normal, dim=-1) + if not grad_enabled: + normal = normal.detach() + else: + raise AttributeError(f"Unknown normal type {self.cfg.normal_type}") + output.update({"normal": normal, "shading_normal": normal}) + + torch.set_grad_enabled(grad_enabled) + return output + + def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]: + points_unscaled = points + points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) + + density = self.density_network( + self.encoding(points.reshape(-1, self.cfg.n_input_dims)) + ).reshape(*points.shape[:-1], 1) + + _, density = self.get_activated_density(points_unscaled, density) + return density + + def forward_field( + self, points: Float[Tensor, "*N Di"] + ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: + if self.cfg.isosurface_deformable_grid: + threestudio.warn( + f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring." + ) + density = self.forward_density(points) + return density, None + + def forward_level( + self, field: Float[Tensor, "*N 1"], threshold: float + ) -> Float[Tensor, "*N 1"]: + return -(field - threshold) + + def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: + out: Dict[str, Any] = {} + if self.cfg.n_feature_dims == 0: + return out + points_unscaled = points + points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) + enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims)) + features = self.feature_network(enc).view( + *points.shape[:-1], self.cfg.n_feature_dims + ) + out.update( + { + "features": features, + } + ) + return out + + @staticmethod + @torch.no_grad() + def create_from( + other: BaseGeometry, + cfg: Optional[Union[dict, DictConfig]] = None, + copy_net: bool = True, + **kwargs, + ) -> "ImplicitVolume": + if isinstance(other, ImplicitVolume): + instance = ImplicitVolume(cfg, **kwargs) + instance.encoding.load_state_dict(other.encoding.state_dict()) + instance.density_network.load_state_dict(other.density_network.state_dict()) + if copy_net: + if ( + instance.cfg.n_feature_dims > 0 + and other.cfg.n_feature_dims == instance.cfg.n_feature_dims + ): + instance.feature_network.load_state_dict( + other.feature_network.state_dict() + ) + if ( + instance.cfg.normal_type == "pred" + and other.cfg.normal_type == "pred" + ): + instance.normal_network.load_state_dict( + other.normal_network.state_dict() + ) + return instance + else: + raise TypeError( + f"Cannot create {ImplicitVolume.__name__} from {other.__class__.__name__}" + ) + + # FIXME: use progressive normal eps + def update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ) -> None: + if self.cfg.anneal_density_blob_std_config is not None: + min_step = self.cfg.anneal_density_blob_std_config.min_anneal_step + max_step = self.cfg.anneal_density_blob_std_config.max_anneal_step + if global_step >= min_step and global_step <= max_step: + end_val = self.cfg.anneal_density_blob_std_config.end_val + start_val = self.cfg.anneal_density_blob_std_config.start_val + self.density_blob_std = start_val + (global_step - min_step) * ( + end_val - start_val + ) / (max_step - min_step) + + if ( + self.cfg.normal_type == "finite_difference" + or self.cfg.normal_type == "finite_difference_laplacian" + ): + if isinstance(self.cfg.finite_difference_normal_eps, float): + self.finite_difference_normal_eps = ( + self.cfg.finite_difference_normal_eps + ) + elif self.cfg.finite_difference_normal_eps == "progressive": + # progressive finite difference eps from Neuralangelo + # https://arxiv.org/abs/2306.03092 + hg_conf: Any = self.cfg.pos_encoding_config + assert ( + hg_conf.otype == "ProgressiveBandHashGrid" + ), "finite_difference_normal_eps=progressive only works with ProgressiveBandHashGrid" + current_level = min( + hg_conf.start_level + + max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps, + hg_conf.n_levels, + ) + grid_res = hg_conf.base_resolution * hg_conf.per_level_scale ** ( + current_level - 1 + ) + grid_size = 2 * self.cfg.radius / grid_res + if grid_size != self.finite_difference_normal_eps: + threestudio.info( + f"Update finite_difference_normal_eps to {grid_size}" + ) + self.finite_difference_normal_eps = grid_size + else: + raise ValueError( + f"Unknown finite_difference_normal_eps={self.cfg.finite_difference_normal_eps}" + ) \ No newline at end of file diff --git a/threestudio/models/geometry/tetrahedra_sdf_grid.py b/threestudio/models/geometry/tetrahedra_sdf_grid.py new file mode 100644 index 0000000..c5ebea0 --- /dev/null +++ b/threestudio/models/geometry/tetrahedra_sdf_grid.py @@ -0,0 +1,369 @@ +import os +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.geometry.base import ( + BaseExplicitGeometry, + BaseGeometry, + contract_to_unisphere, +) +from threestudio.models.geometry.implicit_sdf import ImplicitSDF +from threestudio.models.geometry.implicit_volume import ImplicitVolume +from threestudio.models.isosurface import MarchingTetrahedraHelper +from threestudio.models.mesh import Mesh +from threestudio.models.networks import get_encoding, get_mlp +from threestudio.utils.misc import broadcast +from threestudio.utils.ops import scale_tensor +from threestudio.utils.typing import * + + +@threestudio.register("tetrahedra-sdf-grid") +class TetrahedraSDFGrid(BaseExplicitGeometry): + @dataclass + class Config(BaseExplicitGeometry.Config): + isosurface_resolution: int = 128 + isosurface_deformable_grid: bool = True + isosurface_remove_outliers: bool = False + isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01 + + n_input_dims: int = 3 + n_feature_dims: int = 3 + pos_encoding_config: dict = field( + default_factory=lambda: { + "otype": "HashGrid", + "n_levels": 16, + "n_features_per_level": 2, + "log2_hashmap_size": 19, + "base_resolution": 16, + "per_level_scale": 1.447269237440378, + } + ) + mlp_network_config: dict = field( + default_factory=lambda: { + "otype": "VanillaMLP", + "activation": "ReLU", + "output_activation": "none", + "n_neurons": 64, + "n_hidden_layers": 1, + } + ) + shape_init: Optional[str] = None + shape_init_params: Optional[Any] = None + shape_init_mesh_up: str = "+z" + shape_init_mesh_front: str = "+x" + force_shape_init: bool = False + geometry_only: bool = False + fix_geometry: bool = False + + cfg: Config + + def configure(self) -> None: + super().configure() + + # this should be saved to state_dict, register as buffer + self.isosurface_bbox: Float[Tensor, "2 3"] + self.register_buffer("isosurface_bbox", self.bbox.clone()) + + self.isosurface_helper = MarchingTetrahedraHelper( + self.cfg.isosurface_resolution, + f"load/tets/{self.cfg.isosurface_resolution}_tets.npz", + ) + + self.sdf: Float[Tensor, "Nv 1"] + self.deformation: Optional[Float[Tensor, "Nv 3"]] + + if not self.cfg.fix_geometry: + self.register_parameter( + "sdf", + nn.Parameter( + torch.zeros( + (self.isosurface_helper.grid_vertices.shape[0], 1), + dtype=torch.float32, + ) + ), + ) + if self.cfg.isosurface_deformable_grid: + self.register_parameter( + "deformation", + nn.Parameter( + torch.zeros_like(self.isosurface_helper.grid_vertices) + ), + ) + else: + self.deformation = None + else: + self.register_buffer( + "sdf", + torch.zeros( + (self.isosurface_helper.grid_vertices.shape[0], 1), + dtype=torch.float32, + ), + ) + if self.cfg.isosurface_deformable_grid: + self.register_buffer( + "deformation", + torch.zeros_like(self.isosurface_helper.grid_vertices), + ) + else: + self.deformation = None + + if not self.cfg.geometry_only: + self.encoding = get_encoding( + self.cfg.n_input_dims, self.cfg.pos_encoding_config + ) + self.feature_network = get_mlp( + self.encoding.n_output_dims, + self.cfg.n_feature_dims, + self.cfg.mlp_network_config, + ) + + self.mesh: Optional[Mesh] = None + + def initialize_shape(self) -> None: + if self.cfg.shape_init is None and not self.cfg.force_shape_init: + return + + # do not initialize shape if weights are provided + if self.cfg.weights is not None and not self.cfg.force_shape_init: + return + + get_gt_sdf: Callable[[Float[Tensor, "N 3"]], Float[Tensor, "N 1"]] + assert isinstance(self.cfg.shape_init, str) + if self.cfg.shape_init == "ellipsoid": + assert ( + isinstance(self.cfg.shape_init_params, Sized) + and len(self.cfg.shape_init_params) == 3 + ) + size = torch.as_tensor(self.cfg.shape_init_params).to(self.device) + + def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: + return ((points_rand / size) ** 2).sum( + dim=-1, keepdim=True + ).sqrt() - 1.0 # pseudo signed distance of an ellipsoid + + get_gt_sdf = func + elif self.cfg.shape_init == "sphere": + assert isinstance(self.cfg.shape_init_params, float) + radius = self.cfg.shape_init_params + + def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: + return (points_rand**2).sum(dim=-1, keepdim=True).sqrt() - radius + + get_gt_sdf = func + elif self.cfg.shape_init.startswith("mesh:"): + assert isinstance(self.cfg.shape_init_params, float) + mesh_path = self.cfg.shape_init[5:] + if not os.path.exists(mesh_path): + raise ValueError(f"Mesh file {mesh_path} does not exist.") + + import trimesh + + mesh = trimesh.load(mesh_path) + + # move to center + centroid = mesh.vertices.mean(0) + mesh.vertices = mesh.vertices - centroid + + # align to up-z and front-x + dirs = ["+x", "+y", "+z", "-x", "-y", "-z"] + dir2vec = { + "+x": np.array([1, 0, 0]), + "+y": np.array([0, 1, 0]), + "+z": np.array([0, 0, 1]), + "-x": np.array([-1, 0, 0]), + "-y": np.array([0, -1, 0]), + "-z": np.array([0, 0, -1]), + } + if ( + self.cfg.shape_init_mesh_up not in dirs + or self.cfg.shape_init_mesh_front not in dirs + ): + raise ValueError( + f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}." + ) + if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]: + raise ValueError( + "shape_init_mesh_up and shape_init_mesh_front must be orthogonal." + ) + z_, x_ = ( + dir2vec[self.cfg.shape_init_mesh_up], + dir2vec[self.cfg.shape_init_mesh_front], + ) + y_ = np.cross(z_, x_) + std2mesh = np.stack([x_, y_, z_], axis=0).T + mesh2std = np.linalg.inv(std2mesh) + + # scaling + scale = np.abs(mesh.vertices).max() + mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params + mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T + + from pysdf import SDF + + sdf = SDF(mesh.vertices, mesh.faces) + + def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: + # add a negative signed here + # as in pysdf the inside of the shape has positive signed distance + return torch.from_numpy(-sdf(points_rand.cpu().numpy())).to( + points_rand + )[..., None] + + get_gt_sdf = func + + else: + raise ValueError( + f"Unknown shape initialization type: {self.cfg.shape_init}" + ) + + sdf_gt = get_gt_sdf( + scale_tensor( + self.isosurface_helper.grid_vertices, + self.isosurface_helper.points_range, + self.isosurface_bbox, + ) + ) + self.sdf.data = sdf_gt + + # explicit broadcast to ensure param consistency across ranks + for param in self.parameters(): + broadcast(param, src=0) + + def isosurface(self) -> Mesh: + # return cached mesh if fix_geometry is True to save computation + if self.cfg.fix_geometry and self.mesh is not None: + return self.mesh + mesh = self.isosurface_helper(self.sdf, self.deformation) + mesh.v_pos = scale_tensor( + mesh.v_pos, self.isosurface_helper.points_range, self.isosurface_bbox + ) + if self.cfg.isosurface_remove_outliers: + mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold) + self.mesh = mesh + return mesh + + def forward( + self, points: Float[Tensor, "*N Di"], output_normal: bool = False + ) -> Dict[str, Float[Tensor, "..."]]: + if self.cfg.geometry_only: + return {} + assert ( + output_normal == False + ), f"Normal output is not supported for {self.__class__.__name__}" + points_unscaled = points # points in the original scale + points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1) + enc = self.encoding(points.view(-1, self.cfg.n_input_dims)) + features = self.feature_network(enc).view( + *points.shape[:-1], self.cfg.n_feature_dims + ) + return {"features": features} + + @staticmethod + @torch.no_grad() + def create_from( + other: BaseGeometry, + cfg: Optional[Union[dict, DictConfig]] = None, + copy_net: bool = True, + **kwargs, + ) -> "TetrahedraSDFGrid": + if isinstance(other, TetrahedraSDFGrid): + instance = TetrahedraSDFGrid(cfg, **kwargs) + assert instance.cfg.isosurface_resolution == other.cfg.isosurface_resolution + instance.isosurface_bbox = other.isosurface_bbox.clone() + instance.sdf.data = other.sdf.data.clone() + if ( + instance.cfg.isosurface_deformable_grid + and other.cfg.isosurface_deformable_grid + ): + assert ( + instance.deformation is not None and other.deformation is not None + ) + instance.deformation.data = other.deformation.data.clone() + if ( + not instance.cfg.geometry_only + and not other.cfg.geometry_only + and copy_net + ): + instance.encoding.load_state_dict(other.encoding.state_dict()) + instance.feature_network.load_state_dict( + other.feature_network.state_dict() + ) + return instance + elif isinstance(other, ImplicitVolume): + instance = TetrahedraSDFGrid(cfg, **kwargs) + if other.cfg.isosurface_method != "mt": + other.cfg.isosurface_method = "mt" + threestudio.warn( + f"Override isosurface_method of the source geometry to 'mt'" + ) + if other.cfg.isosurface_resolution != instance.cfg.isosurface_resolution: + other.cfg.isosurface_resolution = instance.cfg.isosurface_resolution + threestudio.warn( + f"Override isosurface_resolution of the source geometry to {instance.cfg.isosurface_resolution}" + ) + mesh = other.isosurface() + instance.isosurface_bbox = mesh.extras["bbox"] + instance.sdf.data = ( + mesh.extras["grid_level"].to(instance.sdf.data).clamp(-1, 1) + ) + if not instance.cfg.geometry_only and copy_net: + instance.encoding.load_state_dict(other.encoding.state_dict()) + instance.feature_network.load_state_dict( + other.feature_network.state_dict() + ) + return instance + elif isinstance(other, ImplicitSDF): + instance = TetrahedraSDFGrid(cfg, **kwargs) + if other.cfg.isosurface_method != "mt": + other.cfg.isosurface_method = "mt" + threestudio.warn( + f"Override isosurface_method of the source geometry to 'mt'" + ) + if other.cfg.isosurface_resolution != instance.cfg.isosurface_resolution: + other.cfg.isosurface_resolution = instance.cfg.isosurface_resolution + threestudio.warn( + f"Override isosurface_resolution of the source geometry to {instance.cfg.isosurface_resolution}" + ) + mesh = other.isosurface() + instance.isosurface_bbox = mesh.extras["bbox"] + instance.sdf.data = mesh.extras["grid_level"].to(instance.sdf.data) + if ( + instance.cfg.isosurface_deformable_grid + and other.cfg.isosurface_deformable_grid + ): + assert instance.deformation is not None + instance.deformation.data = mesh.extras["grid_deformation"].to( + instance.deformation.data + ) + if not instance.cfg.geometry_only and copy_net: + instance.encoding.load_state_dict(other.encoding.state_dict()) + instance.feature_network.load_state_dict( + other.feature_network.state_dict() + ) + return instance + else: + raise TypeError( + f"Cannot create {TetrahedraSDFGrid.__name__} from {other.__class__.__name__}" + ) + + def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: + out: Dict[str, Any] = {} + if self.cfg.geometry_only or self.cfg.n_feature_dims == 0: + return out + points_unscaled = points + points = contract_to_unisphere(points_unscaled, self.bbox) + enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims)) + features = self.feature_network(enc).view( + *points.shape[:-1], self.cfg.n_feature_dims + ) + out.update( + { + "features": features, + } + ) + return out \ No newline at end of file diff --git a/threestudio/models/geometry/volume_grid.py b/threestudio/models/geometry/volume_grid.py new file mode 100644 index 0000000..700648c --- /dev/null +++ b/threestudio/models/geometry/volume_grid.py @@ -0,0 +1,190 @@ +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere +from threestudio.utils.ops import get_activation +from threestudio.utils.typing import * + + +@threestudio.register("volume-grid") +class VolumeGrid(BaseImplicitGeometry): + @dataclass + class Config(BaseImplicitGeometry.Config): + grid_size: Tuple[int, int, int] = field(default_factory=lambda: (100, 100, 100)) + n_feature_dims: int = 3 + density_activation: Optional[str] = "softplus" + density_bias: Union[float, str] = "blob" + density_blob_scale: float = 5.0 + density_blob_std: float = 0.5 + normal_type: Optional[ + str + ] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian'] + + # automatically determine the threshold + isosurface_threshold: Union[float, str] = "auto" + + cfg: Config + + def configure(self) -> None: + super().configure() + self.grid_size = self.cfg.grid_size + + self.grid = nn.Parameter( + torch.zeros(1, self.cfg.n_feature_dims + 1, *self.grid_size) + ) + if self.cfg.density_bias == "blob": + self.register_buffer("density_scale", torch.tensor(0.0)) + else: + self.density_scale = nn.Parameter(torch.tensor(0.0)) + + if self.cfg.normal_type == "pred": + self.normal_grid = nn.Parameter(torch.zeros(1, 3, *self.grid_size)) + + def get_density_bias(self, points: Float[Tensor, "*N Di"]): + if self.cfg.density_bias == "blob": + # density_bias: Float[Tensor, "*N 1"] = self.cfg.density_blob_scale * torch.exp(-0.5 * (points ** 2).sum(dim=-1) / self.cfg.density_blob_std ** 2)[...,None] + density_bias: Float[Tensor, "*N 1"] = ( + self.cfg.density_blob_scale + * ( + 1 + - torch.sqrt((points.detach() ** 2).sum(dim=-1)) + / self.cfg.density_blob_std + )[..., None] + ) + return density_bias + elif isinstance(self.cfg.density_bias, float): + return self.cfg.density_bias + else: + raise AttributeError(f"Unknown density bias {self.cfg.density_bias}") + + def get_trilinear_feature( + self, points: Float[Tensor, "*N Di"], grid: Float[Tensor, "1 Df G1 G2 G3"] + ) -> Float[Tensor, "*N Df"]: + points_shape = points.shape[:-1] + df = grid.shape[1] + di = points.shape[-1] + out = F.grid_sample( + grid, points.view(1, 1, 1, -1, di), align_corners=False, mode="bilinear" + ) + out = out.reshape(df, -1).T.reshape(*points_shape, df) + return out + + def forward( + self, points: Float[Tensor, "*N Di"], output_normal: bool = False + ) -> Dict[str, Float[Tensor, "..."]]: + points_unscaled = points # points in the original scale + points = contract_to_unisphere( + points, self.bbox, self.unbounded + ) # points normalized to (0, 1) + points = points * 2 - 1 # convert to [-1, 1] for grid sample + + out = self.get_trilinear_feature(points, self.grid) + density, features = out[..., 0:1], out[..., 1:] + density = density * torch.exp(self.density_scale) # exp scaling in DreamFusion + + # breakpoint() + density = get_activation(self.cfg.density_activation)( + density + self.get_density_bias(points_unscaled) + ) + + output = { + "density": density, + "features": features, + } + + if output_normal: + if ( + self.cfg.normal_type == "finite_difference" + or self.cfg.normal_type == "finite_difference_laplacian" + ): + eps = 1.0e-3 + if self.cfg.normal_type == "finite_difference_laplacian": + offsets: Float[Tensor, "6 3"] = torch.as_tensor( + [ + [eps, 0.0, 0.0], + [-eps, 0.0, 0.0], + [0.0, eps, 0.0], + [0.0, -eps, 0.0], + [0.0, 0.0, eps], + [0.0, 0.0, -eps], + ] + ).to(points_unscaled) + points_offset: Float[Tensor, "... 6 3"] = ( + points_unscaled[..., None, :] + offsets + ).clamp(-self.cfg.radius, self.cfg.radius) + density_offset: Float[Tensor, "... 6 1"] = self.forward_density( + points_offset + ) + normal = ( + -0.5 + * (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0]) + / eps + ) + else: + offsets: Float[Tensor, "3 3"] = torch.as_tensor( + [[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]] + ).to(points_unscaled) + points_offset: Float[Tensor, "... 3 3"] = ( + points_unscaled[..., None, :] + offsets + ).clamp(-self.cfg.radius, self.cfg.radius) + density_offset: Float[Tensor, "... 3 1"] = self.forward_density( + points_offset + ) + normal = -(density_offset[..., 0::1, 0] - density) / eps + normal = F.normalize(normal, dim=-1) + elif self.cfg.normal_type == "pred": + normal = self.get_trilinear_feature(points, self.normal_grid) + normal = F.normalize(normal, dim=-1) + else: + raise AttributeError(f"Unknown normal type {self.cfg.normal_type}") + output.update({"normal": normal, "shading_normal": normal}) + return output + + def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]: + points_unscaled = points + points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) + points = points * 2 - 1 # convert to [-1, 1] for grid sample + + out = self.get_trilinear_feature(points, self.grid) + density = out[..., 0:1] + density = density * torch.exp(self.density_scale) + + density = get_activation(self.cfg.density_activation)( + density + self.get_density_bias(points_unscaled) + ) + return density + + def forward_field( + self, points: Float[Tensor, "*N Di"] + ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: + if self.cfg.isosurface_deformable_grid: + threestudio.warn( + f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring." + ) + density = self.forward_density(points) + return density, None + + def forward_level( + self, field: Float[Tensor, "*N 1"], threshold: float + ) -> Float[Tensor, "*N 1"]: + return -(field - threshold) + + def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: + out: Dict[str, Any] = {} + if self.cfg.n_feature_dims == 0: + return out + points_unscaled = points + points = contract_to_unisphere(points, self.bbox, self.unbounded) + points = points * 2 - 1 # convert to [-1, 1] for grid sample + features = self.get_trilinear_feature(points, self.grid)[..., 1:] + out.update( + { + "features": features, + } + ) + return out \ No newline at end of file diff --git a/threestudio/models/guidance/__init__.py b/threestudio/models/guidance/__init__.py new file mode 100644 index 0000000..70e41ec --- /dev/null +++ b/threestudio/models/guidance/__init__.py @@ -0,0 +1,13 @@ +from . import ( + controlnet_guidance, + controlnet_reg_guidance, + deep_floyd_guidance, + stable_diffusion_guidance, + stable_diffusion_unified_guidance, + stable_diffusion_vsd_guidance, + stable_diffusion_bsd_guidance, + stable_zero123_guidance, + zero123_guidance, + zero123_unified_guidance, + clip_guidance, +) diff --git a/threestudio/models/guidance/clip_guidance.py b/threestudio/models/guidance/clip_guidance.py new file mode 100644 index 0000000..3d11f79 --- /dev/null +++ b/threestudio/models/guidance/clip_guidance.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass +import torch +import torch.nn.functional as F +import torchvision.transforms as T +import clip + +import threestudio +from threestudio.utils.base import BaseObject +from threestudio.models.prompt_processors.base import PromptProcessorOutput +from threestudio.utils.typing import * + + +@threestudio.register("clip-guidance") +class CLIPGuidance(BaseObject): + @dataclass + class Config(BaseObject.Config): + cache_dir: Optional[str] = None + pretrained_model_name_or_path: str = "ViT-B/16" + view_dependent_prompting: bool = True + + cfg: Config + + def configure(self) -> None: + threestudio.info(f"Loading CLIP ...") + self.clip_model, self.clip_preprocess = clip.load( + self.cfg.pretrained_model_name_or_path, + device=self.device, + jit=False, + download_root=self.cfg.cache_dir + ) + + self.aug = T.Compose([ + T.Resize((224, 224)), + T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + threestudio.info(f"Loaded CLIP!") + + @torch.cuda.amp.autocast(enabled=False) + def get_embedding(self, input_value, is_text=True): + if is_text: + value = clip.tokenize(input_value).to(self.device) + z = self.clip_model.encode_text(value) + else: + input_value = self.aug(input_value) + z = self.clip_model.encode_image(input_value) + + return z / z.norm(dim=-1, keepdim=True) + + def get_loss(self, image_z, clip_z, loss_type='similarity_score', use_mean=True): + if loss_type == 'similarity_score': + loss = -((image_z * clip_z).sum(-1)) + elif loss_type == 'spherical_dist': + image_z, clip_z = F.normalize(image_z, dim=-1), F.normalize(clip_z, dim=-1) + loss = ((image_z - clip_z).norm(dim=-1).div(2).arcsin().pow(2).mul(2)) + else: + raise NotImplementedError + + return loss.mean() if use_mean else loss + + def __call__( + self, + pred_rgb: Float[Tensor, "B H W C"], + gt_rgb: Float[Tensor, "B H W C"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + embedding_type: str = 'both', + loss_type: Optional[str] = 'similarity_score', + **kwargs, + ): + clip_text_loss, clip_img_loss = 0, 0 + + if embedding_type in ('both', 'text'): + text_embeddings = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting + ).chunk(2)[0] + clip_text_loss = self.get_loss(self.get_embedding(pred_rgb, is_text=False), text_embeddings, loss_type=loss_type) + + if embedding_type in ('both', 'img'): + clip_img_loss = self.get_loss(self.get_embedding(pred_rgb, is_text=False), self.get_embedding(gt_rgb, is_text=False), loss_type=loss_type) + + return clip_text_loss + clip_img_loss diff --git a/threestudio/models/guidance/controlnet_guidance.py b/threestudio/models/guidance/controlnet_guidance.py new file mode 100644 index 0000000..11af9b7 --- /dev/null +++ b/threestudio/models/guidance/controlnet_guidance.py @@ -0,0 +1,517 @@ +import os +from dataclasses import dataclass + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from controlnet_aux import CannyDetector, NormalBaeDetector +from diffusers import ControlNetModel, DDIMScheduler, StableDiffusionControlNetPipeline +from diffusers.utils.import_utils import is_xformers_available +from tqdm import tqdm + +import threestudio +from threestudio.models.prompt_processors.base import PromptProcessorOutput +from threestudio.utils.base import BaseObject +from threestudio.utils.misc import C, parse_version +from threestudio.utils.perceptual import PerceptualLoss +from threestudio.utils.typing import * + + +@threestudio.register("stable-diffusion-controlnet-guidance") +class ControlNetGuidance(BaseObject): + @dataclass + class Config(BaseObject.Config): + cache_dir: Optional[str] = None + pretrained_model_name_or_path: str = "SG161222/Realistic_Vision_V2.0" + ddim_scheduler_name_or_path: str = "runwayml/stable-diffusion-v1-5" + control_type: str = "normal" # normal/canny + + enable_memory_efficient_attention: bool = False + enable_sequential_cpu_offload: bool = False + enable_attention_slicing: bool = False + enable_channels_last_format: bool = False + guidance_scale: float = 7.5 + condition_scale: float = 1.5 + grad_clip: Optional[Any] = None + half_precision_weights: bool = True + + fixed_size: int = -1 + + min_step_percent: float = 0.02 + max_step_percent: float = 0.98 + + diffusion_steps: int = 20 + + use_sds: bool = False + + use_du: bool = False + per_du_step: int = 10 + start_du_step: int = 1000 + cache_du: bool = False + + # Canny threshold + canny_lower_bound: int = 50 + canny_upper_bound: int = 100 + + cfg: Config + + def configure(self) -> None: + threestudio.info(f"Loading ControlNet ...") + + controlnet_name_or_path: str + if self.cfg.control_type in ("normal", "input_normal"): + controlnet_name_or_path = "lllyasviel/control_v11p_sd15_normalbae" + elif self.cfg.control_type == "canny": + controlnet_name_or_path = "lllyasviel/control_v11p_sd15_canny" + + self.weights_dtype = ( + torch.float16 if self.cfg.half_precision_weights else torch.float32 + ) + + pipe_kwargs = { + "safety_checker": None, + "feature_extractor": None, + "requires_safety_checker": False, + "torch_dtype": self.weights_dtype, + "cache_dir": self.cfg.cache_dir, + } + + controlnet = ControlNetModel.from_pretrained( + controlnet_name_or_path, + torch_dtype=self.weights_dtype, + cache_dir=self.cfg.cache_dir, + ) + self.pipe = StableDiffusionControlNetPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, controlnet=controlnet, **pipe_kwargs + ).to(self.device) + self.scheduler = DDIMScheduler.from_pretrained( + self.cfg.ddim_scheduler_name_or_path, + subfolder="scheduler", + torch_dtype=self.weights_dtype, + cache_dir=self.cfg.cache_dir, + ) + self.scheduler.set_timesteps(self.cfg.diffusion_steps) + + if self.cfg.enable_memory_efficient_attention: + if parse_version(torch.__version__) >= parse_version("2"): + threestudio.info( + "PyTorch2.0 uses memory efficient attention by default." + ) + elif not is_xformers_available(): + threestudio.warn( + "xformers is not available, memory efficient attention is not enabled." + ) + else: + self.pipe.enable_xformers_memory_efficient_attention() + + if self.cfg.enable_sequential_cpu_offload: + self.pipe.enable_sequential_cpu_offload() + + if self.cfg.enable_attention_slicing: + self.pipe.enable_attention_slicing(1) + + if self.cfg.enable_channels_last_format: + self.pipe.unet.to(memory_format=torch.channels_last) + + # Create model + self.vae = self.pipe.vae.eval() + self.unet = self.pipe.unet.eval() + self.controlnet = self.pipe.controlnet.eval() + + if self.cfg.control_type == "normal": + self.preprocessor = NormalBaeDetector.from_pretrained( + "lllyasviel/Annotators" + ) + self.preprocessor.model.to(self.device) + elif self.cfg.control_type == "canny": + self.preprocessor = CannyDetector() + + for p in self.vae.parameters(): + p.requires_grad_(False) + for p in self.unet.parameters(): + p.requires_grad_(False) + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.set_min_max_steps() # set to default value + + self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to( + self.device + ) + + self.grad_clip_val: Optional[float] = None + + if self.cfg.use_du: + if self.cfg.cache_du: + self.edit_frames = {} + self.perceptual_loss = PerceptualLoss().eval().to(self.device) + + threestudio.info(f"Loaded ControlNet!") + + @torch.cuda.amp.autocast(enabled=False) + def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): + self.min_step = int(self.num_train_timesteps * min_step_percent) + self.max_step = int(self.num_train_timesteps * max_step_percent) + + @torch.cuda.amp.autocast(enabled=False) + def forward_controlnet( + self, + latents: Float[Tensor, "..."], + t: Float[Tensor, "..."], + image_cond: Float[Tensor, "..."], + condition_scale: float, + encoder_hidden_states: Float[Tensor, "..."], + ) -> Float[Tensor, "..."]: + return self.controlnet( + latents.to(self.weights_dtype), + t.to(self.weights_dtype), + encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype), + controlnet_cond=image_cond.to(self.weights_dtype), + conditioning_scale=condition_scale, + return_dict=False, + ) + + @torch.cuda.amp.autocast(enabled=False) + def forward_control_unet( + self, + latents: Float[Tensor, "..."], + t: Float[Tensor, "..."], + encoder_hidden_states: Float[Tensor, "..."], + cross_attention_kwargs, + down_block_additional_residuals, + mid_block_additional_residual, + ) -> Float[Tensor, "..."]: + input_dtype = latents.dtype + return self.unet( + latents.to(self.weights_dtype), + t.to(self.weights_dtype), + encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype), + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + ).sample.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def encode_images( + self, imgs: Float[Tensor, "B 3 H W"] + ) -> Float[Tensor, "B 4 DH DW"]: + input_dtype = imgs.dtype + imgs = imgs * 2.0 - 1.0 + posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + return latents.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def encode_cond_images( + self, imgs: Float[Tensor, "B 3 H W"] + ) -> Float[Tensor, "B 4 DH DW"]: + input_dtype = imgs.dtype + imgs = imgs * 2.0 - 1.0 + posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist + latents = posterior.mode() + uncond_image_latents = torch.zeros_like(latents) + latents = torch.cat([latents, latents, uncond_image_latents], dim=0) + return latents.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def decode_latents( + self, latents: Float[Tensor, "B 4 DH DW"] + ) -> Float[Tensor, "B 3 H W"]: + input_dtype = latents.dtype + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents.to(self.weights_dtype)).sample + image = (image * 0.5 + 0.5).clamp(0, 1) + return image.to(input_dtype) + + def edit_latents( + self, + text_embeddings: Float[Tensor, "BB 77 768"], + latents: Float[Tensor, "B 4 DH DW"], + image_cond: Float[Tensor, "B 3 H W"], + t: Int[Tensor, "B"], + mask = None + ) -> Float[Tensor, "B 4 DH DW"]: + self.scheduler.config.num_train_timesteps = t.item() + self.scheduler.set_timesteps(self.cfg.diffusion_steps) + if mask is not None: + mask = F.interpolate(mask, (latents.shape[-2], latents.shape[-1]), mode='bilinear') + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents = self.scheduler.add_noise(latents, noise, t) # type: ignore + + # sections of code used from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py + threestudio.debug("Start editing...") + for i, t in enumerate(self.scheduler.timesteps): + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # pred noise + latent_model_input = torch.cat([latents] * 2) + ( + down_block_res_samples, + mid_block_res_sample, + ) = self.forward_controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + image_cond=image_cond, + condition_scale=self.cfg.condition_scale, + ) + + noise_pred = self.forward_control_unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + # perform classifier-free guidance + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if mask is not None: + noise_pred = mask * noise_pred + (1 - mask) * noise + # get previous sample, continue loop + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + threestudio.debug("Editing finished.") + return latents + + def prepare_image_cond(self, cond_rgb: Float[Tensor, "B H W C"]): + if self.cfg.control_type == "normal": + cond_rgb = ( + (cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy() + ) + detected_map = self.preprocessor(cond_rgb) + control = ( + torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0 + ) + control = control.unsqueeze(0) + control = control.permute(0, 3, 1, 2) + elif self.cfg.control_type == "canny": + cond_rgb = ( + (cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy() + ) + blurred_img = cv2.blur(cond_rgb, ksize=(5, 5)) + detected_map = self.preprocessor( + blurred_img, self.cfg.canny_lower_bound, self.cfg.canny_upper_bound + ) + control = ( + torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0 + ) + # control = control.unsqueeze(-1).repeat(1, 1, 3) + control = control.unsqueeze(0) + control = control.permute(0, 3, 1, 2) + elif self.cfg.control_type == "input_normal": + cond_rgb[..., 0] = ( + 1 - cond_rgb[..., 0] + ) # Flip the sign on the x-axis to match bae system + control = cond_rgb.permute(0, 3, 1, 2) + else: + raise ValueError(f"Unknown control type: {self.cfg.control_type}") + + return control + + def compute_grad_sds( + self, + text_embeddings: Float[Tensor, "BB 77 768"], + latents: Float[Tensor, "B 4 DH DW"], + image_cond: Float[Tensor, "B 3 H W"], + t: Int[Tensor, "B"], + ): + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) # TODO: use torch generator + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + down_block_res_samples, mid_block_res_sample = self.forward_controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + image_cond=image_cond, + condition_scale=self.cfg.condition_scale, + ) + + noise_pred = self.forward_control_unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + # perform classifier-free guidance + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + w = (1 - self.alphas[t]).view(-1, 1, 1, 1) + grad = w * (noise_pred - noise) + return grad + + def compute_grad_du( + self, + latents: Float[Tensor, "B 4 H W"], + rgb_BCHW_HW8: Float[Tensor, "B 3 RH RW"], + cond_feature: Float[Tensor, "B 3 RH RW"], + cond_rgb: Float[Tensor, "B H W 3"], + text_embeddings: Float[Tensor, "BB 77 768"], + mask = None, + **kwargs, + ): + batch_size, _, RH, RW = cond_feature.shape + assert batch_size == 1 + + origin_gt_rgb = F.interpolate( + cond_rgb.permute(0, 3, 1, 2), (RH, RW), mode="bilinear" + ).permute(0, 2, 3, 1) + need_diffusion = ( + self.global_step % self.cfg.per_du_step == 0 + and self.global_step > self.cfg.start_du_step + ) + if self.cfg.cache_du: + if torch.is_tensor(kwargs["index"]): + batch_index = kwargs["index"].item() + else: + batch_index = kwargs["index"] + if ( + not (batch_index in self.edit_frames) + ) and self.global_step > self.cfg.start_du_step: + need_diffusion = True + need_loss = self.cfg.cache_du or need_diffusion + guidance_out = {} + + if need_diffusion: + t = torch.randint( + self.min_step, + self.max_step, + [1], + dtype=torch.long, + device=self.device, + ) + print("t:", t) + edit_latents = self.edit_latents(text_embeddings, latents, cond_feature, t, mask) + edit_images = self.decode_latents(edit_latents) + edit_images = F.interpolate( + edit_images, (RH, RW), mode="bilinear" + ).permute(0, 2, 3, 1) + self.edit_images = edit_images + if self.cfg.cache_du: + self.edit_frames[batch_index] = edit_images.detach().cpu() + + if need_loss: + if self.cfg.cache_du: + if batch_index in self.edit_frames: + gt_rgb = self.edit_frames[batch_index].to(cond_feature.device) + else: + gt_rgb = origin_gt_rgb + else: + gt_rgb = edit_images + + import cv2 + import numpy as np + + temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8) + cv2.imwrite(".threestudio_cache/test.jpg", temp[:, :, ::-1]) + + guidance_out.update( + { + "loss_l1": torch.nn.functional.l1_loss( + rgb_BCHW_HW8, gt_rgb.permute(0, 3, 1, 2), reduction="sum" + ), + "loss_p": self.perceptual_loss( + rgb_BCHW_HW8.contiguous(), + gt_rgb.permute(0, 3, 1, 2).contiguous(), + ).sum(), + } + ) + + return guidance_out + + def __call__( + self, + rgb: Float[Tensor, "B H W C"], + cond_rgb: Float[Tensor, "B H W C"], + prompt_utils: PromptProcessorOutput, + mask = None, + **kwargs, + ): + batch_size, H, W, _ = rgb.shape + assert batch_size == 1 + assert rgb.shape[:-1] == cond_rgb.shape[:-1] + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + if mask is not None: mask = mask.permute(0, 3, 1, 2) + latents: Float[Tensor, "B 4 DH DW"] + if self.cfg.fixed_size > 0: + RH, RW = self.cfg.fixed_size, self.cfg.fixed_size + else: + RH, RW = H // 8 * 8, W // 8 * 8 + rgb_BCHW_HW8 = F.interpolate( + rgb_BCHW, (RH, RW), mode="bilinear", align_corners=False + ) + latents = self.encode_images(rgb_BCHW_HW8) + + image_cond = self.prepare_image_cond(cond_rgb) + image_cond = F.interpolate( + image_cond, (RH, RW), mode="bilinear", align_corners=False + ) + + temp = torch.zeros(1).to(rgb.device) + azimuth = kwargs.get("azimuth", temp) + camera_distance = kwargs.get("camera_distance", temp) + view_dependent_prompt = kwargs.get("view_dependent_prompt", False) + text_embeddings = prompt_utils.get_text_embeddings(temp, azimuth, camera_distance, view_dependent_prompt) # FIXME: change to view-conditioned prompt + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint( + self.min_step, + self.max_step + 1, + [batch_size], + dtype=torch.long, + device=self.device, + ) + + + guidance_out = {} + if self.cfg.use_sds: + grad = self.compute_grad_sds(text_embeddings, latents, image_cond, t) + grad = torch.nan_to_num(grad) + if self.grad_clip_val is not None: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + target = (latents - grad).detach() + loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size + guidance_out.update( + { + "loss_sds": loss_sds, + "grad_norm": grad.norm(), + "min_step": self.min_step, + "max_step": self.max_step, + } + ) + + if self.cfg.use_du: + grad = self.compute_grad_du( + latents, rgb_BCHW_HW8, image_cond, cond_rgb, text_embeddings, mask, **kwargs + ) + guidance_out.update(grad) + + return guidance_out + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # clip grad for stable training as demonstrated in + # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation + # http://arxiv.org/abs/2303.15413 + if self.cfg.grad_clip is not None: + self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) + + self.set_min_max_steps( + min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), + max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), + ) + + self.global_step = global_step \ No newline at end of file diff --git a/threestudio/models/guidance/controlnet_reg_guidance.py b/threestudio/models/guidance/controlnet_reg_guidance.py new file mode 100644 index 0000000..214a661 --- /dev/null +++ b/threestudio/models/guidance/controlnet_reg_guidance.py @@ -0,0 +1,454 @@ +import os +from dataclasses import dataclass + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from controlnet_aux import CannyDetector, NormalBaeDetector +from diffusers import ControlNetModel, DDIMScheduler, StableDiffusionControlNetPipeline, DPMSolverMultistepScheduler +from diffusers.utils.import_utils import is_xformers_available +from tqdm import tqdm + +import threestudio +from threestudio.models.prompt_processors.base import PromptProcessorOutput +from threestudio.utils.base import BaseObject +from threestudio.utils.misc import C, parse_version +from threestudio.utils.typing import * + + +@threestudio.register("stable-diffusion-controlnet-reg-guidance") +class ControlNetGuidance(BaseObject): + @dataclass + class Config(BaseObject.Config): + cache_dir: Optional[str] = None + local_files_only: Optional[bool] = False + pretrained_model_name_or_path: str = "SG161222/Realistic_Vision_V2.0" + ddim_scheduler_name_or_path: str = "runwayml/stable-diffusion-v1-5" + control_type: str = "normal" # normal/canny + + enable_memory_efficient_attention: bool = False + enable_sequential_cpu_offload: bool = False + enable_attention_slicing: bool = False + enable_channels_last_format: bool = False + guidance_scale: float = 7.5 + condition_scale: float = 1.5 + grad_clip: Optional[Any] = None + half_precision_weights: bool = True + + min_step_percent: float = 0.02 + max_step_percent: float = 0.98 + + diffusion_steps: int = 20 + + use_sds: bool = False + + # Canny threshold + canny_lower_bound: int = 50 + canny_upper_bound: int = 100 + + cfg: Config + + def configure(self) -> None: + threestudio.info(f"Loading ControlNet ...") + + self.weights_dtype = torch.float16 if self.cfg.half_precision_weights else torch.float32 + + self.preprocessor, controlnet_name_or_path = self.get_preprocessor_and_controlnet() + + pipe_kwargs = self.configure_pipeline() + + self.load_models(pipe_kwargs, controlnet_name_or_path) + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.scheduler.set_timesteps(self.cfg.diffusion_steps) + self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config) + self.scheduler = self.pipe.scheduler + + self.check_memory_efficiency_conditions() + + self.set_min_max_steps() + self.alphas = self.scheduler.alphas_cumprod.to(self.device) + self.grad_clip_val = None + + threestudio.info(f"Loaded ControlNet!") + + def get_preprocessor_and_controlnet(self): + if self.cfg.control_type in ("normal", "input_normal"): + if self.cfg.pretrained_model_name_or_path == "SG161222/Realistic_Vision_V2.0": + controlnet_name_or_path = "lllyasviel/control_v11p_sd15_normalbae" + else: + controlnet_name_or_path = "thibaud/controlnet-sd21-normalbae-diffusers" + preprocessor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators", cache_dir=self.cfg.cache_dir) + preprocessor.model.to(self.device) + elif self.cfg.control_type == "canny" or self.cfg.control_type == "canny2": + controlnet_name_or_path = self.get_canny_controlnet() + preprocessor = CannyDetector() + else: + raise ValueError(f"Unknown control type: {self.cfg.control_type}") + return preprocessor, controlnet_name_or_path + + def get_canny_controlnet(self): + if self.cfg.control_type == "canny": + return "lllyasviel/control_v11p_sd15_canny" + elif self.cfg.control_type == "canny2": + return "thepowefuldeez/sd21-controlnet-canny" + + def configure_pipeline(self): + return { + "safety_checker": None, + "feature_extractor": None, + "requires_safety_checker": False, + "torch_dtype": self.weights_dtype, + "cache_dir": self.cfg.cache_dir, + "local_files_only": self.cfg.local_files_only + } + + def load_models(self, pipe_kwargs, controlnet_name_or_path): + controlnet = ControlNetModel.from_pretrained( + controlnet_name_or_path, + torch_dtype=self.weights_dtype, + cache_dir=self.cfg.cache_dir, + local_files_only=self.cfg.local_files_only + ) + self.pipe = StableDiffusionControlNetPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, controlnet=controlnet, **pipe_kwargs + ).to(self.device) + + self.scheduler = DDIMScheduler.from_pretrained( + self.cfg.ddim_scheduler_name_or_path, + subfolder="scheduler", + torch_dtype=self.weights_dtype, + cache_dir=self.cfg.cache_dir, + local_files_only=self.cfg.local_files_only + ) + + self.vae = self.pipe.vae.eval() + self.unet = self.pipe.unet.eval() + self.controlnet = self.pipe.controlnet.eval() + + def check_memory_efficiency_conditions(self): + if self.cfg.enable_memory_efficient_attention: + self.memory_efficiency_status() + if self.cfg.enable_sequential_cpu_offload: + self.pipe.enable_sequential_cpu_offload() + if self.cfg.enable_attention_slicing: + self.pipe.enable_attention_slicing(1) + if self.cfg.enable_channels_last_format: + self.pipe.unet.to(memory_format=torch.channels_last) + + def memory_efficiency_status(self): + if parse_version(torch.__version__) >= parse_version("2"): + threestudio.info("PyTorch2.0 uses memory efficient attention by default.") + elif not is_xformers_available(): + threestudio.warn("xformers is not available, memory efficient attention is not enabled.") + else: + self.pipe.enable_xformers_memory_efficient_attention() + + @torch.cuda.amp.autocast(enabled=False) + def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): + self.min_step = int(self.num_train_timesteps * min_step_percent) + self.max_step = int(self.num_train_timesteps * max_step_percent) + + @torch.cuda.amp.autocast(enabled=False) + def forward_controlnet( + self, + latents: Float[Tensor, "..."], + t: Float[Tensor, "..."], + image_cond: Float[Tensor, "..."], + condition_scale: float, + encoder_hidden_states: Float[Tensor, "..."], + ) -> Float[Tensor, "..."]: + return self.controlnet( + latents.to(self.weights_dtype), + t.to(self.weights_dtype), + encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype), + controlnet_cond=image_cond.to(self.weights_dtype), + conditioning_scale=condition_scale, + return_dict=False, + ) + + @torch.cuda.amp.autocast(enabled=False) + def forward_control_unet( + self, + latents: Float[Tensor, "..."], + t: Float[Tensor, "..."], + encoder_hidden_states: Float[Tensor, "..."], + cross_attention_kwargs, + down_block_additional_residuals, + mid_block_additional_residual, + ) -> Float[Tensor, "..."]: + input_dtype = latents.dtype + return self.unet( + latents.to(self.weights_dtype), + t.to(self.weights_dtype), + encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype), + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + ).sample.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def encode_images( + self, imgs: Float[Tensor, "B 3 512 512"] + ) -> Float[Tensor, "B 4 64 64"]: + input_dtype = imgs.dtype + imgs = imgs * 2.0 - 1.0 + posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + return latents.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def encode_cond_images( + self, imgs: Float[Tensor, "B 3 512 512"] + ) -> Float[Tensor, "B 4 64 64"]: + input_dtype = imgs.dtype + imgs = imgs * 2.0 - 1.0 + posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist + latents = posterior.mode() + uncond_image_latents = torch.zeros_like(latents) + latents = torch.cat([latents, latents, uncond_image_latents], dim=0) + return latents.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def decode_latents( + self, + latents: Float[Tensor, "B 4 H W"], + latent_height: int = 64, + latent_width: int = 64, + ) -> Float[Tensor, "B 3 512 512"]: + input_dtype = latents.dtype + latents = F.interpolate( + latents, (latent_height, latent_width), mode="bilinear", align_corners=False + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents.to(self.weights_dtype)).sample + image = (image * 0.5 + 0.5).clamp(0, 1) + return image.to(input_dtype) + + def edit_latents( + self, + text_embeddings: Float[Tensor, "BB 77 768"], + latents: Float[Tensor, "B 4 64 64"], + image_cond: Float[Tensor, "B 3 512 512"], + t: Int[Tensor, "B"], + mask=None + ) -> Float[Tensor, "B 4 64 64"]: + batch_size = t.shape[0] + self.scheduler.set_timesteps(num_inference_steps=self.cfg.diffusion_steps) + init_timestep = max(1, min(int(self.cfg.diffusion_steps * t[0].item() / self.num_train_timesteps), self.cfg.diffusion_steps)) + t_start = max(self.cfg.diffusion_steps - init_timestep, 0) + latent_timestep = self.scheduler.timesteps[t_start : t_start + 1].repeat(batch_size) + B, _, DH, DW = latents.shape + origin_latents = latents.clone() + if mask is not None: + mask = F.interpolate(mask, (DH, DW), mode="bilinear", antialias=True) + + with torch.no_grad(): + # sections of code used from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py + noise = torch.randn_like(latents) + latents = self.scheduler.add_noise(latents, noise, latent_timestep) # type: ignore + threestudio.debug("Start editing...") + for i, step in enumerate(range(t_start, self.cfg.diffusion_steps)): + timestep = self.scheduler.timesteps[step] + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # pred noise + latent_model_input = torch.cat([latents] * 2) + ( + down_block_res_samples, + mid_block_res_sample, + ) = self.forward_controlnet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings, + image_cond=image_cond, + condition_scale=self.cfg.condition_scale, + ) + + noise_pred = self.forward_control_unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + # perform classifier-free guidance + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + if mask is not None: + noise_pred = noise_pred * mask + (1-mask) * noise + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + threestudio.debug("Editing finished.") + return latents + + def prepare_image_cond(self, cond_rgb: Float[Tensor, "B H W C"]): + if self.cfg.control_type == "normal": + cond_rgb = ( + (cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy() + ) + detected_map = self.preprocessor(cond_rgb) + control = ( + torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0 + ) + control = control.unsqueeze(0) + control = control.permute(0, 3, 1, 2) + elif self.cfg.control_type == "canny" or self.cfg.control_type == "canny2": + cond_rgb = ( + (cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy() + ) + blurred_img = cv2.blur(cond_rgb, ksize=(5, 5)) + detected_map = self.preprocessor( + blurred_img, self.cfg.canny_lower_bound, self.cfg.canny_upper_bound + ) + control = ( + torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0 + ) + control = control.unsqueeze(-1).repeat(1, 1, 3) + control = control.unsqueeze(0) + control = control.permute(0, 3, 1, 2) + elif self.cfg.control_type == "input_normal": + cond_rgb[..., 0] = ( + 1 - cond_rgb[..., 0] + ) # Flip the sign on the x-axis to match bae system + control = cond_rgb.permute(0, 3, 1, 2) + else: + raise ValueError(f"Unknown control type: {self.cfg.control_type}") + + return F.interpolate(control, (512, 512), mode="bilinear", align_corners=False) + + def compute_grad_sds( + self, + text_embeddings: Float[Tensor, "BB 77 768"], + latents: Float[Tensor, "B 4 64 64"], + image_cond: Float[Tensor, "B 3 512 512"], + t: Int[Tensor, "B"], + ): + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) # TODO: use torch generator + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + down_block_res_samples, mid_block_res_sample = self.forward_controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + image_cond=image_cond, + condition_scale=self.cfg.condition_scale, + ) + + noise_pred = self.forward_control_unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + # perform classifier-free guidance + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + w = (1 - self.alphas[t]).view(-1, 1, 1, 1) + grad = w * (noise_pred - noise) + return grad + + def __call__( + self, + rgb: Float[Tensor, "B H W C"], + cond_rgb: Float[Tensor, "B H W C"], + prompt_utils: PromptProcessorOutput, + mask: Float[Tensor, "B H W C"], + **kwargs, + ): + + batch_size, H, W, _ = rgb.shape + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + latents: Float[Tensor, "B 4 64 64"] + rgb_BCHW_512 = F.interpolate( + rgb_BCHW, (512, 512), mode="bilinear", align_corners=False + ) + latents = self.encode_images(rgb_BCHW_512) + + image_cond = self.prepare_image_cond(cond_rgb) + + temp = torch.zeros(1).to(rgb.device) + text_embeddings = prompt_utils.get_text_embeddings(temp, temp, temp, False) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint( + self.min_step, + self.max_step + 1, + [batch_size], + dtype=torch.long, + device=self.device, + ) + + if self.cfg.use_sds: + grad = self.compute_grad_sds(text_embeddings, latents, image_cond, t) + grad = torch.nan_to_num(grad) + if self.grad_clip_val is not None: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + target = (latents - grad).detach() + loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size + return { + "loss_sds": loss_sds, + "grad_norm": grad.norm(), + "min_step": self.min_step, + "max_step": self.max_step, + } + else: + + if mask is not None: mask = mask.permute(0, 3, 1, 2) + edit_latents = self.edit_latents(text_embeddings, latents, image_cond, t, mask) + edit_images = self.decode_latents(edit_latents) + edit_images = F.interpolate(edit_images, (H, W), mode="bilinear") + + return {"edit_images": edit_images.permute(0, 2, 3, 1), + "edit_latents": edit_latents} + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # clip grad for stable training as demonstrated in + # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation + # http://arxiv.org/abs/2303.15413 + if self.cfg.grad_clip is not None: + self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) + + self.set_min_max_steps( + min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), + max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), + ) + + +if __name__ == "__main__": + from threestudio.utils.config import ExperimentConfig, load_config + from threestudio.utils.typing import Optional + + cfg = load_config("configs/experimental/controlnet-normal.yaml") + guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance) + prompt_processor = threestudio.find(cfg.system.prompt_processor_type)( + cfg.system.prompt_processor + ) + + rgb_image = cv2.imread("assets/face.jpg")[:, :, ::-1].copy() / 255 + rgb_image = cv2.resize(rgb_image, (512, 512)) + rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device) + prompt_utils = prompt_processor() + guidance_out = guidance(rgb_image, rgb_image, prompt_utils) + edit_image = ( + (guidance_out["edit_images"][0].detach().cpu().clip(0, 1).numpy() * 255) + .astype(np.uint8)[:, :, ::-1] + .copy() + ) + os.makedirs(".threestudio_cache", exist_ok=True) + cv2.imwrite(".threestudio_cache/edit_image.jpg", edit_image) \ No newline at end of file diff --git a/threestudio/models/guidance/deep_floyd_guidance.py b/threestudio/models/guidance/deep_floyd_guidance.py new file mode 100644 index 0000000..94a7a38 --- /dev/null +++ b/threestudio/models/guidance/deep_floyd_guidance.py @@ -0,0 +1,582 @@ +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers import IFPipeline, DDPMScheduler +from diffusers.utils.import_utils import is_xformers_available +from tqdm import tqdm + +import threestudio +from threestudio.models.prompt_processors.base import PromptProcessorOutput +from threestudio.utils.base import BaseObject +from threestudio.utils.misc import C, parse_version +from threestudio.utils.ops import perpendicular_component +from threestudio.utils.typing import * + + +@threestudio.register("deep-floyd-guidance") +class DeepFloydGuidance(BaseObject): + @dataclass + class Config(BaseObject.Config): + cache_dir: Optional[str] = None + local_files_only: Optional[bool] = False + pretrained_model_name_or_path: str = "DeepFloyd/IF-I-XL-v1.0" + # FIXME: xformers error + enable_memory_efficient_attention: bool = False + enable_sequential_cpu_offload: bool = False + enable_attention_slicing: bool = False + enable_channels_last_format: bool = True + guidance_scale: float = 20.0 + grad_clip: Optional[ + Any + ] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000]) + time_prior: Optional[Any] = None # [w1,w2,s1,s2] + half_precision_weights: bool = True + + min_step_percent: float = 0.02 + max_step_percent: float = 0.98 + + weighting_strategy: str = "sds" + + view_dependent_prompting: bool = True + + """Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items.""" + max_items_eval: int = 4 + + lora_weights_path: Optional[str] = None + + cfg: Config + + def configure(self) -> None: + threestudio.info(f"Loading Deep Floyd ...") + + self.weights_dtype = ( + torch.float16 if self.cfg.half_precision_weights else torch.float32 + ) + + # Create model + self.pipe = IFPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, + text_encoder=None, + safety_checker=None, + watermarker=None, + feature_extractor=None, + requires_safety_checker=False, + variant="fp16" if self.cfg.half_precision_weights else None, + torch_dtype=self.weights_dtype, + cache_dir=self.cfg.cache_dir, + local_files_only=self.cfg.local_files_only + ).to(self.device) + + # Load lora weights + if self.cfg.lora_weights_path is not None: + self.pipe.load_lora_weights(self.cfg.lora_weights_path) + self.pipe.scheduler = self.pipe.scheduler.__class__.from_config(self.pipe.scheduler.config, variance_type="fixed_small") + + if self.cfg.enable_memory_efficient_attention: + if parse_version(torch.__version__) >= parse_version("2"): + threestudio.info( + "PyTorch2.0 uses memory efficient attention by default." + ) + elif not is_xformers_available(): + threestudio.warn( + "xformers is not available, memory efficient attention is not enabled." + ) + else: + threestudio.warn( + f"Use DeepFloyd with xformers may raise error, see https://github.com/deep-floyd/IF/issues/52 to track this problem." + ) + self.pipe.enable_xformers_memory_efficient_attention() + + if self.cfg.enable_sequential_cpu_offload: + self.pipe.enable_sequential_cpu_offload() + + if self.cfg.enable_attention_slicing: + self.pipe.enable_attention_slicing(1) + + if self.cfg.enable_channels_last_format: + self.pipe.unet.to(memory_format=torch.channels_last) + + self.unet = self.pipe.unet.eval() + + for p in self.unet.parameters(): + p.requires_grad_(False) + + self.scheduler = self.pipe.scheduler + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.set_min_max_steps() # set to default value + if self.cfg.time_prior is not None: + m1, m2, s1, s2 = self.cfg.time_prior + weights = torch.cat( + ( + torch.exp( + -((torch.arange(self.num_train_timesteps, m1, -1) - m1) ** 2) + / (2 * s1**2) + ), + torch.ones(m1 - m2 + 1), + torch.exp( + -((torch.arange(m2 - 1, 0, -1) - m2) ** 2) / (2 * s2**2) + ), + ) + ) + weights = weights / torch.sum(weights) + self.time_prior_acc_weights = torch.cumsum(weights, dim=0) + + self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to( + self.device + ) + + self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to( + self.device + ) + + self.grad_clip_val: Optional[float] = None + + threestudio.info(f"Loaded Deep Floyd!") + + @torch.cuda.amp.autocast(enabled=False) + def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): + self.min_step = int(self.num_train_timesteps * min_step_percent) + self.max_step = int(self.num_train_timesteps * max_step_percent) + + @torch.cuda.amp.autocast(enabled=False) + def forward_unet( + self, + latents: Float[Tensor, "..."], + t: Float[Tensor, "..."], + encoder_hidden_states: Float[Tensor, "..."], + ) -> Float[Tensor, "..."]: + input_dtype = latents.dtype + return self.unet( + latents.to(self.weights_dtype), + t.to(self.weights_dtype), + encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype), + ).sample.to(input_dtype) + + def __call__( + self, + rgb: Float[Tensor, "B H W C"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + current_step_ratio=None, + mask: Float[Tensor, "B H W 1"] = None, + rgb_as_latents=False, + guidance_eval=False, + **kwargs, + ): + batch_size = rgb.shape[0] + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + + if mask is not None: + mask = mask.permute(0, 3, 1, 2) + mask = F.interpolate( + mask, (64, 64), mode="bilinear", align_corners=False + ) + + assert rgb_as_latents == False, f"No latent space in {self.__class__.__name__}" + rgb_BCHW = rgb_BCHW * 2.0 - 1.0 # scale to [-1, 1] to match the diffusion range + latents = F.interpolate( + rgb_BCHW, (64, 64), mode="bilinear", align_corners=False + ) + + if self.cfg.time_prior is not None: + time_index = torch.where( + (self.time_prior_acc_weights - current_step_ratio) > 0 + )[0][0] + if time_index == 0 or torch.abs( + self.time_prior_acc_weights[time_index] - current_step_ratio + ) < torch.abs( + self.time_prior_acc_weights[time_index - 1] - current_step_ratio + ): + t = self.num_train_timesteps - time_index + else: + t = self.num_train_timesteps - time_index + 1 + t = torch.clip(t, self.min_step, self.max_step + 1) + t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) + + else: + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint( + self.min_step, + self.max_step + 1, + [batch_size], + dtype=torch.long, + device=self.device, + ) + + if prompt_utils.use_perp_neg: + ( + text_embeddings, + neg_guidance_weights, + ) = prompt_utils.get_text_embeddings_perp_neg( + elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting + ) + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + if mask is not None: + latents_noisy = (1 - mask) * latents + mask * latents_noisy + latent_model_input = torch.cat([latents_noisy] * 4, dim=0) + noise_pred = self.forward_unet( + latent_model_input, + torch.cat([t] * 4), + encoder_hidden_states=text_embeddings, + ) # (4B, 6, 64, 64) + + noise_pred_text, _ = noise_pred[:batch_size].split(3, dim=1) + noise_pred_uncond, _ = noise_pred[batch_size : batch_size * 2].split( + 3, dim=1 + ) + noise_pred_neg, _ = noise_pred[batch_size * 2 :].split(3, dim=1) + + e_pos = noise_pred_text - noise_pred_uncond + accum_grad = 0 + n_negative_prompts = neg_guidance_weights.shape[-1] + for i in range(n_negative_prompts): + e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond + accum_grad += neg_guidance_weights[:, i].view( + -1, 1, 1, 1 + ) * perpendicular_component(e_i_neg, e_pos) + + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + e_pos + accum_grad + ) + else: + neg_guidance_weights = None + text_embeddings = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting + ) + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) # TODO: use torch generator + latents_noisy = self.scheduler.add_noise(latents, noise, t) + if mask is not None: + latents_noisy = (1 - mask) * latents + mask * latents_noisy + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2, dim=0) + noise_pred = self.forward_unet( + latent_model_input, + torch.cat([t] * 2), + encoder_hidden_states=text_embeddings, + ) # (2B, 6, 64, 64) + + # perform guidance (high scale from paper!) + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred_text, predicted_variance = noise_pred_text.split(3, dim=1) + noise_pred_uncond, _ = noise_pred_uncond.split(3, dim=1) + noise_pred = noise_pred_text + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + """ + # thresholding, experimental + if self.cfg.thresholding: + assert batch_size == 1 + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + noise_pred = custom_ddpm_step(self.scheduler, + noise_pred, int(t.item()), latents_noisy, **self.pipe.prepare_extra_step_kwargs(None, 0.0) + ) + """ + + if self.cfg.weighting_strategy == "sds": + # w(t), sigma_t^2 + w = (1 - self.alphas[t]).view(-1, 1, 1, 1) + elif self.cfg.weighting_strategy == "uniform": + w = 1 + elif self.cfg.weighting_strategy == "fantasia3d": + w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1) + else: + raise ValueError( + f"Unknown weighting strategy: {self.cfg.weighting_strategy}" + ) + + grad = w * (noise_pred - noise) + grad = torch.nan_to_num(grad) + # clip grad for stable training? + if self.grad_clip_val is not None: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + + # loss = SpecifyGradient.apply(latents, grad) + # SpecifyGradient is not straghtforward, use a reparameterization trick instead + target = (latents - grad).detach() + # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad + loss_sd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size + + guidance_out = { + "loss_sd": loss_sd, + "grad_norm": grad.norm(), + "min_step": self.min_step, + "max_step": self.max_step, + } + + # # FIXME: Visualize inpainting results + # self.scheduler.set_timesteps(20) + # latents = latents_noisy + # for t in tqdm(self.scheduler.timesteps): + # # pred noise + # noise_pred = self.get_noise_pred( + # latents, t, text_embeddings, prompt_utils.use_perp_neg, None + # ) + # # get prev latent + # prev_latents = latents + # latents = self.scheduler.step(noise_pred, t, latents)["prev_sample"] + # if mask is not None: + # latents = (1 - mask) * prev_latents + mask * latents + + # denoised_img = (latents / 2 + 0.5).permute(0, 2, 3, 1) + # guidance_out.update( + # {"denoised_img": denoised_img} + # ) + + if guidance_eval: + guidance_eval_utils = { + "use_perp_neg": prompt_utils.use_perp_neg, + "neg_guidance_weights": neg_guidance_weights, + "text_embeddings": text_embeddings, + "t_orig": t, + "latents_noisy": latents_noisy, + "noise_pred": torch.cat([noise_pred, predicted_variance], dim=1), + } + guidance_eval_out = self.guidance_eval(**guidance_eval_utils) + texts = [] + for n, e, a, c in zip( + guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances + ): + texts.append( + f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}" + ) + guidance_eval_out.update({"texts": texts}) + guidance_out.update({"eval": guidance_eval_out}) + + return guidance_out + + @torch.cuda.amp.autocast(enabled=False) + @torch.no_grad() + def get_noise_pred( + self, + latents_noisy, + t, + text_embeddings, + use_perp_neg=False, + neg_guidance_weights=None, + ): + batch_size = latents_noisy.shape[0] + if use_perp_neg: + latent_model_input = torch.cat([latents_noisy] * 4, dim=0) + noise_pred = self.forward_unet( + latent_model_input, + torch.cat([t.reshape(1)] * 4).to(self.device), + encoder_hidden_states=text_embeddings, + ) # (4B, 6, 64, 64) + + noise_pred_text, _ = noise_pred[:batch_size].split(3, dim=1) + noise_pred_uncond, _ = noise_pred[batch_size : batch_size * 2].split( + 3, dim=1 + ) + noise_pred_neg, _ = noise_pred[batch_size * 2 :].split(3, dim=1) + + e_pos = noise_pred_text - noise_pred_uncond + accum_grad = 0 + n_negative_prompts = neg_guidance_weights.shape[-1] + for i in range(n_negative_prompts): + e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond + accum_grad += neg_guidance_weights[:, i].view( + -1, 1, 1, 1 + ) * perpendicular_component(e_i_neg, e_pos) + + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + e_pos + accum_grad + ) + else: + latent_model_input = torch.cat([latents_noisy] * 2, dim=0) + noise_pred = self.forward_unet( + latent_model_input, + torch.cat([t.reshape(1)] * 2).to(self.device), + encoder_hidden_states=text_embeddings, + ) # (2B, 6, 64, 64) + + # perform guidance (high scale from paper!) + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred_text, predicted_variance = noise_pred_text.split(3, dim=1) + noise_pred_uncond, _ = noise_pred_uncond.split(3, dim=1) + noise_pred = noise_pred_text + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + return torch.cat([noise_pred, predicted_variance], dim=1) + + @torch.cuda.amp.autocast(enabled=False) + @torch.no_grad() + def guidance_eval( + self, + t_orig, + text_embeddings, + latents_noisy, + noise_pred, + use_perp_neg=False, + neg_guidance_weights=None, + ): + # use only 50 timesteps, and find nearest of those to t + self.scheduler.set_timesteps(50) + self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device) + bs = ( + min(self.cfg.max_items_eval, latents_noisy.shape[0]) + if self.cfg.max_items_eval > 0 + else latents_noisy.shape[0] + ) # batch size + large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[ + :bs + ].unsqueeze( + -1 + ) # sized [bs,50] > [bs,1] + idxs = torch.min(large_enough_idxs, dim=1)[1] + t = self.scheduler.timesteps_gpu[idxs] + + fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy()) + imgs_noisy = (latents_noisy[:bs] / 2 + 0.5).permute(0, 2, 3, 1) + + # get prev latent + latents_1step = [] + pred_1orig = [] + for b in range(bs): + step_output = self.scheduler.step( + noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1] + ) + latents_1step.append(step_output["prev_sample"]) + pred_1orig.append(step_output["pred_original_sample"]) + latents_1step = torch.cat(latents_1step) + pred_1orig = torch.cat(pred_1orig) + imgs_1step = (latents_1step / 2 + 0.5).permute(0, 2, 3, 1) + imgs_1orig = (pred_1orig / 2 + 0.5).permute(0, 2, 3, 1) + + latents_final = [] + for b, i in enumerate(idxs): + latents = latents_1step[b : b + 1] + text_emb = ( + text_embeddings[ + [b, b + len(idxs), b + 2 * len(idxs), b + 3 * len(idxs)], ... + ] + if use_perp_neg + else text_embeddings[[b, b + len(idxs)], ...] + ) + neg_guid = neg_guidance_weights[b : b + 1] if use_perp_neg else None + for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False): + # pred noise + noise_pred = self.get_noise_pred( + latents, t, text_emb, use_perp_neg, neg_guid + ) + # get prev latent + latents = self.scheduler.step(noise_pred, t, latents)["prev_sample"] + latents_final.append(latents) + + latents_final = torch.cat(latents_final) + imgs_final = (latents_final / 2 + 0.5).permute(0, 2, 3, 1) + + return { + "bs": bs, + "noise_levels": fracs, + "imgs_noisy": imgs_noisy, + "imgs_1step": imgs_1step, + "imgs_1orig": imgs_1orig, + "imgs_final": imgs_final, + } + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # clip grad for stable training as demonstrated in + # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation + # http://arxiv.org/abs/2303.15413 + if self.cfg.grad_clip is not None: + self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) + + self.set_min_max_steps( + min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), + max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), + ) + + +""" +# used by thresholding, experimental +def custom_ddpm_step(ddpm, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True): + self = ddpm + t = timestep + + prev_t = self.previous_timestep(t) + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t].item() + alpha_prod_t_prev = self.alphas_cumprod[prev_t].item() if prev_t >= 0 else 1.0 + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + noise_thresholded = (sample - (alpha_prod_t ** 0.5) * pred_original_sample) / (beta_prod_t ** 0.5) + return noise_thresholded +""" + + +if __name__ == '__main__': + from threestudio.utils.config import load_config + import pytorch_lightning as pl + import numpy as np + import os + import cv2 + cfg = load_config("configs/debugging/deepfloyd.yaml") + guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance) + prompt_processor = threestudio.find(cfg.system.prompt_processor_type)(cfg.system.prompt_processor) + prompt_utils = prompt_processor() + temp = torch.zeros(1).to(guidance.device) + # rgb_image = guidance.sample(prompt_utils, temp, temp, temp, seed=cfg.seed) + # rgb_image = (rgb_image[0].detach().cpu().clip(0, 1).numpy()*255).astype(np.uint8)[:, :, ::-1].copy() + # os.makedirs('.threestudio_cache', exist_ok=True) + # cv2.imwrite('.threestudio_cache/diffusion_image.jpg', rgb_image) + + ### inpaint + rgb_image = cv2.imread("assets/test.jpg")[:, :, ::-1].copy() / 255 + mask_image = cv2.imread("assets/mask.png")[:, :, :1].copy() / 255 + rgb_image = cv2.resize(rgb_image, (512, 512)) + mask_image = cv2.resize(mask_image, (512, 512)).reshape(512, 512, 1) + rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device) + mask_image = torch.FloatTensor(mask_image).unsqueeze(0).to(guidance.device) + + guidance_out = guidance(rgb_image, prompt_utils, temp, temp, temp, mask=mask_image) + edit_image = ( + (guidance_out["denoised_img"][0].detach().cpu().clip(0, 1).numpy() * 255) + .astype(np.uint8)[:, :, ::-1] + .copy() + ) + os.makedirs(".threestudio_cache", exist_ok=True) + cv2.imwrite(".threestudio_cache/edit_image.jpg", edit_image) \ No newline at end of file diff --git a/threestudio/models/guidance/stable_diffusion_bsd_guidance.py b/threestudio/models/guidance/stable_diffusion_bsd_guidance.py new file mode 100644 index 0000000..459aff6 --- /dev/null +++ b/threestudio/models/guidance/stable_diffusion_bsd_guidance.py @@ -0,0 +1,1134 @@ +import random +from contextlib import contextmanager +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.embeddings import TimestepEmbedding +from diffusers.utils.import_utils import is_xformers_available + +import threestudio +from threestudio.models.prompt_processors.base import PromptProcessorOutput +from threestudio.utils.base import BaseModule +from threestudio.utils.misc import C, cleanup, parse_version +from threestudio.utils.perceptual import PerceptualLoss +from threestudio.utils.typing import * + + +class ToWeightsDType(nn.Module): + def __init__(self, module: nn.Module, dtype: torch.dtype): + super().__init__() + self.module = module + self.dtype = dtype + + def forward(self, x: Float[Tensor, "..."]) -> Float[Tensor, "..."]: + return self.module(x).to(self.dtype) + + +@threestudio.register("stable-diffusion-bsd-guidance") +class StableDiffusionBSDGuidance(BaseModule): + @dataclass + class Config(BaseModule.Config): + cache_dir: Optional[str] = None + local_files_only: Optional[bool] = False + pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base" + pretrained_model_name_or_path_lora: str = "stabilityai/stable-diffusion-2-1" + enable_memory_efficient_attention: bool = False + enable_sequential_cpu_offload: bool = False + enable_attention_slicing: bool = False + enable_channels_last_format: bool = False + guidance_scale: float = 7.5 + guidance_scale_lora: float = 1.0 + grad_clip: Optional[ + Any + ] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000]) + half_precision_weights: bool = True + lora_cfg_training: bool = True + lora_n_timestamp_samples: int = 1 + + min_step_percent: float = 0.02 + max_step_percent: float = 0.98 + + view_dependent_prompting: bool = True + camera_condition_type: str = "extrinsics" + + use_du: bool = False + per_du_step: int = 10 + start_du_step: int = 0 + du_diffusion_steps: int = 20 + + lora_pretrain_cfg_training: bool = True + lora_pretrain_n_timestamp_samples: int = 1 + per_update_pretrain_step: int = 25 + only_pretrain_step: int = 1000 + + + + cfg: Config + + def configure(self) -> None: + threestudio.info(f"Loading Stable Diffusion ...") + + self.weights_dtype = ( + torch.float16 if self.cfg.half_precision_weights else torch.float32 + ) + + pipe_kwargs = { + "tokenizer": None, + "safety_checker": None, + "feature_extractor": None, + "requires_safety_checker": False, + "torch_dtype": self.weights_dtype, + "cache_dir": self.cfg.cache_dir, + "local_files_only": self.cfg.local_files_only + } + + pipe_lora_kwargs = { + "tokenizer": None, + "safety_checker": None, + "feature_extractor": None, + "requires_safety_checker": False, + "torch_dtype": self.weights_dtype, + "cache_dir": self.cfg.cache_dir, + "local_files_only": self.cfg.local_files_only + } + + @dataclass + class SubModules: + pipe: StableDiffusionPipeline + pipe_lora: StableDiffusionPipeline + pipe_fix: StableDiffusionPipeline + + pipe = StableDiffusionPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, + **pipe_kwargs, + ).to(self.device) + self.single_model = False + pipe_lora = StableDiffusionPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path_lora, + **pipe_lora_kwargs, + ).to(self.device) + del pipe_lora.vae + cleanup() + pipe_lora.vae = pipe.vae + + pipe_fix = pipe + + self.submodules = SubModules(pipe=pipe, pipe_lora=pipe_lora, pipe_fix=pipe_fix) + + if self.cfg.enable_memory_efficient_attention: + if parse_version(torch.__version__) >= parse_version("2"): + threestudio.info( + "PyTorch2.0 uses memory efficient attention by default." + ) + elif not is_xformers_available(): + threestudio.warn( + "xformers is not available, memory efficient attention is not enabled." + ) + else: + self.pipe.enable_xformers_memory_efficient_attention() + self.pipe_lora.enable_xformers_memory_efficient_attention() + + if self.cfg.enable_sequential_cpu_offload: + self.pipe.enable_sequential_cpu_offload() + self.pipe_lora.enable_sequential_cpu_offload() + + if self.cfg.enable_attention_slicing: + self.pipe.enable_attention_slicing(1) + self.pipe_lora.enable_attention_slicing(1) + + if self.cfg.enable_channels_last_format: + self.pipe.unet.to(memory_format=torch.channels_last) + self.pipe_lora.unet.to(memory_format=torch.channels_last) + + del self.pipe.text_encoder + if not self.single_model: + del self.pipe_lora.text_encoder + cleanup() + + for p in self.vae.parameters(): + p.requires_grad_(False) + + for p in self.vae_fix.parameters(): + p.requires_grad_(False) + for p in self.unet_fix.parameters(): + p.requires_grad_(False) + + # FIXME: hard-coded dims + self.camera_embedding = ToWeightsDType( + TimestepEmbedding(16, 1280), self.weights_dtype + ).to(self.device) + # self.unet_lora.class_embedding = self.camera_embedding + + # set up LoRA layers + # self.set_up_lora_layers(self.unet_lora) + # self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to( + # self.device + # ) + # self.lora_layers._load_state_dict_pre_hooks.clear() + # self.lora_layers._state_dict_hooks.clear() + + # set up LoRA layers for pretrain + # self.set_up_lora_layers(self.unet) + # self.lora_layers_pretrain = AttnProcsLayers(self.unet.attn_processors).to( + # self.device + # ) + # self.lora_layers_pretrain._load_state_dict_pre_hooks.clear() + # self.lora_layers_pretrain._state_dict_hooks.clear() + + self.train_unet = UNet2DConditionModel.from_pretrained( + self.cfg.pretrained_model_name_or_path, subfolder="unet", + torch_dtype=self.weights_dtype + ) + self.train_unet.enable_xformers_memory_efficient_attention() + self.train_unet.enable_gradient_checkpointing() + + self.train_unet_lora = UNet2DConditionModel.from_pretrained( + self.cfg.pretrained_model_name_or_path_lora, subfolder="unet", + torch_dtype=self.weights_dtype + ) + self.train_unet_lora.enable_xformers_memory_efficient_attention() + self.train_unet_lora.enable_gradient_checkpointing() + + for p in self.train_unet.parameters(): + p.requires_grad_(True) + for p in self.train_unet_lora.parameters(): + p.requires_grad_(True) + # for p in self.lora_layers.parameters(): + # p.requires_grad_(False) + + self.scheduler = DDPMScheduler.from_pretrained( # DDPM + self.cfg.pretrained_model_name_or_path, + subfolder="scheduler", + torch_dtype=self.weights_dtype, + cache_dir=self.cfg.cache_dir, + local_files_only=self.cfg.local_files_only, + ) + + self.scheduler_lora = DDPMScheduler.from_pretrained( + self.cfg.pretrained_model_name_or_path_lora, + subfolder="scheduler", + torch_dtype=self.weights_dtype, + cache_dir=self.cfg.cache_dir, + local_files_only=self.cfg.local_files_only, + ) + + self.scheduler_sample = DPMSolverMultistepScheduler.from_config( + self.pipe.scheduler.config + ) + self.scheduler_lora_sample = DPMSolverMultistepScheduler.from_config( + self.pipe_lora.scheduler.config + ) + + self.pipe.scheduler = self.scheduler + self.pipe_lora.scheduler = self.scheduler_lora + + self.pipe_fix.scheduler = self.scheduler + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.set_min_max_steps() # set to default value + + self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to( + self.device + ) + + self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device) + + self.grad_clip_val: Optional[float] = None + + if self.cfg.use_du: + self.perceptual_loss = PerceptualLoss().eval().to(self.device) + for p in self.perceptual_loss.parameters(): + p.requires_grad_(False) + + self.cache_frames = [] + + threestudio.info(f"Loaded Stable Diffusion!") + + @torch.cuda.amp.autocast(enabled=False) + def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): + self.min_step = int(self.num_train_timesteps * min_step_percent) + self.max_step = int(self.num_train_timesteps * max_step_percent) + + @property + def pipe(self): + return self.submodules.pipe + + @property + def pipe_lora(self): + return self.submodules.pipe_lora + + @property + def unet(self): + return self.train_unet + + @property + def unet_lora(self): + return self.train_unet_lora + + @property + def vae(self): + return self.submodules.pipe.vae + + @property + def vae_lora(self): + return self.submodules.pipe_lora.vae + + @property + def pipe_fix(self): + return self.submodules.pipe_fix + + @property + def unet_fix(self): + return self.submodules.pipe_fix.unet + + @property + def vae_fix(self): + return self.submodules.pipe_fix.vae + + def set_up_lora_layers(self, unet): + # set up LoRA layers + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[ + block_id + ] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + + unet.set_attn_processor(lora_attn_procs) + + return lora_attn_procs + + @torch.no_grad() + @torch.cuda.amp.autocast(enabled=False) + def _sample( + self, + pipe: StableDiffusionPipeline, + sample_scheduler: DPMSolverMultistepScheduler, + text_embeddings: Float[Tensor, "BB N Nf"], + num_inference_steps: int, + guidance_scale: float, + num_images_per_prompt: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + class_labels: Optional[Float[Tensor, "BB 16"]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents_inp: Optional[Float[Tensor, "..."]] = None, + ) -> Float[Tensor, "B H W 3"]: + vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1) + height = height or pipe.unet.config.sample_size * vae_scale_factor + width = width or pipe.unet.config.sample_size * vae_scale_factor + batch_size = text_embeddings.shape[0] // 2 + device = self.device + + sample_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = sample_scheduler.timesteps + num_channels_latents = pipe.unet.config.in_channels + + if latents_inp is not None: + B = latents_inp.shape[0] + t = torch.randint( + self.max_step, + self.max_step+1, + [B], + dtype=torch.long, + device=self.device, + ) + noise = torch.randn_like(latents_inp) + # latents = sample_scheduler.add_noise(latents_inp, noise, t).to(self.weights_dtype) + + init_timestep = max(1, min(int(num_inference_steps * t[0].item() / self.num_train_timesteps), num_inference_steps)) + t_start = max(num_inference_steps - init_timestep, 0) + latent_timestep = sample_scheduler.timesteps[t_start : t_start + 1].repeat(batch_size) + latents = sample_scheduler.add_noise(latents_inp, noise, latent_timestep).to(self.weights_dtype) + + else: + latents = pipe.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + self.weights_dtype, + device, + generator, + ) + t_start = 0 + + for i, t in enumerate(timesteps[t_start:]): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = sample_scheduler.scale_model_input( + latent_model_input, t + ) + t_start = 0 + + # predict the noise residual + if class_labels is None: + with self.disable_unet_class_embedding(pipe.unet) as unet: + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings.to(self.weights_dtype), + cross_attention_kwargs=cross_attention_kwargs, + ).sample + else: + noise_pred = pipe.unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings.to(self.weights_dtype), + class_labels=class_labels, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = sample_scheduler.step(noise_pred, t, latents).prev_sample + + latents = 1 / pipe.vae.config.scaling_factor * latents + images = pipe.vae.decode(latents).sample + images = (images / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + images = images.permute(0, 2, 3, 1).float() + + return images + + def sample( + self, + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + seed: int = 0, + **kwargs, + ) -> Float[Tensor, "N H W 3"]: + # view-dependent text embeddings + text_embeddings_vd = prompt_utils.get_text_embeddings( + elevation, + azimuth, + camera_distances, + view_dependent_prompting=self.cfg.view_dependent_prompting, + ) + cross_attention_kwargs = {"scale": 0.0} if self.single_model else None + generator = torch.Generator(device=self.device).manual_seed(seed) + + return self._sample( + pipe=self.pipe, + sample_scheduler=self.scheduler_sample, + text_embeddings=text_embeddings_vd, + num_inference_steps=25, + guidance_scale=self.cfg.guidance_scale, + cross_attention_kwargs=cross_attention_kwargs, + generator=generator, + ) + + def sample_img2img( + self, + rgb: Float[Tensor, "B H W C"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + seed: int = 0, + mask = None, + **kwargs, + ) -> Float[Tensor, "N H W 3"]: + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + mask_BCHW = mask.permute(0, 3, 1, 2) + latents = self.get_latents(rgb_BCHW, rgb_as_latents=False) # TODO: 有部分概率是du或者ref image + + # view-dependent text embeddings + text_embeddings_vd = prompt_utils.get_text_embeddings( + elevation, + azimuth, + camera_distances, + view_dependent_prompting=self.cfg.view_dependent_prompting, + ) + cross_attention_kwargs = {"scale": 0.0} if self.single_model else None + generator = torch.Generator(device=self.device).manual_seed(seed) + + # return self._sample( + # pipe=self.pipe, + # sample_scheduler=self.scheduler_sample, + # text_embeddings=text_embeddings_vd, + # num_inference_steps=25, + # guidance_scale=self.cfg.guidance_scale, + # cross_attention_kwargs=cross_attention_kwargs, + # generator=generator, + # latents_inp=latents + # ) + + return self.compute_grad_du(latents, rgb_BCHW, text_embeddings_vd, mask=mask_BCHW) + + def sample_lora( + self, + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + mvp_mtx: Float[Tensor, "B 4 4"], + c2w: Float[Tensor, "B 4 4"], + seed: int = 0, + **kwargs, + ) -> Float[Tensor, "N H W 3"]: + # input text embeddings, view-independent + text_embeddings = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, view_dependent_prompting=False + ) + + if self.cfg.camera_condition_type == "extrinsics": + camera_condition = c2w + elif self.cfg.camera_condition_type == "mvp": + camera_condition = mvp_mtx + else: + raise ValueError( + f"Unknown camera_condition_type {self.cfg.camera_condition_type}" + ) + + B = elevation.shape[0] + camera_condition_cfg = torch.cat( + [ + camera_condition.view(B, -1), + torch.zeros_like(camera_condition.view(B, -1)), + ], + dim=0, + ) + + generator = torch.Generator(device=self.device).manual_seed(seed) + return self._sample( + sample_scheduler=self.scheduler_lora_sample, + pipe=self.pipe_lora, + text_embeddings=text_embeddings, + num_inference_steps=25, + guidance_scale=self.cfg.guidance_scale_lora, + class_labels=camera_condition_cfg, + cross_attention_kwargs={"scale": 1.0}, + generator=generator, + ) + + @torch.cuda.amp.autocast(enabled=False) + def forward_unet( + self, + unet: UNet2DConditionModel, + latents: Float[Tensor, "..."], + t: Float[Tensor, "..."], + encoder_hidden_states: Float[Tensor, "..."], + class_labels: Optional[Float[Tensor, "B 16"]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Float[Tensor, "..."]: + input_dtype = latents.dtype + return unet( + latents.to(self.weights_dtype), + t.to(self.weights_dtype), + encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype), + class_labels=class_labels, + cross_attention_kwargs=cross_attention_kwargs, + ).sample.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def encode_images( + self, imgs: Float[Tensor, "B 3 512 512"] + ) -> Float[Tensor, "B 4 64 64"]: + input_dtype = imgs.dtype + imgs = imgs * 2.0 - 1.0 + posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + return latents.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def decode_latents( + self, + latents: Float[Tensor, "B 4 H W"], + latent_height: int = 64, + latent_width: int = 64, + ) -> Float[Tensor, "B 3 512 512"]: + input_dtype = latents.dtype + latents = F.interpolate( + latents, (latent_height, latent_width), mode="bilinear", align_corners=False + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents.to(self.weights_dtype)).sample + image = (image * 0.5 + 0.5).clamp(0, 1) + return image.to(input_dtype) + + @contextmanager + def disable_unet_class_embedding(self, unet: UNet2DConditionModel): + class_embedding = unet.class_embedding + try: + unet.class_embedding = None + yield unet + finally: + unet.class_embedding = class_embedding + + def compute_grad_du( + self, + latents: Float[Tensor, "B 4 64 64"], + rgb_BCHW_512: Float[Tensor, "B 3 512 512"], + text_embeddings: Float[Tensor, "BB 77 768"], + mask = None, + **kwargs, + ): + batch_size, _, _, _ = latents.shape + rgb_BCHW_512 = F.interpolate(rgb_BCHW_512, (512, 512), mode="bilinear") + assert batch_size == 1 + need_diffusion = ( + self.global_step % self.cfg.per_du_step == 0 + and self.global_step > self.cfg.start_du_step + ) + guidance_out = {} + + if need_diffusion: + t = torch.randint( + self.min_step, + self.max_step, + [1], + dtype=torch.long, + device=self.device, + ) + self.scheduler.config.num_train_timesteps = t.item() + self.scheduler.set_timesteps(self.cfg.du_diffusion_steps) + + if mask is not None: + mask = F.interpolate(mask, (64, 64), mode="bilinear", antialias=True) + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents = self.scheduler.add_noise(latents, noise, t) # type: ignore + for i, timestep in enumerate(self.scheduler.timesteps): + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + latent_model_input = torch.cat([latents] * 2) + with self.disable_unet_class_embedding(self.unet) as unet: + cross_attention_kwargs = ( + {"scale": 0.0} if self.single_model else None + ) + noise_pred = self.forward_unet( + unet, + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=cross_attention_kwargs, + ) + # perform classifier-free guidance + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if mask is not None: + noise_pred = mask * noise_pred + (1 - mask) * noise + # get previous sample, continue loop + latents = self.scheduler.step( + noise_pred, timestep, latents + ).prev_sample + edit_images = self.decode_latents(latents) + edit_images = F.interpolate( + edit_images, (512, 512), mode="bilinear" + ).permute(0, 2, 3, 1) + gt_rgb = edit_images + # import cv2 + # import numpy as np + # mask_temp = mask_BCHW_512.permute(0,2,3,1) + # # edit_images = edit_images * mask_temp + torch.rand(3)[None, None, None].to(self.device).repeat(*edit_images.shape[:-1],1) * (1 - mask_temp) + # temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8) + # cv2.imwrite(f".threestudio_cache/pig_sd_noise_500/test_{kwargs.get('name', 'none')}.jpg", temp[:, :, ::-1]) + + guidance_out.update( + { + "loss_l1": torch.nn.functional.l1_loss( + rgb_BCHW_512, gt_rgb.permute(0, 3, 1, 2), reduction="sum" + ), + "loss_p": self.perceptual_loss( + rgb_BCHW_512.contiguous(), + gt_rgb.permute(0, 3, 1, 2).contiguous(), + ).sum(), + "edit_image": edit_images.detach() + } + ) + + return guidance_out + + def compute_grad_vsd( + self, + latents: Float[Tensor, "B 4 64 64"], + text_embeddings_vd: Float[Tensor, "BB 77 768"], + text_embeddings: Float[Tensor, "BB 77 768"], + camera_condition: Float[Tensor, "B 4 4"], + ): + B, C, H, W = latents.shape + + with torch.no_grad(): + # random timestamp + t = torch.randint( + self.min_step, + self.max_step + 1, + [B], + dtype=torch.long, + device=self.device, + ) + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2, dim=0) + cross_attention_kwargs = {"scale": 0.0} + noise_pred_pretrain = self.forward_unet( + self.train_unet, + latent_model_input, + torch.cat([t] * 2), + encoder_hidden_states=text_embeddings_vd, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # use view-independent text embeddings in LoRA + text_embeddings_cond, _ = text_embeddings.chunk(2) + noise_pred_est = self.forward_unet( + self.train_unet_lora, + latent_model_input, + torch.cat([t] * 2), + encoder_hidden_states=torch.cat([text_embeddings_cond] * 2), + # class_labels=torch.cat( + # [ + # camera_condition.view(B, -1), + # torch.zeros_like(camera_condition.view(B, -1)), + # ], + # dim=0, + # ), + cross_attention_kwargs={"scale": 0.0}, + ) + + + # TODO: more general cases + assert self.scheduler.config.prediction_type == "epsilon" + if self.scheduler_lora.config.prediction_type == "v_prediction": + alphas_cumprod = self.scheduler_lora.alphas_cumprod.to( + device=latents_noisy.device, dtype=latents_noisy.dtype + ) + alpha_t = alphas_cumprod[t] ** 0.5 + sigma_t = (1 - alphas_cumprod[t]) ** 0.5 + + noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).view( + -1, 1, 1, 1 + ) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).view(-1, 1, 1, 1) + + ( + noise_pred_est_camera, + noise_pred_est_uncond, + ) = noise_pred_est.chunk(2) + + # NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance + noise_pred_est = noise_pred_est_uncond + self.cfg.guidance_scale_lora * ( + noise_pred_est_camera - noise_pred_est_uncond + ) + + ( + noise_pred_pretrain_text, + noise_pred_pretrain_uncond, + ) = noise_pred_pretrain.chunk(2) + + # NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance + noise_pred_pretrain = noise_pred_pretrain_uncond + self.cfg.guidance_scale * ( + noise_pred_pretrain_text - noise_pred_pretrain_uncond + ) + + w = (1 - self.alphas[t]).view(-1, 1, 1, 1) + + grad = w * (noise_pred_pretrain - noise_pred_est) + return grad + + def compute_grad_vsd_hifa( + self, + latents: Float[Tensor, "B 4 64 64"], + text_embeddings_vd: Float[Tensor, "BB 77 768"], + text_embeddings: Float[Tensor, "BB 77 768"], + camera_condition: Float[Tensor, "B 4 4"], + mask=None, + ): + B, _, DH, DW = latents.shape + rgb = self.decode_latents(latents) + self.name = "hifa" + + if mask is not None: + mask = F.interpolate(mask, (DH, DW), mode="bilinear", antialias=True) + with torch.no_grad(): + # random timestamp + t = torch.randint( + self.min_step, + self.max_step + 1, + [B], + dtype=torch.long, + device=self.device, + ) + w = (1 - self.alphas[t]).view(-1, 1, 1, 1) + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler_sample.add_noise(latents, noise, t) + latents_noisy_lora = self.scheduler_lora_sample.add_noise(latents, noise, t) + # pred noise + + self.scheduler_sample.config.num_train_timesteps = t.item() + self.scheduler_sample.set_timesteps(t.item() // 50 + 1) + self.scheduler_lora_sample.config.num_train_timesteps = t.item() + self.scheduler_lora_sample.set_timesteps(t.item() // 50 + 1) + + for i, timestep in enumerate(self.scheduler_sample.timesteps): + # for i, timestep in tqdm(enumerate(self.scheduler.timesteps)): + latent_model_input = torch.cat([latents_noisy] * 2, dim=0) + latent_model_input_lora = torch.cat([latents_noisy_lora] * 2, dim=0) + + # print(latent_model_input.shape) + with self.disable_unet_class_embedding(self.unet) as unet: + cross_attention_kwargs = {"scale": 0.0} if self.single_model else None + noise_pred_pretrain = self.forward_unet( + unet, + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings_vd, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # use view-independent text embeddings in LoRA + noise_pred_est = self.forward_unet( + self.unet_lora, + latent_model_input_lora, + timestep, + encoder_hidden_states=text_embeddings, + class_labels=torch.cat( + [ + camera_condition.view(B, -1), + torch.zeros_like(camera_condition.view(B, -1)), + ], + dim=0, + ), + cross_attention_kwargs={"scale": 1.0}, + ) + + ( + noise_pred_pretrain_text, + noise_pred_pretrain_uncond, + ) = noise_pred_pretrain.chunk(2) + + # NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance + noise_pred_pretrain = noise_pred_pretrain_uncond + self.cfg.guidance_scale * ( + noise_pred_pretrain_text - noise_pred_pretrain_uncond + ) + if mask is not None: + noise_pred_pretrain = mask * noise_pred_pretrain + (1 - mask) * noise + + ( + noise_pred_est_text, + noise_pred_est_uncond, + ) = noise_pred_est.chunk(2) + + # NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance + # noise_pred_est = noise_pred_est_uncond + self.cfg.guidance_scale_lora * ( + # noise_pred_est_text - noise_pred_est_uncond + # ) + noise_pred_est = noise_pred_est_text + if mask is not None: + noise_pred_est = mask * noise_pred_est + (1 - mask) * noise + + latents_noisy = self.scheduler_sample.step(noise_pred_pretrain, timestep, latents_noisy).prev_sample + latents_noisy_lora = self.scheduler_lora_sample.step(noise_pred_est, timestep, latents_noisy_lora).prev_sample + + # noise = torch.randn_like(latents) + # latents_noisy = self.scheduler.step(noise_pred_pretrain, timestep, latents_noisy).prev_sample + # latents_noisy = mask * latents_noisy + (1-mask) * latents + # latents_noisy = self.scheduler_sample.add_noise(latents_noisy, noise, timestep) + + # latents_noisy_lora = self.scheduler_lora.step(noise_pred_est, timestep, latents_noisy_lora).prev_sample + # latents_noisy_lora = mask * latents_noisy_lora + (1-mask) * latents + # latents_noisy_lora = self.scheduler_lora_sample.add_noise(latents_noisy_lora, noise, timestep) + + hifa_images = self.decode_latents(latents_noisy) + hifa_lora_images = self.decode_latents(latents_noisy_lora) + + import cv2 + import numpy as np + if mask is not None: + print('hifa mask!') + prefix = 'vsd_mask' + else: + prefix = '' + temp = (hifa_images.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8) + cv2.imwrite(".threestudio_cache/%s%s_test.jpg" % (prefix, self.name), temp[:, :, ::-1]) + temp = (hifa_lora_images.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8) + cv2.imwrite(".threestudio_cache/%s%s_test_lora.jpg" % (prefix, self.name), temp[:, :, ::-1]) + + target = (latents_noisy - latents_noisy_lora + latents).detach() + # target = latents_noisy.detach() + targets_rgb = self.decode_latents(target) + # targets_rgb = (hifa_images - hifa_lora_images + rgb).detach() + temp = (targets_rgb.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8) + cv2.imwrite(".threestudio_cache/%s_target.jpg" % self.name, temp[:, :, ::-1]) + + return w * 0.5 * F.mse_loss(target, latents, reduction='sum') + + def train_lora( + self, + latents: Float[Tensor, "B 4 64 64"], + text_embeddings: Float[Tensor, "BB 77 768"], + camera_condition: Float[Tensor, "B 4 4"], + ): + B = latents.shape[0] + latents = latents.detach().repeat(self.cfg.lora_n_timestamp_samples, 1, 1, 1) + + t = torch.randint( + int(self.num_train_timesteps * 0.0), + int(self.num_train_timesteps * 1.0), + [B * self.cfg.lora_n_timestamp_samples], + dtype=torch.long, + device=self.device, + ) + + noise = torch.randn_like(latents) + noisy_latents = self.scheduler_lora.add_noise(latents, noise, t) + if self.scheduler_lora.config.prediction_type == "epsilon": + target = noise + elif self.scheduler_lora.config.prediction_type == "v_prediction": + target = self.scheduler_lora.get_velocity(latents, noise, t) + else: + raise ValueError( + f"Unknown prediction type {self.scheduler_lora.config.prediction_type}" + ) + # use view-independent text embeddings in LoRA + text_embeddings_cond, _ = text_embeddings.chunk(2) + if self.cfg.lora_cfg_training and random.random() < 0.1: + camera_condition = torch.zeros_like(camera_condition) + noise_pred = self.forward_unet( + self.train_unet_lora, + noisy_latents, + t, + encoder_hidden_states=text_embeddings_cond.repeat( + self.cfg.lora_n_timestamp_samples, 1, 1 + ), + # class_labels=camera_condition.view(B, -1).repeat( + # self.cfg.lora_n_timestamp_samples, 1 + # ), + cross_attention_kwargs={"scale": 0.0}, + ) + return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + def train_pretrain( + self, + latents: Float[Tensor, "B 4 64 64"], + text_embeddings: Float[Tensor, "BB 77 768"], + camera_condition: Float[Tensor, "B 4 4"], + sample_new_img=False, + ): + B = latents.shape[0] + if sample_new_img or len(self.cache_frames) == 0: + latents = latents.detach().repeat(self.cfg.lora_pretrain_n_timestamp_samples, 1, 1, 1) + images_sample = self._sample( + pipe=self.pipe_fix, + sample_scheduler=self.scheduler_sample, + text_embeddings=text_embeddings, + num_inference_steps=25, + guidance_scale=7.5, + cross_attention_kwargs = {"scale": 0.0}, + latents_inp=latents, + ).permute(0,3,1,2) + from torchvision.utils import save_image + save_image(images_sample, f".threestudio_cache/test_sample.jpg") + self.cache_frames.append(images_sample) + + self.pipe.unet = self.train_unet + pretrain_images_sample = self._sample( + pipe=self.pipe, + sample_scheduler=self.scheduler_sample, + text_embeddings=text_embeddings, + num_inference_steps=25, + guidance_scale=1.0, + cross_attention_kwargs = {"scale": 0.0}, + latents_inp=latents, + ).permute(0,3,1,2) + save_image(pretrain_images_sample, f".threestudio_cache/test_pretrain.jpg") + if len(self.cache_frames) > 10: + self.cache_frames.pop(0) + random_idx = torch.randint(0, len(self.cache_frames), [1]).item() + images_sample = self.cache_frames[random_idx] + + with torch.no_grad(): + latents_sample = self.get_latents(images_sample, rgb_as_latents=False) + + t = torch.randint( + int(self.num_train_timesteps * 0.0), + int(self.num_train_timesteps * 1.0), + [B * self.cfg.lora_pretrain_n_timestamp_samples], + dtype=torch.long, + device=self.device, + ) + + noise = torch.randn_like(latents) + noisy_latents = self.scheduler.add_noise(latents_sample, noise, t) + if self.scheduler.config.prediction_type == "epsilon": + target = noise + elif self.scheduler.config.prediction_type == "v_prediction": + target = self.scheduler.get_velocity(latents_sample, noise, t) + else: + raise ValueError( + f"Unknown prediction type {self.scheduler.config.prediction_type}" + ) + # FIXME: use view-independent or dependent embeddings? + text_embeddings_cond, _ = text_embeddings.chunk(2) + if self.cfg.lora_pretrain_cfg_training and random.random() < 0.1: + text_embeddings_cond = torch.zeros_like(text_embeddings_cond) + noise_pred = self.forward_unet( + self.train_unet, + noisy_latents, + t, + encoder_hidden_states=text_embeddings_cond.repeat( + self.cfg.lora_pretrain_n_timestamp_samples, 1, 1 + ) + ) + loss_pretrain = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + return loss_pretrain + + def get_latents( + self, rgb_BCHW: Float[Tensor, "B C H W"], rgb_as_latents=False + ) -> Float[Tensor, "B 4 64 64"]: + if rgb_as_latents: + latents = F.interpolate( + rgb_BCHW, (64, 64), mode="bilinear", align_corners=False + ) + else: + rgb_BCHW_512 = F.interpolate( + rgb_BCHW, (512, 512), mode="bilinear", align_corners=False + ) + # encode image into latents with vae + latents = self.encode_images(rgb_BCHW_512) + return latents + + def forward( + self, + rgb: Float[Tensor, "B H W C"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + mvp_mtx: Float[Tensor, "B 4 4"], + c2w: Float[Tensor, "B 4 4"], + rgb_as_latents=False, + mask: Float[Tensor, "B H W 1"] = None, + lora_prompt_utils = None, + **kwargs, + ): + batch_size = rgb.shape[0] + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + latents = self.get_latents(rgb_BCHW, rgb_as_latents=rgb_as_latents) + + if mask is not None: mask = mask.permute(0, 3, 1, 2) + + # view-dependent text embeddings + text_embeddings_vd = prompt_utils.get_text_embeddings( + elevation, + azimuth, + camera_distances, + view_dependent_prompting=self.cfg.view_dependent_prompting, + ) + if lora_prompt_utils is not None: + # input text embeddings, view-independent + text_embeddings = lora_prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, view_dependent_prompting=False + ) + else: + # input text embeddings, view-independent + text_embeddings = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, view_dependent_prompting=False + ) + + if self.cfg.camera_condition_type == "extrinsics": + camera_condition = c2w + elif self.cfg.camera_condition_type == "mvp": + camera_condition = mvp_mtx + else: + raise ValueError( + f"Unknown camera_condition_type {self.cfg.camera_condition_type}" + ) + + do_update_pretrain = (self.cfg.only_pretrain_step > 0) and ( + (self.global_step % self.cfg.only_pretrain_step) < (self.cfg.only_pretrain_step // 5) + ) + + guidance_out = {} + if do_update_pretrain: + sample_new_img = self.global_step % self.cfg.per_update_pretrain_step == 0 + loss_pretrain = self.train_pretrain(latents, text_embeddings_vd, camera_condition, sample_new_img=sample_new_img) + guidance_out.update({ + "loss_pretrain": loss_pretrain, + "min_step": self.min_step, + "max_step": self.max_step, + }) + return guidance_out + + grad = self.compute_grad_vsd( + latents, text_embeddings_vd, text_embeddings, camera_condition + ) + + grad = torch.nan_to_num(grad) + # clip grad for stable training? + if self.grad_clip_val is not None: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + + # reparameterization trick + # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad + target = (latents - grad).detach() + loss_vsd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size + + loss_lora = self.train_lora(latents, text_embeddings, camera_condition) + + guidance_out.update({ + "loss_sd": loss_vsd, + "loss_lora": loss_lora, + "grad_norm": grad.norm(), + "min_step": self.min_step, + "max_step": self.max_step, + }) + + if self.cfg.use_du: + du_out = self.compute_grad_du(latents, rgb_BCHW, text_embeddings_vd, mask=mask) + guidance_out.update(du_out) + + return guidance_out + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # clip grad for stable training as demonstrated in + # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation + # http://arxiv.org/abs/2303.15413 + if self.cfg.grad_clip is not None: + self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) + self.global_step = global_step + self.set_min_max_steps( + min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), + max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), + ) \ No newline at end of file diff --git a/threestudio/models/guidance/stable_diffusion_guidance.py b/threestudio/models/guidance/stable_diffusion_guidance.py new file mode 100644 index 0000000..5821340 --- /dev/null +++ b/threestudio/models/guidance/stable_diffusion_guidance.py @@ -0,0 +1,632 @@ +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers import DDIMScheduler, DDPMScheduler, StableDiffusionPipeline +from diffusers.utils.import_utils import is_xformers_available +from tqdm import tqdm + +import threestudio +from threestudio.models.prompt_processors.base import PromptProcessorOutput +from threestudio.utils.base import BaseObject +from threestudio.utils.misc import C, cleanup, parse_version +from threestudio.utils.ops import perpendicular_component +from threestudio.utils.typing import * + + +@threestudio.register("stable-diffusion-guidance") +class StableDiffusionGuidance(BaseObject): + @dataclass + class Config(BaseObject.Config): + cache_dir: Optional[str] = None + local_files_only: Optional[bool] = False + pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5" + enable_memory_efficient_attention: bool = False + enable_sequential_cpu_offload: bool = False + enable_attention_slicing: bool = False + enable_channels_last_format: bool = False + guidance_scale: float = 100.0 + grad_clip: Optional[ + Any + ] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000]) + time_prior: Optional[Any] = None # [w1,w2,s1,s2] + half_precision_weights: bool = True + + min_step_percent: float = 0.02 + max_step_percent: float = 0.98 + max_step_percent_annealed: float = 0.5 + anneal_start_step: Optional[int] = None + + use_sjc: bool = False + var_red: bool = True + weighting_strategy: str = "sds" + + token_merging: bool = False + token_merging_params: Optional[dict] = field(default_factory=dict) + + view_dependent_prompting: bool = True + + """Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items.""" + max_items_eval: int = 4 + + cfg: Config + + def configure(self) -> None: + threestudio.info(f"Loading Stable Diffusion ...") + + self.weights_dtype = ( + torch.float16 if self.cfg.half_precision_weights else torch.float32 + ) + + pipe_kwargs = { + "tokenizer": None, + "safety_checker": None, + "feature_extractor": None, + "requires_safety_checker": False, + "torch_dtype": self.weights_dtype, + "cache_dir": self.cfg.cache_dir, + "local_files_only": self.cfg.local_files_only + } + self.pipe = StableDiffusionPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, + **pipe_kwargs, + ).to(self.device) + + if self.cfg.enable_memory_efficient_attention: + if parse_version(torch.__version__) >= parse_version("2"): + threestudio.info( + "PyTorch2.0 uses memory efficient attention by default." + ) + elif not is_xformers_available(): + threestudio.warn( + "xformers is not available, memory efficient attention is not enabled." + ) + else: + self.pipe.enable_xformers_memory_efficient_attention() + + if self.cfg.enable_sequential_cpu_offload: + self.pipe.enable_sequential_cpu_offload() + + if self.cfg.enable_attention_slicing: + self.pipe.enable_attention_slicing(1) + + if self.cfg.enable_channels_last_format: + self.pipe.unet.to(memory_format=torch.channels_last) + + del self.pipe.text_encoder + cleanup() + + # Create model + self.vae = self.pipe.vae.eval() + self.unet = self.pipe.unet.eval() + + for p in self.vae.parameters(): + p.requires_grad_(False) + for p in self.unet.parameters(): + p.requires_grad_(False) + + if self.cfg.token_merging: + import tomesd + + tomesd.apply_patch(self.unet, **self.cfg.token_merging_params) + + if self.cfg.use_sjc: + # score jacobian chaining use DDPM + self.scheduler = DDPMScheduler.from_pretrained( + self.cfg.pretrained_model_name_or_path, + subfolder="scheduler", + torch_dtype=self.weights_dtype, + beta_start=0.00085, + beta_end=0.0120, + beta_schedule="scaled_linear", + cache_dir=self.cfg.cache_dir, + ) + else: + self.scheduler = DDIMScheduler.from_pretrained( + self.cfg.pretrained_model_name_or_path, + subfolder="scheduler", + torch_dtype=self.weights_dtype, + cache_dir=self.cfg.cache_dir, + local_files_only=self.cfg.local_files_only, + ) + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.set_min_max_steps() # set to default value + if self.cfg.time_prior is not None: + m1, m2, s1, s2 = self.cfg.time_prior + weights = torch.cat( + ( + torch.exp( + -((torch.arange(self.num_train_timesteps, m1, -1) - m1) ** 2) + / (2 * s1**2) + ), + torch.ones(m1 - m2 + 1), + torch.exp( + -((torch.arange(m2 - 1, 0, -1) - m2) ** 2) / (2 * s2**2) + ), + ) + ) + weights = weights / torch.sum(weights) + self.time_prior_acc_weights = torch.cumsum(weights, dim=0) + + self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to( + self.device + ) + if self.cfg.use_sjc: + # score jacobian chaining need mu + self.us: Float[Tensor, "..."] = torch.sqrt((1 - self.alphas) / self.alphas) + + self.grad_clip_val: Optional[float] = None + + threestudio.info(f"Loaded Stable Diffusion!") + + @torch.cuda.amp.autocast(enabled=False) + def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): + self.min_step = int(self.num_train_timesteps * min_step_percent) + self.max_step = int(self.num_train_timesteps * max_step_percent) + + @torch.cuda.amp.autocast(enabled=False) + def forward_unet( + self, + latents: Float[Tensor, "..."], + t: Float[Tensor, "..."], + encoder_hidden_states: Float[Tensor, "..."], + ) -> Float[Tensor, "..."]: + input_dtype = latents.dtype + return self.unet( + latents.to(self.weights_dtype), + t.to(self.weights_dtype), + encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype), + ).sample.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def encode_images( + self, imgs: Float[Tensor, "B 3 512 512"] + ) -> Float[Tensor, "B 4 64 64"]: + input_dtype = imgs.dtype + imgs = imgs * 2.0 - 1.0 + posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + return latents.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def decode_latents( + self, + latents: Float[Tensor, "B 4 H W"], + latent_height: int = 64, + latent_width: int = 64, + ) -> Float[Tensor, "B 3 512 512"]: + input_dtype = latents.dtype + latents = F.interpolate( + latents, (latent_height, latent_width), mode="bilinear", align_corners=False + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents.to(self.weights_dtype)).sample + image = (image * 0.5 + 0.5).clamp(0, 1) + return image.to(input_dtype) + + def compute_grad_sds( + self, + latents: Float[Tensor, "B 4 64 64"], + t: Int[Tensor, "B"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + ): + batch_size = elevation.shape[0] + + if prompt_utils.use_perp_neg: + ( + text_embeddings, + neg_guidance_weights, + ) = prompt_utils.get_text_embeddings_perp_neg( + elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting + ) + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + latent_model_input = torch.cat([latents_noisy] * 4, dim=0) + noise_pred = self.forward_unet( + latent_model_input, + torch.cat([t] * 4), + encoder_hidden_states=text_embeddings, + ) # (4B, 3, 64, 64) + + noise_pred_text = noise_pred[:batch_size] + noise_pred_uncond = noise_pred[batch_size : batch_size * 2] + noise_pred_neg = noise_pred[batch_size * 2 :] + + e_pos = noise_pred_text - noise_pred_uncond + accum_grad = 0 + n_negative_prompts = neg_guidance_weights.shape[-1] + for i in range(n_negative_prompts): + e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond + accum_grad += neg_guidance_weights[:, i].view( + -1, 1, 1, 1 + ) * perpendicular_component(e_i_neg, e_pos) + + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + e_pos + accum_grad + ) + else: + neg_guidance_weights = None + text_embeddings = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting + ) + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) # TODO: use torch generator + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2, dim=0) + noise_pred = self.forward_unet( + latent_model_input, + torch.cat([t] * 2), + encoder_hidden_states=text_embeddings, + ) + + # perform guidance (high scale from paper!) + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_text + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + if self.cfg.weighting_strategy == "sds": + # w(t), sigma_t^2 + w = (1 - self.alphas[t]).view(-1, 1, 1, 1) + elif self.cfg.weighting_strategy == "uniform": + w = 1 + elif self.cfg.weighting_strategy == "fantasia3d": + w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1) + else: + raise ValueError( + f"Unknown weighting strategy: {self.cfg.weighting_strategy}" + ) + + grad = w * (noise_pred - noise) + + guidance_eval_utils = { + "use_perp_neg": prompt_utils.use_perp_neg, + "neg_guidance_weights": neg_guidance_weights, + "text_embeddings": text_embeddings, + "t_orig": t, + "latents_noisy": latents_noisy, + "noise_pred": noise_pred, + } + + return grad, guidance_eval_utils + + def compute_grad_sjc( + self, + latents: Float[Tensor, "B 4 64 64"], + t: Int[Tensor, "B"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + ): + batch_size = elevation.shape[0] + + sigma = self.us[t] + sigma = sigma.view(-1, 1, 1, 1) + + if prompt_utils.use_perp_neg: + ( + text_embeddings, + neg_guidance_weights, + ) = prompt_utils.get_text_embeddings_perp_neg( + elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting + ) + with torch.no_grad(): + noise = torch.randn_like(latents) + y = latents + zs = y + sigma * noise + scaled_zs = zs / torch.sqrt(1 + sigma**2) + # pred noise + latent_model_input = torch.cat([scaled_zs] * 4, dim=0) + noise_pred = self.forward_unet( + latent_model_input, + torch.cat([t] * 4), + encoder_hidden_states=text_embeddings, + ) # (4B, 3, 64, 64) + + noise_pred_text = noise_pred[:batch_size] + noise_pred_uncond = noise_pred[batch_size : batch_size * 2] + noise_pred_neg = noise_pred[batch_size * 2 :] + + e_pos = noise_pred_text - noise_pred_uncond + accum_grad = 0 + n_negative_prompts = neg_guidance_weights.shape[-1] + for i in range(n_negative_prompts): + e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond + accum_grad += neg_guidance_weights[:, i].view( + -1, 1, 1, 1 + ) * perpendicular_component(e_i_neg, e_pos) + + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + e_pos + accum_grad + ) + else: + neg_guidance_weights = None + text_embeddings = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting + ) + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) # TODO: use torch generator + y = latents + + zs = y + sigma * noise + scaled_zs = zs / torch.sqrt(1 + sigma**2) + + # pred noise + latent_model_input = torch.cat([scaled_zs] * 2, dim=0) + noise_pred = self.forward_unet( + latent_model_input, + torch.cat([t] * 2), + encoder_hidden_states=text_embeddings, + ) + + # perform guidance (high scale from paper!) + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_text + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + Ds = zs - sigma * noise_pred + + if self.cfg.var_red: + grad = -(Ds - y) / sigma + else: + grad = -(Ds - zs) / sigma + + guidance_eval_utils = { + "use_perp_neg": prompt_utils.use_perp_neg, + "neg_guidance_weights": neg_guidance_weights, + "text_embeddings": text_embeddings, + "t_orig": t, + "latents_noisy": scaled_zs, + "noise_pred": noise_pred, + } + + return grad, guidance_eval_utils + + def __call__( + self, + rgb: Float[Tensor, "B H W C"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + rgb_as_latents=False, + guidance_eval=False, + current_step_ratio=None, + **kwargs, + ): + batch_size = rgb.shape[0] + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + latents: Float[Tensor, "B 4 64 64"] + if rgb_as_latents: + latents = F.interpolate( + rgb_BCHW, (64, 64), mode="bilinear", align_corners=False + ) + else: + rgb_BCHW_512 = F.interpolate( + rgb_BCHW, (512, 512), mode="bilinear", align_corners=False + ) + # encode image into latents with vae + latents = self.encode_images(rgb_BCHW_512) + + if self.cfg.time_prior is not None: + time_index = torch.where( + (self.time_prior_acc_weights - current_step_ratio) > 0 + )[0][0] + if time_index == 0 or torch.abs( + self.time_prior_acc_weights[time_index] - current_step_ratio + ) < torch.abs( + self.time_prior_acc_weights[time_index - 1] - current_step_ratio + ): + t = self.num_train_timesteps - time_index + else: + t = self.num_train_timesteps - time_index + 1 + t = torch.clip(t, self.min_step, self.max_step + 1) + t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) + + else: + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint( + self.min_step, + self.max_step + 1, + [batch_size], + dtype=torch.long, + device=self.device, + ) + + if self.cfg.use_sjc: + grad, guidance_eval_utils = self.compute_grad_sjc( + latents, t, prompt_utils, elevation, azimuth, camera_distances + ) + else: + grad, guidance_eval_utils = self.compute_grad_sds( + latents, t, prompt_utils, elevation, azimuth, camera_distances + ) + + grad = torch.nan_to_num(grad) + # clip grad for stable training? + if self.grad_clip_val is not None: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + + # loss = SpecifyGradient.apply(latents, grad) + # SpecifyGradient is not straghtforward, use a reparameterization trick instead + target = (latents - grad).detach() + # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad + loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size + + guidance_out = { + "loss_sd": loss_sds, + "grad_norm": grad.norm(), + "min_step": self.min_step, + "max_step": self.max_step, + } + + if guidance_eval: + guidance_eval_out = self.guidance_eval(**guidance_eval_utils) + texts = [] + for n, e, a, c in zip( + guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances + ): + texts.append( + f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}" + ) + guidance_eval_out.update({"texts": texts}) + guidance_out.update({"eval": guidance_eval_out}) + + return guidance_out + + @torch.cuda.amp.autocast(enabled=False) + @torch.no_grad() + def get_noise_pred( + self, + latents_noisy, + t, + text_embeddings, + use_perp_neg=False, + neg_guidance_weights=None, + ): + batch_size = latents_noisy.shape[0] + + if use_perp_neg: + # pred noise + latent_model_input = torch.cat([latents_noisy] * 4, dim=0) + noise_pred = self.forward_unet( + latent_model_input, + torch.cat([t.reshape(1)] * 4).to(self.device), + encoder_hidden_states=text_embeddings, + ) # (4B, 3, 64, 64) + + noise_pred_text = noise_pred[:batch_size] + noise_pred_uncond = noise_pred[batch_size : batch_size * 2] + noise_pred_neg = noise_pred[batch_size * 2 :] + + e_pos = noise_pred_text - noise_pred_uncond + accum_grad = 0 + n_negative_prompts = neg_guidance_weights.shape[-1] + for i in range(n_negative_prompts): + e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond + accum_grad += neg_guidance_weights[:, i].view( + -1, 1, 1, 1 + ) * perpendicular_component(e_i_neg, e_pos) + + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + e_pos + accum_grad + ) + else: + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2, dim=0) + noise_pred = self.forward_unet( + latent_model_input, + torch.cat([t.reshape(1)] * 2).to(self.device), + encoder_hidden_states=text_embeddings, + ) + # perform guidance (high scale from paper!) + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_text + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + return noise_pred + + @torch.cuda.amp.autocast(enabled=False) + @torch.no_grad() + def guidance_eval( + self, + t_orig, + text_embeddings, + latents_noisy, + noise_pred, + use_perp_neg=False, + neg_guidance_weights=None, + ): + # use only 50 timesteps, and find nearest of those to t + self.scheduler.set_timesteps(50) + self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device) + bs = ( + min(self.cfg.max_items_eval, latents_noisy.shape[0]) + if self.cfg.max_items_eval > 0 + else latents_noisy.shape[0] + ) # batch size + large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[ + :bs + ].unsqueeze( + -1 + ) # sized [bs,50] > [bs,1] + idxs = torch.min(large_enough_idxs, dim=1)[1] + t = self.scheduler.timesteps_gpu[idxs] + + fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy()) + imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1) + + # get prev latent + latents_1step = [] + pred_1orig = [] + for b in range(bs): + step_output = self.scheduler.step( + noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1 + ) + latents_1step.append(step_output["prev_sample"]) + pred_1orig.append(step_output["pred_original_sample"]) + latents_1step = torch.cat(latents_1step) + pred_1orig = torch.cat(pred_1orig) + imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1) + imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1) + + latents_final = [] + for b, i in enumerate(idxs): + latents = latents_1step[b : b + 1] + text_emb = ( + text_embeddings[ + [b, b + len(idxs), b + 2 * len(idxs), b + 3 * len(idxs)], ... + ] + if use_perp_neg + else text_embeddings[[b, b + len(idxs)], ...] + ) + neg_guid = neg_guidance_weights[b : b + 1] if use_perp_neg else None + for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False): + # pred noise + noise_pred = self.get_noise_pred( + latents, t, text_emb, use_perp_neg, neg_guid + ) + # get prev latent + latents = self.scheduler.step(noise_pred, t, latents, eta=1)[ + "prev_sample" + ] + latents_final.append(latents) + + latents_final = torch.cat(latents_final) + imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1) + + return { + "bs": bs, + "noise_levels": fracs, + "imgs_noisy": imgs_noisy, + "imgs_1step": imgs_1step, + "imgs_1orig": imgs_1orig, + "imgs_final": imgs_final, + } + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # clip grad for stable training as demonstrated in + # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation + # http://arxiv.org/abs/2303.15413 + if self.cfg.grad_clip is not None: + self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) + + self.set_min_max_steps( + min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), + max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), + ) \ No newline at end of file diff --git a/threestudio/models/guidance/stable_diffusion_unified_guidance.py b/threestudio/models/guidance/stable_diffusion_unified_guidance.py new file mode 100644 index 0000000..c7fc4f0 --- /dev/null +++ b/threestudio/models/guidance/stable_diffusion_unified_guidance.py @@ -0,0 +1,729 @@ +import random +from contextlib import contextmanager +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDPMScheduler, + DPMSolverSinglestepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.embeddings import TimestepEmbedding +from diffusers.utils.import_utils import is_xformers_available +from tqdm import tqdm + +import threestudio +from threestudio.models.networks import ToDTypeWrapper +from threestudio.models.prompt_processors.base import PromptProcessorOutput +from threestudio.utils.base import BaseModule +from threestudio.utils.misc import C, cleanup, enable_gradient, parse_version +from threestudio.utils.ops import perpendicular_component +from threestudio.utils.typing import * + + +@threestudio.register("stable-diffusion-unified-guidance") +class StableDiffusionUnifiedGuidance(BaseModule): + @dataclass + class Config(BaseModule.Config): + cache_dir: Optional[str] = None + local_files_only: Optional[bool] = False + + # guidance type, in ["sds", "vsd"] + guidance_type: str = "sds" + + pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5" + guidance_scale: float = 100.0 + weighting_strategy: str = "dreamfusion" + view_dependent_prompting: bool = True + + min_step_percent: Any = 0.02 + max_step_percent: Any = 0.98 + grad_clip: Optional[Any] = None + + return_rgb_1step_orig: bool = False + return_rgb_multistep_orig: bool = False + n_rgb_multistep_orig_steps: int = 4 + + # TODO + # controlnet + controlnet_model_name_or_path: Optional[str] = None + preprocessor: Optional[str] = None + control_scale: float = 1.0 + + # TODO + # lora + lora_model_name_or_path: Optional[str] = None + + # efficiency-related configurations + half_precision_weights: bool = True + enable_memory_efficient_attention: bool = False + enable_sequential_cpu_offload: bool = False + enable_attention_slicing: bool = False + enable_channels_last_format: bool = False + token_merging: bool = False + token_merging_params: Optional[dict] = field(default_factory=dict) + + # VSD configurations, only used when guidance_type is "vsd" + vsd_phi_model_name_or_path: Optional[str] = None + vsd_guidance_scale_phi: float = 1.0 + vsd_use_lora: bool = True + vsd_lora_cfg_training: bool = False + vsd_lora_n_timestamp_samples: int = 1 + vsd_use_camera_condition: bool = True + # camera condition type, in ["extrinsics", "mvp", "spherical"] + vsd_camera_condition_type: Optional[str] = "extrinsics" + + cfg: Config + + def configure(self) -> None: + self.min_step: Optional[int] = None + self.max_step: Optional[int] = None + self.grad_clip_val: Optional[float] = None + + @dataclass + class NonTrainableModules: + pipe: StableDiffusionPipeline + pipe_phi: Optional[StableDiffusionPipeline] = None + controlnet: Optional[ControlNetModel] = None + + self.weights_dtype = ( + torch.float16 if self.cfg.half_precision_weights else torch.float32 + ) + + threestudio.info(f"Loading Stable Diffusion ...") + + pipe_kwargs = { + "tokenizer": None, + "safety_checker": None, + "feature_extractor": None, + "requires_safety_checker": False, + "torch_dtype": self.weights_dtype, + "cache_dir": self.cfg.cache_dir, + "local_files_only": self.cfg.local_files_only, + } + pipe = StableDiffusionPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, + **pipe_kwargs, + ).to(self.device) + self.prepare_pipe(pipe) + self.configure_pipe_token_merging(pipe) + + # phi network for VSD + # introduce two trainable modules: + # - self.camera_embedding + # - self.lora_layers + pipe_phi = None + + # if the phi network shares the same unet with the pretrain network + # we need to pass additional cross attention kwargs to the unet + self.vsd_share_model = ( + self.cfg.guidance_type == "vsd" + and self.cfg.vsd_phi_model_name_or_path is None + ) + if self.cfg.guidance_type == "vsd": + if self.cfg.vsd_phi_model_name_or_path is None: + pipe_phi = pipe + else: + pipe_phi = StableDiffusionPipeline.from_pretrained( + self.cfg.vsd_phi_model_name_or_path, + **pipe_kwargs, + ).to(self.device) + self.prepare_pipe(pipe_phi) + self.configure_pipe_token_merging(pipe_phi) + + # set up camera embedding + if self.cfg.vsd_use_camera_condition: + if self.cfg.vsd_camera_condition_type in ["extrinsics", "mvp"]: + self.camera_embedding_dim = 16 + elif self.cfg.vsd_camera_condition_type == "spherical": + self.camera_embedding_dim = 4 + else: + raise ValueError("Invalid camera condition type!") + + # FIXME: hard-coded output dim + self.camera_embedding = ToDTypeWrapper( + TimestepEmbedding(self.camera_embedding_dim, 1280), + self.weights_dtype, + ).to(self.device) + pipe_phi.unet.class_embedding = self.camera_embedding + + if self.cfg.vsd_use_lora: + # set up LoRA layers + lora_attn_procs = {} + for name in pipe_phi.unet.attn_processors.keys(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else pipe_phi.unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = pipe_phi.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list( + reversed(pipe_phi.unet.config.block_out_channels) + )[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = pipe_phi.unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + + pipe_phi.unet.set_attn_processor(lora_attn_procs) + + self.lora_layers = AttnProcsLayers(pipe_phi.unet.attn_processors).to( + self.device + ) + self.lora_layers._load_state_dict_pre_hooks.clear() + self.lora_layers._state_dict_hooks.clear() + + threestudio.info(f"Loaded Stable Diffusion!") + + # controlnet + controlnet = None + if self.cfg.controlnet_model_name_or_path is not None: + threestudio.info(f"Loading ControlNet ...") + + controlnet = ControlNetModel.from_pretrained( + self.cfg.controlnet_model_name_or_path, + torch_dtype=self.weights_dtype, + ).to(self.device) + controlnet.eval() + enable_gradient(controlnet, enabled=False) + + threestudio.info(f"Loaded ControlNet!") + + self.scheduler = DDPMScheduler.from_config(pipe.scheduler.config) + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + + # q(z_t|x) = N(alpha_t x, sigma_t^2 I) + # in DDPM, alpha_t = sqrt(alphas_cumprod_t), sigma_t^2 = 1 - alphas_cumprod_t + self.alphas_cumprod: Float[Tensor, "T"] = self.scheduler.alphas_cumprod.to( + self.device + ) + self.alphas: Float[Tensor, "T"] = self.alphas_cumprod**0.5 + self.sigmas: Float[Tensor, "T"] = (1 - self.alphas_cumprod) ** 0.5 + # log SNR + self.lambdas: Float[Tensor, "T"] = self.sigmas / self.alphas + + self._non_trainable_modules = NonTrainableModules( + pipe=pipe, + pipe_phi=pipe_phi, + controlnet=controlnet, + ) + + @property + def pipe(self) -> StableDiffusionPipeline: + return self._non_trainable_modules.pipe + + @property + def pipe_phi(self) -> StableDiffusionPipeline: + if self._non_trainable_modules.pipe_phi is None: + raise RuntimeError("phi model is not available.") + return self._non_trainable_modules.pipe_phi + + @property + def controlnet(self) -> ControlNetModel: + if self._non_trainable_modules.controlnet is None: + raise RuntimeError("ControlNet model is not available.") + return self._non_trainable_modules.controlnet + + def prepare_pipe(self, pipe: StableDiffusionPipeline): + if self.cfg.enable_memory_efficient_attention: + if parse_version(torch.__version__) >= parse_version("2"): + threestudio.info( + "PyTorch2.0 uses memory efficient attention by default." + ) + elif not is_xformers_available(): + threestudio.warn( + "xformers is not available, memory efficient attention is not enabled." + ) + else: + pipe.enable_xformers_memory_efficient_attention() + + if self.cfg.enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload() + + if self.cfg.enable_attention_slicing: + pipe.enable_attention_slicing(1) + + if self.cfg.enable_channels_last_format: + pipe.unet.to(memory_format=torch.channels_last) + + # FIXME: pipe.__call__ requires text_encoder.dtype + # pipe.text_encoder.to("meta") + cleanup() + + pipe.vae.eval() + pipe.unet.eval() + + enable_gradient(pipe.vae, enabled=False) + enable_gradient(pipe.unet, enabled=False) + + # disable progress bar + pipe.set_progress_bar_config(disable=True) + + def configure_pipe_token_merging(self, pipe: StableDiffusionPipeline): + if self.cfg.token_merging: + import tomesd + + tomesd.apply_patch(pipe.unet, **self.cfg.token_merging_params) + + @torch.cuda.amp.autocast(enabled=False) + def forward_unet( + self, + unet: UNet2DConditionModel, + latents: Float[Tensor, "..."], + t: Int[Tensor, "..."], + encoder_hidden_states: Float[Tensor, "..."], + class_labels: Optional[Float[Tensor, "..."]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Float[Tensor, "..."]] = None, + mid_block_additional_residual: Optional[Float[Tensor, "..."]] = None, + velocity_to_epsilon: bool = False, + ) -> Float[Tensor, "..."]: + input_dtype = latents.dtype + pred = unet( + latents.to(unet.dtype), + t.to(unet.dtype), + encoder_hidden_states=encoder_hidden_states.to(unet.dtype), + class_labels=class_labels, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + ).sample + if velocity_to_epsilon: + pred = latents * self.sigmas[t].view(-1, 1, 1, 1) + pred * self.alphas[ + t + ].view(-1, 1, 1, 1) + return pred.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def vae_encode( + self, vae: AutoencoderKL, imgs: Float[Tensor, "B 3 H W"], mode=False + ) -> Float[Tensor, "B 4 Hl Wl"]: + # expect input in [-1, 1] + input_dtype = imgs.dtype + posterior = vae.encode(imgs.to(vae.dtype)).latent_dist + if mode: + latents = posterior.mode() + else: + latents = posterior.sample() + latents = latents * vae.config.scaling_factor + return latents.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def vae_decode( + self, vae: AutoencoderKL, latents: Float[Tensor, "B 4 Hl Wl"] + ) -> Float[Tensor, "B 3 H W"]: + # output in [0, 1] + input_dtype = latents.dtype + latents = 1 / vae.config.scaling_factor * latents + image = vae.decode(latents.to(vae.dtype)).sample + image = (image * 0.5 + 0.5).clamp(0, 1) + return image.to(input_dtype) + + @contextmanager + def disable_unet_class_embedding(self, unet: UNet2DConditionModel): + class_embedding = unet.class_embedding + try: + unet.class_embedding = None + yield unet + finally: + unet.class_embedding = class_embedding + + @contextmanager + def set_scheduler( + self, pipe: StableDiffusionPipeline, scheduler_class: Any, **kwargs + ): + scheduler_orig = pipe.scheduler + pipe.scheduler = scheduler_class.from_config(scheduler_orig.config, **kwargs) + yield pipe + pipe.scheduler = scheduler_orig + + def get_eps_pretrain( + self, + latents_noisy: Float[Tensor, "B 4 Hl Wl"], + t: Int[Tensor, "B"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + ) -> Float[Tensor, "B 4 Hl Wl"]: + batch_size = latents_noisy.shape[0] + + if prompt_utils.use_perp_neg: + ( + text_embeddings, + neg_guidance_weights, + ) = prompt_utils.get_text_embeddings_perp_neg( + elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting + ) + with torch.no_grad(): + with self.disable_unet_class_embedding(self.pipe.unet) as unet: + noise_pred = self.forward_unet( + unet, + torch.cat([latents_noisy] * 4, dim=0), + torch.cat([t] * 4, dim=0), + encoder_hidden_states=text_embeddings, + cross_attention_kwargs={"scale": 0.0} + if self.vsd_share_model + else None, + velocity_to_epsilon=self.pipe.scheduler.config.prediction_type + == "v_prediction", + ) # (4B, 3, Hl, Wl) + + noise_pred_text = noise_pred[:batch_size] + noise_pred_uncond = noise_pred[batch_size : batch_size * 2] + noise_pred_neg = noise_pred[batch_size * 2 :] + + e_pos = noise_pred_text - noise_pred_uncond + accum_grad = 0 + n_negative_prompts = neg_guidance_weights.shape[-1] + for i in range(n_negative_prompts): + e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond + accum_grad += neg_guidance_weights[:, i].view( + -1, 1, 1, 1 + ) * perpendicular_component(e_i_neg, e_pos) + + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + e_pos + accum_grad + ) + else: + text_embeddings = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting + ) + with torch.no_grad(): + with self.disable_unet_class_embedding(self.pipe.unet) as unet: + noise_pred = self.forward_unet( + unet, + torch.cat([latents_noisy] * 2, dim=0), + torch.cat([t] * 2, dim=0), + encoder_hidden_states=text_embeddings, + cross_attention_kwargs={"scale": 0.0} + if self.vsd_share_model + else None, + velocity_to_epsilon=self.pipe.scheduler.config.prediction_type + == "v_prediction", + ) + + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + return noise_pred + + def get_eps_phi( + self, + latents_noisy: Float[Tensor, "B 4 Hl Wl"], + t: Int[Tensor, "B"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + camera_condition: Float[Tensor, "B ..."], + ) -> Float[Tensor, "B 4 Hl Wl"]: + batch_size = latents_noisy.shape[0] + + # not using view-dependent prompting in LoRA + text_embeddings, _ = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, view_dependent_prompting=False + ).chunk(2) + with torch.no_grad(): + noise_pred = self.forward_unet( + self.pipe_phi.unet, + torch.cat([latents_noisy] * 2, dim=0), + torch.cat([t] * 2, dim=0), + encoder_hidden_states=torch.cat([text_embeddings] * 2, dim=0), + class_labels=torch.cat( + [ + camera_condition.view(batch_size, -1), + torch.zeros_like(camera_condition.view(batch_size, -1)), + ], + dim=0, + ) + if self.cfg.vsd_use_camera_condition + else None, + cross_attention_kwargs={"scale": 1.0}, + velocity_to_epsilon=self.pipe_phi.scheduler.config.prediction_type + == "v_prediction", + ) + + noise_pred_camera, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.vsd_guidance_scale_phi * ( + noise_pred_camera - noise_pred_uncond + ) + + return noise_pred + + def train_phi( + self, + latents: Float[Tensor, "B 4 Hl Wl"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + camera_condition: Float[Tensor, "B ..."], + ): + B = latents.shape[0] + latents = latents.detach().repeat( + self.cfg.vsd_lora_n_timestamp_samples, 1, 1, 1 + ) + + num_train_timesteps = self.pipe_phi.scheduler.config.num_train_timesteps + t = torch.randint( + int(num_train_timesteps * 0.0), + int(num_train_timesteps * 1.0), + [B * self.cfg.vsd_lora_n_timestamp_samples], + dtype=torch.long, + device=self.device, + ) + + noise = torch.randn_like(latents) + latents_noisy = self.pipe_phi.scheduler.add_noise(latents, noise, t) + if self.pipe_phi.scheduler.config.prediction_type == "epsilon": + target = noise + elif self.pipe_phi.scheduler.prediction_type == "v_prediction": + target = self.pipe_phi.scheduler.get_velocity(latents, noise, t) + else: + raise ValueError( + f"Unknown prediction type {self.pipe_phi.scheduler.prediction_type}" + ) + + # not using view-dependent prompting in LoRA + text_embeddings, _ = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, view_dependent_prompting=False + ).chunk(2) + + if ( + self.cfg.vsd_use_camera_condition + and self.cfg.vsd_lora_cfg_training + and random.random() < 0.1 + ): + camera_condition = torch.zeros_like(camera_condition) + + noise_pred = self.forward_unet( + self.pipe_phi.unet, + latents_noisy, + t, + encoder_hidden_states=text_embeddings.repeat( + self.cfg.vsd_lora_n_timestamp_samples, 1, 1 + ), + class_labels=camera_condition.view(B, -1).repeat( + self.cfg.vsd_lora_n_timestamp_samples, 1 + ) + if self.cfg.vsd_use_camera_condition + else None, + cross_attention_kwargs={"scale": 1.0}, + ) + return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + def forward( + self, + rgb: Float[Tensor, "B H W C"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + mvp_mtx: Float[Tensor, "B 4 4"], + c2w: Float[Tensor, "B 4 4"], + rgb_as_latents=False, + **kwargs, + ): + batch_size = rgb.shape[0] + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + latents: Float[Tensor, "B 4 Hl Wl"] + if rgb_as_latents: + # treat input rgb as latents + # input rgb should be in range [-1, 1] + latents = F.interpolate( + rgb_BCHW, (64, 64), mode="bilinear", align_corners=False + ) + else: + # treat input rgb as rgb + # input rgb should be in range [0, 1] + rgb_BCHW = F.interpolate( + rgb_BCHW, (512, 512), mode="bilinear", align_corners=False + ) + # encode image into latents with vae + latents = self.vae_encode(self.pipe.vae, rgb_BCHW * 2.0 - 1.0) + + # sample timestep + # use the same timestep for each batch + assert self.min_step is not None and self.max_step is not None + t = torch.randint( + self.min_step, + self.max_step + 1, + [1], + dtype=torch.long, + device=self.device, + ).repeat(batch_size) + + # sample noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + + eps_pretrain = self.get_eps_pretrain( + latents_noisy, t, prompt_utils, elevation, azimuth, camera_distances + ) + + latents_1step_orig = ( + 1 + / self.alphas[t].view(-1, 1, 1, 1) + * (latents_noisy - self.sigmas[t].view(-1, 1, 1, 1) * eps_pretrain) + ).detach() + + if self.cfg.guidance_type == "sds": + eps_phi = noise + elif self.cfg.guidance_type == "vsd": + if self.cfg.vsd_camera_condition_type == "extrinsics": + camera_condition = c2w + elif self.cfg.vsd_camera_condition_type == "mvp": + camera_condition = mvp_mtx + elif self.cfg.vsd_camera_condition_type == "spherical": + camera_condition = torch.stack( + [ + torch.deg2rad(elevation), + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + camera_distances, + ], + dim=-1, + ) + else: + raise ValueError( + f"Unknown camera_condition_type {self.cfg.vsd_camera_condition_type}" + ) + eps_phi = self.get_eps_phi( + latents_noisy, + t, + prompt_utils, + elevation, + azimuth, + camera_distances, + camera_condition, + ) + + loss_train_phi = self.train_phi( + latents, + prompt_utils, + elevation, + azimuth, + camera_distances, + camera_condition, + ) + + if self.cfg.weighting_strategy == "dreamfusion": + w = (1.0 - self.alphas[t]).view(-1, 1, 1, 1) + elif self.cfg.weighting_strategy == "uniform": + w = 1.0 + elif self.cfg.weighting_strategy == "fantasia3d": + w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1) + else: + raise ValueError( + f"Unknown weighting strategy: {self.cfg.weighting_strategy}" + ) + + grad = w * (eps_pretrain - eps_phi) + + if self.grad_clip_val is not None: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + + # reparameterization trick: + # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad + target = (latents - grad).detach() + loss_sd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size + + guidance_out = { + "loss_sd": loss_sd, + "grad_norm": grad.norm(), + "timesteps": t, + "min_step": self.min_step, + "max_step": self.max_step, + "latents": latents, + "latents_1step_orig": latents_1step_orig, + "rgb": rgb_BCHW.permute(0, 2, 3, 1), + "weights": w, + "lambdas": self.lambdas[t], + } + + if self.cfg.return_rgb_1step_orig: + with torch.no_grad(): + rgb_1step_orig = self.vae_decode( + self.pipe.vae, latents_1step_orig + ).permute(0, 2, 3, 1) + guidance_out.update({"rgb_1step_orig": rgb_1step_orig}) + + if self.cfg.return_rgb_multistep_orig: + with self.set_scheduler( + self.pipe, + DPMSolverSinglestepScheduler, + solver_order=1, + num_train_timesteps=int(t[0]), + ) as pipe: + text_embeddings = prompt_utils.get_text_embeddings( + elevation, + azimuth, + camera_distances, + self.cfg.view_dependent_prompting, + ) + text_embeddings_cond, text_embeddings_uncond = text_embeddings.chunk(2) + with torch.cuda.amp.autocast(enabled=False): + latents_multistep_orig = pipe( + num_inference_steps=self.cfg.n_rgb_multistep_orig_steps, + guidance_scale=self.cfg.guidance_scale, + eta=1.0, + latents=latents_noisy.to(pipe.unet.dtype), + prompt_embeds=text_embeddings_cond.to(pipe.unet.dtype), + negative_prompt_embeds=text_embeddings_uncond.to( + pipe.unet.dtype + ), + cross_attention_kwargs={"scale": 0.0} + if self.vsd_share_model + else None, + output_type="latent", + ).images.to(latents.dtype) + with torch.no_grad(): + rgb_multistep_orig = self.vae_decode( + self.pipe.vae, latents_multistep_orig + ) + guidance_out.update( + { + "latents_multistep_orig": latents_multistep_orig, + "rgb_multistep_orig": rgb_multistep_orig.permute(0, 2, 3, 1), + } + ) + + if self.cfg.guidance_type == "vsd": + guidance_out.update( + { + "loss_train_phi": loss_train_phi, + } + ) + + return guidance_out + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # clip grad for stable training as demonstrated in + # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation + # http://arxiv.org/abs/2303.15413 + if self.cfg.grad_clip is not None: + self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) + + self.min_step = int( + self.num_train_timesteps * C(self.cfg.min_step_percent, epoch, global_step) + ) + self.max_step = int( + self.num_train_timesteps * C(self.cfg.max_step_percent, epoch, global_step) + ) diff --git a/threestudio/models/guidance/stable_diffusion_vsd_guidance.py b/threestudio/models/guidance/stable_diffusion_vsd_guidance.py new file mode 100644 index 0000000..4ac2a95 --- /dev/null +++ b/threestudio/models/guidance/stable_diffusion_vsd_guidance.py @@ -0,0 +1,1003 @@ +import random +from contextlib import contextmanager +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.embeddings import TimestepEmbedding +from diffusers.utils.import_utils import is_xformers_available + +import threestudio +from threestudio.models.prompt_processors.base import PromptProcessorOutput +from threestudio.utils.base import BaseModule +from threestudio.utils.misc import C, cleanup, parse_version +from threestudio.utils.perceptual import PerceptualLoss +from threestudio.utils.typing import * + + +class ToWeightsDType(nn.Module): + def __init__(self, module: nn.Module, dtype: torch.dtype): + super().__init__() + self.module = module + self.dtype = dtype + + def forward(self, x: Float[Tensor, "..."]) -> Float[Tensor, "..."]: + return self.module(x).to(self.dtype) + + +@threestudio.register("stable-diffusion-vsd-guidance") +class StableDiffusionVSDGuidance(BaseModule): + @dataclass + class Config(BaseModule.Config): + cache_dir: Optional[str] = None + local_files_only: Optional[bool] = False + pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base" + pretrained_model_name_or_path_lora: str = "stabilityai/stable-diffusion-2-1" + enable_memory_efficient_attention: bool = False + enable_sequential_cpu_offload: bool = False + enable_attention_slicing: bool = False + enable_channels_last_format: bool = False + guidance_scale: float = 7.5 + guidance_scale_lora: float = 1.0 + grad_clip: Optional[ + Any + ] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000]) + half_precision_weights: bool = True + lora_cfg_training: bool = True + lora_n_timestamp_samples: int = 1 + + min_step_percent: float = 0.02 + max_step_percent: float = 0.98 + + view_dependent_prompting: bool = True + camera_condition_type: str = "extrinsics" + + use_du: bool = False + per_du_step: int = 10 + start_du_step: int = 0 + du_diffusion_steps: int = 20 + + use_bsd: bool = False + + cfg: Config + + def configure(self) -> None: + threestudio.info(f"Loading Stable Diffusion ...") + + self.weights_dtype = ( + torch.float16 if self.cfg.half_precision_weights else torch.float32 + ) + + pipe_kwargs = { + "tokenizer": None, + "safety_checker": None, + "feature_extractor": None, + "requires_safety_checker": False, + "torch_dtype": self.weights_dtype, + "cache_dir": self.cfg.cache_dir, + "local_files_only": self.cfg.local_files_only + } + + pipe_lora_kwargs = { + "tokenizer": None, + "safety_checker": None, + "feature_extractor": None, + "requires_safety_checker": False, + "torch_dtype": self.weights_dtype, + "cache_dir": self.cfg.cache_dir, + "local_files_only": self.cfg.local_files_only + } + + @dataclass + class SubModules: + pipe: StableDiffusionPipeline + pipe_lora: StableDiffusionPipeline + + pipe = StableDiffusionPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, + **pipe_kwargs, + ).to(self.device) + if ( + self.cfg.pretrained_model_name_or_path + == self.cfg.pretrained_model_name_or_path_lora + ): + self.single_model = True + pipe_lora = pipe + else: + self.single_model = False + pipe_lora = StableDiffusionPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path_lora, + **pipe_lora_kwargs, + ).to(self.device) + del pipe_lora.vae + cleanup() + pipe_lora.vae = pipe.vae + self.submodules = SubModules(pipe=pipe, pipe_lora=pipe_lora) + + if self.cfg.enable_memory_efficient_attention: + if parse_version(torch.__version__) >= parse_version("2"): + threestudio.info( + "PyTorch2.0 uses memory efficient attention by default." + ) + elif not is_xformers_available(): + threestudio.warn( + "xformers is not available, memory efficient attention is not enabled." + ) + else: + self.pipe.enable_xformers_memory_efficient_attention() + self.pipe_lora.enable_xformers_memory_efficient_attention() + + if self.cfg.enable_sequential_cpu_offload: + self.pipe.enable_sequential_cpu_offload() + self.pipe_lora.enable_sequential_cpu_offload() + + if self.cfg.enable_attention_slicing: + self.pipe.enable_attention_slicing(1) + self.pipe_lora.enable_attention_slicing(1) + + if self.cfg.enable_channels_last_format: + self.pipe.unet.to(memory_format=torch.channels_last) + self.pipe_lora.unet.to(memory_format=torch.channels_last) + + del self.pipe.text_encoder + if not self.single_model: + del self.pipe_lora.text_encoder + cleanup() + + for p in self.vae.parameters(): + p.requires_grad_(False) + for p in self.unet.parameters(): + p.requires_grad_(False) + for p in self.unet_lora.parameters(): + p.requires_grad_(False) + + # FIXME: hard-coded dims + self.camera_embedding = ToWeightsDType( + TimestepEmbedding(16, 1280), self.weights_dtype + ).to(self.device) + self.unet_lora.class_embedding = self.camera_embedding + + # set up LoRA layers + lora_attn_procs = {} + for name in self.unet_lora.attn_processors.keys(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else self.unet_lora.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = self.unet_lora.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.unet_lora.config.block_out_channels))[ + block_id + ] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.unet_lora.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + + self.unet_lora.set_attn_processor(lora_attn_procs) + + self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to( + self.device + ) + self.lora_layers._load_state_dict_pre_hooks.clear() + self.lora_layers._state_dict_hooks.clear() + + self.scheduler = DDIMScheduler.from_pretrained( # DDPM + self.cfg.pretrained_model_name_or_path, + subfolder="scheduler", + torch_dtype=self.weights_dtype, + cache_dir=self.cfg.cache_dir, + local_files_only=self.cfg.local_files_only, + ) + + self.scheduler_lora = DDPMScheduler.from_pretrained( + self.cfg.pretrained_model_name_or_path_lora, + subfolder="scheduler", + torch_dtype=self.weights_dtype, + cache_dir=self.cfg.cache_dir, + local_files_only=self.cfg.local_files_only, + ) + + self.scheduler_sample = DPMSolverMultistepScheduler.from_config( + self.pipe.scheduler.config + ) + self.scheduler_lora_sample = DPMSolverMultistepScheduler.from_config( + self.pipe_lora.scheduler.config + ) + + self.pipe.scheduler = self.scheduler + self.pipe_lora.scheduler = self.scheduler_lora + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.set_min_max_steps() # set to default value + + self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to( + self.device + ) + + self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device) + + self.grad_clip_val: Optional[float] = None + + if self.cfg.use_du: + self.perceptual_loss = PerceptualLoss().eval().to(self.device) + for p in self.perceptual_loss.parameters(): + p.requires_grad_(False) + + threestudio.info(f"Loaded Stable Diffusion!") + + @torch.cuda.amp.autocast(enabled=False) + def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): + self.min_step = int(self.num_train_timesteps * min_step_percent) + self.max_step = int(self.num_train_timesteps * max_step_percent) + + @property + def pipe(self): + return self.submodules.pipe + + @property + def pipe_lora(self): + return self.submodules.pipe_lora + + @property + def unet(self): + return self.submodules.pipe.unet + + @property + def unet_lora(self): + return self.submodules.pipe_lora.unet + + @property + def vae(self): + return self.submodules.pipe.vae + + @property + def vae_lora(self): + return self.submodules.pipe_lora.vae + + @torch.no_grad() + @torch.cuda.amp.autocast(enabled=False) + def _sample( + self, + pipe: StableDiffusionPipeline, + sample_scheduler: DPMSolverMultistepScheduler, + text_embeddings: Float[Tensor, "BB N Nf"], + num_inference_steps: int, + guidance_scale: float, + num_images_per_prompt: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + class_labels: Optional[Float[Tensor, "BB 16"]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents_inp: Optional[Float[Tensor, "..."]] = None, + ) -> Float[Tensor, "B H W 3"]: + vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1) + height = height or pipe.unet.config.sample_size * vae_scale_factor + width = width or pipe.unet.config.sample_size * vae_scale_factor + batch_size = text_embeddings.shape[0] // 2 + device = self.device + + sample_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = sample_scheduler.timesteps + num_channels_latents = pipe.unet.config.in_channels + + if latents_inp is not None: + t = torch.randint( + self.min_step, + self.max_step, + [1], + dtype=torch.long, + device=self.device, + ) + noise = torch.randn_like(latents_inp) + init_timestep = max(1, min(int(num_inference_steps * t[0].item() / self.num_train_timesteps), num_inference_steps)) + t_start = max(num_inference_steps - init_timestep, 0) + latent_timestep = sample_scheduler.timesteps[t_start : t_start + 1].repeat(batch_size) + latents = sample_scheduler.add_noise(latents_inp, noise, latent_timestep).to(self.weights_dtype) + + else: + latents = pipe.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + self.weights_dtype, + device, + generator, + ) + t_start = 0 + + for i, t in enumerate(timesteps[t_start:]): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = sample_scheduler.scale_model_input( + latent_model_input, t + ) + + # predict the noise residual + if class_labels is None: + with self.disable_unet_class_embedding(pipe.unet) as unet: + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings.to(self.weights_dtype), + cross_attention_kwargs=cross_attention_kwargs, + ).sample + else: + noise_pred = pipe.unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings.to(self.weights_dtype), + class_labels=class_labels, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = sample_scheduler.step(noise_pred, t, latents).prev_sample + + latents = 1 / pipe.vae.config.scaling_factor * latents + images = pipe.vae.decode(latents).sample + images = (images / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + images = images.permute(0, 2, 3, 1).float() + return images + + def sample( + self, + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + seed: int = 0, + **kwargs, + ) -> Float[Tensor, "N H W 3"]: + # view-dependent text embeddings + text_embeddings_vd = prompt_utils.get_text_embeddings( + elevation, + azimuth, + camera_distances, + view_dependent_prompting=self.cfg.view_dependent_prompting, + ) + cross_attention_kwargs = {"scale": 0.0} if self.single_model else None + generator = torch.Generator(device=self.device).manual_seed(seed) + + return self._sample( + pipe=self.pipe, + sample_scheduler=self.scheduler_sample, + text_embeddings=text_embeddings_vd, + num_inference_steps=25, + guidance_scale=self.cfg.guidance_scale, + cross_attention_kwargs=cross_attention_kwargs, + generator=generator, + ) + + def sample_img2img( + self, + rgb: Float[Tensor, "B H W C"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + seed: int = 0, + mask = None, + **kwargs, + ) -> Float[Tensor, "N H W 3"]: + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + mask_BCHW = mask.permute(0, 3, 1, 2) + latents = self.get_latents(rgb_BCHW, rgb_as_latents=False) # TODO: 有部分概率是du或者ref image + + # view-dependent text embeddings + text_embeddings_vd = prompt_utils.get_text_embeddings( + elevation, + azimuth, + camera_distances, + view_dependent_prompting=self.cfg.view_dependent_prompting, + ) + cross_attention_kwargs = {"scale": 0.0} if self.single_model else None + generator = torch.Generator(device=self.device).manual_seed(seed) + + # return self._sample( + # pipe=self.pipe, + # sample_scheduler=self.scheduler_sample, + # text_embeddings=text_embeddings_vd, + # num_inference_steps=25, + # guidance_scale=self.cfg.guidance_scale, + # cross_attention_kwargs=cross_attention_kwargs, + # generator=generator, + # latents_inp=latents + # ) + + return self.compute_grad_du(latents, rgb_BCHW, text_embeddings_vd, mask=mask_BCHW) + + def sample_lora( + self, + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + mvp_mtx: Float[Tensor, "B 4 4"], + c2w: Float[Tensor, "B 4 4"], + seed: int = 0, + **kwargs, + ) -> Float[Tensor, "N H W 3"]: + # input text embeddings, view-independent + text_embeddings = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, view_dependent_prompting=False + ) + + if self.cfg.camera_condition_type == "extrinsics": + camera_condition = c2w + elif self.cfg.camera_condition_type == "mvp": + camera_condition = mvp_mtx + else: + raise ValueError( + f"Unknown camera_condition_type {self.cfg.camera_condition_type}" + ) + + B = elevation.shape[0] + camera_condition_cfg = torch.cat( + [ + camera_condition.view(B, -1), + torch.zeros_like(camera_condition.view(B, -1)), + ], + dim=0, + ) + + generator = torch.Generator(device=self.device).manual_seed(seed) + return self._sample( + sample_scheduler=self.scheduler_lora_sample, + pipe=self.pipe_lora, + text_embeddings=text_embeddings, + num_inference_steps=25, + guidance_scale=self.cfg.guidance_scale_lora, + class_labels=camera_condition_cfg, + cross_attention_kwargs={"scale": 1.0}, + generator=generator, + ) + + @torch.cuda.amp.autocast(enabled=False) + def forward_unet( + self, + unet: UNet2DConditionModel, + latents: Float[Tensor, "..."], + t: Float[Tensor, "..."], + encoder_hidden_states: Float[Tensor, "..."], + class_labels: Optional[Float[Tensor, "B 16"]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Float[Tensor, "..."]: + input_dtype = latents.dtype + return unet( + latents.to(self.weights_dtype), + t.to(self.weights_dtype), + encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype), + class_labels=class_labels, + cross_attention_kwargs=cross_attention_kwargs, + ).sample.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def encode_images( + self, imgs: Float[Tensor, "B 3 512 512"] + ) -> Float[Tensor, "B 4 64 64"]: + input_dtype = imgs.dtype + imgs = imgs * 2.0 - 1.0 + posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + return latents.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def decode_latents( + self, + latents: Float[Tensor, "B 4 H W"], + latent_height: int = 64, + latent_width: int = 64, + ) -> Float[Tensor, "B 3 512 512"]: + input_dtype = latents.dtype + latents = F.interpolate( + latents, (latent_height, latent_width), mode="bilinear", align_corners=False + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents.to(self.weights_dtype)).sample + image = (image * 0.5 + 0.5).clamp(0, 1) + return image.to(input_dtype) + + @contextmanager + def disable_unet_class_embedding(self, unet: UNet2DConditionModel): + class_embedding = unet.class_embedding + try: + unet.class_embedding = None + yield unet + finally: + unet.class_embedding = class_embedding + + def compute_grad_du( + self, + latents: Float[Tensor, "B 4 64 64"], + rgb_BCHW_512: Float[Tensor, "B 3 512 512"], + text_embeddings: Float[Tensor, "BB 77 768"], + mask = None, + **kwargs, + ): + batch_size, _, _, _ = latents.shape + rgb_BCHW_512 = F.interpolate(rgb_BCHW_512, (512, 512), mode="bilinear") + assert batch_size == 1 + need_diffusion = ( + self.global_step % self.cfg.per_du_step == 0 + and self.global_step > self.cfg.start_du_step + ) + guidance_out = {} + + if need_diffusion: + t = torch.randint( + self.min_step, + self.max_step, + [1], + dtype=torch.long, + device=self.device, + ) + self.scheduler.config.num_train_timesteps = t.item() + self.scheduler.set_timesteps(self.cfg.du_diffusion_steps) + + if mask is not None: + mask = F.interpolate(mask, (64, 64), mode="bilinear", antialias=True) + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents = self.scheduler.add_noise(latents, noise, t) # type: ignore + for i, timestep in enumerate(self.scheduler.timesteps): + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + latent_model_input = torch.cat([latents] * 2) + with self.disable_unet_class_embedding(self.unet) as unet: + cross_attention_kwargs = ( + {"scale": 0.0} if self.single_model else None + ) + noise_pred = self.forward_unet( + unet, + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=cross_attention_kwargs, + ) + # perform classifier-free guidance + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if mask is not None: + noise_pred = mask * noise_pred + (1 - mask) * noise + # get previous sample, continue loop + latents = self.scheduler.step( + noise_pred, timestep, latents + ).prev_sample + edit_images = self.decode_latents(latents) + edit_images = F.interpolate( + edit_images, (512, 512), mode="bilinear" + ).permute(0, 2, 3, 1) + gt_rgb = edit_images + # import cv2 + # import numpy as np + # mask_temp = mask_BCHW_512.permute(0,2,3,1) + # # edit_images = edit_images * mask_temp + torch.rand(3)[None, None, None].to(self.device).repeat(*edit_images.shape[:-1],1) * (1 - mask_temp) + # temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8) + # cv2.imwrite(f".threestudio_cache/pig_sd_noise_500/test_{kwargs.get('name', 'none')}.jpg", temp[:, :, ::-1]) + + guidance_out.update( + { + "loss_l1": torch.nn.functional.l1_loss( + rgb_BCHW_512, gt_rgb.permute(0, 3, 1, 2), reduction="sum" + ), + "loss_p": self.perceptual_loss( + rgb_BCHW_512.contiguous(), + gt_rgb.permute(0, 3, 1, 2).contiguous(), + ).sum(), + "edit_image": edit_images.detach() + } + ) + + return guidance_out + + def compute_grad_vsd( + self, + latents: Float[Tensor, "B 4 64 64"], + text_embeddings_vd: Float[Tensor, "BB 77 768"], + text_embeddings: Float[Tensor, "BB 77 768"], + camera_condition: Float[Tensor, "B 4 4"], + ): + B = latents.shape[0] + + with torch.no_grad(): + # random timestamp + t = torch.randint( + self.min_step, + self.max_step + 1, + [B], + dtype=torch.long, + device=self.device, + ) + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2, dim=0) + if self.cfg.use_bsd: + cross_attention_kwargs = {"scale": 0.0} if self.single_model else None + noise_pred_pretrain = self.forward_unet( + self.unet, + latent_model_input, + torch.cat([t] * 2), + encoder_hidden_states=text_embeddings_vd, + class_labels=torch.cat( + [ + camera_condition.view(B, -1), + torch.zeros_like(camera_condition.view(B, -1)), + ], + dim=0, + ), + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + with self.disable_unet_class_embedding(self.unet) as unet: + cross_attention_kwargs = {"scale": 0.0} if self.single_model else None + noise_pred_pretrain = self.forward_unet( + unet, + latent_model_input, + torch.cat([t] * 2), + encoder_hidden_states=text_embeddings_vd, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # use view-independent text embeddings in LoRA + text_embeddings_cond, _ = text_embeddings.chunk(2) + noise_pred_est = self.forward_unet( + self.unet_lora, + latent_model_input, + torch.cat([t] * 2), + encoder_hidden_states=torch.cat([text_embeddings_cond] * 2), + class_labels=torch.cat( + [ + camera_condition.view(B, -1), + torch.zeros_like(camera_condition.view(B, -1)), + ], + dim=0, + ), + cross_attention_kwargs={"scale": 1.0}, + ) + + ( + noise_pred_pretrain_text, + noise_pred_pretrain_uncond, + ) = noise_pred_pretrain.chunk(2) + + # NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance + noise_pred_pretrain = noise_pred_pretrain_uncond + self.cfg.guidance_scale * ( + noise_pred_pretrain_text - noise_pred_pretrain_uncond + ) + + # TODO: more general cases + assert self.scheduler.config.prediction_type == "epsilon" + if self.scheduler_lora.config.prediction_type == "v_prediction": + alphas_cumprod = self.scheduler_lora.alphas_cumprod.to( + device=latents_noisy.device, dtype=latents_noisy.dtype + ) + alpha_t = alphas_cumprod[t] ** 0.5 + sigma_t = (1 - alphas_cumprod[t]) ** 0.5 + + noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).view( + -1, 1, 1, 1 + ) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).view(-1, 1, 1, 1) + + ( + noise_pred_est_camera, + noise_pred_est_uncond, + ) = noise_pred_est.chunk(2) + + # NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance + noise_pred_est = noise_pred_est_uncond + self.cfg.guidance_scale_lora * ( + noise_pred_est_camera - noise_pred_est_uncond + ) + + w = (1 - self.alphas[t]).view(-1, 1, 1, 1) + + grad = w * (noise_pred_pretrain - noise_pred_est) + return grad + + def compute_grad_vsd_hifa( + self, + latents: Float[Tensor, "B 4 64 64"], + text_embeddings_vd: Float[Tensor, "BB 77 768"], + text_embeddings: Float[Tensor, "BB 77 768"], + camera_condition: Float[Tensor, "B 4 4"], + mask=None, + ): + B, _, DH, DW = latents.shape + rgb = self.decode_latents(latents) + self.name = "hifa" + + if mask is not None: + mask = F.interpolate(mask, (DH, DW), mode="bilinear", antialias=True) + with torch.no_grad(): + # random timestamp + t = torch.randint( + self.min_step, + self.max_step + 1, + [B], + dtype=torch.long, + device=self.device, + ) + w = (1 - self.alphas[t]).view(-1, 1, 1, 1) + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler_sample.add_noise(latents, noise, t) + latents_noisy_lora = self.scheduler_lora_sample.add_noise(latents, noise, t) + # pred noise + + self.scheduler_sample.config.num_train_timesteps = t.item() + self.scheduler_sample.set_timesteps(t.item() // 50 + 1) + self.scheduler_lora_sample.config.num_train_timesteps = t.item() + self.scheduler_lora_sample.set_timesteps(t.item() // 50 + 1) + + for i, timestep in enumerate(self.scheduler_sample.timesteps): + # for i, timestep in tqdm(enumerate(self.scheduler.timesteps)): + latent_model_input = torch.cat([latents_noisy] * 2, dim=0) + latent_model_input_lora = torch.cat([latents_noisy_lora] * 2, dim=0) + + # print(latent_model_input.shape) + with self.disable_unet_class_embedding(self.unet) as unet: + cross_attention_kwargs = {"scale": 0.0} if self.single_model else None + noise_pred_pretrain = self.forward_unet( + unet, + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings_vd, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # use view-independent text embeddings in LoRA + noise_pred_est = self.forward_unet( + self.unet_lora, + latent_model_input_lora, + timestep, + encoder_hidden_states=text_embeddings, + class_labels=torch.cat( + [ + camera_condition.view(B, -1), + torch.zeros_like(camera_condition.view(B, -1)), + ], + dim=0, + ), + cross_attention_kwargs={"scale": 1.0}, + ) + + ( + noise_pred_pretrain_text, + noise_pred_pretrain_uncond, + ) = noise_pred_pretrain.chunk(2) + + # NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance + noise_pred_pretrain = noise_pred_pretrain_uncond + self.cfg.guidance_scale * ( + noise_pred_pretrain_text - noise_pred_pretrain_uncond + ) + if mask is not None: + noise_pred_pretrain = mask * noise_pred_pretrain + (1 - mask) * noise + + ( + noise_pred_est_text, + noise_pred_est_uncond, + ) = noise_pred_est.chunk(2) + + # NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance + # noise_pred_est = noise_pred_est_uncond + self.cfg.guidance_scale_lora * ( + # noise_pred_est_text - noise_pred_est_uncond + # ) + noise_pred_est = noise_pred_est_text + if mask is not None: + noise_pred_est = mask * noise_pred_est + (1 - mask) * noise + + latents_noisy = self.scheduler_sample.step(noise_pred_pretrain, timestep, latents_noisy).prev_sample + latents_noisy_lora = self.scheduler_lora_sample.step(noise_pred_est, timestep, latents_noisy_lora).prev_sample + + # noise = torch.randn_like(latents) + # latents_noisy = self.scheduler.step(noise_pred_pretrain, timestep, latents_noisy).prev_sample + # latents_noisy = mask * latents_noisy + (1-mask) * latents + # latents_noisy = self.scheduler_sample.add_noise(latents_noisy, noise, timestep) + + # latents_noisy_lora = self.scheduler_lora.step(noise_pred_est, timestep, latents_noisy_lora).prev_sample + # latents_noisy_lora = mask * latents_noisy_lora + (1-mask) * latents + # latents_noisy_lora = self.scheduler_lora_sample.add_noise(latents_noisy_lora, noise, timestep) + + hifa_images = self.decode_latents(latents_noisy) + hifa_lora_images = self.decode_latents(latents_noisy_lora) + + import cv2 + import numpy as np + if mask is not None: + print('hifa mask!') + prefix = 'vsd_mask' + else: + prefix = '' + temp = (hifa_images.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8) + cv2.imwrite(".threestudio_cache/%s%s_test.jpg" % (prefix, self.name), temp[:, :, ::-1]) + temp = (hifa_lora_images.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8) + cv2.imwrite(".threestudio_cache/%s%s_test_lora.jpg" % (prefix, self.name), temp[:, :, ::-1]) + + target = (latents_noisy - latents_noisy_lora + latents).detach() + # target = latents_noisy.detach() + targets_rgb = self.decode_latents(target) + # targets_rgb = (hifa_images - hifa_lora_images + rgb).detach() + temp = (targets_rgb.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8) + cv2.imwrite(".threestudio_cache/%s_target.jpg" % self.name, temp[:, :, ::-1]) + + return w * 0.5 * F.mse_loss(target, latents, reduction='sum') + + def train_lora( + self, + latents: Float[Tensor, "B 4 64 64"], + text_embeddings: Float[Tensor, "BB 77 768"], + camera_condition: Float[Tensor, "B 4 4"], + ): + B = latents.shape[0] + latents = latents.detach().repeat(self.cfg.lora_n_timestamp_samples, 1, 1, 1) + + t = torch.randint( + int(self.num_train_timesteps * 0.0), + int(self.num_train_timesteps * 1.0), + [B * self.cfg.lora_n_timestamp_samples], + dtype=torch.long, + device=self.device, + ) + + noise = torch.randn_like(latents) + noisy_latents = self.scheduler_lora.add_noise(latents, noise, t) + if self.scheduler_lora.config.prediction_type == "epsilon": + target = noise + elif self.scheduler_lora.config.prediction_type == "v_prediction": + target = self.scheduler_lora.get_velocity(latents, noise, t) + else: + raise ValueError( + f"Unknown prediction type {self.scheduler_lora.config.prediction_type}" + ) + # use view-independent text embeddings in LoRA + text_embeddings_cond, _ = text_embeddings.chunk(2) + if self.cfg.lora_cfg_training and random.random() < 0.1: + camera_condition = torch.zeros_like(camera_condition) + noise_pred = self.forward_unet( + self.unet_lora, + noisy_latents, + t, + encoder_hidden_states=text_embeddings_cond.repeat( + self.cfg.lora_n_timestamp_samples, 1, 1 + ), + class_labels=camera_condition.view(B, -1).repeat( + self.cfg.lora_n_timestamp_samples, 1 + ), + cross_attention_kwargs={"scale": 1.0}, + ) + return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + def get_latents( + self, rgb_BCHW: Float[Tensor, "B C H W"], rgb_as_latents=False + ) -> Float[Tensor, "B 4 64 64"]: + if rgb_as_latents: + latents = F.interpolate( + rgb_BCHW, (64, 64), mode="bilinear", align_corners=False + ) + else: + rgb_BCHW_512 = F.interpolate( + rgb_BCHW, (512, 512), mode="bilinear", align_corners=False + ) + # encode image into latents with vae + latents = self.encode_images(rgb_BCHW_512) + return latents + + def forward( + self, + rgb: Float[Tensor, "B H W C"], + prompt_utils: PromptProcessorOutput, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + mvp_mtx: Float[Tensor, "B 4 4"], + c2w: Float[Tensor, "B 4 4"], + rgb_as_latents=False, + mask: Float[Tensor, "B H W 1"] = None, + lora_prompt_utils = None, + **kwargs, + ): + batch_size = rgb.shape[0] + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + latents = self.get_latents(rgb_BCHW, rgb_as_latents=rgb_as_latents) + + if mask is not None: mask = mask.permute(0, 3, 1, 2) + + # view-dependent text embeddings + text_embeddings_vd = prompt_utils.get_text_embeddings( + elevation, + azimuth, + camera_distances, + view_dependent_prompting=self.cfg.view_dependent_prompting, + ) + if lora_prompt_utils is not None: + # input text embeddings, view-independent + text_embeddings = lora_prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, view_dependent_prompting=False + ) + else: + # input text embeddings, view-independent + text_embeddings = prompt_utils.get_text_embeddings( + elevation, azimuth, camera_distances, view_dependent_prompting=False + ) + + if self.cfg.camera_condition_type == "extrinsics": + camera_condition = c2w + elif self.cfg.camera_condition_type == "mvp": + camera_condition = mvp_mtx + else: + raise ValueError( + f"Unknown camera_condition_type {self.cfg.camera_condition_type}" + ) + + grad = self.compute_grad_vsd( + latents, text_embeddings_vd, text_embeddings, camera_condition + ) + + grad = torch.nan_to_num(grad) + # clip grad for stable training? + if self.grad_clip_val is not None: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + + # reparameterization trick + # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad + target = (latents - grad).detach() + loss_vsd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size + + loss_lora = self.train_lora(latents, text_embeddings, camera_condition) + + guidance_out = { + "loss_sd": loss_vsd, + "loss_lora": loss_lora, + "grad_norm": grad.norm(), + "min_step": self.min_step, + "max_step": self.max_step, + } + + if self.cfg.use_du: + du_out = self.compute_grad_du(latents, rgb_BCHW, text_embeddings_vd, mask=mask) + guidance_out.update(du_out) + + return guidance_out + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # clip grad for stable training as demonstrated in + # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation + # http://arxiv.org/abs/2303.15413 + if self.cfg.grad_clip is not None: + self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) + self.global_step = global_step + self.set_min_max_steps( + min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), + max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), + ) \ No newline at end of file diff --git a/threestudio/models/guidance/stable_zero123_guidance.py b/threestudio/models/guidance/stable_zero123_guidance.py new file mode 100644 index 0000000..ecd5d04 --- /dev/null +++ b/threestudio/models/guidance/stable_zero123_guidance.py @@ -0,0 +1,340 @@ +import importlib +import os +from dataclasses import dataclass, field + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers import DDIMScheduler, DDPMScheduler, StableDiffusionPipeline +from diffusers.utils.import_utils import is_xformers_available +from omegaconf import OmegaConf +from tqdm import tqdm + +import threestudio +from threestudio.utils.base import BaseObject +from threestudio.utils.misc import C, parse_version +from threestudio.utils.typing import * + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +# load model +def load_model_from_config(config, ckpt, device, vram_O=True, verbose=False): + pl_sd = torch.load(ckpt, map_location="cpu") + + if "global_step" in pl_sd and verbose: + print(f'[INFO] Global Step: {pl_sd["global_step"]}') + + sd = pl_sd["state_dict"] + + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("[INFO] missing keys: \n", m) + if len(u) > 0 and verbose: + print("[INFO] unexpected keys: \n", u) + + # manually load ema and delete it to save GPU memory + if model.use_ema: + if verbose: + print("[INFO] loading EMA...") + model.model_ema.copy_to(model.model) + del model.model_ema + + if vram_O: + # we don't need decoder + del model.first_stage_model.decoder + + torch.cuda.empty_cache() + + model.eval().to(device) + + return model + + +@threestudio.register("stable-zero123-guidance") +class StableZero123Guidance(BaseObject): + @dataclass + class Config(BaseObject.Config): + pretrained_model_name_or_path: str = "load/zero123/stable-zero123.ckpt" + pretrained_config: str = "load/zero123/sd-objaverse-finetune-c_concat-256.yaml" + vram_O: bool = True + + cond_image_path: str = "load/images/hamburger_rgba.png" + cond_elevation_deg: float = 0.0 + cond_azimuth_deg: float = 0.0 + cond_camera_distance: float = 1.2 + + guidance_scale: float = 5.0 + + grad_clip: Optional[ + Any + ] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000]) + half_precision_weights: bool = False + + min_step_percent: float = 0.02 + max_step_percent: float = 0.98 + + cfg: Config + + def configure(self) -> None: + threestudio.info(f"Loading Stable Zero123 ...") + + self.config = OmegaConf.load(self.cfg.pretrained_config) + # TODO: seems it cannot load into fp16... + self.weights_dtype = torch.float32 + self.model = load_model_from_config( + self.config, + self.cfg.pretrained_model_name_or_path, + device=self.device, + vram_O=self.cfg.vram_O, + ) + + for p in self.model.parameters(): + p.requires_grad_(False) + + # timesteps: use diffuser for convenience... hope it's alright. + self.num_train_timesteps = self.config.model.params.timesteps + + self.scheduler = DDIMScheduler( + self.num_train_timesteps, + self.config.model.params.linear_start, + self.config.model.params.linear_end, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.set_min_max_steps() # set to default value + + self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to( + self.device + ) + + self.grad_clip_val: Optional[float] = None + + self.prepare_embeddings(self.cfg.cond_image_path) + + threestudio.info(f"Loaded Stable Zero123!") + + @torch.cuda.amp.autocast(enabled=False) + def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): + self.min_step = int(self.num_train_timesteps * min_step_percent) + self.max_step = int(self.num_train_timesteps * max_step_percent) + + @torch.cuda.amp.autocast(enabled=False) + def prepare_embeddings(self, image_path: str) -> None: + # load cond image for zero123 + assert os.path.exists(image_path) + rgba = cv2.cvtColor( + cv2.imread(image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA + ) + rgba = ( + cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype( + np.float32 + ) + / 255.0 + ) + rgb = rgba[..., :3] * rgba[..., 3:] + (1 - rgba[..., 3:]) + self.rgb_256: Float[Tensor, "1 3 H W"] = ( + torch.from_numpy(rgb) + .unsqueeze(0) + .permute(0, 3, 1, 2) + .contiguous() + .to(self.device) + ) + self.c_crossattn, self.c_concat = self.get_img_embeds(self.rgb_256) + + @torch.cuda.amp.autocast(enabled=False) + @torch.no_grad() + def get_img_embeds( + self, + img: Float[Tensor, "B 3 256 256"], + ) -> Tuple[Float[Tensor, "B 1 768"], Float[Tensor, "B 4 32 32"]]: + img = img * 2.0 - 1.0 + c_crossattn = self.model.get_learned_conditioning(img.to(self.weights_dtype)) + c_concat = self.model.encode_first_stage(img.to(self.weights_dtype)).mode() + return c_crossattn, c_concat + + @torch.cuda.amp.autocast(enabled=False) + def encode_images( + self, imgs: Float[Tensor, "B 3 256 256"] + ) -> Float[Tensor, "B 4 32 32"]: + input_dtype = imgs.dtype + imgs = imgs * 2.0 - 1.0 + latents = self.model.get_first_stage_encoding( + self.model.encode_first_stage(imgs.to(self.weights_dtype)) + ) + return latents.to(input_dtype) # [B, 4, 32, 32] Latent space image + + @torch.cuda.amp.autocast(enabled=False) + def decode_latents( + self, + latents: Float[Tensor, "B 4 H W"], + ) -> Float[Tensor, "B 3 512 512"]: + input_dtype = latents.dtype + image = self.model.decode_first_stage(latents) + image = (image * 0.5 + 0.5).clamp(0, 1) + return image.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + @torch.no_grad() + def get_cond( + self, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + c_crossattn=None, + c_concat=None, + **kwargs, + ) -> dict: + T = torch.stack( + [ + torch.deg2rad( + (90 - elevation) - (90 - self.cfg.cond_elevation_deg) + ), # Zero123 polar is 90-elevation + torch.sin(torch.deg2rad(azimuth - self.cfg.cond_azimuth_deg)), + torch.cos(torch.deg2rad(azimuth - self.cfg.cond_azimuth_deg)), + torch.deg2rad( + 90 - torch.full_like(elevation, self.cfg.cond_elevation_deg) + ), + ], + dim=-1, + )[:, None, :].to(self.device) + cond = {} + clip_emb = self.model.cc_projection( + torch.cat( + [ + (self.c_crossattn if c_crossattn is None else c_crossattn).repeat( + len(T), 1, 1 + ), + T, + ], + dim=-1, + ) + ) + cond["c_crossattn"] = [ + torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0) + ] + cond["c_concat"] = [ + torch.cat( + [ + torch.zeros_like(self.c_concat) + .repeat(len(T), 1, 1, 1) + .to(self.device), + (self.c_concat if c_concat is None else c_concat).repeat( + len(T), 1, 1, 1 + ), + ], + dim=0, + ) + ] + return cond + + def __call__( + self, + rgb: Float[Tensor, "B H W C"], + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + rgb_as_latents=False, + **kwargs, + ): + batch_size = rgb.shape[0] + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + latents: Float[Tensor, "B 4 64 64"] + if rgb_as_latents: + latents = ( + F.interpolate(rgb_BCHW, (32, 32), mode="bilinear", align_corners=False) + * 2 + - 1 + ) + else: + rgb_BCHW_512 = F.interpolate( + rgb_BCHW, (256, 256), mode="bilinear", align_corners=False + ) + # encode image into latents with vae + latents = self.encode_images(rgb_BCHW_512) + + cond = self.get_cond(elevation, azimuth, camera_distances) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint( + self.min_step, + self.max_step + 1, + [batch_size], + dtype=torch.long, + device=self.device, + ) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) # TODO: use torch generator + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + x_in = torch.cat([latents_noisy] * 2) + t_in = torch.cat([t] * 2) + noise_pred = self.model.apply_model(x_in, t_in, cond) + + # perform guidance + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + w = (1 - self.alphas[t]).reshape(-1, 1, 1, 1) + grad = w * (noise_pred - noise) + grad = torch.nan_to_num(grad) + # clip grad for stable training? + if self.grad_clip_val is not None: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + + # loss = SpecifyGradient.apply(latents, grad) + # SpecifyGradient is not straghtforward, use a reparameterization trick instead + target = (latents - grad).detach() + # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad + loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size + + guidance_out = { + "loss_sd": loss_sds, + "grad_norm": grad.norm(), + "min_step": self.min_step, + "max_step": self.max_step, + } + + return guidance_out + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # clip grad for stable training as demonstrated in + # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation + # http://arxiv.org/abs/2303.15413 + if self.cfg.grad_clip is not None: + self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) + + self.set_min_max_steps( + min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), + max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), + ) \ No newline at end of file diff --git a/threestudio/models/guidance/zero123_guidance.py b/threestudio/models/guidance/zero123_guidance.py new file mode 100644 index 0000000..4f6a176 --- /dev/null +++ b/threestudio/models/guidance/zero123_guidance.py @@ -0,0 +1,528 @@ +import importlib +import os +from dataclasses import dataclass, field + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers import DDIMScheduler, DDPMScheduler, StableDiffusionPipeline +from diffusers.utils.import_utils import is_xformers_available +from omegaconf import OmegaConf +from tqdm import tqdm + +import threestudio +from threestudio.utils.base import BaseObject +from threestudio.utils.misc import C, parse_version +from threestudio.utils.typing import * + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +# load model +def load_model_from_config(config, ckpt, device, vram_O=True, verbose=False): + pl_sd = torch.load(ckpt, map_location="cpu") + + if "global_step" in pl_sd and verbose: + print(f'[INFO] Global Step: {pl_sd["global_step"]}') + + sd = pl_sd["state_dict"] + + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("[INFO] missing keys: \n", m) + if len(u) > 0 and verbose: + print("[INFO] unexpected keys: \n", u) + + # manually load ema and delete it to save GPU memory + if model.use_ema: + if verbose: + print("[INFO] loading EMA...") + model.model_ema.copy_to(model.model) + del model.model_ema + + if vram_O: + # we don't need decoder + del model.first_stage_model.decoder + + torch.cuda.empty_cache() + + model.eval().to(device) + + return model + + +@threestudio.register("zero123-guidance") +class Zero123Guidance(BaseObject): + @dataclass + class Config(BaseObject.Config): + pretrained_model_name_or_path: str = "load/zero123/105000.ckpt" + pretrained_config: str = "load/zero123/sd-objaverse-finetune-c_concat-256.yaml" + vram_O: bool = True + + cond_image_path: str = "load/images/hamburger_rgba.png" + cond_elevation_deg: float = 0.0 + cond_azimuth_deg: float = 0.0 + cond_camera_distance: float = 1.2 + + guidance_scale: float = 5.0 + + grad_clip: Optional[ + Any + ] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000]) + half_precision_weights: bool = False + + min_step_percent: float = 0.02 + max_step_percent: float = 0.98 + + """Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items.""" + max_items_eval: int = 4 + + cfg: Config + + def configure(self) -> None: + threestudio.info(f"Loading Zero123 ...") + + self.config = OmegaConf.load(self.cfg.pretrained_config) + # TODO: seems it cannot load into fp16... + self.weights_dtype = torch.float32 + self.model = load_model_from_config( + self.config, + self.cfg.pretrained_model_name_or_path, + device=self.device, + vram_O=self.cfg.vram_O, + ) + + for p in self.model.parameters(): + p.requires_grad_(False) + + # timesteps: use diffuser for convenience... hope it's alright. + self.num_train_timesteps = self.config.model.params.timesteps + + self.scheduler = DDIMScheduler( + self.num_train_timesteps, + self.config.model.params.linear_start, + self.config.model.params.linear_end, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.set_min_max_steps() # set to default value + + self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to( + self.device + ) + + self.grad_clip_val: Optional[float] = None + + self.prepare_embeddings(self.cfg.cond_image_path) + + threestudio.info(f"Loaded Zero123!") + + @torch.cuda.amp.autocast(enabled=False) + def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): + self.min_step = int(self.num_train_timesteps * min_step_percent) + self.max_step = int(self.num_train_timesteps * max_step_percent) + + @torch.cuda.amp.autocast(enabled=False) + def prepare_embeddings(self, image_path: str) -> None: + # load cond image for zero123 + assert os.path.exists(image_path) + rgba = cv2.cvtColor( + cv2.imread(image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA + ) + rgba = ( + cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype( + np.float32 + ) + / 255.0 + ) + rgb = rgba[..., :3] * rgba[..., 3:] + (1 - rgba[..., 3:]) + self.rgb_256: Float[Tensor, "1 3 H W"] = ( + torch.from_numpy(rgb) + .unsqueeze(0) + .permute(0, 3, 1, 2) + .contiguous() + .to(self.device) + ) + self.c_crossattn, self.c_concat = self.get_img_embeds(self.rgb_256) + + @torch.cuda.amp.autocast(enabled=False) + @torch.no_grad() + def get_img_embeds( + self, + img: Float[Tensor, "B 3 256 256"], + ) -> Tuple[Float[Tensor, "B 1 768"], Float[Tensor, "B 4 32 32"]]: + img = img * 2.0 - 1.0 + c_crossattn = self.model.get_learned_conditioning(img.to(self.weights_dtype)) + c_concat = self.model.encode_first_stage(img.to(self.weights_dtype)).mode() + return c_crossattn, c_concat + + @torch.cuda.amp.autocast(enabled=False) + def encode_images( + self, imgs: Float[Tensor, "B 3 256 256"] + ) -> Float[Tensor, "B 4 32 32"]: + input_dtype = imgs.dtype + imgs = imgs * 2.0 - 1.0 + latents = self.model.get_first_stage_encoding( + self.model.encode_first_stage(imgs.to(self.weights_dtype)) + ) + return latents.to(input_dtype) # [B, 4, 32, 32] Latent space image + + @torch.cuda.amp.autocast(enabled=False) + def decode_latents( + self, + latents: Float[Tensor, "B 4 H W"], + ) -> Float[Tensor, "B 3 512 512"]: + input_dtype = latents.dtype + image = self.model.decode_first_stage(latents) + image = (image * 0.5 + 0.5).clamp(0, 1) + return image.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + @torch.no_grad() + def get_cond( + self, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + c_crossattn=None, + c_concat=None, + **kwargs, + ) -> dict: + T = torch.stack( + [ + torch.deg2rad( + (90 - elevation) - (90 - self.cfg.cond_elevation_deg) + ), # Zero123 polar is 90-elevation + torch.sin(torch.deg2rad(azimuth - self.cfg.cond_azimuth_deg)), + torch.cos(torch.deg2rad(azimuth - self.cfg.cond_azimuth_deg)), + camera_distances - self.cfg.cond_camera_distance, + ], + dim=-1, + )[:, None, :].to(self.device) + cond = {} + clip_emb = self.model.cc_projection( + torch.cat( + [ + (self.c_crossattn if c_crossattn is None else c_crossattn).repeat( + len(T), 1, 1 + ), + T, + ], + dim=-1, + ) + ) + cond["c_crossattn"] = [ + torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0) + ] + cond["c_concat"] = [ + torch.cat( + [ + torch.zeros_like(self.c_concat) + .repeat(len(T), 1, 1, 1) + .to(self.device), + (self.c_concat if c_concat is None else c_concat).repeat( + len(T), 1, 1, 1 + ), + ], + dim=0, + ) + ] + return cond + + def __call__( + self, + rgb: Float[Tensor, "B H W C"], + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + rgb_as_latents=False, + guidance_eval=False, + **kwargs, + ): + batch_size = rgb.shape[0] + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + latents: Float[Tensor, "B 4 64 64"] + if rgb_as_latents: + latents = ( + F.interpolate(rgb_BCHW, (32, 32), mode="bilinear", align_corners=False) + * 2 + - 1 + ) + else: + rgb_BCHW_512 = F.interpolate( + rgb_BCHW, (256, 256), mode="bilinear", align_corners=False + ) + # encode image into latents with vae + latents = self.encode_images(rgb_BCHW_512) + + cond = self.get_cond(elevation, azimuth, camera_distances) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint( + self.min_step, + self.max_step + 1, + [batch_size], + dtype=torch.long, + device=self.device, + ) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) # TODO: use torch generator + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + x_in = torch.cat([latents_noisy] * 2) + t_in = torch.cat([t] * 2) + noise_pred = self.model.apply_model(x_in, t_in, cond) + + # perform guidance + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + w = (1 - self.alphas[t]).reshape(-1, 1, 1, 1) + grad = w * (noise_pred - noise) + grad = torch.nan_to_num(grad) + # clip grad for stable training? + if self.grad_clip_val is not None: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + + # loss = SpecifyGradient.apply(latents, grad) + # SpecifyGradient is not straghtforward, use a reparameterization trick instead + target = (latents - grad).detach() + # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad + loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size + + guidance_out = { + "loss_sd": loss_sds, # loss_sds + "grad_norm": grad.norm(), + "min_step": self.min_step, + "max_step": self.max_step, + } + + if guidance_eval: + guidance_eval_utils = { + "cond": cond, + "t_orig": t, + "latents_noisy": latents_noisy, + "noise_pred": noise_pred, + } + guidance_eval_out = self.guidance_eval(**guidance_eval_utils) + texts = [] + for n, e, a, c in zip( + guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances + ): + texts.append( + f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}" + ) + guidance_eval_out.update({"texts": texts}) + guidance_out.update({"eval": guidance_eval_out}) + + return guidance_out + + @torch.cuda.amp.autocast(enabled=False) + @torch.no_grad() + def guidance_eval(self, cond, t_orig, latents_noisy, noise_pred): + # use only 50 timesteps, and find nearest of those to t + self.scheduler.set_timesteps(50) + self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device) + bs = ( + min(self.cfg.max_items_eval, latents_noisy.shape[0]) + if self.cfg.max_items_eval > 0 + else latents_noisy.shape[0] + ) # batch size + large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[ + :bs + ].unsqueeze( + -1 + ) # sized [bs,50] > [bs,1] + idxs = torch.min(large_enough_idxs, dim=1)[1] + t = self.scheduler.timesteps_gpu[idxs] + + fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy()) + imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1) + + # get prev latent + latents_1step = [] + pred_1orig = [] + for b in range(bs): + step_output = self.scheduler.step( + noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1 + ) + latents_1step.append(step_output["prev_sample"]) + pred_1orig.append(step_output["pred_original_sample"]) + latents_1step = torch.cat(latents_1step) + pred_1orig = torch.cat(pred_1orig) + imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1) + imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1) + + latents_final = [] + for b, i in enumerate(idxs): + latents = latents_1step[b : b + 1] + c = { + "c_crossattn": [cond["c_crossattn"][0][[b, b + len(idxs)], ...]], + "c_concat": [cond["c_concat"][0][[b, b + len(idxs)], ...]], + } + for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False): + # pred noise + x_in = torch.cat([latents] * 2) + t_in = torch.cat([t.reshape(1)] * 2).to(self.device) + noise_pred = self.model.apply_model(x_in, t_in, c) + # perform guidance + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + # get prev latent + latents = self.scheduler.step(noise_pred, t, latents, eta=1)[ + "prev_sample" + ] + latents_final.append(latents) + + latents_final = torch.cat(latents_final) + imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1) + + return { + "bs": bs, + "noise_levels": fracs, + "imgs_noisy": imgs_noisy, + "imgs_1step": imgs_1step, + "imgs_1orig": imgs_1orig, + "imgs_final": imgs_final, + } + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # clip grad for stable training as demonstrated in + # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation + # http://arxiv.org/abs/2303.15413 + if self.cfg.grad_clip is not None: + self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) + + self.set_min_max_steps( + min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), + max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), + ) + + # verification - requires `vram_O = False` in load_model_from_config + @torch.no_grad() + def generate( + self, + image, # image tensor [1, 3, H, W] in [0, 1] + elevation=0, + azimuth=0, + camera_distances=0, # new view params + c_crossattn=None, + c_concat=None, + scale=3, + ddim_steps=50, + post_process=True, + ddim_eta=1, + ): + if c_crossattn is None: + c_crossattn, c_concat = self.get_img_embeds(image) + + cond = self.get_cond( + elevation, azimuth, camera_distances, c_crossattn, c_concat + ) + + imgs = self.gen_from_cond(cond, scale, ddim_steps, post_process, ddim_eta) + + return imgs + + # verification - requires `vram_O = False` in load_model_from_config + @torch.no_grad() + def gen_from_cond( + self, + cond, + scale=3, + ddim_steps=50, + post_process=True, + ddim_eta=1, + ): + # produce latents loop + B = cond["c_crossattn"][0].shape[0] // 2 + latents = torch.randn((B, 4, 32, 32), device=self.device) + self.scheduler.set_timesteps(ddim_steps) + + for t in self.scheduler.timesteps: + x_in = torch.cat([latents] * 2) + t_in = torch.cat([t.reshape(1).repeat(B)] * 2).to(self.device) + + noise_pred = self.model.apply_model(x_in, t_in, cond) + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + scale * ( + noise_pred_cond - noise_pred_uncond + ) + + latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)[ + "prev_sample" + ] + + imgs = self.decode_latents(latents) + imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs + + return imgs + + +if __name__ == '__main__': + from threestudio.utils.config import load_config + import pytorch_lightning as pl + import numpy as np + import os + import cv2 + cfg = load_config("configs/experimental/zero123.yaml") + guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance) + elevations = [0, 20, -20] + azimuths = [45, 90, 135, -45, -90] + radius = torch.tensor([3.8]).to(guidance.device) + outdir = ".threestudio_cache/saiyan" + os.makedirs(outdir, exist_ok=True) + # rgb_image = (rgb_image[0].detach().cpu().clip(0, 1).numpy()*255).astype(np.uint8)[:, :, ::-1].copy() + # os.makedirs('.threestudio_cache', exist_ok=True) + # cv2.imwrite('.threestudio_cache/diffusion_image.jpg', rgb_image) + + + rgb_image = cv2.imread(cfg.system.guidance.cond_image_path)[:, :, ::-1].copy() / 255 + rgb_image = cv2.resize(rgb_image, (256, 256)) + rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device).permute(0,3,1,2) + + for elevation in elevations: + for azimuth in azimuths: + output1 = guidance.generate( + rgb_image, + torch.tensor([elevation]).to(guidance.device), + torch.tensor([azimuth]).to(guidance.device), + radius, + c_crossattn=guidance.c_crossattn, + c_concat=guidance.c_concat + ) + from torchvision.utils import save_image + save_image(torch.tensor(output1).float().permute(0,3,1,2), f"{outdir}/result_e_{elevation}_a_{azimuth}.png", normalize=True, value_range=(0,1)) + \ No newline at end of file diff --git a/threestudio/models/guidance/zero123_unified_guidance.py b/threestudio/models/guidance/zero123_unified_guidance.py new file mode 100644 index 0000000..6b643cb --- /dev/null +++ b/threestudio/models/guidance/zero123_unified_guidance.py @@ -0,0 +1,721 @@ +import os +import random +import sys +from contextlib import contextmanager +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DPMSolverSinglestepScheduler, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.embeddings import TimestepEmbedding +from PIL import Image +from tqdm import tqdm + +import threestudio +from extern.zero123 import Zero123Pipeline +from threestudio.models.networks import ToDTypeWrapper +from threestudio.models.prompt_processors.base import PromptProcessorOutput +from threestudio.utils.base import BaseModule +from threestudio.utils.misc import C, cleanup, enable_gradient, parse_version +from threestudio.utils.typing import * + + +@threestudio.register("zero123-unified-guidance") +class Zero123UnifiedGuidance(BaseModule): + @dataclass + class Config(BaseModule.Config): + cache_dir: Optional[str] = None + local_files_only: Optional[bool] = False + + # guidance type, in ["sds", "vsd"] + guidance_type: str = "sds" + + pretrained_model_name_or_path: str = "bennyguo/zero123-diffusers" + guidance_scale: float = 5.0 + weighting_strategy: str = "dreamfusion" + + min_step_percent: Any = 0.02 + max_step_percent: Any = 0.98 + grad_clip: Optional[Any] = None + + return_rgb_1step_orig: bool = False + return_rgb_multistep_orig: bool = False + n_rgb_multistep_orig_steps: int = 4 + + cond_image_path: str = "" + cond_elevation_deg: float = 0.0 + cond_azimuth_deg: float = 0.0 + cond_camera_distance: float = 1.2 + + # efficiency-related configurations + half_precision_weights: bool = True + + # VSD configurations, only used when guidance_type is "vsd" + vsd_phi_model_name_or_path: Optional[str] = None + vsd_guidance_scale_phi: float = 1.0 + vsd_use_lora: bool = True + vsd_lora_cfg_training: bool = False + vsd_lora_n_timestamp_samples: int = 1 + vsd_use_camera_condition: bool = True + # camera condition type, in ["extrinsics", "mvp", "spherical"] + vsd_camera_condition_type: Optional[str] = "extrinsics" + + cfg: Config + + def configure(self) -> None: + self.min_step: Optional[int] = None + self.max_step: Optional[int] = None + self.grad_clip_val: Optional[float] = None + + @dataclass + class NonTrainableModules: + pipe: Zero123Pipeline + pipe_phi: Optional[Zero123Pipeline] = None + + self.weights_dtype = ( + torch.float16 if self.cfg.half_precision_weights else torch.float32 + ) + + threestudio.info(f"Loading Zero123 ...") + + # need to make sure the pipeline file is in path + sys.path.append("extern/") + + pipe_kwargs = { + "safety_checker": None, + "requires_safety_checker": False, + "variant": "fp16" if self.cfg.half_precision_weights else None, + "torch_dtype": self.weights_dtype, + "cache_dir": self.cfg.cache_dir, + "local_files_only": self.cfg.local_files_only, + } + pipe = Zero123Pipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, + **pipe_kwargs, + ).to(self.device) + self.prepare_pipe(pipe) + + # phi network for VSD + # introduce two trainable modules: + # - self.camera_embedding + # - self.lora_layers + pipe_phi = None + + # if the phi network shares the same unet with the pretrain network + # we need to pass additional cross attention kwargs to the unet + self.vsd_share_model = ( + self.cfg.guidance_type == "vsd" + and self.cfg.vsd_phi_model_name_or_path is None + ) + if self.cfg.guidance_type == "vsd": + if self.cfg.vsd_phi_model_name_or_path is None: + pipe_phi = pipe + else: + pipe_phi = Zero123Pipeline.from_pretrained( + self.cfg.vsd_phi_model_name_or_path, + **pipe_kwargs, + ).to(self.device) + self.prepare_pipe(pipe_phi) + + # set up camera embedding + if self.cfg.vsd_use_camera_condition: + if self.cfg.vsd_camera_condition_type in ["extrinsics", "mvp"]: + self.camera_embedding_dim = 16 + elif self.cfg.vsd_camera_condition_type == "spherical": + self.camera_embedding_dim = 4 + else: + raise ValueError("Invalid camera condition type!") + + # FIXME: hard-coded output dim + self.camera_embedding = ToDTypeWrapper( + TimestepEmbedding(self.camera_embedding_dim, 1280), + self.weights_dtype, + ).to(self.device) + pipe_phi.unet.class_embedding = self.camera_embedding + + if self.cfg.vsd_use_lora: + # set up LoRA layers + lora_attn_procs = {} + for name in pipe_phi.unet.attn_processors.keys(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else pipe_phi.unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = pipe_phi.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list( + reversed(pipe_phi.unet.config.block_out_channels) + )[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = pipe_phi.unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + + pipe_phi.unet.set_attn_processor(lora_attn_procs) + + self.lora_layers = AttnProcsLayers(pipe_phi.unet.attn_processors).to( + self.device + ) + self.lora_layers._load_state_dict_pre_hooks.clear() + self.lora_layers._state_dict_hooks.clear() + + threestudio.info(f"Loaded Stable Diffusion!") + + self.scheduler = DDPMScheduler.from_config(pipe.scheduler.config) + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + + # q(z_t|x) = N(alpha_t x, sigma_t^2 I) + # in DDPM, alpha_t = sqrt(alphas_cumprod_t), sigma_t^2 = 1 - alphas_cumprod_t + self.alphas_cumprod: Float[Tensor, "T"] = self.scheduler.alphas_cumprod.to( + self.device + ) + self.alphas: Float[Tensor, "T"] = self.alphas_cumprod**0.5 + self.sigmas: Float[Tensor, "T"] = (1 - self.alphas_cumprod) ** 0.5 + # log SNR + self.lambdas: Float[Tensor, "T"] = self.sigmas / self.alphas + + self._non_trainable_modules = NonTrainableModules( + pipe=pipe, + pipe_phi=pipe_phi, + ) + + # self.clip_image_embeddings and self.image_latents + self.prepare_image_embeddings() + + @property + def pipe(self) -> Zero123Pipeline: + return self._non_trainable_modules.pipe + + @property + def pipe_phi(self) -> Zero123Pipeline: + if self._non_trainable_modules.pipe_phi is None: + raise RuntimeError("phi model is not available.") + return self._non_trainable_modules.pipe_phi + + def prepare_pipe(self, pipe: Zero123Pipeline): + cleanup() + + pipe.image_encoder.eval() + pipe.vae.eval() + pipe.unet.eval() + pipe.clip_camera_projection.eval() + + enable_gradient(pipe.image_encoder, enabled=False) + enable_gradient(pipe.vae, enabled=False) + enable_gradient(pipe.unet, enabled=False) + enable_gradient(pipe.clip_camera_projection, enabled=False) + + # disable progress bar + pipe.set_progress_bar_config(disable=True) + + def prepare_image_embeddings(self) -> None: + if not os.path.exists(self.cfg.cond_image_path): + raise RuntimeError( + f"Condition image not found at {self.cfg.cond_image_path}" + ) + image = Image.open(self.cfg.cond_image_path).convert("RGBA").resize((256, 256)) + image = ( + TF.to_tensor(image) + .unsqueeze(0) + .to(device=self.device, dtype=self.weights_dtype) + ) + # rgba -> rgb, apply white background + image = image[:, :3] * image[:, 3:4] + (1 - image[:, 3:4]) + + with torch.no_grad(): + self.clip_image_embeddings: Float[ + Tensor, "1 1 D" + ] = self.extract_clip_image_embeddings(image) + + # encoded latents should be multiplied with vae.config.scaling_factor + # but zero123 was not trained this way + self.image_latents: Float[Tensor, "1 4 Hl Wl"] = ( + self.vae_encode(self.pipe.vae, image * 2.0 - 1.0, mode=True) + / self.pipe.vae.config.scaling_factor + ) + + def extract_clip_image_embeddings( + self, images: Float[Tensor, "B 3 H W"] + ) -> Float[Tensor, "B 1 D"]: + # expect images in [0, 1] + images_pil = [TF.to_pil_image(image) for image in images] + images_processed = self.pipe.feature_extractor( + images=images_pil, return_tensors="pt" + ).pixel_values.to(device=self.device, dtype=self.weights_dtype) + clip_image_embeddings = self.pipe.image_encoder(images_processed).image_embeds + return clip_image_embeddings.to(images.dtype) + + def get_image_camera_embeddings( + self, + elevation_deg: Float[Tensor, "B"], + azimuth_deg: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + ) -> Float[Tensor, "B 1 D"]: + batch_size = elevation_deg.shape[0] + camera_embeddings: Float[Tensor, "B 1 4"] = torch.stack( + [ + torch.deg2rad(self.cfg.cond_elevation_deg - elevation_deg), + torch.sin(torch.deg2rad(azimuth_deg - self.cfg.cond_azimuth_deg)), + torch.cos(torch.deg2rad(azimuth_deg - self.cfg.cond_azimuth_deg)), + camera_distances - self.cfg.cond_camera_distance, + ], + dim=-1, + )[:, None, :] + + image_camera_embeddings = self.pipe.clip_camera_projection( + torch.cat( + [ + self.clip_image_embeddings.repeat(batch_size, 1, 1), + camera_embeddings, + ], + dim=-1, + ).to(self.weights_dtype) + ) + + return image_camera_embeddings + + @torch.cuda.amp.autocast(enabled=False) + def forward_unet( + self, + unet: UNet2DConditionModel, + latents: Float[Tensor, "..."], + t: Int[Tensor, "..."], + encoder_hidden_states: Float[Tensor, "..."], + class_labels: Optional[Float[Tensor, "..."]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Float[Tensor, "..."]] = None, + mid_block_additional_residual: Optional[Float[Tensor, "..."]] = None, + velocity_to_epsilon: bool = False, + ) -> Float[Tensor, "..."]: + input_dtype = latents.dtype + pred = unet( + latents.to(unet.dtype), + t.to(unet.dtype), + encoder_hidden_states=encoder_hidden_states.to(unet.dtype), + class_labels=class_labels, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + ).sample + if velocity_to_epsilon: + pred = latents * self.sigmas[t].view(-1, 1, 1, 1) + pred * self.alphas[ + t + ].view(-1, 1, 1, 1) + return pred.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def vae_encode( + self, vae: AutoencoderKL, imgs: Float[Tensor, "B 3 H W"], mode=False + ) -> Float[Tensor, "B 4 Hl Wl"]: + # expect input in [-1, 1] + input_dtype = imgs.dtype + posterior = vae.encode(imgs.to(vae.dtype)).latent_dist + if mode: + latents = posterior.mode() + else: + latents = posterior.sample() + latents = latents * vae.config.scaling_factor + return latents.to(input_dtype) + + @torch.cuda.amp.autocast(enabled=False) + def vae_decode( + self, vae: AutoencoderKL, latents: Float[Tensor, "B 4 Hl Wl"] + ) -> Float[Tensor, "B 3 H W"]: + # output in [0, 1] + input_dtype = latents.dtype + latents = 1 / vae.config.scaling_factor * latents + image = vae.decode(latents.to(vae.dtype)).sample + image = (image * 0.5 + 0.5).clamp(0, 1) + return image.to(input_dtype) + + @contextmanager + def disable_unet_class_embedding(self, unet: UNet2DConditionModel): + class_embedding = unet.class_embedding + try: + unet.class_embedding = None + yield unet + finally: + unet.class_embedding = class_embedding + + @contextmanager + def set_scheduler(self, pipe: Zero123Pipeline, scheduler_class: Any, **kwargs): + scheduler_orig = pipe.scheduler + pipe.scheduler = scheduler_class.from_config(scheduler_orig.config, **kwargs) + yield pipe + pipe.scheduler = scheduler_orig + + def get_eps_pretrain( + self, + latents_noisy: Float[Tensor, "B 4 Hl Wl"], + t: Int[Tensor, "B"], + image_camera_embeddings: Float[Tensor, "B 1 D"], + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + ) -> Float[Tensor, "B 4 Hl Wl"]: + batch_size = latents_noisy.shape[0] + + with torch.no_grad(): + with self.disable_unet_class_embedding(self.pipe.unet) as unet: + noise_pred = self.forward_unet( + unet, + torch.cat( + [ + torch.cat([latents_noisy] * 2, dim=0), + torch.cat( + [ + self.image_latents.repeat(batch_size, 1, 1, 1), + torch.zeros_like(self.image_latents).repeat( + batch_size, 1, 1, 1 + ), + ], + dim=0, + ), + ], + dim=1, + ), + torch.cat([t] * 2, dim=0), + encoder_hidden_states=torch.cat( + [ + image_camera_embeddings, + torch.zeros_like(image_camera_embeddings), + ], + dim=0, + ), + cross_attention_kwargs={"scale": 0.0} + if self.vsd_share_model + else None, + velocity_to_epsilon=self.pipe.scheduler.config.prediction_type + == "v_prediction", + ) + + noise_pred_image, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( + noise_pred_image - noise_pred_uncond + ) + + return noise_pred + + def get_eps_phi( + self, + latents_noisy: Float[Tensor, "B 4 Hl Wl"], + t: Int[Tensor, "B"], + image_camera_embeddings: Float[Tensor, "B 1 D"], + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + camera_condition: Float[Tensor, "B ..."], + ) -> Float[Tensor, "B 4 Hl Wl"]: + batch_size = latents_noisy.shape[0] + + with torch.no_grad(): + noise_pred = self.forward_unet( + self.pipe_phi.unet, + torch.cat( + [ + torch.cat([latents_noisy] * 2, dim=0), + torch.cat( + [self.image_latents.repeat(batch_size, 1, 1, 1)] * 2, + dim=0, + ), + ], + dim=1, + ), + torch.cat([t] * 2, dim=0), + encoder_hidden_states=torch.cat([image_camera_embeddings] * 2, dim=0), + class_labels=torch.cat( + [ + camera_condition.view(batch_size, -1), + torch.zeros_like(camera_condition.view(batch_size, -1)), + ], + dim=0, + ) + if self.cfg.vsd_use_camera_condition + else None, + cross_attention_kwargs={"scale": 1.0}, + velocity_to_epsilon=self.pipe_phi.scheduler.config.prediction_type + == "v_prediction", + ) + + noise_pred_camera, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg.vsd_guidance_scale_phi * ( + noise_pred_camera - noise_pred_uncond + ) + + return noise_pred + + def train_phi( + self, + latents: Float[Tensor, "B 4 Hl Wl"], + image_camera_embeddings: Float[Tensor, "B 1 D"], + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + camera_condition: Float[Tensor, "B ..."], + ): + B = latents.shape[0] + latents = latents.detach().repeat( + self.cfg.vsd_lora_n_timestamp_samples, 1, 1, 1 + ) + + num_train_timesteps = self.pipe_phi.scheduler.config.num_train_timesteps + t = torch.randint( + int(num_train_timesteps * 0.0), + int(num_train_timesteps * 1.0), + [B * self.cfg.vsd_lora_n_timestamp_samples], + dtype=torch.long, + device=self.device, + ) + + noise = torch.randn_like(latents) + latents_noisy = self.pipe_phi.scheduler.add_noise(latents, noise, t) + if self.pipe_phi.scheduler.config.prediction_type == "epsilon": + target = noise + elif self.pipe_phi.scheduler.prediction_type == "v_prediction": + target = self.pipe_phi.scheduler.get_velocity(latents, noise, t) + else: + raise ValueError( + f"Unknown prediction type {self.pipe_phi.scheduler.prediction_type}" + ) + + if ( + self.cfg.vsd_use_camera_condition + and self.cfg.vsd_lora_cfg_training + and random.random() < 0.1 + ): + camera_condition = torch.zeros_like(camera_condition) + + noise_pred = self.forward_unet( + self.pipe_phi.unet, + torch.cat([latents_noisy, self.image_latents.repeat(B, 1, 1, 1)], dim=1), + t, + encoder_hidden_states=image_camera_embeddings.repeat( + self.cfg.vsd_lora_n_timestamp_samples, 1, 1 + ), + class_labels=camera_condition.view(B, -1).repeat( + self.cfg.vsd_lora_n_timestamp_samples, 1 + ) + if self.cfg.vsd_use_camera_condition + else None, + cross_attention_kwargs={"scale": 1.0}, + ) + return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + def forward( + self, + rgb: Float[Tensor, "B H W C"], + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + mvp_mtx: Float[Tensor, "B 4 4"], + c2w: Float[Tensor, "B 4 4"], + rgb_as_latents=False, + **kwargs, + ): + batch_size = rgb.shape[0] + + rgb_BCHW = rgb.permute(0, 3, 1, 2) + latents: Float[Tensor, "B 4 32 32"] + if rgb_as_latents: + # treat input rgb as latents + # input rgb should be in range [-1, 1] + latents = F.interpolate( + rgb_BCHW, (32, 32), mode="bilinear", align_corners=False + ) + else: + # treat input rgb as rgb + # input rgb should be in range [0, 1] + rgb_BCHW = F.interpolate( + rgb_BCHW, (256, 256), mode="bilinear", align_corners=False + ) + # encode image into latents with vae + latents = self.vae_encode(self.pipe.vae, rgb_BCHW * 2.0 - 1.0) + + # sample timestep + # use the same timestep for each batch + assert self.min_step is not None and self.max_step is not None + t = torch.randint( + self.min_step, + self.max_step + 1, + [1], + dtype=torch.long, + device=self.device, + ).repeat(batch_size) + + # sample noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + + # image-camera feature condition + image_camera_embeddings = self.get_image_camera_embeddings( + elevation, azimuth, camera_distances + ) + + eps_pretrain = self.get_eps_pretrain( + latents_noisy, + t, + image_camera_embeddings, + elevation, + azimuth, + camera_distances, + ) + + latents_1step_orig = ( + 1 + / self.alphas[t].view(-1, 1, 1, 1) + * (latents_noisy - self.sigmas[t].view(-1, 1, 1, 1) * eps_pretrain) + ).detach() + + if self.cfg.guidance_type == "sds": + eps_phi = noise + elif self.cfg.guidance_type == "vsd": + if self.cfg.vsd_camera_condition_type == "extrinsics": + camera_condition = c2w + elif self.cfg.vsd_camera_condition_type == "mvp": + camera_condition = mvp_mtx + elif self.cfg.vsd_camera_condition_type == "spherical": + camera_condition = torch.stack( + [ + torch.deg2rad(elevation), + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + camera_distances, + ], + dim=-1, + ) + else: + raise ValueError( + f"Unknown camera_condition_type {self.cfg.vsd_camera_condition_type}" + ) + eps_phi = self.get_eps_phi( + latents_noisy, + t, + image_camera_embeddings, + elevation, + azimuth, + camera_distances, + camera_condition, + ) + + loss_train_phi = self.train_phi( + latents, + image_camera_embeddings, + elevation, + azimuth, + camera_distances, + camera_condition, + ) + + if self.cfg.weighting_strategy == "dreamfusion": + w = (1.0 - self.alphas[t]).view(-1, 1, 1, 1) + elif self.cfg.weighting_strategy == "uniform": + w = 1.0 + elif self.cfg.weighting_strategy == "fantasia3d": + w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1) + else: + raise ValueError( + f"Unknown weighting strategy: {self.cfg.weighting_strategy}" + ) + + grad = w * (eps_pretrain - eps_phi) + + if self.grad_clip_val is not None: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + + # reparameterization trick: + # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad + target = (latents - grad).detach() + loss_sd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size + + guidance_out = { + "loss_sd": loss_sd, + "grad_norm": grad.norm(), + "timesteps": t, + "min_step": self.min_step, + "max_step": self.max_step, + "latents": latents, + "latents_1step_orig": latents_1step_orig, + "rgb": rgb_BCHW.permute(0, 2, 3, 1), + "weights": w, + "lambdas": self.lambdas[t], + } + + if self.cfg.return_rgb_1step_orig: + with torch.no_grad(): + rgb_1step_orig = self.vae_decode( + self.pipe.vae, latents_1step_orig + ).permute(0, 2, 3, 1) + guidance_out.update({"rgb_1step_orig": rgb_1step_orig}) + + if self.cfg.return_rgb_multistep_orig: + with self.set_scheduler( + self.pipe, + DPMSolverSinglestepScheduler, + solver_order=1, + num_train_timesteps=int(t[0]), + ) as pipe: + with torch.cuda.amp.autocast(enabled=False): + latents_multistep_orig = pipe( + num_inference_steps=self.cfg.n_rgb_multistep_orig_steps, + guidance_scale=self.cfg.guidance_scale, + eta=1.0, + latents=latents_noisy.to(pipe.unet.dtype), + image_camera_embeddings=image_camera_embeddings.to( + pipe.unet.dtype + ), + image_latents=self.image_latents.repeat(batch_size, 1, 1, 1).to( + pipe.unet.dtype + ), + cross_attention_kwargs={"scale": 0.0} + if self.vsd_share_model + else None, + output_type="latent", + ).images.to(latents.dtype) + with torch.no_grad(): + rgb_multistep_orig = self.vae_decode( + self.pipe.vae, latents_multistep_orig + ) + guidance_out.update( + { + "latents_multistep_orig": latents_multistep_orig, + "rgb_multistep_orig": rgb_multistep_orig.permute(0, 2, 3, 1), + } + ) + + if self.cfg.guidance_type == "vsd": + guidance_out.update( + { + "loss_train_phi": loss_train_phi, + } + ) + + return guidance_out + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # clip grad for stable training as demonstrated in + # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation + # http://arxiv.org/abs/2303.15413 + if self.cfg.grad_clip is not None: + self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) + + self.min_step = int( + self.num_train_timesteps * C(self.cfg.min_step_percent, epoch, global_step) + ) + self.max_step = int( + self.num_train_timesteps * C(self.cfg.max_step_percent, epoch, global_step) + ) \ No newline at end of file diff --git a/threestudio/models/isosurface.py b/threestudio/models/isosurface.py new file mode 100644 index 0000000..3a7f149 --- /dev/null +++ b/threestudio/models/isosurface.py @@ -0,0 +1,253 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.mesh import Mesh +from threestudio.utils.typing import * + + +class IsosurfaceHelper(nn.Module): + points_range: Tuple[float, float] = (0, 1) + + @property + def grid_vertices(self) -> Float[Tensor, "N 3"]: + raise NotImplementedError + + +class MarchingCubeCPUHelper(IsosurfaceHelper): + def __init__(self, resolution: int) -> None: + super().__init__() + self.resolution = resolution + import mcubes + + self.mc_func: Callable = mcubes.marching_cubes + self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None + self._dummy: Float[Tensor, "..."] + self.register_buffer( + "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False + ) + + @property + def grid_vertices(self) -> Float[Tensor, "N3 3"]: + if self._grid_vertices is None: + # keep the vertices on CPU so that we can support very large resolution + x, y, z = ( + torch.linspace(*self.points_range, self.resolution), + torch.linspace(*self.points_range, self.resolution), + torch.linspace(*self.points_range, self.resolution), + ) + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + verts = torch.cat( + [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1 + ).reshape(-1, 3) + self._grid_vertices = verts + return self._grid_vertices + + def forward( + self, + level: Float[Tensor, "N3 1"], + deformation: Optional[Float[Tensor, "N3 3"]] = None, + ) -> Mesh: + if deformation is not None: + threestudio.warn( + f"{self.__class__.__name__} does not support deformation. Ignoring." + ) + level = -level.view(self.resolution, self.resolution, self.resolution) + v_pos, t_pos_idx = self.mc_func( + level.detach().cpu().numpy(), 0.0 + ) # transform to numpy + v_pos, t_pos_idx = ( + torch.from_numpy(v_pos).float().to(self._dummy.device), + torch.from_numpy(t_pos_idx.astype(np.int64)).long().to(self._dummy.device), + ) # transform back to torch tensor on CUDA + v_pos = v_pos / (self.resolution - 1.0) + return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx) + + +class MarchingTetrahedraHelper(IsosurfaceHelper): + def __init__(self, resolution: int, tets_path: str): + super().__init__() + self.resolution = resolution + self.tets_path = tets_path + + self.triangle_table: Float[Tensor, "..."] + self.register_buffer( + "triangle_table", + torch.as_tensor( + [ + [-1, -1, -1, -1, -1, -1], + [1, 0, 2, -1, -1, -1], + [4, 0, 3, -1, -1, -1], + [1, 4, 2, 1, 3, 4], + [3, 1, 5, -1, -1, -1], + [2, 3, 0, 2, 5, 3], + [1, 4, 0, 1, 5, 4], + [4, 2, 5, -1, -1, -1], + [4, 5, 2, -1, -1, -1], + [4, 1, 0, 4, 5, 1], + [3, 2, 0, 3, 5, 2], + [1, 3, 5, -1, -1, -1], + [4, 1, 2, 4, 3, 1], + [3, 0, 4, -1, -1, -1], + [2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1], + ], + dtype=torch.long, + ), + persistent=False, + ) + self.num_triangles_table: Integer[Tensor, "..."] + self.register_buffer( + "num_triangles_table", + torch.as_tensor( + [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long + ), + persistent=False, + ) + self.base_tet_edges: Integer[Tensor, "..."] + self.register_buffer( + "base_tet_edges", + torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), + persistent=False, + ) + + tets = np.load(self.tets_path) + self._grid_vertices: Float[Tensor, "..."] + self.register_buffer( + "_grid_vertices", + torch.from_numpy(tets["vertices"]).float(), + persistent=False, + ) + self.indices: Integer[Tensor, "..."] + self.register_buffer( + "indices", torch.from_numpy(tets["indices"]).long(), persistent=False + ) + + self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None + + def normalize_grid_deformation( + self, grid_vertex_offsets: Float[Tensor, "Nv 3"] + ) -> Float[Tensor, "Nv 3"]: + return ( + (self.points_range[1] - self.points_range[0]) + / (self.resolution) # half tet size is approximately 1 / self.resolution + * torch.tanh(grid_vertex_offsets) + ) # FIXME: hard-coded activation + + @property + def grid_vertices(self) -> Float[Tensor, "Nv 3"]: + return self._grid_vertices + + @property + def all_edges(self) -> Integer[Tensor, "Ne 2"]: + if self._all_edges is None: + # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation) + edges = torch.tensor( + [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], + dtype=torch.long, + device=self.indices.device, + ) + _all_edges = self.indices[:, edges].reshape(-1, 2) + _all_edges_sorted = torch.sort(_all_edges, dim=1)[0] + _all_edges = torch.unique(_all_edges_sorted, dim=0) + self._all_edges = _all_edges + return self._all_edges + + def sort_edges(self, edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() + order = order.unsqueeze(dim=1) + + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1 - order, dim=1) + + return torch.stack([a, b], -1) + + def _forward(self, pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) + all_edges = self.sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = ( + torch.ones( + (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device + ) + * -1 + ) + mapping[mask_edges] = torch.arange( + mask_edges.sum(), dtype=torch.long, device=pos_nx3.device + ) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = self.num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], + dim=1, + index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], + ).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], + dim=1, + index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], + ).reshape(-1, 3), + ), + dim=0, + ) + + return verts, faces + + def forward( + self, + level: Float[Tensor, "N3 1"], + deformation: Optional[Float[Tensor, "N3 3"]] = None, + ) -> Mesh: + if deformation is not None: + grid_vertices = self.grid_vertices + self.normalize_grid_deformation( + deformation + ) + else: + grid_vertices = self.grid_vertices + + v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) + + mesh = Mesh( + v_pos=v_pos, + t_pos_idx=t_pos_idx, + # extras + grid_vertices=grid_vertices, + tet_edges=self.all_edges, + grid_level=level, + grid_deformation=deformation, + ) + + return mesh diff --git a/threestudio/models/materials/__init__.py b/threestudio/models/materials/__init__.py new file mode 100644 index 0000000..85d50ba --- /dev/null +++ b/threestudio/models/materials/__init__.py @@ -0,0 +1,9 @@ +from . import ( + base, + diffuse_with_point_light_material, + hybrid_rgb_latent_material, + neural_radiance_material, + no_material, + pbr_material, + sd_latent_adapter_material, +) diff --git a/threestudio/models/materials/base.py b/threestudio/models/materials/base.py new file mode 100644 index 0000000..9df8f30 --- /dev/null +++ b/threestudio/models/materials/base.py @@ -0,0 +1,29 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.utils.base import BaseModule +from threestudio.utils.typing import * + + +class BaseMaterial(BaseModule): + @dataclass + class Config(BaseModule.Config): + pass + + cfg: Config + requires_normal: bool = False + requires_tangent: bool = False + + def configure(self): + pass + + def forward(self, *args, **kwargs) -> Float[Tensor, "*B 3"]: + raise NotImplementedError + + def export(self, *args, **kwargs) -> Dict[str, Any]: + return {} diff --git a/threestudio/models/materials/diffuse_with_point_light_material.py b/threestudio/models/materials/diffuse_with_point_light_material.py new file mode 100644 index 0000000..abf0671 --- /dev/null +++ b/threestudio/models/materials/diffuse_with_point_light_material.py @@ -0,0 +1,120 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.materials.base import BaseMaterial +from threestudio.utils.ops import dot, get_activation +from threestudio.utils.typing import * + + +@threestudio.register("diffuse-with-point-light-material") +class DiffuseWithPointLightMaterial(BaseMaterial): + @dataclass + class Config(BaseMaterial.Config): + ambient_light_color: Tuple[float, float, float] = (0.1, 0.1, 0.1) + diffuse_light_color: Tuple[float, float, float] = (0.9, 0.9, 0.9) + ambient_only_steps: int = 1000 + diffuse_prob: float = 0.75 + textureless_prob: float = 0.5 + albedo_activation: str = "sigmoid" + soft_shading: bool = False + + cfg: Config + + def configure(self) -> None: + self.requires_normal = True + + self.ambient_light_color: Float[Tensor, "3"] + self.register_buffer( + "ambient_light_color", + torch.as_tensor(self.cfg.ambient_light_color, dtype=torch.float32), + ) + self.diffuse_light_color: Float[Tensor, "3"] + self.register_buffer( + "diffuse_light_color", + torch.as_tensor(self.cfg.diffuse_light_color, dtype=torch.float32), + ) + self.ambient_only = False + + def forward( + self, + features: Float[Tensor, "B ... Nf"], + positions: Float[Tensor, "B ... 3"], + shading_normal: Float[Tensor, "B ... 3"], + light_positions: Float[Tensor, "B ... 3"], + ambient_ratio: Optional[float] = None, + shading: Optional[str] = None, + **kwargs, + ) -> Float[Tensor, "B ... 3"]: + albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]) + + if ambient_ratio is not None: + # if ambient ratio is specified, use it + diffuse_light_color = (1 - ambient_ratio) * torch.ones_like( + self.diffuse_light_color + ) + ambient_light_color = ambient_ratio * torch.ones_like( + self.ambient_light_color + ) + elif self.training and self.cfg.soft_shading: + # otherwise if in training and soft shading is enabled, random a ambient ratio + diffuse_light_color = torch.full_like( + self.diffuse_light_color, random.random() + ) + ambient_light_color = 1.0 - diffuse_light_color + else: + # otherwise use the default fixed values + diffuse_light_color = self.diffuse_light_color + ambient_light_color = self.ambient_light_color + + light_directions: Float[Tensor, "B ... 3"] = F.normalize( + light_positions - positions, dim=-1 + ) + diffuse_light: Float[Tensor, "B ... 3"] = ( + dot(shading_normal, light_directions).clamp(min=0.0) * diffuse_light_color + ) + textureless_color = diffuse_light + ambient_light_color + # clamp albedo to [0, 1] to compute shading + color = albedo.clamp(0.0, 1.0) * textureless_color + + if shading is None: + if self.training: + # adopt the same type of augmentation for the whole batch + if self.ambient_only or random.random() > self.cfg.diffuse_prob: + shading = "albedo" + elif random.random() < self.cfg.textureless_prob: + shading = "textureless" + else: + shading = "diffuse" + else: + if self.ambient_only: + shading = "albedo" + else: + # return shaded color by default in evaluation + shading = "diffuse" + + # multiply by 0 to prevent checking for unused parameters in DDP + if shading == "albedo": + return albedo + textureless_color * 0 + elif shading == "textureless": + return albedo * 0 + textureless_color + elif shading == "diffuse": + return color + else: + raise ValueError(f"Unknown shading type {shading}") + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + if global_step < self.cfg.ambient_only_steps: + self.ambient_only = True + else: + self.ambient_only = False + + def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: + albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]).clamp( + 0.0, 1.0 + ) + return {"albedo": albedo} diff --git a/threestudio/models/materials/hybrid_rgb_latent_material.py b/threestudio/models/materials/hybrid_rgb_latent_material.py new file mode 100644 index 0000000..f5f2c51 --- /dev/null +++ b/threestudio/models/materials/hybrid_rgb_latent_material.py @@ -0,0 +1,36 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.materials.base import BaseMaterial +from threestudio.models.networks import get_encoding, get_mlp +from threestudio.utils.ops import dot, get_activation +from threestudio.utils.typing import * + + +@threestudio.register("hybrid-rgb-latent-material") +class HybridRGBLatentMaterial(BaseMaterial): + @dataclass + class Config(BaseMaterial.Config): + n_output_dims: int = 3 + color_activation: str = "sigmoid" + requires_normal: bool = True + + cfg: Config + + def configure(self) -> None: + self.requires_normal = self.cfg.requires_normal + + def forward( + self, features: Float[Tensor, "B ... Nf"], **kwargs + ) -> Float[Tensor, "B ... Nc"]: + assert ( + features.shape[-1] == self.cfg.n_output_dims + ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." + color = features + color[..., :3] = get_activation(self.cfg.color_activation)(color[..., :3]) + return color diff --git a/threestudio/models/materials/neural_radiance_material.py b/threestudio/models/materials/neural_radiance_material.py new file mode 100644 index 0000000..c9dcc50 --- /dev/null +++ b/threestudio/models/materials/neural_radiance_material.py @@ -0,0 +1,54 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.materials.base import BaseMaterial +from threestudio.models.networks import get_encoding, get_mlp +from threestudio.utils.ops import dot, get_activation +from threestudio.utils.typing import * + + +@threestudio.register("neural-radiance-material") +class NeuralRadianceMaterial(BaseMaterial): + @dataclass + class Config(BaseMaterial.Config): + input_feature_dims: int = 8 + color_activation: str = "sigmoid" + dir_encoding_config: dict = field( + default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} + ) + mlp_network_config: dict = field( + default_factory=lambda: { + "otype": "FullyFusedMLP", + "activation": "ReLU", + "n_neurons": 16, + "n_hidden_layers": 2, + } + ) + + cfg: Config + + def configure(self) -> None: + self.encoding = get_encoding(3, self.cfg.dir_encoding_config) + self.n_input_dims = self.cfg.input_feature_dims + self.encoding.n_output_dims # type: ignore + self.network = get_mlp(self.n_input_dims, 3, self.cfg.mlp_network_config) + + def forward( + self, + features: Float[Tensor, "*B Nf"], + viewdirs: Float[Tensor, "*B 3"], + **kwargs, + ) -> Float[Tensor, "*B 3"]: + # viewdirs and normals must be normalized before passing to this function + viewdirs = (viewdirs + 1.0) / 2.0 # (-1, 1) => (0, 1) + viewdirs_embd = self.encoding(viewdirs.view(-1, 3)) + network_inp = torch.cat( + [features.view(-1, features.shape[-1]), viewdirs_embd], dim=-1 + ) + color = self.network(network_inp).view(*features.shape[:-1], 3) + color = get_activation(self.cfg.color_activation)(color) + return color diff --git a/threestudio/models/materials/no_material.py b/threestudio/models/materials/no_material.py new file mode 100644 index 0000000..402a951 --- /dev/null +++ b/threestudio/models/materials/no_material.py @@ -0,0 +1,63 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.materials.base import BaseMaterial +from threestudio.models.networks import get_encoding, get_mlp +from threestudio.utils.ops import dot, get_activation +from threestudio.utils.typing import * + + +@threestudio.register("no-material") +class NoMaterial(BaseMaterial): + @dataclass + class Config(BaseMaterial.Config): + n_output_dims: int = 3 + color_activation: str = "sigmoid" + input_feature_dims: Optional[int] = None + mlp_network_config: Optional[dict] = None + requires_normal: bool = False + + cfg: Config + + def configure(self) -> None: + self.use_network = False + if ( + self.cfg.input_feature_dims is not None + and self.cfg.mlp_network_config is not None + ): + self.network = get_mlp( + self.cfg.input_feature_dims, + self.cfg.n_output_dims, + self.cfg.mlp_network_config, + ) + self.use_network = True + self.requires_normal = self.cfg.requires_normal + + def forward( + self, features: Float[Tensor, "B ... Nf"], **kwargs + ) -> Float[Tensor, "B ... Nc"]: + if not self.use_network: + assert ( + features.shape[-1] == self.cfg.n_output_dims + ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." + color = get_activation(self.cfg.color_activation)(features) + else: + color = self.network(features.view(-1, features.shape[-1])).view( + *features.shape[:-1], self.cfg.n_output_dims + ) + color = get_activation(self.cfg.color_activation)(color) + return color + + def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: + color = self(features, **kwargs).clamp(0, 1) + assert color.shape[-1] >= 3, "Output color must have at least 3 channels" + if color.shape[-1] > 3: + threestudio.warn( + "Output color has >3 channels, treating the first 3 as RGB" + ) + return {"albedo": color[..., :3]} diff --git a/threestudio/models/materials/pbr_material.py b/threestudio/models/materials/pbr_material.py new file mode 100644 index 0000000..c81f67b --- /dev/null +++ b/threestudio/models/materials/pbr_material.py @@ -0,0 +1,143 @@ +import random +from dataclasses import dataclass, field + +import envlight +import numpy as np +import nvdiffrast.torch as dr +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.materials.base import BaseMaterial +from threestudio.utils.ops import get_activation +from threestudio.utils.typing import * + + +@threestudio.register("pbr-material") +class PBRMaterial(BaseMaterial): + @dataclass + class Config(BaseMaterial.Config): + material_activation: str = "sigmoid" + environment_texture: str = "load/lights/mud_road_puresky_1k.hdr" + environment_scale: float = 2.0 + min_metallic: float = 0.0 + max_metallic: float = 0.9 + min_roughness: float = 0.08 + max_roughness: float = 0.9 + use_bump: bool = True + + cfg: Config + + def configure(self) -> None: + self.requires_normal = True + self.requires_tangent = self.cfg.use_bump + + self.light = envlight.EnvLight( + self.cfg.environment_texture, scale=self.cfg.environment_scale + ) + + FG_LUT = torch.from_numpy( + np.fromfile("load/lights/bsdf_256_256.bin", dtype=np.float32).reshape( + 1, 256, 256, 2 + ) + ) + self.register_buffer("FG_LUT", FG_LUT) + + def forward( + self, + features: Float[Tensor, "*B Nf"], + viewdirs: Float[Tensor, "*B 3"], + shading_normal: Float[Tensor, "B ... 3"], + tangent: Optional[Float[Tensor, "B ... 3"]] = None, + **kwargs, + ) -> Float[Tensor, "*B 3"]: + prefix_shape = features.shape[:-1] + + material: Float[Tensor, "*B Nf"] = get_activation(self.cfg.material_activation)( + features + ) + albedo = material[..., :3] + metallic = ( + material[..., 3:4] * (self.cfg.max_metallic - self.cfg.min_metallic) + + self.cfg.min_metallic + ) + roughness = ( + material[..., 4:5] * (self.cfg.max_roughness - self.cfg.min_roughness) + + self.cfg.min_roughness + ) + + if self.cfg.use_bump: + assert tangent is not None + # perturb_normal is a delta to the initialization [0, 0, 1] + perturb_normal = (material[..., 5:8] * 2 - 1) + torch.tensor( + [0, 0, 1], dtype=material.dtype, device=material.device + ) + perturb_normal = F.normalize(perturb_normal.clamp(-1, 1), dim=-1) + + # apply normal perturbation in tangent space + bitangent = F.normalize(torch.cross(tangent, shading_normal), dim=-1) + shading_normal = ( + tangent * perturb_normal[..., 0:1] + - bitangent * perturb_normal[..., 1:2] + + shading_normal * perturb_normal[..., 2:3] + ) + shading_normal = F.normalize(shading_normal, dim=-1) + + v = -viewdirs + n_dot_v = (shading_normal * v).sum(-1, keepdim=True) + reflective = n_dot_v * shading_normal * 2 - v + + diffuse_albedo = (1 - metallic) * albedo + + fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1) + fg = dr.texture( + self.FG_LUT, + fg_uv.reshape(1, -1, 1, 2).contiguous(), + filter_mode="linear", + boundary_mode="clamp", + ).reshape(*prefix_shape, 2) + F0 = (1 - metallic) * 0.04 + metallic * albedo + specular_albedo = F0 * fg[:, 0:1] + fg[:, 1:2] + + diffuse_light = self.light(shading_normal) + specular_light = self.light(reflective, roughness) + + color = diffuse_albedo * diffuse_light + specular_albedo * specular_light + color = color.clamp(0.0, 1.0) + + return color + + def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: + material: Float[Tensor, "*N Nf"] = get_activation(self.cfg.material_activation)( + features + ) + albedo = material[..., :3] + metallic = ( + material[..., 3:4] * (self.cfg.max_metallic - self.cfg.min_metallic) + + self.cfg.min_metallic + ) + roughness = ( + material[..., 4:5] * (self.cfg.max_roughness - self.cfg.min_roughness) + + self.cfg.min_roughness + ) + + out = { + "albedo": albedo, + "metallic": metallic, + "roughness": roughness, + } + + if self.cfg.use_bump: + perturb_normal = (material[..., 5:8] * 2 - 1) + torch.tensor( + [0, 0, 1], dtype=material.dtype, device=material.device + ) + perturb_normal = F.normalize(perturb_normal.clamp(-1, 1), dim=-1) + perturb_normal = (perturb_normal + 1) / 2 + out.update( + { + "bump": perturb_normal, + } + ) + + return out diff --git a/threestudio/models/materials/sd_latent_adapter_material.py b/threestudio/models/materials/sd_latent_adapter_material.py new file mode 100644 index 0000000..046cabb --- /dev/null +++ b/threestudio/models/materials/sd_latent_adapter_material.py @@ -0,0 +1,42 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.materials.base import BaseMaterial +from threestudio.utils.typing import * + + +@threestudio.register("sd-latent-adapter-material") +class StableDiffusionLatentAdapterMaterial(BaseMaterial): + @dataclass + class Config(BaseMaterial.Config): + pass + + cfg: Config + + def configure(self) -> None: + adapter = nn.Parameter( + torch.as_tensor( + [ + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ] + ) + ) + self.register_parameter("adapter", adapter) + + def forward( + self, features: Float[Tensor, "B ... 4"], **kwargs + ) -> Float[Tensor, "B ... 3"]: + assert features.shape[-1] == 4 + color = features @ self.adapter + color = (color + 1) / 2 + color = color.clamp(0.0, 1.0) + return color diff --git a/threestudio/models/mesh.py b/threestudio/models/mesh.py new file mode 100644 index 0000000..d7324f4 --- /dev/null +++ b/threestudio/models/mesh.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import numpy as np +import torch +import torch.nn.functional as F + +import threestudio +from threestudio.utils.ops import dot +from threestudio.utils.typing import * + + +class Mesh: + def __init__( + self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs + ) -> None: + self.v_pos: Float[Tensor, "Nv 3"] = v_pos + self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx + self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None + self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None + self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None + self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None + self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None + self._edges: Optional[Integer[Tensor, "Ne 2"]] = None + self.extras: Dict[str, Any] = {} + for k, v in kwargs.items(): + self.add_extra(k, v) + + def add_extra(self, k, v) -> None: + self.extras[k] = v + + def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]) -> Mesh: + if self.requires_grad: + threestudio.debug("Mesh is differentiable, not removing outliers") + return self + + # use trimesh to first split the mesh into connected components + # then remove the components with less than n_face_threshold faces + import trimesh + + # construct a trimesh object + mesh = trimesh.Trimesh( + vertices=self.v_pos.detach().cpu().numpy(), + faces=self.t_pos_idx.detach().cpu().numpy(), + ) + + # split the mesh into connected components + components = mesh.split(only_watertight=False) + # log the number of faces in each component + threestudio.debug( + "Mesh has {} components, with faces: {}".format( + len(components), [c.faces.shape[0] for c in components] + ) + ) + + n_faces_threshold: int + if isinstance(outlier_n_faces_threshold, float): + # set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold + n_faces_threshold = int( + max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold + ) + else: + # set the threshold directly to outlier_n_faces_threshold + n_faces_threshold = outlier_n_faces_threshold + + # log the threshold + threestudio.debug( + "Removing components with less than {} faces".format(n_faces_threshold) + ) + + # remove the components with less than n_face_threshold faces + components = [c for c in components if c.faces.shape[0] >= n_faces_threshold] + + # log the number of faces in each component after removing outliers + threestudio.debug( + "Mesh has {} components after removing outliers, with faces: {}".format( + len(components), [c.faces.shape[0] for c in components] + ) + ) + # merge the components + mesh = trimesh.util.concatenate(components) + + # convert back to our mesh format + v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos) + t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx) + + clean_mesh = Mesh(v_pos, t_pos_idx) + # keep the extras unchanged + + if len(self.extras) > 0: + clean_mesh.extras = self.extras + threestudio.debug( + f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}" + ) + return clean_mesh + + @property + def requires_grad(self): + return self.v_pos.requires_grad + + @property + def v_nrm(self): + if self._v_nrm is None: + self._v_nrm = self._compute_vertex_normal() + return self._v_nrm + + @property + def v_tng(self): + if self._v_tng is None: + self._v_tng = self._compute_vertex_tangent() + return self._v_tng + + @property + def v_tex(self): + if self._v_tex is None: + self._v_tex, self._t_tex_idx = self._unwrap_uv() + return self._v_tex + + @property + def t_tex_idx(self): + if self._t_tex_idx is None: + self._v_tex, self._t_tex_idx = self._unwrap_uv() + return self._t_tex_idx + + @property + def v_rgb(self): + return self._v_rgb + + @property + def edges(self): + if self._edges is None: + self._edges = self._compute_edges() + return self._edges + + def _compute_vertex_normal(self): + i0 = self.t_pos_idx[:, 0] + i1 = self.t_pos_idx[:, 1] + i2 = self.t_pos_idx[:, 2] + + v0 = self.v_pos[i0, :] + v1 = self.v_pos[i1, :] + v2 = self.v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(self.v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where( + dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) + ) + v_nrm = F.normalize(v_nrm, dim=1) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return v_nrm + + def _compute_vertex_tangent(self): + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0, 3): + pos[i] = self.v_pos[self.t_pos_idx[:, i]] + tex[i] = self.v_tex[self.t_tex_idx[:, i]] + # t_nrm_idx is always the same as t_pos_idx + vn_idx[i] = self.t_pos_idx[:, i] + + tangents = torch.zeros_like(self.v_nrm) + tansum = torch.zeros_like(self.v_nrm) + + # Compute tangent space for each triangle + uve1 = tex[1] - tex[0] + uve2 = tex[2] - tex[0] + pe1 = pos[1] - pos[0] + pe2 = pos[2] - pos[0] + + nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2] + denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1] + + # Avoid division by zero for degenerated texture coordinates + tang = nom / torch.where( + denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6) + ) + + # Update all 3 vertices + for i in range(0, 3): + idx = vn_idx[i][:, None].repeat(1, 3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + tansum.scatter_add_( + 0, idx, torch.ones_like(tang) + ) # tansum[n_i] = tansum[n_i] + 1 + tangents = tangents / tansum + + # Normalize and make sure tangent is perpendicular to normal + tangents = F.normalize(tangents, dim=1) + tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + + return tangents + + def _unwrap_uv( + self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} + ): + threestudio.info("Using xatlas to perform UV unwrapping, may take a while ...") + + import xatlas + + atlas = xatlas.Atlas() + atlas.add_mesh( + self.v_pos.detach().cpu().numpy(), + self.t_pos_idx.cpu().numpy(), + ) + co = xatlas.ChartOptions() + po = xatlas.PackOptions() + for k, v in xatlas_chart_options.items(): + setattr(co, k, v) + for k, v in xatlas_pack_options.items(): + setattr(po, k, v) + atlas.generate(co, po) + vmapping, indices, uvs = atlas.get_mesh(0) + vmapping = ( + torch.from_numpy( + vmapping.astype(np.uint64, casting="same_kind").view(np.int64) + ) + .to(self.v_pos.device) + .long() + ) + uvs = torch.from_numpy(uvs).to(self.v_pos.device).float() + indices = ( + torch.from_numpy( + indices.astype(np.uint64, casting="same_kind").view(np.int64) + ) + .to(self.v_pos.device) + .long() + ) + return uvs, indices + + def unwrap_uv( + self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} + ): + self._v_tex, self._t_tex_idx = self._unwrap_uv( + xatlas_chart_options, xatlas_pack_options + ) + + def set_vertex_color(self, v_rgb): + assert v_rgb.shape[0] == self.v_pos.shape[0] + self._v_rgb = v_rgb + + def _compute_edges(self): + # Compute edges + edges = torch.cat( + [ + self.t_pos_idx[:, [0, 1]], + self.t_pos_idx[:, [1, 2]], + self.t_pos_idx[:, [2, 0]], + ], + dim=0, + ) + edges = edges.sort()[0] + edges = torch.unique(edges, dim=0) + return edges + + def normal_consistency(self) -> Float[Tensor, ""]: + edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges] + nc = ( + 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1) + ).mean() + return nc + + def _laplacian_uniform(self): + # from stable-dreamfusion + # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224 + verts, faces = self.v_pos, self.t_pos_idx + + V = verts.shape[0] + F = faces.shape[0] + + # Neighbor indices + ii = faces[:, [1, 2, 0]].flatten() + jj = faces[:, [2, 0, 1]].flatten() + adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique( + dim=1 + ) + adj_values = torch.ones(adj.shape[1]).to(verts) + + # Diagonal indices + diag_idx = adj[0] + + # Build the sparse matrix + idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1) + values = torch.cat((-adj_values, adj_values)) + + # The coalesce operation sums the duplicate indices, resulting in the + # correct diagonal + return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce() + + def laplacian(self) -> Float[Tensor, ""]: + with torch.no_grad(): + L = self._laplacian_uniform() + loss = L.mm(self.v_pos) + loss = loss.norm(dim=1) + loss = loss.mean() + return loss \ No newline at end of file diff --git a/threestudio/models/networks.py b/threestudio/models/networks.py new file mode 100644 index 0000000..d86df23 --- /dev/null +++ b/threestudio/models/networks.py @@ -0,0 +1,411 @@ +import math + +import tinycudann as tcnn +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.utils.base import Updateable +from threestudio.utils.config import config_to_primitive +from threestudio.utils.misc import get_rank +from threestudio.utils.ops import get_activation +from threestudio.utils.typing import * + + +class ProgressiveBandFrequency(nn.Module, Updateable): + def __init__(self, in_channels: int, config: dict): + super().__init__() + self.N_freqs = config["n_frequencies"] + self.in_channels, self.n_input_dims = in_channels, in_channels + self.funcs = [torch.sin, torch.cos] + self.freq_bands = 2 ** torch.linspace(0, self.N_freqs - 1, self.N_freqs) + self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs) + self.n_masking_step = config.get("n_masking_step", 0) + self.update_step( + None, None + ) # mask should be updated at the beginning each step + + def forward(self, x): + out = [] + for freq, mask in zip(self.freq_bands, self.mask): + for func in self.funcs: + out += [func(freq * x) * mask] + return torch.cat(out, -1) + + def update_step(self, epoch, global_step, on_load_weights=False): + if self.n_masking_step <= 0 or global_step is None: + self.mask = torch.ones(self.N_freqs, dtype=torch.float32) + else: + self.mask = ( + 1.0 + - torch.cos( + math.pi + * ( + global_step / self.n_masking_step * self.N_freqs + - torch.arange(0, self.N_freqs) + ).clamp(0, 1) + ) + ) / 2.0 + threestudio.debug( + f"Update mask: {global_step}/{self.n_masking_step} {self.mask}" + ) + + +class TCNNEncoding(nn.Module): + def __init__(self, in_channels, config, dtype=torch.float32) -> None: + super().__init__() + self.n_input_dims = in_channels + with torch.cuda.device(get_rank()): + self.encoding = tcnn.Encoding(in_channels, config, dtype=dtype) + self.n_output_dims = self.encoding.n_output_dims + + def forward(self, x): + return self.encoding(x) + + +# 4D implicit decomposition of space and time (4D-fy) +class TCNNEncodingSpatialTime(nn.Module): + def __init__( + self, in_channels, config, dtype=torch.float32, init_time_zero=False + ) -> None: + super().__init__() + self.n_input_dims = in_channels + config["otype"] = "HashGrid" + self.num_frames = 1 # config["num_frames"] + self.static = config["static"] + self.cfg = config_to_primitive(config) + self.cfg_time = self.cfg + self.n_key_frames = config.get("n_key_frames", 1) + with torch.cuda.device(get_rank()): + self.encoding = tcnn.Encoding(self.n_input_dims, self.cfg, dtype=dtype) + self.encoding_time = tcnn.Encoding( + self.n_input_dims + 1, self.cfg_time, dtype=dtype + ) + self.n_output_dims = self.encoding.n_output_dims + self.frame_time = None + if self.static: + self.set_temp_param_grad(requires_grad=False) + self.use_key_frame = config.get("use_key_frame", False) + self.is_video = True + self.update_occ_grid = False + + def set_temp_param_grad(self, requires_grad=False): + self.set_param_grad(self.encoding_time, requires_grad=requires_grad) + + def set_param_grad(self, param_list, requires_grad=False): + if isinstance(param_list, nn.Parameter): + param_list.requires_grad = requires_grad + else: + for param in param_list.parameters(): + param.requires_grad = requires_grad + + def forward(self, x): + # TODO frame_time only supports batch_size == 1 cases + if self.update_occ_grid and not isinstance(self.frame_time, float): + frame_time = self.frame_time + else: + if (self.static or not self.training) and self.frame_time is None: + frame_time = torch.zeros( + (self.num_frames, 1), device=x.device, dtype=x.dtype + ).expand(x.shape[0], 1) + else: + if self.frame_time is None: + frame_time = 0.0 + else: + frame_time = self.frame_time + frame_time = ( + torch.ones((self.num_frames, 1), device=x.device, dtype=x.dtype) + * frame_time + ).expand(x.shape[0], 1) + frame_time = frame_time.view(-1, 1) + enc_space = self.encoding(x) + x_frame_time = torch.cat((x, frame_time), 1) + enc_space_time = self.encoding_time(x_frame_time) + enc = enc_space + enc_space_time + return enc + + +class ProgressiveBandHashGrid(nn.Module, Updateable): + def __init__(self, in_channels, config, dtype=torch.float32): + super().__init__() + self.n_input_dims = in_channels + encoding_config = config.copy() + encoding_config["otype"] = "Grid" + encoding_config["type"] = "Hash" + with torch.cuda.device(get_rank()): + self.encoding = tcnn.Encoding(in_channels, encoding_config, dtype=dtype) + self.n_output_dims = self.encoding.n_output_dims + self.n_level = config["n_levels"] + self.n_features_per_level = config["n_features_per_level"] + self.start_level, self.start_step, self.update_steps = ( + config["start_level"], + config["start_step"], + config["update_steps"], + ) + self.current_level = self.start_level + self.mask = torch.zeros( + self.n_level * self.n_features_per_level, + dtype=torch.float32, + device=get_rank(), + ) + + def forward(self, x): + enc = self.encoding(x) + enc = enc * self.mask + return enc + + def update_step(self, epoch, global_step, on_load_weights=False): + current_level = min( + self.start_level + + max(global_step - self.start_step, 0) // self.update_steps, + self.n_level, + ) + if current_level > self.current_level: + threestudio.debug(f"Update current level to {current_level}") + self.current_level = current_level + self.mask[: self.current_level * self.n_features_per_level] = 1.0 + + +class CompositeEncoding(nn.Module, Updateable): + def __init__(self, encoding, include_xyz=False, xyz_scale=2.0, xyz_offset=-1.0): + super(CompositeEncoding, self).__init__() + self.encoding = encoding + self.include_xyz, self.xyz_scale, self.xyz_offset = ( + include_xyz, + xyz_scale, + xyz_offset, + ) + self.n_output_dims = ( + int(self.include_xyz) * self.encoding.n_input_dims + + self.encoding.n_output_dims + ) + + def forward(self, x, *args): + return ( + self.encoding(x, *args) + if not self.include_xyz + else torch.cat( + [x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1 + ) + ) + + +def get_encoding(n_input_dims: int, config) -> nn.Module: + # input suppose to be range [0, 1] + encoding: nn.Module + if config.otype == "ProgressiveBandFrequency": + encoding = ProgressiveBandFrequency(n_input_dims, config_to_primitive(config)) + elif config.otype == "ProgressiveBandHashGrid": + encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config)) + elif config.otype == "HashGridSpatialTime": + encoding = TCNNEncodingSpatialTime(n_input_dims, config) # 4D-fy encoding + else: + encoding = TCNNEncoding(n_input_dims, config_to_primitive(config)) + encoding = CompositeEncoding( + encoding, + include_xyz=config.get("include_xyz", False), + xyz_scale=2.0, + xyz_offset=-1.0, + ) # FIXME: hard coded + return encoding + + +class VanillaMLP(nn.Module): + def __init__(self, dim_in: int, dim_out: int, config: dict): + super().__init__() + self.n_neurons, self.n_hidden_layers = ( + config["n_neurons"], + config["n_hidden_layers"], + ) + layers = [ + self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), + self.make_activation(), + ] + for i in range(self.n_hidden_layers - 1): + layers += [ + self.make_linear( + self.n_neurons, self.n_neurons, is_first=False, is_last=False + ), + self.make_activation(), + ] + layers += [ + self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True) + ] + self.layers = nn.Sequential(*layers) + self.output_activation = get_activation(config.get("output_activation", None)) + + def forward(self, x): + # disable autocast + # strange that the parameters will have empty gradients if autocast is enabled in AMP + with torch.cuda.amp.autocast(enabled=False): + x = self.layers(x) + x = self.output_activation(x) + return x + + def make_linear(self, dim_in, dim_out, is_first, is_last): + layer = nn.Linear(dim_in, dim_out, bias=False) + return layer + + def make_activation(self): + return nn.ReLU(inplace=True) + + +class SphereInitVanillaMLP(nn.Module): + def __init__(self, dim_in, dim_out, config): + super().__init__() + self.n_neurons, self.n_hidden_layers = ( + config["n_neurons"], + config["n_hidden_layers"], + ) + self.sphere_init, self.weight_norm = True, True + self.sphere_init_radius = config["sphere_init_radius"] + self.sphere_init_inside_out = config["inside_out"] + + self.layers = [ + self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), + self.make_activation(), + ] + for i in range(self.n_hidden_layers - 1): + self.layers += [ + self.make_linear( + self.n_neurons, self.n_neurons, is_first=False, is_last=False + ), + self.make_activation(), + ] + self.layers += [ + self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True) + ] + self.layers = nn.Sequential(*self.layers) + self.output_activation = get_activation(config.get("output_activation", None)) + + def forward(self, x): + # disable autocast + # strange that the parameters will have empty gradients if autocast is enabled in AMP + with torch.cuda.amp.autocast(enabled=False): + x = self.layers(x) + x = self.output_activation(x) + return x + + def make_linear(self, dim_in, dim_out, is_first, is_last): + layer = nn.Linear(dim_in, dim_out, bias=True) + + if is_last: + if not self.sphere_init_inside_out: + torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) + torch.nn.init.normal_( + layer.weight, + mean=math.sqrt(math.pi) / math.sqrt(dim_in), + std=0.0001, + ) + else: + torch.nn.init.constant_(layer.bias, self.sphere_init_radius) + torch.nn.init.normal_( + layer.weight, + mean=-math.sqrt(math.pi) / math.sqrt(dim_in), + std=0.0001, + ) + elif is_first: + torch.nn.init.constant_(layer.bias, 0.0) + torch.nn.init.constant_(layer.weight[:, 3:], 0.0) + torch.nn.init.normal_( + layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out) + ) + else: + torch.nn.init.constant_(layer.bias, 0.0) + torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) + + if self.weight_norm: + layer = nn.utils.weight_norm(layer) + return layer + + def make_activation(self): + return nn.Softplus(beta=100) + + +class TCNNNetwork(nn.Module): + def __init__(self, dim_in: int, dim_out: int, config: dict) -> None: + super().__init__() + with torch.cuda.device(get_rank()): + self.network = tcnn.Network(dim_in, dim_out, config) + + def forward(self, x): + return self.network(x).float() # transform to float32 + + +def get_mlp(n_input_dims, n_output_dims, config) -> nn.Module: + network: nn.Module + if config.otype == "VanillaMLP": + network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config)) + elif config.otype == "SphereInitVanillaMLP": + network = SphereInitVanillaMLP( + n_input_dims, n_output_dims, config_to_primitive(config) + ) + else: + assert ( + config.get("sphere_init", False) is False + ), "sphere_init=True only supported by VanillaMLP" + network = TCNNNetwork(n_input_dims, n_output_dims, config_to_primitive(config)) + return network + + +class NetworkWithInputEncoding(nn.Module, Updateable): + def __init__(self, encoding, network): + super().__init__() + self.encoding, self.network = encoding, network + + def forward(self, x): + return self.network(self.encoding(x)) + + +class TCNNNetworkWithInputEncoding(nn.Module): + def __init__( + self, + n_input_dims: int, + n_output_dims: int, + encoding_config: dict, + network_config: dict, + ) -> None: + super().__init__() + with torch.cuda.device(get_rank()): + self.network_with_input_encoding = tcnn.NetworkWithInputEncoding( + n_input_dims=n_input_dims, + n_output_dims=n_output_dims, + encoding_config=encoding_config, + network_config=network_config, + ) + + def forward(self, x): + return self.network_with_input_encoding(x).float() # transform to float32 + + +def create_network_with_input_encoding( + n_input_dims: int, n_output_dims: int, encoding_config, network_config +) -> nn.Module: + # input suppose to be range [0, 1] + network_with_input_encoding: nn.Module + if encoding_config.otype in [ + "VanillaFrequency", + "ProgressiveBandHashGrid", + ] or network_config.otype in ["VanillaMLP", "SphereInitVanillaMLP"]: + encoding = get_encoding(n_input_dims, encoding_config) + network = get_mlp(encoding.n_output_dims, n_output_dims, network_config) + network_with_input_encoding = NetworkWithInputEncoding(encoding, network) + else: + network_with_input_encoding = TCNNNetworkWithInputEncoding( + n_input_dims=n_input_dims, + n_output_dims=n_output_dims, + encoding_config=config_to_primitive(encoding_config), + network_config=config_to_primitive(network_config), + ) + return network_with_input_encoding + + +class ToDTypeWrapper(nn.Module): + def __init__(self, module: nn.Module, dtype: torch.dtype): + super().__init__() + self.module = module + self.dtype = dtype + + def forward(self, x: Float[Tensor, "..."]) -> Float[Tensor, "..."]: + return self.module(x).to(self.dtype) \ No newline at end of file diff --git a/threestudio/models/prompt_processors/__init__.py b/threestudio/models/prompt_processors/__init__.py new file mode 100644 index 0000000..86294eb --- /dev/null +++ b/threestudio/models/prompt_processors/__init__.py @@ -0,0 +1,7 @@ +from . import ( + base, + deepfloyd_prompt_processor, + dummy_prompt_processor, + stable_diffusion_prompt_processor, + clip_prompt_processor, +) \ No newline at end of file diff --git a/threestudio/models/prompt_processors/base.py b/threestudio/models/prompt_processors/base.py new file mode 100644 index 0000000..2fa1a2e --- /dev/null +++ b/threestudio/models/prompt_processors/base.py @@ -0,0 +1,517 @@ +import json +import os +from dataclasses import dataclass, field + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from transformers import AutoTokenizer, BertForMaskedLM + +import threestudio +from threestudio.utils.base import BaseObject +from threestudio.utils.misc import barrier, cleanup, get_rank +from threestudio.utils.ops import shifted_cosine_decay, shifted_expotional_decay +from threestudio.utils.typing import * + + +def hash_prompt(model: str, prompt: str) -> str: + import hashlib + + identifier = f"{model}-{prompt}" + return hashlib.md5(identifier.encode()).hexdigest() + + +@dataclass +class DirectionConfig: + name: str + prompt: Callable[[str], str] + negative_prompt: Callable[[str], str] + condition: Callable[ + [Float[Tensor, "B"], Float[Tensor, "B"], Float[Tensor, "B"]], + Float[Tensor, "B"], + ] + + +@dataclass +class PromptProcessorOutput: + text_embeddings: Float[Tensor, "N Nf"] + uncond_text_embeddings: Float[Tensor, "N Nf"] + text_embeddings_vd: Float[Tensor, "Nv N Nf"] + uncond_text_embeddings_vd: Float[Tensor, "Nv N Nf"] + directions: List[DirectionConfig] + direction2idx: Dict[str, int] + use_perp_neg: bool + perp_neg_f_sb: Tuple[float, float, float] + perp_neg_f_fsb: Tuple[float, float, float] + perp_neg_f_fs: Tuple[float, float, float] + perp_neg_f_sf: Tuple[float, float, float] + + def get_text_embeddings( + self, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + view_dependent_prompting: bool = True, + ) -> Float[Tensor, "BB N Nf"]: + batch_size = elevation.shape[0] + + if view_dependent_prompting: + # Get direction + direction_idx = torch.zeros_like(elevation, dtype=torch.long) + for d in self.directions: + direction_idx[ + d.condition(elevation, azimuth, camera_distances) + ] = self.direction2idx[d.name] + + # Get text embeddings + text_embeddings = self.text_embeddings_vd[direction_idx] # type: ignore + uncond_text_embeddings = self.uncond_text_embeddings_vd[direction_idx] # type: ignore + else: + text_embeddings = self.text_embeddings.expand(batch_size, -1, -1) # type: ignore + uncond_text_embeddings = self.uncond_text_embeddings.expand( # type: ignore + batch_size, -1, -1 + ) + + # IMPORTANT: we return (cond, uncond), which is in different order than other implementations! + return torch.cat([text_embeddings, uncond_text_embeddings], dim=0) + + def get_text_embeddings_perp_neg( + self, + elevation: Float[Tensor, "B"], + azimuth: Float[Tensor, "B"], + camera_distances: Float[Tensor, "B"], + view_dependent_prompting: bool = True, + ) -> Tuple[Float[Tensor, "BBBB N Nf"], Float[Tensor, "B 2"]]: + assert ( + view_dependent_prompting + ), "Perp-Neg only works with view-dependent prompting" + + batch_size = elevation.shape[0] + + direction_idx = torch.zeros_like(elevation, dtype=torch.long) + for d in self.directions: + direction_idx[ + d.condition(elevation, azimuth, camera_distances) + ] = self.direction2idx[d.name] + # 0 - side view + # 1 - front view + # 2 - back view + # 3 - overhead view + + pos_text_embeddings = [] + neg_text_embeddings = [] + neg_guidance_weights = [] + uncond_text_embeddings = [] + + side_emb = self.text_embeddings_vd[0] + front_emb = self.text_embeddings_vd[1] + back_emb = self.text_embeddings_vd[2] + overhead_emb = self.text_embeddings_vd[3] + + for idx, ele, azi, dis in zip( + direction_idx, elevation, azimuth, camera_distances + ): + azi = shift_azimuth_deg(azi) # to (-180, 180) + uncond_text_embeddings.append( + self.uncond_text_embeddings_vd[idx] + ) # should be "" + if idx.item() == 3: # overhead view + pos_text_embeddings.append(overhead_emb) # side view + # dummy + neg_text_embeddings += [ + self.uncond_text_embeddings_vd[idx], + self.uncond_text_embeddings_vd[idx], + ] + neg_guidance_weights += [0.0, 0.0] + else: # interpolating views + if torch.abs(azi) < 90: + # front-side interpolation + # 0 - complete side, 1 - complete front + r_inter = 1 - torch.abs(azi) / 90 + pos_text_embeddings.append( + r_inter * front_emb + (1 - r_inter) * side_emb + ) + neg_text_embeddings += [front_emb, side_emb] + neg_guidance_weights += [ + -shifted_expotional_decay(*self.perp_neg_f_fs, r_inter), + -shifted_expotional_decay(*self.perp_neg_f_sf, 1 - r_inter), + ] + else: + # side-back interpolation + # 0 - complete back, 1 - complete side + r_inter = 2.0 - torch.abs(azi) / 90 + pos_text_embeddings.append( + r_inter * side_emb + (1 - r_inter) * back_emb + ) + neg_text_embeddings += [side_emb, front_emb] + neg_guidance_weights += [ + -shifted_expotional_decay(*self.perp_neg_f_sb, r_inter), + -shifted_expotional_decay(*self.perp_neg_f_fsb, r_inter), + ] + + text_embeddings = torch.cat( + [ + torch.stack(pos_text_embeddings, dim=0), + torch.stack(uncond_text_embeddings, dim=0), + torch.stack(neg_text_embeddings, dim=0), + ], + dim=0, + ) + + return text_embeddings, torch.as_tensor( + neg_guidance_weights, device=elevation.device + ).reshape(batch_size, 2) + + +def shift_azimuth_deg(azimuth: Float[Tensor, "..."]) -> Float[Tensor, "..."]: + # shift azimuth angle (in degrees), to [-180, 180] + return (azimuth + 180) % 360 - 180 + + +class PromptProcessor(BaseObject): + @dataclass + class Config(BaseObject.Config): + prompt: str = "a hamburger" + + # manually assigned view-dependent prompts + prompt_front: Optional[str] = None + prompt_side: Optional[str] = None + prompt_back: Optional[str] = None + prompt_overhead: Optional[str] = None + + negative_prompt: str = "" + pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5" + overhead_threshold: float = 60.0 + front_threshold: float = 45.0 + back_threshold: float = 45.0 + view_dependent_prompt_front: bool = False + use_cache: bool = True + spawn: bool = True + + # perp neg + use_perp_neg: bool = False + # a*e(-b*r) + c + # a * e(-b) + c = 0 + perp_neg_f_sb: Tuple[float, float, float] = (1, 0.5, -0.606) + perp_neg_f_fsb: Tuple[float, float, float] = (1, 0.5, +0.967) + perp_neg_f_fs: Tuple[float, float, float] = ( + 4, + 0.5, + -2.426, + ) # f_fs(1) = 0, a, b > 0 + perp_neg_f_sf: Tuple[float, float, float] = (4, 0.5, -2.426) + + # prompt debiasing + use_prompt_debiasing: bool = False + pretrained_model_name_or_path_prompt_debiasing: str = "bert-base-uncased" + # index of words that can potentially be removed + prompt_debiasing_mask_ids: Optional[List[int]] = None + + cfg: Config + + @rank_zero_only + def configure_text_encoder(self) -> None: + raise NotImplementedError + + @rank_zero_only + def destroy_text_encoder(self) -> None: + raise NotImplementedError + + def configure(self) -> None: + self._cache_dir = ".threestudio_cache/text_embeddings" # FIXME: hard-coded path + + # view-dependent text embeddings + self.directions: List[DirectionConfig] + if self.cfg.view_dependent_prompt_front: + self.directions = [ + DirectionConfig( + "side", + lambda s: f"side view of {s}", + lambda s: s, + lambda ele, azi, dis: torch.ones_like(ele, dtype=torch.bool), + ), + DirectionConfig( + "front", + lambda s: f"front view of {s}", + lambda s: s, + lambda ele, azi, dis: ( + shift_azimuth_deg(azi) > -self.cfg.front_threshold + ) + & (shift_azimuth_deg(azi) < self.cfg.front_threshold), + ), + DirectionConfig( + "back", + lambda s: f"backside view of {s}", + lambda s: s, + lambda ele, azi, dis: ( + shift_azimuth_deg(azi) > 180 - self.cfg.back_threshold + ) + | (shift_azimuth_deg(azi) < -180 + self.cfg.back_threshold), + ), + DirectionConfig( + "overhead", + lambda s: f"overhead view of {s}", + lambda s: s, + lambda ele, azi, dis: ele > self.cfg.overhead_threshold, + ), + ] + else: + self.directions = [ + DirectionConfig( + "side", + lambda s: f"{s}, side view", + lambda s: s, + lambda ele, azi, dis: torch.ones_like(ele, dtype=torch.bool), + ), + DirectionConfig( + "front", + lambda s: f"{s}, front view", + lambda s: s, + lambda ele, azi, dis: ( + shift_azimuth_deg(azi) > -self.cfg.front_threshold + ) + & (shift_azimuth_deg(azi) < self.cfg.front_threshold), + ), + DirectionConfig( + "back", + lambda s: f"{s}, back view", + lambda s: s, + lambda ele, azi, dis: ( + shift_azimuth_deg(azi) > 180 - self.cfg.back_threshold + ) + | (shift_azimuth_deg(azi) < -180 + self.cfg.back_threshold), + ), + DirectionConfig( + "overhead", + lambda s: f"{s}, overhead view", + lambda s: s, + lambda ele, azi, dis: ele > self.cfg.overhead_threshold, + ), + ] + + self.direction2idx = {d.name: i for i, d in enumerate(self.directions)} + + with open(os.path.join("load/prompt_library.json"), "r") as f: + self.prompt_library = json.load(f) + # use provided prompt or find prompt in library + self.prompt = self.preprocess_prompt(self.cfg.prompt) + # use provided negative prompt + self.negative_prompt = self.cfg.negative_prompt + + threestudio.info( + f"Using prompt [{self.prompt}] and negative prompt [{self.negative_prompt}]" + ) + + # view-dependent prompting + if self.cfg.use_prompt_debiasing: + assert ( + self.cfg.prompt_side is None + and self.cfg.prompt_back is None + and self.cfg.prompt_overhead is None + ), "Do not manually assign prompt_side, prompt_back or prompt_overhead when using prompt debiasing" + prompts = self.get_debiased_prompt(self.prompt) + self.prompts_vd = [ + d.prompt(prompt) for d, prompt in zip(self.directions, prompts) + ] + else: + self.prompts_vd = [ + self.cfg.get(f"prompt_{d.name}", None) or d.prompt(self.prompt) # type: ignore + for d in self.directions + ] + + prompts_vd_display = " ".join( + [ + f"[{d.name}]:[{prompt}]" + for prompt, d in zip(self.prompts_vd, self.directions) + ] + ) + threestudio.info(f"Using view-dependent prompts {prompts_vd_display}") + + self.negative_prompts_vd = [ + d.negative_prompt(self.negative_prompt) for d in self.directions + ] + + self.prepare_text_embeddings() + self.load_text_embeddings() + + @staticmethod + def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device): + raise NotImplementedError + + @rank_zero_only + def prepare_text_embeddings(self): + os.makedirs(self._cache_dir, exist_ok=True) + + all_prompts = ( + [self.prompt] + + [self.negative_prompt] + + self.prompts_vd + + self.negative_prompts_vd + ) + prompts_to_process = [] + for prompt in all_prompts: + if self.cfg.use_cache: + # some text embeddings are already in cache + # do not process them + cache_path = os.path.join( + self._cache_dir, + f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt", + ) + if os.path.exists(cache_path): + threestudio.debug( + f"Text embeddings for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] are already in cache, skip processing." + ) + continue + prompts_to_process.append(prompt) + + if len(prompts_to_process) > 0: + if self.cfg.spawn: + ctx = mp.get_context("spawn") + subprocess = ctx.Process( + target=self.spawn_func, + args=( + self.cfg.pretrained_model_name_or_path, + prompts_to_process, + self._cache_dir, + self.device + ), + ) + subprocess.start() + subprocess.join() + else: + self.spawn_func( + self.cfg.pretrained_model_name_or_path, + prompts_to_process, + self._cache_dir, + self.device + ) + cleanup() + + def load_text_embeddings(self): + # synchronize, to ensure the text embeddings have been computed and saved to cache + barrier() + self.text_embeddings = self.load_from_cache(self.prompt)[None, ...] + self.uncond_text_embeddings = self.load_from_cache(self.negative_prompt)[ + None, ... + ] + self.text_embeddings_vd = torch.stack( + [self.load_from_cache(prompt) for prompt in self.prompts_vd], dim=0 + ) + self.uncond_text_embeddings_vd = torch.stack( + [self.load_from_cache(prompt) for prompt in self.negative_prompts_vd], dim=0 + ) + threestudio.debug(f"Loaded text embeddings.") + + def load_from_cache(self, prompt): + cache_path = os.path.join( + self._cache_dir, + f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt", + ) + if not os.path.exists(cache_path): + raise FileNotFoundError( + f"Text embedding file {cache_path} for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] not found." + ) + return torch.load(cache_path, map_location=self.device) + + def preprocess_prompt(self, prompt: str) -> str: + if prompt.startswith("lib:"): + # find matches in the library + candidate = None + keywords = prompt[4:].lower().split("_") + for prompt in self.prompt_library["dreamfusion"]: + if all([k in prompt.lower() for k in keywords]): + if candidate is not None: + raise ValueError( + f"Multiple prompts matched with keywords {keywords} in library" + ) + candidate = prompt + if candidate is None: + raise ValueError( + f"Cannot find prompt with keywords {keywords} in library" + ) + threestudio.info("Find matched prompt in library: " + candidate) + return candidate + else: + return prompt + + def get_text_embeddings( + self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] + ) -> Tuple[Float[Tensor, "B ..."], Float[Tensor, "B ..."]]: + raise NotImplementedError + + def get_debiased_prompt(self, prompt: str) -> List[str]: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + tokenizer = AutoTokenizer.from_pretrained( + self.cfg.pretrained_model_name_or_path_prompt_debiasing + ) + model = BertForMaskedLM.from_pretrained( + self.cfg.pretrained_model_name_or_path_prompt_debiasing + ) + + views = [d.name for d in self.directions] + view_ids = tokenizer(" ".join(views), return_tensors="pt").input_ids[0] + view_ids = view_ids[1:5] + + def modulate(prompt): + prompt_vd = f"This image is depicting a [MASK] view of {prompt}" + tokens = tokenizer( + prompt_vd, + padding="max_length", + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + mask_idx = torch.where(tokens.input_ids == tokenizer.mask_token_id)[1] + + logits = model(**tokens).logits + logits = F.softmax(logits[0, mask_idx], dim=-1) + logits = logits[0, view_ids] + probes = logits / logits.sum() + return probes + + prompts = [prompt.split(" ") for _ in range(4)] + full_probe = modulate(prompt) + n_words = len(prompt.split(" ")) + prompt_debiasing_mask_ids = ( + self.cfg.prompt_debiasing_mask_ids + if self.cfg.prompt_debiasing_mask_ids is not None + else list(range(n_words)) + ) + words_to_debias = [prompt.split(" ")[idx] for idx in prompt_debiasing_mask_ids] + threestudio.info(f"Words that can potentially be removed: {words_to_debias}") + for idx in prompt_debiasing_mask_ids: + words = prompt.split(" ") + prompt_ = " ".join(words[:idx] + words[(idx + 1) :]) + part_probe = modulate(prompt_) + + pmi = full_probe / torch.lerp(part_probe, full_probe, 0.5) + for i in range(pmi.shape[0]): + if pmi[i].item() < 0.95: + prompts[i][idx] = "" + + debiased_prompts = [" ".join([word for word in p if word]) for p in prompts] + for d, debiased_prompt in zip(views, debiased_prompts): + threestudio.info(f"Debiased prompt of the {d} view is [{debiased_prompt}]") + + del tokenizer, model + cleanup() + + return debiased_prompts + + def __call__(self) -> PromptProcessorOutput: + return PromptProcessorOutput( + text_embeddings=self.text_embeddings, + uncond_text_embeddings=self.uncond_text_embeddings, + text_embeddings_vd=self.text_embeddings_vd, + uncond_text_embeddings_vd=self.uncond_text_embeddings_vd, + directions=self.directions, + direction2idx=self.direction2idx, + use_perp_neg=self.cfg.use_perp_neg, + perp_neg_f_sb=self.cfg.perp_neg_f_sb, + perp_neg_f_fsb=self.cfg.perp_neg_f_fsb, + perp_neg_f_fs=self.cfg.perp_neg_f_fs, + perp_neg_f_sf=self.cfg.perp_neg_f_sf, + ) \ No newline at end of file diff --git a/threestudio/models/prompt_processors/clip_prompt_processor.py b/threestudio/models/prompt_processors/clip_prompt_processor.py new file mode 100644 index 0000000..acc1d1d --- /dev/null +++ b/threestudio/models/prompt_processors/clip_prompt_processor.py @@ -0,0 +1,44 @@ +import json +import os +from dataclasses import dataclass + +import clip +import torch +import torch +import torch.nn as nn + +import threestudio +from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt +from threestudio.utils.misc import cleanup +from threestudio.utils.typing import * + + +@threestudio.register("clip-prompt-processor") +class ClipPromptProcessor(PromptProcessor): + @dataclass + class Config(PromptProcessor.Config): + pass + + cfg: Config + + @staticmethod + def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + clip_model, _ = clip.load(pretrained_model_name_or_path, jit=False) + with torch.no_grad(): + tokens = clip.tokenize( + prompts, + ).to(device) + text_embeddings = clip_model.encode_text(tokens) + text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) + + for prompt, embedding in zip(prompts, text_embeddings): + torch.save( + embedding, + os.path.join( + cache_dir, + f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", + ), + ) + + del clip_model diff --git a/threestudio/models/prompt_processors/deepfloyd_prompt_processor.py b/threestudio/models/prompt_processors/deepfloyd_prompt_processor.py new file mode 100644 index 0000000..b4316ee --- /dev/null +++ b/threestudio/models/prompt_processors/deepfloyd_prompt_processor.py @@ -0,0 +1,98 @@ +import json +import os +from dataclasses import dataclass + +import torch +import torch.nn as nn +from diffusers import IFPipeline +from transformers import T5EncoderModel, T5Tokenizer + +import threestudio +from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt +from threestudio.utils.misc import cleanup +from threestudio.utils.typing import * + + +@threestudio.register("deep-floyd-prompt-processor") +class DeepFloydPromptProcessor(PromptProcessor): + @dataclass + class Config(PromptProcessor.Config): + pretrained_model_name_or_path: str = "DeepFloyd/IF-I-XL-v1.0" + + cfg: Config + + ### these functions are unused, kept for debugging ### + def configure_text_encoder(self) -> None: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + self.text_encoder = T5EncoderModel.from_pretrained( + self.cfg.pretrained_model_name_or_path, + subfolder="text_encoder", + load_in_8bit=True, + variant="8bit", + device_map="auto", + ) # FIXME: behavior of auto device map in multi-GPU training + self.pipe = IFPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, + text_encoder=self.text_encoder, # pass the previously instantiated 8bit text encoder + unet=None, + ) + + def destroy_text_encoder(self) -> None: + del self.text_encoder + del self.pipe + cleanup() + + def get_text_embeddings( + self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] + ) -> Tuple[Float[Tensor, "B 77 4096"], Float[Tensor, "B 77 4096"]]: + text_embeddings, uncond_text_embeddings = self.pipe.encode_prompt( + prompt=prompt, negative_prompt=negative_prompt, device=self.device + ) + return text_embeddings, uncond_text_embeddings + + ### + + @staticmethod + def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device): + max_length = 77 + tokenizer = T5Tokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", + local_files_only=True + ) + text_encoder = T5EncoderModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=torch.float16, # suppress warning + load_in_8bit=True, + variant="8bit", + device_map="auto", + local_files_only=True + ) + with torch.no_grad(): + text_inputs = tokenizer( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + text_embeddings = text_encoder( + text_input_ids.to(text_encoder.device), + attention_mask=attention_mask.to(text_encoder.device), + ) + text_embeddings = text_embeddings[0] + + for prompt, embedding in zip(prompts, text_embeddings): + torch.save( + embedding, + os.path.join( + cache_dir, + f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", + ), + ) + + del text_encoder \ No newline at end of file diff --git a/threestudio/models/prompt_processors/dummy_prompt_processor.py b/threestudio/models/prompt_processors/dummy_prompt_processor.py new file mode 100644 index 0000000..1782953 --- /dev/null +++ b/threestudio/models/prompt_processors/dummy_prompt_processor.py @@ -0,0 +1,18 @@ +import json +import os +from dataclasses import dataclass + +import threestudio +from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt +from threestudio.utils.misc import cleanup +from threestudio.utils.typing import * + + +@threestudio.register("dummy-prompt-processor") +class DummyPromptProcessor(PromptProcessor): + @dataclass + class Config(PromptProcessor.Config): + pretrained_model_name_or_path: str = "" + prompt: str = "" + + cfg: Config \ No newline at end of file diff --git a/threestudio/models/prompt_processors/stable_diffusion_prompt_processor.py b/threestudio/models/prompt_processors/stable_diffusion_prompt_processor.py new file mode 100644 index 0000000..8423b20 --- /dev/null +++ b/threestudio/models/prompt_processors/stable_diffusion_prompt_processor.py @@ -0,0 +1,136 @@ +import json +import os +from dataclasses import dataclass + +import torch +import torch.nn as nn +from transformers import AutoTokenizer, CLIPTextModel + +import threestudio +from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt +from threestudio.utils.misc import cleanup +from threestudio.utils.typing import * + + +@threestudio.register("stable-diffusion-prompt-processor") +class StableDiffusionPromptProcessor(PromptProcessor): + @dataclass + class Config(PromptProcessor.Config): + pass + + cfg: Config + + ### these functions are unused, kept for debugging ### + def configure_text_encoder(self) -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + self.cfg.pretrained_model_name_or_path, subfolder="tokenizer" + ) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + self.text_encoder = CLIPTextModel.from_pretrained( + self.cfg.pretrained_model_name_or_path, subfolder="text_encoder" + ).to(self.device) + + for p in self.text_encoder.parameters(): + p.requires_grad_(False) + + def destroy_text_encoder(self) -> None: + del self.tokenizer + del self.text_encoder + cleanup() + + def get_text_embeddings( + self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] + ) -> Tuple[Float[Tensor, "B 77 768"], Float[Tensor, "B 77 768"]]: + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + # Tokenize text and get embeddings + tokens = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + uncond_tokens = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + + with torch.no_grad(): + text_embeddings = self.text_encoder(tokens.input_ids.to(self.device))[0] + uncond_text_embeddings = self.text_encoder( + uncond_tokens.input_ids.to(self.device) + )[0] + + return text_embeddings, uncond_text_embeddings + + ### + + @staticmethod + def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", + local_files_only=True, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + device_map="auto", + local_files_only=True, + ) + + with torch.no_grad(): + tokens = tokenizer( + prompts, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ) + text_embeddings = text_encoder(tokens.input_ids.to(text_encoder.device))[0] + + for prompt, embedding in zip(prompts, text_embeddings): + torch.save( + embedding, + os.path.join( + cache_dir, + f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", + ), + ) + + del text_encoder + + +from transformers.models.clip import CLIPTextModel, CLIPTokenizer +def add_tokens_to_model(learned_embeds_path, text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, override_token: Optional[Union[str, dict]] = None) -> None: + r"""Adds tokens to the tokenizer and text encoder of a model.""" + + learned_embeds = torch.load(learned_embeds_path, map_location='cpu') + + # Loop over learned embeddings + new_tokens = [] + for token, embedding in learned_embeds.items(): + embedding = embedding.to(text_encoder.get_input_embeddings().weight.dtype) + if override_token is not None: + token = override_token if isinstance(override_token, str) else override_token[token] + + # Add the token to the tokenizer + num_added_tokens = tokenizer.add_tokens(token) + if num_added_tokens == 0: + raise ValueError((f"The tokenizer already contains the token {token}. Please pass a " + "different `token` that is not already in the tokenizer.")) + + # Resize the token embeddings + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embedding + new_tokens.append(token) + + print(f'Added {len(new_tokens)} tokens to tokenizer and text embedding: {new_tokens}') \ No newline at end of file diff --git a/threestudio/models/renderers/__init__.py b/threestudio/models/renderers/__init__.py new file mode 100644 index 0000000..d33e0fc --- /dev/null +++ b/threestudio/models/renderers/__init__.py @@ -0,0 +1,9 @@ +from . import ( + base, + deferred_volume_renderer, + gan_volume_renderer, + nerf_volume_renderer, + neus_volume_renderer, + nvdiff_rasterizer, + patch_renderer, +) diff --git a/threestudio/models/renderers/base.py b/threestudio/models/renderers/base.py new file mode 100644 index 0000000..06ee713 --- /dev/null +++ b/threestudio/models/renderers/base.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass + +import nerfacc +import torch +import torch.nn.functional as F + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.models.geometry.base import BaseImplicitGeometry +from threestudio.models.materials.base import BaseMaterial +from threestudio.utils.base import BaseModule +from threestudio.utils.typing import * + + +class Renderer(BaseModule): + @dataclass + class Config(BaseModule.Config): + radius: float = 1.0 + + cfg: Config + + def configure( + self, + geometry: BaseImplicitGeometry, + material: BaseMaterial, + background: BaseBackground, + ) -> None: + # keep references to submodules using namedtuple, avoid being registered as modules + @dataclass + class SubModules: + geometry: BaseImplicitGeometry + material: BaseMaterial + background: BaseBackground + + self.sub_modules = SubModules(geometry, material, background) + + # set up bounding box + self.bbox: Float[Tensor, "2 3"] + self.register_buffer( + "bbox", + torch.as_tensor( + [ + [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], + [self.cfg.radius, self.cfg.radius, self.cfg.radius], + ], + dtype=torch.float32, + ), + ) + + def forward(self, *args, **kwargs) -> Dict[str, Any]: + raise NotImplementedError + + @property + def geometry(self) -> BaseImplicitGeometry: + return self.sub_modules.geometry + + @property + def material(self) -> BaseMaterial: + return self.sub_modules.material + + @property + def background(self) -> BaseBackground: + return self.sub_modules.background + + def set_geometry(self, geometry: BaseImplicitGeometry) -> None: + self.sub_modules.geometry = geometry + + def set_material(self, material: BaseMaterial) -> None: + self.sub_modules.material = material + + def set_background(self, background: BaseBackground) -> None: + self.sub_modules.background = background + + +class VolumeRenderer(Renderer): + pass + + +class Rasterizer(Renderer): + pass \ No newline at end of file diff --git a/threestudio/models/renderers/deferred_volume_renderer.py b/threestudio/models/renderers/deferred_volume_renderer.py new file mode 100644 index 0000000..4c8f2ac --- /dev/null +++ b/threestudio/models/renderers/deferred_volume_renderer.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +import threestudio +from threestudio.models.renderers.base import VolumeRenderer + + +class DeferredVolumeRenderer(VolumeRenderer): + pass diff --git a/threestudio/models/renderers/gan_volume_renderer.py b/threestudio/models/renderers/gan_volume_renderer.py new file mode 100644 index 0000000..61032f5 --- /dev/null +++ b/threestudio/models/renderers/gan_volume_renderer.py @@ -0,0 +1,159 @@ +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.models.geometry.base import BaseImplicitGeometry +from threestudio.models.materials.base import BaseMaterial +from threestudio.models.renderers.base import VolumeRenderer +from threestudio.utils.GAN.discriminator import NLayerDiscriminator, weights_init +from threestudio.utils.GAN.distribution import DiagonalGaussianDistribution +from threestudio.utils.GAN.mobilenet import MobileNetV3 as GlobalEncoder +from threestudio.utils.GAN.vae import Decoder as Generator +from threestudio.utils.GAN.vae import Encoder as LocalEncoder +from threestudio.utils.typing import * + + +@threestudio.register("gan-volume-renderer") +class GANVolumeRenderer(VolumeRenderer): + @dataclass + class Config(VolumeRenderer.Config): + base_renderer_type: str = "" + base_renderer: Optional[VolumeRenderer.Config] = None + + cfg: Config + + def configure( + self, + geometry: BaseImplicitGeometry, + material: BaseMaterial, + background: BaseBackground, + ) -> None: + self.base_renderer = threestudio.find(self.cfg.base_renderer_type)( + self.cfg.base_renderer, + geometry=geometry, + material=material, + background=background, + ) + self.ch_mult = [1, 2, 4] + self.generator = Generator( + ch=64, + out_ch=3, + ch_mult=self.ch_mult, + num_res_blocks=1, + attn_resolutions=[], + dropout=0.0, + resamp_with_conv=True, + in_channels=7, + resolution=512, + z_channels=4, + ) + self.local_encoder = LocalEncoder( + ch=32, + out_ch=3, + ch_mult=self.ch_mult, + num_res_blocks=1, + attn_resolutions=[], + dropout=0.0, + resamp_with_conv=True, + in_channels=3, + resolution=512, + z_channels=4, + ) + self.global_encoder = GlobalEncoder(n_class=64) + self.discriminator = NLayerDiscriminator( + input_nc=3, n_layers=3, use_actnorm=False, ndf=64 + ).apply(weights_init) + + def forward( + self, + rays_o: Float[Tensor, "B H W 3"], + rays_d: Float[Tensor, "B H W 3"], + light_positions: Float[Tensor, "B 3"], + bg_color: Optional[Tensor] = None, + gt_rgb: Float[Tensor, "B H W 3"] = None, + multi_level_guidance: Bool = False, + **kwargs + ) -> Dict[str, Float[Tensor, "..."]]: + B, H, W, _ = rays_o.shape + if gt_rgb is not None and multi_level_guidance: + generator_level = torch.randint(0, 3, (1,)).item() + interval_x = torch.randint(0, 8, (1,)).item() + interval_y = torch.randint(0, 8, (1,)).item() + int_rays_o = rays_o[:, interval_y::8, interval_x::8] + int_rays_d = rays_d[:, interval_y::8, interval_x::8] + out = self.base_renderer( + int_rays_o, int_rays_d, light_positions, bg_color, **kwargs + ) + comp_int_rgb = out["comp_rgb"][..., :3] + comp_gt_rgb = gt_rgb[:, interval_y::8, interval_x::8] + else: + generator_level = 0 + scale_ratio = 2 ** (len(self.ch_mult) - 1) + rays_o = torch.nn.functional.interpolate( + rays_o.permute(0, 3, 1, 2), + (H // scale_ratio, W // scale_ratio), + mode="bilinear", + ).permute(0, 2, 3, 1) + rays_d = torch.nn.functional.interpolate( + rays_d.permute(0, 3, 1, 2), + (H // scale_ratio, W // scale_ratio), + mode="bilinear", + ).permute(0, 2, 3, 1) + + out = self.base_renderer(rays_o, rays_d, light_positions, bg_color, **kwargs) + comp_rgb = out["comp_rgb"][..., :3] + latent = out["comp_rgb"][..., 3:] + out["comp_lr_rgb"] = comp_rgb.clone() + + posterior = DiagonalGaussianDistribution(latent.permute(0, 3, 1, 2)) + if multi_level_guidance: + z_map = posterior.sample() + else: + z_map = posterior.mode() + lr_rgb = comp_rgb.permute(0, 3, 1, 2) + + if generator_level == 0: + g_code_rgb = self.global_encoder(F.interpolate(lr_rgb, (224, 224))) + comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb) + elif generator_level == 1: + g_code_rgb = self.global_encoder( + F.interpolate(gt_rgb.permute(0, 3, 1, 2), (224, 224)) + ) + comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb) + elif generator_level == 2: + g_code_rgb = self.global_encoder( + F.interpolate(gt_rgb.permute(0, 3, 1, 2), (224, 224)) + ) + l_code_rgb = self.local_encoder(gt_rgb.permute(0, 3, 1, 2)) + posterior = DiagonalGaussianDistribution(l_code_rgb) + z_map = posterior.sample() + comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb) + + comp_rgb = F.interpolate(comp_rgb.permute(0, 3, 1, 2), (H, W), mode="bilinear") + comp_gan_rgb = F.interpolate(comp_gan_rgb, (H, W), mode="bilinear") + out.update( + { + "posterior": posterior, + "comp_gan_rgb": comp_gan_rgb.permute(0, 2, 3, 1), + "comp_rgb": comp_rgb.permute(0, 2, 3, 1), + "generator_level": generator_level, + } + ) + + if gt_rgb is not None and multi_level_guidance: + out.update({"comp_int_rgb": comp_int_rgb, "comp_gt_rgb": comp_gt_rgb}) + return out + + def update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ) -> None: + self.base_renderer.update_step(epoch, global_step, on_load_weights) + + def train(self, mode=True): + return self.base_renderer.train(mode) + + def eval(self): + return self.base_renderer.eval() \ No newline at end of file diff --git a/threestudio/models/renderers/nerf_volume_renderer.py b/threestudio/models/renderers/nerf_volume_renderer.py new file mode 100644 index 0000000..d512f24 --- /dev/null +++ b/threestudio/models/renderers/nerf_volume_renderer.py @@ -0,0 +1,462 @@ +from dataclasses import dataclass, field +from functools import partial + +import nerfacc +import torch +import torch.nn.functional as F + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.models.estimators import ImportanceEstimator +from threestudio.models.geometry.base import BaseImplicitGeometry +from threestudio.models.materials.base import BaseMaterial +from threestudio.models.networks import create_network_with_input_encoding +from threestudio.models.renderers.base import VolumeRenderer +from threestudio.systems.utils import parse_optimizer, parse_scheduler_to_instance +from threestudio.utils.ops import chunk_batch, get_activation, validate_empty_rays +from threestudio.utils.typing import * + + +@threestudio.register("nerf-volume-renderer") +class NeRFVolumeRenderer(VolumeRenderer): + @dataclass + class Config(VolumeRenderer.Config): + num_samples_per_ray: int = 512 + eval_chunk_size: int = 160000 + randomized: bool = True + + near_plane: float = 0.0 + far_plane: float = 1e10 + + return_comp_normal: bool = False + return_normal_perturb: bool = False + + # in ["occgrid", "proposal", "importance"] + estimator: str = "occgrid" + + # for occgrid + grid_prune: bool = True + prune_alpha_threshold: bool = True + + # for proposal + proposal_network_config: Optional[dict] = None + prop_optimizer_config: Optional[dict] = None + prop_scheduler_config: Optional[dict] = None + num_samples_per_ray_proposal: int = 64 + + # for importance + num_samples_per_ray_importance: int = 64 + + cfg: Config + + def configure( + self, + geometry: BaseImplicitGeometry, + material: BaseMaterial, + background: BaseBackground, + ) -> None: + super().configure(geometry, material, background) + if self.cfg.estimator == "occgrid": + self.estimator = nerfacc.OccGridEstimator( + roi_aabb=self.bbox.view(-1), resolution=32, levels=1 + ) + if not self.cfg.grid_prune: + self.estimator.occs.fill_(True) + self.estimator.binaries.fill_(True) + self.render_step_size = ( + 1.732 * 2 * self.cfg.radius / self.cfg.num_samples_per_ray + ) + self.randomized = self.cfg.randomized + elif self.cfg.estimator == "importance": + self.estimator = ImportanceEstimator() + elif self.cfg.estimator == "proposal": + self.prop_net = create_network_with_input_encoding( + **self.cfg.proposal_network_config + ) + self.prop_optim = parse_optimizer( + self.cfg.prop_optimizer_config, self.prop_net + ) + self.prop_scheduler = ( + parse_scheduler_to_instance( + self.cfg.prop_scheduler_config, self.prop_optim + ) + if self.cfg.prop_scheduler_config is not None + else None + ) + self.estimator = nerfacc.PropNetEstimator( + self.prop_optim, self.prop_scheduler + ) + + def get_proposal_requires_grad_fn( + target: float = 5.0, num_steps: int = 1000 + ): + schedule = lambda s: min(s / num_steps, 1.0) * target + + steps_since_last_grad = 0 + + def proposal_requires_grad_fn(step: int) -> bool: + nonlocal steps_since_last_grad + target_steps_since_last_grad = schedule(step) + requires_grad = steps_since_last_grad > target_steps_since_last_grad + if requires_grad: + steps_since_last_grad = 0 + steps_since_last_grad += 1 + return requires_grad + + return proposal_requires_grad_fn + + self.proposal_requires_grad_fn = get_proposal_requires_grad_fn() + self.randomized = self.cfg.randomized + else: + raise NotImplementedError( + "Unknown estimator, should be one of ['occgrid', 'proposal', 'importance']." + ) + + # for proposal + self.vars_in_forward = {} + + def forward( + self, + rays_o: Float[Tensor, "B H W 3"], + rays_d: Float[Tensor, "B H W 3"], + light_positions: Float[Tensor, "B 3"], + bg_color: Optional[Tensor] = None, + **kwargs + ) -> Dict[str, Float[Tensor, "..."]]: + batch_size, height, width = rays_o.shape[:3] + rays_o_flatten: Float[Tensor, "Nr 3"] = rays_o.reshape(-1, 3) + rays_d_flatten: Float[Tensor, "Nr 3"] = rays_d.reshape(-1, 3) + light_positions_flatten: Float[Tensor, "Nr 3"] = ( + light_positions.reshape(-1, 1, 1, 3) + .expand(-1, height, width, -1) + .reshape(-1, 3) + ) + n_rays = rays_o_flatten.shape[0] + + if self.cfg.estimator == "occgrid": + if not self.cfg.grid_prune: + with torch.no_grad(): + ray_indices, t_starts_, t_ends_ = self.estimator.sampling( + rays_o_flatten, + rays_d_flatten, + sigma_fn=None, + near_plane=self.cfg.near_plane, + far_plane=self.cfg.far_plane, + render_step_size=self.render_step_size, + alpha_thre=0.0, + stratified=self.randomized, + cone_angle=0.0, + early_stop_eps=0, + ) + else: + + def sigma_fn(t_starts, t_ends, ray_indices): + t_starts, t_ends = t_starts[..., None], t_ends[..., None] + t_origins = rays_o_flatten[ray_indices] + t_positions = (t_starts + t_ends) / 2.0 + t_dirs = rays_d_flatten[ray_indices] + positions = t_origins + t_dirs * t_positions + if self.training: + sigma = self.geometry.forward_density(positions)[..., 0] + else: + sigma = chunk_batch( + self.geometry.forward_density, + self.cfg.eval_chunk_size, + positions, + )[..., 0] + return sigma + + with torch.no_grad(): + ray_indices, t_starts_, t_ends_ = self.estimator.sampling( + rays_o_flatten, + rays_d_flatten, + sigma_fn=sigma_fn if self.cfg.prune_alpha_threshold else None, + near_plane=self.cfg.near_plane, + far_plane=self.cfg.far_plane, + render_step_size=self.render_step_size, + alpha_thre=0.01 if self.cfg.prune_alpha_threshold else 0.0, + stratified=self.randomized, + cone_angle=0.0, + ) + elif self.cfg.estimator == "proposal": + + def prop_sigma_fn( + t_starts: Float[Tensor, "Nr Ns"], + t_ends: Float[Tensor, "Nr Ns"], + proposal_network, + ): + t_origins: Float[Tensor, "Nr 1 3"] = rays_o_flatten.unsqueeze(-2) + t_dirs: Float[Tensor, "Nr 1 3"] = rays_d_flatten.unsqueeze(-2) + positions: Float[Tensor, "Nr Ns 3"] = ( + t_origins + t_dirs * (t_starts + t_ends)[..., None] / 2.0 + ) + aabb_min, aabb_max = self.bbox[0], self.bbox[1] + positions = (positions - aabb_min) / (aabb_max - aabb_min) + selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1) + density_before_activation = ( + proposal_network(positions.view(-1, 3)) + .view(*positions.shape[:-1], 1) + .to(positions) + ) + density: Float[Tensor, "Nr Ns 1"] = ( + get_activation("shifted_trunc_exp")(density_before_activation) + * selector[..., None] + ) + return density.squeeze(-1) + + t_starts_, t_ends_ = self.estimator.sampling( + prop_sigma_fns=[partial(prop_sigma_fn, proposal_network=self.prop_net)], + prop_samples=[self.cfg.num_samples_per_ray_proposal], + num_samples=self.cfg.num_samples_per_ray, + n_rays=n_rays, + near_plane=self.cfg.near_plane, + far_plane=self.cfg.far_plane, + sampling_type="uniform", + stratified=self.randomized, + requires_grad=self.vars_in_forward["requires_grad"], + ) + ray_indices = ( + torch.arange(n_rays, device=rays_o_flatten.device) + .unsqueeze(-1) + .expand(-1, t_starts_.shape[1]) + ) + ray_indices = ray_indices.flatten() + t_starts_ = t_starts_.flatten() + t_ends_ = t_ends_.flatten() + elif self.cfg.estimator == "importance": + + def prop_sigma_fn( + t_starts: Float[Tensor, "Nr Ns"], + t_ends: Float[Tensor, "Nr Ns"], + proposal_network, + ): + t_origins: Float[Tensor, "Nr 1 3"] = rays_o_flatten.unsqueeze(-2) + t_dirs: Float[Tensor, "Nr 1 3"] = rays_d_flatten.unsqueeze(-2) + positions: Float[Tensor, "Nr Ns 3"] = ( + t_origins + t_dirs * (t_starts + t_ends)[..., None] / 2.0 + ) + with torch.no_grad(): + geo_out = chunk_batch( + proposal_network, + self.cfg.eval_chunk_size, + positions.reshape(-1, 3), + output_normal=False, + ) + density = geo_out["density"] + return density.reshape(positions.shape[:2]) + + t_starts_, t_ends_ = self.estimator.sampling( + prop_sigma_fns=[partial(prop_sigma_fn, proposal_network=self.geometry)], + prop_samples=[self.cfg.num_samples_per_ray_importance], + num_samples=self.cfg.num_samples_per_ray, + n_rays=n_rays, + near_plane=self.cfg.near_plane, + far_plane=self.cfg.far_plane, + sampling_type="uniform", + stratified=self.randomized, + ) + ray_indices = ( + torch.arange(n_rays, device=rays_o_flatten.device) + .unsqueeze(-1) + .expand(-1, t_starts_.shape[1]) + ) + ray_indices = ray_indices.flatten() + t_starts_ = t_starts_.flatten() + t_ends_ = t_ends_.flatten() + else: + raise NotImplementedError + + ray_indices, t_starts_, t_ends_ = validate_empty_rays( + ray_indices, t_starts_, t_ends_ + ) + ray_indices = ray_indices.long() + t_starts, t_ends = t_starts_[..., None], t_ends_[..., None] + t_origins = rays_o_flatten[ray_indices] + t_dirs = rays_d_flatten[ray_indices] + t_light_positions = light_positions_flatten[ray_indices] + t_positions = (t_starts + t_ends) / 2.0 + positions = t_origins + t_dirs * t_positions + t_intervals = t_ends - t_starts + + if self.training: + geo_out = self.geometry( + positions, output_normal=self.material.requires_normal + ) + rgb_fg_all = self.material( + viewdirs=t_dirs, + positions=positions, + light_positions=t_light_positions, + **geo_out, + **kwargs + ) + comp_rgb_bg = self.background(dirs=rays_d) + else: + geo_out = chunk_batch( + self.geometry, + self.cfg.eval_chunk_size, + positions, + output_normal=self.material.requires_normal, + ) + rgb_fg_all = chunk_batch( + self.material, + self.cfg.eval_chunk_size, + viewdirs=t_dirs, + positions=positions, + light_positions=t_light_positions, + **geo_out + ) + comp_rgb_bg = chunk_batch( + self.background, self.cfg.eval_chunk_size, dirs=rays_d + ) + + weights: Float[Tensor, "Nr 1"] + weights_, trans_, _ = nerfacc.render_weight_from_density( + t_starts[..., 0], + t_ends[..., 0], + geo_out["density"][..., 0], + ray_indices=ray_indices, + n_rays=n_rays, + ) + if self.training and self.cfg.estimator == "proposal": + self.vars_in_forward["trans"] = trans_.reshape(n_rays, -1) + + weights = weights_[..., None] + opacity: Float[Tensor, "Nr 1"] = nerfacc.accumulate_along_rays( + weights[..., 0], values=None, ray_indices=ray_indices, n_rays=n_rays + ) + depth: Float[Tensor, "Nr 1"] = nerfacc.accumulate_along_rays( + weights[..., 0], values=t_positions, ray_indices=ray_indices, n_rays=n_rays + ) + comp_rgb_fg: Float[Tensor, "Nr Nc"] = nerfacc.accumulate_along_rays( + weights[..., 0], values=rgb_fg_all, ray_indices=ray_indices, n_rays=n_rays + ) + + # populate depth and opacity to each point + t_depth = depth[ray_indices] + z_variance = nerfacc.accumulate_along_rays( + weights[..., 0], + values=(t_positions - t_depth) ** 2, + ray_indices=ray_indices, + n_rays=n_rays, + ) + + if bg_color is None: + bg_color = comp_rgb_bg + else: + if bg_color.shape[:-1] == (batch_size,): + # e.g. constant random color used for Zero123 + # [bs,3] -> [bs, 1, 1, 3]): + bg_color = bg_color.unsqueeze(1).unsqueeze(1) + # -> [bs, height, width, 3]): + bg_color = bg_color.expand(-1, height, width, -1) + + if bg_color.shape[:-1] == (batch_size, height, width): + bg_color = bg_color.reshape(batch_size * height * width, -1) + + comp_rgb = comp_rgb_fg + bg_color * (1.0 - opacity) + + out = { + "comp_rgb": comp_rgb.view(batch_size, height, width, -1), + "comp_rgb_fg": comp_rgb_fg.view(batch_size, height, width, -1), + "comp_rgb_bg": comp_rgb_bg.view(batch_size, height, width, -1), + "opacity": opacity.view(batch_size, height, width, 1), + "depth": depth.view(batch_size, height, width, 1), + "z_variance": z_variance.view(batch_size, height, width, 1), + } + + if self.training: + out.update( + { + "weights": weights, + "t_points": t_positions, + "t_intervals": t_intervals, + "t_dirs": t_dirs, + "ray_indices": ray_indices, + "points": positions, + **geo_out, + } + ) + if "normal" in geo_out: + if self.cfg.return_comp_normal: + comp_normal: Float[Tensor, "Nr 3"] = nerfacc.accumulate_along_rays( + weights[..., 0], + values=geo_out["normal"], + ray_indices=ray_indices, + n_rays=n_rays, + ) + comp_normal = F.normalize(comp_normal, dim=-1) + comp_normal = ( + (comp_normal + 1.0) / 2.0 * opacity + ) # for visualization + out.update( + { + "comp_normal": comp_normal.view( + batch_size, height, width, 3 + ), + } + ) + if self.cfg.return_normal_perturb: + normal_perturb = self.geometry( + positions + torch.randn_like(positions) * 1e-2, + output_normal=self.material.requires_normal, + )["normal"] + out.update({"normal_perturb": normal_perturb}) + else: + if "normal" in geo_out: + comp_normal = nerfacc.accumulate_along_rays( + weights[..., 0], + values=geo_out["normal"], + ray_indices=ray_indices, + n_rays=n_rays, + ) + comp_normal = F.normalize(comp_normal, dim=-1) + comp_normal = (comp_normal + 1.0) / 2.0 * opacity # for visualization + out.update( + { + "comp_normal": comp_normal.view(batch_size, height, width, 3), + } + ) + + return out + + def update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ) -> None: + if self.cfg.estimator == "occgrid": + if self.cfg.grid_prune: + + def occ_eval_fn(x): + density = self.geometry.forward_density(x) + # approximate for 1 - torch.exp(-density * self.render_step_size) based on taylor series + return density * self.render_step_size + + if self.training and not on_load_weights: + self.estimator.update_every_n_steps( + step=global_step, occ_eval_fn=occ_eval_fn + ) + elif self.cfg.estimator == "proposal": + if self.training: + requires_grad = self.proposal_requires_grad_fn(global_step) + self.vars_in_forward["requires_grad"] = requires_grad + else: + self.vars_in_forward["requires_grad"] = False + + def update_step_end(self, epoch: int, global_step: int) -> None: + if self.cfg.estimator == "proposal" and self.training: + self.estimator.update_every_n_steps( + self.vars_in_forward["trans"], + self.vars_in_forward["requires_grad"], + loss_scaler=1.0, + ) + + def train(self, mode=True): + self.randomized = mode and self.cfg.randomized + if self.cfg.estimator == "proposal": + self.prop_net.train() + return super().train(mode=mode) + + def eval(self): + self.randomized = False + if self.cfg.estimator == "proposal": + self.prop_net.eval() + return super().eval() \ No newline at end of file diff --git a/threestudio/models/renderers/neus_volume_renderer.py b/threestudio/models/renderers/neus_volume_renderer.py new file mode 100644 index 0000000..9c8410b --- /dev/null +++ b/threestudio/models/renderers/neus_volume_renderer.py @@ -0,0 +1,390 @@ +from dataclasses import dataclass +from functools import partial + +import nerfacc +import torch +import torch.nn as nn +import torch.nn.functional as F + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.models.estimators import ImportanceEstimator +from threestudio.models.geometry.base import BaseImplicitGeometry +from threestudio.models.materials.base import BaseMaterial +from threestudio.models.renderers.base import VolumeRenderer +from threestudio.utils.ops import chunk_batch, validate_empty_rays +from threestudio.utils.typing import * + + +def volsdf_density(sdf, inv_std): + inv_std = inv_std.clamp(0.0, 80.0) + beta = 1 / inv_std + alpha = inv_std + return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta)) + + +class LearnedVariance(nn.Module): + def __init__(self, init_val): + super(LearnedVariance, self).__init__() + self.register_parameter("_inv_std", nn.Parameter(torch.tensor(init_val))) + + @property + def inv_std(self): + val = torch.exp(self._inv_std * 10.0) + return val + + def forward(self, x): + return torch.ones_like(x) * self.inv_std.clamp(1.0e-6, 1.0e6) + + +@threestudio.register("neus-volume-renderer") +class NeuSVolumeRenderer(VolumeRenderer): + @dataclass + class Config(VolumeRenderer.Config): + num_samples_per_ray: int = 512 + randomized: bool = True + eval_chunk_size: int = 160000 + learned_variance_init: float = 0.3 + cos_anneal_end_steps: int = 0 + use_volsdf: bool = False + + near_plane: float = 0.0 + far_plane: float = 1e10 + + # in ['occgrid', 'importance'] + estimator: str = "occgrid" + + # for occgrid + grid_prune: bool = True + prune_alpha_threshold: bool = True + + # for importance + num_samples_per_ray_importance: int = 64 + + cfg: Config + + def configure( + self, + geometry: BaseImplicitGeometry, + material: BaseMaterial, + background: BaseBackground, + ) -> None: + super().configure(geometry, material, background) + self.variance = LearnedVariance(self.cfg.learned_variance_init) + if self.cfg.estimator == "occgrid": + self.estimator = nerfacc.OccGridEstimator( + roi_aabb=self.bbox.view(-1), resolution=32, levels=1 + ) + if not self.cfg.grid_prune: + self.estimator.occs.fill_(True) + self.estimator.binaries.fill_(True) + self.render_step_size = ( + 1.732 * 2 * self.cfg.radius / self.cfg.num_samples_per_ray + ) + self.randomized = self.cfg.randomized + elif self.cfg.estimator == "importance": + self.estimator = ImportanceEstimator() + else: + raise NotImplementedError( + "unknown estimator, should be in ['occgrid', 'importance']" + ) + self.cos_anneal_ratio = 1.0 + + def get_alpha(self, sdf, normal, dirs, dists): + inv_std = self.variance(sdf) + if self.cfg.use_volsdf: + alpha = torch.abs(dists.detach()) * volsdf_density(sdf, inv_std) + else: + true_cos = (dirs * normal).sum(-1, keepdim=True) + # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes + # the cos value "not dead" at the beginning training iterations, for better convergence. + iter_cos = -( + F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) + + F.relu(-true_cos) * self.cos_anneal_ratio + ) # always non-positive + + # Estimate signed distances at section points + estimated_next_sdf = sdf + iter_cos * dists * 0.5 + estimated_prev_sdf = sdf - iter_cos * dists * 0.5 + + prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_std) + next_cdf = torch.sigmoid(estimated_next_sdf * inv_std) + + p = prev_cdf - next_cdf + c = prev_cdf + + alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) + return alpha + + def forward( + self, + rays_o: Float[Tensor, "B H W 3"], + rays_d: Float[Tensor, "B H W 3"], + light_positions: Float[Tensor, "B 3"], + bg_color: Optional[Tensor] = None, + **kwargs + ) -> Dict[str, Float[Tensor, "..."]]: + batch_size, height, width = rays_o.shape[:3] + rays_o_flatten: Float[Tensor, "Nr 3"] = rays_o.reshape(-1, 3) + rays_d_flatten: Float[Tensor, "Nr 3"] = rays_d.reshape(-1, 3) + light_positions_flatten: Float[Tensor, "Nr 3"] = ( + light_positions.reshape(-1, 1, 1, 3) + .expand(-1, height, width, -1) + .reshape(-1, 3) + ) + n_rays = rays_o_flatten.shape[0] + + if self.cfg.estimator == "occgrid": + + def alpha_fn(t_starts, t_ends, ray_indices): + t_starts, t_ends = t_starts[..., None], t_ends[..., None] + t_origins = rays_o_flatten[ray_indices] + t_positions = (t_starts + t_ends) / 2.0 + t_dirs = rays_d_flatten[ray_indices] + positions = t_origins + t_dirs * t_positions + if self.training: + sdf = self.geometry.forward_sdf(positions)[..., 0] + else: + sdf = chunk_batch( + self.geometry.forward_sdf, + self.cfg.eval_chunk_size, + positions, + )[..., 0] + + inv_std = self.variance(sdf) + if self.cfg.use_volsdf: + alpha = self.render_step_size * volsdf_density(sdf, inv_std) + else: + estimated_next_sdf = sdf - self.render_step_size * 0.5 + estimated_prev_sdf = sdf + self.render_step_size * 0.5 + prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_std) + next_cdf = torch.sigmoid(estimated_next_sdf * inv_std) + p = prev_cdf - next_cdf + c = prev_cdf + alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) + + return alpha + + if not self.cfg.grid_prune: + with torch.no_grad(): + ray_indices, t_starts_, t_ends_ = self.estimator.sampling( + rays_o_flatten, + rays_d_flatten, + alpha_fn=None, + near_plane=self.cfg.near_plane, + far_plane=self.cfg.far_plane, + render_step_size=self.render_step_size, + alpha_thre=0.0, + stratified=self.randomized, + cone_angle=0.0, + early_stop_eps=0, + ) + else: + with torch.no_grad(): + ray_indices, t_starts_, t_ends_ = self.estimator.sampling( + rays_o_flatten, + rays_d_flatten, + alpha_fn=alpha_fn if self.cfg.prune_alpha_threshold else None, + near_plane=self.cfg.near_plane, + far_plane=self.cfg.far_plane, + render_step_size=self.render_step_size, + alpha_thre=0.01 if self.cfg.prune_alpha_threshold else 0.0, + stratified=self.randomized, + cone_angle=0.0, + ) + elif self.cfg.estimator == "importance": + + def prop_sigma_fn( + t_starts: Float[Tensor, "Nr Ns"], + t_ends: Float[Tensor, "Nr Ns"], + proposal_network, + ): + if self.cfg.use_volsdf: + t_origins: Float[Tensor, "Nr 1 3"] = rays_o_flatten.unsqueeze(-2) + t_dirs: Float[Tensor, "Nr 1 3"] = rays_d_flatten.unsqueeze(-2) + positions: Float[Tensor, "Nr Ns 3"] = ( + t_origins + t_dirs * (t_starts + t_ends)[..., None] / 2.0 + ) + with torch.no_grad(): + geo_out = chunk_batch( + proposal_network, + self.cfg.eval_chunk_size, + positions.reshape(-1, 3), + output_normal=False, + ) + inv_std = self.variance(geo_out["sdf"]) + density = volsdf_density(geo_out["sdf"], inv_std) + return density.reshape(positions.shape[:2]) + else: + raise ValueError( + "Currently only VolSDF supports importance sampling." + ) + + t_starts_, t_ends_ = self.estimator.sampling( + prop_sigma_fns=[partial(prop_sigma_fn, proposal_network=self.geometry)], + prop_samples=[self.cfg.num_samples_per_ray_importance], + num_samples=self.cfg.num_samples_per_ray, + n_rays=n_rays, + near_plane=self.cfg.near_plane, + far_plane=self.cfg.far_plane, + sampling_type="uniform", + stratified=self.randomized, + ) + ray_indices = ( + torch.arange(n_rays, device=rays_o_flatten.device) + .unsqueeze(-1) + .expand(-1, t_starts_.shape[1]) + ) + ray_indices = ray_indices.flatten() + t_starts_ = t_starts_.flatten() + t_ends_ = t_ends_.flatten() + else: + raise NotImplementedError + + ray_indices, t_starts_, t_ends_ = validate_empty_rays( + ray_indices, t_starts_, t_ends_ + ) + ray_indices = ray_indices.long() + t_starts, t_ends = t_starts_[..., None], t_ends_[..., None] + t_origins = rays_o_flatten[ray_indices] + t_dirs = rays_d_flatten[ray_indices] + t_light_positions = light_positions_flatten[ray_indices] + t_positions = (t_starts + t_ends) / 2.0 + positions = t_origins + t_dirs * t_positions + t_intervals = t_ends - t_starts + + if self.training: + geo_out = self.geometry(positions, output_normal=True) + rgb_fg_all = self.material( + viewdirs=t_dirs, + positions=positions, + light_positions=t_light_positions, + **geo_out, + **kwargs + ) + comp_rgb_bg = self.background(dirs=rays_d) + else: + geo_out = chunk_batch( + self.geometry, + self.cfg.eval_chunk_size, + positions, + output_normal=True, + ) + rgb_fg_all = chunk_batch( + self.material, + self.cfg.eval_chunk_size, + viewdirs=t_dirs, + positions=positions, + light_positions=t_light_positions, + **geo_out + ) + comp_rgb_bg = chunk_batch( + self.background, self.cfg.eval_chunk_size, dirs=rays_d + ) + + # grad or normal? + alpha: Float[Tensor, "Nr 1"] = self.get_alpha( + geo_out["sdf"], geo_out["normal"], t_dirs, t_intervals + ) + + weights: Float[Tensor, "Nr 1"] + weights_, _ = nerfacc.render_weight_from_alpha( + alpha[..., 0], + ray_indices=ray_indices, + n_rays=n_rays, + ) + weights = weights_[..., None] + opacity: Float[Tensor, "Nr 1"] = nerfacc.accumulate_along_rays( + weights[..., 0], values=None, ray_indices=ray_indices, n_rays=n_rays + ) + depth: Float[Tensor, "Nr 1"] = nerfacc.accumulate_along_rays( + weights[..., 0], values=t_positions, ray_indices=ray_indices, n_rays=n_rays + ) + comp_rgb_fg: Float[Tensor, "Nr Nc"] = nerfacc.accumulate_along_rays( + weights[..., 0], values=rgb_fg_all, ray_indices=ray_indices, n_rays=n_rays + ) + + if bg_color is None: + bg_color = comp_rgb_bg + + if bg_color.shape[:-1] == (batch_size, height, width): + bg_color = bg_color.reshape(batch_size * height * width, -1) + + comp_rgb = comp_rgb_fg + bg_color * (1.0 - opacity) + + out = { + "comp_rgb": comp_rgb.view(batch_size, height, width, -1), + "comp_rgb_fg": comp_rgb_fg.view(batch_size, height, width, -1), + "comp_rgb_bg": comp_rgb_bg.view(batch_size, height, width, -1), + "opacity": opacity.view(batch_size, height, width, 1), + "depth": depth.view(batch_size, height, width, 1), + } + + if self.training: + out.update( + { + "weights": weights, + "t_points": t_positions, + "t_intervals": t_intervals, + "t_dirs": t_dirs, + "ray_indices": ray_indices, + "points": positions, + **geo_out, + } + ) + else: + if "normal" in geo_out: + comp_normal: Float[Tensor, "Nr 3"] = nerfacc.accumulate_along_rays( + weights[..., 0], + values=geo_out["normal"], + ray_indices=ray_indices, + n_rays=n_rays, + ) + comp_normal = F.normalize(comp_normal, dim=-1) + comp_normal = (comp_normal + 1.0) / 2.0 * opacity # for visualization + out.update( + { + "comp_normal": comp_normal.view(batch_size, height, width, 3), + } + ) + out.update({"inv_std": self.variance.inv_std}) + return out + + def update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ) -> None: + self.cos_anneal_ratio = ( + 1.0 + if self.cfg.cos_anneal_end_steps == 0 + else min(1.0, global_step / self.cfg.cos_anneal_end_steps) + ) + if self.cfg.estimator == "occgrid": + if self.cfg.grid_prune: + + def occ_eval_fn(x): + sdf = self.geometry.forward_sdf(x) + inv_std = self.variance(sdf) + if self.cfg.use_volsdf: + alpha = self.render_step_size * volsdf_density(sdf, inv_std) + else: + estimated_next_sdf = sdf - self.render_step_size * 0.5 + estimated_prev_sdf = sdf + self.render_step_size * 0.5 + prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_std) + next_cdf = torch.sigmoid(estimated_next_sdf * inv_std) + p = prev_cdf - next_cdf + c = prev_cdf + alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) + return alpha + + if self.training and not on_load_weights: + self.estimator.update_every_n_steps( + step=global_step, occ_eval_fn=occ_eval_fn + ) + + def train(self, mode=True): + self.randomized = mode and self.cfg.randomized + return super().train(mode=mode) + + def eval(self): + self.randomized = False + return super().eval() \ No newline at end of file diff --git a/threestudio/models/renderers/nvdiff_rasterizer.py b/threestudio/models/renderers/nvdiff_rasterizer.py new file mode 100644 index 0000000..ced6433 --- /dev/null +++ b/threestudio/models/renderers/nvdiff_rasterizer.py @@ -0,0 +1,188 @@ +from dataclasses import dataclass + +import nerfacc +import torch +import torch.nn.functional as F + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.models.geometry.base import BaseImplicitGeometry +from threestudio.models.materials.base import BaseMaterial +from threestudio.models.renderers.base import Rasterizer, VolumeRenderer +from threestudio.utils.misc import get_device +from threestudio.utils.rasterize import NVDiffRasterizerContext +from threestudio.utils.typing import * + + +@threestudio.register("nvdiff-rasterizer") +class NVDiffRasterizer(Rasterizer): + @dataclass + class Config(VolumeRenderer.Config): + context_type: str = "gl" + + cfg: Config + + def configure( + self, + geometry: BaseImplicitGeometry, + material: BaseMaterial, + background: BaseBackground, + ) -> None: + super().configure(geometry, material, background) + self.ctx = NVDiffRasterizerContext(self.cfg.context_type, get_device()) + + def forward( + self, + mvp_mtx: Float[Tensor, "B 4 4"], + camera_positions: Float[Tensor, "B 3"], + light_positions: Float[Tensor, "B 3"], + height: int, + width: int, + render_rgb: bool = True, + render_mask: bool = False, + **kwargs + ) -> Dict[str, Any]: + batch_size = mvp_mtx.shape[0] + mesh = self.geometry.isosurface() + + v_pos_clip: Float[Tensor, "B Nv 4"] = self.ctx.vertex_transform( + mesh.v_pos, mvp_mtx + ) + rast, _ = self.ctx.rasterize(v_pos_clip, mesh.t_pos_idx, (height, width)) + mask = rast[..., 3:] > 0 + mask_aa = self.ctx.antialias(mask.float(), rast, v_pos_clip, mesh.t_pos_idx) + + out = {"opacity": mask_aa, "mesh": mesh} + + if render_mask: + # get front-view visibility mask + with torch.no_grad(): + mvp_mtx_ref = kwargs["mvp_mtx_ref"] # FIXME + v_pos_clip_front: Float[Tensor, "B Nv 4"] = self.ctx.vertex_transform( + mesh.v_pos, mvp_mtx_ref + ) + rast_front, _ = self.ctx.rasterize(v_pos_clip_front, mesh.t_pos_idx, (height, width)) + mask_front = rast_front[..., 3:] + mask_front = mask_front[mask_front > 0] - 1. + faces_vis = mesh.t_pos_idx[mask_front.long()] + + mesh._v_rgb = torch.zeros(mesh.v_pos.shape[0], 1).to(mesh.v_pos) + mesh._v_rgb[faces_vis[:,0]] = 1. + mesh._v_rgb[faces_vis[:,1]] = 1. + mesh._v_rgb[faces_vis[:,2]] = 1. + mask_vis, _ = self.ctx.interpolate_one(mesh._v_rgb, rast, mesh.t_pos_idx) + mask_vis = mask_vis > 0. + # from torchvision.utils import save_image + # save_image(mask_vis.permute(0,3,1,2).float(), "debug.png") + out.update({"mask": 1.0 - mask_vis.float()}) + + # FIXME: paste texture back to mesh + # import cv2 + # import imageio + # import numpy as np + + # gt_rgb = imageio.imread("load/images/tiger_nurse_rgba.png")/255. + # gt_rgb = cv2.resize(gt_rgb[:,:,:3],(512, 512)) + # gt_rgb = torch.Tensor(gt_rgb[None,...]).permute(0,3,1,2).to(v_pos_clip_front) + + # # align to up-z and front-x + # dir2vec = { + # "+x": np.array([1, 0, 0]), + # "+y": np.array([0, 1, 0]), + # "+z": np.array([0, 0, 1]), + # "-x": np.array([-1, 0, 0]), + # "-y": np.array([0, -1, 0]), + # "-z": np.array([0, 0, -1]), + # } + # z_, x_ = ( + # dir2vec["-y"], + # dir2vec["-z"], + # ) + + # y_ = np.cross(z_, x_) + # std2mesh = np.stack([x_, y_, z_], axis=0).T + # v_pos_ = (torch.mm(torch.tensor(std2mesh).to(mesh.v_pos), mesh.v_pos.T).T) * 2 + # print(v_pos_.min(), v_pos_.max()) + + # mesh._v_rgb=F.grid_sample(gt_rgb, v_pos_[None, None][..., :2], mode="nearest").permute(3,1,0,2).squeeze(-1).squeeze(-1).contiguous() + # rgb_vis, _ = self.ctx.interpolate_one(mesh._v_rgb, rast, mesh.t_pos_idx) + # rgb_vis_aa = self.ctx.antialias( + # rgb_vis, rast, v_pos_clip, mesh.t_pos_idx + # ) + # from torchvision.utils import save_image + # save_image(rgb_vis_aa.permute(0,3,1,2), "debug.png") + + + gb_normal, _ = self.ctx.interpolate_one(mesh.v_nrm, rast, mesh.t_pos_idx) + gb_normal = F.normalize(gb_normal, dim=-1) + gb_normal_aa = torch.lerp( + torch.zeros_like(gb_normal), (gb_normal + 1.0) / 2.0, mask.float() + ) + gb_normal_aa = self.ctx.antialias( + gb_normal_aa, rast, v_pos_clip, mesh.t_pos_idx + ) + out.update({"comp_normal": gb_normal_aa}) # in [0, 1] + + # Compute normal in view space. + # TODO: make is clear whether to compute this. + w2c = kwargs["c2w"][:, :3, :3].inverse() + gb_normal_viewspace = torch.einsum("bij,bhwj->bhwi", w2c, gb_normal) + gb_normal_viewspace = F.normalize(gb_normal_viewspace, dim=-1) + bg_normal = torch.zeros_like(gb_normal_viewspace) + bg_normal[..., 2] = 1 + gb_normal_viewspace_aa = torch.lerp( + (bg_normal + 1.0) / 2.0, + (gb_normal_viewspace + 1.0) / 2.0, + mask.float(), + ).contiguous() + gb_normal_viewspace_aa = self.ctx.antialias( + gb_normal_viewspace_aa, rast, v_pos_clip, mesh.t_pos_idx + ) + out.update({"comp_normal_viewspace": gb_normal_viewspace_aa}) + + # TODO: make it clear whether to compute the normal, now we compute it in all cases + # consider using: require_normal_computation = render_normal or (render_rgb and material.requires_normal) + # or + # render_normal = render_normal or (render_rgb and material.requires_normal) + + if render_rgb: + selector = mask[..., 0] + + gb_pos, _ = self.ctx.interpolate_one(mesh.v_pos, rast, mesh.t_pos_idx) + gb_viewdirs = F.normalize( + gb_pos - camera_positions[:, None, None, :], dim=-1 + ) + gb_light_positions = light_positions[:, None, None, :].expand( + -1, height, width, -1 + ) + + positions = gb_pos[selector] + geo_out = self.geometry(positions, output_normal=False) + + extra_geo_info = {} + if self.material.requires_normal: + extra_geo_info["shading_normal"] = gb_normal[selector] + if self.material.requires_tangent: + gb_tangent, _ = self.ctx.interpolate_one( + mesh.v_tng, rast, mesh.t_pos_idx + ) + gb_tangent = F.normalize(gb_tangent, dim=-1) + extra_geo_info["tangent"] = gb_tangent[selector] + + rgb_fg = self.material( + viewdirs=gb_viewdirs[selector], + positions=positions, + light_positions=gb_light_positions[selector], + **extra_geo_info, + **geo_out + ) + gb_rgb_fg = torch.zeros(batch_size, height, width, 3).to(rgb_fg) + gb_rgb_fg[selector] = rgb_fg + + gb_rgb_bg = self.background(dirs=gb_viewdirs) + gb_rgb = torch.lerp(gb_rgb_bg, gb_rgb_fg, mask.float()) + gb_rgb_aa = self.ctx.antialias(gb_rgb, rast, v_pos_clip, mesh.t_pos_idx) + + out.update({"comp_rgb": gb_rgb_aa, "comp_rgb_bg": gb_rgb_bg}) + + return out \ No newline at end of file diff --git a/threestudio/models/renderers/patch_renderer.py b/threestudio/models/renderers/patch_renderer.py new file mode 100644 index 0000000..dae374f --- /dev/null +++ b/threestudio/models/renderers/patch_renderer.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +import threestudio +from threestudio.models.background.base import BaseBackground +from threestudio.models.geometry.base import BaseImplicitGeometry +from threestudio.models.materials.base import BaseMaterial +from threestudio.models.renderers.base import VolumeRenderer +from threestudio.utils.typing import * + + +@threestudio.register("patch-renderer") +class PatchRenderer(VolumeRenderer): + @dataclass + class Config(VolumeRenderer.Config): + patch_size: int = 128 + base_renderer_type: str = "" + base_renderer: Optional[VolumeRenderer.Config] = None + global_detach: bool = False + global_downsample: int = 4 + + cfg: Config + + def configure( + self, + geometry: BaseImplicitGeometry, + material: BaseMaterial, + background: BaseBackground, + ) -> None: + self.base_renderer = threestudio.find(self.cfg.base_renderer_type)( + self.cfg.base_renderer, + geometry=geometry, + material=material, + background=background, + ) + + def forward( + self, + rays_o: Float[Tensor, "B H W 3"], + rays_d: Float[Tensor, "B H W 3"], + light_positions: Float[Tensor, "B 3"], + bg_color: Optional[Tensor] = None, + **kwargs + ) -> Dict[str, Float[Tensor, "..."]]: + B, H, W, _ = rays_o.shape + + if self.base_renderer.training: + downsample = self.cfg.global_downsample + global_rays_o = torch.nn.functional.interpolate( + rays_o.permute(0, 3, 1, 2), + (H // downsample, W // downsample), + mode="bilinear", + ).permute(0, 2, 3, 1) + global_rays_d = torch.nn.functional.interpolate( + rays_d.permute(0, 3, 1, 2), + (H // downsample, W // downsample), + mode="bilinear", + ).permute(0, 2, 3, 1) + out_global = self.base_renderer( + global_rays_o, global_rays_d, light_positions, bg_color, **kwargs + ) + + PS = self.cfg.patch_size + patch_x = torch.randint(0, W - PS, (1,)).item() + patch_y = torch.randint(0, H - PS, (1,)).item() + patch_rays_o = rays_o[:, patch_y : patch_y + PS, patch_x : patch_x + PS] + patch_rays_d = rays_d[:, patch_y : patch_y + PS, patch_x : patch_x + PS] + out = self.base_renderer( + patch_rays_o, patch_rays_d, light_positions, bg_color, **kwargs + ) + + valid_patch_key = [] + for key in out: + if torch.is_tensor(out[key]): + if len(out[key].shape) == len(out["comp_rgb"].shape): + if out[key][..., 0].shape == out["comp_rgb"][..., 0].shape: + valid_patch_key.append(key) + for key in valid_patch_key: + out_global[key] = F.interpolate( + out_global[key].permute(0, 3, 1, 2), (H, W), mode="bilinear" + ).permute(0, 2, 3, 1) + if self.cfg.global_detach: + out_global[key] = out_global[key].detach() + out_global[key][ + :, patch_y : patch_y + PS, patch_x : patch_x + PS + ] = out[key] + out = out_global + else: + out = self.base_renderer( + rays_o, rays_d, light_positions, bg_color, **kwargs + ) + + return out + + def update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ) -> None: + self.base_renderer.update_step(epoch, global_step, on_load_weights) + + def train(self, mode=True): + return self.base_renderer.train(mode) + + def eval(self): + return self.base_renderer.eval() \ No newline at end of file diff --git a/threestudio/scripts/convert_zero123_to_diffusers.py b/threestudio/scripts/convert_zero123_to_diffusers.py new file mode 100644 index 0000000..e212774 --- /dev/null +++ b/threestudio/scripts/convert_zero123_to_diffusers.py @@ -0,0 +1,1025 @@ +import argparse +import sys + +import torch +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import logging +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +sys.path.append("extern/") +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from zero123 import CLIPCameraProjection, Zero123Pipeline + +logger = logging.get_logger(__name__) + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + if ( + "unet_config" in original_config.model.params + and original_config.model.params.unet_config is not None + ): + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [ + unet_params.model_channels * mult for mult in unet_params.channel_mult + ] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnDownBlock2D" + if resolution in unet_params.attention_resolutions + else "DownBlock2D" + ) + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnUpBlock2D" + if resolution in unet_params.attention_resolutions + else "UpBlock2D" + ) + up_block_types.append(block_type) + resolution //= 2 + + if unet_params.transformer_depth is not None: + transformer_layers_per_block = ( + unet_params.transformer_depth + if isinstance(unet_params.transformer_depth, int) + else list(unet_params.transformer_depth) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer + if "use_linear_in_transformer" in unet_params + else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = ( + unet_params.context_dim + if isinstance(unet_params.context_dim, int) + else unet_params.context_dim[0] + ) + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError( + f"Unknown conditional unet num_classes config: {unet_params.num_classes}" + ) + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if controlnet: + config["conditioning_channels"] = unet_params.hint_channels + else: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def assign_to_checkpoint( + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance( + paths, list + ), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape( + (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] + ) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if ( + attention_paths_to_split is not None + and new_path in attention_paths_to_split + ): + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ( + "attentions" in new_path and "to_" in new_path + ) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint( + checkpoint, + config, + path=None, + extract_ema=False, + controlnet=False, + skip_extract_state_dict=False, +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint[ + flat_ema_key + ] + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint[key] + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[ + "time_embed.0.weight" + ] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict[ + "time_embed.0.bias" + ] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict[ + "time_embed.2.weight" + ] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict[ + "time_embed.2.bias" + ] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif ( + config["class_embed_type"] == "timestep" + or config["class_embed_type"] == "projection" + ): + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict[ + "label_emb.0.0.weight" + ] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict[ + "label_emb.0.0.bias" + ] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict[ + "label_emb.0.2.weight" + ] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict[ + "label_emb.0.2.bias" + ] + else: + raise NotImplementedError( + f"Not implemented `class_embed_type`: {config['class_embed_type']}" + ) + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict[ + "label_emb.0.0.weight" + ] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict[ + "label_emb.0.0.bias" + ] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict[ + "label_emb.0.2.weight" + ] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict[ + "label_emb.0.2.bias" + ] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "input_blocks" in layer + } + ) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "middle_block" in layer + } + ) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "output_blocks" in layer + } + ) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key + for key in input_blocks[i] + if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.weight" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.bias" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + + paths = renew_resnet_paths(resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"input_blocks.{i}.1", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [ + key for key in output_blocks[i] if f"output_blocks.{i}.1" in key + ] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index( + ["conv.bias", "conv.weight"] + ) + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.weight" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.bias" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + else: + resnet_0_paths = renew_resnet_paths( + output_block_layers, n_shave_prefix_segments=1 + ) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + [ + "up_blocks", + str(block_id), + "resnets", + str(layer_in_block_id), + path["new"], + ] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint[ + "controlnet_cond_embedding.conv_in.weight" + ] = unet_state_dict.pop(f"input_hint_block.{orig_index}.weight") + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[ + f"controlnet_cond_embedding.blocks.{diffusers_index}.weight" + ] = unet_state_dict.pop(f"input_hint_block.{orig_index}.weight") + new_checkpoint[ + f"controlnet_cond_embedding.blocks.{diffusers_index}.bias" + ] = unet_state_dict.pop(f"input_hint_block.{orig_index}.bias") + diffusers_index += 1 + orig_index += 2 + + new_checkpoint[ + "controlnet_cond_embedding.conv_out.weight" + ] = unet_state_dict.pop(f"input_hint_block.{orig_index}.weight") + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop( + f"zero_convs.{i}.0.weight" + ) + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop( + f"zero_convs.{i}.0.bias" + ) + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop( + "middle_block_out.0.weight" + ) + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop( + "middle_block_out.0.bias" + ) + + return new_checkpoint + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ + "encoder.conv_out.weight" + ] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ + "encoder.norm_out.weight" + ] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ + "encoder.norm_out.bias" + ] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ + "decoder.conv_out.weight" + ] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ + "decoder.norm_out.weight" + ] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ + "decoder.norm_out.bias" + ] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "encoder.down" in layer + } + ) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] + for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "decoder.up" in layer + } + ) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] + for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [ + key + for key in down_blocks[i] + if f"down.{i}" in key and f"down.{i}.downsample" not in key + ] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key + for key in up_blocks[block_id] + if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def convert_from_original_zero123_ckpt( + checkpoint_path, original_config_file, extract_ema, device +): + ckpt = torch.load(checkpoint_path, map_location=device) + global_step = ckpt["global_step"] + checkpoint = ckpt["state_dict"] + del ckpt + torch.cuda.empty_cache() + + from omegaconf import OmegaConf + + original_config = OmegaConf.load(original_config_file) + model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + num_in_channels = 8 + original_config["model"]["params"]["unet_config"]["params"][ + "in_channels" + ] = num_in_channels + prediction_type = "epsilon" + image_size = 256 + num_train_timesteps = ( + getattr(original_config.model.params, "timesteps", None) or 1000 + ) + + beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 + beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + scheduler.register_to_config(clip_sample=False) + + # Convert the UNet2DConditionModel model. + upcast_attention = None + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + with init_empty_weights(): + unet = UNet2DConditionModel(**unet_config) + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=None, extract_ema=extract_ema + ) + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config.model + and "scale_factor" in original_config.model.params + ): + vae_scaling_factor = original_config.model.params.scale_factor + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + with init_empty_weights(): + vae = AutoencoderKL(**vae_config) + + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + + feature_extractor = CLIPImageProcessor.from_pretrained( + "lambdalabs/sd-image-variations-diffusers", subfolder="feature_extractor" + ) + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "lambdalabs/sd-image-variations-diffusers", subfolder="image_encoder" + ) + + clip_camera_projection = CLIPCameraProjection(additional_embeddings=4) + clip_camera_projection.load_state_dict( + { + "proj.weight": checkpoint["cc_projection.weight"].cpu(), + "proj.bias": checkpoint["cc_projection.bias"].cpu(), + } + ) + + pipe = Zero123Pipeline( + vae, + image_encoder, + unet, + scheduler, + None, + feature_extractor, + clip_camera_projection, + requires_safety_checker=False, + ) + + return pipe + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the checkpoint to convert.", + ) + parser.add_argument( + "--original_config_file", + default=None, + type=str, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument( + "--half", action="store_true", help="Save weights in half precision." + ) + parser.add_argument( + "--dump_path", + default=None, + type=str, + required=True, + help="Path to the output model.", + ) + parser.add_argument( + "--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)" + ) + args = parser.parse_args() + + pipe = convert_from_original_zero123_ckpt( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + extract_ema=args.extract_ema, + device=args.device, + ) + + if args.half: + pipe.to(torch_dtype=torch.float16) + + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/threestudio/scripts/dreamcraft3d_dreambooth.py b/threestudio/scripts/dreamcraft3d_dreambooth.py new file mode 100644 index 0000000..32fa988 --- /dev/null +++ b/threestudio/scripts/dreamcraft3d_dreambooth.py @@ -0,0 +1,53 @@ +import argparse +import os +from subprocess import run, CalledProcessError + +import cv2 +import glob +import numpy as np +import pytorch_lightning as pl +import torch +from tqdm import tqdm +from torchvision.utils import save_image + +from threestudio.scripts.generate_mv_datasets import generate_mv_dataset +from threestudio.utils.config import load_config +import threestudio + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True, help="path to config file") + parser.add_argument("--action", default="both", help="action to perform", choices=["gen_data", "dreambooth", "both""]) + args, extras = parser.parse_known_args() + return args, extras + + +def main(args, extras): + cfg = load_config(args.config, cli_args=extras, n_gpus=1) + + if args.action == "gen_data" or args.action == "both": + # Generate multi-view dataset + generate_mv_dataset(cfg) + + if args.action == "dreambooth" or args.action == "both": + # Run DreamBooth. + command = f'accelerate launch threestudio/scripts/train_dreambooth.py \ + --pretrained_model_name_or_path="{cfg.custom_import.dreambooth.model_name}" \ + --instance_data_dir="{cfg.custom_import.dreambooth.instance_dir}" \ + --output_dir="{cfg.custom_import.dreambooth.output_dir}"\ + --instance_prompt="{cfg.custom_import.dreambooth.prompt_dreambooth}" \ + --resolution=512 \ + --train_batch_size=2 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=1000' + + os.system(command) + + +if __name__ == "__main__": + args, extras = parse_args() + main(args, extras) diff --git a/threestudio/scripts/generate_images_if.py b/threestudio/scripts/generate_images_if.py new file mode 100644 index 0000000..e668f96 --- /dev/null +++ b/threestudio/scripts/generate_images_if.py @@ -0,0 +1,92 @@ +from diffusers import DiffusionPipeline +from diffusers.utils import pt_to_pil +import torch + +import os +import glob +import json +import argparse +import numpy as np +from tqdm import tqdm + + +SAVE_FOLDER = "./load/images_dreamfusion" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--rank", default=0, type=int, help="# of GPU") + parser.add_argument("--prompt",required=True, type=str) + + args = parser.parse_args() + + # stage 1 + stage_1 = DiffusionPipeline.from_pretrained( + "DeepFloyd/IF-I-XL-v1.0", + variant="fp16", + torch_dtype=torch.float16, + local_files_only=True + ) + stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 + stage_1.enable_model_cpu_offload() + + # stage 2 + stage_2 = DiffusionPipeline.from_pretrained( + "DeepFloyd/IF-II-L-v1.0", + text_encoder=None, + variant="fp16", + torch_dtype=torch.float16, + local_files_only=True + ) + # stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 + stage_2.enable_model_cpu_offload() + + # stage 3 + # safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker} + safety_modules = None + stage_3 = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", + torch_dtype=torch.float16, + local_files_only=True + ) + stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 + stage_3.enable_model_cpu_offload() + + # # load prompt library + # with open(os.path.join("load/prompt_library.json"), "r") as f: + # prompt_library = json.load(f) + + # n_prompts = len(prompt_library["dreamfusion"]) + # n_prompts_per_rank = int(np.ceil(n_prompts / 8)) + + # for prompt in tqdm(prompt_library["dreamfusion"][args.rank * n_prompts_per_rank : (args.rank + 1) * n_prompts_per_rank]): + + prompt = args.prompt + print("Prompt:", prompt) + + save_folder = os.path.join(SAVE_FOLDER, prompt) + os.makedirs(save_folder, exist_ok=True) + + # if len(glob.glob(f"{save_folder}/*.png")) >= 30: + # continue + + # enhance prompt + prompt = prompt + ", 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, hyperrealistic, intricate details, ultra-realistic, award-winning" + + prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) + for _ in tqdm(range(30)): + seed = np.random.randint(low=0, high=10000000, size=1)[0] + generator = torch.manual_seed(seed) + + ### Stage 1 + image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images + # pt_to_pil(image)[0].save("./if_stage_I.png") + + ### Stage 2 + image = stage_2( + image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt" + ).images + # pt_to_pil(image)[0].save("./if_stage_II.png") + + ### Stage 3 + image = stage_3(prompt=prompt, image=(image.float() * 0.5 + 0.5), generator=generator, noise_level=100).images + image[0].save(f"{save_folder}/img_{seed:08d}.png") \ No newline at end of file diff --git a/threestudio/scripts/generate_images_if_prompt_library.py b/threestudio/scripts/generate_images_if_prompt_library.py new file mode 100644 index 0000000..af6bd19 --- /dev/null +++ b/threestudio/scripts/generate_images_if_prompt_library.py @@ -0,0 +1,90 @@ +from diffusers import DiffusionPipeline +from diffusers.utils import pt_to_pil +import torch + +import os +import glob +import json +import argparse +import numpy as np +from tqdm import tqdm + + +SAVE_FOLDER = "./load/images_dreamfusion" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--rank", default=0, type=int, help="# of GPU") + + args = parser.parse_args() + + # stage 1 + stage_1 = DiffusionPipeline.from_pretrained( + "DeepFloyd/IF-I-XL-v1.0", + variant="fp16", + torch_dtype=torch.float16, + local_files_only=True + ) + stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 + stage_1.enable_model_cpu_offload() + + # stage 2 + stage_2 = DiffusionPipeline.from_pretrained( + "DeepFloyd/IF-II-L-v1.0", + text_encoder=None, + variant="fp16", + torch_dtype=torch.float16, + local_files_only=True + ) + # stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 + stage_2.enable_model_cpu_offload() + + # stage 3 + # safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker} + safety_modules = None + stage_3 = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", + torch_dtype=torch.float16, + local_files_only=True + ) + stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 + stage_3.enable_model_cpu_offload() + + # load prompt library + with open(os.path.join("load/prompt_library.json"), "r") as f: + prompt_library = json.load(f) + + n_prompts = len(prompt_library["dreamfusion"]) + n_prompts_per_rank = int(np.ceil(n_prompts / 8)) + + for prompt in tqdm(prompt_library["dreamfusion"][args.rank * n_prompts_per_rank : (args.rank + 1) * n_prompts_per_rank]): + + print("Prompt:", prompt) + + save_folder = os.path.join(SAVE_FOLDER, prompt) + os.makedirs(save_folder, exist_ok=True) + + if len(glob.glob(f"{save_folder}/*.png")) >= 30: + continue + + # enhance prompt + prompt = prompt + ", 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, hyperrealistic, intricate details, ultra-realistic, award-winning" + + prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) + for _ in tqdm(range(30)): + seed = np.random.randint(low=0, high=10000000, size=1)[0] + generator = torch.manual_seed(seed) + + ### Stage 1 + image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images + # pt_to_pil(image)[0].save("./if_stage_I.png") + + ### Stage 2 + image = stage_2( + image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt" + ).images + # pt_to_pil(image)[0].save("./if_stage_II.png") + + ### Stage 3 + image = stage_3(prompt=prompt, image=(image.float() * 0.5 + 0.5), generator=generator, noise_level=100).images + image[0].save(f"{save_folder}/img_{seed:08d}.png") \ No newline at end of file diff --git a/threestudio/scripts/generate_mv_datasets.py b/threestudio/scripts/generate_mv_datasets.py new file mode 100644 index 0000000..d0f53fd --- /dev/null +++ b/threestudio/scripts/generate_mv_datasets.py @@ -0,0 +1,95 @@ +import os +import cv2 +import glob +import torch +import argparse +import numpy as np +from tqdm import tqdm +import pytorch_lightning as pl +from torchvision.utils import save_image +from subprocess import run, CalledProcessError +from threestudio.utils.config import load_config +import threestudio + +# Constants +AZIMUTH_FACTOR = 360 +IMAGE_SIZE = (512, 512) + + +def copy_file(source, destination): + try: + command = ['cp', source, destination] + result = run(command, capture_output=True, text=True) + result.check_returncode() + except CalledProcessError as e: + print(f'Error: {e.output}') + + +def prepare_images(cfg): + rgb_list = sorted(glob.glob(os.path.join(cfg.data.render_image_path, "*.png"))) + rgb_list.sort(key=lambda file: int(os.path.splitext(os.path.basename(file))[0])) + n_rgbs = len(rgb_list) + n_samples = cfg.data.n_samples + + os.makedirs(cfg.data.save_path, exist_ok=True) + + copy_file(cfg.data.ref_image_path, f"{cfg.data.save_path}/ref_0.0.png") + + sampled_indices = np.linspace(0, len(rgb_list)-1, n_samples, dtype=int) + rgb_samples = [rgb_list[index] for index in sampled_indices] + + return rgb_samples + + +def process_images(rgb_samples, cfg, guidance, prompt_utils): + n_rgbs = 120 + for rgb_name in tqdm(rgb_samples): + rgb_idx = int(os.path.basename(rgb_name).split(".")[0]) + rgb = cv2.imread(rgb_name)[:, :, :3][:, :, ::-1].copy() / 255.0 + H, W = rgb.shape[0:2] + rgb_image, mask_image = rgb[:, :H], rgb[:, -H:, :1] + rgb_image = cv2.resize(rgb_image, IMAGE_SIZE) + rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device) + + mask_image = cv2.resize(mask_image, IMAGE_SIZE).reshape(IMAGE_SIZE[0], IMAGE_SIZE[1], 1) + mask_image = torch.FloatTensor(mask_image).unsqueeze(0).to(guidance.device) + + temp = torch.zeros(1).to(guidance.device) + azimuth = torch.tensor([rgb_idx/n_rgbs * AZIMUTH_FACTOR]).to(guidance.device) + camera_distance = torch.tensor([cfg.data.default_camera_distance]).to(guidance.device) + + if cfg.data.view_dependent_noise: + guidance.min_step_percent = 0. + (rgb_idx/n_rgbs) * (cfg.system.guidance.min_step_percent) + guidance.max_step_percent = 0. + (rgb_idx/n_rgbs) * (cfg.system.guidance.max_step_percent) + + denoised_image = process_guidance(cfg, guidance, prompt_utils, rgb_image, azimuth, temp, camera_distance, mask_image) + + save_image(denoised_image.permute(0,3,1,2), f"{cfg.data.save_path}/img_{azimuth[0]}.png", normalize=True, value_range=(0, 1)) + + copy_file(rgb_name.replace("png", "npy"), f"{cfg.data.save_path}/img_{azimuth[0]}.npy") + + if rgb_idx == 0: + copy_file(rgb_name.replace("png", "npy"), f"{cfg.data.save_path}/ref_{azimuth[0]}.npy") + + +def process_guidance(cfg, guidance, prompt_utils, rgb_image, azimuth, temp, camera_distance, mask_image): + if cfg.data.azimuth_range[0] < azimuth < cfg.data.azimuth_range[1]: + return guidance.sample_img2img( + rgb_image, prompt_utils, temp, + azimuth, camera_distance, seed=0, mask=mask_image + )["edit_image"] + else: + return rgb_image + + +def generate_mv_dataset(cfg): + + guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance) + prompt_processor = threestudio.find(cfg.system.prompt_processor_type)(cfg.system.prompt_processor) + prompt_utils = prompt_processor() + + guidance.update_step(epoch=0, global_step=0) + rgb_samples = prepare_images(cfg) + print(rgb_samples) + process_images(rgb_samples, cfg, guidance, prompt_utils) + diff --git a/threestudio/scripts/img_to_mv.py b/threestudio/scripts/img_to_mv.py new file mode 100644 index 0000000..43f0a26 --- /dev/null +++ b/threestudio/scripts/img_to_mv.py @@ -0,0 +1,84 @@ +import os +import argparse +from PIL import Image +import torch +from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, StableDiffusionUpscalePipeline + + +def load_model(superres): + mv_model = DiffusionPipeline.from_pretrained( + "sudo-ai/zero123plus-v1.1", custom_pipeline="sudo-ai/zero123plus-pipeline", + torch_dtype=torch.float16, cache_dir="load/checkpoints/huggingface/hub", local_files_only=True, + ) + mv_model.scheduler = EulerAncestralDiscreteScheduler.from_config( + mv_model.scheduler.config, timestep_spacing='trailing', cache_dir="load/checkpoints/huggingface/hub", local_files_only=True, + ) + + if superres: + superres_model = StableDiffusionUpscalePipeline.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", revision="fp16", + torch_dtype=torch.float16, cache_dir="load/checkpoints/huggingface/hub", local_files_only=True, + ) + else: + superres_model = None + + return mv_model, superres_model + + +def superres_4x(image, model, prompt): + low_res_img = image.resize((256, 256)) + model.to('cuda:1') + result = model(prompt=prompt, image=low_res_img).images[0] + return result + + +def img_to_mv(image_path, model): + cond = Image.open(image_path) + model.to('cuda:1') + result = model(cond, num_inference_steps=75).images[0] + return result + + +def crop_save_image_to_2x3_grid(image, args, model): + save_path = args.save_path + width, height = image.size + grid_width = width//2 + grid_height = height//3 + + images = [] + for i in range(3): + for j in range(2): + left = j * grid_width + upper = i * grid_height + right = (j+1) * grid_width + lower = (i+1) * grid_height + + cropped_image = image.crop((left, upper, right, lower)) + if args.superres: + cropped_image = superres_4x(cropped_image, model, args.prompt) + images.append(cropped_image) + + for idx, img in enumerate(images): + img.save(os.path.join(save_path, f'cropped_{idx}.jpg')) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--image_path', type=str, help="path to image (png, jpeg, etc.)") + parser.add_argument('--save_path', type=str, help="path to save output images") + parser.add_argument('--prompt', type=str, help="prompt to use for superres") + parser.add_argument('--superres', action='store_true', help="whether to use superres") + args = parser.parse_args() + + print(args.superres) + + os.makedirs(args.save_path, exist_ok=True) + os.system(f"cp '{args.image_path}' '{args.save_path}'") + + mv_model, superres_model = load_model(args.superres) + images = img_to_mv(args.image_path, mv_model) + crop_save_image_to_2x3_grid(images, args, superres_model) + + +# Example usage: +# python threestudio/scripts/img_to_mv.py --image_path 'mushroom.png' --save_path '.cache/temp' --prompt 'a photo of mushroom' --superres \ No newline at end of file diff --git a/threestudio/scripts/make_training_vid.py b/threestudio/scripts/make_training_vid.py new file mode 100644 index 0000000..71b5f08 --- /dev/null +++ b/threestudio/scripts/make_training_vid.py @@ -0,0 +1,77 @@ +# make_training_vid("outputs/zero123/64_teddy_rgba.png@20230627-195615", frames_per_vid=30, fps=20, max_iters=200) +import argparse +import glob +import os + +import imageio +import numpy as np +from PIL import Image, ImageDraw +from tqdm import tqdm + + +def draw_text_in_image(img, texts): + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + black, white = (0, 0, 0), (255, 255, 255) + for i, text in enumerate(texts): + draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) + return np.asarray(img) + + +def make_training_vid(exp, frames_per_vid=1, fps=3, max_iters=None, max_vids=None): + # exp = "/admin/home-vikram/git/threestudio/outputs/zero123/64_teddy_rgba.png@20230627-195615" + files = glob.glob(os.path.join(exp, "save", "*.mp4")) + if os.path.join(exp, "save", "training_vid.mp4") in files: + files.remove(os.path.join(exp, "save", "training_vid.mp4")) + its = [int(os.path.basename(file).split("-")[0].split("it")[-1]) for file in files] + it_sort = np.argsort(its) + files = list(np.array(files)[it_sort]) + its = list(np.array(its)[it_sort]) + max_vids = max_iters // its[0] if max_iters is not None else max_vids + files, its = files[:max_vids], its[:max_vids] + frames, i = [], 0 + for it, file in tqdm(zip(its, files), total=len(files)): + vid = imageio.mimread(file) + for _ in range(frames_per_vid): + frame = vid[i % len(vid)] + frame = draw_text_in_image(frame, [str(it)]) + frames.append(frame) + i += 1 + # Save + imageio.mimwrite(os.path.join(exp, "save", "training_vid.mp4"), frames, fps=fps) + + +def join(file1, file2, name): + # file1 = "/admin/home-vikram/git/threestudio/outputs/zero123/OLD_64_dragon2_rgba.png@20230629-023028/save/it200-val.mp4" + # file2 = "/admin/home-vikram/git/threestudio/outputs/zero123/64_dragon2_rgba.png@20230628-152734/save/it200-val.mp4" + vid1 = imageio.mimread(file1) + vid2 = imageio.mimread(file2) + frames = [] + for f1, f2 in zip(vid1, vid2): + frames.append( + np.concatenate([f1[:, : f1.shape[0]], f2[:, : f2.shape[0]]], axis=1) + ) + imageio.mimwrite(name, frames) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--exp", help="directory of experiment") + parser.add_argument( + "--frames_per_vid", type=int, default=1, help="# of frames from each val vid" + ) + parser.add_argument("--fps", type=int, help="max # of iters to save") + parser.add_argument("--max_iters", type=int, help="max # of iters to save") + parser.add_argument( + "--max_vids", + type=int, + help="max # of val videos to save. Will be overridden by max_iters", + ) + args = parser.parse_args() + make_training_vid( + args.exp, args.frames_per_vid, args.fps, args.max_iters, args.max_vids + ) \ No newline at end of file diff --git a/threestudio/scripts/metric_utils.py b/threestudio/scripts/metric_utils.py new file mode 100644 index 0000000..a5f134a --- /dev/null +++ b/threestudio/scripts/metric_utils.py @@ -0,0 +1,459 @@ +# * evaluate use laion/CLIP-ViT-H-14-laion2B-s32B-b79K +# best open source clip so far: laion/CLIP-ViT-bigG-14-laion2B-39B-b160k +# code adapted from NeuralLift-360 + +import torch +import torch.nn as nn +import os +import torchvision.transforms as T +import torchvision.transforms.functional as TF +import matplotlib.pyplot as plt +# import clip +from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor +from torchvision import transforms +import numpy as np +import torch.nn.functional as F +from tqdm import tqdm +import cv2 +from PIL import Image +# import torchvision.transforms as transforms +import glob +from skimage.metrics import peak_signal_noise_ratio as compare_psnr +import lpips +from os.path import join as osp +import argparse +import pandas as pd + +class CLIP(nn.Module): + + def __init__(self, + device, + clip_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', + size=224): #'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'): + super().__init__() + self.size = size + self.device = f"cuda:{device}" + + clip_name = clip_name + + self.feature_extractor = CLIPFeatureExtractor.from_pretrained( + clip_name) + self.clip_model = CLIPModel.from_pretrained(clip_name).to(self.device) + self.tokenizer = CLIPTokenizer.from_pretrained( + 'openai/clip-vit-base-patch32') + + self.normalize = transforms.Normalize( + mean=self.feature_extractor.image_mean, + std=self.feature_extractor.image_std) + + self.resize = transforms.Resize(224) + self.to_tensor = transforms.ToTensor() + + # image augmentation + self.aug = T.Compose([ + T.Resize((224, 224)), + T.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + # * recommend to use this function for evaluation + @torch.no_grad() + def score_gt(self, ref_img_path, novel_views): + # assert len(novel_views) == 100 + clip_scores = [] + for novel in novel_views: + clip_scores.append(self.score_from_path(ref_img_path, [novel])) + return np.mean(clip_scores) + + # * recommend to use this function for evaluation + # def score_gt(self, ref_paths, novel_paths): + # clip_scores = [] + # for img1_path, img2_path in zip(ref_paths, novel_paths): + # clip_scores.append(self.score_from_path(img1_path, img2_path)) + + # return np.mean(clip_scores) + + def similarity(self, image1_features: torch.Tensor, + image2_features: torch.Tensor) -> float: + with torch.no_grad(), torch.cuda.amp.autocast(): + y = image1_features.T.view(image1_features.T.shape[1], + image1_features.T.shape[0]) + similarity = torch.matmul(y, image2_features.T) + # print(similarity) + return similarity[0][0].item() + + def get_img_embeds(self, img): + if img.shape[0] == 4: + img = img[:3, :, :] + + img = self.aug(img).to(self.device) + img = img.unsqueeze(0) # b,c,h,w + + # plt.imshow(img.cpu().squeeze(0).permute(1, 2, 0).numpy()) + # plt.show() + # print(img) + + image_z = self.clip_model.get_image_features(img) + image_z = image_z / image_z.norm(dim=-1, + keepdim=True) # normalize features + return image_z + + def score_from_feature(self, img1, img2): + img1_feature, img2_feature = self.get_img_embeds( + img1), self.get_img_embeds(img2) + # for debug + return self.similarity(img1_feature, img2_feature) + + def read_img_list(self, img_list): + size = self.size + images = [] + # white_background = np.ones((size, size, 3), dtype=np.uint8) * 255 + + for img_path in img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + # print(img_path) + if img.shape[2] == 4: # Handle BGRA images + alpha = img[:, :, 3] # Extract alpha channel + img = cv2.cvtColor(img,cv2.COLOR_BGRA2RGB) # Convert BGRA to BGR + img[np.where(alpha == 0)] = [ + 255, 255, 255 + ] # Set transparent pixels to white + else: # Handle other image formats like JPG and PNG + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) + + # plt.imshow(img) + # plt.show() + + images.append(img) + + images = np.stack(images, axis=0) + # images[np.where(images == 0)] = 255 # Set black pixels to white + # images = np.where(images == 0, white_background, images) # Set transparent pixels to white + # images = images.astype(np.float32) + + return images + + def score_from_path(self, img1_path, img2_path): + img1, img2 = self.read_img_list(img1_path), self.read_img_list(img2_path) + img1 = np.squeeze(img1) + img2 = np.squeeze(img2) + # plt.imshow(img1) + # plt.show() + # plt.imshow(img2) + # plt.show() + + img1, img2 = self.to_tensor(img1), self.to_tensor(img2) + # print("img1 to tensor ",img1) + return self.score_from_feature(img1, img2) + + +def numpy_to_torch(images): + images = images * 2.0 - 1.0 + images = torch.from_numpy(images.transpose((0, 3, 1, 2))).float() + return images.cuda() + + +class LPIPSMeter: + + def __init__(self, + net='alex', + device=None, + size=224): # or we can use 'alex', 'vgg' as network + self.size = size + self.net = net + self.results = [] + self.device = device if device is not None else torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.fn = lpips.LPIPS(net=net).eval().to(self.device) + + def measure(self): + return np.mean(self.results) + + def report(self): + return f'LPIPS ({self.net}) = {self.measure():.6f}' + + def read_img_list(self, img_list): + size = self.size + images = [] + white_background = np.ones((size, size, 3), dtype=np.uint8) * 255 + + for img_path in img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + + if img.shape[2] == 4: # Handle BGRA images + alpha = img[:, :, 3] # Extract alpha channel + img = cv2.cvtColor(img, + cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR + + img = cv2.cvtColor(img, + cv2.COLOR_BGR2RGB) # Convert BGR to RGB + img[np.where(alpha == 0)] = [ + 255, 255, 255 + ] # Set transparent pixels to white + else: # Handle other image formats like JPG and PNG + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) + images.append(img) + + images = np.stack(images, axis=0) + # images[np.where(images == 0)] = 255 # Set black pixels to white + # images = np.where(images == 0, white_background, images) # Set transparent pixels to white + images = images.astype(np.float32) / 255.0 + + return images + + # * recommend to use this function for evaluation + @torch.no_grad() + def score_gt(self, ref_paths, novel_paths): + self.results = [] + for path0, path1 in zip(ref_paths, novel_paths): + # Load images + # img0 = lpips.im2tensor(lpips.load_image(path0)).cuda() # RGB image from [-1,1] + # img1 = lpips.im2tensor(lpips.load_image(path1)).cuda() + img0, img1 = self.read_img_list([path0]), self.read_img_list( + [path1]) + img0, img1 = numpy_to_torch(img0), numpy_to_torch(img1) + # print(img0.shape,img1.shape) + img0 = F.interpolate(img0, + size=(self.size, self.size), + mode='area') + img1 = F.interpolate(img1, + size=(self.size, self.size), + mode='area') + + # for debug vis + # plt.imshow(img0.cpu().squeeze(0).permute(1, 2, 0).numpy()) + # plt.show() + # plt.imshow(img1.cpu().squeeze(0).permute(1, 2, 0).numpy()) + # plt.show() + # equivalent to cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA + + # print(img0.shape,img1.shape) + + self.results.append(self.fn.forward(img0, img1).cpu().numpy()) + + return self.measure() + + +class PSNRMeter: + + def __init__(self, size=800): + self.results = [] + self.size = size + + def read_img_list(self, img_list): + size = self.size + images = [] + white_background = np.ones((size, size, 3), dtype=np.uint8) * 255 + for img_path in img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + + if img.shape[2] == 4: # Handle BGRA images + alpha = img[:, :, 3] # Extract alpha channel + img = cv2.cvtColor(img, + cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR + + img = cv2.cvtColor(img, + cv2.COLOR_BGR2RGB) # Convert BGR to RGB + img[np.where(alpha == 0)] = [ + 255, 255, 255 + ] # Set transparent pixels to white + else: # Handle other image formats like JPG and PNG + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) + images.append(img) + + images = np.stack(images, axis=0) + # images[np.where(images == 0)] = 255 # Set black pixels to white + # images = np.where(images == 0, white_background, images) # Set transparent pixels to white + images = images.astype(np.float32) / 255.0 + # print(images.shape) + return images + + def update(self, preds, truths): + # print(preds.shape) + + psnr_values = [] + # For each pair of images in the batches + for img1, img2 in zip(preds, truths): + # Compute the PSNR and add it to the list + # print(img1.shape,img2.shape) + + # for debug + # plt.imshow(img1) + # plt.show() + # plt.imshow(img2) + # plt.show() + + psnr = compare_psnr( + img1, img2, + data_range=1.0) # assuming your images are scaled to [0,1] + # print(f"temp psnr {psnr}") + psnr_values.append(psnr) + + # Convert the list of PSNR values to a numpy array + self.results = psnr_values + + def measure(self): + return np.mean(self.results) + + def report(self): + return f'PSNR = {self.measure():.6f}' + + # * recommend to use this function for evaluation + def score_gt(self, ref_paths, novel_paths): + self.results = [] + # [B, N, 3] or [B, H, W, 3], range[0, 1] + preds = self.read_img_list(ref_paths) + truths = self.read_img_list(novel_paths) + self.update(preds, truths) + return self.measure() + +all_inputs = 'data' +nerf_dataset = os.listdir(osp(all_inputs, 'nerf4')) +realfusion_dataset = os.listdir(osp(all_inputs, 'realfusion15')) +meta_examples = { + 'nerf4': nerf_dataset, + 'realfusion15': realfusion_dataset, +} +all_datasets = meta_examples.keys() + +# organization 1 +def deprecated_score_from_method_for_dataset(my_scorer, + method, + dataset, + input, + output, + score_type='clip', + ): # psnr, lpips + # print("\n\n\n") + # print(f"______{method}___{dataset}___{score_type}_________") + scores = {} + final_res = 0 + examples = meta_examples[dataset] + for i in range(len(examples)): + + # compare entire folder for clip + if score_type == 'clip': + novel_view = osp(pred_path, examples[i], 'colors') + # compare first image for other metrics + else: + if method == '3d_fuse': method = '3d_fuse_0' + novel_view = list( + glob.glob( + osp(pred_path, examples[i], 'colors', + 'step_0000*')))[0] + + score_i = my_scorer.score_gt( + [], [novel_view]) + scores[examples[i]] = score_i + final_res += score_i + # print(scores, " Avg : ", final_res / len(examples)) + # print("``````````````````````") + return scores + +# results organization 2 +def score_from_method_for_dataset(my_scorer, + input_path, + pred_path, + score_type='clip', + rgb_name='lambertian', + result_folder='results/images', + first_str='*0000*' + ): # psnr, lpips + scores = {} + final_res = 0 + examples = os.listdir(input_path) + for i in range(len(examples)): + # ref path + ref_path = osp(input_path, examples[i], 'rgba.png') + # compare entire folder for clip + if score_type == 'clip': + novel_view = glob.glob(osp(pred_path,'*'+examples[i]+'*', result_folder, f'*{rgb_name}*')) + print(f'[INOF] {score_type} loss for example {examples[i]} between 1 GT and {len(novel_view)} predictions') + # compare first image for other metrics + else: + novel_view = glob.glob(osp(pred_path, '*'+examples[i]+'*/', result_folder, f'{first_str}{rgb_name}*')) + print(f'[INOF] {score_type} loss for example {examples[i]} between {ref_path} and {novel_view}') + # breakpoint() + score_i = my_scorer.score_gt([ref_path], novel_view) + scores[examples[i]] = score_i + final_res += score_i + avg_score = final_res / len(examples) + scores['average'] = avg_score + return scores + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Script to accept three string arguments") + parser.add_argument("--input_path", + default=all_inputs, + help="Specify the input path") + parser.add_argument("--pred_pattern", + default="out/magic123*", + help="Specify the pattern of predition paths") + parser.add_argument("--results_folder", + default="results/images", + help="where are the results under each pred_path") + parser.add_argument("--rgb_name", + default="lambertian", + help="the postfix of the image") + parser.add_argument("--first_str", + default="*0000*", + help="the str to indicate the first view") + parser.add_argument("--datasets", + default=all_datasets, + nargs='*', + help="Specify the output path") + parser.add_argument("--device", + type=int, + default=0, + help="Specify the GPU device to be used") + parser.add_argument("--save_dir", type=str, default='all_metrics/results') + args = parser.parse_args() + + clip_scorer = CLIP(args.device) + lpips_scorer = LPIPSMeter() + psnr_scorer = PSNRMeter() + + os.makedirs(args.save_dir, exist_ok=True) + + for dataset in args.datasets: + input_path = osp(args.input_path, dataset) + + # assume the pred_path is organized as: pred_path/methods/dataset + pred_pattern = osp(args.pred_pattern, dataset) + pred_paths = glob.glob(pred_pattern) + print(f"[INFO] Following the pattern {pred_pattern}, find {len(pred_paths)} pred_paths: \n", pred_paths) + if len(pred_paths) == 0: + raise IOError + for pred_path in pred_paths: + if not os.path.exists(pred_path): + print(f'[WARN] prediction does not exit for {pred_path}') + else: + print(f'[INFO] evaluate {pred_path}') + results_dict = {} + results_dict['clip'] = score_from_method_for_dataset( + clip_scorer, input_path, pred_path, 'clip', + result_folder=args.results_folder, rgb_name=args.rgb_name, first_str=args.first_str) + + results_dict['psnr'] = score_from_method_for_dataset( + psnr_scorer, input_path, pred_path, 'psnr', + result_folder=args.results_folder, rgb_name=args.rgb_name, first_str=args.first_str) + + results_dict['lpips'] = score_from_method_for_dataset( + lpips_scorer, input_path, pred_path, 'lpips', + result_folder=args.results_folder, rgb_name=args.rgb_name, first_str=args.first_str) + + df = pd.DataFrame(results_dict) + method = pred_path.split('/')[-2] + print(osp(pred_path, args.results_folder)) + results_str = '_'.join(args.results_folder.split('/')) + print(method+'-'+results_str) + print(df) + df.to_csv(f"{args.save_dir}/{method}-{results_str}-{dataset}.csv") \ No newline at end of file diff --git a/threestudio/scripts/run_gaussian.sh b/threestudio/scripts/run_gaussian.sh new file mode 100755 index 0000000..a39092b --- /dev/null +++ b/threestudio/scripts/run_gaussian.sh @@ -0,0 +1,36 @@ +import subprocess + +prompt_list = [ + "a delicious hamburger", + "A DSLR photo of a roast turkey on a platter", + "A high quality photo of a dragon", + "A DSLR photo of a bald eagle", + "A bunch of blue rose, highly detailed", + "A 3D model of an adorable cottage with a thatched roof", + "A high quality photo of a furry corgi", + "A DSLR photo of a panda", + "a DSLR photo of a cat lying on its side batting at a ball of yarn", + "a beautiful dress made out of fruit, on a mannequin. Studio lighting, high quality, high resolution", + "a DSLR photo of a corgi wearing a beret and holding a baguette, standing up on two hind legs", + "a zoomed out DSLR photo of a stack of pancakes", + "a zoomed out DSLR photo of a baby bunny sitting on top of a stack of pancakes", +] +negative_prompt = "oversaturated color, ugly, tiling, low quality, noise, ugly pattern" + +gpu_id = 0 +max_steps = 10 +val_check = 1 +out_name = "gsgen_baseline" +for prompt in prompt_list: + print(f"Running model on device {gpu_id}: ", prompt) + command = [ + "python", "launch.py", + "--config", "configs/gaussian_splatting.yaml", + "--train", + f"system.prompt_processor.prompt={prompt}", + f"system.prompt_processor.negative_prompt={negative_prompt}", + f"name={out_name}", + "--gpu", f"{gpu_id}" + ] + subprocess.run(command) + \ No newline at end of file diff --git a/threestudio/scripts/run_zero123.py b/threestudio/scripts/run_zero123.py new file mode 100644 index 0000000..17361f2 --- /dev/null +++ b/threestudio/scripts/run_zero123.py @@ -0,0 +1,13 @@ +NAME="dragon2" + +# Phase 1 - 64x64 +python launch.py --config configs/zero123.yaml --train --gpu 7 data.image_path=./load/images/${NAME}_rgba.png use_timestamp=False name=${NAME} tag=Phase1 # system.freq.guidance_eval=0 system.loggers.wandb.enable=false system.loggers.wandb.project="zero123" system.loggers.wandb.name=${NAME}_Phase1 + +# Phase 1.5 - 512 refine +python launch.py --config configs/zero123-geometry.yaml --train --gpu 4 data.image_path=./load/images/${NAME}_rgba.png system.geometry_convert_from=./outputs/${NAME}/Phase1/ckpts/last.ckpt use_timestamp=False name=${NAME} tag=Phase1p5 # system.freq.guidance_eval=0 system.loggers.wandb.enable=false system.loggers.wandb.project="zero123" system.loggers.wandb.name=${NAME}_Phase1p5 + +# Phase 2 - dreamfusion +python launch.py --config configs/experimental/imagecondition_zero123nerf.yaml --train --gpu 5 data.image_path=./load/images/${NAME}_rgba.png system.prompt_processor.prompt="A 3D model of a friendly dragon" system.weights="/admin/home-vikram/git/threestudio/outputs/${NAME}/Phase1/ckpts/last.ckpt" name=${NAME} tag=Phase2 # system.freq.guidance_eval=0 system.loggers.wandb.enable=false system.loggers.wandb.project="zero123" system.loggers.wandb.name=${NAME}_Phase2 + +# Phase 2 - SDF + dreamfusion +python launch.py --config configs/experimental/imagecondition_zero123nerf_refine.yaml --train --gpu 5 data.image_path=./load/images/${NAME}_rgba.png system.prompt_processor.prompt="A 3D model of a friendly dragon" system.geometry_convert_from="/admin/home-vikram/git/threestudio/outputs/${NAME}/Phase1/ckpts/last.ckpt" name=${NAME} tag=Phase2_refine # system.freq.guidance_eval=0 system.loggers.wandb.enable=false system.loggers.wandb.project="zero123" system.loggers.wandb.name=${NAME}_Phase2_refine \ No newline at end of file diff --git a/threestudio/scripts/run_zero123_comparison.sh b/threestudio/scripts/run_zero123_comparison.sh new file mode 100755 index 0000000..7f9f60a --- /dev/null +++ b/threestudio/scripts/run_zero123_comparison.sh @@ -0,0 +1,23 @@ +# with standard zero123 +threestudio/scripts/run_zero123_phase.sh 6 anya_front 105000 0 + +# with zero123XL (not released yet!) +threestudio/scripts/run_zero123_phase.sh 1 anya_front XL_20230604 0 +threestudio/scripts/run_zero123_phase.sh 2 baby_phoenix_on_ice XL_20230604 20 +threestudio/scripts/run_zero123_phase.sh 3 beach_house_1 XL_20230604 50 +threestudio/scripts/run_zero123_phase.sh 4 bollywood_actress XL_20230604 0 +threestudio/scripts/run_zero123_phase.sh 5 beach_house_2 XL_20230604 30 +threestudio/scripts/run_zero123_phase.sh 6 hamburger XL_20230604 10 +threestudio/scripts/run_zero123_phase.sh 7 cactus XL_20230604 8 +threestudio/scripts/run_zero123_phase.sh 0 catstatue XL_20230604 50 +threestudio/scripts/run_zero123_phase.sh 1 church_ruins XL_20230604 0 +threestudio/scripts/run_zero123_phase.sh 2 firekeeper XL_20230604 10 +threestudio/scripts/run_zero123_phase.sh 3 futuristic_car XL_20230604 20 +threestudio/scripts/run_zero123_phase.sh 4 mona_lisa XL_20230604 10 +threestudio/scripts/run_zero123_phase.sh 5 teddy XL_20230604 20 + +# set guidance_eval to 0, to greatly speed up training +threestudio/scripts/run_zero123_phase.sh 7 anya_front XL_20230604 0 system.freq.guidance_eval=0 + +# disable wandb for faster training (or if you don't want to use it) +threestudio/scripts/run_zero123_phase.sh 7 anya_front XL_20230604 0 system.loggers.wandb.enable=false system.freq.guidance_eval=0 diff --git a/threestudio/scripts/run_zero123_demo.sh b/threestudio/scripts/run_zero123_demo.sh new file mode 100644 index 0000000..de45993 --- /dev/null +++ b/threestudio/scripts/run_zero123_demo.sh @@ -0,0 +1,25 @@ +NAME="dragon2" + +# Phase 1 - 64x64 +python launch.py --config configs/zero123_64.yaml --train --gpu 7 system.loggers.wandb.enable=false system.loggers.wandb.project="voletiv-anya-new" system.loggers.wandb.name=${NAME} data.image_path=./load/images/${NAME}_rgba.png system.freq.guidance_eval=0 system.guidance.pretrained_model_name_or_path="./load/zero123/XL_20230604.ckpt" use_timestamp=False name=${NAME} tag="Phase1_64" + +# python threestudio/scripts/make_training_vid.py --exp /admin/home-vikram/git/threestudio/outputs/zero123/64_dragon2_rgba.png@20230628-152734 --frames_per_vid 30 --fps 20 --max_iters 200 + +# # Phase 1.5 - 512 +# python launch.py --config configs/zero123_512.yaml --train --gpu 5 system.loggers.wandb.enable=true system.loggers.wandb.project="voletiv-zero123XL-demo" system.loggers.wandb.name="robot_512_drel_n_XL_SAMEgeom" data.image_path=./load/images/robot_rgba.png system.freq.guidance_eval=0 system.guidance.pretrained_model_name_or_path="./load/zero123/XL_20230604.ckpt" tag='${data.random_camera.height}_${rmspace:${basename:${data.image_path}},_}_XL_SAMEgeom' system.weights="/admin/home-vikram/git/threestudio/outputs/zero123/[64, 128]_robot_rgba.png_OLD@20230630-052314/ckpts/last.ckpt" + +# Phase 1.5 - 512 refine +python launch.py --config configs/zero123-geometry.yaml --train --gpu 4 system.loggers.wandb.enable=false system.loggers.wandb.project="voletiv-zero123XL-demo" system.loggers.wandb.name="robot_512_drel_n_XL_SAMEg" system.freq.guidance_eval=0 data.image_path=./load/images/${NAME}_rgba.png system.geometry_convert_from=./outputs/${NAME}/Phase1_64/ckpts/last.ckpt use_timestamp=False name=${NAME} tag="Phase2_512geom" + +# Phase 2 - dreamfusion +python launch.py --config configs/experimental/imagecondition_zero123nerf.yaml --train --gpu 5 system.loggers.wandb.enable=false system.loggers.wandb.project="voletiv-zero123XL-demo" system.loggers.wandb.name="robot_512_drel_n_XL_SAMEw" tag='${data.random_camera.height}_${rmspace:${basename:${data.image_path}},_}_XL_Phase2' system.freq.guidance_eval=0 data.image_path=./load/images/robot_rgba.png system.prompt_processor.prompt="A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" system.weights="/admin/home-vikram/git/threestudio/outputs/zero123/[64, 128]_robot_rgba.png_OLD@20230630-052314/ckpts/last.ckpt" + +python launch.py --config configs/experimental/imagecondition_zero123nerf_refine.yaml --train --gpu 5 system.loggers.wandb.enable=false system.loggers.wandb.project="voletiv-zero123XL-demo" system.loggers.wandb.name="robot_512_drel_n_XL_SAMEw" tag='${data.random_camera.height}_${rmspace:${basename:${data.image_path}},_}_XL_Phase2_refine' system.freq.guidance_eval=0 data.image_path=./load/images/robot_rgba.png system.prompt_processor.prompt="A 3D model of a friendly dragon" system.geometry_convert_from="/admin/home-vikram/git/threestudio/outputs/zero123/[64, 128, 256]_dragon2_rgba.png_XL_REPEAT@20230705-023531/ckpts/last.ckpt" + +# A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" +# "/admin/home-vikram/git/threestudio/outputs/zero123/[64, 128]_robot_rgba.png_OLD@20230630-052314/ckpts/last.ckpt" + +# Adds zero123_512-refine.yaml +# Adds resolution_milestones to image.py +# guidance_eval gets max batch_size 4 +# Introduces random_bg in solid_color_bg \ No newline at end of file diff --git a/threestudio/scripts/run_zero123_phase.sh b/threestudio/scripts/run_zero123_phase.sh new file mode 100755 index 0000000..fb816d1 --- /dev/null +++ b/threestudio/scripts/run_zero123_phase.sh @@ -0,0 +1,14 @@ + +GPU_ID=$1 # e.g. 0 +IMAGE_PREFIX=$2 # e.g. "anya_front" +ZERO123_PREFIX=$3 # e.g. "XL_20230604" +ELEVATION=$4 # e.g. 0 +REST=${@:5:99} # e.g. "system.guidance.min_step_percent=0.1 system.guidance.max_step_percent=0.9" + +# change this config if you don't use wandb or want to speed up training +python launch.py --config configs/zero123.yaml --train --gpu $GPU_ID system.loggers.wandb.enable=true system.loggers.wandb.project="claforte-noise_atten" \ + system.loggers.wandb.name="${IMAGE_PREFIX}_zero123_${ZERO123_PREFIX}...fov20_${REST}" \ + data.image_path=./load/images/${IMAGE_PREFIX}_rgba.png system.freq.guidance_eval=37 \ + system.guidance.pretrained_model_name_or_path="./load/zero123/${ZERO123_PREFIX}.ckpt" \ + system.guidance.cond_elevation_deg=$ELEVATION \ + ${REST} diff --git a/threestudio/scripts/run_zero123_phase2.sh b/threestudio/scripts/run_zero123_phase2.sh new file mode 100644 index 0000000..be03021 --- /dev/null +++ b/threestudio/scripts/run_zero123_phase2.sh @@ -0,0 +1,5 @@ +# Reconstruct Anya using latest Zero123XL, in <2000 steps. +python launch.py --config configs/zero123.yaml --train --gpu 0 system.loggers.wandb.enable=true system.loggers.wandb.project="voletiv-anya-new" system.loggers.wandb.name="claforte_params" data.image_path=./load/images/anya_front_rgba.png system.freq.ref_or_zero123="accumulate" system.freq.guidance_eval=13 system.guidance.pretrained_model_name_or_path="./load/zero123/XL_20230604.ckpt" + +# PHASE 2 +python launch.py --config configs/experimental/imagecondition_zero123nerf.yaml --train --gpu 0 system.prompt_processor.prompt="A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" system.weights=outputs/zero123/128_anya_front_rgba.png@20230623-145711/ckpts/last.ckpt system.freq.guidance_eval=13 system.loggers.wandb.enable=true system.loggers.wandb.project="voletiv-anya-new" data.image_path=./load/images/anya_front_rgba.png system.loggers.wandb.name="anya" data.random_camera.progressive_until=500 \ No newline at end of file diff --git a/threestudio/scripts/test_dreambooth.py b/threestudio/scripts/test_dreambooth.py new file mode 100644 index 0000000..57efeb3 --- /dev/null +++ b/threestudio/scripts/test_dreambooth.py @@ -0,0 +1,54 @@ +from diffusers import StableDiffusionPipeline, DDIMScheduler +import torch + +# model_id = "load/checkpoints/sd_21_base_mushroom_vd_prompt" +# model_id = "load/checkpoints/sd_base_mushroom" +model_id = ".cache/checkpoints/sd_21_base_rabbit" +# scheduler = DDIMScheduler() +pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") +guidance_scale = 7.5 + +prompt = "a sks rabbit, front view" +image = pipe(prompt, num_inference_steps=50, guidance_scale=guidance_scale).images[0] + +image.save("debug.png") + + +# import os +# import cv2 +# import glob +# import torch +# import argparse +# import numpy as np +# from tqdm import tqdm +# import pytorch_lightning as pl +# from torchvision.utils import save_image + +# import threestudio +# from threestudio.utils.config import load_config + + +# if __name__ == "__main__": +# parser = argparse.ArgumentParser() +# parser.add_argument("--config", required=True, help="path to config file") +# parser.add_argument("--view_dependent_noise", action="store_true", help="use view depdendent noise strength") + +# args, extras = parser.parse_known_args() + +# cfg = load_config(args.config, cli_args=extras, n_gpus=1) +# guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance) +# prompt_processor = threestudio.find(cfg.system.prompt_processor_type)(cfg.system.prompt_processor) +# prompt_utils = prompt_processor() + +# guidance.update_step(epoch=0, global_step=0) +# elevation, azimuth = torch.zeros(1).cuda(), torch.zeros(1).cuda() +# camera_distances = torch.tensor([3.0]).cuda() +# c2w = torch.zeros(4,4).cuda() +# a = guidance.sample(prompt_utils, elevation, azimuth, camera_distances) # sample_lora +# from torchvision.utils import save_image +# save_image(a.permute(0,3,1,2), "debug.png", normalize=True, value_range=(0,1)) + + + +# python threestudio/scripts/test_dreambooth.py --config configs/experimental/stablediffusion.yaml system.prompt_processor.prompt="a sks mushroom growing on a log" \ +# system.guidance.pretrained_model_name_or_path_lora="load/checkpoints/sd_21_base_mushroom_camera_condition" \ No newline at end of file diff --git a/threestudio/scripts/test_dreambooth_lora.py b/threestudio/scripts/test_dreambooth_lora.py new file mode 100644 index 0000000..a3439f0 --- /dev/null +++ b/threestudio/scripts/test_dreambooth_lora.py @@ -0,0 +1,25 @@ +import torch +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler + + +# model_base = "stabilityai/stable-diffusion-2-1-base" + +# pipe = DiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16, cache_dir=CACHE_DIR, local_files_only=True) +# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, cache_dir=CACHE_DIR, local_files_only=True) +# lora_model_path = "load/checkpoints/sd_21_base_bear_dreambooth_lora" +# pipe.unet.load_attn_procs(lora_model_path) + +# pipe.to("cuda") + + +# image = pipe("A picture of a sks bear in the sky", num_inference_steps=50, guidance_scale=7.5).images[0] +# image.save("bear_dreambooth_lora.png") + + +pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", local_files_only=True, safety_checker=None) +pipe.load_lora_weights("if_dreambooth_mushroom") +pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small") +pipe.to("cuda:7") + +image = pipe("A photo of a sks mushroom, front view", num_inference_steps=50, guidance_scale=7.5).images[0] +image.save("mushroom_dreambooth_lora.png") \ No newline at end of file diff --git a/threestudio/scripts/train_dreambooth.py b/threestudio/scripts/train_dreambooth.py new file mode 100644 index 0000000..c90af63 --- /dev/null +++ b/threestudio/scripts/train_dreambooth.py @@ -0,0 +1,1500 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import gc +import hashlib +import importlib +import itertools +import logging +import math +import os +import shutil +import warnings +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, model_info, upload_folder +from packaging import version +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_snr +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.embeddings import TimestepEmbedding + +from threestudio.utils.typing import * + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.23.0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +instance_prompt: {prompt} +tags: +- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} +- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} +- text-to-image +- diffusers +- dreambooth +inference: true +--- + """ + model_card = f""" +# DreamBooth - {repo_id} + +This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). +You can find some example images in the following. \n +{img_str} + +DreamBooth for the text encoder was enabled: {train_text_encoder}. +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def log_validation( + text_encoder, + tokenizer, + unet, + vae, + args, + accelerator, + weight_dtype, + global_step, + prompt_embeds, + negative_prompt_embeds, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + + pipeline_args = {} + + if vae is not None: + pipeline_args["vae"] = vae + + if text_encoder is not None: + text_encoder = accelerator.unwrap_model(text_encoder) + + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + **pipeline_args, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + module = importlib.import_module("diffusers") + scheduler_class = getattr(module, args.validation_scheduler) + pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + + # run inference + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + if args.validation_images is None: + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] + images.append(image) + else: + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + return images + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more details" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + + parser.add_argument( + "--offset_noise", + action="store_true", + default=False, + help=( + "Fine-tuning against a modified noise" + " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." + ), + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" + ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`, `camera_pose`.", + ) + parser.add_argument( + "--validation_scheduler", + type=str, + default="DPMSolverMultistepScheduler", + choices=["DPMSolverMultistepScheduler", "DDPMScheduler"], + help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.", + ) + + parser.add_argument( + "--use_view_dependent_prompt", + action="store_true", + help="Whether to use view-dependent prompt.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + class_num=None, + size=512, + center_crop=False, + encoder_hidden_states=None, + class_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, + use_view_dependent_prompt=False, + class_labels_conditioning=None, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length + self.use_view_dependent_prompt = use_view_dependent_prompt + self.class_labels_conditioning = class_labels_conditioning + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.instance_images_path = [img for img in self.instance_images_path if str(img).endswith("png")] + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + + if self.class_labels_conditioning=="camera_pose": + instance_camera_pose = np.load(str(self.instance_images_path[index % self.num_instance_images]).replace("png", "npy")) + example["instance_camera_pose"] = torch.tensor(instance_camera_pose).reshape(1, -1) + + if self.use_view_dependent_prompt: + angle = float(os.path.basename(self.instance_images_path[index % self.num_instance_images])[4:-4]) + if angle < 45 or angle >= 315: + view = "front view" + elif 45 <= angle < 135 or 225 <= angle < 315: + view = "side view" + else: + view = "back view" + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + # view-dependent prompt + if self.use_view_dependent_prompt: + instance_prompt = self.instance_prompt + f", {view}" + else: + instance_prompt = self.instance_prompt + text_inputs = tokenize_prompt( + self.tokenizer, instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + + if self.class_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask + + return example + + +def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + has_camera_pose = "instance_camera_pose" in examples[0] + + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + + if has_camera_pose: + camera_pose = [example["instance_camera_pose"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + + if has_attention_mask: + attention_mask = torch.cat(attention_mask, dim=0) + batch["attention_mask"] = attention_mask + + if has_camera_pose: + camera_pose = torch.cat(camera_pose, dim=0) + batch["camera_pose"] = camera_pose + + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def model_has_vae(args): + # config_file_name = os.path.join("vae", AutoencoderKL.config_name) + # if os.path.isdir(args.pretrained_model_name_or_path): + # config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) + # return os.path.isfile(config_file_name) + # else: + # files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings + # return any(file.rfilename == config_file_name for file in files_in_repo) + if args.pretrained_model_name_or_path.startswith("DeepFloyd"): + return False + else: + return True + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + +class ToWeightsDType(torch.nn.Module): + def __init__(self, module: torch.nn.Module, dtype: torch.dtype): + super().__init__() + self.module = module + self.dtype = dtype + + def forward(self, x: Float[Tensor, "..."]) -> Float[Tensor, "..."]: + return self.module(x).to(self.dtype) + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + local_files_only=True, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", cache_dir=CACHE_DIR,local_files_only=True) + noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(accelerator.device) + + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant, cache_dir=CACHE_DIR, local_files_only=True, + ) + + if model_has_vae(args): + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, cache_dir=CACHE_DIR, + local_files_only=True, + ) + else: + vae = None + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant, cache_dir=CACHE_DIR, + ) + + # set camera condition embedding + if args.class_labels_conditioning=="camera_pose": + camera_embedding = ToWeightsDType( + TimestepEmbedding(16, 1280), torch.float32 + ).to(accelerator.device) + unet.class_embedding = camera_embedding + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(text_encoder))): + # load transformers style into model + load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") + model.config = load_model.config + else: + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if vae is not None: + vae.requires_grad_(False) + + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + "Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training. copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32: + raise ValueError( + f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." + f" {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.class_prompt is not None: + pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) + else: + pre_computed_class_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_class_prompt_encoder_hidden_states = None + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + class_num=args.num_class_images, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, + use_view_dependent_prompt=args.use_view_dependent_prompt, + class_labels_conditioning=args.class_labels_conditioning, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and text_encoder to device and cast to weight_dtype + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + + if not args.train_text_encoder and text_encoder is not None: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = vars(copy.deepcopy(args)) + tracker_config.pop("validation_images") + accelerator.init_trackers("dreambooth", config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + + if vae is not None: + # Convert images to latent space + model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + else: + model_input = pixel_values + + # Sample noise that we'll add to the model input + if args.offset_noise: + noise = torch.randn_like(model_input) + 0.1 * torch.randn( + model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device + ) + else: + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Get the text embedding for conditioning + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + elif args.class_labels_conditioning == "camera_pose": + class_labels = batch["camera_pose"].to(dtype=weight_dtype) + else: + class_labels = None + + # Predict the noise residual + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels + ).sample + + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Compute instance loss + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + base_weight = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + images = [] + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + images = log_validation( + text_encoder, + tokenizer, + unet, + vae, + args, + accelerator, + weight_dtype, + global_step, + validation_prompt_encoder_hidden_states, + validation_prompt_negative_prompt_embeds, + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + pipeline_args = {} + + if text_encoder is not None: + pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + + if args.skip_save_text_encoder: + pipeline_args["text_encoder"] = None + + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + variant=args.variant, + cache_dir=CACHE_DIR, + local_files_only=True, + **pipeline_args, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + pipeline=pipeline, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/threestudio/scripts/train_dreambooth_lora.py b/threestudio/scripts/train_dreambooth_lora.py new file mode 100644 index 0000000..adfd488 --- /dev/null +++ b/threestudio/scripts/train_dreambooth_lora.py @@ -0,0 +1,1480 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import gc +import hashlib +import itertools +import logging +import math +import os +import shutil +import warnings +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import ( + LoraLoaderMixin, + text_encoder_lora_state_dict, +) +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, +) +from diffusers.models.lora import LoRALinearLayer +from diffusers.optimization import get_scheduler +from diffusers.training_utils import unet_lora_state_dict +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.23.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +instance_prompt: {prompt} +tags: +- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} +- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} +- text-to-image +- diffusers +- lora +inference: true +--- + """ + model_card = f""" +# LoRA DreamBooth - {repo_id} + +These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n +{img_str} + +LoRA for the text encoder was enabled: {train_text_encoder}. +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + local_files_only=True, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="lora-dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + parser.add_argument( + "--use_view_dependent_prompt", + action="store_true", + help="Whether to use view-dependent prompt.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + class_num=None, + size=512, + center_crop=False, + encoder_hidden_states=None, + class_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, + use_view_dependent_prompt=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length + self.use_view_dependent_prompt = use_view_dependent_prompt + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + Image.init() + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.instance_images_path = [p for p in self.instance_images_path if self._file_ext(p) in Image.EXTENSION] + print("images:", self.instance_images_path) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + @staticmethod + def _file_ext(fname): + return os.path.splitext(fname)[1].lower() + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + + if self.use_view_dependent_prompt: + angle = float(os.path.basename(self.instance_images_path[index % self.num_instance_images])[4:-4]) + if angle < 45 or angle >= 315: + view = "front view" + elif 45 <= angle < 135 or 225 <= angle < 315: + view = "side view" + else: + view = "back view" + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + # view-depnedent prompt + if self.use_view_dependent_prompt: + instance_prompt = self.instance_prompt + f", {view}" + else: + instance_prompt = self.instance_prompt + text_inputs = tokenize_prompt( + self.tokenizer, instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + + if self.class_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask + + return example + + +def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + + if has_attention_mask: + batch["attention_mask"] = attention_mask + + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", cache_dir=CACHE_DIR, local_files_only=True) + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant, cache_dir=CACHE_DIR, local_files_only=True, + ) + try: + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, cache_dir=CACHE_DIR, local_files_only=True, + ) + except OSError: + # IF does not have a VAE so let's just set it to None + # We don't have to error out here + vae = None + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant, cache_dir=CACHE_DIR, local_files_only=True, + ) + + # We only train the additional adapter LoRA layers + if vae is not None: + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + # now we will add new LoRA weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18 + # => 32 layers + + # Set correct lora layers + unet_lora_parameters = [] + for attn_processor_name, attn_processor in unet.attn_processors.items(): + # Parse the attention module. + attn_module = unet + for n in attn_processor_name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=args.rank, + ) + ) + + # Accumulate the LoRA params to optimize. + unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + + if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): + attn_module.add_k_proj.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.add_k_proj.in_features, + out_features=attn_module.add_k_proj.out_features, + rank=args.rank, + ) + ) + attn_module.add_v_proj.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.add_v_proj.in_features, + out_features=attn_module.add_v_proj.out_features, + rank=args.rank, + ) + ) + unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. + if args.train_text_encoder: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank) + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = unet_lora_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + unet_ = None + text_encoder_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ + ) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet_lora_parameters, text_lora_parameters) + if args.train_text_encoder + else unet_lora_parameters + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.class_prompt is not None: + pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) + else: + pre_computed_class_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_class_prompt_encoder_hidden_states = None + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + class_num=args.num_class_images, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, + use_view_dependent_prompt=args.use_view_dependent_prompt, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = vars(copy.deepcopy(args)) + tracker_config.pop("validation_images") + accelerator.init_trackers("dreambooth-lora", config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + + if vae is not None: + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + else: + model_input = pixel_values + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Get the text embedding for conditioning + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None + + # Predict the noise residual + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels + ).sample + + # if model predicts variance, throw away the prediction. we will only train on the + # simplified training objective. This means that all schedulers using the fine tuned + # model must be configured to use one of the fixed variance variance types. + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet_lora_parameters, text_lora_parameters) + if args.train_text_encoder + else unet_lora_parameters + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + safety_checker=None, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": validation_prompt_encoder_hidden_states, + "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + + if args.validation_images is None: + images = [] + for _ in range(args.num_validation_images): + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, generator=generator).images[0] + images.append(image) + else: + images = [] + for image in args.validation_images: + image = Image.open(image) + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet = unet.to(torch.float32) + unet_lora_layers = unet_lora_state_dict(unet) + + if text_encoder is not None and args.train_text_encoder: + text_encoder = accelerator.unwrap_model(text_encoder) + text_encoder = text_encoder.to(torch.float32) + text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder) + else: + text_encoder_lora_layers = None + + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) + + # Final inference + # Load previous pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, + cache_dir=CACHE_DIR, local_files_only=True, safety_checker=None, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args, cache_dir=CACHE_DIR, local_files_only=True,) + + pipeline = pipeline.to(accelerator.device) + + # load attention processors + pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + pipeline=pipeline, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/threestudio/scripts/train_text_to_image_lora.py b/threestudio/scripts/train_text_to_image_lora.py new file mode 100644 index 0000000..32b60e3 --- /dev/null +++ b/threestudio/scripts/train_text_to_image_lora.py @@ -0,0 +1,927 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA.""" + +import argparse +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_snr +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.24.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- lora +inference: true +--- + """ + model_card = f""" +# LoRA text2image fine-tuning - {repo_id} +These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=1, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + return args + + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + # freeze parameters of models to save more memory + unet.requires_grad_(False) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # now we will add new LoRA weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 + # => 32 layers + + # Set correct lora layers + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=args.rank, + ) + + unet.set_attn_processor(lora_attn_procs) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + lora_layers = AttnProcsLayers(unet.attn_processors) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + lora_layers.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + examples["input_ids"] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + return {"pixel_values": pixel_values, "input_ids": input_ids} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + lora_layers, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("text2image-fine-tune", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Predict the noise residual and compute loss + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = lora_layers.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unet.to(torch.float32) + unet.save_attn_procs(args.output_dir) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + dataset_name=args.dataset_name, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + # Final inference + # Load previous pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype + ) + pipeline = pipeline.to(accelerator.device) + + # load attention processors + pipeline.unet.load_attn_procs(args.output_dir) + + # run inference + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + + if accelerator.is_main_process: + for tracker in accelerator.trackers: + if len(images) != 0: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/threestudio/systems/__init__.py b/threestudio/systems/__init__.py new file mode 100644 index 0000000..e3bcf23 --- /dev/null +++ b/threestudio/systems/__init__.py @@ -0,0 +1 @@ +from . import dreamcraft3d, zero123 diff --git a/threestudio/systems/base.py b/threestudio/systems/base.py new file mode 100644 index 0000000..37b7fba --- /dev/null +++ b/threestudio/systems/base.py @@ -0,0 +1,396 @@ +import os +from dataclasses import dataclass, field + +import pytorch_lightning as pl +import torch.nn.functional as F + +import threestudio +from threestudio.models.exporters.base import Exporter, ExporterOutput +from threestudio.systems.utils import parse_optimizer, parse_scheduler +from threestudio.utils.base import ( + Updateable, + update_end_if_possible, + update_if_possible, +) +from threestudio.utils.config import parse_structured +from threestudio.utils.misc import C, cleanup, get_device, load_module_weights, find_last_path +from threestudio.utils.saving import SaverMixin +from threestudio.utils.typing import * + + +class BaseSystem(pl.LightningModule, Updateable, SaverMixin): + @dataclass + class Config: + loggers: dict = field(default_factory=dict) + loss: dict = field(default_factory=dict) + optimizer: dict = field(default_factory=dict) + scheduler: Optional[dict] = None + weights: Optional[str] = None + weights_ignore_modules: Optional[List[str]] = None + cleanup_after_validation_step: bool = False + cleanup_after_test_step: bool = False + + cfg: Config + + def __init__(self, cfg, resumed=False) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self._save_dir: Optional[str] = None + self._resumed: bool = resumed + self._resumed_eval: bool = False + self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} + if "loggers" in cfg: + self.create_loggers(cfg.loggers) + + self.configure() + if self.cfg.weights is not None: + self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules) + self.post_configure() + + def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None): + state_dict, epoch, global_step = load_module_weights( + weights, ignore_modules=ignore_modules, map_location="cpu" + ) + self.load_state_dict(state_dict, strict=False) + # restore step-dependent states + self.do_update_step(epoch, global_step, on_load_weights=True) + + def set_resume_status(self, current_epoch: int, global_step: int): + # restore correct epoch and global step in eval + self._resumed_eval = True + self._resumed_eval_status["current_epoch"] = current_epoch + self._resumed_eval_status["global_step"] = global_step + + @property + def resumed(self): + # whether from resumed checkpoint + return self._resumed + + @property + def true_global_step(self): + if self._resumed_eval: + return self._resumed_eval_status["global_step"] + else: + return self.global_step + + @property + def true_current_epoch(self): + if self._resumed_eval: + return self._resumed_eval_status["current_epoch"] + else: + return self.current_epoch + + def configure(self) -> None: + pass + + def post_configure(self) -> None: + """ + executed after weights are loaded + """ + pass + + def C(self, value: Any) -> float: + return C(value, self.true_current_epoch, self.true_global_step) + + def configure_optimizers(self): + optim = parse_optimizer(self.cfg.optimizer, self) + ret = { + "optimizer": optim, + } + if self.cfg.scheduler is not None: + ret.update( + { + "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), + } + ) + return ret + + def training_step(self, batch, batch_idx): + raise NotImplementedError + + def validation_step(self, batch, batch_idx): + raise NotImplementedError + + def on_train_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.train_dataloader.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + + def on_validation_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.val_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_validation_step: + # cleanup to save vram + cleanup() + + def on_validation_epoch_end(self): + raise NotImplementedError + + def test_step(self, batch, batch_idx): + raise NotImplementedError + + def on_test_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.test_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_test_step: + # cleanup to save vram + cleanup() + + def on_test_epoch_end(self): + pass + + def predict_step(self, batch, batch_idx): + raise NotImplementedError + + def on_predict_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.predict_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_test_step: + # cleanup to save vram + cleanup() + + def on_predict_epoch_end(self): + pass + + def preprocess_data(self, batch, stage): + pass + + """ + Implementing on_after_batch_transfer of DataModule does the same. + But on_after_batch_transfer does not support DP. + """ + + def on_train_batch_start(self, batch, batch_idx, unused=0): + self.preprocess_data(batch, "train") + self.dataset = self.trainer.train_dataloader.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "validation") + self.dataset = self.trainer.val_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "test") + self.dataset = self.trainer.test_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "predict") + self.dataset = self.trainer.predict_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + pass + + def on_before_optimizer_step(self, optimizer): + """ + # some gradient-related debugging goes here, example: + from lightning.pytorch.utilities import grad_norm + norms = grad_norm(self.geometry, norm_type=2) + print(norms) + """ + pass + + +class BaseLift3DSystem(BaseSystem): + @dataclass + class Config(BaseSystem.Config): + geometry_type: str = "" + geometry: dict = field(default_factory=dict) + geometry_convert_from: Optional[str] = None + geometry_convert_inherit_texture: bool = False + # used to override configurations of the previous geometry being converted from, + # for example isosurface_threshold + geometry_convert_override: dict = field(default_factory=dict) + + material_type: str = "" + material: dict = field(default_factory=dict) + + background_type: str = "" + background: dict = field(default_factory=dict) + + renderer_type: str = "" + renderer: dict = field(default_factory=dict) + + guidance_type: str = "" + guidance: dict = field(default_factory=dict) + + prompt_processor_type: str = "" + prompt_processor: dict = field(default_factory=dict) + + # geometry export configurations, no need to specify in training + exporter_type: str = "mesh-exporter" + exporter: dict = field(default_factory=dict) + + cfg: Config + + def configure(self) -> None: + self.cfg.geometry_convert_from = find_last_path(self.cfg.geometry_convert_from) + self.cfg.weights = find_last_path(self.cfg.weights) + if ( + self.cfg.geometry_convert_from # from_coarse must be specified + and not self.cfg.weights # not initialized from coarse when weights are specified + and not self.resumed # not initialized from coarse when resumed from checkpoints + ): + threestudio.info("Initializing geometry from a given checkpoint ...") + from threestudio.utils.config import load_config, parse_structured + + prev_cfg = load_config( + os.path.join( + os.path.dirname(self.cfg.geometry_convert_from), + "../configs/parsed.yaml", + ) + ) # TODO: hard-coded relative path + prev_system_cfg: BaseLift3DSystem.Config = parse_structured( + self.Config, prev_cfg.system + ) + prev_geometry_cfg = prev_system_cfg.geometry + prev_geometry_cfg.update(self.cfg.geometry_convert_override) + prev_geometry = threestudio.find(prev_system_cfg.geometry_type)( + prev_geometry_cfg + ) + state_dict, epoch, global_step = load_module_weights( + self.cfg.geometry_convert_from, + module_name="geometry", + map_location="cpu", + ) + prev_geometry.load_state_dict(state_dict, strict=False) + # restore step-dependent states + prev_geometry.do_update_step(epoch, global_step, on_load_weights=True) + # convert from coarse stage geometry + prev_geometry = prev_geometry.to(get_device()) + self.geometry = threestudio.find(self.cfg.geometry_type).create_from( + prev_geometry, + self.cfg.geometry, + copy_net=self.cfg.geometry_convert_inherit_texture, + ) + del prev_geometry + cleanup() + else: + self.geometry = threestudio.find(self.cfg.geometry_type)(self.cfg.geometry) + + self.material = threestudio.find(self.cfg.material_type)(self.cfg.material) + self.background = threestudio.find(self.cfg.background_type)( + self.cfg.background + ) + self.renderer = threestudio.find(self.cfg.renderer_type)( + self.cfg.renderer, + geometry=self.geometry, + material=self.material, + background=self.background, + ) + + def on_fit_start(self) -> None: + if self._save_dir is not None: + threestudio.info(f"Validation results will be saved to {self._save_dir}") + else: + threestudio.warn( + f"Saving directory not set for the system, visualization results will not be saved" + ) + + def on_test_end(self) -> None: + if self._save_dir is not None: + threestudio.info(f"Test results saved to {self._save_dir}") + + def on_predict_start(self) -> None: + self.exporter: Exporter = threestudio.find(self.cfg.exporter_type)( + self.cfg.exporter, + geometry=self.geometry, + material=self.material, + background=self.background, + ) + + def predict_step(self, batch, batch_idx): + if self.exporter.cfg.save_video: + self.test_step(batch, batch_idx) + + def on_predict_epoch_end(self) -> None: + if self.exporter.cfg.save_video: + self.on_test_epoch_end() + exporter_output: List[ExporterOutput] = self.exporter() + for out in exporter_output: + save_func_name = f"save_{out.save_type}" + if not hasattr(self, save_func_name): + raise ValueError(f"{save_func_name} not supported by the SaverMixin") + save_func = getattr(self, save_func_name) + save_func(f"it{self.true_global_step}-export/{out.save_name}", **out.params) + + def on_predict_end(self) -> None: + if self._save_dir is not None: + threestudio.info(f"Export assets saved to {self._save_dir}") + + def guidance_evaluation_save(self, comp_rgb, guidance_eval_out): + B, size = comp_rgb.shape[:2] + resize = lambda x: F.interpolate( + x.permute(0, 3, 1, 2), (size, size), mode="bilinear", align_corners=False + ).permute(0, 2, 3, 1) + filename = f"it{self.true_global_step}-train.png" + + def merge12(x): + return x.reshape(-1, *x.shape[2:]) + + self.save_image_grid( + filename, + [ + { + "type": "rgb", + "img": merge12(comp_rgb), + "kwargs": {"data_format": "HWC"}, + }, + ] + + ( + [ + { + "type": "rgb", + "img": merge12(resize(guidance_eval_out["imgs_noisy"])), + "kwargs": {"data_format": "HWC"}, + } + ] + ) + + ( + [ + { + "type": "rgb", + "img": merge12(resize(guidance_eval_out["imgs_1step"])), + "kwargs": {"data_format": "HWC"}, + } + ] + ) + + ( + [ + { + "type": "rgb", + "img": merge12(resize(guidance_eval_out["imgs_1orig"])), + "kwargs": {"data_format": "HWC"}, + } + ] + ) + + ( + [ + { + "type": "rgb", + "img": merge12(resize(guidance_eval_out["imgs_final"])), + "kwargs": {"data_format": "HWC"}, + } + ] + ), + name="train_step", + step=self.true_global_step, + texts=guidance_eval_out["texts"], + ) \ No newline at end of file diff --git a/threestudio/systems/dreamcraft3d.py b/threestudio/systems/dreamcraft3d.py new file mode 100644 index 0000000..5328d39 --- /dev/null +++ b/threestudio/systems/dreamcraft3d.py @@ -0,0 +1,608 @@ +import os +import random +import shutil +from dataclasses import dataclass, field +import cv2 +import clip +import torch +import shutil +import numpy as np +import torch.nn.functional as F +from torchmetrics import PearsonCorrCoef + +import threestudio +from threestudio.systems.base import BaseLift3DSystem +from threestudio.utils.ops import binary_cross_entropy, dot +from threestudio.utils.typing import * +from threestudio.utils.misc import get_rank, get_device, load_module_weights +from threestudio.utils.perceptual import PerceptualLoss + + +@threestudio.register("dreamcraft3d-system") +class ImageConditionDreamFusion(BaseLift3DSystem): + @dataclass + class Config(BaseLift3DSystem.Config): + # in ['coarse', 'geometry', 'texture']. + # Note that in the paper we consolidate 'coarse' and 'geometry' into a single phase called 'geometry-sculpting'. + stage: str = "coarse" + freq: dict = field(default_factory=dict) + guidance_3d_type: str = "" + guidance_3d: dict = field(default_factory=dict) + use_mixed_camera_config: bool = False + control_guidance_type: str = "" + control_guidance: dict = field(default_factory=dict) + control_prompt_processor_type: str = "" + control_prompt_processor: dict = field(default_factory=dict) + visualize_samples: bool = False + + cfg: Config + + def configure(self): + # create geometry, material, background, renderer + super().configure() + self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance) + if self.cfg.guidance_3d_type != "": + self.guidance_3d = threestudio.find(self.cfg.guidance_3d_type)( + self.cfg.guidance_3d + ) + else: + self.guidance_3d = None + self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)( + self.cfg.prompt_processor + ) + self.prompt_utils = self.prompt_processor() + + p_config = {} + self.perceptual_loss = threestudio.find("perceptual-loss")(p_config) + + if not (self.cfg.control_guidance_type == ""): + self.control_guidance = threestudio.find(self.cfg.control_guidance_type)(self.cfg.control_guidance) + self.control_prompt_processor = threestudio.find(self.cfg.control_prompt_processor_type)( + self.cfg.control_prompt_processor + ) + self.control_prompt_utils = self.control_prompt_processor() + + def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: + if self.cfg.stage == "texture": + render_out = self.renderer(**batch, render_mask=True) + else: + render_out = self.renderer(**batch) + return { + **render_out, + } + + def on_fit_start(self) -> None: + super().on_fit_start() + + # visualize all training images + all_images = self.trainer.datamodule.train_dataloader().dataset.get_all_images() + self.save_image_grid( + "all_training_images.png", + [ + {"type": "rgb", "img": image, "kwargs": {"data_format": "HWC"}} + for image in all_images + ], + name="on_fit_start", + step=self.true_global_step, + ) + + self.pearson = PearsonCorrCoef().to(self.device) + + def training_substep(self, batch, batch_idx, guidance: str, render_type="rgb"): + """ + Args: + guidance: one of "ref" (reference image supervision), "guidance" + """ + + gt_mask = batch["mask"] + gt_rgb = batch["rgb"] + gt_depth = batch["ref_depth"] + gt_normal = batch["ref_normal"] + mvp_mtx_ref = batch["mvp_mtx"] + c2w_ref = batch["c2w4x4"] + + if guidance == "guidance": + batch = batch["random_camera"] + + # Support rendering visibility mask + batch["mvp_mtx_ref"] = mvp_mtx_ref + batch["c2w_ref"] = c2w_ref + + out = self(batch) + loss_prefix = f"loss_{guidance}_" + + loss_terms = {} + + def set_loss(name, value): + loss_terms[f"{loss_prefix}{name}"] = value + + guidance_eval = ( + guidance == "guidance" + and self.cfg.freq.guidance_eval > 0 + and self.true_global_step % self.cfg.freq.guidance_eval == 0 + ) + + prompt_utils = self.prompt_processor() + + if guidance == "ref": + if render_type == "rgb": + # color loss. Use l2 loss in coarse and geometry satge; use l1 loss in texture stage. + if self.C(self.cfg.loss.lambda_rgb) > 0: + gt_rgb = gt_rgb * gt_mask.float() + out["comp_rgb_bg"] * ( + 1 - gt_mask.float() + ) + pred_rgb = out["comp_rgb"] + if self.cfg.stage in ["coarse", "geometry"]: + set_loss("rgb", F.mse_loss(gt_rgb, pred_rgb)) + else: + if self.cfg.stage == "texture": + grow_mask = F.max_pool2d(1 - gt_mask.float().permute(0, 3, 1, 2), (9, 9), 1, 4) + grow_mask = (1 - grow_mask).permute(0, 2, 3, 1) + set_loss("rgb", F.l1_loss(gt_rgb*grow_mask, pred_rgb*grow_mask)) + else: + set_loss("rgb", F.l1_loss(gt_rgb, pred_rgb)) + + # mask loss + if self.C(self.cfg.loss.lambda_mask) > 0: + set_loss("mask", F.mse_loss(gt_mask.float(), out["opacity"])) + + # mask binary cross loss + if self.C(self.cfg.loss.lambda_mask_binary) > 0: + set_loss("mask_binary", F.binary_cross_entropy( + out["opacity"].clamp(1.0e-5, 1.0 - 1.0e-5), + batch["mask"].float(),)) + + # depth loss + if self.C(self.cfg.loss.lambda_depth) > 0: + valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)].unsqueeze(1) + valid_pred_depth = out["depth"][gt_mask].unsqueeze(1) + with torch.no_grad(): + A = torch.cat( + [valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1 + ) # [B, 2] + X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1] + valid_gt_depth = A @ X # [B, 1] + set_loss("depth", F.mse_loss(valid_gt_depth, valid_pred_depth)) + + # relative depth loss + if self.C(self.cfg.loss.lambda_depth_rel) > 0: + valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)] # [B,] + valid_pred_depth = out["depth"][gt_mask] # [B,] + set_loss( + "depth_rel", 1 - self.pearson(valid_pred_depth, valid_gt_depth) + ) + + # normal loss + if self.C(self.cfg.loss.lambda_normal) > 0: + valid_gt_normal = ( + 1 - 2 * gt_normal[gt_mask.squeeze(-1)] + ) # [B, 3] + # FIXME: reverse x axis + pred_normal = out["comp_normal_viewspace"] + pred_normal[..., 0] = 1 - pred_normal[..., 0] + valid_pred_normal = ( + 2 * pred_normal[gt_mask.squeeze(-1)] - 1 + ) # [B, 3] + set_loss( + "normal", + 1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean(), + ) + + elif guidance == "guidance" and self.true_global_step > self.cfg.freq.no_diff_steps: + if self.cfg.stage == "geometry" and render_type == "normal": + guidance_inp = out["comp_normal"] + else: + guidance_inp = out["comp_rgb"] + guidance_out = self.guidance( + guidance_inp, + prompt_utils, + **batch, + rgb_as_latents=False, + guidance_eval=guidance_eval, + mask=out["mask"] if "mask" in out else None, + ) + for name, value in guidance_out.items(): + self.log(f"train/{name}", value) + if name.startswith("loss_"): + set_loss(name.split("_")[-1], value) + + if self.guidance_3d is not None: + + # FIXME: use mixed camera config + if not self.cfg.use_mixed_camera_config or get_rank() % 2 == 0: + guidance_3d_out = self.guidance_3d( + out["comp_rgb"], + **batch, + rgb_as_latents=False, + guidance_eval=guidance_eval, + ) + for name, value in guidance_3d_out.items(): + if not (isinstance(value, torch.Tensor) and len(value.shape) > 0): + self.log(f"train/{name}_3d", value) + if name.startswith("loss_"): + set_loss("3d_"+name.split("_")[-1], value) + # set_loss("3d_sd", guidance_out["loss_sd"]) + + # Regularization + if self.C(self.cfg.loss.lambda_normal_smooth) > 0: + if "comp_normal" not in out: + raise ValueError( + "comp_normal is required for 2D normal smooth loss, no comp_normal is found in the output." + ) + normal = out["comp_normal"] + set_loss( + "normal_smooth", + (normal[:, 1:, :, :] - normal[:, :-1, :, :]).square().mean() + + (normal[:, :, 1:, :] - normal[:, :, :-1, :]).square().mean(), + ) + + if self.C(self.cfg.loss.lambda_3d_normal_smooth) > 0: + if "normal" not in out: + raise ValueError( + "Normal is required for normal smooth loss, no normal is found in the output." + ) + if "normal_perturb" not in out: + raise ValueError( + "normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." + ) + normals = out["normal"] + normals_perturb = out["normal_perturb"] + set_loss("3d_normal_smooth", (normals - normals_perturb).abs().mean()) + + if self.cfg.stage == "coarse": + if self.C(self.cfg.loss.lambda_orient) > 0: + if "normal" not in out: + raise ValueError( + "Normal is required for orientation loss, no normal is found in the output." + ) + set_loss( + "orient", + ( + out["weights"].detach() + * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2 + ).sum() + / (out["opacity"] > 0).sum(), + ) + + if guidance != "ref" and self.C(self.cfg.loss.lambda_sparsity) > 0: + set_loss("sparsity", (out["opacity"] ** 2 + 0.01).sqrt().mean()) + + if self.C(self.cfg.loss.lambda_opaque) > 0: + opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) + set_loss( + "opaque", binary_cross_entropy(opacity_clamped, opacity_clamped) + ) + + if "lambda_eikonal" in self.cfg.loss and self.C(self.cfg.loss.lambda_eikonal) > 0: + if "sdf_grad" not in out: + raise ValueError( + "SDF grad is required for eikonal loss, no normal is found in the output." + ) + set_loss( + "eikonal", ( + (torch.linalg.norm(out["sdf_grad"], ord=2, dim=-1) - 1.0) ** 2 + ).mean() + ) + + if "lambda_z_variance"in self.cfg.loss and self.C(self.cfg.loss.lambda_z_variance) > 0: + # z variance loss proposed in HiFA: http://arxiv.org/abs/2305.18766 + # helps reduce floaters and produce solid geometry + loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean() + set_loss("z_variance", loss_z_variance) + + elif self.cfg.stage == "geometry": + if self.C(self.cfg.loss.lambda_normal_consistency) > 0: + set_loss("normal_consistency", out["mesh"].normal_consistency()) + if self.C(self.cfg.loss.lambda_laplacian_smoothness) > 0: + set_loss("laplacian_smoothness", out["mesh"].laplacian()) + elif self.cfg.stage == "texture": + if self.C(self.cfg.loss.lambda_reg) > 0 and guidance == "guidance" and self.true_global_step % 5 == 0: + + rgb = out["comp_rgb"] + rgb = F.interpolate(rgb.permute(0, 3, 1, 2), (512, 512), mode='bilinear').permute(0, 2, 3, 1) + control_prompt_utils = self.control_prompt_processor() + with torch.no_grad(): + control_dict = self.control_guidance( + rgb=rgb, + cond_rgb=rgb, + prompt_utils=control_prompt_utils, + mask=out["mask"] if "mask" in out else None, + ) + + edit_images = control_dict["edit_images"] + temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8) + cv2.imwrite(".threestudio_cache/control_debug.jpg", temp[:, :, ::-1]) + + loss_reg = (rgb.shape[1] // 8) * (rgb.shape[2] // 8) * self.perceptual_loss(edit_images.permute(0, 3, 1, 2), rgb.permute(0, 3, 1, 2)).mean() + set_loss("reg", loss_reg) + else: + raise ValueError(f"Unknown stage {self.cfg.stage}") + + loss = 0.0 + for name, value in loss_terms.items(): + self.log(f"train/{name}", value) + if name.startswith(loss_prefix): + loss_weighted = value * self.C( + self.cfg.loss[name.replace(loss_prefix, "lambda_")] + ) + self.log(f"train/{name}_w", loss_weighted) + loss += loss_weighted + + for name, value in self.cfg.loss.items(): + self.log(f"train_params/{name}", self.C(value)) + + self.log(f"train/loss_{guidance}", loss) + + if guidance_eval: + self.guidance_evaluation_save( + out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]], + guidance_out["eval"], + ) + + return {"loss": loss} + + def training_step(self, batch, batch_idx): + if self.cfg.freq.ref_or_guidance == "accumulate": + do_ref = True + do_guidance = True + elif self.cfg.freq.ref_or_guidance == "alternate": + do_ref = ( + self.true_global_step < self.cfg.freq.ref_only_steps + or self.true_global_step % self.cfg.freq.n_ref == 0 + ) + do_guidance = not do_ref + if hasattr(self.guidance.cfg, "only_pretrain_step"): + if (self.guidance.cfg.only_pretrain_step > 0) and (self.global_step % self.guidance.cfg.only_pretrain_step) < (self.guidance.cfg.only_pretrain_step // 5): + do_guidance = True + do_ref = False + + if self.cfg.stage == "geometry": + render_type = "rgb" if self.true_global_step % self.cfg.freq.n_rgb == 0 else "normal" + else: + render_type = "rgb" + + total_loss = 0.0 + + if do_guidance: + out = self.training_substep(batch, batch_idx, guidance="guidance", render_type=render_type) + total_loss += out["loss"] + + if do_ref: + out = self.training_substep(batch, batch_idx, guidance="ref", render_type=render_type) + total_loss += out["loss"] + + self.log("train/loss", total_loss, prog_bar=True) + + # sch = self.lr_schedulers() + # sch.step() + + return {"loss": total_loss} + + def validation_step(self, batch, batch_idx): + out = self(batch) + self.save_image_grid( + f"it{self.true_global_step}-val/{batch['index'][0]}.png", + ( + [ + { + "type": "rgb", + "img": batch["rgb"][0], + "kwargs": {"data_format": "HWC"}, + } + ] + if "rgb" in batch + else [] + ) + + ( + [ + { + "type": "rgb", + "img": out["comp_rgb"][0], + "kwargs": {"data_format": "HWC"}, + }, + ] + if "comp_rgb" in out + else [] + ) + + ( + [ + { + "type": "rgb", + "img": out["comp_normal"][0], + "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, + } + ] + if "comp_normal" in out + else [] + ) + + ( + [ + { + "type": "rgb", + "img": out["comp_normal_viewspace"][0], + "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, + } + ] + if "comp_normal_viewspace" in out + else [] + ) + + ( + [ + { + "type": "grayscale", + "img": out["depth"][0], + "kwargs": {} + } + ] + if "depth" in out + else [] + ) + + [ + { + "type": "grayscale", + "img": out["opacity"][0, :, :, 0], + "kwargs": {"cmap": None, "data_range": (0, 1)}, + }, + ], + + name="validation_step", + step=self.true_global_step, + ) + + if self.cfg.stage=="texture" and self.cfg.visualize_samples: + self.save_image_grid( + f"it{self.true_global_step}-{batch['index'][0]}-sample.png", + [ + { + "type": "rgb", + "img": self.guidance.sample( + self.prompt_utils, **batch, seed=self.global_step + )[0], + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": self.guidance.sample_lora(self.prompt_utils, **batch)[0], + "kwargs": {"data_format": "HWC"}, + }, + ], + name="validation_step_samples", + step=self.true_global_step, + ) + + def on_validation_epoch_end(self): + filestem = f"it{self.true_global_step}-val" + + try: + self.save_img_sequence( + filestem, + filestem, + "(\d+)\.png", + save_format="mp4", + fps=30, + name="validation_epoch_end", + step=self.true_global_step, + ) + shutil.rmtree( + os.path.join(self.get_save_dir(), f"it{self.true_global_step}-val") + ) + except: + pass + + def test_step(self, batch, batch_idx): + out = self(batch) + self.save_image_grid( + f"it{self.true_global_step}-test/{batch['index'][0]}.png", + ( + [ + { + "type": "rgb", + "img": batch["rgb"][0], + "kwargs": {"data_format": "HWC"}, + } + ] + if "rgb" in batch + else [] + ) + + ( + [ + { + "type": "rgb", + "img": out["comp_rgb"][0], + "kwargs": {"data_format": "HWC"}, + }, + ] + if "comp_rgb" in out + else [] + ) + + ( + [ + { + "type": "rgb", + "img": out["comp_normal"][0], + "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, + } + ] + if "comp_normal" in out + else [] + ) + + ( + [ + { + "type": "rgb", + "img": out["comp_normal_viewspace"][0], + "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, + } + ] + if "comp_normal_viewspace" in out + else [] + ) + + ( + [ + { + "type": "grayscale", "img": out["depth"][0], "kwargs": {} + } + ] + if "depth" in out + else [] + ) + + [ + { + "type": "grayscale", + "img": out["opacity"][0, :, :, 0], + "kwargs": {"cmap": None, "data_range": (0, 1)}, + }, + ] + + ( + [ + { + "type": "grayscale", "img": out["opacity_vis"][0, :, :, 0], + "kwargs": {"cmap": None, "data_range": (0, 1)} + } + ] + if "opacity_vis" in out + else [] + ) + , + name="test_step", + step=self.true_global_step, + ) + + # FIXME: save camera extrinsics + c2w = batch["c2w"] + save_path = os.path.join(self.get_save_dir(), f"it{self.true_global_step}-test/{batch['index'][0]}.npy") + np.save(save_path, c2w.detach().cpu().numpy()[0]) + + def on_test_epoch_end(self): + self.save_img_sequence( + f"it{self.true_global_step}-test", + f"it{self.true_global_step}-test", + "(\d+)\.png", + save_format="mp4", + fps=30, + name="test", + step=self.true_global_step, + ) + + def on_before_optimizer_step(self, optimizer) -> None: + # print("on_before_opt enter") + # for n, p in self.geometry.named_parameters(): + # if p.grad is None: + # print(n) + # print("on_before_opt exit") + + pass + + def on_load_checkpoint(self, checkpoint): + for k in list(checkpoint['state_dict'].keys()): + if k.startswith("guidance."): + return + guidance_state_dict = {"guidance."+k : v for (k,v) in self.guidance.state_dict().items()} + checkpoint['state_dict'] = {**checkpoint['state_dict'], **guidance_state_dict} + return + + def on_save_checkpoint(self, checkpoint): + for k in list(checkpoint['state_dict'].keys()): + if k.startswith("guidance."): + checkpoint['state_dict'].pop(k) + return \ No newline at end of file diff --git a/threestudio/systems/utils.py b/threestudio/systems/utils.py new file mode 100644 index 0000000..177c105 --- /dev/null +++ b/threestudio/systems/utils.py @@ -0,0 +1,104 @@ +import sys +import warnings +from bisect import bisect_right + +import torch +import torch.nn as nn +from torch.optim import lr_scheduler + +import threestudio + + +def get_scheduler(name): + if hasattr(lr_scheduler, name): + return getattr(lr_scheduler, name) + else: + raise NotImplementedError + + +def getattr_recursive(m, attr): + for name in attr.split("."): + m = getattr(m, name) + return m + + +def get_parameters(model, name): + module = getattr_recursive(model, name) + if isinstance(module, nn.Module): + return module.parameters() + elif isinstance(module, nn.Parameter): + return module + return [] + + +def parse_optimizer(config, model): + if hasattr(config, "params"): + params = [ + {"params": get_parameters(model, name), "name": name, **args} + for name, args in config.params.items() + ] + threestudio.debug(f"Specify optimizer params: {config.params}") + else: + params = model.parameters() + if config.name in ["FusedAdam"]: + import apex + + optim = getattr(apex.optimizers, config.name)(params, **config.args) + elif config.name in ["Adan"]: + from threestudio.systems import optimizers + + optim = getattr(optimizers, config.name)(params, **config.args) + else: + optim = getattr(torch.optim, config.name)(params, **config.args) + return optim + + +def parse_scheduler_to_instance(config, optimizer): + if config.name == "ChainedScheduler": + schedulers = [ + parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers + ] + scheduler = lr_scheduler.ChainedScheduler(schedulers) + elif config.name == "Sequential": + schedulers = [ + parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers + ] + scheduler = lr_scheduler.SequentialLR( + optimizer, schedulers, milestones=config.milestones + ) + else: + scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) + return scheduler + + +def parse_scheduler(config, optimizer): + interval = config.get("interval", "epoch") + assert interval in ["epoch", "step"] + if config.name == "SequentialLR": + scheduler = { + "scheduler": lr_scheduler.SequentialLR( + optimizer, + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ], + milestones=config.milestones, + ), + "interval": interval, + } + elif config.name == "ChainedScheduler": + scheduler = { + "scheduler": lr_scheduler.ChainedScheduler( + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ] + ), + "interval": interval, + } + else: + scheduler = { + "scheduler": get_scheduler(config.name)(optimizer, **config.args), + "interval": interval, + } + return scheduler \ No newline at end of file diff --git a/threestudio/systems/zero123.py b/threestudio/systems/zero123.py new file mode 100644 index 0000000..751c423 --- /dev/null +++ b/threestudio/systems/zero123.py @@ -0,0 +1,390 @@ +import os +import random +import shutil +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from PIL import Image, ImageDraw +from torchmetrics import PearsonCorrCoef + +import threestudio +from threestudio.systems.base import BaseLift3DSystem +from threestudio.utils.ops import binary_cross_entropy, dot +from threestudio.utils.typing import * + + +@threestudio.register("zero123-system") +class Zero123(BaseLift3DSystem): + @dataclass + class Config(BaseLift3DSystem.Config): + freq: dict = field(default_factory=dict) + refinement: bool = False + ambient_ratio_min: float = 0.5 + + cfg: Config + + def configure(self): + # create geometry, material, background, renderer + super().configure() + + def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: + render_out = self.renderer(**batch) + return { + **render_out, + } + + def on_fit_start(self) -> None: + super().on_fit_start() + # no prompt processor + self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance) + + # visualize all training images + all_images = self.trainer.datamodule.train_dataloader().dataset.get_all_images() + self.save_image_grid( + "all_training_images.png", + [ + {"type": "rgb", "img": image, "kwargs": {"data_format": "HWC"}} + for image in all_images + ], + name="on_fit_start", + step=self.true_global_step, + ) + + self.pearson = PearsonCorrCoef().to(self.device) + + def training_substep(self, batch, batch_idx, guidance: str): + """ + Args: + guidance: one of "ref" (reference image supervision), "zero123" + """ + if guidance == "ref": + # bg_color = torch.rand_like(batch['rays_o']) + ambient_ratio = 1.0 + shading = "diffuse" + batch["shading"] = shading + elif guidance == "zero123": + batch = batch["random_camera"] + ambient_ratio = ( + self.cfg.ambient_ratio_min + + (1 - self.cfg.ambient_ratio_min) * random.random() + ) + + batch["bg_color"] = None + batch["ambient_ratio"] = ambient_ratio + + out = self(batch) + loss_prefix = f"loss_{guidance}_" + + loss_terms = {} + + def set_loss(name, value): + loss_terms[f"{loss_prefix}{name}"] = value + + guidance_eval = ( + guidance == "zero123" + and self.cfg.freq.guidance_eval > 0 + and self.true_global_step % self.cfg.freq.guidance_eval == 0 + ) + + if guidance == "ref": + gt_mask = batch["mask"] + gt_rgb = batch["rgb"] + + # color loss + gt_rgb = gt_rgb * gt_mask.float() + out["comp_rgb_bg"] * ( + 1 - gt_mask.float() + ) + set_loss("rgb", F.mse_loss(gt_rgb, out["comp_rgb"])) + + # mask loss + set_loss("mask", F.mse_loss(gt_mask.float(), out["opacity"])) + + # depth loss + if self.C(self.cfg.loss.lambda_depth) > 0: + valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)].unsqueeze(1) + valid_pred_depth = out["depth"][gt_mask].unsqueeze(1) + with torch.no_grad(): + A = torch.cat( + [valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1 + ) # [B, 2] + X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1] + valid_gt_depth = A @ X # [B, 1] + set_loss("depth", F.mse_loss(valid_gt_depth, valid_pred_depth)) + + # relative depth loss + if self.C(self.cfg.loss.lambda_depth_rel) > 0: + valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)] # [B,] + valid_pred_depth = out["depth"][gt_mask] # [B,] + set_loss( + "depth_rel", 1 - self.pearson(valid_pred_depth, valid_gt_depth) + ) + + # normal loss + if self.C(self.cfg.loss.lambda_normal) > 0: + valid_gt_normal = ( + 1 - 2 * batch["ref_normal"][gt_mask.squeeze(-1)] + ) # [B, 3] + valid_pred_normal = ( + 2 * out["comp_normal"][gt_mask.squeeze(-1)] - 1 + ) # [B, 3] + set_loss( + "normal", + 1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean(), + ) + elif guidance == "zero123": + # zero123 + guidance_out = self.guidance( + out["comp_rgb"], + **batch, + rgb_as_latents=False, + guidance_eval=guidance_eval, + ) + # claforte: TODO: rename the loss_terms keys + set_loss("sds", guidance_out["loss_sds"]) + + if self.C(self.cfg.loss.lambda_normal_smooth) > 0: + if "comp_normal" not in out: + raise ValueError( + "comp_normal is required for 2D normal smooth loss, no comp_normal is found in the output." + ) + normal = out["comp_normal"] + set_loss( + "normal_smooth", + (normal[:, 1:, :, :] - normal[:, :-1, :, :]).square().mean() + + (normal[:, :, 1:, :] - normal[:, :, :-1, :]).square().mean(), + ) + + if self.C(self.cfg.loss.lambda_3d_normal_smooth) > 0: + if "normal" not in out: + raise ValueError( + "Normal is required for normal smooth loss, no normal is found in the output." + ) + if "normal_perturb" not in out: + raise ValueError( + "normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." + ) + normals = out["normal"] + normals_perturb = out["normal_perturb"] + set_loss("3d_normal_smooth", (normals - normals_perturb).abs().mean()) + + if not self.cfg.refinement: + if self.C(self.cfg.loss.lambda_orient) > 0: + if "normal" not in out: + raise ValueError( + "Normal is required for orientation loss, no normal is found in the output." + ) + set_loss( + "orient", + ( + out["weights"].detach() + * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2 + ).sum() + / (out["opacity"] > 0).sum(), + ) + + if guidance != "ref" and self.C(self.cfg.loss.lambda_sparsity) > 0: + set_loss("sparsity", (out["opacity"] ** 2 + 0.01).sqrt().mean()) + + if self.C(self.cfg.loss.lambda_opaque) > 0: + opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) + set_loss( + "opaque", binary_cross_entropy(opacity_clamped, opacity_clamped) + ) + else: + if self.C(self.cfg.loss.lambda_normal_consistency) > 0: + set_loss("normal_consistency", out["mesh"].normal_consistency()) + if self.C(self.cfg.loss.lambda_laplacian_smoothness) > 0: + set_loss("laplacian_smoothness", out["mesh"].laplacian()) + + loss = 0.0 + for name, value in loss_terms.items(): + self.log(f"train/{name}", value) + if name.startswith(loss_prefix): + loss_weighted = value * self.C( + self.cfg.loss[name.replace(loss_prefix, "lambda_")] + ) + self.log(f"train/{name}_w", loss_weighted) + loss += loss_weighted + + for name, value in self.cfg.loss.items(): + self.log(f"train_params/{name}", self.C(value)) + + self.log(f"train/loss_{guidance}", loss) + + if guidance_eval: + self.guidance_evaluation_save( + out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]], + guidance_out["eval"], + ) + + return {"loss": loss} + + def training_step(self, batch, batch_idx): + if self.cfg.freq.get("ref_or_zero123", "accumulate") == "accumulate": + do_ref = True + do_zero123 = True + elif self.cfg.freq.get("ref_or_zero123", "accumulate") == "alternate": + do_ref = ( + self.true_global_step < self.cfg.freq.ref_only_steps + or self.true_global_step % self.cfg.freq.n_ref == 0 + ) + do_zero123 = not do_ref + + total_loss = 0.0 + if do_zero123: + out = self.training_substep(batch, batch_idx, guidance="zero123") + total_loss += out["loss"] + + if do_ref: + out = self.training_substep(batch, batch_idx, guidance="ref") + total_loss += out["loss"] + + self.log("train/loss", total_loss, prog_bar=True) + + # sch = self.lr_schedulers() + # sch.step() + + return {"loss": total_loss} + + def validation_step(self, batch, batch_idx): + out = self(batch) + self.save_image_grid( + f"it{self.true_global_step}-val/{batch['index'][0]}.png", + ( + [ + { + "type": "rgb", + "img": batch["rgb"][0], + "kwargs": {"data_format": "HWC"}, + } + ] + if "rgb" in batch + else [] + ) + + [ + { + "type": "rgb", + "img": out["comp_rgb"][0], + "kwargs": {"data_format": "HWC"}, + }, + ] + + ( + [ + { + "type": "rgb", + "img": out["comp_normal"][0], + "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, + } + ] + if "comp_normal" in out + else [] + ) + + ( + [ + { + "type": "grayscale", + "img": out["depth"][0], + "kwargs": {}, + } + ] + if "depth" in out + else [] + ) + + [ + { + "type": "grayscale", + "img": out["opacity"][0, :, :, 0], + "kwargs": {"cmap": None, "data_range": (0, 1)}, + }, + ], + # claforte: TODO: don't hardcode the frame numbers to record... read them from cfg instead. + name=f"validation_step_batchidx_{batch_idx}" + if batch_idx in [0, 7, 15, 23, 29] + else None, + step=self.true_global_step, + ) + + def on_validation_epoch_end(self): + filestem = f"it{self.true_global_step}-val" + self.save_img_sequence( + filestem, + filestem, + "(\d+)\.png", + save_format="mp4", + fps=30, + name="validation_epoch_end", + step=self.true_global_step, + ) + shutil.rmtree( + os.path.join(self.get_save_dir(), f"it{self.true_global_step}-val") + ) + + def test_step(self, batch, batch_idx): + out = self(batch) + self.save_image_grid( + f"it{self.true_global_step}-test/{batch['index'][0]}.png", + ( + [ + { + "type": "rgb", + "img": batch["rgb"][0], + "kwargs": {"data_format": "HWC"}, + } + ] + if "rgb" in batch + else [] + ) + + [ + { + "type": "rgb", + "img": out["comp_rgb"][0], + "kwargs": {"data_format": "HWC"}, + }, + ] + + ( + [ + { + "type": "rgb", + "img": out["comp_normal"][0], + "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, + } + ] + if "comp_normal" in out + else [] + ) + + ( + [ + { + "type": "grayscale", + "img": out["depth"][0], + "kwargs": {}, + } + ] + if "depth" in out + else [] + ) + + [ + { + "type": "grayscale", + "img": out["opacity"][0, :, :, 0], + "kwargs": {"cmap": None, "data_range": (0, 1)}, + }, + ], + name="test_step", + step=self.true_global_step, + ) + + def on_test_epoch_end(self): + self.save_img_sequence( + f"it{self.true_global_step}-test", + f"it{self.true_global_step}-test", + "(\d+)\.png", + save_format="mp4", + fps=30, + name="test", + step=self.true_global_step, + ) + shutil.rmtree( + os.path.join(self.get_save_dir(), f"it{self.true_global_step}-test") + ) \ No newline at end of file diff --git a/threestudio/utils/GAN/attention.py b/threestudio/utils/GAN/attention.py new file mode 100644 index 0000000..b2b8b66 --- /dev/null +++ b/threestudio/utils/GAN/attention.py @@ -0,0 +1,278 @@ +import math +from inspect import isfunction + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import einsum, nn + +from threestudio.utils.GAN.network_util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None + ): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/threestudio/utils/GAN/discriminator.py b/threestudio/utils/GAN/discriminator.py new file mode 100644 index 0000000..3b60174 --- /dev/null +++ b/threestudio/utils/GAN/discriminator.py @@ -0,0 +1,217 @@ +import functools + +import torch +import torch.nn as nn + + +def count_params(model): + total_params = sum(p.numel() for p in model.parameters()) + return total_params + + +class ActNorm(nn.Module): + def __init__( + self, num_features, logdet=False, affine=True, allow_reverse_init=False + ): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:, :, None, None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height * width * torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:, :, None, None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class Labelator(AbstractEncoder): + """Net2Net Interface for Class-Conditional Model""" + + def __init__(self, n_classes, quantize_interface=True): + super().__init__() + self.n_classes = n_classes + self.quantize_interface = quantize_interface + + def encode(self, c): + c = c[:, None] + if self.quantize_interface: + return c, None, [None, None, c.long()] + return c + + +class SOSProvider(AbstractEncoder): + # for unconditional training + def __init__(self, sos_token, quantize_interface=True): + super().__init__() + self.sos_token = sos_token + self.quantize_interface = quantize_interface + + def encode(self, x): + # get batch size from data and replicate sos_token + c = torch.ones(x.shape[0], 1) * self.sos_token + c = c.long().to(x.device) + if self.quantize_interface: + return c, None, [None, None, c] + return c + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if ( + type(norm_layer) == functools.partial + ): # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True), + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) \ No newline at end of file diff --git a/threestudio/utils/GAN/distribution.py b/threestudio/utils/GAN/distribution.py new file mode 100644 index 0000000..36b8a3e --- /dev/null +++ b/threestudio/utils/GAN/distribution.py @@ -0,0 +1,102 @@ +import numpy as np +import torch + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) \ No newline at end of file diff --git a/threestudio/utils/GAN/loss.py b/threestudio/utils/GAN/loss.py new file mode 100644 index 0000000..ecd3b71 --- /dev/null +++ b/threestudio/utils/GAN/loss.py @@ -0,0 +1,35 @@ +import torch +import torch.nn.functional as F + + +def generator_loss(discriminator, inputs, reconstructions, cond=None): + if cond is None: + logits_fake = discriminator(reconstructions.contiguous()) + else: + logits_fake = discriminator( + torch.cat((reconstructions.contiguous(), cond), dim=1) + ) + g_loss = -torch.mean(logits_fake) + return g_loss + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def discriminator_loss(discriminator, inputs, reconstructions, cond=None): + if cond is None: + logits_real = discriminator(inputs.contiguous().detach()) + logits_fake = discriminator(reconstructions.contiguous().detach()) + else: + logits_real = discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1) + ) + logits_fake = discriminator( + torch.cat((reconstructions.contiguous().detach(), cond), dim=1) + ) + d_loss = hinge_d_loss(logits_real, logits_fake).mean() + return d_loss \ No newline at end of file diff --git a/threestudio/utils/GAN/mobilenet.py b/threestudio/utils/GAN/mobilenet.py new file mode 100644 index 0000000..d0ab7d2 --- /dev/null +++ b/threestudio/utils/GAN/mobilenet.py @@ -0,0 +1,254 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["MobileNetV3", "mobilenetv3"] + + +def conv_bn( + inp, + oup, + stride, + conv_layer=nn.Conv2d, + norm_layer=nn.BatchNorm2d, + nlin_layer=nn.ReLU, +): + return nn.Sequential( + conv_layer(inp, oup, 3, stride, 1, bias=False), + norm_layer(oup), + nlin_layer(inplace=True), + ) + + +def conv_1x1_bn( + inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU +): + return nn.Sequential( + conv_layer(inp, oup, 1, 1, 0, bias=False), + norm_layer(oup), + nlin_layer(inplace=True), + ) + + +class Hswish(nn.Module): + def __init__(self, inplace=True): + super(Hswish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 + + +class Hsigmoid(nn.Module): + def __init__(self, inplace=True): + super(Hsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return F.relu6(x + 3.0, inplace=self.inplace) / 6.0 + + +class SEModule(nn.Module): + def __init__(self, channel, reduction=4): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + Hsigmoid() + # nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +class Identity(nn.Module): + def __init__(self, channel): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +def make_divisible(x, divisible_by=8): + import numpy as np + + return int(np.ceil(x * 1.0 / divisible_by) * divisible_by) + + +class MobileBottleneck(nn.Module): + def __init__(self, inp, oup, kernel, stride, exp, se=False, nl="RE"): + super(MobileBottleneck, self).__init__() + assert stride in [1, 2] + assert kernel in [3, 5] + padding = (kernel - 1) // 2 + self.use_res_connect = stride == 1 and inp == oup + + conv_layer = nn.Conv2d + norm_layer = nn.BatchNorm2d + if nl == "RE": + nlin_layer = nn.ReLU # or ReLU6 + elif nl == "HS": + nlin_layer = Hswish + else: + raise NotImplementedError + if se: + SELayer = SEModule + else: + SELayer = Identity + + self.conv = nn.Sequential( + # pw + conv_layer(inp, exp, 1, 1, 0, bias=False), + norm_layer(exp), + nlin_layer(inplace=True), + # dw + conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False), + norm_layer(exp), + SELayer(exp), + nlin_layer(inplace=True), + # pw-linear + conv_layer(exp, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV3(nn.Module): + def __init__( + self, n_class=1000, input_size=224, dropout=0.0, mode="small", width_mult=1.0 + ): + super(MobileNetV3, self).__init__() + input_channel = 16 + last_channel = 1280 + if mode == "large": + # refer to Table 1 in paper + mobile_setting = [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, "RE", 1], + [3, 64, 24, False, "RE", 2], + [3, 72, 24, False, "RE", 1], + [5, 72, 40, True, "RE", 2], + [5, 120, 40, True, "RE", 1], + [5, 120, 40, True, "RE", 1], + [3, 240, 80, False, "HS", 2], + [3, 200, 80, False, "HS", 1], + [3, 184, 80, False, "HS", 1], + [3, 184, 80, False, "HS", 1], + [3, 480, 112, True, "HS", 1], + [3, 672, 112, True, "HS", 1], + [5, 672, 160, True, "HS", 2], + [5, 960, 160, True, "HS", 1], + [5, 960, 160, True, "HS", 1], + ] + elif mode == "small": + # refer to Table 2 in paper + mobile_setting = [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, "RE", 2], + [3, 72, 24, False, "RE", 2], + [3, 88, 24, False, "RE", 1], + [5, 96, 40, True, "HS", 2], + [5, 240, 40, True, "HS", 1], + [5, 240, 40, True, "HS", 1], + [5, 120, 48, True, "HS", 1], + [5, 144, 48, True, "HS", 1], + [5, 288, 96, True, "HS", 2], + [5, 576, 96, True, "HS", 1], + [5, 576, 96, True, "HS", 1], + ] + else: + raise NotImplementedError + + # building first layer + assert input_size % 32 == 0 + last_channel = ( + make_divisible(last_channel * width_mult) + if width_mult > 1.0 + else last_channel + ) + self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)] + self.classifier = [] + + # building mobile blocks + for k, exp, c, se, nl, s in mobile_setting: + output_channel = make_divisible(c * width_mult) + exp_channel = make_divisible(exp * width_mult) + self.features.append( + MobileBottleneck( + input_channel, output_channel, k, s, exp_channel, se, nl + ) + ) + input_channel = output_channel + + # building last several layers + if mode == "large": + last_conv = make_divisible(960 * width_mult) + self.features.append( + conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish) + ) + self.features.append(nn.AdaptiveAvgPool2d(1)) + self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) + self.features.append(Hswish(inplace=True)) + elif mode == "small": + last_conv = make_divisible(576 * width_mult) + self.features.append( + conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish) + ) + # self.features.append(SEModule(last_conv)) # refer to paper Table2, but I think this is a mistake + self.features.append(nn.AdaptiveAvgPool2d(1)) + self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) + self.features.append(Hswish(inplace=True)) + else: + raise NotImplementedError + + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(p=dropout), # refer to paper section 6 + nn.Linear(last_channel, n_class), + ) + + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x = x.mean(3).mean(2) + x = self.classifier(x) + return x + + def _initialize_weights(self): + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.zeros_(m.bias) + + +def mobilenetv3(pretrained=False, **kwargs): + model = MobileNetV3(**kwargs) + if pretrained: + state_dict = torch.load("mobilenetv3_small_67.4.pth.tar") + model.load_state_dict(state_dict, strict=True) + # raise NotImplementedError + return model \ No newline at end of file diff --git a/threestudio/utils/GAN/network_util.py b/threestudio/utils/GAN/network_util.py new file mode 100644 index 0000000..3dd0374 --- /dev/null +++ b/threestudio/utils/GAN/network_util.py @@ -0,0 +1,296 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import math +import os + +import numpy as np +import torch +import torch.nn as nn +from einops import repeat + +from threestudio.utils.GAN.util import instantiate_from_config + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/threestudio/utils/GAN/normalunet.py b/threestudio/utils/GAN/normalunet.py new file mode 100644 index 0000000..daa1894 --- /dev/null +++ b/threestudio/utils/GAN/normalunet.py @@ -0,0 +1,401 @@ +""" +Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. +BSD License. All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. +THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. +IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL +DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING +OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +""" +import functools + +import numpy as np +import torch +import torch.nn as nn +from torch.autograd import Variable + + +############################################################################### +# Functions +############################################################################### +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find("BatchNorm2d") != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def get_norm_layer(norm_type="instance"): + if norm_type == "batch": + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == "instance": + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + else: + raise NotImplementedError("normalization layer [%s] is not found" % norm_type) + return norm_layer + + +def define_G( + input_nc, + output_nc, + ngf, + netG, + n_downsample_global=3, + n_blocks_global=9, + n_local_enhancers=1, + n_blocks_local=3, + norm="instance", + gpu_ids=[], + last_op=nn.Tanh(), +): + norm_layer = get_norm_layer(norm_type=norm) + if netG == "global": + netG = GlobalGenerator( + input_nc, + output_nc, + ngf, + n_downsample_global, + n_blocks_global, + norm_layer, + last_op=last_op, + ) + elif netG == "local": + netG = LocalEnhancer( + input_nc, + output_nc, + ngf, + n_downsample_global, + n_blocks_global, + n_local_enhancers, + n_blocks_local, + norm_layer, + ) + elif netG == "encoder": + netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer) + else: + raise ("generator not implemented!") + # print(netG) + if len(gpu_ids) > 0: + assert torch.cuda.is_available() + netG.cuda(gpu_ids[0]) + netG.apply(weights_init) + return netG + + +def print_network(net): + if isinstance(net, list): + net = net[0] + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print("Total number of parameters: %d" % num_params) + + +############################################################################## +# Generator +############################################################################## +class LocalEnhancer(nn.Module): + def __init__( + self, + input_nc, + output_nc, + ngf=32, + n_downsample_global=3, + n_blocks_global=9, + n_local_enhancers=1, + n_blocks_local=3, + norm_layer=nn.BatchNorm2d, + padding_type="reflect", + ): + super(LocalEnhancer, self).__init__() + self.n_local_enhancers = n_local_enhancers + + ###### global generator model ##### + ngf_global = ngf * (2**n_local_enhancers) + model_global = GlobalGenerator( + input_nc, + output_nc, + ngf_global, + n_downsample_global, + n_blocks_global, + norm_layer, + ).model + model_global = [ + model_global[i] for i in range(len(model_global) - 3) + ] # get rid of final convolution layers + self.model = nn.Sequential(*model_global) + + ###### local enhancer layers ##### + for n in range(1, n_local_enhancers + 1): + ### downsample + ngf_global = ngf * (2 ** (n_local_enhancers - n)) + model_downsample = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), + norm_layer(ngf_global), + nn.ReLU(True), + nn.Conv2d( + ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1 + ), + norm_layer(ngf_global * 2), + nn.ReLU(True), + ] + ### residual blocks + model_upsample = [] + for i in range(n_blocks_local): + model_upsample += [ + ResnetBlock( + ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer + ) + ] + + ### upsample + model_upsample += [ + nn.ConvTranspose2d( + ngf_global * 2, + ngf_global, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + norm_layer(ngf_global), + nn.ReLU(True), + ] + + ### final convolution + if n == n_local_enhancers: + model_upsample += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), + nn.Tanh(), + ] + + setattr(self, "model" + str(n) + "_1", nn.Sequential(*model_downsample)) + setattr(self, "model" + str(n) + "_2", nn.Sequential(*model_upsample)) + + self.downsample = nn.AvgPool2d( + 3, stride=2, padding=[1, 1], count_include_pad=False + ) + + def forward(self, input): + ### create input pyramid + input_downsampled = [input] + for i in range(self.n_local_enhancers): + input_downsampled.append(self.downsample(input_downsampled[-1])) + + ### output at coarest level + output_prev = self.model(input_downsampled[-1]) + ### build up one layer at a time + for n_local_enhancers in range(1, self.n_local_enhancers + 1): + model_downsample = getattr(self, "model" + str(n_local_enhancers) + "_1") + model_upsample = getattr(self, "model" + str(n_local_enhancers) + "_2") + input_i = input_downsampled[self.n_local_enhancers - n_local_enhancers] + output_prev = model_upsample(model_downsample(input_i) + output_prev) + return output_prev + + +class NormalNet(nn.Module): + def __init__( + self, + name="normalnet", + input_nc=3, + output_nc=3, + ngf=64, + n_downsampling=4, + n_blocks=9, + norm_layer=nn.BatchNorm2d, + padding_type="reflect", + last_op=nn.Sigmoid(), + ): + assert n_blocks >= 0 + super(NormalNet, self).__init__() + self.name = name + activation = nn.ReLU(True) + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + nn.BatchNorm2d(ngf), + activation, + ] + ### downsample + for i in range(n_downsampling): + mult = 2**i + model += [ + nn.Conv2d( + ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1 + ), + nn.BatchNorm2d(ngf * mult * 2), + activation, + ] + + ### resnet blocks + mult = 2**n_downsampling + for i in range(n_blocks): + model += [ + ResnetBlock( + ngf * mult, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + ) + ] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [ + nn.Upsample(scale_factor=2), + nn.Conv2d( + ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=1 + ), + nn.BatchNorm2d(int(ngf * mult / 2)), + activation, + ] + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), + ] + if last_op is not None: + model += [last_op] + self.model = nn.Sequential(*model) + + def forward(self, in_x, label=None): + res_list = [] + return self.model(in_x) + + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__( + self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False + ): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block( + dim, padding_type, norm_layer, activation, use_dropout + ) + + def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): + conv_block = [] + p = 0 + if padding_type == "reflect": + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == "replicate": + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == "zero": + p = 1 + else: + raise NotImplementedError("padding [%s] is not implemented" % padding_type) + + conv_block += [ + nn.Conv2d(dim, dim, kernel_size=3, padding=p), + nn.BatchNorm2d(dim), + activation, + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == "reflect": + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == "replicate": + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == "zero": + p = 1 + else: + raise NotImplementedError("padding [%s] is not implemented" % padding_type) + conv_block += [ + nn.Conv2d(dim, dim, kernel_size=3, padding=p), + nn.BatchNorm2d(dim), + ] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +class Encoder(nn.Module): + def __init__( + self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d + ): + super(Encoder, self).__init__() + self.output_nc = output_nc + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + nn.ReLU(True), + ] + ### downsample + for i in range(n_downsampling): + mult = 2**i + model += [ + nn.Conv2d( + ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1 + ), + norm_layer(ngf * mult * 2), + nn.ReLU(True), + ] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [ + nn.ConvTranspose2d( + ngf * mult, + int(ngf * mult / 2), + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True), + ] + + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), + nn.Tanh(), + ] + self.model = nn.Sequential(*model) + + def forward(self, input, inst): + outputs = self.model(input) + + # instance-wise average pooling + outputs_mean = outputs.clone() + inst_list = np.unique(inst.cpu().numpy().astype(int)) + for i in inst_list: + for b in range(input.size()[0]): + indices = (inst[b : b + 1] == int(i)).nonzero() # n x 4 + for j in range(self.output_nc): + output_ins = outputs[ + indices[:, 0] + b, + indices[:, 1] + j, + indices[:, 2], + indices[:, 3], + ] + mean_feat = torch.mean(output_ins).expand_as(output_ins) + outputs_mean[ + indices[:, 0] + b, + indices[:, 1] + j, + indices[:, 2], + indices[:, 3], + ] = mean_feat + return outputs_mean \ No newline at end of file diff --git a/threestudio/utils/GAN/util.py b/threestudio/utils/GAN/util.py new file mode 100644 index 0000000..7f4604a --- /dev/null +++ b/threestudio/utils/GAN/util.py @@ -0,0 +1,208 @@ +import importlib +import multiprocessing as mp +from collections import abc +from functools import partial +from inspect import isfunction +from queue import Queue +from threading import Thread + +import numpy as np +import torch +from einops import rearrange +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join( + xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) + ) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, + data, + n_proc, + target_data_type="ndarray", + cpu_intensive=True, + use_worker_id=False, +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i : i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == "ndarray": + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == "list": + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res \ No newline at end of file diff --git a/threestudio/utils/GAN/vae.py b/threestudio/utils/GAN/vae.py new file mode 100644 index 0000000..3920d56 --- /dev/null +++ b/threestudio/utils/GAN/vae.py @@ -0,0 +1,1028 @@ +# pytorch_diffusion + derived encoder decoder +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from threestudio.utils.GAN.attention import LinearAttention +from threestudio.utils.GAN.util import instantiate_from_config + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.BatchNorm2d(num_features=in_channels) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.attn_resolutions = attn_resolutions + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + if len(attn_resolutions) > 0: + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + if len(self.attn_resolutions) > 0: + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + # self.temb_ch = 3 + self.temb_ch = 64 + # self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.attn_resolutions = attn_resolutions + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + self.conv_in3 = torch.nn.Conv2d( + z_channels + 3, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + up.rgb_conv = torch.nn.Conv2d( + block_in + 3, 3, kernel_size=3, stride=1, padding=1 + ) + up.rgb_cat_conv = torch.nn.Conv2d( + block_in + 3, block_in, kernel_size=3, stride=1, padding=1 + ) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z, temb=None): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + # temb = None + + # z to block_in + rgb = z[:, :3] + if z.shape[1] == self.z_shape[1] + 3: + h = self.conv_in3(z) + else: + h = self.conv_in(z) + + # middle + # h = self.mid.block_1(h, temb) + # h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + + rgb = torch.nn.functional.interpolate(rgb, scale_factor=4.0, mode="bilinear") + rgb = torch.sigmoid(torch.logit(rgb, eps=1e-3) + h) + return rgb + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.upsample_blocks = nn.ModuleList() + self.rgb_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x + + +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock( + in_channels=ch_in, out_channels=m * n_channels, dropout=dropout + ) + ) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, "b c h w -> b (h w) c") + return z \ No newline at end of file diff --git a/threestudio/utils/__init__.py b/threestudio/utils/__init__.py new file mode 100644 index 0000000..0e44449 --- /dev/null +++ b/threestudio/utils/__init__.py @@ -0,0 +1 @@ +from . import base diff --git a/threestudio/utils/base.py b/threestudio/utils/base.py new file mode 100644 index 0000000..97f1f66 --- /dev/null +++ b/threestudio/utils/base.py @@ -0,0 +1,118 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from threestudio.utils.config import parse_structured +from threestudio.utils.misc import get_device, load_module_weights +from threestudio.utils.typing import * + + +class Configurable: + @dataclass + class Config: + pass + + def __init__(self, cfg: Optional[dict] = None) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + + +class Updateable: + def do_update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step( + epoch, global_step, on_load_weights=on_load_weights + ) + self.update_step(epoch, global_step, on_load_weights=on_load_weights) + + def do_update_step_end(self, epoch: int, global_step: int): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + self.update_step_end(epoch, global_step) + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # override this method to implement custom update logic + # if on_load_weights is True, you should be careful doing things related to model evaluations, + # as the models and tensors are not guarenteed to be on the same device + pass + + def update_step_end(self, epoch: int, global_step: int): + pass + + +def update_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step(epoch, global_step) + + +def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + + +class BaseObject(Updateable): + @dataclass + class Config: + pass + + cfg: Config # add this to every subclass of BaseObject to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.device = get_device() + self.configure(*args, **kwargs) + + def configure(self, *args, **kwargs) -> None: + pass + + +class BaseModule(nn.Module, Updateable): + @dataclass + class Config: + weights: Optional[str] = None + + cfg: Config # add this to every subclass of BaseModule to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.device = get_device() + self.configure(*args, **kwargs) + if self.cfg.weights is not None: + # format: path/to/weights:module_name + weights_path, module_name = self.cfg.weights.split(":") + state_dict, epoch, global_step = load_module_weights( + weights_path, module_name=module_name, map_location="cpu" + ) + self.load_state_dict(state_dict) + self.do_update_step( + epoch, global_step, on_load_weights=True + ) # restore states + # dummy tensor to indicate model state + self._dummy: Float[Tensor, "..."] + self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False) + + def configure(self, *args, **kwargs) -> None: + pass diff --git a/threestudio/utils/callbacks.py b/threestudio/utils/callbacks.py new file mode 100644 index 0000000..f3ae798 --- /dev/null +++ b/threestudio/utils/callbacks.py @@ -0,0 +1,156 @@ +import os +import shutil +import subprocess + +import pytorch_lightning + +from threestudio.utils.config import dump_config +from threestudio.utils.misc import parse_version + +if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): + from pytorch_lightning.callbacks import Callback +else: + from pytorch_lightning.callbacks.base import Callback + +from pytorch_lightning.callbacks.progress import TQDMProgressBar +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn + + +class VersionedCallback(Callback): + def __init__(self, save_root, version=None, use_version=True): + self.save_root = save_root + self._version = version + self.use_version = use_version + + @property + def version(self) -> int: + """Get the experiment version. + + Returns: + The experiment version if specified else the next version. + """ + if self._version is None: + self._version = self._get_next_version() + return self._version + + def _get_next_version(self): + existing_versions = [] + if os.path.isdir(self.save_root): + for f in os.listdir(self.save_root): + bn = os.path.basename(f) + if bn.startswith("version_"): + dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") + existing_versions.append(int(dir_ver)) + if len(existing_versions) == 0: + return 0 + return max(existing_versions) + 1 + + @property + def savedir(self): + if not self.use_version: + return self.save_root + return os.path.join( + self.save_root, + self.version + if isinstance(self.version, str) + else f"version_{self.version}", + ) + + +class CodeSnapshotCallback(VersionedCallback): + def __init__(self, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + + def get_file_list(self): + return [ + b.decode() + for b in set( + subprocess.check_output( + 'git ls-files -- ":!:load/*"', shell=True + ).splitlines() + ) + | set( # hard code, TODO: use config to exclude folders or files + subprocess.check_output( + "git ls-files --others --exclude-standard", shell=True + ).splitlines() + ) + ] + + @rank_zero_only + def save_code_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + for f in self.get_file_list(): + if not os.path.exists(f) or os.path.isdir(f): + continue + os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) + shutil.copyfile(f, os.path.join(self.savedir, f)) + + def on_fit_start(self, trainer, pl_module): + try: + self.save_code_snapshot() + except: + rank_zero_warn( + "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." + ) + + +class ConfigSnapshotCallback(VersionedCallback): + def __init__(self, config_path, config, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + self.config_path = config_path + self.config = config + + @rank_zero_only + def save_config_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) + shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) + + def on_fit_start(self, trainer, pl_module): + self.save_config_snapshot() + + +class CustomProgressBar(TQDMProgressBar): + def get_metrics(self, *args, **kwargs): + # don't show the version number + items = super().get_metrics(*args, **kwargs) + items.pop("v_num", None) + return items + + +class ProgressCallback(Callback): + def __init__(self, save_path): + super().__init__() + self.save_path = save_path + self._file_handle = None + + @property + def file_handle(self): + if self._file_handle is None: + self._file_handle = open(self.save_path, "w") + return self._file_handle + + @rank_zero_only + def write(self, msg: str) -> None: + self.file_handle.seek(0) + self.file_handle.truncate() + self.file_handle.write(msg) + self.file_handle.flush() + + @rank_zero_only + def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): + self.write( + f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%" + ) + + @rank_zero_only + def on_validation_start(self, trainer, pl_module): + self.write(f"Rendering validation image ...") + + @rank_zero_only + def on_test_start(self, trainer, pl_module): + self.write(f"Rendering video ...") + + @rank_zero_only + def on_predict_start(self, trainer, pl_module): + self.write(f"Exporting mesh assets ...") \ No newline at end of file diff --git a/threestudio/utils/config.py b/threestudio/utils/config.py new file mode 100644 index 0000000..cd86333 --- /dev/null +++ b/threestudio/utils/config.py @@ -0,0 +1,131 @@ +import os +from dataclasses import dataclass, field +from datetime import datetime + +from omegaconf import OmegaConf + +import threestudio +from threestudio.utils.typing import * + +# ============ Register OmegaConf Recolvers ============= # +OmegaConf.register_new_resolver( + "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) +) +OmegaConf.register_new_resolver("add", lambda a, b: a + b) +OmegaConf.register_new_resolver("sub", lambda a, b: a - b) +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) +OmegaConf.register_new_resolver("div", lambda a, b: a / b) +OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) +OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) +OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) +OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) +OmegaConf.register_new_resolver("gt0", lambda s: s > 0) +OmegaConf.register_new_resolver("cmaxgt0", lambda s: C_max(s) > 0) +OmegaConf.register_new_resolver("not", lambda s: not s) +OmegaConf.register_new_resolver( + "cmaxgt0orcmaxgt0", lambda a, b: C_max(a) > 0 or C_max(b) > 0 +) +# ======================================================= # + + +def C_max(value: Any) -> float: + if isinstance(value, int) or isinstance(value, float): + pass + else: + value = config_to_primitive(value) + if not isinstance(value, list): + raise TypeError("Scalar specification only supports list, got", type(value)) + if len(value) >= 6: + max_value = value[2] + for i in range(4, len(value), 2): + max_value = max(max_value, value[i]) + value = [value[0], value[1], max_value, value[3]] + if len(value) == 3: + value = [0] + value + assert len(value) == 4 + start_step, start_value, end_value, end_step = value + value = max(start_value, end_value) + return value + + +@dataclass +class ExperimentConfig: + name: str = "default" + description: str = "" + tag: str = "" + seed: int = 0 + use_timestamp: bool = True + timestamp: Optional[str] = None + exp_root_dir: str = "outputs" + + # import custom extension + custom_import: Tuple[str] = () + + ### these shouldn't be set manually + exp_dir: str = "outputs/default" + trial_name: str = "exp" + trial_dir: str = "outputs/default/exp" + n_gpus: int = 1 + ### + + resume: Optional[str] = None + + data_type: str = "" + data: dict = field(default_factory=dict) + + system_type: str = "" + system: dict = field(default_factory=dict) + + # accept pytorch-lightning trainer parameters + # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api + trainer: dict = field(default_factory=dict) + + # accept pytorch-lightning checkpoint callback parameters + # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint + checkpoint: dict = field(default_factory=dict) + + def __post_init__(self): + if not self.tag and not self.use_timestamp: + raise ValueError("Either tag is specified or use_timestamp is True.") + self.trial_name = self.tag + # if resume from an existing config, self.timestamp should not be None + if self.timestamp is None: + self.timestamp = "" + if self.use_timestamp: + if self.n_gpus > 1: + threestudio.warn( + "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." + ) + else: + self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") + self.trial_name += self.timestamp + self.exp_dir = os.path.join(self.exp_root_dir, self.name) + self.trial_dir = os.path.join(self.exp_dir, self.trial_name) + os.makedirs(self.trial_dir, exist_ok=True) + + +def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: + if from_string: + yaml_confs = [OmegaConf.create(s) for s in yamls] + else: + yaml_confs = [OmegaConf.load(f) for f in yamls] + cli_conf = OmegaConf.from_cli(cli_args) + cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) + OmegaConf.resolve(cfg) + assert isinstance(cfg, DictConfig) + scfg = parse_structured(ExperimentConfig, cfg) + return scfg + + +def config_to_primitive(config, resolve: bool = True) -> Any: + return OmegaConf.to_container(config, resolve=resolve) + + +def dump_config(path: str, config) -> None: + with open(path, "w") as fp: + OmegaConf.save(config=config, f=fp) + + +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: + scfg = OmegaConf.structured(fields(**cfg)) + return scfg \ No newline at end of file diff --git a/threestudio/utils/dpt.py b/threestudio/utils/dpt.py new file mode 100644 index 0000000..8cc0479 --- /dev/null +++ b/threestudio/utils/dpt.py @@ -0,0 +1,924 @@ +import math +import types + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import timm + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) + + +def unflatten_with_named_tensor(input, dim, sizes): + """Workaround for unflattening with named tensor.""" + # tracer acts up with unflatten. See https://github.com/pytorch/pytorch/issues/49538 + new_shape = list(input.shape)[:dim] + list(sizes) + list(input.shape)[dim+1:] + return input.view(*new_shape) + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + + unflattened_dim = 2 + unflattened_size = ( + int(torch.div(h, pretrained.model.patch_size[1], rounding_mode='floor')), + int(torch.div(w, pretrained.model.patch_size[0], rounding_mode='floor')), + ) + unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size)) + + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten_with_named_tensor(layer_3, unflattened_dim, unflattened_size) + if layer_4.ndim == 3: + layer_4 = unflatten_with_named_tensor(layer_4, unflattened_dim, unflattened_size) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(posemb_grid.shape[0])) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, torch.div(h, self.patch_size[1], rounding_mode='floor'), torch.div(w, self.patch_size[0], rounding_mode='floor') + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + True, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) \ No newline at end of file diff --git a/threestudio/utils/lpips/__init__.py b/threestudio/utils/lpips/__init__.py new file mode 100644 index 0000000..4059bdc --- /dev/null +++ b/threestudio/utils/lpips/__init__.py @@ -0,0 +1 @@ +from .lpips import LPIPS \ No newline at end of file diff --git a/threestudio/utils/lpips/lpips.py b/threestudio/utils/lpips/lpips.py new file mode 100644 index 0000000..4a98f4e --- /dev/null +++ b/threestudio/utils/lpips/lpips.py @@ -0,0 +1,123 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import torch +import torch.nn as nn +from torchvision import models +from collections import namedtuple + +from threestudio.utils.lpips.utils import get_ckpt_path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "threestudio/utils/lpips") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2,3],keepdim=keepdim) + diff --git a/threestudio/utils/lpips/utils.py b/threestudio/utils/lpips/utils.py new file mode 100644 index 0000000..06053e5 --- /dev/null +++ b/threestudio/utils/lpips/utils.py @@ -0,0 +1,157 @@ +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + + +if __name__ == "__main__": + config = {"keya": "a", + "keyb": "b", + "keyc": + {"cc1": 1, + "cc2": 2, + } + } + from omegaconf import OmegaConf + config = OmegaConf.create(config) + print(config) + retrieve(config, "keya") + diff --git a/threestudio/utils/misc.py b/threestudio/utils/misc.py new file mode 100644 index 0000000..412e5b7 --- /dev/null +++ b/threestudio/utils/misc.py @@ -0,0 +1,156 @@ +import gc +import os +import re + +import tinycudann as tcnn +import torch +from packaging import version + +from threestudio.utils.config import config_to_primitive +from threestudio.utils.typing import * + + +def parse_version(ver: str): + return version.parse(ver) + + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("LOCAL_RANK", "RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +def get_device(): + return torch.device(f"cuda:{get_rank()}") + + +def load_module_weights( + path, module_name=None, ignore_modules=None, map_location=None +) -> Tuple[dict, int, int]: + if module_name is not None and ignore_modules is not None: + raise ValueError("module_name and ignore_modules cannot be both set") + if map_location is None: + map_location = get_device() + + ckpt = torch.load(path, map_location=map_location) + state_dict = ckpt["state_dict"] + state_dict_to_load = state_dict + + if ignore_modules is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + ignore = any( + [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] + ) + if ignore: + continue + state_dict_to_load[k] = v + + if module_name is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + m = re.match(rf"^{module_name}\.(.*)$", k) + if m is None: + continue + state_dict_to_load[m.group(1)] = v + + return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] + + +def C(value: Any, epoch: int, global_step: int) -> float: + if isinstance(value, int) or isinstance(value, float): + pass + else: + value = config_to_primitive(value) + if not isinstance(value, list): + raise TypeError("Scalar specification only supports list, got", type(value)) + if len(value) == 3: + value = [0] + value + if len(value) >= 6: + select_i = 3 + for i in range(3, len(value) - 2, 2): + if global_step >= value[i]: + select_i = i + 2 + if select_i != 3: + start_value, start_step = value[select_i - 3], value[select_i - 2] + else: + start_step, start_value = value[:2] + end_value, end_step = value[select_i - 1], value[select_i] + value = [start_step, start_value, end_value, end_step] + assert len(value) == 4 + start_step, start_value, end_value, end_step = value + if isinstance(end_step, int): + current_step = global_step + value = start_value + (end_value - start_value) * max( + min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 + ) + elif isinstance(end_step, float): + current_step = epoch + value = start_value + (end_value - start_value) * max( + min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 + ) + return value + + +def cleanup(): + gc.collect() + torch.cuda.empty_cache() + tcnn.free_temporary_memory() + + +def finish_with_cleanup(func: Callable): + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + cleanup() + return out + + return wrapper + + +def _distributed_available(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + + +def barrier(): + if not _distributed_available(): + return + else: + torch.distributed.barrier() + + +def broadcast(tensor, src=0): + if not _distributed_available(): + return tensor + else: + torch.distributed.broadcast(tensor, src=src) + return tensor + + +def enable_gradient(model, enabled: bool = True) -> None: + for param in model.parameters(): + param.requires_grad_(enabled) + +def find_last_path(path: str): + if (path is not None) and ("LAST" in path): + path = path.replace(" ", "_") + base_dir_prefix, suffix = path.split("LAST", 1) + base_dir = os.path.dirname(base_dir_prefix) + prefix = os.path.split(base_dir_prefix)[-1] + base_dir_prefix = os.path.join(base_dir, prefix) + all_path = os.listdir(base_dir) + all_path = [os.path.join(base_dir, dir) for dir in all_path] + filtered_path = [dir for dir in all_path if dir.startswith(base_dir_prefix)] + filtered_path.sort(reverse=True) + last_path = filtered_path[0] + new_path = last_path + suffix + if os.path.exists(new_path): + return new_path + else: + raise FileNotFoundError(new_path) + else: + return path \ No newline at end of file diff --git a/threestudio/utils/ops.py b/threestudio/utils/ops.py new file mode 100644 index 0000000..5a92278 --- /dev/null +++ b/threestudio/utils/ops.py @@ -0,0 +1,459 @@ +import math +from collections import defaultdict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from igl import fast_winding_number_for_meshes, point_mesh_squared_distance, read_obj +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +import threestudio +from threestudio.utils.typing import * + + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + + +def reflect(x, n): + return 2 * dot(x, n) * n - x + + +ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] + + +def scale_tensor( + dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale +): + if inp_scale is None: + inp_scale = (0, 1) + if tgt_scale is None: + tgt_scale = (0, 1) + if isinstance(tgt_scale, Tensor): + assert dat.shape[-1] == tgt_scale.shape[-1] + dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) + dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] + return dat + + +class _TruncExp(Function): # pylint: disable=abstract-method + # Implementation from torch-ngp: + # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, x): # pylint: disable=arguments-differ + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): # pylint: disable=arguments-differ + x = ctx.saved_tensors[0] + return g * torch.exp(torch.clamp(x, max=15)) + + +class SpecifyGradient(Function): + # Implementation from stable-dreamfusion + # https://github.com/ashawkey/stable-dreamfusion + @staticmethod + @custom_fwd + def forward(ctx, input_tensor, gt_grad): + ctx.save_for_backward(gt_grad) + # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward. + return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype) + + @staticmethod + @custom_bwd + def backward(ctx, grad_scale): + (gt_grad,) = ctx.saved_tensors + gt_grad = gt_grad * grad_scale + return gt_grad, None + + +trunc_exp = _TruncExp.apply + + +def get_activation(name) -> Callable: + if name is None: + return lambda x: x + name = name.lower() + if name == "none": + return lambda x: x + elif name == "lin2srgb": + return lambda x: torch.where( + x > 0.0031308, + torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, + 12.92 * x, + ).clamp(0.0, 1.0) + elif name == "exp": + return lambda x: torch.exp(x) + elif name == "shifted_exp": + return lambda x: torch.exp(x - 1.0) + elif name == "trunc_exp": + return trunc_exp + elif name == "shifted_trunc_exp": + return lambda x: trunc_exp(x - 1.0) + elif name == "sigmoid": + return lambda x: torch.sigmoid(x) + elif name == "tanh": + return lambda x: torch.tanh(x) + elif name == "shifted_softplus": + return lambda x: F.softplus(x - 1.0) + elif name == "scale_-11_01": + return lambda x: x * 0.5 + 0.5 + else: + try: + return getattr(F, name) + except AttributeError: + raise ValueError(f"Unknown activation function: {name}") + + +def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: + if chunk_size <= 0: + return func(*args, **kwargs) + B = None + for arg in list(args) + list(kwargs.values()): + if isinstance(arg, torch.Tensor): + B = arg.shape[0] + break + assert ( + B is not None + ), "No tensor found in args or kwargs, cannot determine batch size." + out = defaultdict(list) + out_type = None + # max(1, B) to support B == 0 + for i in range(0, max(1, B), chunk_size): + out_chunk = func( + *[ + arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for arg in args + ], + **{ + k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for k, arg in kwargs.items() + }, + ) + if out_chunk is None: + continue + out_type = type(out_chunk) + if isinstance(out_chunk, torch.Tensor): + out_chunk = {0: out_chunk} + elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): + chunk_length = len(out_chunk) + out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} + elif isinstance(out_chunk, dict): + pass + else: + print( + f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." + ) + exit(1) + for k, v in out_chunk.items(): + v = v if torch.is_grad_enabled() else v.detach() + out[k].append(v) + + if out_type is None: + return None + + out_merged: Dict[Any, Optional[torch.Tensor]] = {} + for k, v in out.items(): + if all([vv is None for vv in v]): + # allow None in return value + out_merged[k] = None + elif all([isinstance(vv, torch.Tensor) for vv in v]): + out_merged[k] = torch.cat(v, dim=0) + else: + raise TypeError( + f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" + ) + + if out_type is torch.Tensor: + return out_merged[0] + elif out_type in [tuple, list]: + return out_type([out_merged[i] for i in range(chunk_length)]) + elif out_type is dict: + return out_merged + + +def get_ray_directions( + H: int, + W: int, + focal: Union[float, Tuple[float, float]], + principal: Optional[Tuple[float, float]] = None, + use_pixel_centers: bool = True, +) -> Float[Tensor, "H W 3"]: + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + + Inputs: + H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + pixel_center = 0.5 if use_pixel_centers else 0 + + if isinstance(focal, float): + fx, fy = focal, focal + cx, cy = W / 2, H / 2 + else: + fx, fy = focal + assert principal is not None + cx, cy = principal + + i, j = torch.meshgrid( + torch.arange(W, dtype=torch.float32) + pixel_center, + torch.arange(H, dtype=torch.float32) + pixel_center, + indexing="xy", + ) + + directions: Float[Tensor, "H W 3"] = torch.stack( + [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1 + ) + + return directions + + +def get_rays( + directions: Float[Tensor, "... 3"], + c2w: Float[Tensor, "... 4 4"], + keepdim=False, + noise_scale=0.0, + normalize=True, +) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]: + # Rotate ray directions from camera coordinate to the world coordinate + assert directions.shape[-1] == 3 + + if directions.ndim == 2: # (N_rays, 3) + if c2w.ndim == 2: # (4, 4) + c2w = c2w[None, :, :] + assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) + rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) + rays_o = c2w[:, :3, 3].expand(rays_d.shape) + elif directions.ndim == 3: # (H, W, 3) + assert c2w.ndim in [2, 3] + if c2w.ndim == 2: # (4, 4) + rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum( + -1 + ) # (H, W, 3) + rays_o = c2w[None, None, :3, 3].expand(rays_d.shape) + elif c2w.ndim == 3: # (B, 4, 4) + rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + elif directions.ndim == 4: # (B, H, W, 3) + assert c2w.ndim == 3 # (B, 4, 4) + rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + + # add camera noise to avoid grid-like artifect + # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373 + if noise_scale > 0: + rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale + rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale + + if normalize: + rays_d = F.normalize(rays_d, dim=-1) + if not keepdim: + rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) + + return rays_o, rays_d + + +def get_projection_matrix( + fovy: Float[Tensor, "B"], aspect_wh: float, near: float, far: float +) -> Float[Tensor, "B 4 4"]: + batch_size = fovy.shape[0] + proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32) + proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh) + proj_mtx[:, 1, 1] = -1.0 / torch.tan( + fovy / 2.0 + ) # add a negative sign here as the y axis is flipped in nvdiffrast output + proj_mtx[:, 2, 2] = -(far + near) / (far - near) + proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near) + proj_mtx[:, 3, 2] = -1.0 + return proj_mtx + + +def get_mvp_matrix( + c2w: Float[Tensor, "B 4 4"], proj_mtx: Float[Tensor, "B 4 4"] +) -> Float[Tensor, "B 4 4"]: + # calculate w2c from c2w: R' = Rt, t' = -Rt * t + # mathematically equivalent to (c2w)^-1 + w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w) + w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1) + w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:] + w2c[:, 3, 3] = 1.0 + # calculate mvp matrix by proj_mtx @ w2c (mv_mtx) + mvp_mtx = proj_mtx @ w2c + return mvp_mtx + + +def get_full_projection_matrix( + c2w: Float[Tensor, "B 4 4"], proj_mtx: Float[Tensor, "B 4 4"] +) -> Float[Tensor, "B 4 4"]: + return (c2w.unsqueeze(0).bmm(proj_mtx.unsqueeze(0))).squeeze(0) + + +def binary_cross_entropy(input, target): + """ + F.binary_cross_entropy is not numerically stable in mixed-precision training. + """ + return -(target * torch.log(input) + (1 - target) * torch.log(1 - input)).mean() + + +def tet_sdf_diff( + vert_sdf: Float[Tensor, "Nv 1"], tet_edges: Integer[Tensor, "Ne 2"] +) -> Float[Tensor, ""]: + sdf_f1x6x2 = vert_sdf[:, 0][tet_edges.reshape(-1)].reshape(-1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = F.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float() + ) + F.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float() + ) + return sdf_diff + + +# Implementation from Latent-NeRF +# https://github.com/eladrich/latent-nerf/blob/f49ecefcd48972e69a28e3116fe95edf0fac4dc8/src/latent_nerf/models/mesh_utils.py +class MeshOBJ: + dx = torch.zeros(3).float() + dx[0] = 1 + dy, dz = dx[[1, 0, 2]], dx[[2, 1, 0]] + dx, dy, dz = dx[None, :], dy[None, :], dz[None, :] + + def __init__(self, v: np.ndarray, f: np.ndarray): + self.v = v + self.f = f + self.dx, self.dy, self.dz = MeshOBJ.dx, MeshOBJ.dy, MeshOBJ.dz + self.v_tensor = torch.from_numpy(self.v) + + vf = self.v[self.f, :] + self.f_center = vf.mean(axis=1) + self.f_center_tensor = torch.from_numpy(self.f_center).float() + + e1 = vf[:, 1, :] - vf[:, 0, :] + e2 = vf[:, 2, :] - vf[:, 0, :] + self.face_normals = np.cross(e1, e2) + self.face_normals = ( + self.face_normals / np.linalg.norm(self.face_normals, axis=-1)[:, None] + ) + self.face_normals_tensor = torch.from_numpy(self.face_normals) + + def normalize_mesh(self, target_scale=0.5): + verts = self.v + + # Compute center of bounding box + # center = torch.mean(torch.column_stack([torch.max(verts, dim=0)[0], torch.min(verts, dim=0)[0]])) + center = verts.mean(axis=0) + verts = verts - center + scale = np.max(np.linalg.norm(verts, axis=1)) + verts = (verts / scale) * target_scale + + return MeshOBJ(verts, self.f) + + def winding_number(self, query: torch.Tensor): + device = query.device + shp = query.shape + query_np = query.detach().cpu().reshape(-1, 3).numpy() + target_alphas = fast_winding_number_for_meshes( + self.v.astype(np.float32), self.f, query_np + ) + return torch.from_numpy(target_alphas).reshape(shp[:-1]).to(device) + + def gaussian_weighted_distance(self, query: torch.Tensor, sigma): + device = query.device + shp = query.shape + query_np = query.detach().cpu().reshape(-1, 3).numpy() + distances, _, _ = point_mesh_squared_distance( + query_np, self.v.astype(np.float32), self.f + ) + distances = torch.from_numpy(distances).reshape(shp[:-1]).to(device) + weight = torch.exp(-(distances / (2 * sigma**2))) + return weight + + +def ce_pq_loss(p, q, weight=None): + def clamp(v, T=0.0001): + return v.clamp(T, 1 - T) + + p = p.view(q.shape) + ce = -1 * (p * torch.log(clamp(q)) + (1 - p) * torch.log(clamp(1 - q))) + if weight is not None: + ce *= weight + return ce.sum() + + +class ShapeLoss(nn.Module): + def __init__(self, guide_shape): + super().__init__() + self.mesh_scale = 0.7 + self.proximal_surface = 0.3 + self.delta = 0.2 + self.shape_path = guide_shape + v, _, _, f, _, _ = read_obj(self.shape_path, float) + mesh = MeshOBJ(v, f) + matrix_rot = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) @ np.array( + [[0, 0, 1], [0, 1, 0], [-1, 0, 0]] + ) + self.sketchshape = mesh.normalize_mesh(self.mesh_scale) + self.sketchshape = MeshOBJ( + np.ascontiguousarray( + (matrix_rot @ self.sketchshape.v.transpose(1, 0)).transpose(1, 0) + ), + f, + ) + + def forward(self, xyzs, sigmas): + mesh_occ = self.sketchshape.winding_number(xyzs) + if self.proximal_surface > 0: + weight = 1 - self.sketchshape.gaussian_weighted_distance( + xyzs, self.proximal_surface + ) + else: + weight = None + indicator = (mesh_occ > 0.5).float() + nerf_occ = 1 - torch.exp(-self.delta * sigmas) + nerf_occ = nerf_occ.clamp(min=0, max=1.1) + loss = ce_pq_loss( + nerf_occ, indicator, weight=weight + ) # order is important for CE loss + second argument may not be optimized + return loss + + +def shifted_expotional_decay(a, b, c, r): + return a * torch.exp(-b * r) + c + + +def shifted_cosine_decay(a, b, c, r): + return a * torch.cos(b * r + c) + a + + +def perpendicular_component(x: Float[Tensor, "B C H W"], y: Float[Tensor, "B C H W"]): + # get the component of x that is perpendicular to y + eps = torch.ones_like(x[:, 0, 0, 0]) * 1e-6 + return ( + x + - ( + torch.mul(x, y).sum(dim=[1, 2, 3]) + / torch.maximum(torch.mul(y, y).sum(dim=[1, 2, 3]), eps) + ).view(-1, 1, 1, 1) + * y + ) + + +def validate_empty_rays(ray_indices, t_start, t_end): + if ray_indices.nelement() == 0: + threestudio.warn("Empty rays_indices!") + ray_indices = torch.LongTensor([0]).to(ray_indices) + t_start = torch.Tensor([0]).to(ray_indices) + t_end = torch.Tensor([0]).to(ray_indices) + return ray_indices, t_start, t_end \ No newline at end of file diff --git a/threestudio/utils/perceptual/__init__.py b/threestudio/utils/perceptual/__init__.py new file mode 100644 index 0000000..a4d2c7a --- /dev/null +++ b/threestudio/utils/perceptual/__init__.py @@ -0,0 +1 @@ +from .perceptual import PerceptualLoss diff --git a/threestudio/utils/perceptual/perceptual.py b/threestudio/utils/perceptual/perceptual.py new file mode 100644 index 0000000..aaa5648 --- /dev/null +++ b/threestudio/utils/perceptual/perceptual.py @@ -0,0 +1,173 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +from collections import namedtuple +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +from torchvision import models + +import threestudio +from threestudio.utils.perceptual.utils import get_ckpt_path +from threestudio.utils.base import BaseObject +from threestudio.utils.typing import * + +@threestudio.register("perceptual-loss") +class PerceptualLossObject(BaseObject): + @dataclass + class Config(BaseObject.Config): + use_dropout: bool = True + cfg: Config + + def configure(self) -> None: + self.perceptual_loss = PerceptualLoss(self.cfg.use_dropout).to(self.device) + + def __call__( + self, + x: Float[Tensor, "B 3 256 256"], + y: Float[Tensor, "B 3 256 256"], + ): + return self.perceptual_loss(x, y) + + +class PerceptualLoss(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "threestudio/utils/lpips") + self.load_state_dict( + torch.load(ckpt, map_location=torch.device("cpu")), strict=False + ) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict( + torch.load(ckpt, map_location=torch.device("cpu")), strict=False + ) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( + outs1[kk] + ) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [ + spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns)) + ] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer( + "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] + ) + self.register_buffer( + "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] + ) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + try: + vgg_pretrained = models.vgg16(pretrained=True) + vgg_pretrained_features = vgg_pretrained.features + except: + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) \ No newline at end of file diff --git a/threestudio/utils/perceptual/utils.py b/threestudio/utils/perceptual/utils.py new file mode 100644 index 0000000..c3f295f --- /dev/null +++ b/threestudio/utils/perceptual/utils.py @@ -0,0 +1,154 @@ +import hashlib +import os + +import requests +from tqdm import tqdm + +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} + +CKPT_MAP = {"vgg_lpips": "vgg.pth"} + +MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + + +if __name__ == "__main__": + config = { + "keya": "a", + "keyb": "b", + "keyc": { + "cc1": 1, + "cc2": 2, + }, + } + from omegaconf import OmegaConf + + config = OmegaConf.create(config) + print(config) + retrieve(config, "keya") \ No newline at end of file diff --git a/threestudio/utils/rasterize.py b/threestudio/utils/rasterize.py new file mode 100644 index 0000000..a174bc1 --- /dev/null +++ b/threestudio/utils/rasterize.py @@ -0,0 +1,78 @@ +import nvdiffrast.torch as dr +import torch + +from threestudio.utils.typing import * + + +class NVDiffRasterizerContext: + def __init__(self, context_type: str, device: torch.device) -> None: + self.device = device + self.ctx = self.initialize_context(context_type, device) + + def initialize_context( + self, context_type: str, device: torch.device + ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]: + if context_type == "gl": + return dr.RasterizeGLContext(device=device) + elif context_type == "cuda": + return dr.RasterizeCudaContext(device=device) + else: + raise ValueError(f"Unknown rasterizer context type: {context_type}") + + def vertex_transform( + self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"] + ) -> Float[Tensor, "B Nv 4"]: + verts_homo = torch.cat( + [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1 + ) + return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1)) + + def rasterize( + self, + pos: Float[Tensor, "B Nv 4"], + tri: Integer[Tensor, "Nf 3"], + resolution: Union[int, Tuple[int, int]], + ): + # rasterize in instance mode (single topology) + return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True) + + def rasterize_one( + self, + pos: Float[Tensor, "Nv 4"], + tri: Integer[Tensor, "Nf 3"], + resolution: Union[int, Tuple[int, int]], + ): + # rasterize one single mesh under a single viewpoint + rast, rast_db = self.rasterize(pos[None, ...], tri, resolution) + return rast[0], rast_db[0] + + def antialias( + self, + color: Float[Tensor, "B H W C"], + rast: Float[Tensor, "B H W 4"], + pos: Float[Tensor, "B Nv 4"], + tri: Integer[Tensor, "Nf 3"], + ) -> Float[Tensor, "B H W C"]: + return dr.antialias(color.float(), rast, pos.float(), tri.int()) + + def interpolate( + self, + attr: Float[Tensor, "B Nv C"], + rast: Float[Tensor, "B H W 4"], + tri: Integer[Tensor, "Nf 3"], + rast_db=None, + diff_attrs=None, + ) -> Float[Tensor, "B H W C"]: + return dr.interpolate( + attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs + ) + + def interpolate_one( + self, + attr: Float[Tensor, "Nv C"], + rast: Float[Tensor, "B H W 4"], + tri: Integer[Tensor, "Nf 3"], + rast_db=None, + diff_attrs=None, + ) -> Float[Tensor, "B H W C"]: + return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs) diff --git a/threestudio/utils/saving.py b/threestudio/utils/saving.py new file mode 100644 index 0000000..a9040fa --- /dev/null +++ b/threestudio/utils/saving.py @@ -0,0 +1,652 @@ +import json +import os +import re +import shutil + +import cv2 +import imageio +import matplotlib.pyplot as plt +import numpy as np +import torch +import trimesh +import wandb +from matplotlib import cm +from matplotlib.colors import LinearSegmentedColormap +from PIL import Image, ImageDraw +from pytorch_lightning.loggers import WandbLogger + +from threestudio.models.mesh import Mesh +from threestudio.utils.typing import * + + +class SaverMixin: + _save_dir: Optional[str] = None + _wandb_logger: Optional[WandbLogger] = None + + def set_save_dir(self, save_dir: str): + self._save_dir = save_dir + + def get_save_dir(self): + if self._save_dir is None: + raise ValueError("Save dir is not set") + return self._save_dir + + def convert_data(self, data): + if data is None: + return None + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, torch.Tensor): + return data.detach().cpu().numpy() + elif isinstance(data, list): + return [self.convert_data(d) for d in data] + elif isinstance(data, dict): + return {k: self.convert_data(v) for k, v in data.items()} + else: + raise TypeError( + "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", + type(data), + ) + + def get_save_path(self, filename): + save_path = os.path.join(self.get_save_dir(), filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + return save_path + + def create_loggers(self, cfg_loggers: DictConfig) -> None: + if "wandb" in cfg_loggers.keys() and cfg_loggers.wandb.enable: + self._wandb_logger = WandbLogger( + project=cfg_loggers.wandb.project, name=cfg_loggers.wandb.name + ) + + def get_loggers(self) -> List: + if self._wandb_logger: + return [self._wandb_logger] + else: + return [] + + DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} + DEFAULT_UV_KWARGS = { + "data_format": "HWC", + "data_range": (0, 1), + "cmap": "checkerboard", + } + DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} + DEFAULT_GRID_KWARGS = {"align": "max"} + + def get_rgb_image_(self, img, data_format, data_range, rgba=False): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + if img.dtype != np.uint8: + img = img.clip(min=data_range[0], max=data_range[1]) + img = ( + (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 + ).astype(np.uint8) + nc = 4 if rgba else 3 + imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] + imgs = [ + img_ + if img_.shape[-1] == nc + else np.concatenate( + [ + img_, + np.zeros( + (img_.shape[0], img_.shape[1], nc - img_.shape[2]), + dtype=img_.dtype, + ), + ], + axis=-1, + ) + for img_ in imgs + ] + img = np.concatenate(imgs, axis=1) + if rgba: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_rgb_image( + self, + filename, + img, + data_format, + data_range, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_rgb_image_(img, data_format, data_range) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + wandb.log( + { + name: wandb.Image(self.get_save_path(filename)), + "trainer/global_step": step, + } + ) + + def save_rgb_image( + self, + filename, + img, + data_format=DEFAULT_RGB_KWARGS["data_format"], + data_range=DEFAULT_RGB_KWARGS["data_range"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_rgb_image(save_path, img, data_format, data_range, name, step) + return save_path + + def get_uv_image_(self, img, data_format, data_range, cmap): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + img = img.clip(min=data_range[0], max=data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in ["checkerboard", "color"] + if cmap == "checkerboard": + n_grid = 64 + mask = (img * n_grid).astype(int) + mask = (mask[..., 0] + mask[..., 1]) % 2 == 0 + img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 + img[mask] = np.array([255, 0, 255], dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif cmap == "color": + img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) + img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) + img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) + img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) + img = img_ + return img + + def save_uv_image( + self, + filename, + img, + data_format=DEFAULT_UV_KWARGS["data_format"], + data_range=DEFAULT_UV_KWARGS["data_range"], + cmap=DEFAULT_UV_KWARGS["cmap"], + ) -> str: + save_path = self.get_save_path(filename) + img = self.get_uv_image_(img, data_format, data_range, cmap) + cv2.imwrite(save_path, img) + return save_path + + def get_grayscale_image_(self, img, data_range, cmap): + img = self.convert_data(img) + img = np.nan_to_num(img) + if data_range is None: + img = (img - img.min()) / (img.max() - img.min()) + else: + img = img.clip(data_range[0], data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in [None, "jet", "magma", "spectral"] + if cmap == None: + img = (img * 255.0).astype(np.uint8) + img = np.repeat(img[..., None], 3, axis=2) + elif cmap == "jet": + img = (img * 255.0).astype(np.uint8) + img = cv2.applyColorMap(img, cv2.COLORMAP_JET) + elif cmap == "magma": + img = 1.0 - img + base = cm.get_cmap("magma") + num_bins = 256 + colormap = LinearSegmentedColormap.from_list( + f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins + )(np.linspace(0, 1, num_bins))[:, :3] + a = np.floor(img * 255.0) + b = (a + 1).clip(max=255.0) + f = img * 255.0 - a + a = a.astype(np.uint16).clip(0, 255) + b = b.astype(np.uint16).clip(0, 255) + img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] + img = (img * 255.0).astype(np.uint8) + elif cmap == "spectral": + colormap = plt.get_cmap("Spectral") + + def blend_rgba(image): + image = image[..., :3] * image[..., -1:] + ( + 1.0 - image[..., -1:] + ) # blend A to RGB + return image + + img = colormap(img) + img = blend_rgba(img) + img = (img * 255).astype(np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_grayscale_image( + self, + filename, + img, + data_range, + cmap, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_grayscale_image_(img, data_range, cmap) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + wandb.log( + { + name: wandb.Image(self.get_save_path(filename)), + "trainer/global_step": step, + } + ) + + def save_grayscale_image( + self, + filename, + img, + data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], + cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_grayscale_image(save_path, img, data_range, cmap, name, step) + return save_path + + def get_image_grid_(self, imgs, align): + if isinstance(imgs[0], list): + return np.concatenate( + [self.get_image_grid_(row, align) for row in imgs], axis=0 + ) + cols = [] + for col in imgs: + assert col["type"] in ["rgb", "uv", "grayscale"] + if col["type"] == "rgb": + rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() + rgb_kwargs.update(col["kwargs"]) + cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) + elif col["type"] == "uv": + uv_kwargs = self.DEFAULT_UV_KWARGS.copy() + uv_kwargs.update(col["kwargs"]) + cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) + elif col["type"] == "grayscale": + grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() + grayscale_kwargs.update(col["kwargs"]) + cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) + + if align == "max": + h = max([col.shape[0] for col in cols]) + w = max([col.shape[1] for col in cols]) + elif align == "min": + h = min([col.shape[0] for col in cols]) + w = min([col.shape[1] for col in cols]) + elif isinstance(align, int): + h = align + w = align + elif ( + isinstance(align, tuple) + and isinstance(align[0], int) + and isinstance(align[1], int) + ): + h, w = align + else: + raise ValueError( + f"Unsupported image grid align: {align}, should be min, max, int or (int, int)" + ) + + for i in range(len(cols)): + if cols[i].shape[0] != h or cols[i].shape[1] != w: + cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_LINEAR) + return np.concatenate(cols, axis=1) + + def save_image_grid( + self, + filename, + imgs, + align=DEFAULT_GRID_KWARGS["align"], + name: Optional[str] = None, + step: Optional[int] = None, + texts: Optional[List[float]] = None, + ): + save_path = self.get_save_path(filename) + img = self.get_image_grid_(imgs, align=align) + + if texts is not None: + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + black, white = (0, 0, 0), (255, 255, 255) + for i, text in enumerate(texts): + draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) + img = np.asarray(img) + + cv2.imwrite(save_path, img) + if name and self._wandb_logger: + wandb.log({name: wandb.Image(save_path), "trainer/global_step": step}) + return save_path + + def save_image(self, filename, img) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.dtype == np.uint8 or img.dtype == np.uint16 + if img.ndim == 3 and img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif img.ndim == 3 and img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + cv2.imwrite(save_path, img) + return save_path + + def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] + + imgs_full = [] + for start in range(0, img.shape[-1], 3): + img_ = img[..., start : start + 3] + img_ = np.stack( + [ + self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba) + for i in range(img_.shape[0]) + ], + axis=0, + ) + size = img_.shape[1] + placeholder = np.zeros((size, size, 3), dtype=np.float32) + img_full = np.concatenate( + [ + np.concatenate( + [placeholder, img_[2], placeholder, placeholder], axis=1 + ), + np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), + np.concatenate( + [placeholder, img_[3], placeholder, placeholder], axis=1 + ), + ], + axis=0, + ) + imgs_full.append(img_full) + + imgs_full = np.concatenate(imgs_full, axis=1) + cv2.imwrite(save_path, imgs_full) + return save_path + + def save_data(self, filename, data) -> str: + data = self.convert_data(data) + if isinstance(data, dict): + if not filename.endswith(".npz"): + filename += ".npz" + save_path = self.get_save_path(filename) + np.savez(save_path, **data) + else: + if not filename.endswith(".npy"): + filename += ".npy" + save_path = self.get_save_path(filename) + np.save(save_path, data) + return save_path + + def save_state_dict(self, filename, data) -> str: + save_path = self.get_save_path(filename) + torch.save(data, save_path) + return save_path + + def save_img_sequence( + self, + filename, + img_dir, + matcher, + save_format="mp4", + fps=30, + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + assert save_format in ["gif", "mp4"] + if not filename.endswith(save_format): + filename += f".{save_format}" + save_path = self.get_save_path(filename) + matcher = re.compile(matcher) + img_dir = os.path.join(self.get_save_dir(), img_dir) + imgs = [] + for f in os.listdir(img_dir): + if matcher.search(f): + imgs.append(f) + imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) + imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] + + if save_format == "gif": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) + elif save_format == "mp4": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps) + if name and self._wandb_logger: + wandb.log( + { + name: wandb.Video(save_path, format="mp4"), + "trainer/global_step": step, + } + ) + return save_path + + def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None) -> str: + save_path = self.get_save_path(filename) + v_pos = self.convert_data(v_pos) + t_pos_idx = self.convert_data(t_pos_idx) + mesh = trimesh.Trimesh(vertices=v_pos, faces=t_pos_idx) + mesh.export(save_path) + return save_path + + def save_obj( + self, + filename: str, + mesh: Mesh, + save_mat: bool = False, + save_normal: bool = False, + save_uv: bool = False, + save_vertex_color: bool = False, + map_Kd: Optional[Float[Tensor, "H W 3"]] = None, + map_Ks: Optional[Float[Tensor, "H W 3"]] = None, + map_Bump: Optional[Float[Tensor, "H W 3"]] = None, + map_Pm: Optional[Float[Tensor, "H W 1"]] = None, + map_Pr: Optional[Float[Tensor, "H W 1"]] = None, + map_format: str = "jpg", + ) -> List[str]: + save_paths: List[str] = [] + if not filename.endswith(".obj"): + filename += ".obj" + v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data( + mesh.t_pos_idx + ) + v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None + if save_normal: + v_nrm = self.convert_data(mesh.v_nrm) + if save_uv: + v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data( + mesh.t_tex_idx + ) + if save_vertex_color: + v_rgb = self.convert_data(mesh.v_rgb) + matname, mtllib = None, None + if save_mat: + matname = "default" + mtl_filename = filename.replace(".obj", ".mtl") + mtllib = os.path.basename(mtl_filename) + mtl_save_paths = self._save_mtl( + mtl_filename, + matname, + map_Kd=self.convert_data(map_Kd), + map_Ks=self.convert_data(map_Ks), + map_Bump=self.convert_data(map_Bump), + map_Pm=self.convert_data(map_Pm), + map_Pr=self.convert_data(map_Pr), + map_format=map_format, + ) + save_paths += mtl_save_paths + obj_save_path = self._save_obj( + filename, + v_pos, + t_pos_idx, + v_nrm=v_nrm, + v_tex=v_tex, + t_tex_idx=t_tex_idx, + v_rgb=v_rgb, + matname=matname, + mtllib=mtllib, + ) + save_paths.append(obj_save_path) + return save_paths + + def _save_obj( + self, + filename, + v_pos, + t_pos_idx, + v_nrm=None, + v_tex=None, + t_tex_idx=None, + v_rgb=None, + matname=None, + mtllib=None, + ) -> str: + obj_str = "" + if matname is not None: + obj_str += f"mtllib {mtllib}\n" + obj_str += f"g object\n" + obj_str += f"usemtl {matname}\n" + for i in range(len(v_pos)): + obj_str += f"v {v_pos[i][0]} {v_pos[i][1]} {v_pos[i][2]}" + if v_rgb is not None: + obj_str += f" {v_rgb[i][0]} {v_rgb[i][1]} {v_rgb[i][2]}" + obj_str += "\n" + if v_nrm is not None: + for v in v_nrm: + obj_str += f"vn {v[0]} {v[1]} {v[2]}\n" + if v_tex is not None: + for v in v_tex: + obj_str += f"vt {v[0]} {1.0 - v[1]}\n" + + for i in range(len(t_pos_idx)): + obj_str += "f" + for j in range(3): + obj_str += f" {t_pos_idx[i][j] + 1}/" + if v_tex is not None: + obj_str += f"{t_tex_idx[i][j] + 1}" + obj_str += "/" + if v_nrm is not None: + obj_str += f"{t_pos_idx[i][j] + 1}" + obj_str += "\n" + + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(obj_str) + return save_path + + def _save_mtl( + self, + filename, + matname, + Ka=(0.0, 0.0, 0.0), + Kd=(1.0, 1.0, 1.0), + Ks=(0.0, 0.0, 0.0), + map_Kd=None, + map_Ks=None, + map_Bump=None, + map_Pm=None, + map_Pr=None, + map_format="jpg", + step: Optional[int] = None, + ) -> List[str]: + mtl_save_path = self.get_save_path(filename) + save_paths = [mtl_save_path] + mtl_str = f"newmtl {matname}\n" + mtl_str += f"Ka {Ka[0]} {Ka[1]} {Ka[2]}\n" + if map_Kd is not None: + map_Kd_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_kd.{map_format}" + ) + mtl_str += f"map_Kd texture_kd.{map_format}\n" + self._save_rgb_image( + map_Kd_save_path, + map_Kd, + data_format="HWC", + data_range=(0, 1), + name=f"{matname}_Kd", + step=step, + ) + save_paths.append(map_Kd_save_path) + else: + mtl_str += f"Kd {Kd[0]} {Kd[1]} {Kd[2]}\n" + if map_Ks is not None: + map_Ks_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_ks.{map_format}" + ) + mtl_str += f"map_Ks texture_ks.{map_format}\n" + self._save_rgb_image( + map_Ks_save_path, + map_Ks, + data_format="HWC", + data_range=(0, 1), + name=f"{matname}_Ks", + step=step, + ) + save_paths.append(map_Ks_save_path) + else: + mtl_str += f"Ks {Ks[0]} {Ks[1]} {Ks[2]}\n" + if map_Bump is not None: + map_Bump_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_nrm.{map_format}" + ) + mtl_str += f"map_Bump texture_nrm.{map_format}\n" + self._save_rgb_image( + map_Bump_save_path, + map_Bump, + data_format="HWC", + data_range=(0, 1), + name=f"{matname}_Bump", + step=step, + ) + save_paths.append(map_Bump_save_path) + if map_Pm is not None: + map_Pm_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_metallic.{map_format}" + ) + mtl_str += f"map_Pm texture_metallic.{map_format}\n" + self._save_grayscale_image( + map_Pm_save_path, + map_Pm, + data_range=(0, 1), + cmap=None, + name=f"{matname}_refl", + step=step, + ) + save_paths.append(map_Pm_save_path) + if map_Pr is not None: + map_Pr_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_roughness.{map_format}" + ) + mtl_str += f"map_Pr texture_roughness.{map_format}\n" + self._save_grayscale_image( + map_Pr_save_path, + map_Pr, + data_range=(0, 1), + cmap=None, + name=f"{matname}_Ns", + step=step, + ) + save_paths.append(map_Pr_save_path) + with open(self.get_save_path(filename), "w") as f: + f.write(mtl_str) + return save_paths + + def save_file(self, filename, src_path) -> str: + save_path = self.get_save_path(filename) + shutil.copyfile(src_path, save_path) + return save_path + + def save_json(self, filename, payload) -> str: + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(json.dumps(payload)) + return save_path diff --git a/threestudio/utils/typing.py b/threestudio/utils/typing.py new file mode 100644 index 0000000..dee9f96 --- /dev/null +++ b/threestudio/utils/typing.py @@ -0,0 +1,40 @@ +""" +This module contains type annotations for the project, using +1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects +2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors + +Two types of typing checking can be used: +1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) +2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) +""" + +# Basic types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + NamedTuple, + NewType, + Optional, + Sized, + Tuple, + Type, + TypeVar, + Union, +) + +# Tensor dtype +# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md +from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt + +# Config type +from omegaconf import DictConfig + +# PyTorch Tensor type +from torch import Tensor + +# Runtime type checking decorator +from typeguard import typechecked as typechecker