chores: rebase commits

This commit is contained in:
MrTornado24 2023-12-13 00:17:53 +08:00
commit 50ecd13a88
177 changed files with 45954 additions and 0 deletions

12
.editorconfig Normal file
View File

@ -0,0 +1,12 @@
root = true
[*.py]
charset = utf-8
trim_trailing_whitespace = true
end_of_line = lf
insert_final_newline = true
indent_style = space
indent_size = 4
[*.md]
trim_trailing_whitespace = false

195
.gitignore vendored Normal file
View File

@ -0,0 +1,195 @@
# Created by https://www.toptal.com/developers/gitignore/api/python
# Edit at https://www.toptal.com/developers/gitignore?templates=python
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python
.vscode/
.threestudio_cache/
outputs/
outputs-gradio/
# pretrained model weights
*.ckpt
*.pt
*.pth
# wandb
wandb/
load/tets/256_tets.npz
# dataset
dataset/
load/

34
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,34 @@
default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: check-ast
- id: check-merge-conflict
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
exclude: README.md
args: ["--profile", "black"]
# temporarily disable static type checking
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.2.0
# hooks:
# - id: mypy
# args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"]

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 deepseek-ai
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

21
LICENSE-CODE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

91
LICENSE-MODEL Normal file
View File

@ -0,0 +1,91 @@
DEEPSEEK LICENSE AGREEMENT
Version 1.0, 23 October 2023
Copyright (c) 2023 DeepSeek
Section I: PREAMBLE
Large generative models are being widely adopted and used, and have the potential to transform the way individuals conceive and benefit from AI or ML technologies.
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for content generation.
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this agreement aims to strike a balance between both in order to enable responsible open-science in the field of AI.
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
NOW THEREFORE, You and DeepSeek agree as follows:
1. Definitions
"License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
"Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
"Output" means the results of operating a Model as embodied in informational content resulting therefrom.
"Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
"Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
"Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
"Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
"DeepSeek" (or "we") means Beijing DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd., Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. and/or any of their affiliates.
"You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, etc.
"Third Parties" means individuals or legal entities that are not under common control with DeepSeek or You.
Section II: INTELLECTUAL PROPERTY RIGHTS
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
2. Grant of Copyright License. Subject to the terms and conditions of this License, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by DeepSeek that are necessarily infringed by its contribution(s). If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or works shall terminate as of the date such litigation is asserted or filed.
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
c. You must cause any modified files to carry prominent notices stating that You changed the files;
d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
e. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
6. The Output You Generate. Except as set forth herein, DeepSeek claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
Section IV: OTHER PROVISIONS
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, DeepSeek reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License.
8. Trademarks and related. Nothing in this License permits You to make use of DeepSeek trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by DeepSeek.
9. Personal information, IP rights and related. This Model may contain personal information and works with IP rights. You commit to complying with applicable laws and regulations in the handling of personal information and the use of such works. Please note that DeepSeek's license granted to you to use the Model does not imply that you have obtained a legitimate basis for processing the related information or works. As an independent personal information processor and IP rights user, you need to ensure full compliance with relevant legal and regulatory requirements when handling personal information and works with IP rights that may be contained in the Model, and are willing to assume solely any risks and consequences that may arise from that.
10. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, DeepSeek provides the Model and the Complementary Material on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
11. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall DeepSeek be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if DeepSeek has been advised of the possibility of such damages.
12. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of DeepSeek, and only if You agree to indemnify, defend, and hold DeepSeek harmless for any liability incurred by, or claims asserted against, DeepSeek by reason of your accepting any such warranty or additional liability.
13. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
14. Governing Law and Jurisdiction. This agreement will be governed and construed under PRC laws without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this agreement. The courts located in the domicile of Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. shall have exclusive jurisdiction of any dispute arising out of this agreement.
END OF TERMS AND CONDITIONS
Attachment A
Use Restrictions
You agree not to use the Model or Derivatives of the Model:
- In any way that violates any applicable national or international law or regulation or infringes upon the lawful rights and interests of any third party;
- For military use in any way;
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
- To generate or disseminate inappropriate content subject to applicable regulatory requirements;
- To generate or disseminate personal identifiable information without due authorization or for unreasonable use;
- To defame, disparage or otherwise harass others;
- For fully automated decision making that adversely impacts an individuals legal rights or otherwise creates or modifies a binding, enforceable obligation;
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories.

175
README.md Normal file
View File

@ -0,0 +1,175 @@
# DreamCraft3D
[**Paper**](https://arxiv.org/abs/2310.16818) | [**Project Page**](https://mrtornado24.github.io/DreamCraft3D/) | [**Youtube video**](https://www.youtube.com/watch?v=0FazXENkQms)
Official implementation of DreamCraft3D: Hierarchical 3D Generation with Bootstrapped Diffusion Prior
[Jingxiang Sun](https://mrtornado24.github.io/), [Bo Zhang](https://bo-zhang.me/), [Ruizhi Shao](https://dsaurus.github.io/saurus/), [Lizhen Wang](https://lizhenwangt.github.io/), [Wen Liu](https://github.com/StevenLiuWen), [Zhenda Xie](https://zdaxie.github.io/), [Yebin Liu](https://liuyebin.com/)
Abstract: *We present DreamCraft3D, a hierarchical 3D content generation method that produces high-fidelity and coherent 3D objects. We tackle the problem by leveraging a 2D reference image to guide the stages of geometry sculpting and texture boosting. A central focus of this work is to address the consistency issue that existing
works encounter. To sculpt geometries that render coherently, we perform score
distillation sampling via a view-dependent diffusion model. This 3D prior, alongside several training strategies, prioritizes the geometry consistency but compromises the texture fidelity. We further propose **Bootstrapped Score Distillation** to
specifically boost the texture. We train a personalized diffusion model, Dreambooth, on the augmented renderings of the scene, imbuing it with 3D knowledge
of the scene being optimized. The score distillation from this 3D-aware diffusion prior provides view-consistent guidance for the scene. Notably, through an
alternating optimization of the diffusion prior and 3D scene representation, we
achieve mutually reinforcing improvements: the optimized 3D scene aids in training the scene-specific diffusion model, which offers increasingly view-consistent
guidance for 3D optimization. The optimization is thus bootstrapped and leads
to substantial texture boosting. With tailored 3D priors throughout the hierarchical generation, DreamCraft3D generates coherent 3D objects with photorealistic
renderings, advancing the state-of-the-art in 3D content generation.*
<p align="center">
<img src="assets/repo_static_v2.png">
</p>
## Method Overview
<p align="center">
<img src="assets/diagram-1.png">
</p>
<!-- https://github.com/MrTornado24/DreamCraft3D/assets/45503891/8e70610c-d812-4544-86bf-7f8764e41067
https://github.com/MrTornado24/DreamCraft3D/assets/45503891/b1e8ae54-1afd-4e0f-88f7-9bd5b70fd44d
https://github.com/MrTornado24/DreamCraft3D/assets/45503891/ead40f9b-d7ee-4ee8-8d98-dbd0b8fbab97 -->
## Installation
### Install threestudio
**This part is the same as original threestudio. Skip it if you already have installed the environment.**
See [installation.md](docs/installation.md) for additional information, including installation via Docker.
- You must have an NVIDIA graphics card with at least 20GB VRAM and have [CUDA](https://developer.nvidia.com/cuda-downloads) installed.
- Install `Python >= 3.8`.
- (Optional, Recommended) Create a virtual environment:
```sh
python3 -m virtualenv venv
. venv/bin/activate
# Newer pip versions, e.g. pip-23.x, can be much faster than old versions, e.g. pip-20.x.
# For instance, it caches the wheels of git packages to avoid unnecessarily rebuilding them later.
python3 -m pip install --upgrade pip
```
- Install `PyTorch >= 1.12`. We have tested on `torch1.12.1+cu113` and `torch2.0.0+cu118`, but other versions should also work fine.
```sh
# torch1.12.1+cu113
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
# or torch2.0.0+cu118
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
```
- (Optional, Recommended) Install ninja to speed up the compilation of CUDA extensions:
```sh
pip install ninja
```
- Install dependencies:
```sh
pip install -r requirements.txt
```
## Quickstart
Our model is trained in multiple stages. You can run it by
```sh
prompt="a brightly colored mushroom growing on a log"
image_path="load/images/mushroom_log_rgba.png"
# --------- Stage 1 (NeRF & NeuS) --------- #
python launch.py --config configs/dreamcraft3d-coarse-nerf.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path"
ckpt=outputs/dreamcraft3d-coarse-nerf/$prompt@LAST/ckpts/last.ckpt
python launch.py --config configs/dreamcraft3d-coarse-neus.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.weights="$ckpt"
# --------- Stage 2 (Geometry Refinement) --------- #
ckpt=outputs/dreamcraft3d-coarse-neus/$prompt@LAST/ckpts/last.ckpt
python launch.py --config configs/dreamcraft3d-geometry.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.geometry_convert_from="$ckpt"
# --------- Stage 3 (Texture Refinement) --------- #
ckpt=outputs/dreamcraft3d-geometry/$prompt@LAST/ckpts/last.ckpt
python launch.py --config configs/dreamcraft3d-texture.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.geometry_convert_from="$ckpt"
```
<details>
<summary>[Optional] If the "Janus problem" arises in Stage 1, consider training a custom Text2Image model.</summary>
First, generate multi-view images from a single reference image by Zero123++.
```sh
python threestudio/scripts/img_to_mv.py --image_path 'load/mushroom.png' --save_path '.cache/temp' --prompt 'a photo of mushroom' --superres
```
Train a personalized DeepFloyd model by DreamBooth Lora. Please check if the generated mv images above are reasonable.
```sh
export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
export INSTANCE_DIR=".cache/temp"
export OUTPUT_DIR=".cache/if_dreambooth_mushroom"
accelerate launch threestudio/scripts/train_dreambooth_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a sks mushroom" \
--resolution=64 \
--train_batch_size=4 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
--scale_lr \
--max_train_steps=1200 \
--checkpointing_steps=600 \
--pre_compute_text_embeddings \
--tokenizer_max_length=77 \
--text_encoder_use_attention_mask
```
The personalized DeepFloyd model lora is save at `.cache/if_dreambooth_mushroom`. Now you can replace the guidance the training scripts by
```sh
# --------- Stage 1 (NeRF & NeuS) --------- #
python launch.py --config configs/dreamcraft3d-coarse-nerf.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.guidance.lora_weights_path=".cache/if_dreambooth_mushroom"
```
</details>
## Tips
- **Memory Usage**. We run the default configs on 40G A100 GPUs. For reducing memory usage, you can reduce the rendering resolution of NeuS by ```data.height=128 data.width=128 data.random_camera.height=128 data.random_camera.width=128```. You can also reduce resolution for other stages in the same way.
## Todo
- [x] Release the reorganized code.
- [ ] Clean the original dreambooth training code.
- [ ] Provide some running results and checkpoints.
## Credits
This code is built on the amazing open-source [threestudio-project](https://github.com/threestudio-project/threestudio).
## Related links
- [DreamFusion](https://dreamfusion3d.github.io/)
- [Magic3D](https://research.nvidia.com/labs/dir/magic3d/)
- [Make-it-3D](https://make-it-3d.github.io/)
- [Magic123](https://guochengqian.github.io/project/magic123/)
- [ProlificDreamer](https://ml.cs.tsinghua.edu.cn/prolificdreamer/)
- [DreamBooth](https://dreambooth.github.io/)
## BibTeX
```bibtex
@article{sun2023dreamcraft3d,
title={Dreamcraft3d: Hierarchical 3d generation with bootstrapped diffusion prior},
author={Sun, Jingxiang and Zhang, Bo and Shao, Ruizhi and Wang, Lizhen and Liu, Wen and Xie, Zhenda and Liu, Yebin},
journal={arXiv preprint arXiv:2310.16818},
year={2023}
}
```

BIN
assets/diagram-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 396 KiB

BIN
assets/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 354 KiB

BIN
assets/repo_demo_0.mp4 Normal file

Binary file not shown.

BIN
assets/repo_demo_01.mp4 Normal file

Binary file not shown.

BIN
assets/repo_demo_02.mp4 Normal file

Binary file not shown.

BIN
assets/repo_static_v2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 MiB

BIN
assets/result_mushroom.mp4 Normal file

Binary file not shown.

View File

@ -0,0 +1,159 @@
name: "dreamcraft3d-coarse-nerf"
tag: "${rmspace:${system.prompt_processor.prompt},_}"
exp_root_dir: "outputs"
seed: 0
data_type: "single-image-datamodule"
data:
image_path: ./load/images/hamburger_rgba.png
height: [128, 384]
width: [128, 384]
resolution_milestones: [3000]
default_elevation_deg: 0.0
default_azimuth_deg: 0.0
default_camera_distance: 3.8
default_fovy_deg: 20.0
requires_depth: true
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
random_camera:
height: [128, 384]
width: [128, 384]
batch_size: [1, 1]
resolution_milestones: [3000]
eval_height: 512
eval_width: 512
eval_batch_size: 1
elevation_range: [-10, 45]
azimuth_range: [-180, 180]
camera_distance_range: [3.8, 3.8]
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
progressive_until: 200
camera_perturb: 0.0
center_perturb: 0.0
up_perturb: 0.0
eval_elevation_deg: ${data.default_elevation_deg}
eval_camera_distance: ${data.default_camera_distance}
eval_fovy_deg: ${data.default_fovy_deg}
batch_uniform_azimuth: false
n_val_views: 40
n_test_views: 120
system_type: "dreamcraft3d-system"
system:
stage: coarse
geometry_type: "implicit-volume"
geometry:
radius: 2.0
normal_type: "finite_difference"
# the density initialization proposed in the DreamFusion paper
# does not work very well
# density_bias: "blob_dreamfusion"
# density_activation: exp
# density_blob_scale: 5.
# density_blob_std: 0.2
# use Magic3D density initialization instead
density_bias: "blob_magic3d"
density_activation: softplus
density_blob_scale: 10.
density_blob_std: 0.5
# coarse to fine hash grid encoding
# to ensure smooth analytic normals
pos_encoding_config:
otype: ProgressiveBandHashGrid
n_levels: 16
n_features_per_level: 2
log2_hashmap_size: 19
base_resolution: 16
per_level_scale: 1.447269237440378 # max resolution 4096
start_level: 8 # resolution ~200
start_step: 2000
update_steps: 500
material_type: "no-material"
material:
requires_normal: true
background_type: "solid-color-background"
renderer_type: "nerf-volume-renderer"
renderer:
radius: ${system.geometry.radius}
num_samples_per_ray: 512
return_normal_perturb: true
return_comp_normal: ${cmaxgt0:${system.loss.lambda_normal_smooth}}
prompt_processor_type: "deep-floyd-prompt-processor"
prompt_processor:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
prompt: ???
use_perp_neg: true
guidance_type: "deep-floyd-guidance"
guidance:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
guidance_scale: 20
min_step_percent: [0, 0.7, 0.2, 200]
max_step_percent: [0, 0.85, 0.5, 200]
guidance_3d_type: "stable-zero123-guidance"
guidance_3d:
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
cond_image_path: ${data.image_path}
cond_elevation_deg: ${data.default_elevation_deg}
cond_azimuth_deg: ${data.default_azimuth_deg}
cond_camera_distance: ${data.default_camera_distance}
guidance_scale: 5.0
min_step_percent: [0, 0.7, 0.2, 200] # (start_iter, start_val, end_val, end_iter)
max_step_percent: [0, 0.85, 0.5, 200]
freq:
n_ref: 2
ref_only_steps: 0
ref_or_guidance: "alternate"
no_diff_steps: 0
guidance_eval: 0
loggers:
wandb:
enable: false
project: "threestudio"
loss:
lambda_sd: 0.1
lambda_3d_sd: 0.1
lambda_rgb: 1000.0
lambda_mask: 100.0
lambda_mask_binary: 0.0
lambda_depth: 0.0
lambda_depth_rel: 0.05
lambda_normal: 0.0
lambda_normal_smooth: 1.0
lambda_3d_normal_smooth: [2000, 5., 1., 2001]
lambda_orient: [2000, 1., 10., 2001]
lambda_sparsity: [2000, 0.1, 10., 2001]
lambda_opaque: [2000, 0.1, 10., 2001]
lambda_clip: 0.0
optimizer:
name: Adam
args:
lr: 0.01
betas: [0.9, 0.99]
eps: 1.e-8
trainer:
max_steps: 5000
log_every_n_steps: 1
num_sanity_val_steps: 0
val_check_interval: 200
enable_progress_bar: true
precision: 16-mixed
checkpoint:
save_last: true
save_top_k: -1
every_n_train_steps: ${trainer.max_steps}

View File

@ -0,0 +1,155 @@
name: "dreamcraft3d-coarse-neus"
tag: "${rmspace:${system.prompt_processor.prompt},_}"
exp_root_dir: "outputs"
seed: 0
data_type: "single-image-datamodule"
data:
image_path: ./load/images/hamburger_rgba.png
height: 256
width: 256
default_elevation_deg: 0.0
default_azimuth_deg: 0.0
default_camera_distance: 3.8
default_fovy_deg: 20.0
requires_depth: true
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
random_camera:
height: 256
width: 256
batch_size: 1
eval_height: 512
eval_width: 512
eval_batch_size: 1
elevation_range: [-10, 45]
azimuth_range: [-180, 180]
camera_distance_range: [3.8, 3.8]
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
progressive_until: 0
camera_perturb: 0.0
center_perturb: 0.0
up_perturb: 0.0
eval_elevation_deg: ${data.default_elevation_deg}
eval_camera_distance: ${data.default_camera_distance}
eval_fovy_deg: ${data.default_fovy_deg}
batch_uniform_azimuth: false
n_val_views: 40
n_test_views: 120
system_type: "dreamcraft3d-system"
system:
stage: coarse
geometry_type: "implicit-sdf"
geometry:
radius: 2.0
normal_type: "finite_difference"
sdf_bias: sphere
sdf_bias_params: 0.5
# coarse to fine hash grid encoding
pos_encoding_config:
otype: HashGrid
n_levels: 16
n_features_per_level: 2
log2_hashmap_size: 19
base_resolution: 16
per_level_scale: 1.447269237440378 # max resolution 4096
start_level: 8 # resolution ~200
start_step: 2000
update_steps: 500
material_type: "no-material"
material:
requires_normal: true
background_type: "solid-color-background"
renderer_type: "neus-volume-renderer"
renderer:
radius: ${system.geometry.radius}
num_samples_per_ray: 512
cos_anneal_end_steps: ${trainer.max_steps}
eval_chunk_size: 8192
prompt_processor_type: "deep-floyd-prompt-processor"
prompt_processor:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
prompt: ???
use_perp_neg: true
guidance_type: "deep-floyd-guidance"
guidance:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
guidance_scale: 20
min_step_percent: 0.2
max_step_percent: 0.5
guidance_3d_type: "stable-zero123-guidance"
guidance_3d:
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
cond_image_path: ${data.image_path}
cond_elevation_deg: ${data.default_elevation_deg}
cond_azimuth_deg: ${data.default_azimuth_deg}
cond_camera_distance: ${data.default_camera_distance}
guidance_scale: 5.0
min_step_percent: 0.2
max_step_percent: 0.5
freq:
n_ref: 2
ref_only_steps: 0
ref_or_guidance: "alternate"
no_diff_steps: 0
guidance_eval: 0
loggers:
wandb:
enable: false
project: "threestudio"
loss:
lambda_sd: 0.1
lambda_3d_sd: 0.1
lambda_rgb: 1000.0
lambda_mask: 100.0
lambda_mask_binary: 0.0
lambda_depth: 0.0
lambda_depth_rel: 0.05
lambda_normal: 0.0
lambda_normal_smooth: 0.0
lambda_3d_normal_smooth: 0.0
lambda_orient: 10.0
lambda_sparsity: 0.1
lambda_opaque: 0.1
lambda_clip: 0.0
lambda_eikonal: 0.0
optimizer:
name: Adam
args:
betas: [0.9, 0.99]
eps: 1.e-15
params:
geometry.encoding:
lr: 0.01
geometry.sdf_network:
lr: 0.001
geometry.feature_network:
lr: 0.001
renderer:
lr: 0.001
trainer:
max_steps: 5000
log_every_n_steps: 1
num_sanity_val_steps: 0
val_check_interval: 200
enable_progress_bar: true
precision: 16-mixed
checkpoint:
save_last: true
save_top_k: -1
every_n_train_steps: ${trainer.max_steps}

View File

@ -0,0 +1,133 @@
name: "dreamcraft3d-geometry"
tag: "${rmspace:${system.prompt_processor.prompt},_}"
exp_root_dir: "outputs"
seed: 0
data_type: "single-image-datamodule"
data:
image_path: ./load/images/hamburger_rgba.png
height: 1024
width: 1024
default_elevation_deg: 0.0
default_azimuth_deg: 0.0
default_camera_distance: 3.8
default_fovy_deg: 20.0
requires_depth: ${cmaxgt0orcmaxgt0:${system.loss.lambda_depth},${system.loss.lambda_depth_rel}}
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
use_mixed_camera_config: false
random_camera:
height: 1024
width: 1024
batch_size: 1
eval_height: 1024
eval_width: 1024
eval_batch_size: 1
elevation_range: [-10, 45]
azimuth_range: [-180, 180]
camera_distance_range: [3.8, 3.8]
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
progressive_until: 0
camera_perturb: 0.0
center_perturb: 0.0
up_perturb: 0.0
eval_elevation_deg: ${data.default_elevation_deg}
eval_camera_distance: ${data.default_camera_distance}
eval_fovy_deg: ${data.default_fovy_deg}
batch_uniform_azimuth: false
n_val_views: 40
n_test_views: 120
system_type: "dreamcraft3d-system"
system:
stage: geometry
use_mixed_camera_config: ${data.use_mixed_camera_config}
geometry_convert_from: ???
geometry_convert_inherit_texture: true
geometry_type: "tetrahedra-sdf-grid"
geometry:
radius: 2.0 # consistent with coarse
isosurface_resolution: 128
isosurface_deformable_grid: true
material_type: "no-material"
material:
n_output_dims: 3
background_type: "solid-color-background"
renderer_type: "nvdiff-rasterizer"
renderer:
context_type: cuda
prompt_processor_type: "deep-floyd-prompt-processor"
prompt_processor:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
prompt: ???
use_perp_neg: true
guidance_type: "deep-floyd-guidance"
guidance:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
guidance_scale: 20
min_step_percent: 0.02
max_step_percent: 0.5
guidance_3d_type: "stable-zero123-guidance"
guidance_3d:
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
cond_image_path: ${data.image_path}
cond_elevation_deg: ${data.default_elevation_deg}
cond_azimuth_deg: ${data.default_azimuth_deg}
cond_camera_distance: ${data.default_camera_distance}
guidance_scale: 5.0
min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter)
max_step_percent: 0.5
freq:
n_ref: 2
ref_only_steps: 0
ref_or_guidance: "accumulate"
no_diff_steps: 0
guidance_eval: 0
n_rgb: 4
loggers:
wandb:
enable: false
project: "threestudio"
loss:
lambda_sd: 0.1
lambda_3d_sd: 0.1
lambda_rgb: 1000.0
lambda_mask: 100.0
lambda_mask_binary: 0.0
lambda_depth: 0.0
lambda_depth_rel: 0.0
lambda_normal: 0.0
lambda_normal_smooth: 0.
lambda_3d_normal_smooth: 0.
lambda_normal_consistency: [1000,10.0,1,2000]
lambda_laplacian_smoothness: 0.0
optimizer:
name: Adam
args:
lr: 0.005
betas: [0.9, 0.99]
eps: 1.e-15
trainer:
max_steps: 5000
log_every_n_steps: 1
num_sanity_val_steps: 0
val_check_interval: 200
enable_progress_bar: true
precision: 32
strategy: "ddp_find_unused_parameters_true"
checkpoint:
save_last: true
save_top_k: -1
every_n_train_steps: ${trainer.max_steps}

View File

@ -0,0 +1,166 @@
name: "dreamcraft3d-texture"
tag: "${rmspace:${system.prompt_processor.prompt},_}"
exp_root_dir: "outputs"
seed: 0
data_type: "single-image-datamodule"
data:
image_path: ./load/images/hamburger_rgba.png
height: 1024
width: 1024
default_elevation_deg: 0.0
default_azimuth_deg: 0.0
default_camera_distance: 3.8
default_fovy_deg: 20.0
requires_depth: false
requires_normal: false
use_mixed_camera_config: false
random_camera:
height: 1024
width: 1024
batch_size: 1
eval_height: 1024
eval_width: 1024
eval_batch_size: 1
elevation_range: [-10, 45]
azimuth_range: [-180, 180]
camera_distance_range: [3.8, 3.8]
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
progressive_until: 0
camera_perturb: 0.0
center_perturb: 0.0
up_perturb: 0.0
eval_elevation_deg: ${data.default_elevation_deg}
eval_camera_distance: ${data.default_camera_distance}
eval_fovy_deg: ${data.default_fovy_deg}
batch_uniform_azimuth: false
n_val_views: 40
n_test_views: 120
system_type: "dreamcraft3d-system"
system:
stage: texture
use_mixed_camera_config: ${data.use_mixed_camera_config}
geometry_convert_from: ???
geometry_convert_inherit_texture: true
geometry_type: "tetrahedra-sdf-grid"
geometry:
radius: 2.0 # consistent with coarse
isosurface_resolution: 128
isosurface_deformable_grid: true
isosurface_remove_outliers: true
pos_encoding_config:
otype: HashGrid
n_levels: 16
n_features_per_level: 2
log2_hashmap_size: 19
base_resolution: 16
per_level_scale: 1.447269237440378 # max resolution 4096
fix_geometry: true
material_type: "no-material"
material:
n_output_dims: 3
background_type: "solid-color-background"
renderer_type: "nvdiff-rasterizer"
renderer:
context_type: cuda
prompt_processor_type: "stable-diffusion-prompt-processor"
prompt_processor:
pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
prompt: ???
front_threshold: 30.
back_threshold: 30.
guidance_type: "stable-diffusion-bsd-guidance"
guidance:
pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1-base"
# pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1"
guidance_scale: 2.0
min_step_percent: 0.05
max_step_percent: [0, 0.5, 0.2, 5000]
only_pretrain_step: 1000
# guidance_3d_type: "stable-zero123-guidance"
# guidance_3d:
# pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
# pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
# cond_image_path: ${data.image_path}
# cond_elevation_deg: ${data.default_elevation_deg}
# cond_azimuth_deg: ${data.default_azimuth_deg}
# cond_camera_distance: ${data.default_camera_distance}
# guidance_scale: 5.0
# min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter)
# max_step_percent: 0.5
# control_guidance_type: "stable-diffusion-controlnet-reg-guidance"
# control_guidance:
# min_step_percent: 0.1
# max_step_percent: 0.5
# control_prompt_processor_type: "stable-diffusion-prompt-processor"
# control_prompt_processor:
# pretrained_model_name_or_path: "SG161222/Realistic_Vision_V2.0"
# prompt: ${system.prompt_processor.prompt}
# front_threshold: 30.
# back_threshold: 30.
freq:
n_ref: 2
ref_only_steps: 0
ref_or_guidance: "alternate"
no_diff_steps: -1
guidance_eval: 0
loggers:
wandb:
enable: false
project: "threestudio"
loss:
lambda_sd: 0.01
lambda_lora: 0.1
lambda_pretrain: 0.1
lambda_3d_sd: 0.0
lambda_rgb: 1000.
lambda_mask: 100.
lambda_mask_binary: 0.0
lambda_depth: 0.0
lambda_depth_rel: 0.0
lambda_normal: 0.0
lambda_normal_smooth: 0.0
lambda_3d_normal_smooth: 0.0
lambda_z_variance: 0.0
lambda_reg: 0.0
optimizer:
name: AdamW
args:
betas: [0.9, 0.99]
eps: 1.e-4
params:
geometry.encoding:
lr: 0.01
geometry.feature_network:
lr: 0.001
guidance.train_unet:
lr: 0.00001
guidance.train_unet_lora:
lr: 0.00001
trainer:
max_steps: 5000
log_every_n_steps: 1
num_sanity_val_steps: 0
val_check_interval: 200
enable_progress_bar: true
precision: 32
strategy: "ddp_find_unused_parameters_true"
checkpoint:
save_last: true
save_top_k: -1
every_n_train_steps: ${trainer.max_steps}

60
docker/Dockerfile Normal file
View File

@ -0,0 +1,60 @@
# Reference:
# https://github.com/cvpaperchallenge/Ascender
# https://github.com/nerfstudio-project/nerfstudio
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
ARG USER_NAME=dreamer
ARG GROUP_NAME=dreamers
ARG UID=1000
ARG GID=1000
# Set compute capability for nerfacc and tiny-cuda-nn
# See https://developer.nvidia.com/cuda-gpus and limit number to speed-up build
ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
ENV TCNN_CUDA_ARCHITECTURES=90;89;86;80;75;70;61;60
# Speed-up build for RTX 30xx
# ENV TORCH_CUDA_ARCH_LIST="8.6"
# ENV TCNN_CUDA_ARCHITECTURES=86
# Speed-up build for RTX 40xx
# ENV TORCH_CUDA_ARCH_LIST="8.9"
# ENV TCNN_CUDA_ARCHITECTURES=89
ENV CUDA_HOME=/usr/local/cuda
ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH}
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH}
# apt install by root user
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
curl \
git \
libegl1-mesa-dev \
libgl1-mesa-dev \
libgles2-mesa-dev \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender1 \
python-is-python3 \
python3.10-dev \
python3-pip \
wget \
&& rm -rf /var/lib/apt/lists/*
# Change user to non-root user
RUN groupadd -g ${GID} ${GROUP_NAME} \
&& useradd -ms /bin/sh -u ${UID} -g ${GID} ${USER_NAME}
USER ${USER_NAME}
RUN pip install --upgrade pip setuptools ninja
RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
# Install nerfacc and tiny-cuda-nn before installing requirements.txt
# because these two installations are time consuming and error prone
RUN pip install git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2
RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn.git#subdirectory=bindings/torch
COPY requirements.txt /tmp
RUN cd /tmp && pip install -r requirements.txt
WORKDIR /home/${USER_NAME}/threestudio

23
docker/compose.yaml Normal file
View File

@ -0,0 +1,23 @@
services:
threestudio:
build:
context: ../
dockerfile: docker/Dockerfile
args:
# you can set environment variables, otherwise default values will be used
USER_NAME: ${HOST_USER_NAME:-dreamer} # export HOST_USER_NAME=$USER
GROUP_NAME: ${HOST_GROUP_NAME:-dreamers}
UID: ${HOST_UID:-1000} # export HOST_UID=$(id -u)
GID: ${HOST_GID:-1000} # export HOST_GID=$(id -g)
shm_size: '4gb'
environment:
NVIDIA_DISABLE_REQUIRE: 1 # avoid wrong `nvidia-container-cli: requirement error`
tty: true
volumes:
- ../:/home/${HOST_USER_NAME:-dreamer}/threestudio
deploy:
resources:
reservations:
devices:
- driver: nvidia
capabilities: [gpu]

59
docs/installation.md Normal file
View File

@ -0,0 +1,59 @@
# Installation
## Prerequisite
- NVIDIA GPU with at least 6GB VRAM. The more memory you have, the more methods and higher resolutions you can try.
- [NVIDIA Driver](https://www.nvidia.com/Download/index.aspx) whose version is higher than the [Minimum Required Driver Version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html) of CUDA Toolkit you want to use.
## Install CUDA Toolkit
You can skip this step if you have installed sufficiently new version or you use Docker.
Install [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive).
- Example for Ubuntu 22.04:
- Run [command for CUDA 11.8 Ubuntu 22.04](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=22.04&target_type=deb_local)
- Example for Ubuntu on WSL2:
- `sudo apt-key del 7fa2af80`
- Run [command for CUDA 11.8 WSL-Ubuntu](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=WSL-Ubuntu&target_version=2.0&target_type=deb_local)
## Git Clone
```bash
git clone https://github.com/threestudio-project/threestudio.git
cd threestudio/
```
## Install threestudio via Docker
1. [Install Docker Engine](https://docs.docker.com/engine/install/).
This document assumes you [install Docker Engine on Ubuntu](https://docs.docker.com/engine/install/ubuntu/).
2. [Create `docker` group](https://docs.docker.com/engine/install/linux-postinstall/).
Otherwise, you need to type `sudo docker` instead of `docker`.
3. [Install NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html#setting-up-nvidia-container-toolkit).
4. If you use WSL2, [enable systemd](https://learn.microsoft.com/en-us/windows/wsl/wsl-config#systemd-support).
5. Edit [Dockerfile](../docker/Dockerfile) for your GPU to speed-up build.
The default Dockerfile takes into account many types of GPUs.
6. Run Docker via `docker compose`.
```bash
cd docker/
docker compose build # build Docker image
docker compose up -d # create and start a container in background
docker compose exec threestudio bash # run bash in the container
# Enjoy threestudio!
exit # or Ctrl+D
docker compose stop # stop the container
docker compose start # start the container
docker compose down # stop and remove the container
```
Note: The current Dockerfile will cause errors when using the OpenGL-based rasterizer of nvdiffrast.
You can use the CUDA-based rasterizer by adding commands or editing configs.
- `system.renderer.context_type=cuda` for training
- `system.exporter.context_type=cuda` for exporting meshes
[This comment by the nvdiffrast author](https://github.com/NVlabs/nvdiffrast/issues/94#issuecomment-1288566038) could be a guide to resolve this limitation.

1
extern/MVDream vendored Submodule

@ -0,0 +1 @@
Subproject commit 853c51b5575e179b25d3aef3d9dbdff950e922ee

1
extern/One-2-3-45 vendored Submodule

@ -0,0 +1 @@
Subproject commit ea885683ee1a5ad93ba369057dc3d71b7a5ae061

0
extern/__init__.py vendored Normal file
View File

78
extern/ldm_zero123/extras.py vendored Executable file
View File

@ -0,0 +1,78 @@
import logging
from contextlib import contextmanager
from pathlib import Path
import torch
from omegaconf import OmegaConf
from extern.ldm_zero123.util import instantiate_from_config
@contextmanager
def all_logging_disabled(highest_level=logging.CRITICAL):
"""
A context manager that will prevent any logging messages
triggered during the body from being processed.
:param highest_level: the maximum logging level in use.
This would only need to be changed if a custom level greater than CRITICAL
is defined.
https://gist.github.com/simon-weber/7853144
"""
# two kind-of hacks here:
# * can't get the highest logging level in effect => delegate to the user
# * can't get the current module-level override => use an undocumented
# (but non-private!) interface
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def load_training_dir(train_dir, device, epoch="last"):
"""Load a checkpoint and config from training directory"""
train_dir = Path(train_dir)
ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
config = list(train_dir.rglob(f"*-project.yaml"))
assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
if len(config) > 1:
print(f"found {len(config)} matching config files")
config = sorted(config)[-1]
print(f"selecting {config}")
else:
config = config[0]
config = OmegaConf.load(config)
return load_model_from_config(config, ckpt[0], device)
def load_model_from_config(config, ckpt, device="cpu", verbose=False):
"""Loads a model from config and a ckpt
if config is a path will use omegaconf to load
"""
if isinstance(config, (str, Path)):
config = OmegaConf.load(config)
with all_logging_disabled():
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
model.to(device)
model.eval()
model.cond_stage_model.device = device
return model

110
extern/ldm_zero123/guidance.py vendored Executable file
View File

@ -0,0 +1,110 @@
import abc
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from IPython.display import clear_output
from scipy import interpolate
class GuideModel(torch.nn.Module, abc.ABC):
def __init__(self) -> None:
super().__init__()
@abc.abstractmethod
def preprocess(self, x_img):
pass
@abc.abstractmethod
def compute_loss(self, inp):
pass
class Guider(torch.nn.Module):
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
"""Apply classifier guidance
Specify a guidance scale as either a scalar
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
[(0, 10), (0.5, 20), (1, 50)]
"""
super().__init__()
self.sampler = sampler
self.index = 0
self.show = verbose
self.guide_model = guide_model
self.history = []
if isinstance(scale, (Tuple, List)):
times = np.array([x[0] for x in scale])
values = np.array([x[1] for x in scale])
self.scale_schedule = {"times": times, "values": values}
else:
self.scale_schedule = float(scale)
self.ddim_timesteps = sampler.ddim_timesteps
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
def get_scales(self):
if isinstance(self.scale_schedule, float):
return len(self.ddim_timesteps) * [self.scale_schedule]
interpolater = interpolate.interp1d(
self.scale_schedule["times"], self.scale_schedule["values"]
)
fractional_steps = np.array(self.ddim_timesteps) / self.ddpm_num_timesteps
return interpolater(fractional_steps)
def modify_score(self, model, e_t, x, t, c):
# TODO look up index by t
scale = self.get_scales()[self.index]
if scale == 0:
return e_t
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
x_img = model.first_stage_model.decode((1 / 0.18215) * pred_x0)
inp = self.guide_model.preprocess(x_img)
loss = self.guide_model.compute_loss(inp)
grads = torch.autograd.grad(loss.sum(), x_in)[0]
correction = grads * scale
if self.show:
clear_output(wait=True)
print(
loss.item(),
scale,
correction.abs().max().item(),
e_t.abs().max().item(),
)
self.history.append(
[
loss.item(),
scale,
correction.min().item(),
correction.max().item(),
]
)
plt.imshow(
(inp[0].detach().permute(1, 2, 0).clamp(-1, 1).cpu() + 1) / 2
)
plt.axis("off")
plt.show()
plt.imshow(correction[0][0].detach().cpu())
plt.axis("off")
plt.show()
e_t_mod = e_t - sqrt_1ma * correction
if self.show:
fig, axs = plt.subplots(1, 3)
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
plt.show()
self.index += 1
return e_t_mod

135
extern/ldm_zero123/lr_scheduler.py vendored Executable file
View File

@ -0,0 +1,135 @@
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi)
)
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi)
)
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f
return f

551
extern/ldm_zero123/models/autoencoder.py vendored Executable file
View File

@ -0,0 +1,551 @@
from contextlib import contextmanager
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from extern.ldm_zero123.modules.diffusionmodules.model import Decoder, Encoder
from extern.ldm_zero123.modules.distributions.distributions import (
DiagonalGaussianDistribution,
)
from extern.ldm_zero123.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False,
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape,
)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(
f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16)
)
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
predicted_indices=ind,
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + suffix,
predicted_indices=ind,
)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + suffix,
predicted_indices=ind,
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(
f"val{suffix}/rec_loss",
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
self.log(
f"val{suffix}/aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
if version.parse(pl.__version__) >= version.parse("1.4.0"):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor * self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quantize.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr_g,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
},
{
"scheduler": LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
class AutoencoderKL(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val",
)
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val",
)
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

View File

@ -0,0 +1,319 @@
import os
from copy import deepcopy
from glob import glob
import pytorch_lightning as pl
import torch
from einops import rearrange
from natsort import natsorted
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from extern.ldm_zero123.modules.diffusionmodules.openaimodel import (
EncoderUNetModel,
UNetModel,
)
from extern.ldm_zero123.util import (
default,
instantiate_from_config,
ismap,
log_txt_as_img,
)
__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(
self,
diffusion_path,
num_classes,
ckpt_path=None,
pool="attention",
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.0e-2,
log_steps=10,
monitor="val/loss",
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(
glob(os.path.join(diffusion_path, "configs", "*-project.yaml"))
)[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
self.log_steps = log_steps
self.label_key = (
label_key
if not hasattr(self.diffusion_model, "cond_stage_key")
else self.diffusion_model.cond_stage_key
)
assert (
self.label_key is not None
), "label_key neither in diffusion model nor in model.params"
if self.label_key not in __models__:
raise NotImplementedError()
self.load_classifier(ckpt_path, pool)
self.scheduler_config = scheduler_config
self.use_scheduler = self.scheduler_config is not None
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = (
self.load_state_dict(sd, strict=False)
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters():
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
model_config.in_channels = (
self.diffusion_config.params.unet_config.params.out_channels
)
model_config.out_channels = self.num_classes
if self.label_key == "class_label":
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print(
"#####################################################################"
)
print(f'load from ckpt "{ckpt_path}"')
print(
"#####################################################################"
)
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
def get_x_noisy(self, x, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = (
self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(
x_start=x,
t=t,
noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@torch.no_grad()
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, "b h w c -> b c h w")
x = x.to(memory_format=torch.contiguous_format).float()
return x
@torch.no_grad()
def get_conditioning(self, batch, k=None):
if k is None:
k = self.label_key
assert k is not None, "Needs to provide label key"
targets = batch[k].to(self.device)
if self.label_key == "segmentation":
targets = rearrange(targets, "b h w c -> b c h w")
for down in range(self.numd):
h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest")
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
def compute_top_k(self, logits, labels, k, reduction="mean"):
_, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self):
# save some memory
self.diffusion_model.model.to("cpu")
@torch.no_grad()
def write_logs(self, loss, logits, targets):
log_prefix = "train" if self.training else "val"
log = {}
log[f"{log_prefix}/loss"] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction="mean"
)
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction="mean"
)
self.log_dict(
log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True
)
self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log(
"global_step", self.global_step, logger=False, on_epoch=False, prog_bar=True
)
lr = self.optimizers().param_groups[0]["lr"]
self.log("lr_abs", lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(
batch, k=self.diffusion_model.first_stage_key
)
targets = self.get_conditioning(batch)
if targets.dim() == 4:
targets = targets.argmax(dim=1)
if t is None:
t = torch.randint(
0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device
).long()
else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t)
loss = F.cross_entropy(logits, targets, reduction="none")
self.write_logs(loss.detach(), logits.detach(), targets.detach())
loss = loss.mean()
return loss, logits, x_noisy, targets
def training_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
return loss
def reset_noise_accs(self):
self.noisy_acc = {
t: {"acc@1": [], "acc@5": []}
for t in range(
0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t
)
}
def on_validation_start(self):
self.reset_noise_accs()
@torch.no_grad()
def validation_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]["acc@1"].append(
self.compute_top_k(logits, targets, k=1, reduction="mean")
)
self.noisy_acc[t]["acc@5"].append(
self.compute_top_k(logits, targets, k=5, reduction="mean")
)
return loss
def configure_optimizers(self):
optimizer = AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [optimizer], scheduler
return optimizer
@torch.no_grad()
def log_images(self, batch, N=8, *args, **kwargs):
log = dict()
x = self.get_input(batch, self.diffusion_model.first_stage_key)
log["inputs"] = x
y = self.get_conditioning(batch)
if self.label_key == "class_label":
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log["labels"] = y
if ismap(y):
log["labels"] = self.diffusion_model.to_rgb(y)
for step in range(self.log_steps):
current_time = step * self.log_time_interval
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
log[f"inputs@t{current_time}"] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
pred = rearrange(pred, "b h w c -> b c h w")
log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred)
for key in log:
log[key] = log[key][:N]
return log

488
extern/ldm_zero123/models/diffusion/ddim.py vendored Executable file
View File

@ -0,0 +1,488 @@
"""SAMPLING ONLY."""
from functools import partial
import numpy as np
import torch
from tqdm import tqdm
from extern.ldm_zero123.models.diffusion.sampling_util import (
norm_thresholding,
renorm_thresholding,
spatial_norm_thresholding,
)
from extern.ldm_zero123.modules.diffusionmodules.util import (
extract_into_tensor,
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def to(self, device):
"""Same as to in torch module
Don't really underestand why this isn't a module in the first place"""
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
new_v = getattr(self, k).to(device)
setattr(self, k, new_v)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
t_start=-1,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
timesteps = timesteps[:t_start]
intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
img, pred_x0 = outs
if callback:
img = callback(i, img, pred_x0)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
print(t, sqrt_one_minus_at, a_t)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(
self,
x0,
c,
t_enc,
use_original_steps=False,
return_intermediates=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
num_reference_steps = (
self.ddpm_num_timesteps
if use_original_steps
else self.ddim_timesteps.shape[0]
)
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc="Encoding Image"):
t = torch.full(
(x0.shape[0],), i, device=self.model.device, dtype=torch.long
)
if unconditional_guidance_scale == 1.0:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(
torch.cat((x_next, x_next)),
torch.cat((t, t)),
torch.cat((unconditional_conditioning, c)),
),
2,
)
noise_pred = e_t_uncond + unconditional_guidance_scale * (
noise_pred - e_t_uncond
)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = (
alphas_next[i].sqrt()
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
* noise_pred
)
x_next = xt_weighted + weighted_noise_pred
if (
return_intermediates
and i % (num_steps // return_intermediates) == 0
and i < num_steps - 1
):
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
if return_intermediates:
out.update({"intermediates": intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
)
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return x_dec

2689
extern/ldm_zero123/models/diffusion/ddpm.py vendored Executable file

File diff suppressed because it is too large Load Diff

383
extern/ldm_zero123/models/diffusion/plms.py vendored Executable file
View File

@ -0,0 +1,383 @@
"""SAMPLING ONLY."""
from functools import partial
import numpy as np
import torch
from tqdm import tqdm
from extern.ldm_zero123.models.diffusion.sampling_util import norm_thresholding
from extern.ldm_zero123.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
if ddim_eta != 0:
raise ValueError("ddim_eta must be 0 for PLMS")
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f"Data shape for PLMS sampling is {size}")
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
dynamic_threshold=dynamic_threshold,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
dynamic_threshold=None,
):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@ -0,0 +1,51 @@
import numpy as np
import torch
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
def renorm_thresholding(x0, value):
# renorm
pred_max = x0.max()
pred_min = x0.min()
pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
pred_x0 = 2 * pred_x0 - 1.0 # -1 ... 1
s = torch.quantile(rearrange(pred_x0, "b ... -> b (...)").abs(), value, dim=-1)
s.clamp_(min=1.0)
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
# clip by threshold
# pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
# temporary hack: numpy on cpu
pred_x0 = (
np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy())
/ s.cpu().numpy()
)
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
# re.renorm
pred_x0 = (pred_x0 + 1.0) / 2.0 # 0 ... 1
pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
return pred_x0
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)

364
extern/ldm_zero123/modules/attention.py vendored Executable file
View File

@ -0,0 +1,364 @@
import math
from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum, nn
from extern.ldm_zero123.modules.diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.lora = False
self.query_dim = query_dim
self.inner_dim = inner_dim
self.context_dim = context_dim
def setup_lora(self, rank=4, network_alpha=None):
self.lora = True
self.rank = rank
self.to_q_lora = LoRALinearLayer(self.query_dim, self.inner_dim, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(self.context_dim, self.inner_dim, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(self.context_dim, self.inner_dim, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(self.inner_dim, self.query_dim, rank, network_alpha)
self.lora_layers = nn.ModuleList()
self.lora_layers.append(self.to_q_lora)
self.lora_layers.append(self.to_k_lora)
self.lora_layers.append(self.to_v_lora)
self.lora_layers.append(self.to_out_lora)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if self.lora:
q += self.to_q_lora(x)
k += self.to_k_lora(context)
v += self.to_v_lora(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
# return self.to_out(out)
# linear proj
o = self.to_out[0](out)
if self.lora:
o += self.to_out_lora(out)
# dropout
out = self.to_out[1](o)
return out
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
# return checkpoint(
# self._forward, (x, context), self.parameters(), self.checkpoint
# )
return self._forward(x, context)
def _forward(self, x, context=None):
x = (
self.attn1(
self.norm1(x), context=context if self.disable_self_attn else None
)
+ x
)
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
x = self.proj_out(x)
return x + x_in

301
extern/ldm_zero123/modules/attention_ori.py vendored Executable file
View File

@ -0,0 +1,301 @@
import math
from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum, nn
from extern.ldm_zero123.modules.diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None):
x = (
self.attn1(
self.norm1(x), context=context if self.disable_self_attn else None
)
+ x
)
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
x = self.proj_out(x)
return x + x_in

View File

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,296 @@
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import math
import os
import numpy as np
import torch
import torch.nn as nn
from einops import repeat
from extern.ldm_zero123.util import instantiate_from_config
def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == "linear":
betas = (
torch.linspace(
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
)
** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
elif schedule == "sqrt":
betas = (
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
** 0.5
)
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == "quad":
ddim_timesteps = (
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
).astype(int)
else:
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f"Selected timesteps for ddim sampler: {steps_out}")
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose:
print(
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
)
print(
f"For the chosen value of eta, which is {eta}, "
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
)
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(timesteps, "b -> b d", d=dim)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()

View File

View File

@ -0,0 +1,102 @@
import numpy as np
import torch
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(
device=self.parameters.device
)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

82
extern/ldm_zero123/modules/ema.py vendored Executable file
View File

@ -0,0 +1,82 @@
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)

View File

712
extern/ldm_zero123/modules/encoders/modules.py vendored Executable file
View File

@ -0,0 +1,712 @@
from functools import partial
import clip
import kornia
import numpy as np
import torch
import torch.nn as nn
from extern.ldm_zero123.modules.x_transformer import ( # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
Encoder,
TransformerWrapper,
)
from extern.ldm_zero123.util import default
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class FaceClipEncoder(AbstractEncoder):
def __init__(self, augment=True, retreival_key=None):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
self.augment = augment
self.retreival_key = retreival_key
def forward(self, img):
encodings = []
with torch.no_grad():
x_offset = 125
if self.retreival_key:
# Assumes retrieved image are packed into the second half of channels
face = img[:, 3:, 190:440, x_offset : (512 - x_offset)]
other = img[:, :3, ...].clone()
else:
face = img[:, :, 190:440, x_offset : (512 - x_offset)]
other = img.clone()
if self.augment:
face = K.RandomHorizontalFlip()(face)
other[:, :, 190:440, x_offset : (512 - x_offset)] *= 0
encodings = [
self.encoder.encode(face),
self.encoder.encode(other),
]
return torch.cat(encodings, dim=1)
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros(
(1, 2, 768), device=self.encoder.model.visual.conv1.weight.device
)
return self(img)
class FaceIdClipEncoder(AbstractEncoder):
def __init__(self):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
for p in self.encoder.parameters():
p.requires_grad = False
self.id = FrozenFaceEncoder(
"/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True
)
def forward(self, img):
encodings = []
with torch.no_grad():
face = kornia.geometry.resize(
img, (256, 256), interpolation="bilinear", align_corners=True
)
other = img.clone()
other[:, :, 184:452, 122:396] *= 0
encodings = [
self.id.encode(face),
self.encoder.encode(other),
]
return torch.cat(encodings, dim=1)
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros(
(1, 2, 768), device=self.encoder.model.visual.conv1.weight.device
)
return self(img)
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key="class"):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
def forward(self, batch, key=None):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
c = self.embedding(c)
return c
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
)
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, x):
return self(x)
class BERTTokenizer(AbstractEncoder):
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
return tokens
@torch.no_grad()
def encode(self, text):
tokens = self(text)
if not self.vq_interface:
return tokens
return None, None, [None, None, tokens]
def decode(self, text):
return text
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(
self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device="cuda",
use_tokenizer=True,
embedding_dropout=0.0,
):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.device = device
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout,
)
def forward(self, text):
if self.use_tknz_fn:
tokens = self.tknz_fn(text) # .to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, text):
# output of length 77
return self(text)
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(
self, version="google/t5-v1_1-large", device="cuda", max_length=77
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
import kornia.augmentation as K
from extern.ldm_zero123.thirdp.psp.id_loss import IDFeatures
class FrozenFaceEncoder(AbstractEncoder):
def __init__(self, model_path, augment=False):
super().__init__()
self.loss_fn = IDFeatures(model_path)
# face encoder is frozen
for p in self.loss_fn.parameters():
p.requires_grad = False
# Mapper is trainable
self.mapper = torch.nn.Linear(512, 768)
p = 0.25
if augment:
self.augment = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomEqualize(p=p),
# K.RandomPlanckianJitter(p=p),
# K.RandomPlasmaBrightness(p=p),
# K.RandomPlasmaContrast(p=p),
# K.ColorJiggle(0.02, 0.2, 0.2, p=p),
)
else:
self.augment = False
def forward(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
if self.augment is not None:
# Transforms require 0-1
img = self.augment((img + 1) / 2)
img = 2 * img - 1
feat = self.loss_fn(img, crop=True)
feat = self.mapper(feat.unsqueeze(1))
return feat
def encode(self, img):
return self(img)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(
self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77
): # clip-vit-base-patch32
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
import torch.nn.functional as F
from transformers import CLIPVisionModel
class ClipImageProjector(AbstractEncoder):
"""
Uses the CLIP image encoder.
"""
def __init__(
self, version="openai/clip-vit-large-patch14", max_length=77
): # clip-vit-base-patch32
super().__init__()
self.model = CLIPVisionModel.from_pretrained(version)
self.model.train()
self.max_length = max_length # TODO: typical value?
self.antialias = True
self.mapper = torch.nn.Linear(1024, 768)
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
null_cond = self.get_null_cond(version, max_length)
self.register_buffer("null_cond", null_cond)
@torch.no_grad()
def get_null_cond(self, version, max_length):
device = self.mean.device
embedder = FrozenCLIPEmbedder(
version=version, device=device, max_length=max_length
)
null_cond = embedder([""])
return null_cond
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
if isinstance(x, list):
return self.null_cond
# x is assumed to be in range [-1,1]
x = self.preprocess(x)
outputs = self.model(pixel_values=x)
last_hidden_state = outputs.last_hidden_state
last_hidden_state = self.mapper(last_hidden_state)
return F.pad(
last_hidden_state,
[0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0],
)
def encode(self, im):
return self(im)
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
def __init__(
self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77
): # clip-vit-base-patch32
super().__init__()
self.embedder = FrozenCLIPEmbedder(
version=version, device=device, max_length=max_length
)
self.projection = torch.nn.Linear(768, 768)
def forward(self, text):
z = self.embedder(text)
return self.projection(z)
def encode(self, text):
return self(text)
class FrozenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model="ViT-L/14",
jit=False,
device="cpu",
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit, download_root=None)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, 768, device=device)
return self.model.encode_image(self.preprocess(x)).float()
def encode(self, im):
return self(im).unsqueeze(1)
import random
from torchvision import transforms
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model="ViT-L/14",
jit=False,
device="cpu",
antialias=True,
max_crops=5,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
self.max_crops = max_crops
def preprocess(self, x):
# Expects inputs in the range -1, 1
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1, 1))
max_crops = self.max_crops
patches = []
crops = [randcrop(x) for _ in range(max_crops)]
patches.extend(crops)
x = torch.cat(patches, dim=0)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, self.max_crops, 768, device=device)
batch_tokens = []
for im in x:
patches = self.preprocess(im.unsqueeze(0))
tokens = self.model.encode_image(patches).float()
for t in tokens:
if random.random() < 0.1:
t *= 0
batch_tokens.append(tokens.unsqueeze(0))
return torch.cat(batch_tokens, dim=0)
def encode(self, im):
return self(im)
class SpatialRescaler(nn.Module):
def __init__(
self,
n_stages=1,
method="bilinear",
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in [
"nearest",
"linear",
"bilinear",
"trilinear",
"bicubic",
"area",
]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
from extern.ldm_zero123.modules.diffusionmodules.util import (
extract_into_tensor,
make_beta_schedule,
noise_like,
)
from extern.ldm_zero123.util import instantiate_from_config
class LowScaleEncoder(nn.Module):
def __init__(
self,
model_config,
linear_start,
linear_end,
timesteps=1000,
max_noise_level=250,
output_size=64,
scale_factor=1.0,
):
super().__init__()
self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(
timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
)
self.out_size = output_size
self.scale_factor = scale_factor
def register_schedule(
self,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def forward(self, x):
z = self.model.encode(x).sample()
z = z * self.scale_factor
noise_level = torch.randint(
0, self.max_noise_level, (x.shape[0],), device=x.device
).long()
z = self.q_sample(z, noise_level)
if self.out_size is not None:
z = torch.nn.functional.interpolate(
z, size=self.out_size, mode="nearest"
) # TODO: experiment with mode
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level
def decode(self, z):
z = z / self.scale_factor
return self.model.decode(z)
if __name__ == "__main__":
from extern.ldm_zero123.util import count_params
sentences = [
"a hedgehog drinking a whiskey",
"der mond ist aufgegangen",
"Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'",
]
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
model = FrozenCLIPEmbedder().cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
print("done.")

View File

@ -0,0 +1,703 @@
import argparse
import io
import os
import random
import warnings
import zipfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import partial
from multiprocessing import cpu_count
from multiprocessing.pool import ThreadPool
from typing import Iterable, Optional, Tuple
import numpy as np
import requests
import tensorflow.compat.v1 as tf
import yaml
from scipy import linalg
from tqdm.auto import tqdm
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
FID_POOL_NAME = "pool_3:0"
FID_SPATIAL_NAME = "mixed_6/conv:0"
REQUIREMENTS = (
f"This script has the following requirements: \n"
"tensorflow-gpu>=2.0" + "\n" + "scipy" + "\n" + "requests" + "\n" + "tqdm"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ref_batch", help="path to reference batch npz file")
parser.add_argument("--sample_batch", help="path to sample batch npz file")
args = parser.parse_args()
config = tf.ConfigProto(
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
)
config.gpu_options.allow_growth = True
evaluator = Evaluator(tf.Session(config=config))
print("warming up TensorFlow...")
# This will cause TF to print a bunch of verbose stuff now rather
# than after the next print(), to help prevent confusion.
evaluator.warmup()
print("computing reference batch activations...")
ref_acts = evaluator.read_activations(args.ref_batch)
print("computing/reading reference batch statistics...")
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
print("computing sample batch activations...")
sample_acts = evaluator.read_activations(args.sample_batch)
print("computing/reading sample batch statistics...")
sample_stats, sample_stats_spatial = evaluator.read_statistics(
args.sample_batch, sample_acts
)
print("Computing evaluations...")
is_ = evaluator.compute_inception_score(sample_acts[0])
print("Inception Score:", is_)
fid = sample_stats.frechet_distance(ref_stats)
print("FID:", fid)
sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial)
print("sFID:", sfid)
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
print("Precision:", prec)
print("Recall:", recall)
savepath = "/".join(args.sample_batch.split("/")[:-1])
results_file = os.path.join(savepath, "evaluation_metrics.yaml")
print(f'Saving evaluation results to "{results_file}"')
results = {
"IS": is_,
"FID": fid,
"sFID": sfid,
"Precision:": prec,
"Recall": recall,
}
with open(results_file, "w") as f:
yaml.dump(results, f, default_flow_style=False)
class InvalidFIDException(Exception):
pass
class FIDStatistics:
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
self.mu = mu
self.sigma = sigma
def frechet_distance(self, other, eps=1e-6):
"""
Compute the Frechet distance between two sets of statistics.
"""
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
mu1, sigma1 = self.mu, self.sigma
mu2, sigma2 = other.mu, other.sigma
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
assert (
sigma1.shape == sigma2.shape
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
% eps
)
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
class Evaluator:
def __init__(
self,
session,
batch_size=64,
softmax_batch_size=512,
):
self.sess = session
self.batch_size = batch_size
self.softmax_batch_size = softmax_batch_size
self.manifold_estimator = ManifoldEstimator(session)
with self.sess.graph.as_default():
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
self.pool_features, self.spatial_features = _create_feature_graph(
self.image_input
)
self.softmax = _create_softmax_graph(self.softmax_input)
def warmup(self):
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
with open_npz_array(npz_path, "arr_0") as reader:
return self.compute_activations(reader.read_batches(self.batch_size))
def compute_activations(
self, batches: Iterable[np.ndarray], silent=False
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute image features for downstream evals.
:param batches: a iterator over NHWC numpy arrays in [0, 255].
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
dimension. The tuple is (pool_3, spatial).
"""
preds = []
spatial_preds = []
it = batches if silent else tqdm(batches)
for batch in it:
batch = batch.astype(np.float32)
pred, spatial_pred = self.sess.run(
[self.pool_features, self.spatial_features], {self.image_input: batch}
)
preds.append(pred.reshape([pred.shape[0], -1]))
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
return (
np.concatenate(preds, axis=0),
np.concatenate(spatial_preds, axis=0),
)
def read_statistics(
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
) -> Tuple[FIDStatistics, FIDStatistics]:
obj = np.load(npz_path)
if "mu" in list(obj.keys()):
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
obj["mu_s"], obj["sigma_s"]
)
return tuple(self.compute_statistics(x) for x in activations)
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
return FIDStatistics(mu, sigma)
def compute_inception_score(
self, activations: np.ndarray, split_size: int = 5000
) -> float:
softmax_out = []
for i in range(0, len(activations), self.softmax_batch_size):
acts = activations[i : i + self.softmax_batch_size]
softmax_out.append(
self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})
)
preds = np.concatenate(softmax_out, axis=0)
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
scores = []
for i in range(0, len(preds), split_size):
part = preds[i : i + split_size]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return float(np.mean(scores))
def compute_prec_recall(
self, activations_ref: np.ndarray, activations_sample: np.ndarray
) -> Tuple[float, float]:
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
pr = self.manifold_estimator.evaluate_pr(
activations_ref, radii_1, activations_sample, radii_2
)
return (float(pr[0][0]), float(pr[1][0]))
class ManifoldEstimator:
"""
A helper for comparing manifolds of feature vectors.
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
"""
def __init__(
self,
session,
row_batch_size=10000,
col_batch_size=10000,
nhood_sizes=(3,),
clamp_to_percentile=None,
eps=1e-5,
):
"""
Estimate the manifold of given feature vectors.
:param session: the TensorFlow session.
:param row_batch_size: row batch size to compute pairwise distances
(parameter to trade-off between memory usage and performance).
:param col_batch_size: column batch size to compute pairwise distances.
:param nhood_sizes: number of neighbors used to estimate the manifold.
:param clamp_to_percentile: prune hyperspheres that have radius larger than
the given percentile.
:param eps: small number for numerical stability.
"""
self.distance_block = DistanceBlock(session)
self.row_batch_size = row_batch_size
self.col_batch_size = col_batch_size
self.nhood_sizes = nhood_sizes
self.num_nhoods = len(nhood_sizes)
self.clamp_to_percentile = clamp_to_percentile
self.eps = eps
def warmup(self):
feats, radii = (
np.zeros([1, 2048], dtype=np.float32),
np.zeros([1, 1], dtype=np.float32),
)
self.evaluate_pr(feats, radii, feats, radii)
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
num_images = len(features)
# Estimate manifold of features by calculating distances to k-NN of each sample.
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
for begin1 in range(0, num_images, self.row_batch_size):
end1 = min(begin1 + self.row_batch_size, num_images)
row_batch = features[begin1:end1]
for begin2 in range(0, num_images, self.col_batch_size):
end2 = min(begin2 + self.col_batch_size, num_images)
col_batch = features[begin2:end2]
# Compute distances between batches.
distance_batch[
0 : end1 - begin1, begin2:end2
] = self.distance_block.pairwise_distances(row_batch, col_batch)
# Find the k-nearest neighbor from the current batch.
radii[begin1:end1, :] = np.concatenate(
[
x[:, self.nhood_sizes]
for x in _numpy_partition(
distance_batch[0 : end1 - begin1, :], seq, axis=1
)
],
axis=0,
)
if self.clamp_to_percentile is not None:
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
radii[radii > max_distances] = 0
return radii
def evaluate(
self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray
):
"""
Evaluate if new feature vectors are at the manifold.
"""
num_eval_images = eval_features.shape[0]
num_ref_images = radii.shape[0]
distance_batch = np.zeros(
[self.row_batch_size, num_ref_images], dtype=np.float32
)
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
for begin1 in range(0, num_eval_images, self.row_batch_size):
end1 = min(begin1 + self.row_batch_size, num_eval_images)
feature_batch = eval_features[begin1:end1]
for begin2 in range(0, num_ref_images, self.col_batch_size):
end2 = min(begin2 + self.col_batch_size, num_ref_images)
ref_batch = features[begin2:end2]
distance_batch[
0 : end1 - begin1, begin2:end2
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
# If a feature vector is inside a hypersphere of some reference sample, then
# the new sample lies at the estimated manifold.
# The radii of the hyperspheres are determined from distances of neighborhood size k.
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(
np.int32
)
max_realism_score[begin1:end1] = np.max(
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
)
nearest_indices[begin1:end1] = np.argmin(
distance_batch[0 : end1 - begin1, :], axis=1
)
return {
"fraction": float(np.mean(batch_predictions)),
"batch_predictions": batch_predictions,
"max_realisim_score": max_realism_score,
"nearest_indices": nearest_indices,
}
def evaluate_pr(
self,
features_1: np.ndarray,
radii_1: np.ndarray,
features_2: np.ndarray,
radii_2: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Evaluate precision and recall efficiently.
:param features_1: [N1 x D] feature vectors for reference batch.
:param radii_1: [N1 x K1] radii for reference vectors.
:param features_2: [N2 x D] feature vectors for the other batch.
:param radii_2: [N x K2] radii for other vectors.
:return: a tuple of arrays for (precision, recall):
- precision: an np.ndarray of length K1
- recall: an np.ndarray of length K2
"""
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
for begin_1 in range(0, len(features_1), self.row_batch_size):
end_1 = begin_1 + self.row_batch_size
batch_1 = features_1[begin_1:end_1]
for begin_2 in range(0, len(features_2), self.col_batch_size):
end_2 = begin_2 + self.col_batch_size
batch_2 = features_2[begin_2:end_2]
batch_1_in, batch_2_in = self.distance_block.less_thans(
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
)
features_1_status[begin_1:end_1] |= batch_1_in
features_2_status[begin_2:end_2] |= batch_2_in
return (
np.mean(features_2_status.astype(np.float64), axis=0),
np.mean(features_1_status.astype(np.float64), axis=0),
)
class DistanceBlock:
"""
Calculate pairwise distances between vectors.
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
"""
def __init__(self, session):
self.session = session
# Initialize TF graph to calculate pairwise distances.
with session.graph.as_default():
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
distance_block_16 = _batch_pairwise_distances(
tf.cast(self._features_batch1, tf.float16),
tf.cast(self._features_batch2, tf.float16),
)
self.distance_block = tf.cond(
tf.reduce_all(tf.math.is_finite(distance_block_16)),
lambda: tf.cast(distance_block_16, tf.float32),
lambda: _batch_pairwise_distances(
self._features_batch1, self._features_batch2
),
)
# Extra logic for less thans.
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
self._batch_2_in = tf.math.reduce_any(
dist32 <= self._radii1[:, None], axis=0
)
def pairwise_distances(self, U, V):
"""
Evaluate pairwise distances between two batches of feature vectors.
"""
return self.session.run(
self.distance_block,
feed_dict={self._features_batch1: U, self._features_batch2: V},
)
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
return self.session.run(
[self._batch_1_in, self._batch_2_in],
feed_dict={
self._features_batch1: batch_1,
self._features_batch2: batch_2,
self._radii1: radii_1,
self._radii2: radii_2,
},
)
def _batch_pairwise_distances(U, V):
"""
Compute pairwise distances between two batches of feature vectors.
"""
with tf.variable_scope("pairwise_dist_block"):
# Squared norms of each row in U and V.
norm_u = tf.reduce_sum(tf.square(U), 1)
norm_v = tf.reduce_sum(tf.square(V), 1)
# norm_u as a column and norm_v as a row vectors.
norm_u = tf.reshape(norm_u, [-1, 1])
norm_v = tf.reshape(norm_v, [1, -1])
# Pairwise squared Euclidean distances.
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
return D
class NpzArrayReader(ABC):
@abstractmethod
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
pass
@abstractmethod
def remaining(self) -> int:
pass
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
def gen_fn():
while True:
batch = self.read_batch(batch_size)
if batch is None:
break
yield batch
rem = self.remaining()
num_batches = rem // batch_size + int(rem % batch_size != 0)
return BatchIterator(gen_fn, num_batches)
class BatchIterator:
def __init__(self, gen_fn, length):
self.gen_fn = gen_fn
self.length = length
def __len__(self):
return self.length
def __iter__(self):
return self.gen_fn()
class StreamingNpzArrayReader(NpzArrayReader):
def __init__(self, arr_f, shape, dtype):
self.arr_f = arr_f
self.shape = shape
self.dtype = dtype
self.idx = 0
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.shape[0]:
return None
bs = min(batch_size, self.shape[0] - self.idx)
self.idx += bs
if self.dtype.itemsize == 0:
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
read_count = bs * np.prod(self.shape[1:])
read_size = int(read_count * self.dtype.itemsize)
data = _read_bytes(self.arr_f, read_size, "array data")
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
def remaining(self) -> int:
return max(0, self.shape[0] - self.idx)
class MemoryNpzArrayReader(NpzArrayReader):
def __init__(self, arr):
self.arr = arr
self.idx = 0
@classmethod
def load(cls, path: str, arr_name: str):
with open(path, "rb") as f:
arr = np.load(f)[arr_name]
return cls(arr)
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.arr.shape[0]:
return None
res = self.arr[self.idx : self.idx + batch_size]
self.idx += batch_size
return res
def remaining(self) -> int:
return max(0, self.arr.shape[0] - self.idx)
@contextmanager
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
with _open_npy_file(path, arr_name) as arr_f:
version = np.lib.format.read_magic(arr_f)
if version == (1, 0):
header = np.lib.format.read_array_header_1_0(arr_f)
elif version == (2, 0):
header = np.lib.format.read_array_header_2_0(arr_f)
else:
yield MemoryNpzArrayReader.load(path, arr_name)
return
shape, fortran, dtype = header
if fortran or dtype.hasobject:
yield MemoryNpzArrayReader.load(path, arr_name)
else:
yield StreamingNpzArrayReader(arr_f, shape, dtype)
def _read_bytes(fp, size, error_template="ran out of data"):
"""
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
Read from file-like object until size bytes are read.
Raises ValueError if not EOF is encountered before size bytes are read.
Non-blocking objects only supported if they derive from io objects.
Required as e.g. ZipExtFile in python 2.6 can return less data than
requested.
"""
data = bytes()
while True:
# io files (default in python3) return None or raise on
# would-block, python2 file will truncate, probably nothing can be
# done about that. note that regular files can't be non-blocking
try:
r = fp.read(size - len(data))
data += r
if len(r) == 0 or len(data) == size:
break
except io.BlockingIOError:
pass
if len(data) != size:
msg = "EOF: reading %s, expected %d bytes got %d"
raise ValueError(msg % (error_template, size, len(data)))
else:
return data
@contextmanager
def _open_npy_file(path: str, arr_name: str):
with open(path, "rb") as f:
with zipfile.ZipFile(f, "r") as zip_f:
if f"{arr_name}.npy" not in zip_f.namelist():
raise ValueError(f"missing {arr_name} in npz file")
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
yield arr_f
def _download_inception_model():
if os.path.exists(INCEPTION_V3_PATH):
return
print("downloading InceptionV3 model...")
with requests.get(INCEPTION_V3_URL, stream=True) as r:
r.raise_for_status()
tmp_path = INCEPTION_V3_PATH + ".tmp"
with open(tmp_path, "wb") as f:
for chunk in tqdm(r.iter_content(chunk_size=8192)):
f.write(chunk)
os.rename(tmp_path, INCEPTION_V3_PATH)
def _create_feature_graph(input_batch):
_download_inception_model()
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
with open(INCEPTION_V3_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
pool3, spatial = tf.import_graph_def(
graph_def,
input_map={f"ExpandDims:0": input_batch},
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
name=prefix,
)
_update_shapes(pool3)
spatial = spatial[..., :7]
return pool3, spatial
def _create_softmax_graph(input_batch):
_download_inception_model()
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
with open(INCEPTION_V3_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
(matmul,) = tf.import_graph_def(
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
)
w = matmul.inputs[1]
logits = tf.matmul(input_batch, w)
return tf.nn.softmax(logits)
def _update_shapes(pool3):
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
ops = pool3.graph.get_operations()
for op in ops:
for o in op.outputs:
shape = o.get_shape()
if shape._dims is not None: # pylint: disable=protected-access
# shape = [s.value for s in shape] TF 1.x
shape = [s for s in shape] # TF 2.x
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
return pool3
def _numpy_partition(arr, kth, **kwargs):
num_workers = min(cpu_count(), len(arr))
chunk_size = len(arr) // num_workers
extra = len(arr) % num_workers
start_idx = 0
batches = []
for i in range(num_workers):
size = chunk_size + (1 if i < extra else 0)
batches.append(arr[start_idx : start_idx + size])
start_idx += size
with ThreadPool(num_workers) as pool:
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
if __name__ == "__main__":
print(REQUIREMENTS)
main()

View File

@ -0,0 +1,606 @@
import argparse
import glob
import os
from collections import namedtuple
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
from tqdm import tqdm
from extern.ldm_zero123.modules.evaluate.ssim import ssim
transform = transforms.Compose([transforms.ToTensor()])
def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1)).view(
in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]
)
return in_feat / (norm_factor.expand_as(in_feat) + eps)
def cos_sim(in0, in1):
in0_norm = normalize_tensor(in0)
in1_norm = normalize_tensor(in1)
N = in0.size()[0]
X = in0.size()[2]
Y = in0.size()[3]
return torch.mean(
torch.mean(torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2).view(
N, 1, 1, Y
),
dim=3,
).view(N)
class squeezenet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(squeezenet, self).__init__()
pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.slice6 = torch.nn.Sequential()
self.slice7 = torch.nn.Sequential()
self.N_slices = 7
for x in range(2):
self.slice1.add_module(str(x), pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), pretrained_features[x])
for x in range(10, 11):
self.slice5.add_module(str(x), pretrained_features[x])
for x in range(11, 12):
self.slice6.add_module(str(x), pretrained_features[x])
for x in range(12, 13):
self.slice7.add_module(str(x), pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
h = self.slice6(h)
h_relu6 = h
h = self.slice7(h)
h_relu7 = h
vgg_outputs = namedtuple(
"SqueezeOutputs",
["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
)
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
return out
class alexnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(alexnet, self).__init__()
alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(2):
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple(
"AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
)
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs",
["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"],
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class resnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, num=18):
super(resnet, self).__init__()
if num == 18:
self.net = models.resnet18(pretrained=pretrained)
elif num == 34:
self.net = models.resnet34(pretrained=pretrained)
elif num == 50:
self.net = models.resnet50(pretrained=pretrained)
elif num == 101:
self.net = models.resnet101(pretrained=pretrained)
elif num == 152:
self.net = models.resnet152(pretrained=pretrained)
self.N_slices = 5
self.conv1 = self.net.conv1
self.bn1 = self.net.bn1
self.relu = self.net.relu
self.maxpool = self.net.maxpool
self.layer1 = self.net.layer1
self.layer2 = self.net.layer2
self.layer3 = self.net.layer3
self.layer4 = self.net.layer4
def forward(self, X):
h = self.conv1(X)
h = self.bn1(h)
h = self.relu(h)
h_relu1 = h
h = self.maxpool(h)
h = self.layer1(h)
h_conv2 = h
h = self.layer2(h)
h_conv3 = h
h = self.layer3(h)
h_conv4 = h
h = self.layer4(h)
h_conv5 = h
outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"])
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
return out
# Off-the-shelf deep network
class PNet(torch.nn.Module):
"""Pre-trained network with all channels equally weighted by default"""
def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True):
super(PNet, self).__init__()
self.use_gpu = use_gpu
self.pnet_type = pnet_type
self.pnet_rand = pnet_rand
self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)
self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)
if self.pnet_type in ["vgg", "vgg16"]:
self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False)
elif self.pnet_type == "alex":
self.net = alexnet(pretrained=not self.pnet_rand, requires_grad=False)
elif self.pnet_type[:-2] == "resnet":
self.net = resnet(
pretrained=not self.pnet_rand,
requires_grad=False,
num=int(self.pnet_type[-2:]),
)
elif self.pnet_type == "squeeze":
self.net = squeezenet(pretrained=not self.pnet_rand, requires_grad=False)
self.L = self.net.N_slices
if use_gpu:
self.net.cuda()
self.shift = self.shift.cuda()
self.scale = self.scale.cuda()
def forward(self, in0, in1, retPerLayer=False):
in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
outs0 = self.net.forward(in0_sc)
outs1 = self.net.forward(in1_sc)
if retPerLayer:
all_scores = []
for kk, out0 in enumerate(outs0):
cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk])
if kk == 0:
val = 1.0 * cur_score
else:
val = val + cur_score
if retPerLayer:
all_scores += [cur_score]
if retPerLayer:
return (val, all_scores)
else:
return val
# The SSIM metric
def ssim_metric(img1, img2, mask=None):
return ssim(img1, img2, mask=mask, size_average=False)
# The PSNR metric
def psnr(img1, img2, mask=None, reshape=False):
b = img1.size(0)
if not (mask is None):
b = img1.size(0)
mse_err = (img1 - img2).pow(2) * mask
if reshape:
mse_err = mse_err.reshape(b, -1).sum(dim=1) / (
3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1)
)
else:
mse_err = mse_err.view(b, -1).sum(dim=1) / (
3 * mask.view(b, -1).sum(dim=1).clamp(min=1)
)
else:
if reshape:
mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1)
else:
mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1)
psnr = 10 * (1 / mse_err).log10()
return psnr
# The perceptual similarity metric
def perceptual_sim(img1, img2, vgg16):
# First extract features
dist = vgg16(img1 * 2 - 1, img2 * 2 - 1)
return dist
def load_img(img_name, size=None):
try:
img = Image.open(img_name)
if type(size) == int:
img = img.resize((size, size))
elif size is not None:
img = img.resize((size[1], size[0]))
img = transform(img).cuda()
img = img.unsqueeze(0)
except Exception as e:
print("Failed at loading %s " % img_name)
print(e)
img = torch.zeros(1, 3, 256, 256).cuda()
raise
return img
def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
folders = os.listdir(folder)
for i, f in tqdm(enumerate(sorted(folders))):
pred_imgs = glob.glob(folder + f + "/" + pred_img)
tgt_imgs = glob.glob(folder + f + "/" + tgt_img)
assert len(tgt_imgs) == 1
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
for p_img in pred_imgs:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
perc_sim = min(perc_sim, t_perc_sim)
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
values_psnr += [psnr_sim]
if take_every_other:
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
return {
"Perceptual similarity": (avg_percsim, std_percsim),
"PSNR": (avg_psnr, std_psnr),
"SSIM": (avg_ssim, std_ssim),
}
def compute_perceptual_similarity_from_list(
pred_imgs_list, tgt_imgs_list, take_every_other, simple_format=True
):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
equal_count = 0
ambig_count = 0
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
pred_imgs = pred_imgs_list[i]
tgt_imgs = [tgt_img]
assert len(tgt_imgs) == 1
if type(pred_imgs) != list:
pred_imgs = [pred_imgs]
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
assert len(pred_imgs) > 0
for p_img in pred_imgs:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
perc_sim = min(perc_sim, t_perc_sim)
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
if psnr_sim != np.float("inf"):
values_psnr += [psnr_sim]
else:
if torch.allclose(p_img, t_img):
equal_count += 1
print("{} equal src and wrp images.".format(equal_count))
else:
ambig_count += 1
print("{} ambiguous src and wrp images.".format(ambig_count))
if take_every_other:
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
if simple_format:
# just to make yaml formatting readable
return {
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
"PSNR": [float(avg_psnr), float(std_psnr)],
"SSIM": [float(avg_ssim), float(std_ssim)],
}
else:
return {
"Perceptual similarity": (avg_percsim, std_percsim),
"PSNR": (avg_psnr, std_psnr),
"SSIM": (avg_ssim, std_ssim),
}
def compute_perceptual_similarity_from_list_topk(
pred_imgs_list, tgt_imgs_list, take_every_other, resize=False
):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
individual_percsim = []
individual_ssim = []
individual_psnr = []
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
pred_imgs = pred_imgs_list[i]
tgt_imgs = [tgt_img]
assert len(tgt_imgs) == 1
if type(pred_imgs) != list:
assert False
pred_imgs = [pred_imgs]
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
sample_percsim = list()
sample_ssim = list()
sample_psnr = list()
for p_img in pred_imgs:
if resize:
t_img = load_img(tgt_imgs[0], size=(256, 256))
else:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
sample_percsim.append(t_perc_sim)
perc_sim = min(perc_sim, t_perc_sim)
t_ssim = ssim_metric(p_img, t_img).item()
sample_ssim.append(t_ssim)
ssim_sim = max(ssim_sim, t_ssim)
t_psnr = psnr(p_img, t_img).item()
sample_psnr.append(t_psnr)
psnr_sim = max(psnr_sim, t_psnr)
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
values_psnr += [psnr_sim]
individual_percsim.append(sample_percsim)
individual_ssim.append(sample_ssim)
individual_psnr.append(sample_psnr)
if take_every_other:
assert False, "Do this later, after specifying topk to get proper results"
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
individual_percsim = np.array(individual_percsim)
individual_psnr = np.array(individual_psnr)
individual_ssim = np.array(individual_ssim)
return {
"avg_of_best": {
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
"PSNR": [float(avg_psnr), float(std_psnr)],
"SSIM": [float(avg_ssim), float(std_ssim)],
},
"individual": {
"PSIM": individual_percsim,
"PSNR": individual_psnr,
"SSIM": individual_ssim,
},
}
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--folder", type=str, default="")
args.add_argument("--pred_image", type=str, default="")
args.add_argument("--target_image", type=str, default="")
args.add_argument("--take_every_other", action="store_true", default=False)
args.add_argument("--output_file", type=str, default="")
opts = args.parse_args()
folder = opts.folder
pred_img = opts.pred_image
tgt_img = opts.target_image
results = compute_perceptual_similarity(
folder, pred_img, tgt_img, opts.take_every_other
)
f = open(opts.output_file, "w")
for key in results:
print("%s for %s: \n" % (key, opts.folder))
print("\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]))
f.write("%s for %s: \n" % (key, opts.folder))
f.write("\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]))
f.close()

View File

@ -0,0 +1,147 @@
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python2, python3
"""Minimal Reference implementation for the Frechet Video Distance (FVD).
FVD is a metric for the quality of video generation models. It is inspired by
the FID (Frechet Inception Distance) used for images, but uses a different
embedding to be better suitable for videos.
"""
from __future__ import absolute_import, division, print_function
import six
import tensorflow.compat.v1 as tf
import tensorflow_gan as tfgan
import tensorflow_hub as hub
def preprocess(videos, target_resolution):
"""Runs some preprocessing on the videos for I3D model.
Args:
videos: <T>[batch_size, num_frames, height, width, depth] The videos to be
preprocessed. We don't care about the specific dtype of the videos, it can
be anything that tf.image.resize_bilinear accepts. Values are expected to
be in the range 0-255.
target_resolution: (width, height): target video resolution
Returns:
videos: <float32>[batch_size, num_frames, height, width, depth]
"""
videos_shape = list(videos.shape)
all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
output_videos = tf.reshape(resized_videos, target_shape)
scaled_videos = 2.0 * tf.cast(output_videos, tf.float32) / 255.0 - 1
return scaled_videos
def _is_in_graph(tensor_name):
"""Checks whether a given tensor does exists in the graph."""
try:
tf.get_default_graph().get_tensor_by_name(tensor_name)
except KeyError:
return False
return True
def create_id3_embedding(videos, warmup=False, batch_size=16):
"""Embeds the given videos using the Inflated 3D Convolution ne twork.
Downloads the graph of the I3D from tf.hub and adds it to the graph on the
first call.
Args:
videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3].
Expected range is [-1, 1].
Returns:
embedding: <float32>[batch_size, embedding_size]. embedding_size depends
on the model used.
Raises:
ValueError: when a provided embedding_layer is not supported.
"""
# batch_size = 16
module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
# Making sure that we import the graph separately for
# each different input video tensor.
module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(videos.name).replace(
":", "_"
)
assert_ops = [
tf.Assert(
tf.reduce_max(videos) <= 1.001, ["max value in frame is > 1", videos]
),
tf.Assert(
tf.reduce_min(videos) >= -1.001, ["min value in frame is < -1", videos]
),
tf.assert_equal(
tf.shape(videos)[0],
batch_size,
["invalid frame batch size: ", tf.shape(videos)],
summarize=6,
),
]
with tf.control_dependencies(assert_ops):
videos = tf.identity(videos)
module_scope = "%s_apply_default/" % module_name
# To check whether the module has already been loaded into the graph, we look
# for a given tensor name. If this tensor name exists, we assume the function
# has been called before and the graph was imported. Otherwise we import it.
# Note: in theory, the tensor could exist, but have wrong shapes.
# This will happen if create_id3_embedding is called with a frames_placehoder
# of wrong size/batch size, because even though that will throw a tf.Assert
# on graph-execution time, it will insert the tensor (with wrong shape) into
# the graph. This is why we need the following assert.
if warmup:
video_batch_size = int(videos.shape[0])
assert video_batch_size in [
batch_size,
-1,
None,
], f"Invalid batch size {video_batch_size}"
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
if not _is_in_graph(tensor_name):
i3d_model = hub.Module(module_spec, name=module_name)
i3d_model(videos)
# gets the kinetics-i3d-400-logits layer
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
return tensor
def calculate_fvd(real_activations, generated_activations):
"""Returns a list of ops that compute metrics as funcs of activations.
Args:
real_activations: <float32>[num_samples, embedding_size]
generated_activations: <float32>[num_samples, embedding_size]
Returns:
A scalar that contains the requested FVD.
"""
return tfgan.eval.frechet_classifier_distance_from_activations(
real_activations, generated_activations
)

118
extern/ldm_zero123/modules/evaluate/ssim.py vendored Executable file
View File

@ -0,0 +1,118 @@
# MIT Licence
# Methods to predict the SSIM, taken from
# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
from math import exp
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def gaussian(window_size, sigma):
gauss = torch.Tensor(
[
exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
for x in range(window_size)
]
)
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(
_2D_window.expand(channel, 1, window_size, window_size).contiguous()
)
return window
def _ssim(img1, img2, window, window_size, channel, mask=None, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = (
F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
)
sigma2_sq = (
F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
)
sigma12 = (
F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
- mu1_mu2
)
C1 = (0.01) ** 2
C2 = (0.03) ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
if not (mask is None):
b = mask.size(0)
ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(dim=1).clamp(
min=1
)
return ssim_map
import pdb
pdb.set_trace
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2, mask=None):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(
img1,
img2,
window,
self.window_size,
channel,
mask,
self.size_average,
)
def ssim(img1, img2, window_size=11, mask=None, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, mask, size_average)

View File

@ -0,0 +1,331 @@
# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!
import glob
import hashlib
import html
import io
import multiprocessing as mp
import os
import re
import urllib
import urllib.request
from typing import Any, Callable, Dict, List, Tuple, Union
import numpy as np
import requests
import scipy.linalg
import torch
from torchvision.io import read_video
from tqdm import tqdm
torch.set_grad_enabled(False)
from einops import rearrange
from nitro.util import isvideo
def compute_frechet_distance(mu_sample, sigma_sample, mu_ref, sigma_ref) -> float:
print("Calculate frechet distance...")
m = np.square(mu_sample - mu_ref).sum()
s, _ = scipy.linalg.sqrtm(
np.dot(sigma_sample, sigma_ref), disp=False
) # pylint: disable=no-member
fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2))
return float(fid)
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
mu = feats.mean(axis=0) # [d]
sigma = np.cov(feats, rowvar=False) # [d, d]
return mu, sigma
def open_url(
url: str,
num_attempts: int = 10,
verbose: bool = True,
return_filename: bool = False,
) -> Any:
"""Download the given URL and return a binary-mode file object to access the data."""
assert num_attempts >= 1
# Doesn't look like an URL scheme so interpret it as a local filename.
if not re.match("^[a-z]+://", url):
return url if return_filename else open(url, "rb")
# Handle file URLs. This code handles unusual file:// patterns that
# arise on Windows:
#
# file:///c:/foo.txt
#
# which would translate to a local '/c:/foo.txt' filename that's
# invalid. Drop the forward slash for such pathnames.
#
# If you touch this code path, you should test it on both Linux and
# Windows.
#
# Some internet resources suggest using urllib.request.url2pathname() but
# but that converts forward slashes to backslashes and this causes
# its own set of problems.
if url.startswith("file://"):
filename = urllib.parse.urlparse(url).path
if re.match(r"^/[a-zA-Z]:", filename):
filename = filename[1:]
return filename if return_filename else open(filename, "rb")
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
# Download.
url_name = None
url_data = None
with requests.Session() as session:
if verbose:
print("Downloading %s ..." % url, end="", flush=True)
for attempts_left in reversed(range(num_attempts)):
try:
with session.get(url) as res:
res.raise_for_status()
if len(res.content) == 0:
raise IOError("No data received")
if len(res.content) < 8192:
content_str = res.content.decode("utf-8")
if "download_warning" in res.headers.get("Set-Cookie", ""):
links = [
html.unescape(link)
for link in content_str.split('"')
if "export=download" in link
]
if len(links) == 1:
url = requests.compat.urljoin(url, links[0])
raise IOError("Google Drive virus checker nag")
if "Google Drive - Quota exceeded" in content_str:
raise IOError(
"Google Drive download quota exceeded -- please try again later"
)
match = re.search(
r'filename="([^"]*)"',
res.headers.get("Content-Disposition", ""),
)
url_name = match[1] if match else url
url_data = res.content
if verbose:
print(" done")
break
except KeyboardInterrupt:
raise
except:
if not attempts_left:
if verbose:
print(" failed")
raise
if verbose:
print(".", end="", flush=True)
# Return data as file object.
assert not return_filename
return io.BytesIO(url_data)
def load_video(ip):
vid, *_ = read_video(ip)
vid = rearrange(vid, "t h w c -> t c h w").to(torch.uint8)
return vid
def get_data_from_str(input_str, nprc=None):
assert os.path.isdir(
input_str
), f'Specified input folder "{input_str}" is not a directory'
vid_filelist = glob.glob(os.path.join(input_str, "*.mp4"))
print(f"Found {len(vid_filelist)} videos in dir {input_str}")
if nprc is None:
try:
nprc = mp.cpu_count()
except NotImplementedError:
print(
"WARNING: cpu_count() not avlailable, using only 1 cpu for video loading"
)
nprc = 1
pool = mp.Pool(processes=nprc)
vids = []
for v in tqdm(
pool.imap_unordered(load_video, vid_filelist),
total=len(vid_filelist),
desc="Loading videos...",
):
vids.append(v)
vids = torch.stack(vids, dim=0).float()
return vids
def get_stats(stats):
assert os.path.isfile(stats) and stats.endswith(
".npz"
), f"no stats found under {stats}"
print(f"Using precomputed statistics under {stats}")
stats = np.load(stats)
stats = {key: stats[key] for key in stats.files}
return stats
@torch.no_grad()
def compute_fvd(
ref_input, sample_input, bs=32, ref_stats=None, sample_stats=None, nprc_load=None
):
calc_stats = ref_stats is None or sample_stats is None
if calc_stats:
only_ref = sample_stats is not None
only_sample = ref_stats is not None
if isinstance(ref_input, str) and not only_sample:
ref_input = get_data_from_str(ref_input, nprc_load)
if isinstance(sample_input, str) and not only_ref:
sample_input = get_data_from_str(sample_input, nprc_load)
stats = compute_statistics(
sample_input,
ref_input,
device="cuda" if torch.cuda.is_available() else "cpu",
bs=bs,
only_ref=only_ref,
only_sample=only_sample,
)
if only_ref:
stats.update(get_stats(sample_stats))
elif only_sample:
stats.update(get_stats(ref_stats))
else:
stats = get_stats(sample_stats)
stats.update(get_stats(ref_stats))
fvd = compute_frechet_distance(**stats)
return {
"FVD": fvd,
}
@torch.no_grad()
def compute_statistics(
videos_fake,
videos_real,
device: str = "cuda",
bs=32,
only_ref=False,
only_sample=False,
) -> Dict:
detector_url = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1"
detector_kwargs = dict(
rescale=True, resize=True, return_features=True
) # Return raw features before the softmax layer.
with open_url(detector_url, verbose=False) as f:
detector = torch.jit.load(f).eval().to(device)
assert not (
only_sample and only_ref
), "only_ref and only_sample arguments are mutually exclusive"
ref_embed, sample_embed = [], []
info = f"Computing I3D activations for FVD score with batch size {bs}"
if only_ref:
if not isvideo(videos_real):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
print(videos_real.shape)
if videos_real.shape[0] % bs == 0:
n_secs = videos_real.shape[0] // bs
else:
n_secs = videos_real.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
for ref_v in tqdm(videos_real, total=len(videos_real), desc=info):
feats_ref = (
detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
)
ref_embed.append(feats_ref)
elif only_sample:
if not isvideo(videos_fake):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
print(videos_fake.shape)
if videos_fake.shape[0] % bs == 0:
n_secs = videos_fake.shape[0] // bs
else:
n_secs = videos_fake.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
for sample_v in tqdm(videos_fake, total=len(videos_real), desc=info):
feats_sample = (
detector(sample_v.to(device).contiguous(), **detector_kwargs)
.cpu()
.numpy()
)
sample_embed.append(feats_sample)
else:
if not isvideo(videos_real):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
if not isvideo(videos_fake):
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
if videos_fake.shape[0] % bs == 0:
n_secs = videos_fake.shape[0] // bs
else:
n_secs = videos_fake.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0)
for ref_v, sample_v in tqdm(
zip(videos_real, videos_fake), total=len(videos_fake), desc=info
):
# print(ref_v.shape)
# ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
# sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
feats_sample = (
detector(sample_v.to(device).contiguous(), **detector_kwargs)
.cpu()
.numpy()
)
feats_ref = (
detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
)
sample_embed.append(feats_sample)
ref_embed.append(feats_ref)
out = dict()
if len(sample_embed) > 0:
sample_embed = np.concatenate(sample_embed, axis=0)
mu_sample, sigma_sample = compute_stats(sample_embed)
out.update({"mu_sample": mu_sample, "sigma_sample": sigma_sample})
if len(ref_embed) > 0:
ref_embed = np.concatenate(ref_embed, axis=0)
mu_ref, sigma_ref = compute_stats(ref_embed)
out.update({"mu_ref": mu_ref, "sigma_ref": sigma_ref})
return out

View File

@ -0,0 +1,6 @@
from extern.ldm_zero123.modules.image_degradation.bsrgan import (
degradation_bsrgan_variant as degradation_fn_bsr,
)
from extern.ldm_zero123.modules.image_degradation.bsrgan_light import (
degradation_bsrgan_variant as degradation_fn_bsr_light,
)

View File

@ -0,0 +1,809 @@
# -*- coding: utf-8 -*-
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
import random
from functools import partial
import albumentations
import cv2
import numpy as np
import scipy
import scipy.stats as ss
import torch
from scipy import ndimage
from scipy.interpolate import interp2d
from scipy.linalg import orth
import extern.ldm_zero123.modules.image_degradation.utils_image as util
def modcrop_np(img, sf):
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[: w - w % sf, : h - h % sf, ...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def analytic_kernel(k):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size = k.shape[0]
# Calculate the big kernels size
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
# Normalize to 1
return cropped_big_k / cropped_big_k.sum()
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
"""generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v = np.dot(
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
return k
def gm_blur_kernel(mean, cov, size=15):
center = size / 2.0 + 0.5
k = np.zeros([size, size])
for y in range(size):
for x in range(size):
cy = y - center + 1
cx = x - center + 1
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
k = k / np.sum(k)
return k
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf - 1) * 0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w - 1)
y1 = np.clip(y1, 0, h - 1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x
def blur(x, k):
"""
x: image, NxcxHxW
k: kernel, Nx1xhxw
"""
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z - MU
ZZ_t = ZZ.transpose(0, 1, 3, 2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
h = h / sumh
return h
def fspecial_laplacian(alpha):
alpha = max([0, min([alpha, 1])])
h1 = alpha / (alpha + 1)
h2 = (1 - alpha) / (alpha + 1)
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
h = np.array(h)
return h
def fspecial(filter_type, *args, **kwargs):
"""
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
"""
if filter_type == "gaussian":
return fspecial_gaussian(*args, **kwargs)
if filter_type == "laplacian":
return fspecial_laplacian(*args, **kwargs)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def bicubic_degradation(x, sf=3):
"""
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
"""
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
"""blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
"""
x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode="wrap"
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
"""bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
"""
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
return x
def classical_degradation(x, k, sf=3):
"""blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype("float32")
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
K = img + weight * residual
K = np.clip(K, 0, 1)
return soft_mask * K + (1 - soft_mask) * img
def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(
ksize=2 * random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else:
k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
return img
def add_resize(img, sf=4):
rnum = np.random.rand()
if rnum > 0.8: # up
sf1 = random.uniform(1, 2)
elif rnum < 0.7: # down
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
return img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4:
img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else:
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
)
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
def add_JPEG_noise(img):
quality_factor = random.randint(30, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode(
".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :
]
return lq, hq
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f"img size ({h1}X{w1}) is too small!")
hq = img.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(
img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_blur(img, sf=sf)
elif i == 2:
a, b = img.shape[1], img.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(
img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode="mirror"
)
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
img = add_JPEG_noise(img)
elif i == 6:
# add processed camera sensor noise
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
return img, hq
# todo no isp_model?
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = util.uint2single(image)
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
if i == 0:
image = add_blur(image, sf=sf)
elif i == 1:
image = add_blur(image, sf=sf)
elif i == 2:
a, b = image.shape[1], image.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(
image,
(int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode="mirror"
)
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
image = add_JPEG_noise(image)
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image": image}
return example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def degradation_bsrgan_plus(
img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None
):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
use_shuffle: the degradation shuffle
use_sharp: sharpening the img
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
h1, w1 = img.shape[:2]
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f"img size ({h1}X{w1}) is too small!")
if use_sharp:
img = add_sharpening(img)
hq = img.copy()
if random.random() < shuffle_prob:
shuffle_order = random.sample(range(13), 13)
else:
shuffle_order = list(range(13))
# local shuffle for noise, JPEG is always the last one
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_resize(img, sf=sf)
elif i == 2:
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 3:
if random.random() < poisson_prob:
img = add_Poisson_noise(img)
elif i == 4:
if random.random() < speckle_prob:
img = add_speckle_noise(img)
elif i == 5:
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
elif i == 6:
img = add_JPEG_noise(img)
elif i == 7:
img = add_blur(img, sf=sf)
elif i == 8:
img = add_resize(img, sf=sf)
elif i == 9:
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 10:
if random.random() < poisson_prob:
img = add_Poisson_noise(img)
elif i == 11:
if random.random() < speckle_prob:
img = add_speckle_noise(img)
elif i == 12:
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
else:
print("check the shuffle!")
# resize to desired size
img = cv2.resize(
img,
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf, lq_patchsize)
return img, hq
if __name__ == "__main__":
print("hey")
img = util.imread_uint("utils/test.png", 3)
print(img)
img = util.uint2single(img)
print(img)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + ".png")

View File

@ -0,0 +1,720 @@
# -*- coding: utf-8 -*-
import random
from functools import partial
import albumentations
import cv2
import numpy as np
import scipy
import scipy.stats as ss
import torch
from scipy import ndimage
from scipy.interpolate import interp2d
from scipy.linalg import orth
import extern.ldm_zero123.modules.image_degradation.utils_image as util
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
def modcrop_np(img, sf):
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[: w - w % sf, : h - h % sf, ...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def analytic_kernel(k):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size = k.shape[0]
# Calculate the big kernels size
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
# Normalize to 1
return cropped_big_k / cropped_big_k.sum()
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
"""generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v = np.dot(
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
return k
def gm_blur_kernel(mean, cov, size=15):
center = size / 2.0 + 0.5
k = np.zeros([size, size])
for y in range(size):
for x in range(size):
cy = y - center + 1
cx = x - center + 1
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
k = k / np.sum(k)
return k
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf - 1) * 0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w - 1)
y1 = np.clip(y1, 0, h - 1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x
def blur(x, k):
"""
x: image, NxcxHxW
k: kernel, Nx1xhxw
"""
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z - MU
ZZ_t = ZZ.transpose(0, 1, 3, 2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
h = h / sumh
return h
def fspecial_laplacian(alpha):
alpha = max([0, min([alpha, 1])])
h1 = alpha / (alpha + 1)
h2 = (1 - alpha) / (alpha + 1)
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
h = np.array(h)
return h
def fspecial(filter_type, *args, **kwargs):
"""
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
"""
if filter_type == "gaussian":
return fspecial_gaussian(*args, **kwargs)
if filter_type == "laplacian":
return fspecial_laplacian(*args, **kwargs)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def bicubic_degradation(x, sf=3):
"""
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
"""
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
"""blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
"""
x = ndimage.convolve(
x, np.expand_dims(k, axis=2), mode="wrap"
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
"""bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
"""
x = bicubic_degradation(x, sf=sf)
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
return x
def classical_degradation(x, k, sf=3):
"""blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
"""
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype("float32")
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
K = img + weight * residual
K = np.clip(K, 0, 1)
return soft_mask * K + (1 - soft_mask) * img
def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
wd2 = wd2 / 4
wd = wd / 4
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(
ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2
)
else:
k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random())
img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
return img
def add_resize(img, sf=4):
rnum = np.random.rand()
if rnum > 0.8: # up
sf1 = random.uniform(1, 2)
elif rnum < 0.7: # down
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
return img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4:
img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else:
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
)
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
def add_JPEG_noise(img):
quality_factor = random.randint(80, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode(
".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :
]
return lq, hq
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f"img size ({h1}X{w1}) is too small!")
hq = img.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(
img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_blur(img, sf=sf)
elif i == 2:
a, b = img.shape[1], img.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(
img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.convolve(
img, np.expand_dims(k_shifted, axis=2), mode="mirror"
)
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
img = add_JPEG_noise(img)
elif i == 6:
# add processed camera sensor noise
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
return img, hq
# todo no isp_model?
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = util.uint2single(image)
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
if i == 0:
image = add_blur(image, sf=sf)
# elif i == 1:
# image = add_blur(image, sf=sf)
if i == 0:
pass
elif i == 2:
a, b = image.shape[1], image.shape[0]
# downsample2
if random.random() < 0.8:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(
image,
(int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.convolve(
image, np.expand_dims(k_shifted, axis=2), mode="mirror"
)
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
image = add_JPEG_noise(image)
#
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image": image}
return example
if __name__ == "__main__":
print("hey")
img = util.imread_uint("utils/test.png", 3)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_hq = img
img_lq = deg_fn(img)["image"]
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img_hq)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + ".png")

Binary file not shown.

After

Width:  |  Height:  |  Size: 431 KiB

View File

@ -0,0 +1,988 @@
import math
import os
import random
from datetime import datetime
import cv2
import numpy as np
import torch
from torchvision.utils import make_grid
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
"""
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
"""
IMG_EXTENSIONS = [
".jpg",
".JPG",
".jpeg",
".JPEG",
".png",
".PNG",
".ppm",
".PPM",
".bmp",
".BMP",
".tif",
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def get_timestamp():
return datetime.now().strftime("%y%m%d-%H%M%S")
def imshow(x, title=None, cbar=False, figsize=None):
plt.figure(figsize=figsize)
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
if title:
plt.title(title)
if cbar:
plt.colorbar()
plt.show()
def surf(Z, cmap="rainbow", figsize=None):
plt.figure(figsize=figsize)
ax3 = plt.axes(projection="3d")
w, h = Z.shape[:2]
xx = np.arange(0, w, 1)
yy = np.arange(0, h, 1)
X, Y = np.meshgrid(xx, yy)
ax3.plot_surface(X, Y, Z, cmap=cmap)
# ax3.contour(X,Y,Z, zdim='z',offset=-2cmap=cmap)
plt.show()
"""
# --------------------------------------------
# get image pathes
# --------------------------------------------
"""
def get_image_paths(dataroot):
paths = None # return None if dataroot is None
if dataroot is not None:
paths = sorted(_get_paths_from_images(dataroot))
return paths
def _get_paths_from_images(path):
assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, "{:s} has no valid image file".format(path)
return images
"""
# --------------------------------------------
# split large images into small images
# --------------------------------------------
"""
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
w, h = img.shape[:2]
patches = []
if w > p_max and h > p_max:
w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
w1.append(w - p_size)
h1.append(h - p_size)
# print(w1)
# print(h1)
for i in w1:
for j in h1:
patches.append(img[i : i + p_size, j : j + p_size, :])
else:
patches.append(img)
return patches
def imssave(imgs, img_path):
"""
imgs: list, N images of size WxHxC
"""
img_name, ext = os.path.splitext(os.path.basename(img_path))
for i, img in enumerate(imgs):
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
new_path = os.path.join(
os.path.dirname(img_path), img_name + str("_s{:04d}".format(i)) + ".png"
)
cv2.imwrite(new_path, img)
def split_imageset(
original_dataroot,
taget_dataroot,
n_channels=3,
p_size=800,
p_overlap=96,
p_max=1000,
):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
will be splitted.
Args:
original_dataroot:
taget_dataroot:
p_size: size of small images
p_overlap: patch size in training is a good choice
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
"""
paths = get_image_paths(original_dataroot)
for img_path in paths:
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img = imread_uint(img_path, n_channels=n_channels)
patches = patches_from_image(img, p_size, p_overlap, p_max)
imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
# if original_dataroot == taget_dataroot:
# del img_path
"""
# --------------------------------------------
# makedir
# --------------------------------------------
"""
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def mkdirs(paths):
if isinstance(paths, str):
mkdir(paths)
else:
for path in paths:
mkdir(path)
def mkdir_and_rename(path):
if os.path.exists(path):
new_name = path + "_archived_" + get_timestamp()
print("Path already exists. Rename it to [{:s}]".format(new_name))
os.rename(path, new_name)
os.makedirs(path)
"""
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
"""
# --------------------------------------------
# get uint8 image of size HxWxn_channles (RGB)
# --------------------------------------------
def imread_uint(path, n_channels=3):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if n_channels == 1:
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
img = np.expand_dims(img, axis=2) # HxWx1
elif n_channels == 3:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
return img
# --------------------------------------------
# matlab's imwrite
# --------------------------------------------
def imsave(img, img_path):
img = np.squeeze(img)
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img)
def imwrite(img, img_path):
img = np.squeeze(img)
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
def read_img(path):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
img = img.astype(np.float32) / 255.0
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
"""
# --------------------------------------------
# image format conversion
# --------------------------------------------
# numpy(single) <---> numpy(unit)
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
"""
# --------------------------------------------
# numpy(single) [0, 1] <---> numpy(unit)
# --------------------------------------------
def uint2single(img):
return np.float32(img / 255.0)
def single2uint(img):
return np.uint8((img.clip(0, 1) * 255.0).round())
def uint162single(img):
return np.float32(img / 65535.0)
def single2uint16(img):
return np.uint16((img.clip(0, 1) * 65535.0).round())
# --------------------------------------------
# numpy(unit) (HxWxC or HxW) <---> tensor
# --------------------------------------------
# convert uint to 4-dimensional torch tensor
def uint2tensor4(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
.unsqueeze(0)
)
# convert uint to 3-dimensional torch tensor
def uint2tensor3(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return (
torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0)
)
# convert 2/3/4-dimensional torch tensor to uint
def tensor2uint(img):
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
return np.uint8((img * 255.0).round())
# --------------------------------------------
# numpy(single) (HxWxC) <---> tensor
# --------------------------------------------
# convert single (HxWxC) to 3-dimensional torch tensor
def single2tensor3(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
# convert single (HxWxC) to 4-dimensional torch tensor
def single2tensor4(img):
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.unsqueeze(0)
)
# convert torch tensor to single
def tensor2single(img):
img = img.data.squeeze().float().cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
return img
# convert torch tensor to single
def tensor2single3(img):
img = img.data.squeeze().float().cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
elif img.ndim == 2:
img = np.expand_dims(img, axis=2)
return img
def single2tensor5(img):
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1, 3)
.float()
.unsqueeze(0)
)
def single32tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
def single42tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
# from skimage.io import imread, imsave
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
"""
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
"""
tensor = (
tensor.squeeze().float().cpu().clamp_(*min_max)
) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
n_dim = tensor.dim()
if n_dim == 4:
n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 3:
img_np = tensor.numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 2:
img_np = tensor.numpy()
else:
raise TypeError(
"Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(
n_dim
)
)
if out_type == np.uint8:
img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type)
"""
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
# The following two are enough.
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
"""
def augment_img(img, mode=0):
"""Kai Zhang (github: https://github.com/cszn)"""
if mode == 0:
return img
elif mode == 1:
return np.flipud(np.rot90(img))
elif mode == 2:
return np.flipud(img)
elif mode == 3:
return np.rot90(img, k=3)
elif mode == 4:
return np.flipud(np.rot90(img, k=2))
elif mode == 5:
return np.rot90(img)
elif mode == 6:
return np.rot90(img, k=2)
elif mode == 7:
return np.flipud(np.rot90(img, k=3))
def augment_img_tensor4(img, mode=0):
"""Kai Zhang (github: https://github.com/cszn)"""
if mode == 0:
return img
elif mode == 1:
return img.rot90(1, [2, 3]).flip([2])
elif mode == 2:
return img.flip([2])
elif mode == 3:
return img.rot90(3, [2, 3])
elif mode == 4:
return img.rot90(2, [2, 3]).flip([2])
elif mode == 5:
return img.rot90(1, [2, 3])
elif mode == 6:
return img.rot90(2, [2, 3])
elif mode == 7:
return img.rot90(3, [2, 3]).flip([2])
def augment_img_tensor(img, mode=0):
"""Kai Zhang (github: https://github.com/cszn)"""
img_size = img.size()
img_np = img.data.cpu().numpy()
if len(img_size) == 3:
img_np = np.transpose(img_np, (1, 2, 0))
elif len(img_size) == 4:
img_np = np.transpose(img_np, (2, 3, 1, 0))
img_np = augment_img(img_np, mode=mode)
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
if len(img_size) == 3:
img_tensor = img_tensor.permute(2, 0, 1)
elif len(img_size) == 4:
img_tensor = img_tensor.permute(3, 2, 0, 1)
return img_tensor.type_as(img)
def augment_img_np3(img, mode=0):
if mode == 0:
return img
elif mode == 1:
return img.transpose(1, 0, 2)
elif mode == 2:
return img[::-1, :, :]
elif mode == 3:
img = img[::-1, :, :]
img = img.transpose(1, 0, 2)
return img
elif mode == 4:
return img[:, ::-1, :]
elif mode == 5:
img = img[:, ::-1, :]
img = img.transpose(1, 0, 2)
return img
elif mode == 6:
img = img[:, ::-1, :]
img = img[::-1, :, :]
return img
elif mode == 7:
img = img[:, ::-1, :]
img = img[::-1, :, :]
img = img.transpose(1, 0, 2)
return img
def augment_imgs(img_list, hflip=True, rot=True):
# horizontal flip OR rotate
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
"""
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
"""
def modcrop(img_in, scale):
# img_in: Numpy, HWC or HW
img = np.copy(img_in)
if img.ndim == 2:
H, W = img.shape
H_r, W_r = H % scale, W % scale
img = img[: H - H_r, : W - W_r]
elif img.ndim == 3:
H, W, C = img.shape
H_r, W_r = H % scale, W % scale
img = img[: H - H_r, : W - W_r, :]
else:
raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim))
return img
def shave(img_in, border=0):
# img_in: Numpy, HWC or HW
img = np.copy(img_in)
h, w = img.shape[:2]
img = img[border : h - border, border : w - border]
return img
"""
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
# rgb2ycbcr(img, only_y=True):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
"""
def rgb2ycbcr(img, only_y=True):
"""same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.0
# convert
if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else:
rlt = np.matmul(
img,
[
[65.481, -37.797, 112.0],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.0
return rlt.astype(in_img_type)
def ycbcr2rgb(img):
"""same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.0
# convert
rlt = np.matmul(
img,
[
[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0],
],
) * 255.0 + [-222.921, 135.576, -276.836]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.0
return rlt.astype(in_img_type)
def bgr2ycbcr(img, only_y=True):
"""bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.0
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(
img,
[
[24.966, 112.0, -18.214],
[128.553, -74.203, -93.786],
[65.481, -37.797, 112.0],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.0
return rlt.astype(in_img_type)
def channel_convert(in_c, tar_type, img_list):
# conversion among BGR, gray and y
if in_c == 3 and tar_type == "gray": # BGR to gray
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
return [np.expand_dims(img, axis=2) for img in gray_list]
elif in_c == 3 and tar_type == "y": # BGR to y
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
return [np.expand_dims(img, axis=2) for img in y_list]
elif in_c == 1 and tar_type == "RGB": # gray/y to BGR
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
else:
return img_list
"""
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
"""
# --------------------------------------------
# PSNR
# --------------------------------------------
def calculate_psnr(img1, img2, border=0):
# img1 and img2 have range [0, 255]
# img1 = img1.squeeze()
# img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError("Input images must have the same dimensions.")
h, w = img1.shape[:2]
img1 = img1[border : h - border, border : w - border]
img2 = img2[border : h - border, border : w - border]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return float("inf")
return 20 * math.log10(255.0 / math.sqrt(mse))
# --------------------------------------------
# SSIM
# --------------------------------------------
def calculate_ssim(img1, img2, border=0):
"""calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
"""
# img1 = img1.squeeze()
# img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError("Input images must have the same dimensions.")
h, w = img1.shape[:2]
img1 = img1[border : h - border, border : w - border]
img2 = img2[border : h - border, border : w - border]
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError("Wrong input image dimensions.")
def ssim(img1, img2):
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
return ssim_map.mean()
"""
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
"""
# matlab 'imresize' function, now only support 'bicubic'
def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
) * (((absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(
in_length, out_length, scale, kernel, kernel_width, antialiasing
):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
0, P - 1, P
).view(1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
# apply cubic kernel
if (scale < 1) and (antialiasing):
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, P)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, P - 2)
weights = weights.narrow(1, 1, P - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, P - 2)
weights = weights.narrow(1, 0, P - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
# --------------------------------------------
# imresize for tensor image [0, 1]
# --------------------------------------------
def imresize(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: pytorch tensor, CHW or HW [0,1]
# output: CHW or HW [0,1] w/o round
need_squeeze = True if img.dim() == 2 else False
if need_squeeze:
img.unsqueeze_(0)
in_C, in_H, in_W = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
kernel_width = 4
kernel = "cubic"
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:, :sym_len_Hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_He:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_C, out_H, in_W)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[j, i, :] = (
img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
)
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_Ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_We:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_C, out_H, out_W)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i])
if need_squeeze:
out_2.squeeze_()
return out_2
# --------------------------------------------
# imresize for numpy image [0, 1]
# --------------------------------------------
def imresize_np(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img = torch.from_numpy(img)
need_squeeze = True if img.dim() == 2 else False
if need_squeeze:
img.unsqueeze_(2)
in_H, in_W, in_C = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
kernel_width = 4
kernel = "cubic"
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:sym_len_Hs, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[-sym_len_He:, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(out_H, in_W, in_C)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[i, :, j] = (
img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
)
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :sym_len_Ws, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, -sym_len_We:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(out_H, out_W, in_C)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i])
if need_squeeze:
out_2.squeeze_()
return out_2.numpy()
if __name__ == "__main__":
print("---")
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4)

View File

@ -0,0 +1 @@
from extern.ldm_zero123.modules.losses.contperceptual import LPIPSWithDiscriminator

View File

@ -0,0 +1,153 @@
import torch
import torch.nn as nn
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
class LPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss="hinge",
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
inputs,
reconstructions,
posteriors,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split="train",
weights=None,
):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = (
weighted_nll_loss
+ self.kl_weight * kl_loss
+ d_weight * disc_factor * g_loss
)
log = {
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log

View File

@ -0,0 +1,218 @@
import torch
import torch.nn.functional as F
from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
from torch import nn
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x - y)
def l2(x, y):
return torch.pow((x - y), 2)
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start,
codebook_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_ndf=64,
disc_loss="hinge",
n_classes=None,
perceptual_loss="lpips",
pixel_loss="l1",
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert perceptual_loss in ["lpips", "clips", "dists"]
assert pixel_loss in ["l1", "l2"]
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips":
print(f"{self.__class__.__name__}: Running with LPIPS.")
self.perceptual_loss = LPIPS().eval()
else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
self.perceptual_weight = perceptual_weight
if pixel_loss == "l1":
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf,
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.n_classes = n_classes
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
codebook_loss,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split="train",
predicted_indices=None,
):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.0]).to(inputs.device)
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = (
nll_loss
+ d_weight * disc_factor * g_loss
+ self.codebook_weight * codebook_loss.mean()
)
log = {
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(
predicted_indices, self.n_classes
)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log

705
extern/ldm_zero123/modules/x_transformer.py vendored Executable file
View File

@ -0,0 +1,705 @@
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
from collections import namedtuple
from functools import partial
from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import einsum, nn
# constants
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"])
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.emb = nn.Embedding(max_seq_len, dim)
self.init_()
def init_(self):
nn.init.normal_(self.emb.weight, std=0.02)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
return self.emb(n)[None, :, :]
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = (
torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ offset
)
sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def always(val):
def inner(*args, **kwargs):
return val
return inner
def not_equals(val):
def inner(x):
return x != val
return inner
def equals(val):
def inner(x):
return x == val
return inner
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(
partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
)
return kwargs_without_prefix, kwargs
# classes
class Scale(nn.Module):
def __init__(self, value, fn):
super().__init__()
self.value = value
self.fn = fn
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.value, *rest)
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.g, *rest)
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class Residual(nn.Module):
def forward(self, x, residual):
return x + residual
class GRUGating(nn.Module):
def __init__(self, dim):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
def forward(self, x, residual):
gated_output = self.gru(
rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")
)
return gated_output.reshape_as(x)
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
# attention.
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.0,
on_attn=False,
):
super().__init__()
if use_entmax15:
raise NotImplementedError(
"Check out entmax activation instead of softmax activation!"
)
self.scale = dim_head**-0.5
self.heads = heads
self.causal = causal
self.mask = mask
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias=False)
self.dropout = nn.Dropout(dropout)
# talking heads
self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
# explicit topk sparse attention
self.sparse_topk = sparse_topk
# entmax
# self.attn_fn = entmax15 if use_entmax15 else F.softmax
self.attn_fn = F.softmax
# add memory key / values
self.num_mem_kv = num_mem_kv
if num_mem_kv > 0:
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
# attention on attention
self.attn_on_attn = on_attn
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
if on_attn
else nn.Linear(inner_dim, dim)
)
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None,
):
b, n, _, h, talking_heads, device = (
*x.shape,
self.heads,
self.talking_heads,
x.device,
)
kv_input = default(context, x)
q_input = x
k_input = kv_input
v_input = kv_input
if exists(mem):
k_input = torch.cat((mem, k_input), dim=-2)
v_input = torch.cat((mem, v_input), dim=-2)
if exists(sinusoidal_emb):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset = k_input.shape[-2] - q_input.shape[-2]
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
k_input = k_input + sinusoidal_emb(k_input)
q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(
k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()
)
q_mask = rearrange(q_mask, "b i -> b () i ()")
k_mask = rearrange(k_mask, "b j -> b () () j")
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(
lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)
)
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
mask_value = max_neg_value(dots)
if exists(prev_attn):
dots = dots + prev_attn
pre_softmax_attn = dots
if talking_heads:
dots = einsum(
"b h i j, h k -> b k i j", dots, self.pre_softmax_proj
).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
if exists(input_mask):
dots.masked_fill_(~input_mask, mask_value)
del input_mask
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
top, _ = dots.topk(self.sparse_topk, dim=-1)
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
mask = dots < vk
dots.masked_fill_(mask, mask_value)
del mask
attn = self.attn_fn(dots, dim=-1)
post_softmax_attn = attn
attn = self.dropout(attn)
if talking_heads:
attn = einsum(
"b h i j, h k -> b k i j", attn, self.post_softmax_proj
).contiguous()
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn
)
return self.to_out(out), intermediates
class AttentionLayers(nn.Module):
def __init__(
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs,
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
self.dim = dim
self.depth = depth
self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn
self.pia_pos_emb = (
FixedPositionalEmbedding(dim) if position_infused_attn else None
)
self.rotary_pos_emb = always(None)
assert (
rel_pos_num_buckets <= rel_pos_max_distance
), "number of relative position buckets must be less than the relative position max distance"
self.rel_pos = None
self.pre_norm = pre_norm
self.residual_attn = residual_attn
self.cross_residual_attn = cross_residual_attn
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
norm_class = RMSNorm if use_rmsnorm else norm_class
norm_fn = partial(norm_class, dim)
norm_fn = nn.Identity if use_rezero else norm_fn
branch_fn = Rezero if use_rezero else None
if cross_attend and not only_cross:
default_block = ("a", "c", "f")
elif cross_attend and only_cross:
default_block = ("c", "f")
else:
default_block = ("a", "f")
if macaron:
default_block = ("f",) + default_block
if exists(custom_layers):
layer_types = custom_layers
elif exists(par_ratio):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, "par ratio out of range"
default_block = tuple(filter(not_equals("f"), default_block))
par_attn = par_depth // par_ratio
depth_cut = (
par_depth * 2 // 3
) # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert (
len(default_block) <= par_width
), "default block is too large for par_ratio"
par_block = default_block + ("f",) * (par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ("f",) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert (
sandwich_coef > 0 and sandwich_coef <= depth
), "sandwich coefficient should be less than the depth"
layer_types = (
("a",) * sandwich_coef
+ default_block * (depth - sandwich_coef)
+ ("f",) * sandwich_coef
)
else:
layer_types = default_block * depth
self.layer_types = layer_types
self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
for layer_type in self.layer_types:
if layer_type == "a":
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == "c":
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == "f":
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f"invalid layer type {layer_type}")
if isinstance(layer, Attention) and exists(branch_fn):
layer = branch_fn(layer)
if gate_residual:
residual_fn = GRUGating(dim)
else:
residual_fn = Residual()
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False,
):
hiddens = []
intermediates = []
prev_attn = None
prev_cross_attn = None
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = ind == (len(self.layers) - 1)
if layer_type == "a":
hiddens.append(x)
layer_mem = mems.pop(0)
residual = x
if self.pre_norm:
x = norm(x)
if layer_type == "a":
out, inter = block(
x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
prev_attn=prev_attn,
mem=layer_mem,
)
elif layer_type == "c":
out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn,
)
elif layer_type == "f":
out = block(x)
x = residual_fn(out, residual)
if layer_type in ("a", "c"):
intermediates.append(inter)
if layer_type == "a" and self.residual_attn:
prev_attn = inter.pre_softmax_attn
elif layer_type == "c" and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn
if not self.pre_norm and not is_last:
x = norm(x)
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens, attn_intermediates=intermediates
)
return x, intermediates
return x
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert "causal" not in kwargs, "cannot set causality on encoder"
super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.0,
emb_dropout=0.0,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True,
):
super().__init__()
assert isinstance(
attn_layers, AttentionLayers
), "attention layers must be one of Encoder or Decoder"
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = (
AbsolutePositionalEmbedding(emb_dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = (
nn.Linear(dim, num_tokens)
if not tie_embedding
else lambda t: t @ self.token_emb.weight.t()
)
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
# let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, "num_memory_tokens"):
attn_layers.num_memory_tokens = num_memory_tokens
def init_(self):
nn.init.normal_(self.token_emb.weight, std=0.02)
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
**kwargs,
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
x += self.pos_emb(x)
x = self.emb_dropout(x)
x = self.project_emb(x)
if num_mem > 0:
mem = repeat(self.memory_tokens, "n d -> b n d", b=b)
x = torch.cat((mem, x), dim=1)
# auto-handle masking after appending memory tokens
if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, **kwargs
)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
out = self.to_logits(x) if not return_embeddings else x
if return_mems:
hiddens = intermediates.hiddens
new_mems = (
list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens)))
if exists(mems)
else hiddens
)
new_mems = list(
map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)
)
return out, new_mems
if return_attn:
attn_maps = list(
map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
)
return out, attn_maps
return out

144
extern/ldm_zero123/thirdp/psp/helpers.py vendored Executable file
View File

@ -0,0 +1,144 @@
# https://github.com/eladrich/pixel2style2pixel
from collections import namedtuple
import torch
from torch.nn import (
AdaptiveAvgPool2d,
BatchNorm2d,
Conv2d,
MaxPool2d,
Module,
PReLU,
ReLU,
Sequential,
Sigmoid,
)
"""
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])):
"""A named tuple describing a ResNet block."""
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + [
Bottleneck(depth, depth, 1) for i in range(num_units - 1)
]
def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3),
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3),
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3),
]
else:
raise ValueError(
"Invalid number of layers: {}. Must be one of [50, 100, 152]".format(
num_layers
)
)
return blocks
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(
channels, channels // reduction, kernel_size=1, padding=0, bias=False
)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(
channels // reduction, channels, kernel_size=1, padding=0, bias=False
)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class bottleneck_IR(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth),
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth),
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class bottleneck_IR_SE(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth),
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth),
SEModule(depth, 16),
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut

26
extern/ldm_zero123/thirdp/psp/id_loss.py vendored Executable file
View File

@ -0,0 +1,26 @@
# https://github.com/eladrich/pixel2style2pixel
import torch
from torch import nn
from extern.ldm_zero123.thirdp.psp.model_irse import Backbone
class IDFeatures(nn.Module):
def __init__(self, model_path):
super(IDFeatures, self).__init__()
print("Loading ResNet ArcFace")
self.facenet = Backbone(
input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
)
self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
def forward(self, x, crop=False):
# Not sure of the image range here
if crop:
x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
x = x[:, :, 35:223, 32:220]
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats

118
extern/ldm_zero123/thirdp/psp/model_irse.py vendored Executable file
View File

@ -0,0 +1,118 @@
# https://github.com/eladrich/pixel2style2pixel
from torch.nn import (
BatchNorm1d,
BatchNorm2d,
Conv2d,
Dropout,
Linear,
Module,
PReLU,
Sequential,
)
from extern.ldm_zero123.thirdp.psp.helpers import (
Flatten,
bottleneck_IR,
bottleneck_IR_SE,
get_blocks,
l2_norm,
)
"""
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Backbone(Module):
def __init__(self, input_size, num_layers, mode="ir", drop_ratio=0.4, affine=True):
super(Backbone, self).__init__()
assert input_size in [112, 224], "input_size should be 112 or 224"
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
assert mode in ["ir", "ir_se"], "mode should be ir or ir_se"
blocks = get_blocks(num_layers)
if mode == "ir":
unit_module = bottleneck_IR
elif mode == "ir_se":
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)
)
if input_size == 112:
self.output_layer = Sequential(
BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 7 * 7, 512),
BatchNorm1d(512, affine=affine),
)
else:
self.output_layer = Sequential(
BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 14 * 14, 512),
BatchNorm1d(512, affine=affine),
)
modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(
bottleneck.in_channel, bottleneck.depth, bottleneck.stride
)
)
self.body = Sequential(*modules)
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
def IR_50(input_size):
"""Constructs a ir-50 model."""
model = Backbone(input_size, num_layers=50, mode="ir", drop_ratio=0.4, affine=False)
return model
def IR_101(input_size):
"""Constructs a ir-101 model."""
model = Backbone(
input_size, num_layers=100, mode="ir", drop_ratio=0.4, affine=False
)
return model
def IR_152(input_size):
"""Constructs a ir-152 model."""
model = Backbone(
input_size, num_layers=152, mode="ir", drop_ratio=0.4, affine=False
)
return model
def IR_SE_50(input_size):
"""Constructs a ir_se-50 model."""
model = Backbone(
input_size, num_layers=50, mode="ir_se", drop_ratio=0.4, affine=False
)
return model
def IR_SE_101(input_size):
"""Constructs a ir_se-101 model."""
model = Backbone(
input_size, num_layers=100, mode="ir_se", drop_ratio=0.4, affine=False
)
return model
def IR_SE_152(input_size):
"""Constructs a ir_se-152 model."""
model = Backbone(
input_size, num_layers=152, mode="ir_se", drop_ratio=0.4, affine=False
)
return model

249
extern/ldm_zero123/util.py vendored Executable file
View File

@ -0,0 +1,249 @@
import importlib
import os
import time
from inspect import isfunction
import cv2
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
import torchvision
from PIL import Image, ImageDraw, ImageFont
from torch import optim
def pil_rectangle_crop(im):
width, height = im.size # Get dimensions
if width <= height:
left = 0
right = width
top = (height - width) / 2
bottom = (height + width) / 2
else:
top = 0
bottom = height
left = (width - height) / 2
bottom = (width + height) / 2
# Crop the center of the image
im = im.crop((left, top, right, bottom))
return im
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
nc = int(40 * (wh[0] / 256))
lines = "\n".join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
class AdamWwithEMAandWings(optim.Optimizer):
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def __init__(
self,
params,
lr=1.0e-3,
betas=(0.9, 0.999),
eps=1.0e-8, # TODO: check hyperparameters before using
weight_decay=1.0e-2,
amsgrad=False,
ema_decay=0.9999, # ema decay to match previous code
ema_power=1.0,
param_names=(),
):
"""AdamW that saves EMA versions of the parameters."""
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= ema_decay <= 1.0:
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
ema_decay=ema_decay,
ema_power=ema_power,
param_names=param_names,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group["amsgrad"]
beta1, beta2 = group["betas"]
ema_decay = group["ema_decay"]
ema_power = group["ema_power"]
for p in group["params"]:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("AdamW does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of parameter values
state["param_exp_avg"] = p.detach().float().clone()
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
ema_params_with_grad.append(state["param_exp_avg"])
if amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
# update the steps for each param group update
state["step"] += 1
# record the step after step update
state_steps.append(state["step"])
optim._functional.adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=False,
)
cur_ema_decay = min(ema_decay, 1 - state["step"] ** -ema_power)
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(
param.float(), alpha=1 - cur_ema_decay
)
return loss

666
extern/zero123.py vendored Normal file
View File

@ -0,0 +1,666 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
import PIL
import torch
import torchvision.transforms.functional as TF
from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.modeling_utils import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import deprecate, is_accelerate_available, logging
from diffusers.utils.torch_utils import randn_tensor
from packaging import version
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CLIPCameraProjection(ModelMixin, ConfigMixin):
"""
A Projection layer for CLIP embedding and camera embedding.
Parameters:
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed`
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
additional_embeddings`.
"""
@register_to_config
def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4):
super().__init__()
self.embedding_dim = embedding_dim
self.additional_embeddings = additional_embeddings
self.input_dim = self.embedding_dim + self.additional_embeddings
self.output_dim = self.embedding_dim
self.proj = torch.nn.Linear(self.input_dim, self.output_dim)
def forward(
self,
embedding: torch.FloatTensor,
):
"""
The [`PriorTransformer`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`):
The currently input embeddings.
Returns:
The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`).
"""
proj_embedding = self.proj(embedding)
return proj_embedding
class Zero123Pipeline(DiffusionPipeline):
r"""
Pipeline to generate variations from an input image using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
# TODO: feature_extractor is required to encode images (if they are in PIL format),
# we should give a descriptive message if the pipeline doesn't have one.
_optional_components = ["safety_checker"]
def __init__(
self,
vae: AutoencoderKL,
image_encoder: CLIPVisionModelWithProjection,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
clip_camera_projection: CLIPCameraProjection,
requires_safety_checker: bool = True,
):
super().__init__()
if safety_checker is None and requires_safety_checker:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(
unet.config, "_diffusers_version"
) and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse(
"0.9.0.dev0"
)
is_unet_sample_size_less_64 = (
hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate(
"sample_size<64", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
image_encoder=image_encoder,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
clip_camera_projection=clip_camera_projection,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [
self.unet,
self.image_encoder,
self.vae,
self.safety_checker,
]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_image(
self,
image,
elevation,
azimuth,
distance,
device,
num_images_per_prompt,
do_classifier_free_guidance,
clip_image_embeddings=None,
image_camera_embeddings=None,
):
dtype = next(self.image_encoder.parameters()).dtype
if image_camera_embeddings is None:
if image is None:
assert clip_image_embeddings is not None
image_embeddings = clip_image_embeddings.to(device=device, dtype=dtype)
else:
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(
images=image, return_tensors="pt"
).pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.unsqueeze(1)
bs_embed, seq_len, _ = image_embeddings.shape
if isinstance(elevation, float):
elevation = torch.as_tensor(
[elevation] * bs_embed, dtype=dtype, device=device
)
if isinstance(azimuth, float):
azimuth = torch.as_tensor(
[azimuth] * bs_embed, dtype=dtype, device=device
)
if isinstance(distance, float):
distance = torch.as_tensor(
[distance] * bs_embed, dtype=dtype, device=device
)
camera_embeddings = torch.stack(
[
torch.deg2rad(elevation),
torch.sin(torch.deg2rad(azimuth)),
torch.cos(torch.deg2rad(azimuth)),
distance,
],
dim=-1,
)[:, None, :]
image_embeddings = torch.cat([image_embeddings, camera_embeddings], dim=-1)
# project (image, camera) embeddings to the same dimension as clip embeddings
image_embeddings = self.clip_camera_projection(image_embeddings)
else:
image_embeddings = image_camera_embeddings.to(device=device, dtype=dtype)
bs_embed, seq_len, _ = image_embeddings.shape
# duplicate image embeddings for each generation per prompt, using mps friendly method
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
image_embeddings = image_embeddings.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
if do_classifier_free_guidance:
negative_prompt_embeds = torch.zeros_like(image_embeddings)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(
image, output_type="pil"
)
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(
feature_extractor_input, return_tensors="pt"
).to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, image, height, width, callback_steps):
# TODO: check image size or adjust image size to (height, width)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if (callback_steps is None) or (
callback_steps is not None
and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(
shape, generator=generator, device=device, dtype=dtype
)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def _get_latent_model_input(
self,
latents: torch.FloatTensor,
image: Optional[
Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]
],
num_images_per_prompt: int,
do_classifier_free_guidance: bool,
image_latents: Optional[torch.FloatTensor] = None,
):
if isinstance(image, PIL.Image.Image):
image_pt = TF.to_tensor(image).unsqueeze(0).to(latents)
elif isinstance(image, list):
image_pt = torch.stack([TF.to_tensor(img) for img in image], dim=0).to(
latents
)
elif isinstance(image, torch.Tensor):
image_pt = image
else:
image_pt = None
if image_pt is None:
assert image_latents is not None
image_pt = image_latents.repeat_interleave(num_images_per_prompt, dim=0)
else:
image_pt = image_pt * 2.0 - 1.0 # scale to [-1, 1]
# FIXME: encoded latents should be multiplied with self.vae.config.scaling_factor
# but zero123 was not trained this way
image_pt = self.vae.encode(image_pt).latent_dist.mode()
image_pt = image_pt.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
latent_model_input = torch.cat(
[
torch.cat([latents, latents], dim=0),
torch.cat([torch.zeros_like(image_pt), image_pt], dim=0),
],
dim=1,
)
else:
latent_model_input = torch.cat([latents, image_pt], dim=1)
return latent_model_input
@torch.no_grad()
def __call__(
self,
image: Optional[
Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]
] = None,
elevation: Optional[Union[float, torch.FloatTensor]] = None,
azimuth: Optional[Union[float, torch.FloatTensor]] = None,
distance: Optional[Union[float, torch.FloatTensor]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 3.0,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
clip_image_embeddings: Optional[torch.FloatTensor] = None,
image_camera_embeddings: Optional[torch.FloatTensor] = None,
image_latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
configuration of
[this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPImageProcessor`
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
# TODO: check input elevation, azimuth, and distance
# TODO: check image, clip_image_embeddings, image_latents
self.check_inputs(image, height, width, callback_steps)
# 2. Define call parameters
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
elif isinstance(image, torch.Tensor):
batch_size = image.shape[0]
else:
assert image_latents is not None
assert (
clip_image_embeddings is not None or image_camera_embeddings is not None
)
batch_size = image_latents.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input image
if isinstance(image, PIL.Image.Image) or isinstance(image, list):
pil_image = image
elif isinstance(image, torch.Tensor):
pil_image = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
else:
pil_image = None
image_embeddings = self._encode_image(
pil_image,
elevation,
azimuth,
distance,
device,
num_images_per_prompt,
do_classifier_free_guidance,
clip_image_embeddings,
image_camera_embeddings,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
# num_channels_latents = self.unet.config.in_channels
num_channels_latents = 4 # FIXME: hard-coded
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
image_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = self._get_latent_model_input(
latents,
image,
num_images_per_prompt,
do_classifier_free_guidance,
image_latents,
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=image_embeddings,
cross_attention_kwargs=cross_attention_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, image_embeddings.dtype
)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize
)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)

450
gradio_app.py Normal file
View File

@ -0,0 +1,450 @@
import argparse
import glob
import os
import re
import signal
import subprocess
import tempfile
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
import gradio as gr
import numpy as np
import psutil
import trimesh
def tail(f, window=20):
# Returns the last `window` lines of file `f`.
if window == 0:
return []
BUFSIZ = 1024
f.seek(0, 2)
remaining_bytes = f.tell()
size = window + 1
block = -1
data = []
while size > 0 and remaining_bytes > 0:
if remaining_bytes - BUFSIZ > 0:
# Seek back one whole BUFSIZ
f.seek(block * BUFSIZ, 2)
# read BUFFER
bunch = f.read(BUFSIZ)
else:
# file too small, start from beginning
f.seek(0, 0)
# only read what was not read
bunch = f.read(remaining_bytes)
bunch = bunch.decode("utf-8")
data.insert(0, bunch)
size -= bunch.count("\n")
remaining_bytes -= BUFSIZ
block -= 1
return "\n".join("".join(data).splitlines()[-window:])
@dataclass
class ExperimentStatus:
pid: Optional[int] = None
progress: str = ""
log: str = ""
output_image: Optional[str] = None
output_video: Optional[str] = None
output_mesh: Optional[str] = None
def tolist(self):
return [
self.pid,
self.progress,
self.log,
self.output_image,
self.output_video,
self.output_mesh,
]
EXP_ROOT_DIR = "outputs-gradio"
DEFAULT_PROMPT = "a delicious hamburger"
model_config = [
("DreamFusion (DeepFloyd-IF)", "configs/gradio/dreamfusion-if.yaml"),
("DreamFusion (Stable Diffusion)", "configs/gradio/dreamfusion-sd.yaml"),
("TextMesh (DeepFloyd-IF)", "configs/gradio/textmesh-if.yaml"),
("Fantasia3D (Stable Diffusion, Geometry Only)", "configs/gradio/fantasia3d.yaml"),
("SJC (Stable Diffusion)", "configs/gradio/sjc.yaml"),
("Latent-NeRF (Stable Diffusion)", "configs/gradio/latentnerf.yaml"),
]
model_choices = [m[0] for m in model_config]
model_name_to_config = {m[0]: m[1] for m in model_config}
def load_model_config(model_name):
return open(model_name_to_config[model_name]).read()
def load_model_config_attrs(model_name):
config_str = load_model_config(model_name)
from threestudio.utils.config import load_config
cfg = load_config(
config_str,
cli_args=[
"name=dummy",
"tag=dummy",
"use_timestamp=false",
f"exp_root_dir={EXP_ROOT_DIR}",
"system.prompt_processor.prompt=placeholder",
],
from_string=True,
)
return {
"source": config_str,
"guidance_scale": cfg.system.guidance.guidance_scale,
"max_steps": cfg.trainer.max_steps,
}
def on_model_selector_change(model_name):
cfg = load_model_config_attrs(model_name)
return [cfg["source"], cfg["guidance_scale"]]
def get_current_status(process, trial_dir, alive_path):
status = ExperimentStatus()
status.pid = process.pid
# write the current timestamp to the alive file
# the watcher will know the last active time of this process from this timestamp
if os.path.exists(os.path.dirname(alive_path)):
alive_fp = open(alive_path, "w")
alive_fp.seek(0)
alive_fp.write(str(time.time()))
alive_fp.flush()
log_path = os.path.join(trial_dir, "logs")
progress_path = os.path.join(trial_dir, "progress")
save_path = os.path.join(trial_dir, "save")
# read current progress from the progress file
# the progress file is created by GradioCallback
if os.path.exists(progress_path):
status.progress = open(progress_path).read()
else:
status.progress = "Setting up everything ..."
# read the last 10 lines of the log file
if os.path.exists(log_path):
status.log = tail(open(log_path, "rb"), window=10)
else:
status.log = ""
# get the validation image and testing video if they exist
if os.path.exists(save_path):
images = glob.glob(os.path.join(save_path, "*.png"))
steps = [
int(re.match(r"it(\d+)-0\.png", os.path.basename(f)).group(1))
for f in images
]
images = sorted(list(zip(images, steps)), key=lambda x: x[1])
if len(images) > 0:
status.output_image = images[-1][0]
videos = glob.glob(os.path.join(save_path, "*.mp4"))
steps = [
int(re.match(r"it(\d+)-test\.mp4", os.path.basename(f)).group(1))
for f in videos
]
videos = sorted(list(zip(videos, steps)), key=lambda x: x[1])
if len(videos) > 0:
status.output_video = videos[-1][0]
export_dirs = glob.glob(os.path.join(save_path, "*export"))
steps = [
int(re.match(r"it(\d+)-export", os.path.basename(f)).group(1))
for f in export_dirs
]
export_dirs = sorted(list(zip(export_dirs, steps)), key=lambda x: x[1])
if len(export_dirs) > 0:
obj = glob.glob(os.path.join(export_dirs[-1][0], "*.obj"))
if len(obj) > 0:
# FIXME
# seems the gr.Model3D cannot load our manually saved obj file
# here we load the obj and save it to a temporary file using trimesh
mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
trimesh.load(obj[0]).export(mesh_path.name)
status.output_mesh = mesh_path.name
return status
def run(
model_name: str,
config: str,
prompt: str,
guidance_scale: float,
seed: int,
max_steps: int,
):
# update status every 1 second
status_update_interval = 1
# save the config to a temporary file
config_file = tempfile.NamedTemporaryFile()
with open(config_file.name, "w") as f:
f.write(config)
# manually assign the output directory, name and tag so that we know the trial directory
name = os.path.basename(model_name_to_config[model_name]).split(".")[0]
tag = datetime.now().strftime("@%Y%m%d-%H%M%S")
trial_dir = os.path.join(EXP_ROOT_DIR, name, tag)
alive_path = os.path.join(trial_dir, "alive")
# spawn the training process
process = subprocess.Popen(
f"python launch.py --config {config_file.name} --train --gpu 0 --gradio trainer.enable_progress_bar=false".split()
+ [
f'name="{name}"',
f'tag="{tag}"',
f"exp_root_dir={EXP_ROOT_DIR}",
"use_timestamp=false",
f'system.prompt_processor.prompt="{prompt}"',
f"system.guidance.guidance_scale={guidance_scale}",
f"seed={seed}",
f"trainer.max_steps={max_steps}",
]
)
# spawn the watcher process
watch_process = subprocess.Popen(
"python gradio_app.py watch".split()
+ ["--pid", f"{process.pid}", "--trial-dir", f"{trial_dir}"]
)
# update status (progress, log, image, video) every status_update_interval senconds
# button status: Run -> Stop
while process.poll() is None:
time.sleep(status_update_interval)
yield get_current_status(process, trial_dir, alive_path).tolist() + [
gr.update(visible=False),
gr.update(value="Stop", variant="stop", visible=True),
]
# wait for the processes to finish
process.wait()
watch_process.wait()
# update status one last time
# button status: Stop / Reset -> Run
status = get_current_status(process, trial_dir, alive_path)
status.progress = "Finished."
yield status.tolist() + [
gr.update(value="Run", variant="primary", visible=True),
gr.update(visible=False),
]
def stop_run(pid):
# kill the process
print(f"Trying to kill process {pid} ...")
try:
os.kill(pid, signal.SIGKILL)
except:
print(f"Exception when killing process {pid}.")
# button status: Stop -> Reset
return [
gr.update(value="Reset", variant="secondary", visible=True),
gr.update(visible=False),
]
def launch(port, listen=False):
with gr.Blocks(title="threestudio - Web Demo") as demo:
with gr.Row():
pid = gr.State()
with gr.Column(scale=1):
header = gr.Markdown(
"""
# threestudio
- Select a model from the dropdown menu.
- Input a text prompt.
- Hit Run!
"""
)
# model selection dropdown
model_selector = gr.Dropdown(
value=model_choices[0],
choices=model_choices,
label="Select a model",
)
# prompt input
prompt_input = gr.Textbox(value=DEFAULT_PROMPT, label="Input prompt")
# guidance scale slider
guidance_scale_input = gr.Slider(
minimum=0.0,
maximum=100.0,
value=load_model_config_attrs(model_selector.value)[
"guidance_scale"
],
step=0.5,
label="Guidance scale",
)
# seed slider
seed_input = gr.Slider(
minimum=0, maximum=2147483647, value=0, step=1, label="Seed"
)
max_steps_input = gr.Slider(
minimum=1,
maximum=5000,
value=5000,
step=1,
label="Number of training steps",
)
# full config viewer
with gr.Accordion("See full configurations", open=False):
config_editor = gr.Code(
value=load_model_config(model_selector.value),
language="yaml",
interactive=False,
)
# load config on model selection change
model_selector.change(
fn=on_model_selector_change,
inputs=model_selector,
outputs=[config_editor, guidance_scale_input],
)
run_btn = gr.Button(value="Run", variant="primary")
stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
# generation status
status = gr.Textbox(
value="Hit the Run button to start.",
label="Status",
lines=1,
max_lines=1,
)
with gr.Column(scale=1):
with gr.Accordion("See terminal logs", open=False):
# logs
logs = gr.Textbox(label="Logs", lines=10)
# validation image display
output_image = gr.Image(value=None, label="Image")
# testing video display
output_video = gr.Video(value=None, label="Video")
# export mesh display
output_mesh = gr.Model3D(value=None, label="3D Mesh")
run_event = run_btn.click(
fn=run,
inputs=[
model_selector,
config_editor,
prompt_input,
guidance_scale_input,
seed_input,
max_steps_input,
],
outputs=[
pid,
status,
logs,
output_image,
output_video,
output_mesh,
run_btn,
stop_btn,
],
)
stop_btn.click(
fn=stop_run, inputs=[pid], outputs=[run_btn, stop_btn], cancels=[run_event]
)
launch_args = {"server_port": port}
if listen:
launch_args["server_name"] = "0.0.0.0"
demo.queue().launch(**launch_args)
def watch(
pid: int, trial_dir: str, alive_timeout: int, wait_timeout: int, check_interval: int
) -> None:
print(f"Spawn watcher for process {pid}")
def timeout_handler(signum, frame):
exit(1)
alive_path = os.path.join(trial_dir, "alive")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(wait_timeout)
def loop_find_progress_file():
while True:
if not os.path.exists(alive_path):
time.sleep(check_interval)
else:
signal.alarm(0)
return
def loop_check_alive():
while True:
if not psutil.pid_exists(pid):
print(f"Process {pid} not exists, watcher exits.")
exit(0)
alive_timestamp = float(open(alive_path).read())
if time.time() - alive_timestamp > alive_timeout:
print(f"Alive timeout for process {pid}, killed.")
try:
os.kill(pid, signal.SIGKILL)
except:
print(f"Exception when killing process {pid}.")
exit(0)
time.sleep(check_interval)
# loop until alive file is found, or alive_timeout is reached
loop_find_progress_file()
# kill the process if it is not accessed for alive_timeout seconds
loop_check_alive()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("operation", type=str, choices=["launch", "watch"])
args, extra = parser.parse_known_args()
if args.operation == "launch":
parser.add_argument("--listen", action="store_true")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
launch(args.port, listen=args.listen)
if args.operation == "watch":
parser.add_argument("--pid", type=int)
parser.add_argument("--trial-dir", type=str)
parser.add_argument("--alive-timeout", type=int, default=10)
parser.add_argument("--wait-timeout", type=int, default=10)
parser.add_argument("--check-interval", type=int, default=1)
args = parser.parse_args()
watch(
args.pid,
args.trial_dir,
alive_timeout=args.alive_timeout,
wait_timeout=args.wait_timeout,
check_interval=args.check_interval,
)

252
launch.py Normal file
View File

@ -0,0 +1,252 @@
import argparse
import contextlib
import importlib
import logging
import os
import sys
class ColoredFilter(logging.Filter):
"""
A logging filter to add color to certain log levels.
"""
RESET = "\033[0m"
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
COLORS = {
"WARNING": YELLOW,
"INFO": GREEN,
"DEBUG": BLUE,
"CRITICAL": MAGENTA,
"ERROR": RED,
}
RESET = "\x1b[0m"
def __init__(self):
super().__init__()
def filter(self, record):
if record.levelname in self.COLORS:
color_start = self.COLORS[record.levelname]
record.levelname = f"{color_start}[{record.levelname}]"
record.msg = f"{record.msg}{self.RESET}"
return True
def main(args, extras) -> None:
# set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None)
env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else []
selected_gpus = [0]
# Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified.
# As far as Pytorch Lightning is concerned, we always use all available GPUs
# (possibly filtered by CUDA_VISIBLE_DEVICES).
devices = -1
if len(env_gpus) > 0:
# CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script.
n_gpus = len(env_gpus)
else:
selected_gpus = list(args.gpu.split(","))
n_gpus = len(selected_gpus)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
import pytorch_lightning as pl
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
if args.typecheck:
from jaxtyping import install_import_hook
install_import_hook("threestudio", "typeguard.typechecked")
import threestudio
from threestudio.systems.base import BaseSystem
from threestudio.utils.callbacks import (
CodeSnapshotCallback,
ConfigSnapshotCallback,
CustomProgressBar,
ProgressCallback,
)
from threestudio.utils.config import ExperimentConfig, load_config
from threestudio.utils.misc import get_rank
from threestudio.utils.typing import Optional
logger = logging.getLogger("pytorch_lightning")
if args.verbose:
logger.setLevel(logging.DEBUG)
for handler in logger.handlers:
if handler.stream == sys.stderr: # type: ignore
if not args.gradio:
handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
handler.addFilter(ColoredFilter())
else:
handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
# parse YAML config to OmegaConf
cfg: ExperimentConfig
cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus)
if len(cfg.custom_import) > 0:
print(cfg.custom_import)
for extension in cfg.custom_import:
importlib.import_module(extension)
# set a different seed for each device
pl.seed_everything(cfg.seed + get_rank(), workers=True)
dm = threestudio.find(cfg.data_type)(cfg.data)
# Auto check resume files during training
if args.train and cfg.resume is None:
import glob
resume_file_list = glob.glob(f"{cfg.trial_dir}/ckpts/*")
if len(resume_file_list) != 0:
print(sorted(resume_file_list))
cfg.resume = sorted(resume_file_list)[-1]
print(f"Find resume file: {cfg.resume}")
system: BaseSystem = threestudio.find(cfg.system_type)(
cfg.system, resumed=cfg.resume is not None
)
system.set_save_dir(os.path.join(cfg.trial_dir, "save"))
if args.gradio:
fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs"))
fh.setLevel(logging.INFO)
if args.verbose:
fh.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
logger.addHandler(fh)
callbacks = []
if args.train:
callbacks += [
ModelCheckpoint(
dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint
),
LearningRateMonitor(logging_interval="step"),
CodeSnapshotCallback(
os.path.join(cfg.trial_dir, "code"), use_version=False
),
ConfigSnapshotCallback(
args.config,
cfg,
os.path.join(cfg.trial_dir, "configs"),
use_version=False,
),
]
if args.gradio:
callbacks += [
ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress"))
]
else:
callbacks += [CustomProgressBar(refresh_rate=1)]
def write_to_text(file, lines):
with open(file, "w") as f:
for line in lines:
f.write(line + "\n")
loggers = []
if args.train:
# make tensorboard logging dir to suppress warning
rank_zero_only(
lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True)
)()
loggers += [
TensorBoardLogger(cfg.trial_dir, name="tb_logs"),
CSVLogger(cfg.trial_dir, name="csv_logs"),
] + system.get_loggers()
rank_zero_only(
lambda: write_to_text(
os.path.join(cfg.trial_dir, "cmd.txt"),
["python " + " ".join(sys.argv), str(args)],
)
)()
trainer = Trainer(
callbacks=callbacks,
logger=loggers,
inference_mode=False,
accelerator="gpu",
devices=devices,
**cfg.trainer,
)
def set_system_status(system: BaseSystem, ckpt_path: Optional[str]):
if ckpt_path is None:
return
ckpt = torch.load(ckpt_path, map_location="cpu")
system.set_resume_status(ckpt["epoch"], ckpt["global_step"])
if args.train:
trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume)
trainer.test(system, datamodule=dm)
if args.gradio:
# also export assets if in gradio mode
trainer.predict(system, datamodule=dm)
elif args.validate:
# manually set epoch and global_step as they cannot be automatically resumed
set_system_status(system, cfg.resume)
trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume)
elif args.test:
# manually set epoch and global_step as they cannot be automatically resumed
set_system_status(system, cfg.resume)
trainer.test(system, datamodule=dm, ckpt_path=cfg.resume)
elif args.export:
set_system_status(system, cfg.resume)
trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to config file")
parser.add_argument(
"--gpu",
default="0",
help="GPU(s) to be used. 0 means use the 1st available GPU. "
"1,2 means use the 2nd and 3rd available GPU. "
"If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, "
"this argument is ignored and all available GPUs are always used.",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--train", action="store_true")
group.add_argument("--validate", action="store_true")
group.add_argument("--test", action="store_true")
group.add_argument("--export", action="store_true")
parser.add_argument(
"--gradio", action="store_true", help="if true, run in gradio mode"
)
parser.add_argument(
"--verbose", action="store_true", help="if true, set logging level to DEBUG"
)
parser.add_argument(
"--typecheck",
action="store_true",
help="whether to enable dynamic type checking",
)
args, extras = parser.parse_known_args()
if args.gradio:
# FIXME: no effect, stdout is not captured
with contextlib.redirect_stdout(sys.stderr):
main(args, extras)
else:
main(args, extras)

567
metric_utils.py Normal file
View File

@ -0,0 +1,567 @@
# * evaluate use laion/CLIP-ViT-H-14-laion2B-s32B-b79K
# best open source clip so far: laion/CLIP-ViT-bigG-14-laion2B-39B-b160k
# code adapted from NeuralLift-360
import torch
import torch.nn as nn
import os
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
# import clip
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import cv2
from PIL import Image
# import torchvision.transforms as transforms
import glob
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
import lpips
from os.path import join as osp
import argparse
import pandas as pd
import contextual_loss as cl
criterion = cl.ContextualLoss(use_vgg=True, vgg_layer='relu5_4')
class CLIP(nn.Module):
def __init__(self,
device,
clip_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
size=224): #'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'):
super().__init__()
self.size = size
self.device = f"cuda:{device}"
clip_name = clip_name
self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
clip_name)
self.clip_model = CLIPModel.from_pretrained(clip_name).to(self.device)
self.tokenizer = CLIPTokenizer.from_pretrained(
'openai/clip-vit-base-patch32')
self.normalize = transforms.Normalize(
mean=self.feature_extractor.image_mean,
std=self.feature_extractor.image_std)
self.resize = transforms.Resize(224)
self.to_tensor = transforms.ToTensor()
# image augmentation
self.aug = T.Compose([
T.Resize((224, 224)),
T.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
# * recommend to use this function for evaluation
@torch.no_grad()
def score_gt(self, ref_img_path, novel_views):
# assert len(novel_views) == 100
clip_scores = []
for novel in novel_views:
clip_scores.append(self.score_from_path(ref_img_path, [novel]))
return np.mean(clip_scores)
# * recommend to use this function for evaluation
# def score_gt(self, ref_paths, novel_paths):
# clip_scores = []
# for img1_path, img2_path in zip(ref_paths, novel_paths):
# clip_scores.append(self.score_from_path(img1_path, img2_path))
# return np.mean(clip_scores)
def similarity(self, image1_features: torch.Tensor,
image2_features: torch.Tensor) -> float:
with torch.no_grad(), torch.cuda.amp.autocast():
y = image1_features.T.view(image1_features.T.shape[1],
image1_features.T.shape[0])
similarity = torch.matmul(y, image2_features.T)
# print(similarity)
return similarity[0][0].item()
def get_img_embeds(self, img):
if img.shape[0] == 4:
img = img[:3, :, :]
img = self.aug(img).to(self.device)
img = img.unsqueeze(0) # b,c,h,w
# plt.imshow(img.cpu().squeeze(0).permute(1, 2, 0).numpy())
# plt.show()
# print(img)
image_z = self.clip_model.get_image_features(img)
image_z = image_z / image_z.norm(dim=-1,
keepdim=True) # normalize features
return image_z
def score_from_feature(self, img1, img2):
img1_feature, img2_feature = self.get_img_embeds(
img1), self.get_img_embeds(img2)
# for debug
return self.similarity(img1_feature, img2_feature)
def read_img_list(self, img_list):
size = self.size
images = []
# white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
for img_path in img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
# print(img_path)
if img.shape[2] == 4: # Handle BGRA images
alpha = img[:, :, 3] # Extract alpha channel
img = cv2.cvtColor(img,cv2.COLOR_BGRA2RGB) # Convert BGRA to BGR
img[np.where(alpha == 0)] = [
255, 255, 255
] # Set transparent pixels to white
else: # Handle other image formats like JPG and PNG
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
# plt.imshow(img)
# plt.show()
images.append(img)
images = np.stack(images, axis=0)
# images[np.where(images == 0)] = 255 # Set black pixels to white
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
# images = images.astype(np.float32)
return images
def score_from_path(self, img1_path, img2_path):
img1, img2 = self.read_img_list(img1_path), self.read_img_list(img2_path)
img1 = np.squeeze(img1)
img2 = np.squeeze(img2)
# plt.imshow(img1)
# plt.show()
# plt.imshow(img2)
# plt.show()
img1, img2 = self.to_tensor(img1), self.to_tensor(img2)
# print("img1 to tensor ",img1)
return self.score_from_feature(img1, img2)
def numpy_to_torch(images):
images = images * 2.0 - 1.0
images = torch.from_numpy(images.transpose((0, 3, 1, 2))).float()
return images.cuda()
class LPIPSMeter:
def __init__(self,
net='alex',
device=None,
size=224): # or we can use 'alex', 'vgg' as network
self.size = size
self.net = net
self.results = []
self.device = device if device is not None else torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.fn = lpips.LPIPS(net=net).eval().to(self.device)
def measure(self):
return np.mean(self.results)
def report(self):
return f'LPIPS ({self.net}) = {self.measure():.6f}'
def read_img_list(self, img_list):
size = self.size
images = []
white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
for img_path in img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if img.shape[2] == 4: # Handle BGRA images
alpha = img[:, :, 3] # Extract alpha channel
img = cv2.cvtColor(img,
cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR
img = cv2.cvtColor(img,
cv2.COLOR_BGR2RGB) # Convert BGR to RGB
img[np.where(alpha == 0)] = [
255, 255, 255
] # Set transparent pixels to white
else: # Handle other image formats like JPG and PNG
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
images.append(img)
images = np.stack(images, axis=0)
# images[np.where(images == 0)] = 255 # Set black pixels to white
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
images = images.astype(np.float32) / 255.0
return images
# * recommend to use this function for evaluation
@torch.no_grad()
def score_gt(self, ref_paths, novel_paths):
self.results = []
for path0, path1 in zip(ref_paths, novel_paths):
# Load images
# img0 = lpips.im2tensor(lpips.load_image(path0)).cuda() # RGB image from [-1,1]
# img1 = lpips.im2tensor(lpips.load_image(path1)).cuda()
img0, img1 = self.read_img_list([path0]), self.read_img_list(
[path1])
img0, img1 = numpy_to_torch(img0), numpy_to_torch(img1)
# print(img0.shape,img1.shape)
img0 = F.interpolate(img0,
size=(self.size, self.size),
mode='area')
img1 = F.interpolate(img1,
size=(self.size, self.size),
mode='area')
# for debug vis
# plt.imshow(img0.cpu().squeeze(0).permute(1, 2, 0).numpy())
# plt.show()
# plt.imshow(img1.cpu().squeeze(0).permute(1, 2, 0).numpy())
# plt.show()
# equivalent to cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA
# print(img0.shape,img1.shape)
self.results.append(self.fn.forward(img0, img1).cpu().numpy())
return self.measure()
class CXMeter:
def __init__(self,
net='vgg',
device=None,
size=512): # or we can use 'alex', 'vgg' as network
self.size = size
self.net = net
self.results = []
self.device = device if device is not None else torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.fn = lpips.LPIPS(net=net).eval().to(self.device)
def measure(self):
return np.mean(self.results)
def report(self):
return f'LPIPS ({self.net}) = {self.measure():.6f}'
def read_img_list(self, img_list):
size = self.size
images = []
white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
for img_path in img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if img.shape[2] == 4: # Handle BGRA images
alpha = img[:, :, 3] # Extract alpha channel
img = cv2.cvtColor(img,
cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR
img = cv2.cvtColor(img,
cv2.COLOR_BGR2RGB) # Convert BGR to RGB
img[np.where(alpha == 0)] = [
255, 255, 255
] # Set transparent pixels to white
else: # Handle other image formats like JPG and PNG
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
images.append(img)
images = np.stack(images, axis=0)
# images[np.where(images == 0)] = 255 # Set black pixels to white
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
images = images.astype(np.float32) / 255.0
return images
# * recommend to use this function for evaluation
@torch.no_grad()
def score_gt(self, ref_paths, novel_paths):
self.results = []
path0 = ref_paths[0]
print('calculating CX loss')
for path1 in tqdm(novel_paths):
# Load images
img0, img1 = self.read_img_list([path0]), self.read_img_list(
[path1])
img0, img1 = numpy_to_torch(img0), numpy_to_torch(img1)
img0, img1 = img0 * 0.5 + 0.5, img1 * 0.5 + 0.5
img0 = F.interpolate(img0,
size=(self.size, self.size),
mode='area')
img1 = F.interpolate(img1,
size=(self.size, self.size),
mode='area')
loss = criterion(img0.cpu(), img1.cpu())
self.results.append(loss.cpu().numpy())
return self.measure()
class PSNRMeter:
def __init__(self, size=800):
self.results = []
self.size = size
def read_img_list(self, img_list):
size = self.size
images = []
white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
for img_path in img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if img.shape[2] == 4: # Handle BGRA images
alpha = img[:, :, 3] # Extract alpha channel
img = cv2.cvtColor(img,
cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR
img = cv2.cvtColor(img,
cv2.COLOR_BGR2RGB) # Convert BGR to RGB
img[np.where(alpha == 0)] = [
255, 255, 255
] # Set transparent pixels to white
else: # Handle other image formats like JPG and PNG
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
images.append(img)
images = np.stack(images, axis=0)
# images[np.where(images == 0)] = 255 # Set black pixels to white
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
images = images.astype(np.float32) / 255.0
# print(images.shape)
return images
def update(self, preds, truths):
# print(preds.shape)
psnr_values = []
# For each pair of images in the batches
for img1, img2 in zip(preds, truths):
# Compute the PSNR and add it to the list
# print(img1.shape,img2.shape)
# for debug
# plt.imshow(img1)
# plt.show()
# plt.imshow(img2)
# plt.show()
psnr = compare_psnr(
img1, img2,
data_range=1.0) # assuming your images are scaled to [0,1]
# print(f"temp psnr {psnr}")
psnr_values.append(psnr)
# Convert the list of PSNR values to a numpy array
self.results = psnr_values
def measure(self):
return np.mean(self.results)
def report(self):
return f'PSNR = {self.measure():.6f}'
# * recommend to use this function for evaluation
def score_gt(self, ref_paths, novel_paths):
self.results = []
# [B, N, 3] or [B, H, W, 3], range[0, 1]
preds = self.read_img_list(ref_paths)
print('novel_paths', novel_paths)
truths = self.read_img_list(novel_paths)
self.update(preds, truths)
return self.measure()
# all_inputs = 'data'
# nerf_dataset = os.listdir(osp(all_inputs, 'nerf4'))
# realfusion_dataset = os.listdir(osp(all_inputs, 'realfusion15'))
# meta_examples = {
# 'nerf4': nerf_dataset,
# 'realfusion15': realfusion_dataset,
# }
# all_datasets = meta_examples.keys()
# organization 1
def deprecated_score_from_method_for_dataset(my_scorer,
method,
dataset,
input,
output,
score_type='clip',
): # psnr, lpips
# print("\n\n\n")
# print(f"______{method}___{dataset}___{score_type}_________")
scores = {}
final_res = 0
examples = meta_examples[dataset]
for i in range(len(examples)):
# compare entire folder for clip
if score_type == 'clip':
novel_view = osp(pred_path, examples[i], 'colors')
# compare first image for other metrics
else:
if method == '3d_fuse': method = '3d_fuse_0'
novel_view = list(
glob.glob(
osp(pred_path, examples[i], 'colors',
'step_0000*')))[0]
score_i = my_scorer.score_gt(
[], [novel_view])
scores[examples[i]] = score_i
final_res += score_i
# print(scores, " Avg : ", final_res / len(examples))
# print("``````````````````````")
return scores
# results organization 2
def score_from_method_for_dataset(my_scorer,
input_path,
pred_path,
score_type='clip',
rgb_name='lambertian',
result_folder='results/images',
first_str='*0000*'
): # psnr, lpips
scores = {}
final_res = 0
examples = os.listdir(input_path)
for i in range(len(examples)):
# ref path
ref_path = osp(input_path, examples[i], 'rgba.png')
# compare entire folder for clip
print(pred_path,'*'+examples[i]+'*', result_folder, f'*{rgb_name}*')
exit(0)
if score_type == 'clip':
novel_view = glob.glob(osp(pred_path,'*'+examples[i]+'*', result_folder, f'*{rgb_name}*'))
print(f'[INOF] {score_type} loss for example {examples[i]} between 1 GT and {len(novel_view)} predictions')
# compare first image for other metrics
else:
novel_view = glob.glob(osp(pred_path, '*'+examples[i]+'*/', result_folder, f'{first_str}{rgb_name}*'))
print(f'[INOF] {score_type} loss for example {examples[i]} between {ref_path} and {novel_view}')
# breakpoint()
score_i = my_scorer.score_gt([ref_path], novel_view)
scores[examples[i]] = score_i
final_res += score_i
avg_score = final_res / len(examples)
scores['average'] = avg_score
return scores
# results organization 2
def score_from_my_method_for_dataset(my_scorer,
input_path, dataset,
score_type='clip'
): # psnr, lpips
scores = {}
final_res = 0
input_path = osp(input_path, dataset)
ref_path = glob.glob(osp(input_path, "*_rgba.png"))
novel_view = [osp(input_path, '%d.png' % i) for i in range(120)]
# print(ref_path)
# print(novel_view)
for i in tqdm(range(120)):
if os.path.exists(osp(input_path, '%d_color.png' % i)):
continue
img = cv2.imread(novel_view[i])
H = img.shape[0]
img = img[:, :H]
cv2.imwrite(osp(input_path, '%d_color.png' % i), img)
if score_type == 'clip' or score_type == 'cx':
novel_view = [osp(input_path, '%d_color.png' % i) for i in range(120)]
else:
novel_view = [osp(input_path, '%d_color.png' % i) for i in range(1)]
print(novel_view)
scores['%s_average' % dataset] = my_scorer.score_gt(ref_path, novel_view)
return scores
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Script to accept three string arguments")
parser.add_argument("--input_path",
default=None,
help="Specify the input path")
parser.add_argument("--pred_pattern",
default="out/magic123*",
help="Specify the pattern of predition paths")
parser.add_argument("--results_folder",
default="results/images",
help="where are the results under each pred_path")
parser.add_argument("--rgb_name",
default="lambertian",
help="the postfix of the image")
parser.add_argument("--first_str",
default="*0000*",
help="the str to indicate the first view")
parser.add_argument("--datasets",
default=None,
nargs='*',
help="Specify the output path")
parser.add_argument("--device",
type=int,
default=0,
help="Specify the GPU device to be used")
parser.add_argument("--save_dir", type=str, default='all_metrics/results')
args = parser.parse_args()
clip_scorer = CLIP(args.device)
lpips_scorer = LPIPSMeter()
psnr_scorer = PSNRMeter()
CX_scorer = CXMeter()
# criterion = criterion.to(args.device)
os.makedirs(args.save_dir, exist_ok=True)
for dataset in os.listdir(args.input_path):
print(dataset)
results_dict = {}
results_dict['clip'] = score_from_my_method_for_dataset(
clip_scorer, args.input_path, dataset, 'clip')
results_dict['psnr'] = score_from_my_method_for_dataset(
psnr_scorer, args.input_path, dataset, 'psnr')
results_dict['lpips'] = score_from_my_method_for_dataset(
lpips_scorer, args.input_path, dataset, 'lpips')
results_dict['CX'] = score_from_my_method_for_dataset(
CX_scorer, args.input_path, dataset, 'cx')
df = pd.DataFrame(results_dict)
print(df)
df.to_csv(f"{args.save_dir}/result.csv")
# for dataset in args.datasets:
# input_path = osp(args.input_path, dataset)
# # assume the pred_path is organized as: pred_path/methods/dataset
# pred_pattern = osp(args.pred_pattern, dataset)
# pred_paths = glob.glob(pred_pattern)
# print(f"[INFO] Following the pattern {pred_pattern}, find {len(pred_paths)} pred_paths: \n", pred_paths)
# if len(pred_paths) == 0:
# raise IOError
# for pred_path in pred_paths:
# if not os.path.exists(pred_path):
# print(f'[WARN] prediction does not exit for {pred_path}')
# else:
# print(f'[INFO] evaluate {pred_path}')

237
preprocess_image.py Normal file
View File

@ -0,0 +1,237 @@
import os
import sys
import cv2
import argparse
import numpy as np
import matplotlib.pyplot as plt
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
class BackgroundRemoval():
def __init__(self, device='cuda'):
from carvekit.api.high import HiInterface
self.interface = HiInterface(
object_type="object", # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device=device,
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=True,
)
@torch.no_grad()
def __call__(self, image):
# image: [H, W, 3] array in [0, 255].
image = Image.fromarray(image)
image = self.interface([image])[0]
image = np.array(image)
return image
class BLIP2():
def __init__(self, device='cuda'):
self.device = device
from transformers import AutoProcessor, Blip2ForConditionalGeneration
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)
@torch.no_grad()
def __call__(self, image):
image = Image.fromarray(image)
inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.model.generate(**inputs, max_new_tokens=20)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return generated_text
class DPT():
def __init__(self, task='depth', device='cuda'):
self.task = task
self.device = device
from threestudio.utils.dpt import DPTDepthModel
if task == 'depth':
path = 'load/omnidata/omnidata_dpt_depth_v2.ckpt'
self.model = DPTDepthModel(backbone='vitb_rn50_384')
self.aug = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.5)
])
else: # normal
path = 'load/omnidata/omnidata_dpt_normal_v2.ckpt'
self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3)
self.aug = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor()
])
# load model
checkpoint = torch.load(path, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = {}
for k, v in checkpoint['state_dict'].items():
state_dict[k[6:]] = v
else:
state_dict = checkpoint
self.model.load_state_dict(state_dict)
self.model.eval().to(device)
@torch.no_grad()
def __call__(self, image):
# image: np.ndarray, uint8, [H, W, 3]
H, W = image.shape[:2]
image = Image.fromarray(image)
image = self.aug(image).unsqueeze(0).to(self.device)
if self.task == 'depth':
depth = self.model(image).clamp(0, 1)
depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False)
depth = depth.squeeze(1).cpu().numpy()
return depth
else:
normal = self.model(image).clamp(0, 1)
normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False)
normal = normal.cpu().numpy()
return normal
def preprocess_single_image(img_path, args):
out_dir = os.path.dirname(img_path)
out_rgba = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_rgba.png')
out_depth = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_depth.png')
out_normal = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_normal.png')
out_caption = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_caption.txt')
# load image
print(f'[INFO] loading image {img_path}...')
# check the exisiting files
if os.path.isfile(out_rgba) and os.path.isfile(out_depth) and os.path.isfile(out_normal):
print(f"{img_path} has already been here!")
return
print(img_path)
image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
carved_image = None
# debug
if image.shape[-1] == 4:
if args.do_rm_bg_force:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
else:
carved_image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
else:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if args.do_seg:
if carved_image is None:
# carve background
print(f'[INFO] background removal...')
carved_image = BackgroundRemoval()(image) # [H, W, 4]
mask = carved_image[..., -1] > 0
# predict depth
print(f'[INFO] depth estimation...')
dpt_depth_model = DPT(task='depth')
depth = dpt_depth_model(image)[0]
depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9)
depth[~mask] = 0
depth = (depth * 255).astype(np.uint8)
del dpt_depth_model
# predict normal
print(f'[INFO] normal estimation...')
dpt_normal_model = DPT(task='normal')
normal = dpt_normal_model(image)[0]
normal = (normal * 255).astype(np.uint8).transpose(1, 2, 0)
normal[~mask] = 0
del dpt_normal_model
opt.recenter=False
# recenter
if opt.recenter:
print(f'[INFO] recenter...')
final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8)
final_depth = np.zeros((opt.size, opt.size), dtype=np.uint8)
final_normal = np.zeros((opt.size, opt.size, 3), dtype=np.uint8)
coords = np.nonzero(mask)
x_min, x_max = coords[0].min(), coords[0].max()
y_min, y_max = coords[1].min(), coords[1].max()
h = x_max - x_min
w = y_max - y_min
desired_size = int(opt.size * (1 - opt.border_ratio))
scale = desired_size / max(h, w)
h2 = int(h * scale)
w2 = int(w * scale)
x2_min = (opt.size - h2) // 2
x2_max = x2_min + h2
y2_min = (opt.size - w2) // 2
y2_max = y2_min + w2
final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
final_depth[x2_min:x2_max, y2_min:y2_max] = cv2.resize(depth[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
final_normal[x2_min:x2_max, y2_min:y2_max] = cv2.resize(normal[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
else:
final_rgba = carved_image
final_depth = depth
final_normal = normal
# write output
cv2.imwrite(out_rgba, cv2.cvtColor(final_rgba, cv2.COLOR_RGBA2BGRA))
cv2.imwrite(out_depth, final_depth)
cv2.imwrite(out_normal, final_normal)
if opt.do_caption:
# predict caption (it's too slow... use your brain instead)
print(f'[INFO] captioning...')
blip2 = BLIP2()
caption = blip2(image)
with open(out_caption, 'w') as f:
f.write(caption)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)")
parser.add_argument('--size', default=1024, type=int, help="output resolution")
parser.add_argument('--border_ratio', default=0.1, type=float, help="output border ratio")
parser.add_argument('--recenter', type=bool, default=False, help="recenter, potentially not helpful for multiview zero123")
parser.add_argument('--dont_recenter', dest='recenter', action='store_false')
parser.add_argument('--do_caption', type=bool, default=False, help="do text captioning")
parser.add_argument('--do_seg', type=bool, default=True)
parser.add_argument('--do_rm_bg_force', type=bool, default=False)
opt = parser.parse_args()
if os.path.isdir(opt.path):
img_list = sorted(os.path.join(root, fname) for root, _dirs, files in os.walk(opt.path) for fname in files)
img_list = [img for img in img_list if not img.endswith("rgba.png") and not img.endswith("depth.png") and not img.endswith("normal.png")]
img_list = [img for img in img_list if img.endswith(".png")]
for img in img_list:
# try:
preprocess_single_image(img, opt)
# except:
# with open("preprocess_images_invalid.txt", "a") as f:
# print(img, file=f)
else: # single image file
preprocess_single_image(opt.path, opt)

35
requirements.txt Normal file
View File

@ -0,0 +1,35 @@
lightning==2.0.0
omegaconf==2.3.0
jaxtyping
typeguard
diffusers<=0.23.0
transformers
accelerate
opencv-python
tensorboard
matplotlib
imageio>=2.28.0
imageio[ffmpeg]
libigl
xatlas
trimesh[easy]
networkx
pysdf
PyMCubes
wandb
gradio
# deepfloyd
xformers
bitsandbytes
sentencepiece
safetensors
huggingface_hub
# for zero123
einops
kornia
taming-transformers-rom1504
#controlnet
controlnet_aux

36
threestudio/__init__.py Normal file
View File

@ -0,0 +1,36 @@
__modules__ = {}
def register(name):
def decorator(cls):
__modules__[name] = cls
return cls
return decorator
def find(name):
return __modules__[name]
### grammar sugar for logging utilities ###
import logging
logger = logging.getLogger("pytorch_lightning")
from pytorch_lightning.utilities.rank_zero import (
rank_zero_debug,
rank_zero_info,
rank_zero_only,
)
debug = rank_zero_debug
info = rank_zero_info
@rank_zero_only
def warn(*args, **kwargs):
logger.warn(*args, **kwargs)
from . import data, models, systems

View File

@ -0,0 +1 @@
from . import image, uncond

351
threestudio/data/image.py Normal file
View File

@ -0,0 +1,351 @@
import bisect
import math
import os
from dataclasses import dataclass, field
import cv2
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
import threestudio
from threestudio import register
from threestudio.data.uncond import (
RandomCameraDataModuleConfig,
RandomCameraDataset,
RandomCameraIterableDataset,
)
from threestudio.utils.base import Updateable
from threestudio.utils.config import parse_structured
from threestudio.utils.misc import get_rank
from threestudio.utils.ops import (
get_mvp_matrix,
get_projection_matrix,
get_ray_directions,
get_rays,
)
from threestudio.utils.typing import *
@dataclass
class SingleImageDataModuleConfig:
# height and width should be Union[int, List[int]]
# but OmegaConf does not support Union of containers
height: Any = 96
width: Any = 96
resolution_milestones: List[int] = field(default_factory=lambda: [])
default_elevation_deg: float = 0.0
default_azimuth_deg: float = -180.0
default_camera_distance: float = 1.2
default_fovy_deg: float = 60.0
image_path: str = ""
use_random_camera: bool = True
random_camera: dict = field(default_factory=dict)
rays_noise_scale: float = 2e-3
batch_size: int = 1
requires_depth: bool = False
requires_normal: bool = False
rays_d_normalize: bool = True
use_mixed_camera_config: bool = False
class SingleImageDataBase:
def setup(self, cfg, split):
self.split = split
self.rank = get_rank()
self.cfg: SingleImageDataModuleConfig = cfg
if self.cfg.use_random_camera:
random_camera_cfg = parse_structured(
RandomCameraDataModuleConfig, self.cfg.get("random_camera", {})
)
# FIXME:
if self.cfg.use_mixed_camera_config:
if self.rank % 2 == 0:
random_camera_cfg.camera_distance_range=[self.cfg.default_camera_distance, self.cfg.default_camera_distance]
random_camera_cfg.fovy_range=[self.cfg.default_fovy_deg, self.cfg.default_fovy_deg]
self.fixed_camera_intrinsic = True
else:
self.fixed_camera_intrinsic = False
if split == "train":
self.random_pose_generator = RandomCameraIterableDataset(
random_camera_cfg
)
else:
self.random_pose_generator = RandomCameraDataset(
random_camera_cfg, split
)
elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg])
azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg])
camera_distance = torch.FloatTensor([self.cfg.default_camera_distance])
elevation = elevation_deg * math.pi / 180
azimuth = azimuth_deg * math.pi / 180
camera_position: Float[Tensor, "1 3"] = torch.stack(
[
camera_distance * torch.cos(elevation) * torch.cos(azimuth),
camera_distance * torch.cos(elevation) * torch.sin(azimuth),
camera_distance * torch.sin(elevation),
],
dim=-1,
)
center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position)
up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None]
light_position: Float[Tensor, "1 3"] = camera_position
lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1)
right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
up = F.normalize(torch.cross(right, lookat), dim=-1)
self.c2w: Float[Tensor, "1 3 4"] = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]],
dim=-1,
)
self.c2w4x4: Float[Tensor, "B 4 4"] = torch.cat(
[self.c2w, torch.zeros_like(self.c2w[:, :1])], dim=1
)
self.c2w4x4[:, 3, 3] = 1.0
self.camera_position = camera_position
self.light_position = light_position
self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg
self.camera_distance = camera_distance
self.fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg]))
self.heights: List[int] = (
[self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height
)
self.widths: List[int] = (
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width
)
assert len(self.heights) == len(self.widths)
self.resolution_milestones: List[int]
if len(self.heights) == 1 and len(self.widths) == 1:
if len(self.cfg.resolution_milestones) > 0:
threestudio.warn(
"Ignoring resolution_milestones since height and width are not changing"
)
self.resolution_milestones = [-1]
else:
assert len(self.heights) == len(self.cfg.resolution_milestones) + 1
self.resolution_milestones = [-1] + self.cfg.resolution_milestones
self.directions_unit_focals = [
get_ray_directions(H=height, W=width, focal=1.0)
for (height, width) in zip(self.heights, self.widths)
]
self.focal_lengths = [
0.5 * height / torch.tan(0.5 * self.fovy) for height in self.heights
]
self.height: int = self.heights[0]
self.width: int = self.widths[0]
self.directions_unit_focal = self.directions_unit_focals[0]
self.focal_length = self.focal_lengths[0]
self.set_rays()
self.load_images()
self.prev_height = self.height
def set_rays(self):
# get directions by dividing directions_unit_focal by focal length
directions: Float[Tensor, "1 H W 3"] = self.directions_unit_focal[None]
directions[:, :, :, :2] = directions[:, :, :, :2] / self.focal_length
rays_o, rays_d = get_rays(
directions,
self.c2w,
keepdim=True,
noise_scale=self.cfg.rays_noise_scale,
normalize=self.cfg.rays_d_normalize,
)
proj_mtx: Float[Tensor, "4 4"] = get_projection_matrix(
self.fovy, self.width / self.height, 0.01, 100.0
) # FIXME: hard-coded near and far
mvp_mtx: Float[Tensor, "4 4"] = get_mvp_matrix(self.c2w, proj_mtx)
self.rays_o, self.rays_d = rays_o, rays_d
self.mvp_mtx = mvp_mtx
def load_images(self):
# load image
assert os.path.exists(
self.cfg.image_path
), f"Could not find image {self.cfg.image_path}!"
rgba = cv2.cvtColor(
cv2.imread(self.cfg.image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA
)
rgba = (
cv2.resize(
rgba, (self.width, self.height), interpolation=cv2.INTER_AREA
).astype(np.float32)
/ 255.0
)
rgb = rgba[..., :3]
self.rgb: Float[Tensor, "1 H W 3"] = (
torch.from_numpy(rgb).unsqueeze(0).contiguous().to(self.rank)
)
self.mask: Float[Tensor, "1 H W 1"] = (
torch.from_numpy(rgba[..., 3:] > 0.5).unsqueeze(0).to(self.rank)
)
print(
f"[INFO] single image dataset: load image {self.cfg.image_path} {self.rgb.shape}"
)
# load depth
if self.cfg.requires_depth:
depth_path = self.cfg.image_path.replace("_rgba.png", "_depth.png")
assert os.path.exists(depth_path)
depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
depth = cv2.resize(
depth, (self.width, self.height), interpolation=cv2.INTER_AREA
)
self.depth: Float[Tensor, "1 H W 1"] = (
torch.from_numpy(depth.astype(np.float32) / 255.0)
.unsqueeze(0)
.to(self.rank)
)
print(
f"[INFO] single image dataset: load depth {depth_path} {self.depth.shape}"
)
else:
self.depth = None
# load normal
if self.cfg.requires_normal:
normal_path = self.cfg.image_path.replace("_rgba.png", "_normal.png")
assert os.path.exists(normal_path)
normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED)
normal = cv2.resize(
normal, (self.width, self.height), interpolation=cv2.INTER_AREA
)
self.normal: Float[Tensor, "1 H W 3"] = (
torch.from_numpy(normal.astype(np.float32) / 255.0)
.unsqueeze(0)
.to(self.rank)
)
print(
f"[INFO] single image dataset: load normal {normal_path} {self.normal.shape}"
)
else:
self.normal = None
def get_all_images(self):
return self.rgb
def update_step_(self, epoch: int, global_step: int, on_load_weights: bool = False):
size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1
self.height = self.heights[size_ind]
if self.height == self.prev_height:
return
self.prev_height = self.height
self.width = self.widths[size_ind]
self.directions_unit_focal = self.directions_unit_focals[size_ind]
self.focal_length = self.focal_lengths[size_ind]
threestudio.debug(f"Training height: {self.height}, width: {self.width}")
self.set_rays()
self.load_images()
class SingleImageIterableDataset(IterableDataset, SingleImageDataBase, Updateable):
def __init__(self, cfg: Any, split: str) -> None:
super().__init__()
self.setup(cfg, split)
def collate(self, batch) -> Dict[str, Any]:
batch = {
"rays_o": self.rays_o,
"rays_d": self.rays_d,
"mvp_mtx": self.mvp_mtx,
"camera_positions": self.camera_position,
"light_positions": self.light_position,
"elevation": self.elevation_deg,
"azimuth": self.azimuth_deg,
"camera_distances": self.camera_distance,
"rgb": self.rgb,
"ref_depth": self.depth,
"ref_normal": self.normal,
"mask": self.mask,
"height": self.cfg.height,
"width": self.cfg.width,
"c2w": self.c2w,
"c2w4x4": self.c2w4x4,
}
if self.cfg.use_random_camera:
batch["random_camera"] = self.random_pose_generator.collate(None)
return batch
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
self.update_step_(epoch, global_step, on_load_weights)
self.random_pose_generator.update_step(epoch, global_step, on_load_weights)
def __iter__(self):
while True:
yield {}
class SingleImageDataset(Dataset, SingleImageDataBase):
def __init__(self, cfg: Any, split: str) -> None:
super().__init__()
self.setup(cfg, split)
def __len__(self):
return len(self.random_pose_generator)
def __getitem__(self, index):
batch = self.random_pose_generator[index]
batch.update(
{
"height": self.random_pose_generator.cfg.eval_height,
"width": self.random_pose_generator.cfg.eval_width,
"mvp_mtx_ref": self.mvp_mtx[0],
"c2w_ref": self.c2w4x4,
}
)
return batch
@register("single-image-datamodule")
class SingleImageDataModule(pl.LightningDataModule):
cfg: SingleImageDataModuleConfig
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
super().__init__()
self.cfg = parse_structured(SingleImageDataModuleConfig, cfg)
def setup(self, stage=None) -> None:
if stage in [None, "fit"]:
self.train_dataset = SingleImageIterableDataset(self.cfg, "train")
if stage in [None, "fit", "validate"]:
self.val_dataset = SingleImageDataset(self.cfg, "val")
if stage in [None, "test", "predict"]:
self.test_dataset = SingleImageDataset(self.cfg, "test")
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
return DataLoader(
dataset, num_workers=0, batch_size=batch_size, collate_fn=collate_fn
)
def train_dataloader(self) -> DataLoader:
return self.general_loader(
self.train_dataset,
batch_size=self.cfg.batch_size,
collate_fn=self.train_dataset.collate,
)
def val_dataloader(self) -> DataLoader:
return self.general_loader(self.val_dataset, batch_size=1)
def test_dataloader(self) -> DataLoader:
return self.general_loader(self.test_dataset, batch_size=1)
def predict_dataloader(self) -> DataLoader:
return self.general_loader(self.test_dataset, batch_size=1)

351
threestudio/data/images.py Normal file
View File

@ -0,0 +1,351 @@
import bisect
import math
import os
from dataclasses import dataclass, field
import cv2
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
import threestudio
from threestudio import register
from threestudio.data.uncond import (
RandomCameraDataModuleConfig,
RandomCameraDataset,
RandomCameraIterableDataset,
)
from threestudio.utils.base import Updateable
from threestudio.utils.config import parse_structured
from threestudio.utils.misc import get_rank
from threestudio.utils.ops import (
get_mvp_matrix,
get_projection_matrix,
get_ray_directions,
get_rays,
)
from threestudio.utils.typing import *
@dataclass
class SingleImageDataModuleConfig:
# height and width should be Union[int, List[int]]
# but OmegaConf does not support Union of containers
height: Any = 96
width: Any = 96
resolution_milestones: List[int] = field(default_factory=lambda: [])
default_elevation_deg: float = 0.0
default_azimuth_deg: float = -180.0
default_camera_distance: float = 1.2
default_fovy_deg: float = 60.0
image_path: str = ""
use_random_camera: bool = True
random_camera: dict = field(default_factory=dict)
rays_noise_scale: float = 2e-3
batch_size: int = 1
requires_depth: bool = False
requires_normal: bool = False
rays_d_normalize: bool = True
use_mixed_camera_config: bool = False
class SingleImageDataBase:
def setup(self, cfg, split):
self.split = split
self.rank = get_rank()
self.cfg: SingleImageDataModuleConfig = cfg
if self.cfg.use_random_camera:
random_camera_cfg = parse_structured(
RandomCameraDataModuleConfig, self.cfg.get("random_camera", {})
)
# FIXME:
if self.cfg.use_mixed_camera_config:
if self.rank % 2 == 0:
random_camera_cfg.camera_distance_range=[self.cfg.default_camera_distance, self.cfg.default_camera_distance]
random_camera_cfg.fovy_range=[self.cfg.default_fovy_deg, self.cfg.default_fovy_deg]
self.fixed_camera_intrinsic = True
else:
self.fixed_camera_intrinsic = False
if split == "train":
self.random_pose_generator = RandomCameraIterableDataset(
random_camera_cfg
)
else:
self.random_pose_generator = RandomCameraDataset(
random_camera_cfg, split
)
elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg])
azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg])
camera_distance = torch.FloatTensor([self.cfg.default_camera_distance])
elevation = elevation_deg * math.pi / 180
azimuth = azimuth_deg * math.pi / 180
camera_position: Float[Tensor, "1 3"] = torch.stack(
[
camera_distance * torch.cos(elevation) * torch.cos(azimuth),
camera_distance * torch.cos(elevation) * torch.sin(azimuth),
camera_distance * torch.sin(elevation),
],
dim=-1,
)
center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position)
up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None]
light_position: Float[Tensor, "1 3"] = camera_position
lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1)
right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
up = F.normalize(torch.cross(right, lookat), dim=-1)
self.c2w: Float[Tensor, "1 3 4"] = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]],
dim=-1,
)
self.c2w4x4: Float[Tensor, "B 4 4"] = torch.cat(
[self.c2w, torch.zeros_like(self.c2w[:, :1])], dim=1
)
self.c2w4x4[:, 3, 3] = 1.0
self.camera_position = camera_position
self.light_position = light_position
self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg
self.camera_distance = camera_distance
self.fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg]))
self.heights: List[int] = (
[self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height
)
self.widths: List[int] = (
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width
)
assert len(self.heights) == len(self.widths)
self.resolution_milestones: List[int]
if len(self.heights) == 1 and len(self.widths) == 1:
if len(self.cfg.resolution_milestones) > 0:
threestudio.warn(
"Ignoring resolution_milestones since height and width are not changing"
)
self.resolution_milestones = [-1]
else:
assert len(self.heights) == len(self.cfg.resolution_milestones) + 1
self.resolution_milestones = [-1] + self.cfg.resolution_milestones
self.directions_unit_focals = [
get_ray_directions(H=height, W=width, focal=1.0)
for (height, width) in zip(self.heights, self.widths)
]
self.focal_lengths = [
0.5 * height / torch.tan(0.5 * self.fovy) for height in self.heights
]
self.height: int = self.heights[0]
self.width: int = self.widths[0]
self.directions_unit_focal = self.directions_unit_focals[0]
self.focal_length = self.focal_lengths[0]
self.set_rays()
self.load_images()
self.prev_height = self.height
def set_rays(self):
# get directions by dividing directions_unit_focal by focal length
directions: Float[Tensor, "1 H W 3"] = self.directions_unit_focal[None]
directions[:, :, :, :2] = directions[:, :, :, :2] / self.focal_length
rays_o, rays_d = get_rays(
directions,
self.c2w,
keepdim=True,
noise_scale=self.cfg.rays_noise_scale,
normalize=self.cfg.rays_d_normalize,
)
proj_mtx: Float[Tensor, "4 4"] = get_projection_matrix(
self.fovy, self.width / self.height, 0.01, 100.0
) # FIXME: hard-coded near and far
mvp_mtx: Float[Tensor, "4 4"] = get_mvp_matrix(self.c2w, proj_mtx)
self.rays_o, self.rays_d = rays_o, rays_d
self.mvp_mtx = mvp_mtx
def load_images(self):
# load image
assert os.path.exists(
self.cfg.image_path
), f"Could not find image {self.cfg.image_path}!"
rgba = cv2.cvtColor(
cv2.imread(self.cfg.image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA
)
rgba = (
cv2.resize(
rgba, (self.width, self.height), interpolation=cv2.INTER_AREA
).astype(np.float32)
/ 255.0
)
rgb = rgba[..., :3]
self.rgb: Float[Tensor, "1 H W 3"] = (
torch.from_numpy(rgb).unsqueeze(0).contiguous().to(self.rank)
)
self.mask: Float[Tensor, "1 H W 1"] = (
torch.from_numpy(rgba[..., 3:] > 0.5).unsqueeze(0).to(self.rank)
)
print(
f"[INFO] single image dataset: load image {self.cfg.image_path} {self.rgb.shape}"
)
# load depth
if self.cfg.requires_depth:
depth_path = self.cfg.image_path.replace("_rgba.png", "_depth.png")
assert os.path.exists(depth_path)
depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
depth = cv2.resize(
depth, (self.width, self.height), interpolation=cv2.INTER_AREA
)
self.depth: Float[Tensor, "1 H W 1"] = (
torch.from_numpy(depth.astype(np.float32) / 255.0)
.unsqueeze(0)
.to(self.rank)
)
print(
f"[INFO] single image dataset: load depth {depth_path} {self.depth.shape}"
)
else:
self.depth = None
# load normal
if self.cfg.requires_normal:
normal_path = self.cfg.image_path.replace("_rgba.png", "_normal.png")
assert os.path.exists(normal_path)
normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED)
normal = cv2.resize(
normal, (self.width, self.height), interpolation=cv2.INTER_AREA
)
self.normal: Float[Tensor, "1 H W 3"] = (
torch.from_numpy(normal.astype(np.float32) / 255.0)
.unsqueeze(0)
.to(self.rank)
)
print(
f"[INFO] single image dataset: load normal {normal_path} {self.normal.shape}"
)
else:
self.normal = None
def get_all_images(self):
return self.rgb
def update_step_(self, epoch: int, global_step: int, on_load_weights: bool = False):
size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1
self.height = self.heights[size_ind]
if self.height == self.prev_height:
return
self.prev_height = self.height
self.width = self.widths[size_ind]
self.directions_unit_focal = self.directions_unit_focals[size_ind]
self.focal_length = self.focal_lengths[size_ind]
threestudio.debug(f"Training height: {self.height}, width: {self.width}")
self.set_rays()
self.load_images()
class SingleImageIterableDataset(IterableDataset, SingleImageDataBase, Updateable):
def __init__(self, cfg: Any, split: str) -> None:
super().__init__()
self.setup(cfg, split)
def collate(self, batch) -> Dict[str, Any]:
batch = {
"rays_o": self.rays_o,
"rays_d": self.rays_d,
"mvp_mtx": self.mvp_mtx,
"camera_positions": self.camera_position,
"light_positions": self.light_position,
"elevation": self.elevation_deg,
"azimuth": self.azimuth_deg,
"camera_distances": self.camera_distance,
"rgb": self.rgb,
"ref_depth": self.depth,
"ref_normal": self.normal,
"mask": self.mask,
"height": self.cfg.height,
"width": self.cfg.width,
"c2w": self.c2w,
"c2w4x4": self.c2w4x4,
}
if self.cfg.use_random_camera:
batch["random_camera"] = self.random_pose_generator.collate(None)
return batch
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
self.update_step_(epoch, global_step, on_load_weights)
self.random_pose_generator.update_step(epoch, global_step, on_load_weights)
def __iter__(self):
while True:
yield {}
class SingleImageDataset(Dataset, SingleImageDataBase):
def __init__(self, cfg: Any, split: str) -> None:
super().__init__()
self.setup(cfg, split)
def __len__(self):
return len(self.random_pose_generator)
def __getitem__(self, index):
batch = self.random_pose_generator[index]
batch.update(
{
"height": self.random_pose_generator.cfg.eval_height,
"width": self.random_pose_generator.cfg.eval_width,
"mvp_mtx_ref": self.mvp_mtx[0],
"c2w_ref": self.c2w4x4,
}
)
return batch
@register("single-image-datamodule")
class SingleImageDataModule(pl.LightningDataModule):
cfg: SingleImageDataModuleConfig
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
super().__init__()
self.cfg = parse_structured(SingleImageDataModuleConfig, cfg)
def setup(self, stage=None) -> None:
if stage in [None, "fit"]:
self.train_dataset = SingleImageIterableDataset(self.cfg, "train")
if stage in [None, "fit", "validate"]:
self.val_dataset = SingleImageDataset(self.cfg, "val")
if stage in [None, "test", "predict"]:
self.test_dataset = SingleImageDataset(self.cfg, "test")
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
return DataLoader(
dataset, num_workers=0, batch_size=batch_size, collate_fn=collate_fn
)
def train_dataloader(self) -> DataLoader:
return self.general_loader(
self.train_dataset,
batch_size=self.cfg.batch_size,
collate_fn=self.train_dataset.collate,
)
def val_dataloader(self) -> DataLoader:
return self.general_loader(self.val_dataset, batch_size=1)
def test_dataloader(self) -> DataLoader:
return self.general_loader(self.test_dataset, batch_size=1)
def predict_dataloader(self) -> DataLoader:
return self.general_loader(self.test_dataset, batch_size=1)

518
threestudio/data/uncond.py Normal file
View File

@ -0,0 +1,518 @@
import bisect
import math
import random
from dataclasses import dataclass, field
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
import threestudio
from threestudio import register
from threestudio.utils.base import Updateable
from threestudio.utils.config import parse_structured
from threestudio.utils.misc import get_device
from threestudio.utils.ops import (
get_full_projection_matrix,
get_mvp_matrix,
get_projection_matrix,
get_ray_directions,
get_rays,
)
from threestudio.utils.typing import *
@dataclass
class RandomCameraDataModuleConfig:
# height, width, and batch_size should be Union[int, List[int]]
# but OmegaConf does not support Union of containers
height: Any = 64
width: Any = 64
batch_size: Any = 1
resolution_milestones: List[int] = field(default_factory=lambda: [])
eval_height: int = 512
eval_width: int = 512
eval_batch_size: int = 1
n_val_views: int = 1
n_test_views: int = 120
elevation_range: Tuple[float, float] = (-10, 90)
azimuth_range: Tuple[float, float] = (-180, 180)
camera_distance_range: Tuple[float, float] = (1, 1.5)
fovy_range: Tuple[float, float] = (
40,
70,
) # in degrees, in vertical direction (along height)
camera_perturb: float = 0.1
center_perturb: float = 0.2
up_perturb: float = 0.02
light_position_perturb: float = 1.0
light_distance_range: Tuple[float, float] = (0.8, 1.5)
eval_elevation_deg: float = 15.0
eval_camera_distance: float = 1.5
eval_fovy_deg: float = 70.0
light_sample_strategy: str = "dreamfusion"
batch_uniform_azimuth: bool = True
progressive_until: int = 0 # progressive ranges for elevation, azimuth, r, fovy
rays_d_normalize: bool = True
class RandomCameraIterableDataset(IterableDataset, Updateable):
def __init__(self, cfg: Any) -> None:
super().__init__()
self.cfg: RandomCameraDataModuleConfig = cfg
self.heights: List[int] = (
[self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height
)
self.widths: List[int] = (
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width
)
self.batch_sizes: List[int] = (
[self.cfg.batch_size]
if isinstance(self.cfg.batch_size, int)
else self.cfg.batch_size
)
assert len(self.heights) == len(self.widths) == len(self.batch_sizes)
self.resolution_milestones: List[int]
if (
len(self.heights) == 1
and len(self.widths) == 1
and len(self.batch_sizes) == 1
):
if len(self.cfg.resolution_milestones) > 0:
threestudio.warn(
"Ignoring resolution_milestones since height and width are not changing"
)
self.resolution_milestones = [-1]
else:
assert len(self.heights) == len(self.cfg.resolution_milestones) + 1
self.resolution_milestones = [-1] + self.cfg.resolution_milestones
self.directions_unit_focals = [
get_ray_directions(H=height, W=width, focal=1.0)
for (height, width) in zip(self.heights, self.widths)
]
self.height: int = self.heights[0]
self.width: int = self.widths[0]
self.batch_size: int = self.batch_sizes[0]
self.directions_unit_focal = self.directions_unit_focals[0]
self.elevation_range = self.cfg.elevation_range
self.azimuth_range = self.cfg.azimuth_range
self.camera_distance_range = self.cfg.camera_distance_range
self.fovy_range = self.cfg.fovy_range
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1
self.height = self.heights[size_ind]
self.width = self.widths[size_ind]
self.batch_size = self.batch_sizes[size_ind]
self.directions_unit_focal = self.directions_unit_focals[size_ind]
threestudio.debug(
f"Training height: {self.height}, width: {self.width}, batch_size: {self.batch_size}"
)
# progressive view
self.progressive_view(global_step)
def __iter__(self):
while True:
yield {}
def progressive_view(self, global_step):
r = min(1.0, global_step / (self.cfg.progressive_until + 1))
self.elevation_range = [
(1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[0],
(1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[1],
]
self.azimuth_range = [
(1 - r) * 0.0 + r * self.cfg.azimuth_range[0],
(1 - r) * 0.0 + r * self.cfg.azimuth_range[1],
]
# self.camera_distance_range = [
# (1 - r) * self.cfg.eval_camera_distance
# + r * self.cfg.camera_distance_range[0],
# (1 - r) * self.cfg.eval_camera_distance
# + r * self.cfg.camera_distance_range[1],
# ]
# self.fovy_range = [
# (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[0],
# (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[1],
# ]
def collate(self, batch) -> Dict[str, Any]:
# sample elevation angles
elevation_deg: Float[Tensor, "B"]
elevation: Float[Tensor, "B"]
if random.random() < 0.5:
# sample elevation angles uniformly with a probability 0.5 (biased towards poles)
elevation_deg = (
torch.rand(self.batch_size)
* (self.elevation_range[1] - self.elevation_range[0])
+ self.elevation_range[0]
)
elevation = elevation_deg * math.pi / 180
else:
# otherwise sample uniformly on sphere
elevation_range_percent = [
self.elevation_range[0] / 180.0 * math.pi,
self.elevation_range[1] / 180.0 * math.pi,
]
# inverse transform sampling
elevation = torch.asin(
(
torch.rand(self.batch_size)
* (
math.sin(elevation_range_percent[1])
- math.sin(elevation_range_percent[0])
)
+ math.sin(elevation_range_percent[0])
)
)
elevation_deg = elevation / math.pi * 180.0
# sample azimuth angles from a uniform distribution bounded by azimuth_range
azimuth_deg: Float[Tensor, "B"]
if self.cfg.batch_uniform_azimuth:
# ensures sampled azimuth angles in a batch cover the whole range
azimuth_deg = (
torch.rand(self.batch_size) + torch.arange(self.batch_size)
) / self.batch_size * (
self.azimuth_range[1] - self.azimuth_range[0]
) + self.azimuth_range[
0
]
else:
# simple random sampling
azimuth_deg = (
torch.rand(self.batch_size)
* (self.azimuth_range[1] - self.azimuth_range[0])
+ self.azimuth_range[0]
)
azimuth = azimuth_deg * math.pi / 180
# sample distances from a uniform distribution bounded by distance_range
camera_distances: Float[Tensor, "B"] = (
torch.rand(self.batch_size)
* (self.camera_distance_range[1] - self.camera_distance_range[0])
+ self.camera_distance_range[0]
)
# convert spherical coordinates to cartesian coordinates
# right hand coordinate system, x back, y right, z up
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
camera_positions: Float[Tensor, "B 3"] = torch.stack(
[
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
camera_distances * torch.sin(elevation),
],
dim=-1,
)
# default scene center at origin
center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
# default camera up direction as +z
up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
None, :
].repeat(self.batch_size, 1)
# sample camera perturbations from a uniform distribution [-camera_perturb, camera_perturb]
camera_perturb: Float[Tensor, "B 3"] = (
torch.rand(self.batch_size, 3) * 2 * self.cfg.camera_perturb
- self.cfg.camera_perturb
)
camera_positions = camera_positions + camera_perturb
# sample center perturbations from a normal distribution with mean 0 and std center_perturb
center_perturb: Float[Tensor, "B 3"] = (
torch.randn(self.batch_size, 3) * self.cfg.center_perturb
)
center = center + center_perturb
# sample up perturbations from a normal distribution with mean 0 and std up_perturb
up_perturb: Float[Tensor, "B 3"] = (
torch.randn(self.batch_size, 3) * self.cfg.up_perturb
)
up = up + up_perturb
# sample fovs from a uniform distribution bounded by fov_range
fovy_deg: Float[Tensor, "B"] = (
torch.rand(self.batch_size) * (self.fovy_range[1] - self.fovy_range[0])
+ self.fovy_range[0]
)
fovy = fovy_deg * math.pi / 180
# sample light distance from a uniform distribution bounded by light_distance_range
light_distances: Float[Tensor, "B"] = (
torch.rand(self.batch_size)
* (self.cfg.light_distance_range[1] - self.cfg.light_distance_range[0])
+ self.cfg.light_distance_range[0]
)
if self.cfg.light_sample_strategy == "dreamfusion":
# sample light direction from a normal distribution with mean camera_position and std light_position_perturb
light_direction: Float[Tensor, "B 3"] = F.normalize(
camera_positions
+ torch.randn(self.batch_size, 3) * self.cfg.light_position_perturb,
dim=-1,
)
# get light position by scaling light direction by light distance
light_positions: Float[Tensor, "B 3"] = (
light_direction * light_distances[:, None]
)
elif self.cfg.light_sample_strategy == "magic3d":
# sample light direction within restricted angle range (pi/3)
local_z = F.normalize(camera_positions, dim=-1)
local_x = F.normalize(
torch.stack(
[local_z[:, 1], -local_z[:, 0], torch.zeros_like(local_z[:, 0])],
dim=-1,
),
dim=-1,
)
local_y = F.normalize(torch.cross(local_z, local_x, dim=-1), dim=-1)
rot = torch.stack([local_x, local_y, local_z], dim=-1)
light_azimuth = (
torch.rand(self.batch_size) * math.pi * 2 - math.pi
) # [-pi, pi]
light_elevation = (
torch.rand(self.batch_size) * math.pi / 3 + math.pi / 6
) # [pi/6, pi/2]
light_positions_local = torch.stack(
[
light_distances
* torch.cos(light_elevation)
* torch.cos(light_azimuth),
light_distances
* torch.cos(light_elevation)
* torch.sin(light_azimuth),
light_distances * torch.sin(light_elevation),
],
dim=-1,
)
light_positions = (rot @ light_positions_local[:, :, None])[:, :, 0]
else:
raise ValueError(
f"Unknown light sample strategy: {self.cfg.light_sample_strategy}"
)
lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
up = F.normalize(torch.cross(right, lookat), dim=-1)
c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
dim=-1,
)
c2w: Float[Tensor, "B 4 4"] = torch.cat(
[c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
)
c2w[:, 3, 3] = 1.0
# get directions by dividing directions_unit_focal by focal length
focal_length: Float[Tensor, "B"] = 0.5 * self.height / torch.tan(0.5 * fovy)
directions: Float[Tensor, "B H W 3"] = self.directions_unit_focal[
None, :, :, :
].repeat(self.batch_size, 1, 1, 1)
directions[:, :, :, :2] = (
directions[:, :, :, :2] / focal_length[:, None, None, None]
)
# Importance note: the returned rays_d MUST be normalized!
rays_o, rays_d = get_rays(
directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize
)
self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix(
fovy, self.width / self.height, 0.1, 1000.0
) # FIXME: hard-coded near and far
mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx)
self.fovy = fovy
return {
"rays_o": rays_o,
"rays_d": rays_d,
"mvp_mtx": mvp_mtx,
"camera_positions": camera_positions,
"c2w": c2w,
"light_positions": light_positions,
"elevation": elevation_deg,
"azimuth": azimuth_deg,
"camera_distances": camera_distances,
"height": self.height,
"width": self.width,
"fovy": self.fovy,
"proj_mtx": self.proj_mtx,
}
class RandomCameraDataset(Dataset):
def __init__(self, cfg: Any, split: str) -> None:
super().__init__()
self.cfg: RandomCameraDataModuleConfig = cfg
self.split = split
if split == "val":
self.n_views = self.cfg.n_val_views
else:
self.n_views = self.cfg.n_test_views
azimuth_deg: Float[Tensor, "B"]
if self.split == "val":
# make sure the first and last view are not the same
azimuth_deg = torch.linspace(0, 360.0, self.n_views + 1)[: self.n_views]
else:
azimuth_deg = torch.linspace(0, 360.0, self.n_views)
elevation_deg: Float[Tensor, "B"] = torch.full_like(
azimuth_deg, self.cfg.eval_elevation_deg
)
camera_distances: Float[Tensor, "B"] = torch.full_like(
elevation_deg, self.cfg.eval_camera_distance
)
elevation = elevation_deg * math.pi / 180
azimuth = azimuth_deg * math.pi / 180
# convert spherical coordinates to cartesian coordinates
# right hand coordinate system, x back, y right, z up
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
camera_positions: Float[Tensor, "B 3"] = torch.stack(
[
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
camera_distances * torch.sin(elevation),
],
dim=-1,
)
# default scene center at origin
center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
# default camera up direction as +z
up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
None, :
].repeat(self.cfg.eval_batch_size, 1)
fovy_deg: Float[Tensor, "B"] = torch.full_like(
elevation_deg, self.cfg.eval_fovy_deg
)
fovy = fovy_deg * math.pi / 180
light_positions: Float[Tensor, "B 3"] = camera_positions
lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
up = F.normalize(torch.cross(right, lookat), dim=-1)
c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
dim=-1,
)
c2w: Float[Tensor, "B 4 4"] = torch.cat(
[c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
)
c2w[:, 3, 3] = 1.0
# get directions by dividing directions_unit_focal by focal length
focal_length: Float[Tensor, "B"] = (
0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy)
)
directions_unit_focal = get_ray_directions(
H=self.cfg.eval_height, W=self.cfg.eval_width, focal=1.0
)
directions: Float[Tensor, "B H W 3"] = directions_unit_focal[
None, :, :, :
].repeat(self.n_views, 1, 1, 1)
directions[:, :, :, :2] = (
directions[:, :, :, :2] / focal_length[:, None, None, None]
)
rays_o, rays_d = get_rays(
directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize
)
self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix(
fovy, self.cfg.eval_width / self.cfg.eval_height, 0.1, 1000.0
) # FIXME: hard-coded near and far
mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx)
self.rays_o, self.rays_d = rays_o, rays_d
self.mvp_mtx = mvp_mtx
self.c2w = c2w
self.camera_positions = camera_positions
self.light_positions = light_positions
self.elevation, self.azimuth = elevation, azimuth
self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg
self.camera_distances = camera_distances
self.fovy = fovy
def __len__(self):
return self.n_views
def __getitem__(self, index):
return {
"index": index,
"rays_o": self.rays_o[index],
"rays_d": self.rays_d[index],
"mvp_mtx": self.mvp_mtx[index],
"c2w": self.c2w[index],
"camera_positions": self.camera_positions[index],
"light_positions": self.light_positions[index],
"elevation": self.elevation_deg[index],
"azimuth": self.azimuth_deg[index],
"camera_distances": self.camera_distances[index],
"height": self.cfg.eval_height,
"width": self.cfg.eval_width,
"fovy": self.fovy[index],
"proj_mtx": self.proj_mtx[index],
}
def collate(self, batch):
batch = torch.utils.data.default_collate(batch)
batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width})
return batch
@register("random-camera-datamodule")
class RandomCameraDataModule(pl.LightningDataModule):
cfg: RandomCameraDataModuleConfig
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
super().__init__()
self.cfg = parse_structured(RandomCameraDataModuleConfig, cfg)
def setup(self, stage=None) -> None:
if stage in [None, "fit"]:
self.train_dataset = RandomCameraIterableDataset(self.cfg)
if stage in [None, "fit", "validate"]:
self.val_dataset = RandomCameraDataset(self.cfg, "val")
if stage in [None, "test", "predict"]:
self.test_dataset = RandomCameraDataset(self.cfg, "test")
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
return DataLoader(
dataset,
# very important to disable multi-processing if you want to change self attributes at runtime!
# (for example setting self.width and self.height in update_step)
num_workers=0, # type: ignore
batch_size=batch_size,
collate_fn=collate_fn,
)
def train_dataloader(self) -> DataLoader:
return self.general_loader(
self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate
)
def val_dataloader(self) -> DataLoader:
return self.general_loader(
self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate
)
# return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate)
def test_dataloader(self) -> DataLoader:
return self.general_loader(
self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate
)
def predict_dataloader(self) -> DataLoader:
return self.general_loader(
self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate
)

View File

@ -0,0 +1,9 @@
from . import (
background,
exporters,
geometry,
guidance,
materials,
prompt_processors,
renderers,
)

View File

@ -0,0 +1,6 @@
from . import (
base,
neural_environment_map_background,
solid_color_background,
textured_background,
)

View File

@ -0,0 +1,24 @@
import random
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.utils.base import BaseModule
from threestudio.utils.typing import *
class BaseBackground(BaseModule):
@dataclass
class Config(BaseModule.Config):
pass
cfg: Config
def configure(self):
pass
def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
raise NotImplementedError

View File

@ -0,0 +1,71 @@
import random
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.models.networks import get_encoding, get_mlp
from threestudio.utils.ops import get_activation
from threestudio.utils.typing import *
@threestudio.register("neural-environment-map-background")
class NeuralEnvironmentMapBackground(BaseBackground):
@dataclass
class Config(BaseBackground.Config):
n_output_dims: int = 3
color_activation: str = "sigmoid"
dir_encoding_config: dict = field(
default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3}
)
mlp_network_config: dict = field(
default_factory=lambda: {
"otype": "VanillaMLP",
"activation": "ReLU",
"n_neurons": 16,
"n_hidden_layers": 2,
}
)
random_aug: bool = False
random_aug_prob: float = 0.5
eval_color: Optional[Tuple[float, float, float]] = None
# multi-view diffusion
share_aug_bg: bool = False
cfg: Config
def configure(self) -> None:
self.encoding = get_encoding(3, self.cfg.dir_encoding_config)
self.network = get_mlp(
self.encoding.n_output_dims,
self.cfg.n_output_dims,
self.cfg.mlp_network_config,
)
def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
if not self.training and self.cfg.eval_color is not None:
return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(
dirs
) * torch.as_tensor(self.cfg.eval_color).to(dirs)
# viewdirs must be normalized before passing to this function
dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1)
dirs_embd = self.encoding(dirs.view(-1, 3))
color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims)
color = get_activation(self.cfg.color_activation)(color)
if (
self.training
and self.cfg.random_aug
and random.random() < self.cfg.random_aug_prob
):
# use random background color with probability random_aug_prob
n_color = 1 if self.cfg.share_aug_bg else dirs.shape[0]
color = color * 0 + ( # prevent checking for unused parameters in DDP
torch.rand(n_color, 1, 1, self.cfg.n_output_dims)
.to(dirs)
.expand(*dirs.shape[:-1], -1)
)
return color

View File

@ -0,0 +1,51 @@
import random
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.utils.typing import *
@threestudio.register("solid-color-background")
class SolidColorBackground(BaseBackground):
@dataclass
class Config(BaseBackground.Config):
n_output_dims: int = 3
color: Tuple = (1.0, 1.0, 1.0)
learned: bool = False
random_aug: bool = False
random_aug_prob: float = 0.5
cfg: Config
def configure(self) -> None:
self.env_color: Float[Tensor, "Nc"]
if self.cfg.learned:
self.env_color = nn.Parameter(
torch.as_tensor(self.cfg.color, dtype=torch.float32)
)
else:
self.register_buffer(
"env_color", torch.as_tensor(self.cfg.color, dtype=torch.float32)
)
def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
color = torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(
dirs
) * self.env_color.to(dirs)
if (
self.training
and self.cfg.random_aug
and random.random() < self.cfg.random_aug_prob
):
# use random background color with probability random_aug_prob
color = color * 0 + ( # prevent checking for unused parameters in DDP
torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims)
.to(dirs)
.expand(*dirs.shape[:-1], -1)
)
return color

View File

@ -0,0 +1,54 @@
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.utils.ops import get_activation
from threestudio.utils.typing import *
@threestudio.register("textured-background")
class TexturedBackground(BaseBackground):
@dataclass
class Config(BaseBackground.Config):
n_output_dims: int = 3
height: int = 64
width: int = 64
color_activation: str = "sigmoid"
cfg: Config
def configure(self) -> None:
self.texture = nn.Parameter(
torch.randn((1, self.cfg.n_output_dims, self.cfg.height, self.cfg.width))
)
def spherical_xyz_to_uv(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 2"]:
x, y, z = dirs[..., 0], dirs[..., 1], dirs[..., 2]
xy = (x**2 + y**2) ** 0.5
u = torch.atan2(xy, z) / torch.pi
v = torch.atan2(y, x) / (torch.pi * 2) + 0.5
uv = torch.stack([u, v], -1)
return uv
def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B Nc"]:
dirs_shape = dirs.shape[:-1]
uv = self.spherical_xyz_to_uv(dirs.reshape(-1, dirs.shape[-1]))
uv = 2 * uv - 1 # rescale to [-1, 1] for grid_sample
uv = uv.reshape(1, -1, 1, 2)
color = (
F.grid_sample(
self.texture,
uv,
mode="bilinear",
padding_mode="reflection",
align_corners=False,
)
.reshape(self.cfg.n_output_dims, -1)
.T.reshape(*dirs_shape, self.cfg.n_output_dims)
)
color = get_activation(self.cfg.color_activation)(color)
return color

View File

@ -0,0 +1,118 @@
from typing import Callable, List, Optional, Tuple
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
from nerfacc.data_specs import RayIntervals
from nerfacc.estimators.base import AbstractEstimator
from nerfacc.pdf import importance_sampling, searchsorted
from nerfacc.volrend import render_transmittance_from_density
from torch import Tensor
class ImportanceEstimator(AbstractEstimator):
def __init__(
self,
) -> None:
super().__init__()
@torch.no_grad()
def sampling(
self,
prop_sigma_fns: List[Callable],
prop_samples: List[int],
num_samples: int,
# rendering options
n_rays: int,
near_plane: float,
far_plane: float,
sampling_type: Literal["uniform", "lindisp"] = "uniform",
# training options
stratified: bool = False,
requires_grad: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Sampling with CDFs from proposal networks.
Args:
prop_sigma_fns: Proposal network evaluate functions. It should be a list
of functions that take in samples {t_starts (n_rays, n_samples),
t_ends (n_rays, n_samples)} and returns the post-activation densities
(n_rays, n_samples).
prop_samples: Number of samples to draw from each proposal network. Should
be the same length as `prop_sigma_fns`.
num_samples: Number of samples to draw in the end.
n_rays: Number of rays.
near_plane: Near plane.
far_plane: Far plane.
sampling_type: Sampling type. Either "uniform" or "lindisp". Default to
"lindisp".
stratified: Whether to use stratified sampling. Default to `False`.
Returns:
A tuple of {Tensor, Tensor}:
- **t_starts**: The starts of the samples. Shape (n_rays, num_samples).
- **t_ends**: The ends of the samples. Shape (n_rays, num_samples).
"""
assert len(prop_sigma_fns) == len(prop_samples), (
"The number of proposal networks and the number of samples "
"should be the same."
)
cdfs = torch.cat(
[
torch.zeros((n_rays, 1), device=self.device),
torch.ones((n_rays, 1), device=self.device),
],
dim=-1,
)
intervals = RayIntervals(vals=cdfs)
for level_fn, level_samples in zip(prop_sigma_fns, prop_samples):
intervals, _ = importance_sampling(
intervals, cdfs, level_samples, stratified
)
t_vals = _transform_stot(
sampling_type, intervals.vals, near_plane, far_plane
)
t_starts = t_vals[..., :-1]
t_ends = t_vals[..., 1:]
with torch.set_grad_enabled(requires_grad):
sigmas = level_fn(t_starts, t_ends)
assert sigmas.shape == t_starts.shape
trans, _ = render_transmittance_from_density(t_starts, t_ends, sigmas)
cdfs = 1.0 - torch.cat([trans, torch.zeros_like(trans[:, :1])], dim=-1)
intervals, _ = importance_sampling(intervals, cdfs, num_samples, stratified)
t_vals_fine = _transform_stot(
sampling_type, intervals.vals, near_plane, far_plane
)
t_vals = torch.cat([t_vals, t_vals_fine], dim=-1)
t_vals, _ = torch.sort(t_vals, dim=-1)
t_starts_ = t_vals[..., :-1]
t_ends_ = t_vals[..., 1:]
return t_starts_, t_ends_
def _transform_stot(
transform_type: Literal["uniform", "lindisp"],
s_vals: torch.Tensor,
t_min: torch.Tensor,
t_max: torch.Tensor,
) -> torch.Tensor:
if transform_type == "uniform":
_contract_fn, _icontract_fn = lambda x: x, lambda x: x
elif transform_type == "lindisp":
_contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x
else:
raise ValueError(f"Unknown transform_type: {transform_type}")
s_min, s_max = _contract_fn(t_min), _contract_fn(t_max)
icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min)
return icontract_fn(s_vals)

View File

@ -0,0 +1 @@
from . import base, mesh_exporter

View File

@ -0,0 +1,59 @@
from dataclasses import dataclass
import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.models.geometry.base import BaseImplicitGeometry
from threestudio.models.materials.base import BaseMaterial
from threestudio.utils.base import BaseObject
from threestudio.utils.typing import *
@dataclass
class ExporterOutput:
save_name: str
save_type: str
params: Dict[str, Any]
class Exporter(BaseObject):
@dataclass
class Config(BaseObject.Config):
save_video: bool = False
cfg: Config
def configure(
self,
geometry: BaseImplicitGeometry,
material: BaseMaterial,
background: BaseBackground,
) -> None:
@dataclass
class SubModules:
geometry: BaseImplicitGeometry
material: BaseMaterial
background: BaseBackground
self.sub_modules = SubModules(geometry, material, background)
@property
def geometry(self) -> BaseImplicitGeometry:
return self.sub_modules.geometry
@property
def material(self) -> BaseMaterial:
return self.sub_modules.material
@property
def background(self) -> BaseBackground:
return self.sub_modules.background
def __call__(self, *args, **kwargs) -> List[ExporterOutput]:
raise NotImplementedError
@threestudio.register("dummy-exporter")
class DummyExporter(Exporter):
def __call__(self, *args, **kwargs) -> List[ExporterOutput]:
# DummyExporter does not export anything
return []

View File

@ -0,0 +1,175 @@
from dataclasses import dataclass, field
import cv2
import numpy as np
import torch
import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.models.exporters.base import Exporter, ExporterOutput
from threestudio.models.geometry.base import BaseImplicitGeometry
from threestudio.models.materials.base import BaseMaterial
from threestudio.models.mesh import Mesh
from threestudio.utils.rasterize import NVDiffRasterizerContext
from threestudio.utils.typing import *
@threestudio.register("mesh-exporter")
class MeshExporter(Exporter):
@dataclass
class Config(Exporter.Config):
fmt: str = "obj-mtl" # in ['obj-mtl', 'obj'], TODO: fbx
save_name: str = "model"
save_normal: bool = False
save_uv: bool = True
save_texture: bool = True
texture_size: int = 1024
texture_format: str = "jpg"
xatlas_chart_options: dict = field(default_factory=dict)
xatlas_pack_options: dict = field(default_factory=dict)
context_type: str = "gl"
cfg: Config
def configure(
self,
geometry: BaseImplicitGeometry,
material: BaseMaterial,
background: BaseBackground,
) -> None:
super().configure(geometry, material, background)
self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device)
def __call__(self) -> List[ExporterOutput]:
mesh: Mesh = self.geometry.isosurface()
if self.cfg.fmt == "obj-mtl":
return self.export_obj_with_mtl(mesh)
elif self.cfg.fmt == "obj":
return self.export_obj(mesh)
else:
raise ValueError(f"Unsupported mesh export format: {self.cfg.fmt}")
def export_obj_with_mtl(self, mesh: Mesh) -> List[ExporterOutput]:
params = {
"mesh": mesh,
"save_mat": True,
"save_normal": self.cfg.save_normal,
"save_uv": self.cfg.save_uv,
"save_vertex_color": False,
"map_Kd": None, # Base Color
"map_Ks": None, # Specular
"map_Bump": None, # Normal
# ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
"map_Pm": None, # Metallic
"map_Pr": None, # Roughness
"map_format": self.cfg.texture_format,
}
if self.cfg.save_uv:
mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options)
if self.cfg.save_texture:
threestudio.info("Exporting textures ...")
assert self.cfg.save_uv, "save_uv must be True when save_texture is True"
# clip space transform
uv_clip = mesh.v_tex * 2.0 - 1.0
# pad to four component coordinate
uv_clip4 = torch.cat(
(
uv_clip,
torch.zeros_like(uv_clip[..., 0:1]),
torch.ones_like(uv_clip[..., 0:1]),
),
dim=-1,
)
# rasterize
rast, _ = self.ctx.rasterize_one(
uv_clip4, mesh.t_tex_idx, (self.cfg.texture_size, self.cfg.texture_size)
)
hole_mask = ~(rast[:, :, 3] > 0)
def uv_padding(image):
uv_padding_size = self.cfg.xatlas_pack_options.get("padding", 2)
inpaint_image = (
cv2.inpaint(
(image.detach().cpu().numpy() * 255).astype(np.uint8),
(hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
uv_padding_size,
cv2.INPAINT_TELEA,
)
/ 255.0
)
return torch.from_numpy(inpaint_image).to(image)
# Interpolate world space position
gb_pos, _ = self.ctx.interpolate_one(
mesh.v_pos, rast[None, ...], mesh.t_pos_idx
)
gb_pos = gb_pos[0]
# Sample out textures from MLP
geo_out = self.geometry.export(points=gb_pos)
mat_out = self.material.export(points=gb_pos, **geo_out)
threestudio.info(
"Perform UV padding on texture maps to avoid seams, may take a while ..."
)
if "albedo" in mat_out:
params["map_Kd"] = uv_padding(mat_out["albedo"])
else:
threestudio.warn(
"save_texture is True but no albedo texture found, using default white texture"
)
if "metallic" in mat_out:
params["map_Pm"] = uv_padding(mat_out["metallic"])
if "roughness" in mat_out:
params["map_Pr"] = uv_padding(mat_out["roughness"])
if "bump" in mat_out:
params["map_Bump"] = uv_padding(mat_out["bump"])
# TODO: map_Ks
return [
ExporterOutput(
save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params
)
]
def export_obj(self, mesh: Mesh) -> List[ExporterOutput]:
params = {
"mesh": mesh,
"save_mat": False,
"save_normal": self.cfg.save_normal,
"save_uv": self.cfg.save_uv,
"save_vertex_color": False,
"map_Kd": None, # Base Color
"map_Ks": None, # Specular
"map_Bump": None, # Normal
# ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
"map_Pm": None, # Metallic
"map_Pr": None, # Roughness
"map_format": self.cfg.texture_format,
}
if self.cfg.save_uv:
mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options)
if self.cfg.save_texture:
threestudio.info("Exporting textures ...")
geo_out = self.geometry.export(points=mesh.v_pos)
mat_out = self.material.export(points=mesh.v_pos, **geo_out)
if "albedo" in mat_out:
mesh.set_vertex_color(mat_out["albedo"])
params["save_vertex_color"] = True
else:
threestudio.warn(
"save_texture is True but no albedo texture found, not saving vertex color"
)
return [
ExporterOutput(
save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params
)
]

View File

@ -0,0 +1,8 @@
from . import (
base,
custom_mesh,
implicit_sdf,
implicit_volume,
tetrahedra_sdf_grid,
volume_grid,
)

View File

@ -0,0 +1,209 @@
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.isosurface import (
IsosurfaceHelper,
MarchingCubeCPUHelper,
MarchingTetrahedraHelper,
)
from threestudio.models.mesh import Mesh
from threestudio.utils.base import BaseModule
from threestudio.utils.ops import chunk_batch, scale_tensor
from threestudio.utils.typing import *
def contract_to_unisphere(
x: Float[Tensor, "... 3"], bbox: Float[Tensor, "2 3"], unbounded: bool = False
) -> Float[Tensor, "... 3"]:
if unbounded:
x = scale_tensor(x, bbox, (0, 1))
x = x * 2 - 1 # aabb is at [-1, 1]
mag = x.norm(dim=-1, keepdim=True)
mask = mag.squeeze(-1) > 1
x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
x = x / 4 + 0.5 # [-inf, inf] is at [0, 1]
else:
x = scale_tensor(x, bbox, (0, 1))
return x
class BaseGeometry(BaseModule):
@dataclass
class Config(BaseModule.Config):
pass
cfg: Config
@staticmethod
def create_from(
other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs
) -> "BaseGeometry":
raise TypeError(
f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}"
)
def export(self, *args, **kwargs) -> Dict[str, Any]:
return {}
class BaseImplicitGeometry(BaseGeometry):
@dataclass
class Config(BaseGeometry.Config):
radius: float = 1.0
isosurface: bool = True
isosurface_method: str = "mt"
isosurface_resolution: int = 128
isosurface_threshold: Union[float, str] = 0.0
isosurface_chunk: int = 0
isosurface_coarse_to_fine: bool = True
isosurface_deformable_grid: bool = False
isosurface_remove_outliers: bool = True
isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01
cfg: Config
def configure(self) -> None:
self.bbox: Float[Tensor, "2 3"]
self.register_buffer(
"bbox",
torch.as_tensor(
[
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
],
dtype=torch.float32,
),
)
self.isosurface_helper: Optional[IsosurfaceHelper] = None
self.unbounded: bool = False
def _initilize_isosurface_helper(self):
if self.cfg.isosurface and self.isosurface_helper is None:
if self.cfg.isosurface_method == "mc-cpu":
self.isosurface_helper = MarchingCubeCPUHelper(
self.cfg.isosurface_resolution
).to(self.device)
elif self.cfg.isosurface_method == "mt":
self.isosurface_helper = MarchingTetrahedraHelper(
self.cfg.isosurface_resolution,
f"load/tets/{self.cfg.isosurface_resolution}_tets.npz",
).to(self.device)
else:
raise AttributeError(
"Unknown isosurface method {self.cfg.isosurface_method}"
)
def forward(
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
) -> Dict[str, Float[Tensor, "..."]]:
raise NotImplementedError
def forward_field(
self, points: Float[Tensor, "*N Di"]
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
# return the value of the implicit field, could be density / signed distance
# also return a deformation field if the grid vertices can be optimized
raise NotImplementedError
def forward_level(
self, field: Float[Tensor, "*N 1"], threshold: float
) -> Float[Tensor, "*N 1"]:
# return the value of the implicit field, where the zero level set represents the surface
raise NotImplementedError
def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh:
def batch_func(x):
# scale to bbox as the input vertices are in [0, 1]
field, deformation = self.forward_field(
scale_tensor(
x.to(bbox.device), self.isosurface_helper.points_range, bbox
),
)
field = field.to(
x.device
) # move to the same device as the input (could be CPU)
if deformation is not None:
deformation = deformation.to(x.device)
return field, deformation
assert self.isosurface_helper is not None
field, deformation = chunk_batch(
batch_func,
self.cfg.isosurface_chunk,
self.isosurface_helper.grid_vertices,
)
threshold: float
if isinstance(self.cfg.isosurface_threshold, float):
threshold = self.cfg.isosurface_threshold
elif self.cfg.isosurface_threshold == "auto":
eps = 1.0e-5
threshold = field[field > eps].mean().item()
threestudio.info(
f"Automatically determined isosurface threshold: {threshold}"
)
else:
raise TypeError(
f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}"
)
level = self.forward_level(field, threshold)
mesh: Mesh = self.isosurface_helper(level, deformation=deformation)
mesh.v_pos = scale_tensor(
mesh.v_pos, self.isosurface_helper.points_range, bbox
) # scale to bbox as the grid vertices are in [0, 1]
mesh.add_extra("bbox", bbox)
if self.cfg.isosurface_remove_outliers:
# remove outliers components with small number of faces
# only enabled when the mesh is not differentiable
mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold)
return mesh
def isosurface(self) -> Mesh:
if not self.cfg.isosurface:
raise NotImplementedError(
"Isosurface is not enabled in the current configuration"
)
self._initilize_isosurface_helper()
if self.cfg.isosurface_coarse_to_fine:
threestudio.debug("First run isosurface to get a tight bounding box ...")
with torch.no_grad():
mesh_coarse = self._isosurface(self.bbox)
vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0)
vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0])
vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1])
threestudio.debug("Run isosurface again with the tight bounding box ...")
mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True)
else:
mesh = self._isosurface(self.bbox)
return mesh
class BaseExplicitGeometry(BaseGeometry):
@dataclass
class Config(BaseGeometry.Config):
radius: float = 1.0
cfg: Config
def configure(self) -> None:
self.bbox: Float[Tensor, "2 3"]
self.register_buffer(
"bbox",
torch.as_tensor(
[
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
],
dtype=torch.float32,
),
)

View File

@ -0,0 +1,178 @@
import os
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.geometry.base import (
BaseExplicitGeometry,
BaseGeometry,
contract_to_unisphere,
)
from threestudio.models.mesh import Mesh
from threestudio.models.networks import get_encoding, get_mlp
from threestudio.utils.ops import scale_tensor
from threestudio.utils.typing import *
@threestudio.register("custom-mesh")
class CustomMesh(BaseExplicitGeometry):
@dataclass
class Config(BaseExplicitGeometry.Config):
n_input_dims: int = 3
n_feature_dims: int = 3
pos_encoding_config: dict = field(
default_factory=lambda: {
"otype": "HashGrid",
"n_levels": 16,
"n_features_per_level": 2,
"log2_hashmap_size": 19,
"base_resolution": 16,
"per_level_scale": 1.447269237440378,
}
)
mlp_network_config: dict = field(
default_factory=lambda: {
"otype": "VanillaMLP",
"activation": "ReLU",
"output_activation": "none",
"n_neurons": 64,
"n_hidden_layers": 1,
}
)
shape_init: str = ""
shape_init_params: Optional[Any] = None
shape_init_mesh_up: str = "+z"
shape_init_mesh_front: str = "+x"
cfg: Config
def configure(self) -> None:
super().configure()
self.encoding = get_encoding(
self.cfg.n_input_dims, self.cfg.pos_encoding_config
)
self.feature_network = get_mlp(
self.encoding.n_output_dims,
self.cfg.n_feature_dims,
self.cfg.mlp_network_config,
)
# Initialize custom mesh
if self.cfg.shape_init.startswith("mesh:"):
assert isinstance(self.cfg.shape_init_params, float)
mesh_path = self.cfg.shape_init[5:]
if not os.path.exists(mesh_path):
raise ValueError(f"Mesh file {mesh_path} does not exist.")
import trimesh
scene = trimesh.load(mesh_path)
if isinstance(scene, trimesh.Trimesh):
mesh = scene
elif isinstance(scene, trimesh.scene.Scene):
mesh = trimesh.Trimesh()
for obj in scene.geometry.values():
mesh = trimesh.util.concatenate([mesh, obj])
else:
raise ValueError(f"Unknown mesh type at {mesh_path}.")
# move to center
centroid = mesh.vertices.mean(0)
mesh.vertices = mesh.vertices - centroid
# align to up-z and front-x
dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
dir2vec = {
"+x": np.array([1, 0, 0]),
"+y": np.array([0, 1, 0]),
"+z": np.array([0, 0, 1]),
"-x": np.array([-1, 0, 0]),
"-y": np.array([0, -1, 0]),
"-z": np.array([0, 0, -1]),
}
if (
self.cfg.shape_init_mesh_up not in dirs
or self.cfg.shape_init_mesh_front not in dirs
):
raise ValueError(
f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
)
if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]:
raise ValueError(
"shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
)
z_, x_ = (
dir2vec[self.cfg.shape_init_mesh_up],
dir2vec[self.cfg.shape_init_mesh_front],
)
y_ = np.cross(z_, x_)
std2mesh = np.stack([x_, y_, z_], axis=0).T
mesh2std = np.linalg.inv(std2mesh)
# scaling
scale = np.abs(mesh.vertices).max()
mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params
mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
v_pos = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device)
t_pos_idx = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device)
self.mesh = Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx)
self.register_buffer(
"v_buffer",
v_pos,
)
self.register_buffer(
"t_buffer",
t_pos_idx,
)
else:
raise ValueError(
f"Unknown shape initialization type: {self.cfg.shape_init}"
)
print(self.mesh.v_pos.device)
def isosurface(self) -> Mesh:
if hasattr(self, "mesh"):
return self.mesh
elif hasattr(self, "v_buffer"):
self.mesh = Mesh(v_pos=self.v_buffer, t_pos_idx=self.t_buffer)
return self.mesh
else:
raise ValueError(f"custom mesh is not initialized")
def forward(
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
) -> Dict[str, Float[Tensor, "..."]]:
assert (
output_normal == False
), f"Normal output is not supported for {self.__class__.__name__}"
points_unscaled = points # points in the original scale
points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1)
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
return {"features": features}
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
out: Dict[str, Any] = {}
if self.cfg.n_feature_dims == 0:
return out
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox)
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
out.update(
{
"features": features,
}
)
return out

View File

@ -0,0 +1,413 @@
import os
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere
from threestudio.models.mesh import Mesh
from threestudio.models.networks import get_encoding, get_mlp
from threestudio.utils.misc import broadcast, get_rank
from threestudio.utils.typing import *
@threestudio.register("implicit-sdf")
class ImplicitSDF(BaseImplicitGeometry):
@dataclass
class Config(BaseImplicitGeometry.Config):
n_input_dims: int = 3
n_feature_dims: int = 3
pos_encoding_config: dict = field(
default_factory=lambda: {
"otype": "HashGrid",
"n_levels": 16,
"n_features_per_level": 2,
"log2_hashmap_size": 19,
"base_resolution": 16,
"per_level_scale": 1.447269237440378,
}
)
mlp_network_config: dict = field(
default_factory=lambda: {
"otype": "VanillaMLP",
"activation": "ReLU",
"output_activation": "none",
"n_neurons": 64,
"n_hidden_layers": 1,
}
)
normal_type: Optional[
str
] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian']
finite_difference_normal_eps: Union[
float, str
] = 0.01 # in [float, "progressive"]
shape_init: Optional[str] = None
shape_init_params: Optional[Any] = None
shape_init_mesh_up: str = "+z"
shape_init_mesh_front: str = "+x"
force_shape_init: bool = False
sdf_bias: Union[float, str] = 0.0
sdf_bias_params: Optional[Any] = None
# no need to removal outlier for SDF
isosurface_remove_outliers: bool = False
cfg: Config
def configure(self) -> None:
super().configure()
self.encoding = get_encoding(
self.cfg.n_input_dims, self.cfg.pos_encoding_config
)
self.sdf_network = get_mlp(
self.encoding.n_output_dims, 1, self.cfg.mlp_network_config
)
if self.cfg.n_feature_dims > 0:
self.feature_network = get_mlp(
self.encoding.n_output_dims,
self.cfg.n_feature_dims,
self.cfg.mlp_network_config,
)
if self.cfg.normal_type == "pred":
self.normal_network = get_mlp(
self.encoding.n_output_dims, 3, self.cfg.mlp_network_config
)
if self.cfg.isosurface_deformable_grid:
assert (
self.cfg.isosurface_method == "mt"
), "isosurface_deformable_grid only works with mt"
self.deformation_network = get_mlp(
self.encoding.n_output_dims, 3, self.cfg.mlp_network_config
)
self.finite_difference_normal_eps: Optional[float] = None
def initialize_shape(self) -> None:
if self.cfg.shape_init is None and not self.cfg.force_shape_init:
return
# do not initialize shape if weights are provided
if self.cfg.weights is not None and not self.cfg.force_shape_init:
return
if self.cfg.sdf_bias != 0.0:
threestudio.warn(
"shape_init and sdf_bias are both specified, which may lead to unexpected results."
)
get_gt_sdf: Callable[[Float[Tensor, "N 3"]], Float[Tensor, "N 1"]]
assert isinstance(self.cfg.shape_init, str)
if self.cfg.shape_init == "ellipsoid":
assert (
isinstance(self.cfg.shape_init_params, Sized)
and len(self.cfg.shape_init_params) == 3
)
size = torch.as_tensor(self.cfg.shape_init_params).to(self.device)
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
return ((points_rand / size) ** 2).sum(
dim=-1, keepdim=True
).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
get_gt_sdf = func
elif self.cfg.shape_init == "sphere":
assert isinstance(self.cfg.shape_init_params, float)
radius = self.cfg.shape_init_params
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
return (points_rand**2).sum(dim=-1, keepdim=True).sqrt() - radius
get_gt_sdf = func
elif self.cfg.shape_init.startswith("mesh:"):
assert isinstance(self.cfg.shape_init_params, float)
mesh_path = self.cfg.shape_init[5:]
if not os.path.exists(mesh_path):
raise ValueError(f"Mesh file {mesh_path} does not exist.")
import trimesh
scene = trimesh.load(mesh_path)
if isinstance(scene, trimesh.Trimesh):
mesh = scene
elif isinstance(scene, trimesh.scene.Scene):
mesh = trimesh.Trimesh()
for obj in scene.geometry.values():
mesh = trimesh.util.concatenate([mesh, obj])
else:
raise ValueError(f"Unknown mesh type at {mesh_path}.")
# move to center
centroid = mesh.vertices.mean(0)
mesh.vertices = mesh.vertices - centroid
# align to up-z and front-x
dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
dir2vec = {
"+x": np.array([1, 0, 0]),
"+y": np.array([0, 1, 0]),
"+z": np.array([0, 0, 1]),
"-x": np.array([-1, 0, 0]),
"-y": np.array([0, -1, 0]),
"-z": np.array([0, 0, -1]),
}
if (
self.cfg.shape_init_mesh_up not in dirs
or self.cfg.shape_init_mesh_front not in dirs
):
raise ValueError(
f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
)
if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]:
raise ValueError(
"shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
)
z_, x_ = (
dir2vec[self.cfg.shape_init_mesh_up],
dir2vec[self.cfg.shape_init_mesh_front],
)
y_ = np.cross(z_, x_)
std2mesh = np.stack([x_, y_, z_], axis=0).T
mesh2std = np.linalg.inv(std2mesh)
# scaling
scale = np.abs(mesh.vertices).max()
mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params
mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
from pysdf import SDF
sdf = SDF(mesh.vertices, mesh.faces)
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
# add a negative signed here
# as in pysdf the inside of the shape has positive signed distance
return torch.from_numpy(-sdf(points_rand.cpu().numpy())).to(
points_rand
)[..., None]
get_gt_sdf = func
else:
raise ValueError(
f"Unknown shape initialization type: {self.cfg.shape_init}"
)
# Initialize SDF to a given shape when no weights are provided or force_shape_init is True
optim = torch.optim.Adam(self.parameters(), lr=1e-3)
from tqdm import tqdm
for _ in tqdm(
range(1000),
desc=f"Initializing SDF to a(n) {self.cfg.shape_init}:",
disable=get_rank() != 0,
):
points_rand = (
torch.rand((10000, 3), dtype=torch.float32).to(self.device) * 2.0 - 1.0
)
sdf_gt = get_gt_sdf(points_rand)
sdf_pred = self.forward_sdf(points_rand)
loss = F.mse_loss(sdf_pred, sdf_gt)
optim.zero_grad()
loss.backward()
optim.step()
# explicit broadcast to ensure param consistency across ranks
for param in self.parameters():
broadcast(param, src=0)
def get_shifted_sdf(
self, points: Float[Tensor, "*N Di"], sdf: Float[Tensor, "*N 1"]
) -> Float[Tensor, "*N 1"]:
sdf_bias: Union[float, Float[Tensor, "*N 1"]]
if self.cfg.sdf_bias == "ellipsoid":
assert (
isinstance(self.cfg.sdf_bias_params, Sized)
and len(self.cfg.sdf_bias_params) == 3
)
size = torch.as_tensor(self.cfg.sdf_bias_params).to(points)
sdf_bias = ((points / size) ** 2).sum(
dim=-1, keepdim=True
).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
elif self.cfg.sdf_bias == "sphere":
assert isinstance(self.cfg.sdf_bias_params, float)
radius = self.cfg.sdf_bias_params
sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius
elif isinstance(self.cfg.sdf_bias, float):
sdf_bias = self.cfg.sdf_bias
else:
raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}")
return sdf + sdf_bias
def forward(
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
) -> Dict[str, Float[Tensor, "..."]]:
grad_enabled = torch.is_grad_enabled()
if output_normal and self.cfg.normal_type == "analytic":
torch.set_grad_enabled(True)
points.requires_grad_(True)
points_unscaled = points # points in the original scale
points = contract_to_unisphere(
points, self.bbox, self.unbounded
) # points normalized to (0, 1)
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
sdf = self.sdf_network(enc).view(*points.shape[:-1], 1)
sdf = self.get_shifted_sdf(points_unscaled, sdf)
output = {"sdf": sdf}
if self.cfg.n_feature_dims > 0:
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
output.update({"features": features})
if output_normal:
if (
self.cfg.normal_type == "finite_difference"
or self.cfg.normal_type == "finite_difference_laplacian"
):
assert self.finite_difference_normal_eps is not None
eps: float = self.finite_difference_normal_eps
if self.cfg.normal_type == "finite_difference_laplacian":
offsets: Float[Tensor, "6 3"] = torch.as_tensor(
[
[eps, 0.0, 0.0],
[-eps, 0.0, 0.0],
[0.0, eps, 0.0],
[0.0, -eps, 0.0],
[0.0, 0.0, eps],
[0.0, 0.0, -eps],
]
).to(points_unscaled)
points_offset: Float[Tensor, "... 6 3"] = (
points_unscaled[..., None, :] + offsets
).clamp(-self.cfg.radius, self.cfg.radius)
sdf_offset: Float[Tensor, "... 6 1"] = self.forward_sdf(
points_offset
)
sdf_grad = (
0.5
* (sdf_offset[..., 0::2, 0] - sdf_offset[..., 1::2, 0])
/ eps
)
else:
offsets: Float[Tensor, "3 3"] = torch.as_tensor(
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
).to(points_unscaled)
points_offset: Float[Tensor, "... 3 3"] = (
points_unscaled[..., None, :] + offsets
).clamp(-self.cfg.radius, self.cfg.radius)
sdf_offset: Float[Tensor, "... 3 1"] = self.forward_sdf(
points_offset
)
sdf_grad = (sdf_offset[..., 0::1, 0] - sdf) / eps
normal = F.normalize(sdf_grad, dim=-1)
elif self.cfg.normal_type == "pred":
normal = self.normal_network(enc).view(*points.shape[:-1], 3)
normal = F.normalize(normal, dim=-1)
sdf_grad = normal
elif self.cfg.normal_type == "analytic":
sdf_grad = -torch.autograd.grad(
sdf,
points_unscaled,
grad_outputs=torch.ones_like(sdf),
create_graph=True,
)[0]
normal = F.normalize(sdf_grad, dim=-1)
if not grad_enabled:
sdf_grad = sdf_grad.detach()
normal = normal.detach()
else:
raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
output.update(
{"normal": normal, "shading_normal": normal, "sdf_grad": sdf_grad}
)
return output
def forward_sdf(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
sdf = self.sdf_network(
self.encoding(points.reshape(-1, self.cfg.n_input_dims))
).reshape(*points.shape[:-1], 1)
sdf = self.get_shifted_sdf(points_unscaled, sdf)
return sdf
def forward_field(
self, points: Float[Tensor, "*N Di"]
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
sdf = self.sdf_network(enc).reshape(*points.shape[:-1], 1)
sdf = self.get_shifted_sdf(points_unscaled, sdf)
deformation: Optional[Float[Tensor, "*N 3"]] = None
if self.cfg.isosurface_deformable_grid:
deformation = self.deformation_network(enc).reshape(*points.shape[:-1], 3)
return sdf, deformation
def forward_level(
self, field: Float[Tensor, "*N 1"], threshold: float
) -> Float[Tensor, "*N 1"]:
return field - threshold
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
out: Dict[str, Any] = {}
if self.cfg.n_feature_dims == 0:
return out
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
out.update(
{
"features": features,
}
)
return out
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
if (
self.cfg.normal_type == "finite_difference"
or self.cfg.normal_type == "finite_difference_laplacian"
):
if isinstance(self.cfg.finite_difference_normal_eps, float):
self.finite_difference_normal_eps = (
self.cfg.finite_difference_normal_eps
)
elif self.cfg.finite_difference_normal_eps == "progressive":
# progressive finite difference eps from Neuralangelo
# https://arxiv.org/abs/2306.03092
hg_conf: Any = self.cfg.pos_encoding_config
assert (
hg_conf.otype == "ProgressiveBandHashGrid"
), "finite_difference_normal_eps=progressive only works with ProgressiveBandHashGrid"
current_level = min(
hg_conf.start_level
+ max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps,
hg_conf.n_levels,
)
grid_res = hg_conf.base_resolution * hg_conf.per_level_scale ** (
current_level - 1
)
grid_size = 2 * self.cfg.radius / grid_res
if grid_size != self.finite_difference_normal_eps:
threestudio.info(
f"Update finite_difference_normal_eps to {grid_size}"
)
self.finite_difference_normal_eps = grid_size
else:
raise ValueError(
f"Unknown finite_difference_normal_eps={self.cfg.finite_difference_normal_eps}"
)

View File

@ -0,0 +1,325 @@
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.geometry.base import (
BaseGeometry,
BaseImplicitGeometry,
contract_to_unisphere,
)
from threestudio.models.networks import get_encoding, get_mlp
from threestudio.utils.ops import get_activation
from threestudio.utils.typing import *
@threestudio.register("implicit-volume")
class ImplicitVolume(BaseImplicitGeometry):
@dataclass
class Config(BaseImplicitGeometry.Config):
n_input_dims: int = 3
n_feature_dims: int = 3
density_activation: Optional[str] = "softplus"
density_bias: Union[float, str] = "blob_magic3d"
density_blob_scale: float = 10.0
density_blob_std: float = 0.5
pos_encoding_config: dict = field(
default_factory=lambda: {
"otype": "HashGrid",
"n_levels": 16,
"n_features_per_level": 2,
"log2_hashmap_size": 19,
"base_resolution": 16,
"per_level_scale": 1.447269237440378,
}
)
mlp_network_config: dict = field(
default_factory=lambda: {
"otype": "VanillaMLP",
"activation": "ReLU",
"output_activation": "none",
"n_neurons": 64,
"n_hidden_layers": 1,
}
)
normal_type: Optional[
str
] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian']
finite_difference_normal_eps: Union[
float, str
] = 0.01 # in [float, "progressive"]
# automatically determine the threshold
isosurface_threshold: Union[float, str] = 25.0
# 4D Gaussian Annealing
anneal_density_blob_std_config: Optional[dict] = None
cfg: Config
def configure(self) -> None:
super().configure()
self.encoding = get_encoding(
self.cfg.n_input_dims, self.cfg.pos_encoding_config
)
self.density_network = get_mlp(
self.encoding.n_output_dims, 1, self.cfg.mlp_network_config
)
if self.cfg.n_feature_dims > 0:
self.feature_network = get_mlp(
self.encoding.n_output_dims,
self.cfg.n_feature_dims,
self.cfg.mlp_network_config,
)
if self.cfg.normal_type == "pred":
self.normal_network = get_mlp(
self.encoding.n_output_dims, 3, self.cfg.mlp_network_config
)
self.finite_difference_normal_eps: Optional[float] = None
def get_activated_density(
self, points: Float[Tensor, "*N Di"], density: Float[Tensor, "*N 1"]
) -> Tuple[Float[Tensor, "*N 1"], Float[Tensor, "*N 1"]]:
density_bias: Union[float, Float[Tensor, "*N 1"]]
if self.cfg.density_bias == "blob_dreamfusion":
# pre-activation density bias
density_bias = (
self.cfg.density_blob_scale
* torch.exp(
-0.5 * (points**2).sum(dim=-1) / self.cfg.density_blob_std**2
)[..., None]
)
elif self.cfg.density_bias == "blob_magic3d":
# pre-activation density bias
density_bias = (
self.cfg.density_blob_scale
* (
1
- torch.sqrt((points**2).sum(dim=-1)) / self.cfg.density_blob_std
)[..., None]
)
elif isinstance(self.cfg.density_bias, float):
density_bias = self.cfg.density_bias
else:
raise ValueError(f"Unknown density bias {self.cfg.density_bias}")
raw_density: Float[Tensor, "*N 1"] = density + density_bias
density = get_activation(self.cfg.density_activation)(raw_density)
return raw_density, density
def forward(
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
) -> Dict[str, Float[Tensor, "..."]]:
grad_enabled = torch.is_grad_enabled()
if output_normal and self.cfg.normal_type == "analytic":
torch.set_grad_enabled(True)
points.requires_grad_(True)
points_unscaled = points # points in the original scale
points = contract_to_unisphere(
points, self.bbox, self.unbounded
) # points normalized to (0, 1)
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
density = self.density_network(enc).view(*points.shape[:-1], 1)
raw_density, density = self.get_activated_density(points_unscaled, density)
output = {
"density": density,
}
if self.cfg.n_feature_dims > 0:
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
output.update({"features": features})
if output_normal:
if (
self.cfg.normal_type == "finite_difference"
or self.cfg.normal_type == "finite_difference_laplacian"
):
# TODO: use raw density
assert self.finite_difference_normal_eps is not None
eps: float = self.finite_difference_normal_eps
if self.cfg.normal_type == "finite_difference_laplacian":
offsets: Float[Tensor, "6 3"] = torch.as_tensor(
[
[eps, 0.0, 0.0],
[-eps, 0.0, 0.0],
[0.0, eps, 0.0],
[0.0, -eps, 0.0],
[0.0, 0.0, eps],
[0.0, 0.0, -eps],
]
).to(points_unscaled)
points_offset: Float[Tensor, "... 6 3"] = (
points_unscaled[..., None, :] + offsets
).clamp(-self.cfg.radius, self.cfg.radius)
density_offset: Float[Tensor, "... 6 1"] = self.forward_density(
points_offset
)
normal = (
-0.5
* (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0])
/ eps
)
else:
offsets: Float[Tensor, "3 3"] = torch.as_tensor(
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
).to(points_unscaled)
points_offset: Float[Tensor, "... 3 3"] = (
points_unscaled[..., None, :] + offsets
).clamp(-self.cfg.radius, self.cfg.radius)
density_offset: Float[Tensor, "... 3 1"] = self.forward_density(
points_offset
)
normal = -(density_offset[..., 0::1, 0] - density) / eps
normal = F.normalize(normal, dim=-1)
elif self.cfg.normal_type == "pred":
normal = self.normal_network(enc).view(*points.shape[:-1], 3)
normal = F.normalize(normal, dim=-1)
elif self.cfg.normal_type == "analytic":
normal = -torch.autograd.grad(
density,
points_unscaled,
grad_outputs=torch.ones_like(density),
create_graph=True,
)[0]
normal = F.normalize(normal, dim=-1)
if not grad_enabled:
normal = normal.detach()
else:
raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
output.update({"normal": normal, "shading_normal": normal})
torch.set_grad_enabled(grad_enabled)
return output
def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
density = self.density_network(
self.encoding(points.reshape(-1, self.cfg.n_input_dims))
).reshape(*points.shape[:-1], 1)
_, density = self.get_activated_density(points_unscaled, density)
return density
def forward_field(
self, points: Float[Tensor, "*N Di"]
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
if self.cfg.isosurface_deformable_grid:
threestudio.warn(
f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring."
)
density = self.forward_density(points)
return density, None
def forward_level(
self, field: Float[Tensor, "*N 1"], threshold: float
) -> Float[Tensor, "*N 1"]:
return -(field - threshold)
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
out: Dict[str, Any] = {}
if self.cfg.n_feature_dims == 0:
return out
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
out.update(
{
"features": features,
}
)
return out
@staticmethod
@torch.no_grad()
def create_from(
other: BaseGeometry,
cfg: Optional[Union[dict, DictConfig]] = None,
copy_net: bool = True,
**kwargs,
) -> "ImplicitVolume":
if isinstance(other, ImplicitVolume):
instance = ImplicitVolume(cfg, **kwargs)
instance.encoding.load_state_dict(other.encoding.state_dict())
instance.density_network.load_state_dict(other.density_network.state_dict())
if copy_net:
if (
instance.cfg.n_feature_dims > 0
and other.cfg.n_feature_dims == instance.cfg.n_feature_dims
):
instance.feature_network.load_state_dict(
other.feature_network.state_dict()
)
if (
instance.cfg.normal_type == "pred"
and other.cfg.normal_type == "pred"
):
instance.normal_network.load_state_dict(
other.normal_network.state_dict()
)
return instance
else:
raise TypeError(
f"Cannot create {ImplicitVolume.__name__} from {other.__class__.__name__}"
)
# FIXME: use progressive normal eps
def update_step(
self, epoch: int, global_step: int, on_load_weights: bool = False
) -> None:
if self.cfg.anneal_density_blob_std_config is not None:
min_step = self.cfg.anneal_density_blob_std_config.min_anneal_step
max_step = self.cfg.anneal_density_blob_std_config.max_anneal_step
if global_step >= min_step and global_step <= max_step:
end_val = self.cfg.anneal_density_blob_std_config.end_val
start_val = self.cfg.anneal_density_blob_std_config.start_val
self.density_blob_std = start_val + (global_step - min_step) * (
end_val - start_val
) / (max_step - min_step)
if (
self.cfg.normal_type == "finite_difference"
or self.cfg.normal_type == "finite_difference_laplacian"
):
if isinstance(self.cfg.finite_difference_normal_eps, float):
self.finite_difference_normal_eps = (
self.cfg.finite_difference_normal_eps
)
elif self.cfg.finite_difference_normal_eps == "progressive":
# progressive finite difference eps from Neuralangelo
# https://arxiv.org/abs/2306.03092
hg_conf: Any = self.cfg.pos_encoding_config
assert (
hg_conf.otype == "ProgressiveBandHashGrid"
), "finite_difference_normal_eps=progressive only works with ProgressiveBandHashGrid"
current_level = min(
hg_conf.start_level
+ max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps,
hg_conf.n_levels,
)
grid_res = hg_conf.base_resolution * hg_conf.per_level_scale ** (
current_level - 1
)
grid_size = 2 * self.cfg.radius / grid_res
if grid_size != self.finite_difference_normal_eps:
threestudio.info(
f"Update finite_difference_normal_eps to {grid_size}"
)
self.finite_difference_normal_eps = grid_size
else:
raise ValueError(
f"Unknown finite_difference_normal_eps={self.cfg.finite_difference_normal_eps}"
)

View File

@ -0,0 +1,369 @@
import os
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.geometry.base import (
BaseExplicitGeometry,
BaseGeometry,
contract_to_unisphere,
)
from threestudio.models.geometry.implicit_sdf import ImplicitSDF
from threestudio.models.geometry.implicit_volume import ImplicitVolume
from threestudio.models.isosurface import MarchingTetrahedraHelper
from threestudio.models.mesh import Mesh
from threestudio.models.networks import get_encoding, get_mlp
from threestudio.utils.misc import broadcast
from threestudio.utils.ops import scale_tensor
from threestudio.utils.typing import *
@threestudio.register("tetrahedra-sdf-grid")
class TetrahedraSDFGrid(BaseExplicitGeometry):
@dataclass
class Config(BaseExplicitGeometry.Config):
isosurface_resolution: int = 128
isosurface_deformable_grid: bool = True
isosurface_remove_outliers: bool = False
isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01
n_input_dims: int = 3
n_feature_dims: int = 3
pos_encoding_config: dict = field(
default_factory=lambda: {
"otype": "HashGrid",
"n_levels": 16,
"n_features_per_level": 2,
"log2_hashmap_size": 19,
"base_resolution": 16,
"per_level_scale": 1.447269237440378,
}
)
mlp_network_config: dict = field(
default_factory=lambda: {
"otype": "VanillaMLP",
"activation": "ReLU",
"output_activation": "none",
"n_neurons": 64,
"n_hidden_layers": 1,
}
)
shape_init: Optional[str] = None
shape_init_params: Optional[Any] = None
shape_init_mesh_up: str = "+z"
shape_init_mesh_front: str = "+x"
force_shape_init: bool = False
geometry_only: bool = False
fix_geometry: bool = False
cfg: Config
def configure(self) -> None:
super().configure()
# this should be saved to state_dict, register as buffer
self.isosurface_bbox: Float[Tensor, "2 3"]
self.register_buffer("isosurface_bbox", self.bbox.clone())
self.isosurface_helper = MarchingTetrahedraHelper(
self.cfg.isosurface_resolution,
f"load/tets/{self.cfg.isosurface_resolution}_tets.npz",
)
self.sdf: Float[Tensor, "Nv 1"]
self.deformation: Optional[Float[Tensor, "Nv 3"]]
if not self.cfg.fix_geometry:
self.register_parameter(
"sdf",
nn.Parameter(
torch.zeros(
(self.isosurface_helper.grid_vertices.shape[0], 1),
dtype=torch.float32,
)
),
)
if self.cfg.isosurface_deformable_grid:
self.register_parameter(
"deformation",
nn.Parameter(
torch.zeros_like(self.isosurface_helper.grid_vertices)
),
)
else:
self.deformation = None
else:
self.register_buffer(
"sdf",
torch.zeros(
(self.isosurface_helper.grid_vertices.shape[0], 1),
dtype=torch.float32,
),
)
if self.cfg.isosurface_deformable_grid:
self.register_buffer(
"deformation",
torch.zeros_like(self.isosurface_helper.grid_vertices),
)
else:
self.deformation = None
if not self.cfg.geometry_only:
self.encoding = get_encoding(
self.cfg.n_input_dims, self.cfg.pos_encoding_config
)
self.feature_network = get_mlp(
self.encoding.n_output_dims,
self.cfg.n_feature_dims,
self.cfg.mlp_network_config,
)
self.mesh: Optional[Mesh] = None
def initialize_shape(self) -> None:
if self.cfg.shape_init is None and not self.cfg.force_shape_init:
return
# do not initialize shape if weights are provided
if self.cfg.weights is not None and not self.cfg.force_shape_init:
return
get_gt_sdf: Callable[[Float[Tensor, "N 3"]], Float[Tensor, "N 1"]]
assert isinstance(self.cfg.shape_init, str)
if self.cfg.shape_init == "ellipsoid":
assert (
isinstance(self.cfg.shape_init_params, Sized)
and len(self.cfg.shape_init_params) == 3
)
size = torch.as_tensor(self.cfg.shape_init_params).to(self.device)
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
return ((points_rand / size) ** 2).sum(
dim=-1, keepdim=True
).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
get_gt_sdf = func
elif self.cfg.shape_init == "sphere":
assert isinstance(self.cfg.shape_init_params, float)
radius = self.cfg.shape_init_params
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
return (points_rand**2).sum(dim=-1, keepdim=True).sqrt() - radius
get_gt_sdf = func
elif self.cfg.shape_init.startswith("mesh:"):
assert isinstance(self.cfg.shape_init_params, float)
mesh_path = self.cfg.shape_init[5:]
if not os.path.exists(mesh_path):
raise ValueError(f"Mesh file {mesh_path} does not exist.")
import trimesh
mesh = trimesh.load(mesh_path)
# move to center
centroid = mesh.vertices.mean(0)
mesh.vertices = mesh.vertices - centroid
# align to up-z and front-x
dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
dir2vec = {
"+x": np.array([1, 0, 0]),
"+y": np.array([0, 1, 0]),
"+z": np.array([0, 0, 1]),
"-x": np.array([-1, 0, 0]),
"-y": np.array([0, -1, 0]),
"-z": np.array([0, 0, -1]),
}
if (
self.cfg.shape_init_mesh_up not in dirs
or self.cfg.shape_init_mesh_front not in dirs
):
raise ValueError(
f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
)
if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]:
raise ValueError(
"shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
)
z_, x_ = (
dir2vec[self.cfg.shape_init_mesh_up],
dir2vec[self.cfg.shape_init_mesh_front],
)
y_ = np.cross(z_, x_)
std2mesh = np.stack([x_, y_, z_], axis=0).T
mesh2std = np.linalg.inv(std2mesh)
# scaling
scale = np.abs(mesh.vertices).max()
mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params
mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
from pysdf import SDF
sdf = SDF(mesh.vertices, mesh.faces)
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
# add a negative signed here
# as in pysdf the inside of the shape has positive signed distance
return torch.from_numpy(-sdf(points_rand.cpu().numpy())).to(
points_rand
)[..., None]
get_gt_sdf = func
else:
raise ValueError(
f"Unknown shape initialization type: {self.cfg.shape_init}"
)
sdf_gt = get_gt_sdf(
scale_tensor(
self.isosurface_helper.grid_vertices,
self.isosurface_helper.points_range,
self.isosurface_bbox,
)
)
self.sdf.data = sdf_gt
# explicit broadcast to ensure param consistency across ranks
for param in self.parameters():
broadcast(param, src=0)
def isosurface(self) -> Mesh:
# return cached mesh if fix_geometry is True to save computation
if self.cfg.fix_geometry and self.mesh is not None:
return self.mesh
mesh = self.isosurface_helper(self.sdf, self.deformation)
mesh.v_pos = scale_tensor(
mesh.v_pos, self.isosurface_helper.points_range, self.isosurface_bbox
)
if self.cfg.isosurface_remove_outliers:
mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold)
self.mesh = mesh
return mesh
def forward(
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
) -> Dict[str, Float[Tensor, "..."]]:
if self.cfg.geometry_only:
return {}
assert (
output_normal == False
), f"Normal output is not supported for {self.__class__.__name__}"
points_unscaled = points # points in the original scale
points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1)
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
return {"features": features}
@staticmethod
@torch.no_grad()
def create_from(
other: BaseGeometry,
cfg: Optional[Union[dict, DictConfig]] = None,
copy_net: bool = True,
**kwargs,
) -> "TetrahedraSDFGrid":
if isinstance(other, TetrahedraSDFGrid):
instance = TetrahedraSDFGrid(cfg, **kwargs)
assert instance.cfg.isosurface_resolution == other.cfg.isosurface_resolution
instance.isosurface_bbox = other.isosurface_bbox.clone()
instance.sdf.data = other.sdf.data.clone()
if (
instance.cfg.isosurface_deformable_grid
and other.cfg.isosurface_deformable_grid
):
assert (
instance.deformation is not None and other.deformation is not None
)
instance.deformation.data = other.deformation.data.clone()
if (
not instance.cfg.geometry_only
and not other.cfg.geometry_only
and copy_net
):
instance.encoding.load_state_dict(other.encoding.state_dict())
instance.feature_network.load_state_dict(
other.feature_network.state_dict()
)
return instance
elif isinstance(other, ImplicitVolume):
instance = TetrahedraSDFGrid(cfg, **kwargs)
if other.cfg.isosurface_method != "mt":
other.cfg.isosurface_method = "mt"
threestudio.warn(
f"Override isosurface_method of the source geometry to 'mt'"
)
if other.cfg.isosurface_resolution != instance.cfg.isosurface_resolution:
other.cfg.isosurface_resolution = instance.cfg.isosurface_resolution
threestudio.warn(
f"Override isosurface_resolution of the source geometry to {instance.cfg.isosurface_resolution}"
)
mesh = other.isosurface()
instance.isosurface_bbox = mesh.extras["bbox"]
instance.sdf.data = (
mesh.extras["grid_level"].to(instance.sdf.data).clamp(-1, 1)
)
if not instance.cfg.geometry_only and copy_net:
instance.encoding.load_state_dict(other.encoding.state_dict())
instance.feature_network.load_state_dict(
other.feature_network.state_dict()
)
return instance
elif isinstance(other, ImplicitSDF):
instance = TetrahedraSDFGrid(cfg, **kwargs)
if other.cfg.isosurface_method != "mt":
other.cfg.isosurface_method = "mt"
threestudio.warn(
f"Override isosurface_method of the source geometry to 'mt'"
)
if other.cfg.isosurface_resolution != instance.cfg.isosurface_resolution:
other.cfg.isosurface_resolution = instance.cfg.isosurface_resolution
threestudio.warn(
f"Override isosurface_resolution of the source geometry to {instance.cfg.isosurface_resolution}"
)
mesh = other.isosurface()
instance.isosurface_bbox = mesh.extras["bbox"]
instance.sdf.data = mesh.extras["grid_level"].to(instance.sdf.data)
if (
instance.cfg.isosurface_deformable_grid
and other.cfg.isosurface_deformable_grid
):
assert instance.deformation is not None
instance.deformation.data = mesh.extras["grid_deformation"].to(
instance.deformation.data
)
if not instance.cfg.geometry_only and copy_net:
instance.encoding.load_state_dict(other.encoding.state_dict())
instance.feature_network.load_state_dict(
other.feature_network.state_dict()
)
return instance
else:
raise TypeError(
f"Cannot create {TetrahedraSDFGrid.__name__} from {other.__class__.__name__}"
)
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
out: Dict[str, Any] = {}
if self.cfg.geometry_only or self.cfg.n_feature_dims == 0:
return out
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox)
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
out.update(
{
"features": features,
}
)
return out

View File

@ -0,0 +1,190 @@
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere
from threestudio.utils.ops import get_activation
from threestudio.utils.typing import *
@threestudio.register("volume-grid")
class VolumeGrid(BaseImplicitGeometry):
@dataclass
class Config(BaseImplicitGeometry.Config):
grid_size: Tuple[int, int, int] = field(default_factory=lambda: (100, 100, 100))
n_feature_dims: int = 3
density_activation: Optional[str] = "softplus"
density_bias: Union[float, str] = "blob"
density_blob_scale: float = 5.0
density_blob_std: float = 0.5
normal_type: Optional[
str
] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian']
# automatically determine the threshold
isosurface_threshold: Union[float, str] = "auto"
cfg: Config
def configure(self) -> None:
super().configure()
self.grid_size = self.cfg.grid_size
self.grid = nn.Parameter(
torch.zeros(1, self.cfg.n_feature_dims + 1, *self.grid_size)
)
if self.cfg.density_bias == "blob":
self.register_buffer("density_scale", torch.tensor(0.0))
else:
self.density_scale = nn.Parameter(torch.tensor(0.0))
if self.cfg.normal_type == "pred":
self.normal_grid = nn.Parameter(torch.zeros(1, 3, *self.grid_size))
def get_density_bias(self, points: Float[Tensor, "*N Di"]):
if self.cfg.density_bias == "blob":
# density_bias: Float[Tensor, "*N 1"] = self.cfg.density_blob_scale * torch.exp(-0.5 * (points ** 2).sum(dim=-1) / self.cfg.density_blob_std ** 2)[...,None]
density_bias: Float[Tensor, "*N 1"] = (
self.cfg.density_blob_scale
* (
1
- torch.sqrt((points.detach() ** 2).sum(dim=-1))
/ self.cfg.density_blob_std
)[..., None]
)
return density_bias
elif isinstance(self.cfg.density_bias, float):
return self.cfg.density_bias
else:
raise AttributeError(f"Unknown density bias {self.cfg.density_bias}")
def get_trilinear_feature(
self, points: Float[Tensor, "*N Di"], grid: Float[Tensor, "1 Df G1 G2 G3"]
) -> Float[Tensor, "*N Df"]:
points_shape = points.shape[:-1]
df = grid.shape[1]
di = points.shape[-1]
out = F.grid_sample(
grid, points.view(1, 1, 1, -1, di), align_corners=False, mode="bilinear"
)
out = out.reshape(df, -1).T.reshape(*points_shape, df)
return out
def forward(
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
) -> Dict[str, Float[Tensor, "..."]]:
points_unscaled = points # points in the original scale
points = contract_to_unisphere(
points, self.bbox, self.unbounded
) # points normalized to (0, 1)
points = points * 2 - 1 # convert to [-1, 1] for grid sample
out = self.get_trilinear_feature(points, self.grid)
density, features = out[..., 0:1], out[..., 1:]
density = density * torch.exp(self.density_scale) # exp scaling in DreamFusion
# breakpoint()
density = get_activation(self.cfg.density_activation)(
density + self.get_density_bias(points_unscaled)
)
output = {
"density": density,
"features": features,
}
if output_normal:
if (
self.cfg.normal_type == "finite_difference"
or self.cfg.normal_type == "finite_difference_laplacian"
):
eps = 1.0e-3
if self.cfg.normal_type == "finite_difference_laplacian":
offsets: Float[Tensor, "6 3"] = torch.as_tensor(
[
[eps, 0.0, 0.0],
[-eps, 0.0, 0.0],
[0.0, eps, 0.0],
[0.0, -eps, 0.0],
[0.0, 0.0, eps],
[0.0, 0.0, -eps],
]
).to(points_unscaled)
points_offset: Float[Tensor, "... 6 3"] = (
points_unscaled[..., None, :] + offsets
).clamp(-self.cfg.radius, self.cfg.radius)
density_offset: Float[Tensor, "... 6 1"] = self.forward_density(
points_offset
)
normal = (
-0.5
* (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0])
/ eps
)
else:
offsets: Float[Tensor, "3 3"] = torch.as_tensor(
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
).to(points_unscaled)
points_offset: Float[Tensor, "... 3 3"] = (
points_unscaled[..., None, :] + offsets
).clamp(-self.cfg.radius, self.cfg.radius)
density_offset: Float[Tensor, "... 3 1"] = self.forward_density(
points_offset
)
normal = -(density_offset[..., 0::1, 0] - density) / eps
normal = F.normalize(normal, dim=-1)
elif self.cfg.normal_type == "pred":
normal = self.get_trilinear_feature(points, self.normal_grid)
normal = F.normalize(normal, dim=-1)
else:
raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
output.update({"normal": normal, "shading_normal": normal})
return output
def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
points = points * 2 - 1 # convert to [-1, 1] for grid sample
out = self.get_trilinear_feature(points, self.grid)
density = out[..., 0:1]
density = density * torch.exp(self.density_scale)
density = get_activation(self.cfg.density_activation)(
density + self.get_density_bias(points_unscaled)
)
return density
def forward_field(
self, points: Float[Tensor, "*N Di"]
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
if self.cfg.isosurface_deformable_grid:
threestudio.warn(
f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring."
)
density = self.forward_density(points)
return density, None
def forward_level(
self, field: Float[Tensor, "*N 1"], threshold: float
) -> Float[Tensor, "*N 1"]:
return -(field - threshold)
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
out: Dict[str, Any] = {}
if self.cfg.n_feature_dims == 0:
return out
points_unscaled = points
points = contract_to_unisphere(points, self.bbox, self.unbounded)
points = points * 2 - 1 # convert to [-1, 1] for grid sample
features = self.get_trilinear_feature(points, self.grid)[..., 1:]
out.update(
{
"features": features,
}
)
return out

View File

@ -0,0 +1,13 @@
from . import (
controlnet_guidance,
controlnet_reg_guidance,
deep_floyd_guidance,
stable_diffusion_guidance,
stable_diffusion_unified_guidance,
stable_diffusion_vsd_guidance,
stable_diffusion_bsd_guidance,
stable_zero123_guidance,
zero123_guidance,
zero123_unified_guidance,
clip_guidance,
)

View File

@ -0,0 +1,84 @@
from dataclasses import dataclass
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import clip
import threestudio
from threestudio.utils.base import BaseObject
from threestudio.models.prompt_processors.base import PromptProcessorOutput
from threestudio.utils.typing import *
@threestudio.register("clip-guidance")
class CLIPGuidance(BaseObject):
@dataclass
class Config(BaseObject.Config):
cache_dir: Optional[str] = None
pretrained_model_name_or_path: str = "ViT-B/16"
view_dependent_prompting: bool = True
cfg: Config
def configure(self) -> None:
threestudio.info(f"Loading CLIP ...")
self.clip_model, self.clip_preprocess = clip.load(
self.cfg.pretrained_model_name_or_path,
device=self.device,
jit=False,
download_root=self.cfg.cache_dir
)
self.aug = T.Compose([
T.Resize((224, 224)),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
threestudio.info(f"Loaded CLIP!")
@torch.cuda.amp.autocast(enabled=False)
def get_embedding(self, input_value, is_text=True):
if is_text:
value = clip.tokenize(input_value).to(self.device)
z = self.clip_model.encode_text(value)
else:
input_value = self.aug(input_value)
z = self.clip_model.encode_image(input_value)
return z / z.norm(dim=-1, keepdim=True)
def get_loss(self, image_z, clip_z, loss_type='similarity_score', use_mean=True):
if loss_type == 'similarity_score':
loss = -((image_z * clip_z).sum(-1))
elif loss_type == 'spherical_dist':
image_z, clip_z = F.normalize(image_z, dim=-1), F.normalize(clip_z, dim=-1)
loss = ((image_z - clip_z).norm(dim=-1).div(2).arcsin().pow(2).mul(2))
else:
raise NotImplementedError
return loss.mean() if use_mean else loss
def __call__(
self,
pred_rgb: Float[Tensor, "B H W C"],
gt_rgb: Float[Tensor, "B H W C"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
embedding_type: str = 'both',
loss_type: Optional[str] = 'similarity_score',
**kwargs,
):
clip_text_loss, clip_img_loss = 0, 0
if embedding_type in ('both', 'text'):
text_embeddings = prompt_utils.get_text_embeddings(
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
).chunk(2)[0]
clip_text_loss = self.get_loss(self.get_embedding(pred_rgb, is_text=False), text_embeddings, loss_type=loss_type)
if embedding_type in ('both', 'img'):
clip_img_loss = self.get_loss(self.get_embedding(pred_rgb, is_text=False), self.get_embedding(gt_rgb, is_text=False), loss_type=loss_type)
return clip_text_loss + clip_img_loss

View File

@ -0,0 +1,517 @@
import os
from dataclasses import dataclass
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from controlnet_aux import CannyDetector, NormalBaeDetector
from diffusers import ControlNetModel, DDIMScheduler, StableDiffusionControlNetPipeline
from diffusers.utils.import_utils import is_xformers_available
from tqdm import tqdm
import threestudio
from threestudio.models.prompt_processors.base import PromptProcessorOutput
from threestudio.utils.base import BaseObject
from threestudio.utils.misc import C, parse_version
from threestudio.utils.perceptual import PerceptualLoss
from threestudio.utils.typing import *
@threestudio.register("stable-diffusion-controlnet-guidance")
class ControlNetGuidance(BaseObject):
@dataclass
class Config(BaseObject.Config):
cache_dir: Optional[str] = None
pretrained_model_name_or_path: str = "SG161222/Realistic_Vision_V2.0"
ddim_scheduler_name_or_path: str = "runwayml/stable-diffusion-v1-5"
control_type: str = "normal" # normal/canny
enable_memory_efficient_attention: bool = False
enable_sequential_cpu_offload: bool = False
enable_attention_slicing: bool = False
enable_channels_last_format: bool = False
guidance_scale: float = 7.5
condition_scale: float = 1.5
grad_clip: Optional[Any] = None
half_precision_weights: bool = True
fixed_size: int = -1
min_step_percent: float = 0.02
max_step_percent: float = 0.98
diffusion_steps: int = 20
use_sds: bool = False
use_du: bool = False
per_du_step: int = 10
start_du_step: int = 1000
cache_du: bool = False
# Canny threshold
canny_lower_bound: int = 50
canny_upper_bound: int = 100
cfg: Config
def configure(self) -> None:
threestudio.info(f"Loading ControlNet ...")
controlnet_name_or_path: str
if self.cfg.control_type in ("normal", "input_normal"):
controlnet_name_or_path = "lllyasviel/control_v11p_sd15_normalbae"
elif self.cfg.control_type == "canny":
controlnet_name_or_path = "lllyasviel/control_v11p_sd15_canny"
self.weights_dtype = (
torch.float16 if self.cfg.half_precision_weights else torch.float32
)
pipe_kwargs = {
"safety_checker": None,
"feature_extractor": None,
"requires_safety_checker": False,
"torch_dtype": self.weights_dtype,
"cache_dir": self.cfg.cache_dir,
}
controlnet = ControlNetModel.from_pretrained(
controlnet_name_or_path,
torch_dtype=self.weights_dtype,
cache_dir=self.cfg.cache_dir,
)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path, controlnet=controlnet, **pipe_kwargs
).to(self.device)
self.scheduler = DDIMScheduler.from_pretrained(
self.cfg.ddim_scheduler_name_or_path,
subfolder="scheduler",
torch_dtype=self.weights_dtype,
cache_dir=self.cfg.cache_dir,
)
self.scheduler.set_timesteps(self.cfg.diffusion_steps)
if self.cfg.enable_memory_efficient_attention:
if parse_version(torch.__version__) >= parse_version("2"):
threestudio.info(
"PyTorch2.0 uses memory efficient attention by default."
)
elif not is_xformers_available():
threestudio.warn(
"xformers is not available, memory efficient attention is not enabled."
)
else:
self.pipe.enable_xformers_memory_efficient_attention()
if self.cfg.enable_sequential_cpu_offload:
self.pipe.enable_sequential_cpu_offload()
if self.cfg.enable_attention_slicing:
self.pipe.enable_attention_slicing(1)
if self.cfg.enable_channels_last_format:
self.pipe.unet.to(memory_format=torch.channels_last)
# Create model
self.vae = self.pipe.vae.eval()
self.unet = self.pipe.unet.eval()
self.controlnet = self.pipe.controlnet.eval()
if self.cfg.control_type == "normal":
self.preprocessor = NormalBaeDetector.from_pretrained(
"lllyasviel/Annotators"
)
self.preprocessor.model.to(self.device)
elif self.cfg.control_type == "canny":
self.preprocessor = CannyDetector()
for p in self.vae.parameters():
p.requires_grad_(False)
for p in self.unet.parameters():
p.requires_grad_(False)
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.set_min_max_steps() # set to default value
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
self.device
)
self.grad_clip_val: Optional[float] = None
if self.cfg.use_du:
if self.cfg.cache_du:
self.edit_frames = {}
self.perceptual_loss = PerceptualLoss().eval().to(self.device)
threestudio.info(f"Loaded ControlNet!")
@torch.cuda.amp.autocast(enabled=False)
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
self.min_step = int(self.num_train_timesteps * min_step_percent)
self.max_step = int(self.num_train_timesteps * max_step_percent)
@torch.cuda.amp.autocast(enabled=False)
def forward_controlnet(
self,
latents: Float[Tensor, "..."],
t: Float[Tensor, "..."],
image_cond: Float[Tensor, "..."],
condition_scale: float,
encoder_hidden_states: Float[Tensor, "..."],
) -> Float[Tensor, "..."]:
return self.controlnet(
latents.to(self.weights_dtype),
t.to(self.weights_dtype),
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
controlnet_cond=image_cond.to(self.weights_dtype),
conditioning_scale=condition_scale,
return_dict=False,
)
@torch.cuda.amp.autocast(enabled=False)
def forward_control_unet(
self,
latents: Float[Tensor, "..."],
t: Float[Tensor, "..."],
encoder_hidden_states: Float[Tensor, "..."],
cross_attention_kwargs,
down_block_additional_residuals,
mid_block_additional_residual,
) -> Float[Tensor, "..."]:
input_dtype = latents.dtype
return self.unet(
latents.to(self.weights_dtype),
t.to(self.weights_dtype),
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
).sample.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def encode_images(
self, imgs: Float[Tensor, "B 3 H W"]
) -> Float[Tensor, "B 4 DH DW"]:
input_dtype = imgs.dtype
imgs = imgs * 2.0 - 1.0
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def encode_cond_images(
self, imgs: Float[Tensor, "B 3 H W"]
) -> Float[Tensor, "B 4 DH DW"]:
input_dtype = imgs.dtype
imgs = imgs * 2.0 - 1.0
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
latents = posterior.mode()
uncond_image_latents = torch.zeros_like(latents)
latents = torch.cat([latents, latents, uncond_image_latents], dim=0)
return latents.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def decode_latents(
self, latents: Float[Tensor, "B 4 DH DW"]
) -> Float[Tensor, "B 3 H W"]:
input_dtype = latents.dtype
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents.to(self.weights_dtype)).sample
image = (image * 0.5 + 0.5).clamp(0, 1)
return image.to(input_dtype)
def edit_latents(
self,
text_embeddings: Float[Tensor, "BB 77 768"],
latents: Float[Tensor, "B 4 DH DW"],
image_cond: Float[Tensor, "B 3 H W"],
t: Int[Tensor, "B"],
mask = None
) -> Float[Tensor, "B 4 DH DW"]:
self.scheduler.config.num_train_timesteps = t.item()
self.scheduler.set_timesteps(self.cfg.diffusion_steps)
if mask is not None:
mask = F.interpolate(mask, (latents.shape[-2], latents.shape[-1]), mode='bilinear')
with torch.no_grad():
# add noise
noise = torch.randn_like(latents)
latents = self.scheduler.add_noise(latents, noise, t) # type: ignore
# sections of code used from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
threestudio.debug("Start editing...")
for i, t in enumerate(self.scheduler.timesteps):
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# pred noise
latent_model_input = torch.cat([latents] * 2)
(
down_block_res_samples,
mid_block_res_sample,
) = self.forward_controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
image_cond=image_cond,
condition_scale=self.cfg.condition_scale,
)
noise_pred = self.forward_control_unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
cross_attention_kwargs=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)
# perform classifier-free guidance
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if mask is not None:
noise_pred = mask * noise_pred + (1 - mask) * noise
# get previous sample, continue loop
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
threestudio.debug("Editing finished.")
return latents
def prepare_image_cond(self, cond_rgb: Float[Tensor, "B H W C"]):
if self.cfg.control_type == "normal":
cond_rgb = (
(cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy()
)
detected_map = self.preprocessor(cond_rgb)
control = (
torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0
)
control = control.unsqueeze(0)
control = control.permute(0, 3, 1, 2)
elif self.cfg.control_type == "canny":
cond_rgb = (
(cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy()
)
blurred_img = cv2.blur(cond_rgb, ksize=(5, 5))
detected_map = self.preprocessor(
blurred_img, self.cfg.canny_lower_bound, self.cfg.canny_upper_bound
)
control = (
torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0
)
# control = control.unsqueeze(-1).repeat(1, 1, 3)
control = control.unsqueeze(0)
control = control.permute(0, 3, 1, 2)
elif self.cfg.control_type == "input_normal":
cond_rgb[..., 0] = (
1 - cond_rgb[..., 0]
) # Flip the sign on the x-axis to match bae system
control = cond_rgb.permute(0, 3, 1, 2)
else:
raise ValueError(f"Unknown control type: {self.cfg.control_type}")
return control
def compute_grad_sds(
self,
text_embeddings: Float[Tensor, "BB 77 768"],
latents: Float[Tensor, "B 4 DH DW"],
image_cond: Float[Tensor, "B 3 H W"],
t: Int[Tensor, "B"],
):
with torch.no_grad():
# add noise
noise = torch.randn_like(latents) # TODO: use torch generator
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2)
down_block_res_samples, mid_block_res_sample = self.forward_controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
image_cond=image_cond,
condition_scale=self.cfg.condition_scale,
)
noise_pred = self.forward_control_unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
cross_attention_kwargs=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)
# perform classifier-free guidance
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
grad = w * (noise_pred - noise)
return grad
def compute_grad_du(
self,
latents: Float[Tensor, "B 4 H W"],
rgb_BCHW_HW8: Float[Tensor, "B 3 RH RW"],
cond_feature: Float[Tensor, "B 3 RH RW"],
cond_rgb: Float[Tensor, "B H W 3"],
text_embeddings: Float[Tensor, "BB 77 768"],
mask = None,
**kwargs,
):
batch_size, _, RH, RW = cond_feature.shape
assert batch_size == 1
origin_gt_rgb = F.interpolate(
cond_rgb.permute(0, 3, 1, 2), (RH, RW), mode="bilinear"
).permute(0, 2, 3, 1)
need_diffusion = (
self.global_step % self.cfg.per_du_step == 0
and self.global_step > self.cfg.start_du_step
)
if self.cfg.cache_du:
if torch.is_tensor(kwargs["index"]):
batch_index = kwargs["index"].item()
else:
batch_index = kwargs["index"]
if (
not (batch_index in self.edit_frames)
) and self.global_step > self.cfg.start_du_step:
need_diffusion = True
need_loss = self.cfg.cache_du or need_diffusion
guidance_out = {}
if need_diffusion:
t = torch.randint(
self.min_step,
self.max_step,
[1],
dtype=torch.long,
device=self.device,
)
print("t:", t)
edit_latents = self.edit_latents(text_embeddings, latents, cond_feature, t, mask)
edit_images = self.decode_latents(edit_latents)
edit_images = F.interpolate(
edit_images, (RH, RW), mode="bilinear"
).permute(0, 2, 3, 1)
self.edit_images = edit_images
if self.cfg.cache_du:
self.edit_frames[batch_index] = edit_images.detach().cpu()
if need_loss:
if self.cfg.cache_du:
if batch_index in self.edit_frames:
gt_rgb = self.edit_frames[batch_index].to(cond_feature.device)
else:
gt_rgb = origin_gt_rgb
else:
gt_rgb = edit_images
import cv2
import numpy as np
temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8)
cv2.imwrite(".threestudio_cache/test.jpg", temp[:, :, ::-1])
guidance_out.update(
{
"loss_l1": torch.nn.functional.l1_loss(
rgb_BCHW_HW8, gt_rgb.permute(0, 3, 1, 2), reduction="sum"
),
"loss_p": self.perceptual_loss(
rgb_BCHW_HW8.contiguous(),
gt_rgb.permute(0, 3, 1, 2).contiguous(),
).sum(),
}
)
return guidance_out
def __call__(
self,
rgb: Float[Tensor, "B H W C"],
cond_rgb: Float[Tensor, "B H W C"],
prompt_utils: PromptProcessorOutput,
mask = None,
**kwargs,
):
batch_size, H, W, _ = rgb.shape
assert batch_size == 1
assert rgb.shape[:-1] == cond_rgb.shape[:-1]
rgb_BCHW = rgb.permute(0, 3, 1, 2)
if mask is not None: mask = mask.permute(0, 3, 1, 2)
latents: Float[Tensor, "B 4 DH DW"]
if self.cfg.fixed_size > 0:
RH, RW = self.cfg.fixed_size, self.cfg.fixed_size
else:
RH, RW = H // 8 * 8, W // 8 * 8
rgb_BCHW_HW8 = F.interpolate(
rgb_BCHW, (RH, RW), mode="bilinear", align_corners=False
)
latents = self.encode_images(rgb_BCHW_HW8)
image_cond = self.prepare_image_cond(cond_rgb)
image_cond = F.interpolate(
image_cond, (RH, RW), mode="bilinear", align_corners=False
)
temp = torch.zeros(1).to(rgb.device)
azimuth = kwargs.get("azimuth", temp)
camera_distance = kwargs.get("camera_distance", temp)
view_dependent_prompt = kwargs.get("view_dependent_prompt", False)
text_embeddings = prompt_utils.get_text_embeddings(temp, azimuth, camera_distance, view_dependent_prompt) # FIXME: change to view-conditioned prompt
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(
self.min_step,
self.max_step + 1,
[batch_size],
dtype=torch.long,
device=self.device,
)
guidance_out = {}
if self.cfg.use_sds:
grad = self.compute_grad_sds(text_embeddings, latents, image_cond, t)
grad = torch.nan_to_num(grad)
if self.grad_clip_val is not None:
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
target = (latents - grad).detach()
loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
guidance_out.update(
{
"loss_sds": loss_sds,
"grad_norm": grad.norm(),
"min_step": self.min_step,
"max_step": self.max_step,
}
)
if self.cfg.use_du:
grad = self.compute_grad_du(
latents, rgb_BCHW_HW8, image_cond, cond_rgb, text_embeddings, mask, **kwargs
)
guidance_out.update(grad)
return guidance_out
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
# clip grad for stable training as demonstrated in
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
# http://arxiv.org/abs/2303.15413
if self.cfg.grad_clip is not None:
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
self.set_min_max_steps(
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
)
self.global_step = global_step

View File

@ -0,0 +1,454 @@
import os
from dataclasses import dataclass
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from controlnet_aux import CannyDetector, NormalBaeDetector
from diffusers import ControlNetModel, DDIMScheduler, StableDiffusionControlNetPipeline, DPMSolverMultistepScheduler
from diffusers.utils.import_utils import is_xformers_available
from tqdm import tqdm
import threestudio
from threestudio.models.prompt_processors.base import PromptProcessorOutput
from threestudio.utils.base import BaseObject
from threestudio.utils.misc import C, parse_version
from threestudio.utils.typing import *
@threestudio.register("stable-diffusion-controlnet-reg-guidance")
class ControlNetGuidance(BaseObject):
@dataclass
class Config(BaseObject.Config):
cache_dir: Optional[str] = None
local_files_only: Optional[bool] = False
pretrained_model_name_or_path: str = "SG161222/Realistic_Vision_V2.0"
ddim_scheduler_name_or_path: str = "runwayml/stable-diffusion-v1-5"
control_type: str = "normal" # normal/canny
enable_memory_efficient_attention: bool = False
enable_sequential_cpu_offload: bool = False
enable_attention_slicing: bool = False
enable_channels_last_format: bool = False
guidance_scale: float = 7.5
condition_scale: float = 1.5
grad_clip: Optional[Any] = None
half_precision_weights: bool = True
min_step_percent: float = 0.02
max_step_percent: float = 0.98
diffusion_steps: int = 20
use_sds: bool = False
# Canny threshold
canny_lower_bound: int = 50
canny_upper_bound: int = 100
cfg: Config
def configure(self) -> None:
threestudio.info(f"Loading ControlNet ...")
self.weights_dtype = torch.float16 if self.cfg.half_precision_weights else torch.float32
self.preprocessor, controlnet_name_or_path = self.get_preprocessor_and_controlnet()
pipe_kwargs = self.configure_pipeline()
self.load_models(pipe_kwargs, controlnet_name_or_path)
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.scheduler.set_timesteps(self.cfg.diffusion_steps)
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
self.scheduler = self.pipe.scheduler
self.check_memory_efficiency_conditions()
self.set_min_max_steps()
self.alphas = self.scheduler.alphas_cumprod.to(self.device)
self.grad_clip_val = None
threestudio.info(f"Loaded ControlNet!")
def get_preprocessor_and_controlnet(self):
if self.cfg.control_type in ("normal", "input_normal"):
if self.cfg.pretrained_model_name_or_path == "SG161222/Realistic_Vision_V2.0":
controlnet_name_or_path = "lllyasviel/control_v11p_sd15_normalbae"
else:
controlnet_name_or_path = "thibaud/controlnet-sd21-normalbae-diffusers"
preprocessor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators", cache_dir=self.cfg.cache_dir)
preprocessor.model.to(self.device)
elif self.cfg.control_type == "canny" or self.cfg.control_type == "canny2":
controlnet_name_or_path = self.get_canny_controlnet()
preprocessor = CannyDetector()
else:
raise ValueError(f"Unknown control type: {self.cfg.control_type}")
return preprocessor, controlnet_name_or_path
def get_canny_controlnet(self):
if self.cfg.control_type == "canny":
return "lllyasviel/control_v11p_sd15_canny"
elif self.cfg.control_type == "canny2":
return "thepowefuldeez/sd21-controlnet-canny"
def configure_pipeline(self):
return {
"safety_checker": None,
"feature_extractor": None,
"requires_safety_checker": False,
"torch_dtype": self.weights_dtype,
"cache_dir": self.cfg.cache_dir,
"local_files_only": self.cfg.local_files_only
}
def load_models(self, pipe_kwargs, controlnet_name_or_path):
controlnet = ControlNetModel.from_pretrained(
controlnet_name_or_path,
torch_dtype=self.weights_dtype,
cache_dir=self.cfg.cache_dir,
local_files_only=self.cfg.local_files_only
)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path, controlnet=controlnet, **pipe_kwargs
).to(self.device)
self.scheduler = DDIMScheduler.from_pretrained(
self.cfg.ddim_scheduler_name_or_path,
subfolder="scheduler",
torch_dtype=self.weights_dtype,
cache_dir=self.cfg.cache_dir,
local_files_only=self.cfg.local_files_only
)
self.vae = self.pipe.vae.eval()
self.unet = self.pipe.unet.eval()
self.controlnet = self.pipe.controlnet.eval()
def check_memory_efficiency_conditions(self):
if self.cfg.enable_memory_efficient_attention:
self.memory_efficiency_status()
if self.cfg.enable_sequential_cpu_offload:
self.pipe.enable_sequential_cpu_offload()
if self.cfg.enable_attention_slicing:
self.pipe.enable_attention_slicing(1)
if self.cfg.enable_channels_last_format:
self.pipe.unet.to(memory_format=torch.channels_last)
def memory_efficiency_status(self):
if parse_version(torch.__version__) >= parse_version("2"):
threestudio.info("PyTorch2.0 uses memory efficient attention by default.")
elif not is_xformers_available():
threestudio.warn("xformers is not available, memory efficient attention is not enabled.")
else:
self.pipe.enable_xformers_memory_efficient_attention()
@torch.cuda.amp.autocast(enabled=False)
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
self.min_step = int(self.num_train_timesteps * min_step_percent)
self.max_step = int(self.num_train_timesteps * max_step_percent)
@torch.cuda.amp.autocast(enabled=False)
def forward_controlnet(
self,
latents: Float[Tensor, "..."],
t: Float[Tensor, "..."],
image_cond: Float[Tensor, "..."],
condition_scale: float,
encoder_hidden_states: Float[Tensor, "..."],
) -> Float[Tensor, "..."]:
return self.controlnet(
latents.to(self.weights_dtype),
t.to(self.weights_dtype),
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
controlnet_cond=image_cond.to(self.weights_dtype),
conditioning_scale=condition_scale,
return_dict=False,
)
@torch.cuda.amp.autocast(enabled=False)
def forward_control_unet(
self,
latents: Float[Tensor, "..."],
t: Float[Tensor, "..."],
encoder_hidden_states: Float[Tensor, "..."],
cross_attention_kwargs,
down_block_additional_residuals,
mid_block_additional_residual,
) -> Float[Tensor, "..."]:
input_dtype = latents.dtype
return self.unet(
latents.to(self.weights_dtype),
t.to(self.weights_dtype),
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
).sample.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def encode_images(
self, imgs: Float[Tensor, "B 3 512 512"]
) -> Float[Tensor, "B 4 64 64"]:
input_dtype = imgs.dtype
imgs = imgs * 2.0 - 1.0
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def encode_cond_images(
self, imgs: Float[Tensor, "B 3 512 512"]
) -> Float[Tensor, "B 4 64 64"]:
input_dtype = imgs.dtype
imgs = imgs * 2.0 - 1.0
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
latents = posterior.mode()
uncond_image_latents = torch.zeros_like(latents)
latents = torch.cat([latents, latents, uncond_image_latents], dim=0)
return latents.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def decode_latents(
self,
latents: Float[Tensor, "B 4 H W"],
latent_height: int = 64,
latent_width: int = 64,
) -> Float[Tensor, "B 3 512 512"]:
input_dtype = latents.dtype
latents = F.interpolate(
latents, (latent_height, latent_width), mode="bilinear", align_corners=False
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents.to(self.weights_dtype)).sample
image = (image * 0.5 + 0.5).clamp(0, 1)
return image.to(input_dtype)
def edit_latents(
self,
text_embeddings: Float[Tensor, "BB 77 768"],
latents: Float[Tensor, "B 4 64 64"],
image_cond: Float[Tensor, "B 3 512 512"],
t: Int[Tensor, "B"],
mask=None
) -> Float[Tensor, "B 4 64 64"]:
batch_size = t.shape[0]
self.scheduler.set_timesteps(num_inference_steps=self.cfg.diffusion_steps)
init_timestep = max(1, min(int(self.cfg.diffusion_steps * t[0].item() / self.num_train_timesteps), self.cfg.diffusion_steps))
t_start = max(self.cfg.diffusion_steps - init_timestep, 0)
latent_timestep = self.scheduler.timesteps[t_start : t_start + 1].repeat(batch_size)
B, _, DH, DW = latents.shape
origin_latents = latents.clone()
if mask is not None:
mask = F.interpolate(mask, (DH, DW), mode="bilinear", antialias=True)
with torch.no_grad():
# sections of code used from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
noise = torch.randn_like(latents)
latents = self.scheduler.add_noise(latents, noise, latent_timestep) # type: ignore
threestudio.debug("Start editing...")
for i, step in enumerate(range(t_start, self.cfg.diffusion_steps)):
timestep = self.scheduler.timesteps[step]
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# pred noise
latent_model_input = torch.cat([latents] * 2)
(
down_block_res_samples,
mid_block_res_sample,
) = self.forward_controlnet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings,
image_cond=image_cond,
condition_scale=self.cfg.condition_scale,
)
noise_pred = self.forward_control_unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings,
cross_attention_kwargs=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)
# perform classifier-free guidance
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if mask is not None:
noise_pred = noise_pred * mask + (1-mask) * noise
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
threestudio.debug("Editing finished.")
return latents
def prepare_image_cond(self, cond_rgb: Float[Tensor, "B H W C"]):
if self.cfg.control_type == "normal":
cond_rgb = (
(cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy()
)
detected_map = self.preprocessor(cond_rgb)
control = (
torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0
)
control = control.unsqueeze(0)
control = control.permute(0, 3, 1, 2)
elif self.cfg.control_type == "canny" or self.cfg.control_type == "canny2":
cond_rgb = (
(cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy()
)
blurred_img = cv2.blur(cond_rgb, ksize=(5, 5))
detected_map = self.preprocessor(
blurred_img, self.cfg.canny_lower_bound, self.cfg.canny_upper_bound
)
control = (
torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0
)
control = control.unsqueeze(-1).repeat(1, 1, 3)
control = control.unsqueeze(0)
control = control.permute(0, 3, 1, 2)
elif self.cfg.control_type == "input_normal":
cond_rgb[..., 0] = (
1 - cond_rgb[..., 0]
) # Flip the sign on the x-axis to match bae system
control = cond_rgb.permute(0, 3, 1, 2)
else:
raise ValueError(f"Unknown control type: {self.cfg.control_type}")
return F.interpolate(control, (512, 512), mode="bilinear", align_corners=False)
def compute_grad_sds(
self,
text_embeddings: Float[Tensor, "BB 77 768"],
latents: Float[Tensor, "B 4 64 64"],
image_cond: Float[Tensor, "B 3 512 512"],
t: Int[Tensor, "B"],
):
with torch.no_grad():
# add noise
noise = torch.randn_like(latents) # TODO: use torch generator
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2)
down_block_res_samples, mid_block_res_sample = self.forward_controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
image_cond=image_cond,
condition_scale=self.cfg.condition_scale,
)
noise_pred = self.forward_control_unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
cross_attention_kwargs=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)
# perform classifier-free guidance
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
grad = w * (noise_pred - noise)
return grad
def __call__(
self,
rgb: Float[Tensor, "B H W C"],
cond_rgb: Float[Tensor, "B H W C"],
prompt_utils: PromptProcessorOutput,
mask: Float[Tensor, "B H W C"],
**kwargs,
):
batch_size, H, W, _ = rgb.shape
rgb_BCHW = rgb.permute(0, 3, 1, 2)
latents: Float[Tensor, "B 4 64 64"]
rgb_BCHW_512 = F.interpolate(
rgb_BCHW, (512, 512), mode="bilinear", align_corners=False
)
latents = self.encode_images(rgb_BCHW_512)
image_cond = self.prepare_image_cond(cond_rgb)
temp = torch.zeros(1).to(rgb.device)
text_embeddings = prompt_utils.get_text_embeddings(temp, temp, temp, False)
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(
self.min_step,
self.max_step + 1,
[batch_size],
dtype=torch.long,
device=self.device,
)
if self.cfg.use_sds:
grad = self.compute_grad_sds(text_embeddings, latents, image_cond, t)
grad = torch.nan_to_num(grad)
if self.grad_clip_val is not None:
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
target = (latents - grad).detach()
loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
return {
"loss_sds": loss_sds,
"grad_norm": grad.norm(),
"min_step": self.min_step,
"max_step": self.max_step,
}
else:
if mask is not None: mask = mask.permute(0, 3, 1, 2)
edit_latents = self.edit_latents(text_embeddings, latents, image_cond, t, mask)
edit_images = self.decode_latents(edit_latents)
edit_images = F.interpolate(edit_images, (H, W), mode="bilinear")
return {"edit_images": edit_images.permute(0, 2, 3, 1),
"edit_latents": edit_latents}
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
# clip grad for stable training as demonstrated in
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
# http://arxiv.org/abs/2303.15413
if self.cfg.grad_clip is not None:
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
self.set_min_max_steps(
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
)
if __name__ == "__main__":
from threestudio.utils.config import ExperimentConfig, load_config
from threestudio.utils.typing import Optional
cfg = load_config("configs/experimental/controlnet-normal.yaml")
guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance)
prompt_processor = threestudio.find(cfg.system.prompt_processor_type)(
cfg.system.prompt_processor
)
rgb_image = cv2.imread("assets/face.jpg")[:, :, ::-1].copy() / 255
rgb_image = cv2.resize(rgb_image, (512, 512))
rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device)
prompt_utils = prompt_processor()
guidance_out = guidance(rgb_image, rgb_image, prompt_utils)
edit_image = (
(guidance_out["edit_images"][0].detach().cpu().clip(0, 1).numpy() * 255)
.astype(np.uint8)[:, :, ::-1]
.copy()
)
os.makedirs(".threestudio_cache", exist_ok=True)
cv2.imwrite(".threestudio_cache/edit_image.jpg", edit_image)

View File

@ -0,0 +1,582 @@
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import IFPipeline, DDPMScheduler
from diffusers.utils.import_utils import is_xformers_available
from tqdm import tqdm
import threestudio
from threestudio.models.prompt_processors.base import PromptProcessorOutput
from threestudio.utils.base import BaseObject
from threestudio.utils.misc import C, parse_version
from threestudio.utils.ops import perpendicular_component
from threestudio.utils.typing import *
@threestudio.register("deep-floyd-guidance")
class DeepFloydGuidance(BaseObject):
@dataclass
class Config(BaseObject.Config):
cache_dir: Optional[str] = None
local_files_only: Optional[bool] = False
pretrained_model_name_or_path: str = "DeepFloyd/IF-I-XL-v1.0"
# FIXME: xformers error
enable_memory_efficient_attention: bool = False
enable_sequential_cpu_offload: bool = False
enable_attention_slicing: bool = False
enable_channels_last_format: bool = True
guidance_scale: float = 20.0
grad_clip: Optional[
Any
] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000])
time_prior: Optional[Any] = None # [w1,w2,s1,s2]
half_precision_weights: bool = True
min_step_percent: float = 0.02
max_step_percent: float = 0.98
weighting_strategy: str = "sds"
view_dependent_prompting: bool = True
"""Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items."""
max_items_eval: int = 4
lora_weights_path: Optional[str] = None
cfg: Config
def configure(self) -> None:
threestudio.info(f"Loading Deep Floyd ...")
self.weights_dtype = (
torch.float16 if self.cfg.half_precision_weights else torch.float32
)
# Create model
self.pipe = IFPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path,
text_encoder=None,
safety_checker=None,
watermarker=None,
feature_extractor=None,
requires_safety_checker=False,
variant="fp16" if self.cfg.half_precision_weights else None,
torch_dtype=self.weights_dtype,
cache_dir=self.cfg.cache_dir,
local_files_only=self.cfg.local_files_only
).to(self.device)
# Load lora weights
if self.cfg.lora_weights_path is not None:
self.pipe.load_lora_weights(self.cfg.lora_weights_path)
self.pipe.scheduler = self.pipe.scheduler.__class__.from_config(self.pipe.scheduler.config, variance_type="fixed_small")
if self.cfg.enable_memory_efficient_attention:
if parse_version(torch.__version__) >= parse_version("2"):
threestudio.info(
"PyTorch2.0 uses memory efficient attention by default."
)
elif not is_xformers_available():
threestudio.warn(
"xformers is not available, memory efficient attention is not enabled."
)
else:
threestudio.warn(
f"Use DeepFloyd with xformers may raise error, see https://github.com/deep-floyd/IF/issues/52 to track this problem."
)
self.pipe.enable_xformers_memory_efficient_attention()
if self.cfg.enable_sequential_cpu_offload:
self.pipe.enable_sequential_cpu_offload()
if self.cfg.enable_attention_slicing:
self.pipe.enable_attention_slicing(1)
if self.cfg.enable_channels_last_format:
self.pipe.unet.to(memory_format=torch.channels_last)
self.unet = self.pipe.unet.eval()
for p in self.unet.parameters():
p.requires_grad_(False)
self.scheduler = self.pipe.scheduler
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.set_min_max_steps() # set to default value
if self.cfg.time_prior is not None:
m1, m2, s1, s2 = self.cfg.time_prior
weights = torch.cat(
(
torch.exp(
-((torch.arange(self.num_train_timesteps, m1, -1) - m1) ** 2)
/ (2 * s1**2)
),
torch.ones(m1 - m2 + 1),
torch.exp(
-((torch.arange(m2 - 1, 0, -1) - m2) ** 2) / (2 * s2**2)
),
)
)
weights = weights / torch.sum(weights)
self.time_prior_acc_weights = torch.cumsum(weights, dim=0)
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
self.device
)
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(
self.device
)
self.grad_clip_val: Optional[float] = None
threestudio.info(f"Loaded Deep Floyd!")
@torch.cuda.amp.autocast(enabled=False)
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
self.min_step = int(self.num_train_timesteps * min_step_percent)
self.max_step = int(self.num_train_timesteps * max_step_percent)
@torch.cuda.amp.autocast(enabled=False)
def forward_unet(
self,
latents: Float[Tensor, "..."],
t: Float[Tensor, "..."],
encoder_hidden_states: Float[Tensor, "..."],
) -> Float[Tensor, "..."]:
input_dtype = latents.dtype
return self.unet(
latents.to(self.weights_dtype),
t.to(self.weights_dtype),
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
).sample.to(input_dtype)
def __call__(
self,
rgb: Float[Tensor, "B H W C"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
current_step_ratio=None,
mask: Float[Tensor, "B H W 1"] = None,
rgb_as_latents=False,
guidance_eval=False,
**kwargs,
):
batch_size = rgb.shape[0]
rgb_BCHW = rgb.permute(0, 3, 1, 2)
if mask is not None:
mask = mask.permute(0, 3, 1, 2)
mask = F.interpolate(
mask, (64, 64), mode="bilinear", align_corners=False
)
assert rgb_as_latents == False, f"No latent space in {self.__class__.__name__}"
rgb_BCHW = rgb_BCHW * 2.0 - 1.0 # scale to [-1, 1] to match the diffusion range
latents = F.interpolate(
rgb_BCHW, (64, 64), mode="bilinear", align_corners=False
)
if self.cfg.time_prior is not None:
time_index = torch.where(
(self.time_prior_acc_weights - current_step_ratio) > 0
)[0][0]
if time_index == 0 or torch.abs(
self.time_prior_acc_weights[time_index] - current_step_ratio
) < torch.abs(
self.time_prior_acc_weights[time_index - 1] - current_step_ratio
):
t = self.num_train_timesteps - time_index
else:
t = self.num_train_timesteps - time_index + 1
t = torch.clip(t, self.min_step, self.max_step + 1)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(
self.min_step,
self.max_step + 1,
[batch_size],
dtype=torch.long,
device=self.device,
)
if prompt_utils.use_perp_neg:
(
text_embeddings,
neg_guidance_weights,
) = prompt_utils.get_text_embeddings_perp_neg(
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
)
with torch.no_grad():
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
if mask is not None:
latents_noisy = (1 - mask) * latents + mask * latents_noisy
latent_model_input = torch.cat([latents_noisy] * 4, dim=0)
noise_pred = self.forward_unet(
latent_model_input,
torch.cat([t] * 4),
encoder_hidden_states=text_embeddings,
) # (4B, 6, 64, 64)
noise_pred_text, _ = noise_pred[:batch_size].split(3, dim=1)
noise_pred_uncond, _ = noise_pred[batch_size : batch_size * 2].split(
3, dim=1
)
noise_pred_neg, _ = noise_pred[batch_size * 2 :].split(3, dim=1)
e_pos = noise_pred_text - noise_pred_uncond
accum_grad = 0
n_negative_prompts = neg_guidance_weights.shape[-1]
for i in range(n_negative_prompts):
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
accum_grad += neg_guidance_weights[:, i].view(
-1, 1, 1, 1
) * perpendicular_component(e_i_neg, e_pos)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
e_pos + accum_grad
)
else:
neg_guidance_weights = None
text_embeddings = prompt_utils.get_text_embeddings(
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(latents) # TODO: use torch generator
latents_noisy = self.scheduler.add_noise(latents, noise, t)
if mask is not None:
latents_noisy = (1 - mask) * latents + mask * latents_noisy
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
noise_pred = self.forward_unet(
latent_model_input,
torch.cat([t] * 2),
encoder_hidden_states=text_embeddings,
) # (2B, 6, 64, 64)
# perform guidance (high scale from paper!)
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred_text, predicted_variance = noise_pred_text.split(3, dim=1)
noise_pred_uncond, _ = noise_pred_uncond.split(3, dim=1)
noise_pred = noise_pred_text + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
"""
# thresholding, experimental
if self.cfg.thresholding:
assert batch_size == 1
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
noise_pred = custom_ddpm_step(self.scheduler,
noise_pred, int(t.item()), latents_noisy, **self.pipe.prepare_extra_step_kwargs(None, 0.0)
)
"""
if self.cfg.weighting_strategy == "sds":
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
elif self.cfg.weighting_strategy == "uniform":
w = 1
elif self.cfg.weighting_strategy == "fantasia3d":
w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1)
else:
raise ValueError(
f"Unknown weighting strategy: {self.cfg.weighting_strategy}"
)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# clip grad for stable training?
if self.grad_clip_val is not None:
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
# loss = SpecifyGradient.apply(latents, grad)
# SpecifyGradient is not straghtforward, use a reparameterization trick instead
target = (latents - grad).detach()
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
loss_sd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
guidance_out = {
"loss_sd": loss_sd,
"grad_norm": grad.norm(),
"min_step": self.min_step,
"max_step": self.max_step,
}
# # FIXME: Visualize inpainting results
# self.scheduler.set_timesteps(20)
# latents = latents_noisy
# for t in tqdm(self.scheduler.timesteps):
# # pred noise
# noise_pred = self.get_noise_pred(
# latents, t, text_embeddings, prompt_utils.use_perp_neg, None
# )
# # get prev latent
# prev_latents = latents
# latents = self.scheduler.step(noise_pred, t, latents)["prev_sample"]
# if mask is not None:
# latents = (1 - mask) * prev_latents + mask * latents
# denoised_img = (latents / 2 + 0.5).permute(0, 2, 3, 1)
# guidance_out.update(
# {"denoised_img": denoised_img}
# )
if guidance_eval:
guidance_eval_utils = {
"use_perp_neg": prompt_utils.use_perp_neg,
"neg_guidance_weights": neg_guidance_weights,
"text_embeddings": text_embeddings,
"t_orig": t,
"latents_noisy": latents_noisy,
"noise_pred": torch.cat([noise_pred, predicted_variance], dim=1),
}
guidance_eval_out = self.guidance_eval(**guidance_eval_utils)
texts = []
for n, e, a, c in zip(
guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances
):
texts.append(
f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}"
)
guidance_eval_out.update({"texts": texts})
guidance_out.update({"eval": guidance_eval_out})
return guidance_out
@torch.cuda.amp.autocast(enabled=False)
@torch.no_grad()
def get_noise_pred(
self,
latents_noisy,
t,
text_embeddings,
use_perp_neg=False,
neg_guidance_weights=None,
):
batch_size = latents_noisy.shape[0]
if use_perp_neg:
latent_model_input = torch.cat([latents_noisy] * 4, dim=0)
noise_pred = self.forward_unet(
latent_model_input,
torch.cat([t.reshape(1)] * 4).to(self.device),
encoder_hidden_states=text_embeddings,
) # (4B, 6, 64, 64)
noise_pred_text, _ = noise_pred[:batch_size].split(3, dim=1)
noise_pred_uncond, _ = noise_pred[batch_size : batch_size * 2].split(
3, dim=1
)
noise_pred_neg, _ = noise_pred[batch_size * 2 :].split(3, dim=1)
e_pos = noise_pred_text - noise_pred_uncond
accum_grad = 0
n_negative_prompts = neg_guidance_weights.shape[-1]
for i in range(n_negative_prompts):
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
accum_grad += neg_guidance_weights[:, i].view(
-1, 1, 1, 1
) * perpendicular_component(e_i_neg, e_pos)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
e_pos + accum_grad
)
else:
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
noise_pred = self.forward_unet(
latent_model_input,
torch.cat([t.reshape(1)] * 2).to(self.device),
encoder_hidden_states=text_embeddings,
) # (2B, 6, 64, 64)
# perform guidance (high scale from paper!)
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred_text, predicted_variance = noise_pred_text.split(3, dim=1)
noise_pred_uncond, _ = noise_pred_uncond.split(3, dim=1)
noise_pred = noise_pred_text + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return torch.cat([noise_pred, predicted_variance], dim=1)
@torch.cuda.amp.autocast(enabled=False)
@torch.no_grad()
def guidance_eval(
self,
t_orig,
text_embeddings,
latents_noisy,
noise_pred,
use_perp_neg=False,
neg_guidance_weights=None,
):
# use only 50 timesteps, and find nearest of those to t
self.scheduler.set_timesteps(50)
self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device)
bs = (
min(self.cfg.max_items_eval, latents_noisy.shape[0])
if self.cfg.max_items_eval > 0
else latents_noisy.shape[0]
) # batch size
large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[
:bs
].unsqueeze(
-1
) # sized [bs,50] > [bs,1]
idxs = torch.min(large_enough_idxs, dim=1)[1]
t = self.scheduler.timesteps_gpu[idxs]
fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy())
imgs_noisy = (latents_noisy[:bs] / 2 + 0.5).permute(0, 2, 3, 1)
# get prev latent
latents_1step = []
pred_1orig = []
for b in range(bs):
step_output = self.scheduler.step(
noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1]
)
latents_1step.append(step_output["prev_sample"])
pred_1orig.append(step_output["pred_original_sample"])
latents_1step = torch.cat(latents_1step)
pred_1orig = torch.cat(pred_1orig)
imgs_1step = (latents_1step / 2 + 0.5).permute(0, 2, 3, 1)
imgs_1orig = (pred_1orig / 2 + 0.5).permute(0, 2, 3, 1)
latents_final = []
for b, i in enumerate(idxs):
latents = latents_1step[b : b + 1]
text_emb = (
text_embeddings[
[b, b + len(idxs), b + 2 * len(idxs), b + 3 * len(idxs)], ...
]
if use_perp_neg
else text_embeddings[[b, b + len(idxs)], ...]
)
neg_guid = neg_guidance_weights[b : b + 1] if use_perp_neg else None
for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False):
# pred noise
noise_pred = self.get_noise_pred(
latents, t, text_emb, use_perp_neg, neg_guid
)
# get prev latent
latents = self.scheduler.step(noise_pred, t, latents)["prev_sample"]
latents_final.append(latents)
latents_final = torch.cat(latents_final)
imgs_final = (latents_final / 2 + 0.5).permute(0, 2, 3, 1)
return {
"bs": bs,
"noise_levels": fracs,
"imgs_noisy": imgs_noisy,
"imgs_1step": imgs_1step,
"imgs_1orig": imgs_1orig,
"imgs_final": imgs_final,
}
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
# clip grad for stable training as demonstrated in
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
# http://arxiv.org/abs/2303.15413
if self.cfg.grad_clip is not None:
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
self.set_min_max_steps(
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
)
"""
# used by thresholding, experimental
def custom_ddpm_step(ddpm, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True):
self = ddpm
t = timestep
prev_t = self.previous_timestep(t)
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t].item()
alpha_prod_t_prev = self.alphas_cumprod[prev_t].item() if prev_t >= 0 else 1.0
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
" `v_prediction` for the DDPMScheduler."
)
# 3. Clip or threshold "predicted x_0"
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
noise_thresholded = (sample - (alpha_prod_t ** 0.5) * pred_original_sample) / (beta_prod_t ** 0.5)
return noise_thresholded
"""
if __name__ == '__main__':
from threestudio.utils.config import load_config
import pytorch_lightning as pl
import numpy as np
import os
import cv2
cfg = load_config("configs/debugging/deepfloyd.yaml")
guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance)
prompt_processor = threestudio.find(cfg.system.prompt_processor_type)(cfg.system.prompt_processor)
prompt_utils = prompt_processor()
temp = torch.zeros(1).to(guidance.device)
# rgb_image = guidance.sample(prompt_utils, temp, temp, temp, seed=cfg.seed)
# rgb_image = (rgb_image[0].detach().cpu().clip(0, 1).numpy()*255).astype(np.uint8)[:, :, ::-1].copy()
# os.makedirs('.threestudio_cache', exist_ok=True)
# cv2.imwrite('.threestudio_cache/diffusion_image.jpg', rgb_image)
### inpaint
rgb_image = cv2.imread("assets/test.jpg")[:, :, ::-1].copy() / 255
mask_image = cv2.imread("assets/mask.png")[:, :, :1].copy() / 255
rgb_image = cv2.resize(rgb_image, (512, 512))
mask_image = cv2.resize(mask_image, (512, 512)).reshape(512, 512, 1)
rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device)
mask_image = torch.FloatTensor(mask_image).unsqueeze(0).to(guidance.device)
guidance_out = guidance(rgb_image, prompt_utils, temp, temp, temp, mask=mask_image)
edit_image = (
(guidance_out["denoised_img"][0].detach().cpu().clip(0, 1).numpy() * 255)
.astype(np.uint8)[:, :, ::-1]
.copy()
)
os.makedirs(".threestudio_cache", exist_ok=True)
cv2.imwrite(".threestudio_cache/edit_image.jpg", edit_image)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,632 @@
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDIMScheduler, DDPMScheduler, StableDiffusionPipeline
from diffusers.utils.import_utils import is_xformers_available
from tqdm import tqdm
import threestudio
from threestudio.models.prompt_processors.base import PromptProcessorOutput
from threestudio.utils.base import BaseObject
from threestudio.utils.misc import C, cleanup, parse_version
from threestudio.utils.ops import perpendicular_component
from threestudio.utils.typing import *
@threestudio.register("stable-diffusion-guidance")
class StableDiffusionGuidance(BaseObject):
@dataclass
class Config(BaseObject.Config):
cache_dir: Optional[str] = None
local_files_only: Optional[bool] = False
pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
enable_memory_efficient_attention: bool = False
enable_sequential_cpu_offload: bool = False
enable_attention_slicing: bool = False
enable_channels_last_format: bool = False
guidance_scale: float = 100.0
grad_clip: Optional[
Any
] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000])
time_prior: Optional[Any] = None # [w1,w2,s1,s2]
half_precision_weights: bool = True
min_step_percent: float = 0.02
max_step_percent: float = 0.98
max_step_percent_annealed: float = 0.5
anneal_start_step: Optional[int] = None
use_sjc: bool = False
var_red: bool = True
weighting_strategy: str = "sds"
token_merging: bool = False
token_merging_params: Optional[dict] = field(default_factory=dict)
view_dependent_prompting: bool = True
"""Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items."""
max_items_eval: int = 4
cfg: Config
def configure(self) -> None:
threestudio.info(f"Loading Stable Diffusion ...")
self.weights_dtype = (
torch.float16 if self.cfg.half_precision_weights else torch.float32
)
pipe_kwargs = {
"tokenizer": None,
"safety_checker": None,
"feature_extractor": None,
"requires_safety_checker": False,
"torch_dtype": self.weights_dtype,
"cache_dir": self.cfg.cache_dir,
"local_files_only": self.cfg.local_files_only
}
self.pipe = StableDiffusionPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path,
**pipe_kwargs,
).to(self.device)
if self.cfg.enable_memory_efficient_attention:
if parse_version(torch.__version__) >= parse_version("2"):
threestudio.info(
"PyTorch2.0 uses memory efficient attention by default."
)
elif not is_xformers_available():
threestudio.warn(
"xformers is not available, memory efficient attention is not enabled."
)
else:
self.pipe.enable_xformers_memory_efficient_attention()
if self.cfg.enable_sequential_cpu_offload:
self.pipe.enable_sequential_cpu_offload()
if self.cfg.enable_attention_slicing:
self.pipe.enable_attention_slicing(1)
if self.cfg.enable_channels_last_format:
self.pipe.unet.to(memory_format=torch.channels_last)
del self.pipe.text_encoder
cleanup()
# Create model
self.vae = self.pipe.vae.eval()
self.unet = self.pipe.unet.eval()
for p in self.vae.parameters():
p.requires_grad_(False)
for p in self.unet.parameters():
p.requires_grad_(False)
if self.cfg.token_merging:
import tomesd
tomesd.apply_patch(self.unet, **self.cfg.token_merging_params)
if self.cfg.use_sjc:
# score jacobian chaining use DDPM
self.scheduler = DDPMScheduler.from_pretrained(
self.cfg.pretrained_model_name_or_path,
subfolder="scheduler",
torch_dtype=self.weights_dtype,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
cache_dir=self.cfg.cache_dir,
)
else:
self.scheduler = DDIMScheduler.from_pretrained(
self.cfg.pretrained_model_name_or_path,
subfolder="scheduler",
torch_dtype=self.weights_dtype,
cache_dir=self.cfg.cache_dir,
local_files_only=self.cfg.local_files_only,
)
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.set_min_max_steps() # set to default value
if self.cfg.time_prior is not None:
m1, m2, s1, s2 = self.cfg.time_prior
weights = torch.cat(
(
torch.exp(
-((torch.arange(self.num_train_timesteps, m1, -1) - m1) ** 2)
/ (2 * s1**2)
),
torch.ones(m1 - m2 + 1),
torch.exp(
-((torch.arange(m2 - 1, 0, -1) - m2) ** 2) / (2 * s2**2)
),
)
)
weights = weights / torch.sum(weights)
self.time_prior_acc_weights = torch.cumsum(weights, dim=0)
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
self.device
)
if self.cfg.use_sjc:
# score jacobian chaining need mu
self.us: Float[Tensor, "..."] = torch.sqrt((1 - self.alphas) / self.alphas)
self.grad_clip_val: Optional[float] = None
threestudio.info(f"Loaded Stable Diffusion!")
@torch.cuda.amp.autocast(enabled=False)
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
self.min_step = int(self.num_train_timesteps * min_step_percent)
self.max_step = int(self.num_train_timesteps * max_step_percent)
@torch.cuda.amp.autocast(enabled=False)
def forward_unet(
self,
latents: Float[Tensor, "..."],
t: Float[Tensor, "..."],
encoder_hidden_states: Float[Tensor, "..."],
) -> Float[Tensor, "..."]:
input_dtype = latents.dtype
return self.unet(
latents.to(self.weights_dtype),
t.to(self.weights_dtype),
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
).sample.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def encode_images(
self, imgs: Float[Tensor, "B 3 512 512"]
) -> Float[Tensor, "B 4 64 64"]:
input_dtype = imgs.dtype
imgs = imgs * 2.0 - 1.0
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def decode_latents(
self,
latents: Float[Tensor, "B 4 H W"],
latent_height: int = 64,
latent_width: int = 64,
) -> Float[Tensor, "B 3 512 512"]:
input_dtype = latents.dtype
latents = F.interpolate(
latents, (latent_height, latent_width), mode="bilinear", align_corners=False
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents.to(self.weights_dtype)).sample
image = (image * 0.5 + 0.5).clamp(0, 1)
return image.to(input_dtype)
def compute_grad_sds(
self,
latents: Float[Tensor, "B 4 64 64"],
t: Int[Tensor, "B"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
):
batch_size = elevation.shape[0]
if prompt_utils.use_perp_neg:
(
text_embeddings,
neg_guidance_weights,
) = prompt_utils.get_text_embeddings_perp_neg(
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
)
with torch.no_grad():
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
latent_model_input = torch.cat([latents_noisy] * 4, dim=0)
noise_pred = self.forward_unet(
latent_model_input,
torch.cat([t] * 4),
encoder_hidden_states=text_embeddings,
) # (4B, 3, 64, 64)
noise_pred_text = noise_pred[:batch_size]
noise_pred_uncond = noise_pred[batch_size : batch_size * 2]
noise_pred_neg = noise_pred[batch_size * 2 :]
e_pos = noise_pred_text - noise_pred_uncond
accum_grad = 0
n_negative_prompts = neg_guidance_weights.shape[-1]
for i in range(n_negative_prompts):
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
accum_grad += neg_guidance_weights[:, i].view(
-1, 1, 1, 1
) * perpendicular_component(e_i_neg, e_pos)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
e_pos + accum_grad
)
else:
neg_guidance_weights = None
text_embeddings = prompt_utils.get_text_embeddings(
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(latents) # TODO: use torch generator
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
noise_pred = self.forward_unet(
latent_model_input,
torch.cat([t] * 2),
encoder_hidden_states=text_embeddings,
)
# perform guidance (high scale from paper!)
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_text + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if self.cfg.weighting_strategy == "sds":
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
elif self.cfg.weighting_strategy == "uniform":
w = 1
elif self.cfg.weighting_strategy == "fantasia3d":
w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1)
else:
raise ValueError(
f"Unknown weighting strategy: {self.cfg.weighting_strategy}"
)
grad = w * (noise_pred - noise)
guidance_eval_utils = {
"use_perp_neg": prompt_utils.use_perp_neg,
"neg_guidance_weights": neg_guidance_weights,
"text_embeddings": text_embeddings,
"t_orig": t,
"latents_noisy": latents_noisy,
"noise_pred": noise_pred,
}
return grad, guidance_eval_utils
def compute_grad_sjc(
self,
latents: Float[Tensor, "B 4 64 64"],
t: Int[Tensor, "B"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
):
batch_size = elevation.shape[0]
sigma = self.us[t]
sigma = sigma.view(-1, 1, 1, 1)
if prompt_utils.use_perp_neg:
(
text_embeddings,
neg_guidance_weights,
) = prompt_utils.get_text_embeddings_perp_neg(
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
)
with torch.no_grad():
noise = torch.randn_like(latents)
y = latents
zs = y + sigma * noise
scaled_zs = zs / torch.sqrt(1 + sigma**2)
# pred noise
latent_model_input = torch.cat([scaled_zs] * 4, dim=0)
noise_pred = self.forward_unet(
latent_model_input,
torch.cat([t] * 4),
encoder_hidden_states=text_embeddings,
) # (4B, 3, 64, 64)
noise_pred_text = noise_pred[:batch_size]
noise_pred_uncond = noise_pred[batch_size : batch_size * 2]
noise_pred_neg = noise_pred[batch_size * 2 :]
e_pos = noise_pred_text - noise_pred_uncond
accum_grad = 0
n_negative_prompts = neg_guidance_weights.shape[-1]
for i in range(n_negative_prompts):
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
accum_grad += neg_guidance_weights[:, i].view(
-1, 1, 1, 1
) * perpendicular_component(e_i_neg, e_pos)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
e_pos + accum_grad
)
else:
neg_guidance_weights = None
text_embeddings = prompt_utils.get_text_embeddings(
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(latents) # TODO: use torch generator
y = latents
zs = y + sigma * noise
scaled_zs = zs / torch.sqrt(1 + sigma**2)
# pred noise
latent_model_input = torch.cat([scaled_zs] * 2, dim=0)
noise_pred = self.forward_unet(
latent_model_input,
torch.cat([t] * 2),
encoder_hidden_states=text_embeddings,
)
# perform guidance (high scale from paper!)
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_text + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
Ds = zs - sigma * noise_pred
if self.cfg.var_red:
grad = -(Ds - y) / sigma
else:
grad = -(Ds - zs) / sigma
guidance_eval_utils = {
"use_perp_neg": prompt_utils.use_perp_neg,
"neg_guidance_weights": neg_guidance_weights,
"text_embeddings": text_embeddings,
"t_orig": t,
"latents_noisy": scaled_zs,
"noise_pred": noise_pred,
}
return grad, guidance_eval_utils
def __call__(
self,
rgb: Float[Tensor, "B H W C"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
rgb_as_latents=False,
guidance_eval=False,
current_step_ratio=None,
**kwargs,
):
batch_size = rgb.shape[0]
rgb_BCHW = rgb.permute(0, 3, 1, 2)
latents: Float[Tensor, "B 4 64 64"]
if rgb_as_latents:
latents = F.interpolate(
rgb_BCHW, (64, 64), mode="bilinear", align_corners=False
)
else:
rgb_BCHW_512 = F.interpolate(
rgb_BCHW, (512, 512), mode="bilinear", align_corners=False
)
# encode image into latents with vae
latents = self.encode_images(rgb_BCHW_512)
if self.cfg.time_prior is not None:
time_index = torch.where(
(self.time_prior_acc_weights - current_step_ratio) > 0
)[0][0]
if time_index == 0 or torch.abs(
self.time_prior_acc_weights[time_index] - current_step_ratio
) < torch.abs(
self.time_prior_acc_weights[time_index - 1] - current_step_ratio
):
t = self.num_train_timesteps - time_index
else:
t = self.num_train_timesteps - time_index + 1
t = torch.clip(t, self.min_step, self.max_step + 1)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(
self.min_step,
self.max_step + 1,
[batch_size],
dtype=torch.long,
device=self.device,
)
if self.cfg.use_sjc:
grad, guidance_eval_utils = self.compute_grad_sjc(
latents, t, prompt_utils, elevation, azimuth, camera_distances
)
else:
grad, guidance_eval_utils = self.compute_grad_sds(
latents, t, prompt_utils, elevation, azimuth, camera_distances
)
grad = torch.nan_to_num(grad)
# clip grad for stable training?
if self.grad_clip_val is not None:
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
# loss = SpecifyGradient.apply(latents, grad)
# SpecifyGradient is not straghtforward, use a reparameterization trick instead
target = (latents - grad).detach()
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
guidance_out = {
"loss_sd": loss_sds,
"grad_norm": grad.norm(),
"min_step": self.min_step,
"max_step": self.max_step,
}
if guidance_eval:
guidance_eval_out = self.guidance_eval(**guidance_eval_utils)
texts = []
for n, e, a, c in zip(
guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances
):
texts.append(
f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}"
)
guidance_eval_out.update({"texts": texts})
guidance_out.update({"eval": guidance_eval_out})
return guidance_out
@torch.cuda.amp.autocast(enabled=False)
@torch.no_grad()
def get_noise_pred(
self,
latents_noisy,
t,
text_embeddings,
use_perp_neg=False,
neg_guidance_weights=None,
):
batch_size = latents_noisy.shape[0]
if use_perp_neg:
# pred noise
latent_model_input = torch.cat([latents_noisy] * 4, dim=0)
noise_pred = self.forward_unet(
latent_model_input,
torch.cat([t.reshape(1)] * 4).to(self.device),
encoder_hidden_states=text_embeddings,
) # (4B, 3, 64, 64)
noise_pred_text = noise_pred[:batch_size]
noise_pred_uncond = noise_pred[batch_size : batch_size * 2]
noise_pred_neg = noise_pred[batch_size * 2 :]
e_pos = noise_pred_text - noise_pred_uncond
accum_grad = 0
n_negative_prompts = neg_guidance_weights.shape[-1]
for i in range(n_negative_prompts):
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
accum_grad += neg_guidance_weights[:, i].view(
-1, 1, 1, 1
) * perpendicular_component(e_i_neg, e_pos)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
e_pos + accum_grad
)
else:
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
noise_pred = self.forward_unet(
latent_model_input,
torch.cat([t.reshape(1)] * 2).to(self.device),
encoder_hidden_states=text_embeddings,
)
# perform guidance (high scale from paper!)
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_text + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
@torch.cuda.amp.autocast(enabled=False)
@torch.no_grad()
def guidance_eval(
self,
t_orig,
text_embeddings,
latents_noisy,
noise_pred,
use_perp_neg=False,
neg_guidance_weights=None,
):
# use only 50 timesteps, and find nearest of those to t
self.scheduler.set_timesteps(50)
self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device)
bs = (
min(self.cfg.max_items_eval, latents_noisy.shape[0])
if self.cfg.max_items_eval > 0
else latents_noisy.shape[0]
) # batch size
large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[
:bs
].unsqueeze(
-1
) # sized [bs,50] > [bs,1]
idxs = torch.min(large_enough_idxs, dim=1)[1]
t = self.scheduler.timesteps_gpu[idxs]
fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy())
imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1)
# get prev latent
latents_1step = []
pred_1orig = []
for b in range(bs):
step_output = self.scheduler.step(
noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1
)
latents_1step.append(step_output["prev_sample"])
pred_1orig.append(step_output["pred_original_sample"])
latents_1step = torch.cat(latents_1step)
pred_1orig = torch.cat(pred_1orig)
imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1)
imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1)
latents_final = []
for b, i in enumerate(idxs):
latents = latents_1step[b : b + 1]
text_emb = (
text_embeddings[
[b, b + len(idxs), b + 2 * len(idxs), b + 3 * len(idxs)], ...
]
if use_perp_neg
else text_embeddings[[b, b + len(idxs)], ...]
)
neg_guid = neg_guidance_weights[b : b + 1] if use_perp_neg else None
for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False):
# pred noise
noise_pred = self.get_noise_pred(
latents, t, text_emb, use_perp_neg, neg_guid
)
# get prev latent
latents = self.scheduler.step(noise_pred, t, latents, eta=1)[
"prev_sample"
]
latents_final.append(latents)
latents_final = torch.cat(latents_final)
imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1)
return {
"bs": bs,
"noise_levels": fracs,
"imgs_noisy": imgs_noisy,
"imgs_1step": imgs_1step,
"imgs_1orig": imgs_1orig,
"imgs_final": imgs_final,
}
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
# clip grad for stable training as demonstrated in
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
# http://arxiv.org/abs/2303.15413
if self.cfg.grad_clip is not None:
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
self.set_min_max_steps(
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
)

View File

@ -0,0 +1,729 @@
import random
from contextlib import contextmanager
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import (
AutoencoderKL,
ControlNetModel,
DDPMScheduler,
DPMSolverSinglestepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.models.embeddings import TimestepEmbedding
from diffusers.utils.import_utils import is_xformers_available
from tqdm import tqdm
import threestudio
from threestudio.models.networks import ToDTypeWrapper
from threestudio.models.prompt_processors.base import PromptProcessorOutput
from threestudio.utils.base import BaseModule
from threestudio.utils.misc import C, cleanup, enable_gradient, parse_version
from threestudio.utils.ops import perpendicular_component
from threestudio.utils.typing import *
@threestudio.register("stable-diffusion-unified-guidance")
class StableDiffusionUnifiedGuidance(BaseModule):
@dataclass
class Config(BaseModule.Config):
cache_dir: Optional[str] = None
local_files_only: Optional[bool] = False
# guidance type, in ["sds", "vsd"]
guidance_type: str = "sds"
pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
guidance_scale: float = 100.0
weighting_strategy: str = "dreamfusion"
view_dependent_prompting: bool = True
min_step_percent: Any = 0.02
max_step_percent: Any = 0.98
grad_clip: Optional[Any] = None
return_rgb_1step_orig: bool = False
return_rgb_multistep_orig: bool = False
n_rgb_multistep_orig_steps: int = 4
# TODO
# controlnet
controlnet_model_name_or_path: Optional[str] = None
preprocessor: Optional[str] = None
control_scale: float = 1.0
# TODO
# lora
lora_model_name_or_path: Optional[str] = None
# efficiency-related configurations
half_precision_weights: bool = True
enable_memory_efficient_attention: bool = False
enable_sequential_cpu_offload: bool = False
enable_attention_slicing: bool = False
enable_channels_last_format: bool = False
token_merging: bool = False
token_merging_params: Optional[dict] = field(default_factory=dict)
# VSD configurations, only used when guidance_type is "vsd"
vsd_phi_model_name_or_path: Optional[str] = None
vsd_guidance_scale_phi: float = 1.0
vsd_use_lora: bool = True
vsd_lora_cfg_training: bool = False
vsd_lora_n_timestamp_samples: int = 1
vsd_use_camera_condition: bool = True
# camera condition type, in ["extrinsics", "mvp", "spherical"]
vsd_camera_condition_type: Optional[str] = "extrinsics"
cfg: Config
def configure(self) -> None:
self.min_step: Optional[int] = None
self.max_step: Optional[int] = None
self.grad_clip_val: Optional[float] = None
@dataclass
class NonTrainableModules:
pipe: StableDiffusionPipeline
pipe_phi: Optional[StableDiffusionPipeline] = None
controlnet: Optional[ControlNetModel] = None
self.weights_dtype = (
torch.float16 if self.cfg.half_precision_weights else torch.float32
)
threestudio.info(f"Loading Stable Diffusion ...")
pipe_kwargs = {
"tokenizer": None,
"safety_checker": None,
"feature_extractor": None,
"requires_safety_checker": False,
"torch_dtype": self.weights_dtype,
"cache_dir": self.cfg.cache_dir,
"local_files_only": self.cfg.local_files_only,
}
pipe = StableDiffusionPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path,
**pipe_kwargs,
).to(self.device)
self.prepare_pipe(pipe)
self.configure_pipe_token_merging(pipe)
# phi network for VSD
# introduce two trainable modules:
# - self.camera_embedding
# - self.lora_layers
pipe_phi = None
# if the phi network shares the same unet with the pretrain network
# we need to pass additional cross attention kwargs to the unet
self.vsd_share_model = (
self.cfg.guidance_type == "vsd"
and self.cfg.vsd_phi_model_name_or_path is None
)
if self.cfg.guidance_type == "vsd":
if self.cfg.vsd_phi_model_name_or_path is None:
pipe_phi = pipe
else:
pipe_phi = StableDiffusionPipeline.from_pretrained(
self.cfg.vsd_phi_model_name_or_path,
**pipe_kwargs,
).to(self.device)
self.prepare_pipe(pipe_phi)
self.configure_pipe_token_merging(pipe_phi)
# set up camera embedding
if self.cfg.vsd_use_camera_condition:
if self.cfg.vsd_camera_condition_type in ["extrinsics", "mvp"]:
self.camera_embedding_dim = 16
elif self.cfg.vsd_camera_condition_type == "spherical":
self.camera_embedding_dim = 4
else:
raise ValueError("Invalid camera condition type!")
# FIXME: hard-coded output dim
self.camera_embedding = ToDTypeWrapper(
TimestepEmbedding(self.camera_embedding_dim, 1280),
self.weights_dtype,
).to(self.device)
pipe_phi.unet.class_embedding = self.camera_embedding
if self.cfg.vsd_use_lora:
# set up LoRA layers
lora_attn_procs = {}
for name in pipe_phi.unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else pipe_phi.unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = pipe_phi.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(
reversed(pipe_phi.unet.config.block_out_channels)
)[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = pipe_phi.unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
pipe_phi.unet.set_attn_processor(lora_attn_procs)
self.lora_layers = AttnProcsLayers(pipe_phi.unet.attn_processors).to(
self.device
)
self.lora_layers._load_state_dict_pre_hooks.clear()
self.lora_layers._state_dict_hooks.clear()
threestudio.info(f"Loaded Stable Diffusion!")
# controlnet
controlnet = None
if self.cfg.controlnet_model_name_or_path is not None:
threestudio.info(f"Loading ControlNet ...")
controlnet = ControlNetModel.from_pretrained(
self.cfg.controlnet_model_name_or_path,
torch_dtype=self.weights_dtype,
).to(self.device)
controlnet.eval()
enable_gradient(controlnet, enabled=False)
threestudio.info(f"Loaded ControlNet!")
self.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
# q(z_t|x) = N(alpha_t x, sigma_t^2 I)
# in DDPM, alpha_t = sqrt(alphas_cumprod_t), sigma_t^2 = 1 - alphas_cumprod_t
self.alphas_cumprod: Float[Tensor, "T"] = self.scheduler.alphas_cumprod.to(
self.device
)
self.alphas: Float[Tensor, "T"] = self.alphas_cumprod**0.5
self.sigmas: Float[Tensor, "T"] = (1 - self.alphas_cumprod) ** 0.5
# log SNR
self.lambdas: Float[Tensor, "T"] = self.sigmas / self.alphas
self._non_trainable_modules = NonTrainableModules(
pipe=pipe,
pipe_phi=pipe_phi,
controlnet=controlnet,
)
@property
def pipe(self) -> StableDiffusionPipeline:
return self._non_trainable_modules.pipe
@property
def pipe_phi(self) -> StableDiffusionPipeline:
if self._non_trainable_modules.pipe_phi is None:
raise RuntimeError("phi model is not available.")
return self._non_trainable_modules.pipe_phi
@property
def controlnet(self) -> ControlNetModel:
if self._non_trainable_modules.controlnet is None:
raise RuntimeError("ControlNet model is not available.")
return self._non_trainable_modules.controlnet
def prepare_pipe(self, pipe: StableDiffusionPipeline):
if self.cfg.enable_memory_efficient_attention:
if parse_version(torch.__version__) >= parse_version("2"):
threestudio.info(
"PyTorch2.0 uses memory efficient attention by default."
)
elif not is_xformers_available():
threestudio.warn(
"xformers is not available, memory efficient attention is not enabled."
)
else:
pipe.enable_xformers_memory_efficient_attention()
if self.cfg.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
if self.cfg.enable_attention_slicing:
pipe.enable_attention_slicing(1)
if self.cfg.enable_channels_last_format:
pipe.unet.to(memory_format=torch.channels_last)
# FIXME: pipe.__call__ requires text_encoder.dtype
# pipe.text_encoder.to("meta")
cleanup()
pipe.vae.eval()
pipe.unet.eval()
enable_gradient(pipe.vae, enabled=False)
enable_gradient(pipe.unet, enabled=False)
# disable progress bar
pipe.set_progress_bar_config(disable=True)
def configure_pipe_token_merging(self, pipe: StableDiffusionPipeline):
if self.cfg.token_merging:
import tomesd
tomesd.apply_patch(pipe.unet, **self.cfg.token_merging_params)
@torch.cuda.amp.autocast(enabled=False)
def forward_unet(
self,
unet: UNet2DConditionModel,
latents: Float[Tensor, "..."],
t: Int[Tensor, "..."],
encoder_hidden_states: Float[Tensor, "..."],
class_labels: Optional[Float[Tensor, "..."]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
down_block_additional_residuals: Optional[Float[Tensor, "..."]] = None,
mid_block_additional_residual: Optional[Float[Tensor, "..."]] = None,
velocity_to_epsilon: bool = False,
) -> Float[Tensor, "..."]:
input_dtype = latents.dtype
pred = unet(
latents.to(unet.dtype),
t.to(unet.dtype),
encoder_hidden_states=encoder_hidden_states.to(unet.dtype),
class_labels=class_labels,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
).sample
if velocity_to_epsilon:
pred = latents * self.sigmas[t].view(-1, 1, 1, 1) + pred * self.alphas[
t
].view(-1, 1, 1, 1)
return pred.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def vae_encode(
self, vae: AutoencoderKL, imgs: Float[Tensor, "B 3 H W"], mode=False
) -> Float[Tensor, "B 4 Hl Wl"]:
# expect input in [-1, 1]
input_dtype = imgs.dtype
posterior = vae.encode(imgs.to(vae.dtype)).latent_dist
if mode:
latents = posterior.mode()
else:
latents = posterior.sample()
latents = latents * vae.config.scaling_factor
return latents.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def vae_decode(
self, vae: AutoencoderKL, latents: Float[Tensor, "B 4 Hl Wl"]
) -> Float[Tensor, "B 3 H W"]:
# output in [0, 1]
input_dtype = latents.dtype
latents = 1 / vae.config.scaling_factor * latents
image = vae.decode(latents.to(vae.dtype)).sample
image = (image * 0.5 + 0.5).clamp(0, 1)
return image.to(input_dtype)
@contextmanager
def disable_unet_class_embedding(self, unet: UNet2DConditionModel):
class_embedding = unet.class_embedding
try:
unet.class_embedding = None
yield unet
finally:
unet.class_embedding = class_embedding
@contextmanager
def set_scheduler(
self, pipe: StableDiffusionPipeline, scheduler_class: Any, **kwargs
):
scheduler_orig = pipe.scheduler
pipe.scheduler = scheduler_class.from_config(scheduler_orig.config, **kwargs)
yield pipe
pipe.scheduler = scheduler_orig
def get_eps_pretrain(
self,
latents_noisy: Float[Tensor, "B 4 Hl Wl"],
t: Int[Tensor, "B"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
) -> Float[Tensor, "B 4 Hl Wl"]:
batch_size = latents_noisy.shape[0]
if prompt_utils.use_perp_neg:
(
text_embeddings,
neg_guidance_weights,
) = prompt_utils.get_text_embeddings_perp_neg(
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
)
with torch.no_grad():
with self.disable_unet_class_embedding(self.pipe.unet) as unet:
noise_pred = self.forward_unet(
unet,
torch.cat([latents_noisy] * 4, dim=0),
torch.cat([t] * 4, dim=0),
encoder_hidden_states=text_embeddings,
cross_attention_kwargs={"scale": 0.0}
if self.vsd_share_model
else None,
velocity_to_epsilon=self.pipe.scheduler.config.prediction_type
== "v_prediction",
) # (4B, 3, Hl, Wl)
noise_pred_text = noise_pred[:batch_size]
noise_pred_uncond = noise_pred[batch_size : batch_size * 2]
noise_pred_neg = noise_pred[batch_size * 2 :]
e_pos = noise_pred_text - noise_pred_uncond
accum_grad = 0
n_negative_prompts = neg_guidance_weights.shape[-1]
for i in range(n_negative_prompts):
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
accum_grad += neg_guidance_weights[:, i].view(
-1, 1, 1, 1
) * perpendicular_component(e_i_neg, e_pos)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
e_pos + accum_grad
)
else:
text_embeddings = prompt_utils.get_text_embeddings(
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
)
with torch.no_grad():
with self.disable_unet_class_embedding(self.pipe.unet) as unet:
noise_pred = self.forward_unet(
unet,
torch.cat([latents_noisy] * 2, dim=0),
torch.cat([t] * 2, dim=0),
encoder_hidden_states=text_embeddings,
cross_attention_kwargs={"scale": 0.0}
if self.vsd_share_model
else None,
velocity_to_epsilon=self.pipe.scheduler.config.prediction_type
== "v_prediction",
)
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
def get_eps_phi(
self,
latents_noisy: Float[Tensor, "B 4 Hl Wl"],
t: Int[Tensor, "B"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
camera_condition: Float[Tensor, "B ..."],
) -> Float[Tensor, "B 4 Hl Wl"]:
batch_size = latents_noisy.shape[0]
# not using view-dependent prompting in LoRA
text_embeddings, _ = prompt_utils.get_text_embeddings(
elevation, azimuth, camera_distances, view_dependent_prompting=False
).chunk(2)
with torch.no_grad():
noise_pred = self.forward_unet(
self.pipe_phi.unet,
torch.cat([latents_noisy] * 2, dim=0),
torch.cat([t] * 2, dim=0),
encoder_hidden_states=torch.cat([text_embeddings] * 2, dim=0),
class_labels=torch.cat(
[
camera_condition.view(batch_size, -1),
torch.zeros_like(camera_condition.view(batch_size, -1)),
],
dim=0,
)
if self.cfg.vsd_use_camera_condition
else None,
cross_attention_kwargs={"scale": 1.0},
velocity_to_epsilon=self.pipe_phi.scheduler.config.prediction_type
== "v_prediction",
)
noise_pred_camera, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.vsd_guidance_scale_phi * (
noise_pred_camera - noise_pred_uncond
)
return noise_pred
def train_phi(
self,
latents: Float[Tensor, "B 4 Hl Wl"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
camera_condition: Float[Tensor, "B ..."],
):
B = latents.shape[0]
latents = latents.detach().repeat(
self.cfg.vsd_lora_n_timestamp_samples, 1, 1, 1
)
num_train_timesteps = self.pipe_phi.scheduler.config.num_train_timesteps
t = torch.randint(
int(num_train_timesteps * 0.0),
int(num_train_timesteps * 1.0),
[B * self.cfg.vsd_lora_n_timestamp_samples],
dtype=torch.long,
device=self.device,
)
noise = torch.randn_like(latents)
latents_noisy = self.pipe_phi.scheduler.add_noise(latents, noise, t)
if self.pipe_phi.scheduler.config.prediction_type == "epsilon":
target = noise
elif self.pipe_phi.scheduler.prediction_type == "v_prediction":
target = self.pipe_phi.scheduler.get_velocity(latents, noise, t)
else:
raise ValueError(
f"Unknown prediction type {self.pipe_phi.scheduler.prediction_type}"
)
# not using view-dependent prompting in LoRA
text_embeddings, _ = prompt_utils.get_text_embeddings(
elevation, azimuth, camera_distances, view_dependent_prompting=False
).chunk(2)
if (
self.cfg.vsd_use_camera_condition
and self.cfg.vsd_lora_cfg_training
and random.random() < 0.1
):
camera_condition = torch.zeros_like(camera_condition)
noise_pred = self.forward_unet(
self.pipe_phi.unet,
latents_noisy,
t,
encoder_hidden_states=text_embeddings.repeat(
self.cfg.vsd_lora_n_timestamp_samples, 1, 1
),
class_labels=camera_condition.view(B, -1).repeat(
self.cfg.vsd_lora_n_timestamp_samples, 1
)
if self.cfg.vsd_use_camera_condition
else None,
cross_attention_kwargs={"scale": 1.0},
)
return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
def forward(
self,
rgb: Float[Tensor, "B H W C"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
mvp_mtx: Float[Tensor, "B 4 4"],
c2w: Float[Tensor, "B 4 4"],
rgb_as_latents=False,
**kwargs,
):
batch_size = rgb.shape[0]
rgb_BCHW = rgb.permute(0, 3, 1, 2)
latents: Float[Tensor, "B 4 Hl Wl"]
if rgb_as_latents:
# treat input rgb as latents
# input rgb should be in range [-1, 1]
latents = F.interpolate(
rgb_BCHW, (64, 64), mode="bilinear", align_corners=False
)
else:
# treat input rgb as rgb
# input rgb should be in range [0, 1]
rgb_BCHW = F.interpolate(
rgb_BCHW, (512, 512), mode="bilinear", align_corners=False
)
# encode image into latents with vae
latents = self.vae_encode(self.pipe.vae, rgb_BCHW * 2.0 - 1.0)
# sample timestep
# use the same timestep for each batch
assert self.min_step is not None and self.max_step is not None
t = torch.randint(
self.min_step,
self.max_step + 1,
[1],
dtype=torch.long,
device=self.device,
).repeat(batch_size)
# sample noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
eps_pretrain = self.get_eps_pretrain(
latents_noisy, t, prompt_utils, elevation, azimuth, camera_distances
)
latents_1step_orig = (
1
/ self.alphas[t].view(-1, 1, 1, 1)
* (latents_noisy - self.sigmas[t].view(-1, 1, 1, 1) * eps_pretrain)
).detach()
if self.cfg.guidance_type == "sds":
eps_phi = noise
elif self.cfg.guidance_type == "vsd":
if self.cfg.vsd_camera_condition_type == "extrinsics":
camera_condition = c2w
elif self.cfg.vsd_camera_condition_type == "mvp":
camera_condition = mvp_mtx
elif self.cfg.vsd_camera_condition_type == "spherical":
camera_condition = torch.stack(
[
torch.deg2rad(elevation),
torch.sin(torch.deg2rad(azimuth)),
torch.cos(torch.deg2rad(azimuth)),
camera_distances,
],
dim=-1,
)
else:
raise ValueError(
f"Unknown camera_condition_type {self.cfg.vsd_camera_condition_type}"
)
eps_phi = self.get_eps_phi(
latents_noisy,
t,
prompt_utils,
elevation,
azimuth,
camera_distances,
camera_condition,
)
loss_train_phi = self.train_phi(
latents,
prompt_utils,
elevation,
azimuth,
camera_distances,
camera_condition,
)
if self.cfg.weighting_strategy == "dreamfusion":
w = (1.0 - self.alphas[t]).view(-1, 1, 1, 1)
elif self.cfg.weighting_strategy == "uniform":
w = 1.0
elif self.cfg.weighting_strategy == "fantasia3d":
w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1)
else:
raise ValueError(
f"Unknown weighting strategy: {self.cfg.weighting_strategy}"
)
grad = w * (eps_pretrain - eps_phi)
if self.grad_clip_val is not None:
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
# reparameterization trick:
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
target = (latents - grad).detach()
loss_sd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
guidance_out = {
"loss_sd": loss_sd,
"grad_norm": grad.norm(),
"timesteps": t,
"min_step": self.min_step,
"max_step": self.max_step,
"latents": latents,
"latents_1step_orig": latents_1step_orig,
"rgb": rgb_BCHW.permute(0, 2, 3, 1),
"weights": w,
"lambdas": self.lambdas[t],
}
if self.cfg.return_rgb_1step_orig:
with torch.no_grad():
rgb_1step_orig = self.vae_decode(
self.pipe.vae, latents_1step_orig
).permute(0, 2, 3, 1)
guidance_out.update({"rgb_1step_orig": rgb_1step_orig})
if self.cfg.return_rgb_multistep_orig:
with self.set_scheduler(
self.pipe,
DPMSolverSinglestepScheduler,
solver_order=1,
num_train_timesteps=int(t[0]),
) as pipe:
text_embeddings = prompt_utils.get_text_embeddings(
elevation,
azimuth,
camera_distances,
self.cfg.view_dependent_prompting,
)
text_embeddings_cond, text_embeddings_uncond = text_embeddings.chunk(2)
with torch.cuda.amp.autocast(enabled=False):
latents_multistep_orig = pipe(
num_inference_steps=self.cfg.n_rgb_multistep_orig_steps,
guidance_scale=self.cfg.guidance_scale,
eta=1.0,
latents=latents_noisy.to(pipe.unet.dtype),
prompt_embeds=text_embeddings_cond.to(pipe.unet.dtype),
negative_prompt_embeds=text_embeddings_uncond.to(
pipe.unet.dtype
),
cross_attention_kwargs={"scale": 0.0}
if self.vsd_share_model
else None,
output_type="latent",
).images.to(latents.dtype)
with torch.no_grad():
rgb_multistep_orig = self.vae_decode(
self.pipe.vae, latents_multistep_orig
)
guidance_out.update(
{
"latents_multistep_orig": latents_multistep_orig,
"rgb_multistep_orig": rgb_multistep_orig.permute(0, 2, 3, 1),
}
)
if self.cfg.guidance_type == "vsd":
guidance_out.update(
{
"loss_train_phi": loss_train_phi,
}
)
return guidance_out
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
# clip grad for stable training as demonstrated in
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
# http://arxiv.org/abs/2303.15413
if self.cfg.grad_clip is not None:
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
self.min_step = int(
self.num_train_timesteps * C(self.cfg.min_step_percent, epoch, global_step)
)
self.max_step = int(
self.num_train_timesteps * C(self.cfg.max_step_percent, epoch, global_step)
)

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More