mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-22 05:48:56 -05:00
chores: rebase commits
This commit is contained in:
commit
50ecd13a88
12
.editorconfig
Normal file
12
.editorconfig
Normal file
@ -0,0 +1,12 @@
|
||||
root = true
|
||||
|
||||
[*.py]
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
195
.gitignore
vendored
Normal file
195
.gitignore
vendored
Normal file
@ -0,0 +1,195 @@
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/python
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
### Python Patch ###
|
||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||
poetry.toml
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# LSP config files
|
||||
pyrightconfig.json
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/python
|
||||
|
||||
.vscode/
|
||||
.threestudio_cache/
|
||||
outputs/
|
||||
outputs-gradio/
|
||||
|
||||
# pretrained model weights
|
||||
*.ckpt
|
||||
*.pt
|
||||
*.pth
|
||||
|
||||
# wandb
|
||||
wandb/
|
||||
|
||||
load/tets/256_tets.npz
|
||||
|
||||
# dataset
|
||||
dataset/
|
||||
load/
|
34
.pre-commit-config.yaml
Normal file
34
.pre-commit-config.yaml
Normal file
@ -0,0 +1,34 @@
|
||||
default_language_version:
|
||||
python: python3
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-ast
|
||||
- id: check-merge-conflict
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
args: [--markdown-linebreak-ext=md]
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
exclude: README.md
|
||||
args: ["--profile", "black"]
|
||||
|
||||
# temporarily disable static type checking
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v1.2.0
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"]
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 deepseek-ai
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
21
LICENSE-CODE
Normal file
21
LICENSE-CODE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 DeepSeek
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
91
LICENSE-MODEL
Normal file
91
LICENSE-MODEL
Normal file
@ -0,0 +1,91 @@
|
||||
DEEPSEEK LICENSE AGREEMENT
|
||||
|
||||
Version 1.0, 23 October 2023
|
||||
|
||||
Copyright (c) 2023 DeepSeek
|
||||
|
||||
Section I: PREAMBLE
|
||||
|
||||
Large generative models are being widely adopted and used, and have the potential to transform the way individuals conceive and benefit from AI or ML technologies.
|
||||
|
||||
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
|
||||
|
||||
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for content generation.
|
||||
|
||||
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this agreement aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
||||
|
||||
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
|
||||
|
||||
NOW THEREFORE, You and DeepSeek agree as follows:
|
||||
|
||||
1. Definitions
|
||||
"License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
|
||||
"Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
||||
"Output" means the results of operating a Model as embodied in informational content resulting therefrom.
|
||||
"Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
|
||||
"Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
|
||||
"Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
||||
"Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
|
||||
"DeepSeek" (or "we") means Beijing DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd., Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. and/or any of their affiliates.
|
||||
"You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, etc.
|
||||
"Third Parties" means individuals or legal entities that are not under common control with DeepSeek or You.
|
||||
|
||||
Section II: INTELLECTUAL PROPERTY RIGHTS
|
||||
|
||||
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by DeepSeek that are necessarily infringed by its contribution(s). If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or works shall terminate as of the date such litigation is asserted or filed.
|
||||
|
||||
|
||||
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
||||
|
||||
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
|
||||
a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
|
||||
b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
|
||||
c. You must cause any modified files to carry prominent notices stating that You changed the files;
|
||||
d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
|
||||
e. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. – for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
||||
|
||||
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
|
||||
|
||||
6. The Output You Generate. Except as set forth herein, DeepSeek claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
|
||||
|
||||
Section IV: OTHER PROVISIONS
|
||||
|
||||
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, DeepSeek reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License.
|
||||
|
||||
8. Trademarks and related. Nothing in this License permits You to make use of DeepSeek’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by DeepSeek.
|
||||
|
||||
9. Personal information, IP rights and related. This Model may contain personal information and works with IP rights. You commit to complying with applicable laws and regulations in the handling of personal information and the use of such works. Please note that DeepSeek's license granted to you to use the Model does not imply that you have obtained a legitimate basis for processing the related information or works. As an independent personal information processor and IP rights user, you need to ensure full compliance with relevant legal and regulatory requirements when handling personal information and works with IP rights that may be contained in the Model, and are willing to assume solely any risks and consequences that may arise from that.
|
||||
|
||||
10. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, DeepSeek provides the Model and the Complementary Material on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
|
||||
|
||||
11. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall DeepSeek be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if DeepSeek has been advised of the possibility of such damages.
|
||||
|
||||
12. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of DeepSeek, and only if You agree to indemnify, defend, and hold DeepSeek harmless for any liability incurred by, or claims asserted against, DeepSeek by reason of your accepting any such warranty or additional liability.
|
||||
|
||||
13. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
||||
|
||||
14. Governing Law and Jurisdiction. This agreement will be governed and construed under PRC laws without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this agreement. The courts located in the domicile of Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. shall have exclusive jurisdiction of any dispute arising out of this agreement.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Attachment A
|
||||
|
||||
Use Restrictions
|
||||
|
||||
You agree not to use the Model or Derivatives of the Model:
|
||||
|
||||
- In any way that violates any applicable national or international law or regulation or infringes upon the lawful rights and interests of any third party;
|
||||
- For military use in any way;
|
||||
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
||||
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
|
||||
- To generate or disseminate inappropriate content subject to applicable regulatory requirements;
|
||||
- To generate or disseminate personal identifiable information without due authorization or for unreasonable use;
|
||||
- To defame, disparage or otherwise harass others;
|
||||
- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
|
||||
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
|
||||
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
||||
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories.
|
175
README.md
Normal file
175
README.md
Normal file
@ -0,0 +1,175 @@
|
||||
# DreamCraft3D
|
||||
|
||||
[**Paper**](https://arxiv.org/abs/2310.16818) | [**Project Page**](https://mrtornado24.github.io/DreamCraft3D/) | [**Youtube video**](https://www.youtube.com/watch?v=0FazXENkQms)
|
||||
|
||||
Official implementation of DreamCraft3D: Hierarchical 3D Generation with Bootstrapped Diffusion Prior
|
||||
|
||||
[Jingxiang Sun](https://mrtornado24.github.io/), [Bo Zhang](https://bo-zhang.me/), [Ruizhi Shao](https://dsaurus.github.io/saurus/), [Lizhen Wang](https://lizhenwangt.github.io/), [Wen Liu](https://github.com/StevenLiuWen), [Zhenda Xie](https://zdaxie.github.io/), [Yebin Liu](https://liuyebin.com/)
|
||||
|
||||
|
||||
Abstract: *We present DreamCraft3D, a hierarchical 3D content generation method that produces high-fidelity and coherent 3D objects. We tackle the problem by leveraging a 2D reference image to guide the stages of geometry sculpting and texture boosting. A central focus of this work is to address the consistency issue that existing
|
||||
works encounter. To sculpt geometries that render coherently, we perform score
|
||||
distillation sampling via a view-dependent diffusion model. This 3D prior, alongside several training strategies, prioritizes the geometry consistency but compromises the texture fidelity. We further propose **Bootstrapped Score Distillation** to
|
||||
specifically boost the texture. We train a personalized diffusion model, Dreambooth, on the augmented renderings of the scene, imbuing it with 3D knowledge
|
||||
of the scene being optimized. The score distillation from this 3D-aware diffusion prior provides view-consistent guidance for the scene. Notably, through an
|
||||
alternating optimization of the diffusion prior and 3D scene representation, we
|
||||
achieve mutually reinforcing improvements: the optimized 3D scene aids in training the scene-specific diffusion model, which offers increasingly view-consistent
|
||||
guidance for 3D optimization. The optimization is thus bootstrapped and leads
|
||||
to substantial texture boosting. With tailored 3D priors throughout the hierarchical generation, DreamCraft3D generates coherent 3D objects with photorealistic
|
||||
renderings, advancing the state-of-the-art in 3D content generation.*
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/repo_static_v2.png">
|
||||
</p>
|
||||
|
||||
|
||||
## Method Overview
|
||||
<p align="center">
|
||||
<img src="assets/diagram-1.png">
|
||||
</p>
|
||||
|
||||
|
||||
<!-- https://github.com/MrTornado24/DreamCraft3D/assets/45503891/8e70610c-d812-4544-86bf-7f8764e41067
|
||||
|
||||
|
||||
|
||||
https://github.com/MrTornado24/DreamCraft3D/assets/45503891/b1e8ae54-1afd-4e0f-88f7-9bd5b70fd44d
|
||||
|
||||
|
||||
|
||||
https://github.com/MrTornado24/DreamCraft3D/assets/45503891/ead40f9b-d7ee-4ee8-8d98-dbd0b8fbab97 -->
|
||||
|
||||
## Installation
|
||||
### Install threestudio
|
||||
|
||||
**This part is the same as original threestudio. Skip it if you already have installed the environment.**
|
||||
|
||||
See [installation.md](docs/installation.md) for additional information, including installation via Docker.
|
||||
|
||||
- You must have an NVIDIA graphics card with at least 20GB VRAM and have [CUDA](https://developer.nvidia.com/cuda-downloads) installed.
|
||||
- Install `Python >= 3.8`.
|
||||
- (Optional, Recommended) Create a virtual environment:
|
||||
|
||||
```sh
|
||||
python3 -m virtualenv venv
|
||||
. venv/bin/activate
|
||||
|
||||
# Newer pip versions, e.g. pip-23.x, can be much faster than old versions, e.g. pip-20.x.
|
||||
# For instance, it caches the wheels of git packages to avoid unnecessarily rebuilding them later.
|
||||
python3 -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
- Install `PyTorch >= 1.12`. We have tested on `torch1.12.1+cu113` and `torch2.0.0+cu118`, but other versions should also work fine.
|
||||
|
||||
```sh
|
||||
# torch1.12.1+cu113
|
||||
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
# or torch2.0.0+cu118
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
- (Optional, Recommended) Install ninja to speed up the compilation of CUDA extensions:
|
||||
|
||||
```sh
|
||||
pip install ninja
|
||||
```
|
||||
|
||||
- Install dependencies:
|
||||
|
||||
```sh
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
## Quickstart
|
||||
Our model is trained in multiple stages. You can run it by
|
||||
```sh
|
||||
prompt="a brightly colored mushroom growing on a log"
|
||||
image_path="load/images/mushroom_log_rgba.png"
|
||||
|
||||
# --------- Stage 1 (NeRF & NeuS) --------- #
|
||||
python launch.py --config configs/dreamcraft3d-coarse-nerf.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path"
|
||||
|
||||
ckpt=outputs/dreamcraft3d-coarse-nerf/$prompt@LAST/ckpts/last.ckpt
|
||||
python launch.py --config configs/dreamcraft3d-coarse-neus.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.weights="$ckpt"
|
||||
|
||||
# --------- Stage 2 (Geometry Refinement) --------- #
|
||||
ckpt=outputs/dreamcraft3d-coarse-neus/$prompt@LAST/ckpts/last.ckpt
|
||||
python launch.py --config configs/dreamcraft3d-geometry.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.geometry_convert_from="$ckpt"
|
||||
|
||||
|
||||
# --------- Stage 3 (Texture Refinement) --------- #
|
||||
ckpt=outputs/dreamcraft3d-geometry/$prompt@LAST/ckpts/last.ckpt
|
||||
python launch.py --config configs/dreamcraft3d-texture.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.geometry_convert_from="$ckpt"
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>[Optional] If the "Janus problem" arises in Stage 1, consider training a custom Text2Image model.</summary>
|
||||
|
||||
First, generate multi-view images from a single reference image by Zero123++.
|
||||
|
||||
```sh
|
||||
python threestudio/scripts/img_to_mv.py --image_path 'load/mushroom.png' --save_path '.cache/temp' --prompt 'a photo of mushroom' --superres
|
||||
```
|
||||
Train a personalized DeepFloyd model by DreamBooth Lora. Please check if the generated mv images above are reasonable.
|
||||
|
||||
```sh
|
||||
export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
|
||||
export INSTANCE_DIR=".cache/temp"
|
||||
export OUTPUT_DIR=".cache/if_dreambooth_mushroom"
|
||||
|
||||
accelerate launch threestudio/scripts/train_dreambooth_lora.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--instance_prompt="a sks mushroom" \
|
||||
--resolution=64 \
|
||||
--train_batch_size=4 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--learning_rate=5e-6 \
|
||||
--scale_lr \
|
||||
--max_train_steps=1200 \
|
||||
--checkpointing_steps=600 \
|
||||
--pre_compute_text_embeddings \
|
||||
--tokenizer_max_length=77 \
|
||||
--text_encoder_use_attention_mask
|
||||
```
|
||||
|
||||
The personalized DeepFloyd model lora is save at `.cache/if_dreambooth_mushroom`. Now you can replace the guidance the training scripts by
|
||||
|
||||
```sh
|
||||
# --------- Stage 1 (NeRF & NeuS) --------- #
|
||||
python launch.py --config configs/dreamcraft3d-coarse-nerf.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.guidance.lora_weights_path=".cache/if_dreambooth_mushroom"
|
||||
```
|
||||
</details>
|
||||
|
||||
## Tips
|
||||
- **Memory Usage**. We run the default configs on 40G A100 GPUs. For reducing memory usage, you can reduce the rendering resolution of NeuS by ```data.height=128 data.width=128 data.random_camera.height=128 data.random_camera.width=128```. You can also reduce resolution for other stages in the same way.
|
||||
|
||||
|
||||
## Todo
|
||||
|
||||
- [x] Release the reorganized code.
|
||||
- [ ] Clean the original dreambooth training code.
|
||||
- [ ] Provide some running results and checkpoints.
|
||||
|
||||
## Credits
|
||||
This code is built on the amazing open-source [threestudio-project](https://github.com/threestudio-project/threestudio).
|
||||
|
||||
## Related links
|
||||
|
||||
- [DreamFusion](https://dreamfusion3d.github.io/)
|
||||
- [Magic3D](https://research.nvidia.com/labs/dir/magic3d/)
|
||||
- [Make-it-3D](https://make-it-3d.github.io/)
|
||||
- [Magic123](https://guochengqian.github.io/project/magic123/)
|
||||
- [ProlificDreamer](https://ml.cs.tsinghua.edu.cn/prolificdreamer/)
|
||||
- [DreamBooth](https://dreambooth.github.io/)
|
||||
|
||||
## BibTeX
|
||||
|
||||
```bibtex
|
||||
@article{sun2023dreamcraft3d,
|
||||
title={Dreamcraft3d: Hierarchical 3d generation with bootstrapped diffusion prior},
|
||||
author={Sun, Jingxiang and Zhang, Bo and Shao, Ruizhi and Wang, Lizhen and Liu, Wen and Xie, Zhenda and Liu, Yebin},
|
||||
journal={arXiv preprint arXiv:2310.16818},
|
||||
year={2023}
|
||||
}
|
||||
```
|
BIN
assets/diagram-1.png
Normal file
BIN
assets/diagram-1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 396 KiB |
BIN
assets/logo.png
Normal file
BIN
assets/logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 354 KiB |
BIN
assets/repo_demo_0.mp4
Normal file
BIN
assets/repo_demo_0.mp4
Normal file
Binary file not shown.
BIN
assets/repo_demo_01.mp4
Normal file
BIN
assets/repo_demo_01.mp4
Normal file
Binary file not shown.
BIN
assets/repo_demo_02.mp4
Normal file
BIN
assets/repo_demo_02.mp4
Normal file
Binary file not shown.
BIN
assets/repo_static_v2.png
Normal file
BIN
assets/repo_static_v2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 21 MiB |
BIN
assets/result_mushroom.mp4
Normal file
BIN
assets/result_mushroom.mp4
Normal file
Binary file not shown.
159
configs/dreamcraft3d-coarse-nerf.yaml
Normal file
159
configs/dreamcraft3d-coarse-nerf.yaml
Normal file
@ -0,0 +1,159 @@
|
||||
name: "dreamcraft3d-coarse-nerf"
|
||||
tag: "${rmspace:${system.prompt_processor.prompt},_}"
|
||||
exp_root_dir: "outputs"
|
||||
seed: 0
|
||||
|
||||
data_type: "single-image-datamodule"
|
||||
data:
|
||||
image_path: ./load/images/hamburger_rgba.png
|
||||
height: [128, 384]
|
||||
width: [128, 384]
|
||||
resolution_milestones: [3000]
|
||||
default_elevation_deg: 0.0
|
||||
default_azimuth_deg: 0.0
|
||||
default_camera_distance: 3.8
|
||||
default_fovy_deg: 20.0
|
||||
requires_depth: true
|
||||
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
|
||||
random_camera:
|
||||
height: [128, 384]
|
||||
width: [128, 384]
|
||||
batch_size: [1, 1]
|
||||
resolution_milestones: [3000]
|
||||
eval_height: 512
|
||||
eval_width: 512
|
||||
eval_batch_size: 1
|
||||
elevation_range: [-10, 45]
|
||||
azimuth_range: [-180, 180]
|
||||
camera_distance_range: [3.8, 3.8]
|
||||
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
|
||||
progressive_until: 200
|
||||
camera_perturb: 0.0
|
||||
center_perturb: 0.0
|
||||
up_perturb: 0.0
|
||||
eval_elevation_deg: ${data.default_elevation_deg}
|
||||
eval_camera_distance: ${data.default_camera_distance}
|
||||
eval_fovy_deg: ${data.default_fovy_deg}
|
||||
batch_uniform_azimuth: false
|
||||
n_val_views: 40
|
||||
n_test_views: 120
|
||||
|
||||
system_type: "dreamcraft3d-system"
|
||||
system:
|
||||
stage: coarse
|
||||
geometry_type: "implicit-volume"
|
||||
geometry:
|
||||
radius: 2.0
|
||||
normal_type: "finite_difference"
|
||||
|
||||
# the density initialization proposed in the DreamFusion paper
|
||||
# does not work very well
|
||||
# density_bias: "blob_dreamfusion"
|
||||
# density_activation: exp
|
||||
# density_blob_scale: 5.
|
||||
# density_blob_std: 0.2
|
||||
|
||||
# use Magic3D density initialization instead
|
||||
density_bias: "blob_magic3d"
|
||||
density_activation: softplus
|
||||
density_blob_scale: 10.
|
||||
density_blob_std: 0.5
|
||||
|
||||
# coarse to fine hash grid encoding
|
||||
# to ensure smooth analytic normals
|
||||
pos_encoding_config:
|
||||
otype: ProgressiveBandHashGrid
|
||||
n_levels: 16
|
||||
n_features_per_level: 2
|
||||
log2_hashmap_size: 19
|
||||
base_resolution: 16
|
||||
per_level_scale: 1.447269237440378 # max resolution 4096
|
||||
start_level: 8 # resolution ~200
|
||||
start_step: 2000
|
||||
update_steps: 500
|
||||
|
||||
material_type: "no-material"
|
||||
material:
|
||||
requires_normal: true
|
||||
|
||||
background_type: "solid-color-background"
|
||||
|
||||
renderer_type: "nerf-volume-renderer"
|
||||
renderer:
|
||||
radius: ${system.geometry.radius}
|
||||
num_samples_per_ray: 512
|
||||
return_normal_perturb: true
|
||||
return_comp_normal: ${cmaxgt0:${system.loss.lambda_normal_smooth}}
|
||||
|
||||
prompt_processor_type: "deep-floyd-prompt-processor"
|
||||
prompt_processor:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
prompt: ???
|
||||
use_perp_neg: true
|
||||
|
||||
guidance_type: "deep-floyd-guidance"
|
||||
guidance:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
guidance_scale: 20
|
||||
min_step_percent: [0, 0.7, 0.2, 200]
|
||||
max_step_percent: [0, 0.85, 0.5, 200]
|
||||
|
||||
guidance_3d_type: "stable-zero123-guidance"
|
||||
guidance_3d:
|
||||
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
|
||||
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
|
||||
cond_image_path: ${data.image_path}
|
||||
cond_elevation_deg: ${data.default_elevation_deg}
|
||||
cond_azimuth_deg: ${data.default_azimuth_deg}
|
||||
cond_camera_distance: ${data.default_camera_distance}
|
||||
guidance_scale: 5.0
|
||||
min_step_percent: [0, 0.7, 0.2, 200] # (start_iter, start_val, end_val, end_iter)
|
||||
max_step_percent: [0, 0.85, 0.5, 200]
|
||||
|
||||
freq:
|
||||
n_ref: 2
|
||||
ref_only_steps: 0
|
||||
ref_or_guidance: "alternate"
|
||||
no_diff_steps: 0
|
||||
guidance_eval: 0
|
||||
|
||||
loggers:
|
||||
wandb:
|
||||
enable: false
|
||||
project: "threestudio"
|
||||
|
||||
loss:
|
||||
lambda_sd: 0.1
|
||||
lambda_3d_sd: 0.1
|
||||
lambda_rgb: 1000.0
|
||||
lambda_mask: 100.0
|
||||
lambda_mask_binary: 0.0
|
||||
lambda_depth: 0.0
|
||||
lambda_depth_rel: 0.05
|
||||
lambda_normal: 0.0
|
||||
lambda_normal_smooth: 1.0
|
||||
lambda_3d_normal_smooth: [2000, 5., 1., 2001]
|
||||
lambda_orient: [2000, 1., 10., 2001]
|
||||
lambda_sparsity: [2000, 0.1, 10., 2001]
|
||||
lambda_opaque: [2000, 0.1, 10., 2001]
|
||||
lambda_clip: 0.0
|
||||
|
||||
optimizer:
|
||||
name: Adam
|
||||
args:
|
||||
lr: 0.01
|
||||
betas: [0.9, 0.99]
|
||||
eps: 1.e-8
|
||||
|
||||
trainer:
|
||||
max_steps: 5000
|
||||
log_every_n_steps: 1
|
||||
num_sanity_val_steps: 0
|
||||
val_check_interval: 200
|
||||
enable_progress_bar: true
|
||||
precision: 16-mixed
|
||||
|
||||
checkpoint:
|
||||
save_last: true
|
||||
save_top_k: -1
|
||||
every_n_train_steps: ${trainer.max_steps}
|
155
configs/dreamcraft3d-coarse-neus.yaml
Normal file
155
configs/dreamcraft3d-coarse-neus.yaml
Normal file
@ -0,0 +1,155 @@
|
||||
name: "dreamcraft3d-coarse-neus"
|
||||
tag: "${rmspace:${system.prompt_processor.prompt},_}"
|
||||
exp_root_dir: "outputs"
|
||||
seed: 0
|
||||
|
||||
data_type: "single-image-datamodule"
|
||||
data:
|
||||
image_path: ./load/images/hamburger_rgba.png
|
||||
height: 256
|
||||
width: 256
|
||||
default_elevation_deg: 0.0
|
||||
default_azimuth_deg: 0.0
|
||||
default_camera_distance: 3.8
|
||||
default_fovy_deg: 20.0
|
||||
requires_depth: true
|
||||
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
|
||||
random_camera:
|
||||
height: 256
|
||||
width: 256
|
||||
batch_size: 1
|
||||
eval_height: 512
|
||||
eval_width: 512
|
||||
eval_batch_size: 1
|
||||
elevation_range: [-10, 45]
|
||||
azimuth_range: [-180, 180]
|
||||
camera_distance_range: [3.8, 3.8]
|
||||
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
|
||||
progressive_until: 0
|
||||
camera_perturb: 0.0
|
||||
center_perturb: 0.0
|
||||
up_perturb: 0.0
|
||||
eval_elevation_deg: ${data.default_elevation_deg}
|
||||
eval_camera_distance: ${data.default_camera_distance}
|
||||
eval_fovy_deg: ${data.default_fovy_deg}
|
||||
batch_uniform_azimuth: false
|
||||
n_val_views: 40
|
||||
n_test_views: 120
|
||||
|
||||
system_type: "dreamcraft3d-system"
|
||||
system:
|
||||
stage: coarse
|
||||
geometry_type: "implicit-sdf"
|
||||
geometry:
|
||||
radius: 2.0
|
||||
normal_type: "finite_difference"
|
||||
|
||||
sdf_bias: sphere
|
||||
sdf_bias_params: 0.5
|
||||
|
||||
# coarse to fine hash grid encoding
|
||||
pos_encoding_config:
|
||||
otype: HashGrid
|
||||
n_levels: 16
|
||||
n_features_per_level: 2
|
||||
log2_hashmap_size: 19
|
||||
base_resolution: 16
|
||||
per_level_scale: 1.447269237440378 # max resolution 4096
|
||||
start_level: 8 # resolution ~200
|
||||
start_step: 2000
|
||||
update_steps: 500
|
||||
|
||||
material_type: "no-material"
|
||||
material:
|
||||
requires_normal: true
|
||||
|
||||
background_type: "solid-color-background"
|
||||
|
||||
renderer_type: "neus-volume-renderer"
|
||||
renderer:
|
||||
radius: ${system.geometry.radius}
|
||||
num_samples_per_ray: 512
|
||||
cos_anneal_end_steps: ${trainer.max_steps}
|
||||
eval_chunk_size: 8192
|
||||
|
||||
prompt_processor_type: "deep-floyd-prompt-processor"
|
||||
prompt_processor:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
prompt: ???
|
||||
use_perp_neg: true
|
||||
|
||||
guidance_type: "deep-floyd-guidance"
|
||||
guidance:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
guidance_scale: 20
|
||||
min_step_percent: 0.2
|
||||
max_step_percent: 0.5
|
||||
|
||||
guidance_3d_type: "stable-zero123-guidance"
|
||||
guidance_3d:
|
||||
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
|
||||
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
|
||||
cond_image_path: ${data.image_path}
|
||||
cond_elevation_deg: ${data.default_elevation_deg}
|
||||
cond_azimuth_deg: ${data.default_azimuth_deg}
|
||||
cond_camera_distance: ${data.default_camera_distance}
|
||||
guidance_scale: 5.0
|
||||
min_step_percent: 0.2
|
||||
max_step_percent: 0.5
|
||||
|
||||
freq:
|
||||
n_ref: 2
|
||||
ref_only_steps: 0
|
||||
ref_or_guidance: "alternate"
|
||||
no_diff_steps: 0
|
||||
guidance_eval: 0
|
||||
|
||||
loggers:
|
||||
wandb:
|
||||
enable: false
|
||||
project: "threestudio"
|
||||
|
||||
loss:
|
||||
lambda_sd: 0.1
|
||||
lambda_3d_sd: 0.1
|
||||
lambda_rgb: 1000.0
|
||||
lambda_mask: 100.0
|
||||
lambda_mask_binary: 0.0
|
||||
lambda_depth: 0.0
|
||||
lambda_depth_rel: 0.05
|
||||
lambda_normal: 0.0
|
||||
lambda_normal_smooth: 0.0
|
||||
lambda_3d_normal_smooth: 0.0
|
||||
lambda_orient: 10.0
|
||||
lambda_sparsity: 0.1
|
||||
lambda_opaque: 0.1
|
||||
lambda_clip: 0.0
|
||||
lambda_eikonal: 0.0
|
||||
|
||||
optimizer:
|
||||
name: Adam
|
||||
args:
|
||||
betas: [0.9, 0.99]
|
||||
eps: 1.e-15
|
||||
params:
|
||||
geometry.encoding:
|
||||
lr: 0.01
|
||||
geometry.sdf_network:
|
||||
lr: 0.001
|
||||
geometry.feature_network:
|
||||
lr: 0.001
|
||||
renderer:
|
||||
lr: 0.001
|
||||
|
||||
trainer:
|
||||
max_steps: 5000
|
||||
log_every_n_steps: 1
|
||||
num_sanity_val_steps: 0
|
||||
val_check_interval: 200
|
||||
enable_progress_bar: true
|
||||
precision: 16-mixed
|
||||
|
||||
checkpoint:
|
||||
save_last: true
|
||||
save_top_k: -1
|
||||
every_n_train_steps: ${trainer.max_steps}
|
133
configs/dreamcraft3d-geometry.yaml
Normal file
133
configs/dreamcraft3d-geometry.yaml
Normal file
@ -0,0 +1,133 @@
|
||||
name: "dreamcraft3d-geometry"
|
||||
tag: "${rmspace:${system.prompt_processor.prompt},_}"
|
||||
exp_root_dir: "outputs"
|
||||
seed: 0
|
||||
|
||||
data_type: "single-image-datamodule"
|
||||
data:
|
||||
image_path: ./load/images/hamburger_rgba.png
|
||||
height: 1024
|
||||
width: 1024
|
||||
default_elevation_deg: 0.0
|
||||
default_azimuth_deg: 0.0
|
||||
default_camera_distance: 3.8
|
||||
default_fovy_deg: 20.0
|
||||
requires_depth: ${cmaxgt0orcmaxgt0:${system.loss.lambda_depth},${system.loss.lambda_depth_rel}}
|
||||
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
|
||||
use_mixed_camera_config: false
|
||||
random_camera:
|
||||
height: 1024
|
||||
width: 1024
|
||||
batch_size: 1
|
||||
eval_height: 1024
|
||||
eval_width: 1024
|
||||
eval_batch_size: 1
|
||||
elevation_range: [-10, 45]
|
||||
azimuth_range: [-180, 180]
|
||||
camera_distance_range: [3.8, 3.8]
|
||||
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
|
||||
progressive_until: 0
|
||||
camera_perturb: 0.0
|
||||
center_perturb: 0.0
|
||||
up_perturb: 0.0
|
||||
eval_elevation_deg: ${data.default_elevation_deg}
|
||||
eval_camera_distance: ${data.default_camera_distance}
|
||||
eval_fovy_deg: ${data.default_fovy_deg}
|
||||
batch_uniform_azimuth: false
|
||||
n_val_views: 40
|
||||
n_test_views: 120
|
||||
|
||||
system_type: "dreamcraft3d-system"
|
||||
system:
|
||||
stage: geometry
|
||||
use_mixed_camera_config: ${data.use_mixed_camera_config}
|
||||
geometry_convert_from: ???
|
||||
geometry_convert_inherit_texture: true
|
||||
geometry_type: "tetrahedra-sdf-grid"
|
||||
geometry:
|
||||
radius: 2.0 # consistent with coarse
|
||||
isosurface_resolution: 128
|
||||
isosurface_deformable_grid: true
|
||||
|
||||
material_type: "no-material"
|
||||
material:
|
||||
n_output_dims: 3
|
||||
|
||||
background_type: "solid-color-background"
|
||||
|
||||
renderer_type: "nvdiff-rasterizer"
|
||||
renderer:
|
||||
context_type: cuda
|
||||
|
||||
prompt_processor_type: "deep-floyd-prompt-processor"
|
||||
prompt_processor:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
prompt: ???
|
||||
use_perp_neg: true
|
||||
|
||||
guidance_type: "deep-floyd-guidance"
|
||||
guidance:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
guidance_scale: 20
|
||||
min_step_percent: 0.02
|
||||
max_step_percent: 0.5
|
||||
|
||||
guidance_3d_type: "stable-zero123-guidance"
|
||||
guidance_3d:
|
||||
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
|
||||
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
|
||||
cond_image_path: ${data.image_path}
|
||||
cond_elevation_deg: ${data.default_elevation_deg}
|
||||
cond_azimuth_deg: ${data.default_azimuth_deg}
|
||||
cond_camera_distance: ${data.default_camera_distance}
|
||||
guidance_scale: 5.0
|
||||
min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter)
|
||||
max_step_percent: 0.5
|
||||
|
||||
freq:
|
||||
n_ref: 2
|
||||
ref_only_steps: 0
|
||||
ref_or_guidance: "accumulate"
|
||||
no_diff_steps: 0
|
||||
guidance_eval: 0
|
||||
n_rgb: 4
|
||||
|
||||
loggers:
|
||||
wandb:
|
||||
enable: false
|
||||
project: "threestudio"
|
||||
|
||||
loss:
|
||||
lambda_sd: 0.1
|
||||
lambda_3d_sd: 0.1
|
||||
lambda_rgb: 1000.0
|
||||
lambda_mask: 100.0
|
||||
lambda_mask_binary: 0.0
|
||||
lambda_depth: 0.0
|
||||
lambda_depth_rel: 0.0
|
||||
lambda_normal: 0.0
|
||||
lambda_normal_smooth: 0.
|
||||
lambda_3d_normal_smooth: 0.
|
||||
lambda_normal_consistency: [1000,10.0,1,2000]
|
||||
lambda_laplacian_smoothness: 0.0
|
||||
|
||||
optimizer:
|
||||
name: Adam
|
||||
args:
|
||||
lr: 0.005
|
||||
betas: [0.9, 0.99]
|
||||
eps: 1.e-15
|
||||
|
||||
trainer:
|
||||
max_steps: 5000
|
||||
log_every_n_steps: 1
|
||||
num_sanity_val_steps: 0
|
||||
val_check_interval: 200
|
||||
enable_progress_bar: true
|
||||
precision: 32
|
||||
strategy: "ddp_find_unused_parameters_true"
|
||||
|
||||
checkpoint:
|
||||
save_last: true
|
||||
save_top_k: -1
|
||||
every_n_train_steps: ${trainer.max_steps}
|
166
configs/dreamcraft3d-texture.yaml
Normal file
166
configs/dreamcraft3d-texture.yaml
Normal file
@ -0,0 +1,166 @@
|
||||
name: "dreamcraft3d-texture"
|
||||
tag: "${rmspace:${system.prompt_processor.prompt},_}"
|
||||
exp_root_dir: "outputs"
|
||||
seed: 0
|
||||
|
||||
data_type: "single-image-datamodule"
|
||||
data:
|
||||
image_path: ./load/images/hamburger_rgba.png
|
||||
height: 1024
|
||||
width: 1024
|
||||
default_elevation_deg: 0.0
|
||||
default_azimuth_deg: 0.0
|
||||
default_camera_distance: 3.8
|
||||
default_fovy_deg: 20.0
|
||||
requires_depth: false
|
||||
requires_normal: false
|
||||
use_mixed_camera_config: false
|
||||
random_camera:
|
||||
height: 1024
|
||||
width: 1024
|
||||
batch_size: 1
|
||||
eval_height: 1024
|
||||
eval_width: 1024
|
||||
eval_batch_size: 1
|
||||
elevation_range: [-10, 45]
|
||||
azimuth_range: [-180, 180]
|
||||
camera_distance_range: [3.8, 3.8]
|
||||
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
|
||||
progressive_until: 0
|
||||
camera_perturb: 0.0
|
||||
center_perturb: 0.0
|
||||
up_perturb: 0.0
|
||||
eval_elevation_deg: ${data.default_elevation_deg}
|
||||
eval_camera_distance: ${data.default_camera_distance}
|
||||
eval_fovy_deg: ${data.default_fovy_deg}
|
||||
batch_uniform_azimuth: false
|
||||
n_val_views: 40
|
||||
n_test_views: 120
|
||||
|
||||
system_type: "dreamcraft3d-system"
|
||||
system:
|
||||
stage: texture
|
||||
use_mixed_camera_config: ${data.use_mixed_camera_config}
|
||||
geometry_convert_from: ???
|
||||
geometry_convert_inherit_texture: true
|
||||
geometry_type: "tetrahedra-sdf-grid"
|
||||
geometry:
|
||||
radius: 2.0 # consistent with coarse
|
||||
isosurface_resolution: 128
|
||||
isosurface_deformable_grid: true
|
||||
isosurface_remove_outliers: true
|
||||
pos_encoding_config:
|
||||
otype: HashGrid
|
||||
n_levels: 16
|
||||
n_features_per_level: 2
|
||||
log2_hashmap_size: 19
|
||||
base_resolution: 16
|
||||
per_level_scale: 1.447269237440378 # max resolution 4096
|
||||
fix_geometry: true
|
||||
|
||||
material_type: "no-material"
|
||||
material:
|
||||
n_output_dims: 3
|
||||
|
||||
background_type: "solid-color-background"
|
||||
|
||||
renderer_type: "nvdiff-rasterizer"
|
||||
renderer:
|
||||
context_type: cuda
|
||||
|
||||
prompt_processor_type: "stable-diffusion-prompt-processor"
|
||||
prompt_processor:
|
||||
pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
|
||||
prompt: ???
|
||||
front_threshold: 30.
|
||||
back_threshold: 30.
|
||||
|
||||
guidance_type: "stable-diffusion-bsd-guidance"
|
||||
guidance:
|
||||
pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
|
||||
pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1-base"
|
||||
# pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1"
|
||||
guidance_scale: 2.0
|
||||
min_step_percent: 0.05
|
||||
max_step_percent: [0, 0.5, 0.2, 5000]
|
||||
only_pretrain_step: 1000
|
||||
|
||||
# guidance_3d_type: "stable-zero123-guidance"
|
||||
# guidance_3d:
|
||||
# pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
|
||||
# pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
|
||||
# cond_image_path: ${data.image_path}
|
||||
# cond_elevation_deg: ${data.default_elevation_deg}
|
||||
# cond_azimuth_deg: ${data.default_azimuth_deg}
|
||||
# cond_camera_distance: ${data.default_camera_distance}
|
||||
# guidance_scale: 5.0
|
||||
# min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter)
|
||||
# max_step_percent: 0.5
|
||||
|
||||
# control_guidance_type: "stable-diffusion-controlnet-reg-guidance"
|
||||
# control_guidance:
|
||||
# min_step_percent: 0.1
|
||||
# max_step_percent: 0.5
|
||||
# control_prompt_processor_type: "stable-diffusion-prompt-processor"
|
||||
# control_prompt_processor:
|
||||
# pretrained_model_name_or_path: "SG161222/Realistic_Vision_V2.0"
|
||||
# prompt: ${system.prompt_processor.prompt}
|
||||
# front_threshold: 30.
|
||||
# back_threshold: 30.
|
||||
|
||||
freq:
|
||||
n_ref: 2
|
||||
ref_only_steps: 0
|
||||
ref_or_guidance: "alternate"
|
||||
no_diff_steps: -1
|
||||
guidance_eval: 0
|
||||
|
||||
loggers:
|
||||
wandb:
|
||||
enable: false
|
||||
project: "threestudio"
|
||||
|
||||
loss:
|
||||
lambda_sd: 0.01
|
||||
lambda_lora: 0.1
|
||||
lambda_pretrain: 0.1
|
||||
lambda_3d_sd: 0.0
|
||||
lambda_rgb: 1000.
|
||||
lambda_mask: 100.
|
||||
lambda_mask_binary: 0.0
|
||||
lambda_depth: 0.0
|
||||
lambda_depth_rel: 0.0
|
||||
lambda_normal: 0.0
|
||||
lambda_normal_smooth: 0.0
|
||||
lambda_3d_normal_smooth: 0.0
|
||||
lambda_z_variance: 0.0
|
||||
lambda_reg: 0.0
|
||||
|
||||
optimizer:
|
||||
name: AdamW
|
||||
args:
|
||||
betas: [0.9, 0.99]
|
||||
eps: 1.e-4
|
||||
params:
|
||||
geometry.encoding:
|
||||
lr: 0.01
|
||||
geometry.feature_network:
|
||||
lr: 0.001
|
||||
guidance.train_unet:
|
||||
lr: 0.00001
|
||||
guidance.train_unet_lora:
|
||||
lr: 0.00001
|
||||
|
||||
trainer:
|
||||
max_steps: 5000
|
||||
log_every_n_steps: 1
|
||||
num_sanity_val_steps: 0
|
||||
val_check_interval: 200
|
||||
enable_progress_bar: true
|
||||
precision: 32
|
||||
strategy: "ddp_find_unused_parameters_true"
|
||||
|
||||
checkpoint:
|
||||
save_last: true
|
||||
save_top_k: -1
|
||||
every_n_train_steps: ${trainer.max_steps}
|
60
docker/Dockerfile
Normal file
60
docker/Dockerfile
Normal file
@ -0,0 +1,60 @@
|
||||
# Reference:
|
||||
# https://github.com/cvpaperchallenge/Ascender
|
||||
# https://github.com/nerfstudio-project/nerfstudio
|
||||
|
||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
|
||||
ARG USER_NAME=dreamer
|
||||
ARG GROUP_NAME=dreamers
|
||||
ARG UID=1000
|
||||
ARG GID=1000
|
||||
|
||||
# Set compute capability for nerfacc and tiny-cuda-nn
|
||||
# See https://developer.nvidia.com/cuda-gpus and limit number to speed-up build
|
||||
ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
|
||||
ENV TCNN_CUDA_ARCHITECTURES=90;89;86;80;75;70;61;60
|
||||
# Speed-up build for RTX 30xx
|
||||
# ENV TORCH_CUDA_ARCH_LIST="8.6"
|
||||
# ENV TCNN_CUDA_ARCHITECTURES=86
|
||||
# Speed-up build for RTX 40xx
|
||||
# ENV TORCH_CUDA_ARCH_LIST="8.9"
|
||||
# ENV TCNN_CUDA_ARCHITECTURES=89
|
||||
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH}
|
||||
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||
ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH}
|
||||
|
||||
# apt install by root user
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
curl \
|
||||
git \
|
||||
libegl1-mesa-dev \
|
||||
libgl1-mesa-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libxrender1 \
|
||||
python-is-python3 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
wget \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Change user to non-root user
|
||||
RUN groupadd -g ${GID} ${GROUP_NAME} \
|
||||
&& useradd -ms /bin/sh -u ${UID} -g ${GID} ${USER_NAME}
|
||||
USER ${USER_NAME}
|
||||
|
||||
RUN pip install --upgrade pip setuptools ninja
|
||||
RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
# Install nerfacc and tiny-cuda-nn before installing requirements.txt
|
||||
# because these two installations are time consuming and error prone
|
||||
RUN pip install git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2
|
||||
RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn.git#subdirectory=bindings/torch
|
||||
|
||||
COPY requirements.txt /tmp
|
||||
RUN cd /tmp && pip install -r requirements.txt
|
||||
WORKDIR /home/${USER_NAME}/threestudio
|
23
docker/compose.yaml
Normal file
23
docker/compose.yaml
Normal file
@ -0,0 +1,23 @@
|
||||
services:
|
||||
threestudio:
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: docker/Dockerfile
|
||||
args:
|
||||
# you can set environment variables, otherwise default values will be used
|
||||
USER_NAME: ${HOST_USER_NAME:-dreamer} # export HOST_USER_NAME=$USER
|
||||
GROUP_NAME: ${HOST_GROUP_NAME:-dreamers}
|
||||
UID: ${HOST_UID:-1000} # export HOST_UID=$(id -u)
|
||||
GID: ${HOST_GID:-1000} # export HOST_GID=$(id -g)
|
||||
shm_size: '4gb'
|
||||
environment:
|
||||
NVIDIA_DISABLE_REQUIRE: 1 # avoid wrong `nvidia-container-cli: requirement error`
|
||||
tty: true
|
||||
volumes:
|
||||
- ../:/home/${HOST_USER_NAME:-dreamer}/threestudio
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
capabilities: [gpu]
|
59
docs/installation.md
Normal file
59
docs/installation.md
Normal file
@ -0,0 +1,59 @@
|
||||
# Installation
|
||||
|
||||
## Prerequisite
|
||||
|
||||
- NVIDIA GPU with at least 6GB VRAM. The more memory you have, the more methods and higher resolutions you can try.
|
||||
- [NVIDIA Driver](https://www.nvidia.com/Download/index.aspx) whose version is higher than the [Minimum Required Driver Version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html) of CUDA Toolkit you want to use.
|
||||
|
||||
## Install CUDA Toolkit
|
||||
|
||||
You can skip this step if you have installed sufficiently new version or you use Docker.
|
||||
|
||||
Install [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive).
|
||||
|
||||
- Example for Ubuntu 22.04:
|
||||
- Run [command for CUDA 11.8 Ubuntu 22.04](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=22.04&target_type=deb_local)
|
||||
- Example for Ubuntu on WSL2:
|
||||
- `sudo apt-key del 7fa2af80`
|
||||
- Run [command for CUDA 11.8 WSL-Ubuntu](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=WSL-Ubuntu&target_version=2.0&target_type=deb_local)
|
||||
|
||||
## Git Clone
|
||||
|
||||
```bash
|
||||
git clone https://github.com/threestudio-project/threestudio.git
|
||||
cd threestudio/
|
||||
```
|
||||
|
||||
## Install threestudio via Docker
|
||||
|
||||
1. [Install Docker Engine](https://docs.docker.com/engine/install/).
|
||||
This document assumes you [install Docker Engine on Ubuntu](https://docs.docker.com/engine/install/ubuntu/).
|
||||
2. [Create `docker` group](https://docs.docker.com/engine/install/linux-postinstall/).
|
||||
Otherwise, you need to type `sudo docker` instead of `docker`.
|
||||
3. [Install NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html#setting-up-nvidia-container-toolkit).
|
||||
4. If you use WSL2, [enable systemd](https://learn.microsoft.com/en-us/windows/wsl/wsl-config#systemd-support).
|
||||
5. Edit [Dockerfile](../docker/Dockerfile) for your GPU to speed-up build.
|
||||
The default Dockerfile takes into account many types of GPUs.
|
||||
6. Run Docker via `docker compose`.
|
||||
|
||||
```bash
|
||||
cd docker/
|
||||
docker compose build # build Docker image
|
||||
docker compose up -d # create and start a container in background
|
||||
docker compose exec threestudio bash # run bash in the container
|
||||
|
||||
# Enjoy threestudio!
|
||||
|
||||
exit # or Ctrl+D
|
||||
docker compose stop # stop the container
|
||||
docker compose start # start the container
|
||||
docker compose down # stop and remove the container
|
||||
```
|
||||
|
||||
Note: The current Dockerfile will cause errors when using the OpenGL-based rasterizer of nvdiffrast.
|
||||
You can use the CUDA-based rasterizer by adding commands or editing configs.
|
||||
|
||||
- `system.renderer.context_type=cuda` for training
|
||||
- `system.exporter.context_type=cuda` for exporting meshes
|
||||
|
||||
[This comment by the nvdiffrast author](https://github.com/NVlabs/nvdiffrast/issues/94#issuecomment-1288566038) could be a guide to resolve this limitation.
|
1
extern/MVDream
vendored
Submodule
1
extern/MVDream
vendored
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 853c51b5575e179b25d3aef3d9dbdff950e922ee
|
1
extern/One-2-3-45
vendored
Submodule
1
extern/One-2-3-45
vendored
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit ea885683ee1a5ad93ba369057dc3d71b7a5ae061
|
0
extern/__init__.py
vendored
Normal file
0
extern/__init__.py
vendored
Normal file
78
extern/ldm_zero123/extras.py
vendored
Executable file
78
extern/ldm_zero123/extras.py
vendored
Executable file
@ -0,0 +1,78 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from extern.ldm_zero123.util import instantiate_from_config
|
||||
|
||||
|
||||
@contextmanager
|
||||
def all_logging_disabled(highest_level=logging.CRITICAL):
|
||||
"""
|
||||
A context manager that will prevent any logging messages
|
||||
triggered during the body from being processed.
|
||||
|
||||
:param highest_level: the maximum logging level in use.
|
||||
This would only need to be changed if a custom level greater than CRITICAL
|
||||
is defined.
|
||||
|
||||
https://gist.github.com/simon-weber/7853144
|
||||
"""
|
||||
# two kind-of hacks here:
|
||||
# * can't get the highest logging level in effect => delegate to the user
|
||||
# * can't get the current module-level override => use an undocumented
|
||||
# (but non-private!) interface
|
||||
|
||||
previous_level = logging.root.manager.disable
|
||||
|
||||
logging.disable(highest_level)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logging.disable(previous_level)
|
||||
|
||||
|
||||
def load_training_dir(train_dir, device, epoch="last"):
|
||||
"""Load a checkpoint and config from training directory"""
|
||||
train_dir = Path(train_dir)
|
||||
ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
|
||||
assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
|
||||
config = list(train_dir.rglob(f"*-project.yaml"))
|
||||
assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
|
||||
if len(config) > 1:
|
||||
print(f"found {len(config)} matching config files")
|
||||
config = sorted(config)[-1]
|
||||
print(f"selecting {config}")
|
||||
else:
|
||||
config = config[0]
|
||||
|
||||
config = OmegaConf.load(config)
|
||||
return load_model_from_config(config, ckpt[0], device)
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, device="cpu", verbose=False):
|
||||
"""Loads a model from config and a ckpt
|
||||
if config is a path will use omegaconf to load
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
|
||||
with all_logging_disabled():
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
global_step = pl_sd["global_step"]
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.cond_stage_model.device = device
|
||||
return model
|
110
extern/ldm_zero123/guidance.py
vendored
Executable file
110
extern/ldm_zero123/guidance.py
vendored
Executable file
@ -0,0 +1,110 @@
|
||||
import abc
|
||||
from typing import List, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from IPython.display import clear_output
|
||||
from scipy import interpolate
|
||||
|
||||
|
||||
class GuideModel(torch.nn.Module, abc.ABC):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@abc.abstractmethod
|
||||
def preprocess(self, x_img):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_loss(self, inp):
|
||||
pass
|
||||
|
||||
|
||||
class Guider(torch.nn.Module):
|
||||
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
|
||||
"""Apply classifier guidance
|
||||
|
||||
Specify a guidance scale as either a scalar
|
||||
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
|
||||
[(0, 10), (0.5, 20), (1, 50)]
|
||||
"""
|
||||
super().__init__()
|
||||
self.sampler = sampler
|
||||
self.index = 0
|
||||
self.show = verbose
|
||||
self.guide_model = guide_model
|
||||
self.history = []
|
||||
|
||||
if isinstance(scale, (Tuple, List)):
|
||||
times = np.array([x[0] for x in scale])
|
||||
values = np.array([x[1] for x in scale])
|
||||
self.scale_schedule = {"times": times, "values": values}
|
||||
else:
|
||||
self.scale_schedule = float(scale)
|
||||
|
||||
self.ddim_timesteps = sampler.ddim_timesteps
|
||||
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
|
||||
|
||||
def get_scales(self):
|
||||
if isinstance(self.scale_schedule, float):
|
||||
return len(self.ddim_timesteps) * [self.scale_schedule]
|
||||
|
||||
interpolater = interpolate.interp1d(
|
||||
self.scale_schedule["times"], self.scale_schedule["values"]
|
||||
)
|
||||
fractional_steps = np.array(self.ddim_timesteps) / self.ddpm_num_timesteps
|
||||
return interpolater(fractional_steps)
|
||||
|
||||
def modify_score(self, model, e_t, x, t, c):
|
||||
# TODO look up index by t
|
||||
scale = self.get_scales()[self.index]
|
||||
|
||||
if scale == 0:
|
||||
return e_t
|
||||
|
||||
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
|
||||
x_img = model.first_stage_model.decode((1 / 0.18215) * pred_x0)
|
||||
|
||||
inp = self.guide_model.preprocess(x_img)
|
||||
loss = self.guide_model.compute_loss(inp)
|
||||
grads = torch.autograd.grad(loss.sum(), x_in)[0]
|
||||
correction = grads * scale
|
||||
|
||||
if self.show:
|
||||
clear_output(wait=True)
|
||||
print(
|
||||
loss.item(),
|
||||
scale,
|
||||
correction.abs().max().item(),
|
||||
e_t.abs().max().item(),
|
||||
)
|
||||
self.history.append(
|
||||
[
|
||||
loss.item(),
|
||||
scale,
|
||||
correction.min().item(),
|
||||
correction.max().item(),
|
||||
]
|
||||
)
|
||||
plt.imshow(
|
||||
(inp[0].detach().permute(1, 2, 0).clamp(-1, 1).cpu() + 1) / 2
|
||||
)
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
plt.imshow(correction[0][0].detach().cpu())
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
e_t_mod = e_t - sqrt_1ma * correction
|
||||
if self.show:
|
||||
fig, axs = plt.subplots(1, 3)
|
||||
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
plt.show()
|
||||
self.index += 1
|
||||
return e_t_mod
|
135
extern/ldm_zero123/lr_scheduler.py
vendored
Executable file
135
extern/ldm_zero123/lr_scheduler.py
vendored
Executable file
@ -0,0 +1,135 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler:
|
||||
"""
|
||||
note: use with a base_lr of 1.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warm_up_steps,
|
||||
lr_min,
|
||||
lr_max,
|
||||
lr_start,
|
||||
max_decay_steps,
|
||||
verbosity_interval=0,
|
||||
):
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.lr_start = lr_start
|
||||
self.lr_min = lr_min
|
||||
self.lr_max = lr_max
|
||||
self.lr_max_decay_steps = max_decay_steps
|
||||
self.last_lr = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if n < self.lr_warm_up_steps:
|
||||
lr = (
|
||||
self.lr_max - self.lr_start
|
||||
) / self.lr_warm_up_steps * n + self.lr_start
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps) / (
|
||||
self.lr_max_decay_steps - self.lr_warm_up_steps
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
||||
1 + np.cos(t * np.pi)
|
||||
)
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler2:
|
||||
"""
|
||||
supports repeated iterations, configurable via lists
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
||||
):
|
||||
assert (
|
||||
len(warm_up_steps)
|
||||
== len(f_min)
|
||||
== len(f_max)
|
||||
== len(f_start)
|
||||
== len(cycle_lengths)
|
||||
)
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.f_start = f_start
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.cycle_lengths = cycle_lengths
|
||||
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||
self.last_f = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def find_in_interval(self, n):
|
||||
interval = 0
|
||||
for cl in self.cum_cycles[1:]:
|
||||
if n <= cl:
|
||||
return interval
|
||||
interval += 1
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (
|
||||
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
1 + np.cos(t * np.pi)
|
||||
)
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
self.cycle_lengths[cycle] - n
|
||||
) / (self.cycle_lengths[cycle])
|
||||
self.last_f = f
|
||||
return f
|
551
extern/ldm_zero123/models/autoencoder.py
vendored
Executable file
551
extern/ldm_zero123/models/autoencoder.py
vendored
Executable file
@ -0,0 +1,551 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from extern.ldm_zero123.modules.diffusionmodules.model import Decoder, Encoder
|
||||
from extern.ldm_zero123.modules.distributions.distributions import (
|
||||
DiagonalGaussianDistribution,
|
||||
)
|
||||
from extern.ldm_zero123.util import instantiate_from_config
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(
|
||||
n_embed,
|
||||
embed_dim,
|
||||
beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape,
|
||||
)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(
|
||||
f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
|
||||
)
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_, _, ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(
|
||||
np.arange(lower_size, upper_size + 16, 16)
|
||||
)
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
predicted_indices=ind,
|
||||
)
|
||||
|
||||
self.log_dict(
|
||||
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
||||
)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
||||
)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val" + suffix,
|
||||
predicted_indices=ind,
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val" + suffix,
|
||||
predicted_indices=ind,
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(
|
||||
f"val{suffix}/rec_loss",
|
||||
rec_loss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
f"val{suffix}/aeloss",
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
if version.parse(pl.__version__) >= version.parse("1.4.0"):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor * self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
+ list(self.quantize.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
|
||||
)
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
"scheduler": LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
},
|
||||
{
|
||||
"scheduler": LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
)
|
||||
self.log(
|
||||
"aeloss",
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
|
||||
)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
)
|
||||
|
||||
self.log(
|
||||
"discloss",
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
|
||||
)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val",
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val",
|
||||
)
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
|
||||
)
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
0
extern/ldm_zero123/models/diffusion/__init__.py
vendored
Executable file
0
extern/ldm_zero123/models/diffusion/__init__.py
vendored
Executable file
319
extern/ldm_zero123/models/diffusion/classifier.py
vendored
Executable file
319
extern/ldm_zero123/models/diffusion/classifier.py
vendored
Executable file
@ -0,0 +1,319 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from glob import glob
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from natsort import natsorted
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn import functional as F
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from extern.ldm_zero123.modules.diffusionmodules.openaimodel import (
|
||||
EncoderUNetModel,
|
||||
UNetModel,
|
||||
)
|
||||
from extern.ldm_zero123.util import (
|
||||
default,
|
||||
instantiate_from_config,
|
||||
ismap,
|
||||
log_txt_as_img,
|
||||
)
|
||||
|
||||
__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_path,
|
||||
num_classes,
|
||||
ckpt_path=None,
|
||||
pool="attention",
|
||||
label_key=None,
|
||||
diffusion_ckpt_path=None,
|
||||
scheduler_config=None,
|
||||
weight_decay=1.0e-2,
|
||||
log_steps=10,
|
||||
monitor="val/loss",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_classes = num_classes
|
||||
# get latest config of diffusion model
|
||||
diffusion_config = natsorted(
|
||||
glob(os.path.join(diffusion_path, "configs", "*-project.yaml"))
|
||||
)[-1]
|
||||
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||
self.load_diffusion()
|
||||
|
||||
self.monitor = monitor
|
||||
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
||||
self.log_steps = log_steps
|
||||
|
||||
self.label_key = (
|
||||
label_key
|
||||
if not hasattr(self.diffusion_model, "cond_stage_key")
|
||||
else self.diffusion_model.cond_stage_key
|
||||
)
|
||||
|
||||
assert (
|
||||
self.label_key is not None
|
||||
), "label_key neither in diffusion model nor in model.params"
|
||||
|
||||
if self.label_key not in __models__:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.load_classifier(ckpt_path, pool)
|
||||
|
||||
self.scheduler_config = scheduler_config
|
||||
self.use_scheduler = self.scheduler_config is not None
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = (
|
||||
self.load_state_dict(sd, strict=False)
|
||||
if not only_model
|
||||
else self.model.load_state_dict(sd, strict=False)
|
||||
)
|
||||
print(
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def load_diffusion(self):
|
||||
model = instantiate_from_config(self.diffusion_config)
|
||||
self.diffusion_model = model.eval()
|
||||
self.diffusion_model.train = disabled_train
|
||||
for param in self.diffusion_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def load_classifier(self, ckpt_path, pool):
|
||||
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
||||
model_config.in_channels = (
|
||||
self.diffusion_config.params.unet_config.params.out_channels
|
||||
)
|
||||
model_config.out_channels = self.num_classes
|
||||
if self.label_key == "class_label":
|
||||
model_config.pool = pool
|
||||
|
||||
self.model = __models__[self.label_key](**model_config)
|
||||
if ckpt_path is not None:
|
||||
print(
|
||||
"#####################################################################"
|
||||
)
|
||||
print(f'load from ckpt "{ckpt_path}"')
|
||||
print(
|
||||
"#####################################################################"
|
||||
)
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_x_noisy(self, x, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x))
|
||||
continuous_sqrt_alpha_cumprod = None
|
||||
if self.diffusion_model.use_continuous_noise:
|
||||
continuous_sqrt_alpha_cumprod = (
|
||||
self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
||||
)
|
||||
# todo: make sure t+1 is correct here
|
||||
|
||||
return self.diffusion_model.q_sample(
|
||||
x_start=x,
|
||||
t=t,
|
||||
noise=noise,
|
||||
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
|
||||
)
|
||||
|
||||
def forward(self, x_noisy, t, *args, **kwargs):
|
||||
return self.model(x_noisy, t)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = rearrange(x, "b h w c -> b c h w")
|
||||
x = x.to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def get_conditioning(self, batch, k=None):
|
||||
if k is None:
|
||||
k = self.label_key
|
||||
assert k is not None, "Needs to provide label key"
|
||||
|
||||
targets = batch[k].to(self.device)
|
||||
|
||||
if self.label_key == "segmentation":
|
||||
targets = rearrange(targets, "b h w c -> b c h w")
|
||||
for down in range(self.numd):
|
||||
h, w = targets.shape[-2:]
|
||||
targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest")
|
||||
|
||||
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||
|
||||
return targets
|
||||
|
||||
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
||||
_, top_ks = torch.topk(logits, k, dim=1)
|
||||
if reduction == "mean":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||
elif reduction == "none":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
# save some memory
|
||||
self.diffusion_model.model.to("cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def write_logs(self, loss, logits, targets):
|
||||
log_prefix = "train" if self.training else "val"
|
||||
log = {}
|
||||
log[f"{log_prefix}/loss"] = loss.mean()
|
||||
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
||||
logits, targets, k=1, reduction="mean"
|
||||
)
|
||||
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
||||
logits, targets, k=5, reduction="mean"
|
||||
)
|
||||
|
||||
self.log_dict(
|
||||
log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True
|
||||
)
|
||||
self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
||||
self.log(
|
||||
"global_step", self.global_step, logger=False, on_epoch=False, prog_bar=True
|
||||
)
|
||||
lr = self.optimizers().param_groups[0]["lr"]
|
||||
self.log("lr_abs", lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
||||
|
||||
def shared_step(self, batch, t=None):
|
||||
x, *_ = self.diffusion_model.get_input(
|
||||
batch, k=self.diffusion_model.first_stage_key
|
||||
)
|
||||
targets = self.get_conditioning(batch)
|
||||
if targets.dim() == 4:
|
||||
targets = targets.argmax(dim=1)
|
||||
if t is None:
|
||||
t = torch.randint(
|
||||
0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device
|
||||
).long()
|
||||
else:
|
||||
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
||||
x_noisy = self.get_x_noisy(x, t)
|
||||
logits = self(x_noisy, t)
|
||||
|
||||
loss = F.cross_entropy(logits, targets, reduction="none")
|
||||
|
||||
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
||||
|
||||
loss = loss.mean()
|
||||
return loss, logits, x_noisy, targets
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
return loss
|
||||
|
||||
def reset_noise_accs(self):
|
||||
self.noisy_acc = {
|
||||
t: {"acc@1": [], "acc@5": []}
|
||||
for t in range(
|
||||
0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t
|
||||
)
|
||||
}
|
||||
|
||||
def on_validation_start(self):
|
||||
self.reset_noise_accs()
|
||||
|
||||
@torch.no_grad()
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
|
||||
for t in self.noisy_acc:
|
||||
_, logits, _, targets = self.shared_step(batch, t)
|
||||
self.noisy_acc[t]["acc@1"].append(
|
||||
self.compute_top_k(logits, targets, k=1, reduction="mean")
|
||||
)
|
||||
self.noisy_acc[t]["acc@5"].append(
|
||||
self.compute_top_k(logits, targets, k=5, reduction="mean")
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
|
||||
if self.use_scheduler:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
"scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
}
|
||||
]
|
||||
return [optimizer], scheduler
|
||||
|
||||
return optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
||||
log["inputs"] = x
|
||||
|
||||
y = self.get_conditioning(batch)
|
||||
|
||||
if self.label_key == "class_label":
|
||||
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||
log["labels"] = y
|
||||
|
||||
if ismap(y):
|
||||
log["labels"] = self.diffusion_model.to_rgb(y)
|
||||
|
||||
for step in range(self.log_steps):
|
||||
current_time = step * self.log_time_interval
|
||||
|
||||
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
||||
|
||||
log[f"inputs@t{current_time}"] = x_noisy
|
||||
|
||||
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
||||
pred = rearrange(pred, "b h w c -> b c h w")
|
||||
|
||||
log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred)
|
||||
|
||||
for key in log:
|
||||
log[key] = log[key][:N]
|
||||
|
||||
return log
|
488
extern/ldm_zero123/models/diffusion/ddim.py
vendored
Executable file
488
extern/ldm_zero123/models/diffusion/ddim.py
vendored
Executable file
@ -0,0 +1,488 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from extern.ldm_zero123.models.diffusion.sampling_util import (
|
||||
norm_thresholding,
|
||||
renorm_thresholding,
|
||||
spatial_norm_thresholding,
|
||||
)
|
||||
from extern.ldm_zero123.modules.diffusionmodules.util import (
|
||||
extract_into_tensor,
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
)
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def to(self, device):
|
||||
"""Same as to in torch module
|
||||
Don't really underestand why this isn't a module in the first place"""
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_v = getattr(self, k).to(device)
|
||||
setattr(self, k, new_v)
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(
|
||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||
):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer("betas", to_torch(self.model.betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
# print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
t_start=-1,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps
|
||||
if ddim_use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = (
|
||||
int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
||||
time_range = (
|
||||
reversed(range(0, timesteps))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
img = callback(i, img, pred_x0)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates["x_inter"].append(img)
|
||||
intermediates["pred_x0"].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
print(t, sqrt_one_minus_at, a_t)
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(
|
||||
self,
|
||||
x0,
|
||||
c,
|
||||
t_enc,
|
||||
use_original_steps=False,
|
||||
return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
):
|
||||
num_reference_steps = (
|
||||
self.ddpm_num_timesteps
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps.shape[0]
|
||||
)
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc="Encoding Image"):
|
||||
t = torch.full(
|
||||
(x0.shape[0],), i, device=self.model.device, dtype=torch.long
|
||||
)
|
||||
if unconditional_guidance_scale == 1.0:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(
|
||||
torch.cat((x_next, x_next)),
|
||||
torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c)),
|
||||
),
|
||||
2,
|
||||
)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (
|
||||
noise_pred - e_t_uncond
|
||||
)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
weighted_noise_pred = (
|
||||
alphas_next[i].sqrt()
|
||||
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
|
||||
* noise_pred
|
||||
)
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if (
|
||||
return_intermediates
|
||||
and i % (num_steps // return_intermediates) == 0
|
||||
and i < num_steps - 1
|
||||
):
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
|
||||
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({"intermediates": intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
):
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
|
||||
)
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return x_dec
|
2689
extern/ldm_zero123/models/diffusion/ddpm.py
vendored
Executable file
2689
extern/ldm_zero123/models/diffusion/ddpm.py
vendored
Executable file
File diff suppressed because it is too large
Load Diff
383
extern/ldm_zero123/models/diffusion/plms.py
vendored
Executable file
383
extern/ldm_zero123/models/diffusion/plms.py
vendored
Executable file
@ -0,0 +1,383 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from extern.ldm_zero123.models.diffusion.sampling_util import norm_thresholding
|
||||
from extern.ldm_zero123.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
)
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(
|
||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||
):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError("ddim_eta must be 0 for PLMS")
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer("betas", to_torch(self.model.betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f"Data shape for PLMS sampling is {size}")
|
||||
|
||||
samples, intermediates = self.plms_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps
|
||||
if ddim_use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = (
|
||||
int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
||||
time_range = (
|
||||
list(reversed(range(0, timesteps)))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full(
|
||||
(b,),
|
||||
time_range[min(i + 1, len(time_range) - 1)],
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates["x_inter"].append(img)
|
||||
intermediates["pred_x0"].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=None,
|
||||
t_next=None,
|
||||
dynamic_threshold=None,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if (
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (
|
||||
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
|
||||
) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
51
extern/ldm_zero123/models/diffusion/sampling_util.py
vendored
Executable file
51
extern/ldm_zero123/models/diffusion/sampling_util.py
vendored
Executable file
@ -0,0 +1,51 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def renorm_thresholding(x0, value):
|
||||
# renorm
|
||||
pred_max = x0.max()
|
||||
pred_min = x0.min()
|
||||
pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
|
||||
pred_x0 = 2 * pred_x0 - 1.0 # -1 ... 1
|
||||
|
||||
s = torch.quantile(rearrange(pred_x0, "b ... -> b (...)").abs(), value, dim=-1)
|
||||
s.clamp_(min=1.0)
|
||||
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
|
||||
|
||||
# clip by threshold
|
||||
# pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
|
||||
|
||||
# temporary hack: numpy on cpu
|
||||
pred_x0 = (
|
||||
np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy())
|
||||
/ s.cpu().numpy()
|
||||
)
|
||||
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
|
||||
|
||||
# re.renorm
|
||||
pred_x0 = (pred_x0 + 1.0) / 2.0 # 0 ... 1
|
||||
pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
|
||||
return pred_x0
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
364
extern/ldm_zero123/modules/attention.py
vendored
Executable file
364
extern/ldm_zero123/modules/attention.py
vendored
Executable file
@ -0,0 +1,364 @@
|
||||
import math
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
from extern.ldm_zero123.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
out = rearrange(
|
||||
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
||||
)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b (h w) c")
|
||||
k = rearrange(k, "b c h w -> b c (h w)")
|
||||
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, "b c h w -> b c (h w)")
|
||||
w_ = rearrange(w_, "b i j -> b j i")
|
||||
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
|
||||
super().__init__()
|
||||
|
||||
if rank > min(in_features, out_features):
|
||||
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False)
|
||||
self.up = nn.Linear(rank, out_features, bias=False)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
self.network_alpha = network_alpha
|
||||
self.rank = rank
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
down_hidden_states = self.down(hidden_states.to(dtype))
|
||||
up_hidden_states = self.up(down_hidden_states)
|
||||
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.lora = False
|
||||
self.query_dim = query_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.context_dim = context_dim
|
||||
|
||||
def setup_lora(self, rank=4, network_alpha=None):
|
||||
self.lora = True
|
||||
self.rank = rank
|
||||
self.to_q_lora = LoRALinearLayer(self.query_dim, self.inner_dim, rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(self.context_dim, self.inner_dim, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(self.context_dim, self.inner_dim, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(self.inner_dim, self.query_dim, rank, network_alpha)
|
||||
self.lora_layers = nn.ModuleList()
|
||||
self.lora_layers.append(self.to_q_lora)
|
||||
self.lora_layers.append(self.to_k_lora)
|
||||
self.lora_layers.append(self.to_v_lora)
|
||||
self.lora_layers.append(self.to_out_lora)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
if self.lora:
|
||||
q += self.to_q_lora(x)
|
||||
k += self.to_k_lora(context)
|
||||
v += self.to_v_lora(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
|
||||
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, "b ... -> b (...)")
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||
# return self.to_out(out)
|
||||
|
||||
# linear proj
|
||||
o = self.to_out[0](out)
|
||||
if self.lora:
|
||||
o += self.to_out_lora(out)
|
||||
# dropout
|
||||
out = self.to_out[1](o)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None,
|
||||
) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# return checkpoint(
|
||||
# self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
# )
|
||||
return self._forward(x, context)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = (
|
||||
self.attn1(
|
||||
self.norm1(x), context=context if self.disable_self_attn else None
|
||||
)
|
||||
+ x
|
||||
)
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
301
extern/ldm_zero123/modules/attention_ori.py
vendored
Executable file
301
extern/ldm_zero123/modules/attention_ori.py
vendored
Executable file
@ -0,0 +1,301 @@
|
||||
import math
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
from extern.ldm_zero123.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
out = rearrange(
|
||||
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
||||
)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b (h w) c")
|
||||
k = rearrange(k, "b c h w -> b c (h w)")
|
||||
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, "b c h w -> b c (h w)")
|
||||
w_ = rearrange(w_, "b i j -> b j i")
|
||||
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
|
||||
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, "b ... -> b (...)")
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None,
|
||||
) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = (
|
||||
self.attn1(
|
||||
self.norm1(x), context=context if self.disable_self_attn else None
|
||||
)
|
||||
+ x
|
||||
)
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
0
extern/ldm_zero123/modules/diffusionmodules/__init__.py
vendored
Executable file
0
extern/ldm_zero123/modules/diffusionmodules/__init__.py
vendored
Executable file
1009
extern/ldm_zero123/modules/diffusionmodules/model.py
vendored
Executable file
1009
extern/ldm_zero123/modules/diffusionmodules/model.py
vendored
Executable file
File diff suppressed because it is too large
Load Diff
1062
extern/ldm_zero123/modules/diffusionmodules/openaimodel.py
vendored
Executable file
1062
extern/ldm_zero123/modules/diffusionmodules/openaimodel.py
vendored
Executable file
File diff suppressed because it is too large
Load Diff
296
extern/ldm_zero123/modules/diffusionmodules/util.py
vendored
Executable file
296
extern/ldm_zero123/modules/diffusionmodules/util.py
vendored
Executable file
@ -0,0 +1,296 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
|
||||
from extern.ldm_zero123.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(
|
||||
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
||||
):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(
|
||||
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
||||
)
|
||||
** 2
|
||||
)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64
|
||||
)
|
||||
elif schedule == "sqrt":
|
||||
betas = (
|
||||
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
** 0.5
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(
|
||||
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
|
||||
):
|
||||
if ddim_discr_method == "uniform":
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == "quad":
|
||||
ddim_timesteps = (
|
||||
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
|
||||
).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
||||
)
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f"Selected timesteps for ddim sampler: {steps_out}")
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt(
|
||||
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
|
||||
)
|
||||
print(
|
||||
f"For the chosen value of eta, which is {eta}, "
|
||||
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
|
||||
)
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
else:
|
||||
embedding = repeat(timesteps, "b -> b d", d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
||||
shape[0], *((1,) * (len(shape) - 1))
|
||||
)
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
0
extern/ldm_zero123/modules/distributions/__init__.py
vendored
Executable file
0
extern/ldm_zero123/modules/distributions/__init__.py
vendored
Executable file
102
extern/ldm_zero123/modules/distributions/distributions.py
vendored
Executable file
102
extern/ldm_zero123/modules/distributions/distributions.py
vendored
Executable file
@ -0,0 +1,102 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(
|
||||
device=self.parameters.device
|
||||
)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
||||
device=self.parameters.device
|
||||
)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims,
|
||||
)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
82
extern/ldm_zero123/modules/ema.py
vendored
Executable file
82
extern/ldm_zero123/modules/ema.py
vendored
Executable file
@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError("Decay must be between 0 and 1")
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
"num_updates",
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int),
|
||||
)
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
# remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace(".", "")
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(
|
||||
one_minus_decay * (shadow_params[sname] - m_param[key])
|
||||
)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
0
extern/ldm_zero123/modules/encoders/__init__.py
vendored
Executable file
0
extern/ldm_zero123/modules/encoders/__init__.py
vendored
Executable file
712
extern/ldm_zero123/modules/encoders/modules.py
vendored
Executable file
712
extern/ldm_zero123/modules/encoders/modules.py
vendored
Executable file
@ -0,0 +1,712 @@
|
||||
from functools import partial
|
||||
|
||||
import clip
|
||||
import kornia
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from extern.ldm_zero123.modules.x_transformer import ( # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
Encoder,
|
||||
TransformerWrapper,
|
||||
)
|
||||
from extern.ldm_zero123.util import default
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class FaceClipEncoder(AbstractEncoder):
|
||||
def __init__(self, augment=True, retreival_key=None):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
self.augment = augment
|
||||
self.retreival_key = retreival_key
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
x_offset = 125
|
||||
if self.retreival_key:
|
||||
# Assumes retrieved image are packed into the second half of channels
|
||||
face = img[:, 3:, 190:440, x_offset : (512 - x_offset)]
|
||||
other = img[:, :3, ...].clone()
|
||||
else:
|
||||
face = img[:, :, 190:440, x_offset : (512 - x_offset)]
|
||||
other = img.clone()
|
||||
|
||||
if self.augment:
|
||||
face = K.RandomHorizontalFlip()(face)
|
||||
|
||||
other[:, :, 190:440, x_offset : (512 - x_offset)] *= 0
|
||||
encodings = [
|
||||
self.encoder.encode(face),
|
||||
self.encoder.encode(other),
|
||||
]
|
||||
|
||||
return torch.cat(encodings, dim=1)
|
||||
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros(
|
||||
(1, 2, 768), device=self.encoder.model.visual.conv1.weight.device
|
||||
)
|
||||
|
||||
return self(img)
|
||||
|
||||
|
||||
class FaceIdClipEncoder(AbstractEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
for p in self.encoder.parameters():
|
||||
p.requires_grad = False
|
||||
self.id = FrozenFaceEncoder(
|
||||
"/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
face = kornia.geometry.resize(
|
||||
img, (256, 256), interpolation="bilinear", align_corners=True
|
||||
)
|
||||
|
||||
other = img.clone()
|
||||
other[:, :, 184:452, 122:396] *= 0
|
||||
encodings = [
|
||||
self.id.encode(face),
|
||||
self.encoder.encode(other),
|
||||
]
|
||||
|
||||
return torch.cat(encodings, dim=1)
|
||||
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros(
|
||||
(1, 2, 768), device=self.encoder.model.visual.conv1.weight.device
|
||||
)
|
||||
|
||||
return self(img)
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key="class"):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
|
||||
def forward(self, batch, key=None):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
|
||||
class TransformerEmbedder(AbstractEncoder):
|
||||
"""Some transformer encoder layers"""
|
||||
|
||||
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
tokens = tokens.to(self.device) # meh
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
class BERTTokenizer(AbstractEncoder):
|
||||
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
|
||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||
super().__init__()
|
||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
self.device = device
|
||||
self.vq_interface = vq_interface
|
||||
self.max_length = max_length
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
return tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, text):
|
||||
tokens = self(text)
|
||||
if not self.vq_interface:
|
||||
return tokens
|
||||
return None, None, [None, None, tokens]
|
||||
|
||||
def decode(self, text):
|
||||
return text
|
||||
|
||||
|
||||
class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device="cuda",
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
if self.use_tknz_fn:
|
||||
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
emb_dropout=embedding_dropout,
|
||||
)
|
||||
|
||||
def forward(self, text):
|
||||
if self.use_tknz_fn:
|
||||
tokens = self.tknz_fn(text) # .to(self.device)
|
||||
else:
|
||||
tokens = text
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
# output of length 77
|
||||
return self(text)
|
||||
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
|
||||
def __init__(
|
||||
self, version="google/t5-v1_1-large", device="cuda", max_length=77
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
import kornia.augmentation as K
|
||||
|
||||
from extern.ldm_zero123.thirdp.psp.id_loss import IDFeatures
|
||||
|
||||
|
||||
class FrozenFaceEncoder(AbstractEncoder):
|
||||
def __init__(self, model_path, augment=False):
|
||||
super().__init__()
|
||||
self.loss_fn = IDFeatures(model_path)
|
||||
# face encoder is frozen
|
||||
for p in self.loss_fn.parameters():
|
||||
p.requires_grad = False
|
||||
# Mapper is trainable
|
||||
self.mapper = torch.nn.Linear(512, 768)
|
||||
p = 0.25
|
||||
if augment:
|
||||
self.augment = K.AugmentationSequential(
|
||||
K.RandomHorizontalFlip(p=0.5),
|
||||
K.RandomEqualize(p=p),
|
||||
# K.RandomPlanckianJitter(p=p),
|
||||
# K.RandomPlasmaBrightness(p=p),
|
||||
# K.RandomPlasmaContrast(p=p),
|
||||
# K.ColorJiggle(0.02, 0.2, 0.2, p=p),
|
||||
)
|
||||
else:
|
||||
self.augment = False
|
||||
|
||||
def forward(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
|
||||
|
||||
if self.augment is not None:
|
||||
# Transforms require 0-1
|
||||
img = self.augment((img + 1) / 2)
|
||||
img = 2 * img - 1
|
||||
|
||||
feat = self.loss_fn(img, crop=True)
|
||||
feat = self.mapper(feat.unsqueeze(1))
|
||||
return feat
|
||||
|
||||
def encode(self, img):
|
||||
return self(img)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
|
||||
def __init__(
|
||||
self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77
|
||||
): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
|
||||
class ClipImageProjector(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, version="openai/clip-vit-large-patch14", max_length=77
|
||||
): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.model = CLIPVisionModel.from_pretrained(version)
|
||||
self.model.train()
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.antialias = True
|
||||
self.mapper = torch.nn.Linear(1024, 768)
|
||||
self.register_buffer(
|
||||
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
|
||||
)
|
||||
null_cond = self.get_null_cond(version, max_length)
|
||||
self.register_buffer("null_cond", null_cond)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_null_cond(self, version, max_length):
|
||||
device = self.mean.device
|
||||
embedder = FrozenCLIPEmbedder(
|
||||
version=version, device=device, max_length=max_length
|
||||
)
|
||||
null_cond = embedder([""])
|
||||
return null_cond
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(
|
||||
x,
|
||||
(224, 224),
|
||||
interpolation="bicubic",
|
||||
align_corners=True,
|
||||
antialias=self.antialias,
|
||||
)
|
||||
x = (x + 1.0) / 2.0
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if isinstance(x, list):
|
||||
return self.null_cond
|
||||
# x is assumed to be in range [-1,1]
|
||||
x = self.preprocess(x)
|
||||
outputs = self.model(pixel_values=x)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
last_hidden_state = self.mapper(last_hidden_state)
|
||||
return F.pad(
|
||||
last_hidden_state,
|
||||
[0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0],
|
||||
)
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
|
||||
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
|
||||
def __init__(
|
||||
self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77
|
||||
): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.embedder = FrozenCLIPEmbedder(
|
||||
version=version, device=device, max_length=max_length
|
||||
)
|
||||
self.projection = torch.nn.Linear(768, 768)
|
||||
|
||||
def forward(self, text):
|
||||
z = self.embedder(text)
|
||||
return self.projection(z)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="ViT-L/14",
|
||||
jit=False,
|
||||
device="cpu",
|
||||
antialias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit, download_root=None)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer(
|
||||
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
|
||||
)
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(
|
||||
x,
|
||||
(224, 224),
|
||||
interpolation="bicubic",
|
||||
align_corners=True,
|
||||
antialias=self.antialias,
|
||||
)
|
||||
x = (x + 1.0) / 2.0
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
if isinstance(x, list):
|
||||
# [""] denotes condition dropout for ucg
|
||||
device = self.model.visual.conv1.weight.device
|
||||
return torch.zeros(1, 768, device=device)
|
||||
return self.model.encode_image(self.preprocess(x)).float()
|
||||
|
||||
def encode(self, im):
|
||||
return self(im).unsqueeze(1)
|
||||
|
||||
|
||||
import random
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="ViT-L/14",
|
||||
jit=False,
|
||||
device="cpu",
|
||||
antialias=True,
|
||||
max_crops=5,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer(
|
||||
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
|
||||
)
|
||||
self.max_crops = max_crops
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1, 1))
|
||||
max_crops = self.max_crops
|
||||
patches = []
|
||||
crops = [randcrop(x) for _ in range(max_crops)]
|
||||
patches.extend(crops)
|
||||
x = torch.cat(patches, dim=0)
|
||||
x = (x + 1.0) / 2.0
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
if isinstance(x, list):
|
||||
# [""] denotes condition dropout for ucg
|
||||
device = self.model.visual.conv1.weight.device
|
||||
return torch.zeros(1, self.max_crops, 768, device=device)
|
||||
batch_tokens = []
|
||||
for im in x:
|
||||
patches = self.preprocess(im.unsqueeze(0))
|
||||
tokens = self.model.encode_image(patches).float()
|
||||
for t in tokens:
|
||||
if random.random() < 0.1:
|
||||
t *= 0
|
||||
batch_tokens.append(tokens.unsqueeze(0))
|
||||
|
||||
return torch.cat(batch_tokens, dim=0)
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
|
||||
class SpatialRescaler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_stages=1,
|
||||
method="bilinear",
|
||||
multiplier=0.5,
|
||||
in_channels=3,
|
||||
out_channels=None,
|
||||
bias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_stages = n_stages
|
||||
assert self.n_stages >= 0
|
||||
assert method in [
|
||||
"nearest",
|
||||
"linear",
|
||||
"bilinear",
|
||||
"trilinear",
|
||||
"bicubic",
|
||||
"area",
|
||||
]
|
||||
self.multiplier = multiplier
|
||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||
self.remap_output = out_channels is not None
|
||||
if self.remap_output:
|
||||
print(
|
||||
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
|
||||
)
|
||||
self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
for stage in range(self.n_stages):
|
||||
x = self.interpolator(x, scale_factor=self.multiplier)
|
||||
|
||||
if self.remap_output:
|
||||
x = self.channel_mapper(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
from extern.ldm_zero123.modules.diffusionmodules.util import (
|
||||
extract_into_tensor,
|
||||
make_beta_schedule,
|
||||
noise_like,
|
||||
)
|
||||
from extern.ldm_zero123.util import instantiate_from_config
|
||||
|
||||
|
||||
class LowScaleEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
linear_start,
|
||||
linear_end,
|
||||
timesteps=1000,
|
||||
max_noise_level=250,
|
||||
output_size=64,
|
||||
scale_factor=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_noise_level = max_noise_level
|
||||
self.model = instantiate_from_config(model_config)
|
||||
self.augmentation_schedule = self.register_schedule(
|
||||
timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
|
||||
)
|
||||
self.out_size = output_size
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def register_schedule(
|
||||
self,
|
||||
beta_schedule="linear",
|
||||
timesteps=1000,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3,
|
||||
):
|
||||
betas = make_beta_schedule(
|
||||
beta_schedule,
|
||||
timesteps,
|
||||
linear_start=linear_start,
|
||||
linear_end=linear_end,
|
||||
cosine_s=cosine_s,
|
||||
)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
||||
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer("betas", to_torch(betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
||||
)
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (
|
||||
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
* noise
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
z = self.model.encode(x).sample()
|
||||
z = z * self.scale_factor
|
||||
noise_level = torch.randint(
|
||||
0, self.max_noise_level, (x.shape[0],), device=x.device
|
||||
).long()
|
||||
z = self.q_sample(z, noise_level)
|
||||
if self.out_size is not None:
|
||||
z = torch.nn.functional.interpolate(
|
||||
z, size=self.out_size, mode="nearest"
|
||||
) # TODO: experiment with mode
|
||||
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
def decode(self, z):
|
||||
z = z / self.scale_factor
|
||||
return self.model.decode(z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from extern.ldm_zero123.util import count_params
|
||||
|
||||
sentences = [
|
||||
"a hedgehog drinking a whiskey",
|
||||
"der mond ist aufgegangen",
|
||||
"Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'",
|
||||
]
|
||||
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
model = FrozenCLIPEmbedder().cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
print("done.")
|
703
extern/ldm_zero123/modules/evaluate/adm_evaluator.py
vendored
Executable file
703
extern/ldm_zero123/modules/evaluate/adm_evaluator.py
vendored
Executable file
@ -0,0 +1,703 @@
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
import zipfile
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from multiprocessing import cpu_count
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow.compat.v1 as tf
|
||||
import yaml
|
||||
from scipy import linalg
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
||||
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
||||
|
||||
FID_POOL_NAME = "pool_3:0"
|
||||
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
||||
|
||||
REQUIREMENTS = (
|
||||
f"This script has the following requirements: \n"
|
||||
"tensorflow-gpu>=2.0" + "\n" + "scipy" + "\n" + "requests" + "\n" + "tqdm"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ref_batch", help="path to reference batch npz file")
|
||||
parser.add_argument("--sample_batch", help="path to sample batch npz file")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = tf.ConfigProto(
|
||||
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
|
||||
)
|
||||
config.gpu_options.allow_growth = True
|
||||
evaluator = Evaluator(tf.Session(config=config))
|
||||
|
||||
print("warming up TensorFlow...")
|
||||
# This will cause TF to print a bunch of verbose stuff now rather
|
||||
# than after the next print(), to help prevent confusion.
|
||||
evaluator.warmup()
|
||||
|
||||
print("computing reference batch activations...")
|
||||
ref_acts = evaluator.read_activations(args.ref_batch)
|
||||
print("computing/reading reference batch statistics...")
|
||||
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
||||
|
||||
print("computing sample batch activations...")
|
||||
sample_acts = evaluator.read_activations(args.sample_batch)
|
||||
print("computing/reading sample batch statistics...")
|
||||
sample_stats, sample_stats_spatial = evaluator.read_statistics(
|
||||
args.sample_batch, sample_acts
|
||||
)
|
||||
|
||||
print("Computing evaluations...")
|
||||
is_ = evaluator.compute_inception_score(sample_acts[0])
|
||||
print("Inception Score:", is_)
|
||||
fid = sample_stats.frechet_distance(ref_stats)
|
||||
print("FID:", fid)
|
||||
sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial)
|
||||
print("sFID:", sfid)
|
||||
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
||||
print("Precision:", prec)
|
||||
print("Recall:", recall)
|
||||
|
||||
savepath = "/".join(args.sample_batch.split("/")[:-1])
|
||||
results_file = os.path.join(savepath, "evaluation_metrics.yaml")
|
||||
print(f'Saving evaluation results to "{results_file}"')
|
||||
|
||||
results = {
|
||||
"IS": is_,
|
||||
"FID": fid,
|
||||
"sFID": sfid,
|
||||
"Precision:": prec,
|
||||
"Recall": recall,
|
||||
}
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
yaml.dump(results, f, default_flow_style=False)
|
||||
|
||||
|
||||
class InvalidFIDException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FIDStatistics:
|
||||
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
||||
self.mu = mu
|
||||
self.sigma = sigma
|
||||
|
||||
def frechet_distance(self, other, eps=1e-6):
|
||||
"""
|
||||
Compute the Frechet distance between two sets of statistics.
|
||||
"""
|
||||
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
||||
mu1, sigma1 = self.mu, self.sigma
|
||||
mu2, sigma2 = other.mu, other.sigma
|
||||
|
||||
mu1 = np.atleast_1d(mu1)
|
||||
mu2 = np.atleast_1d(mu2)
|
||||
|
||||
sigma1 = np.atleast_2d(sigma1)
|
||||
sigma2 = np.atleast_2d(sigma2)
|
||||
|
||||
assert (
|
||||
mu1.shape == mu2.shape
|
||||
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
||||
assert (
|
||||
sigma1.shape == sigma2.shape
|
||||
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
||||
|
||||
diff = mu1 - mu2
|
||||
|
||||
# product might be almost singular
|
||||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
||||
if not np.isfinite(covmean).all():
|
||||
msg = (
|
||||
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
||||
% eps
|
||||
)
|
||||
warnings.warn(msg)
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
||||
|
||||
# numerical error might give slight imaginary component
|
||||
if np.iscomplexobj(covmean):
|
||||
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
||||
m = np.max(np.abs(covmean.imag))
|
||||
raise ValueError("Imaginary component {}".format(m))
|
||||
covmean = covmean.real
|
||||
|
||||
tr_covmean = np.trace(covmean)
|
||||
|
||||
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(
|
||||
self,
|
||||
session,
|
||||
batch_size=64,
|
||||
softmax_batch_size=512,
|
||||
):
|
||||
self.sess = session
|
||||
self.batch_size = batch_size
|
||||
self.softmax_batch_size = softmax_batch_size
|
||||
self.manifold_estimator = ManifoldEstimator(session)
|
||||
with self.sess.graph.as_default():
|
||||
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
||||
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
||||
self.pool_features, self.spatial_features = _create_feature_graph(
|
||||
self.image_input
|
||||
)
|
||||
self.softmax = _create_softmax_graph(self.softmax_input)
|
||||
|
||||
def warmup(self):
|
||||
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
||||
|
||||
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
with open_npz_array(npz_path, "arr_0") as reader:
|
||||
return self.compute_activations(reader.read_batches(self.batch_size))
|
||||
|
||||
def compute_activations(
|
||||
self, batches: Iterable[np.ndarray], silent=False
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute image features for downstream evals.
|
||||
|
||||
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
||||
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
||||
dimension. The tuple is (pool_3, spatial).
|
||||
"""
|
||||
preds = []
|
||||
spatial_preds = []
|
||||
it = batches if silent else tqdm(batches)
|
||||
for batch in it:
|
||||
batch = batch.astype(np.float32)
|
||||
pred, spatial_pred = self.sess.run(
|
||||
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
||||
)
|
||||
preds.append(pred.reshape([pred.shape[0], -1]))
|
||||
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
||||
return (
|
||||
np.concatenate(preds, axis=0),
|
||||
np.concatenate(spatial_preds, axis=0),
|
||||
)
|
||||
|
||||
def read_statistics(
|
||||
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
|
||||
) -> Tuple[FIDStatistics, FIDStatistics]:
|
||||
obj = np.load(npz_path)
|
||||
if "mu" in list(obj.keys()):
|
||||
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
||||
obj["mu_s"], obj["sigma_s"]
|
||||
)
|
||||
return tuple(self.compute_statistics(x) for x in activations)
|
||||
|
||||
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
||||
mu = np.mean(activations, axis=0)
|
||||
sigma = np.cov(activations, rowvar=False)
|
||||
return FIDStatistics(mu, sigma)
|
||||
|
||||
def compute_inception_score(
|
||||
self, activations: np.ndarray, split_size: int = 5000
|
||||
) -> float:
|
||||
softmax_out = []
|
||||
for i in range(0, len(activations), self.softmax_batch_size):
|
||||
acts = activations[i : i + self.softmax_batch_size]
|
||||
softmax_out.append(
|
||||
self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})
|
||||
)
|
||||
preds = np.concatenate(softmax_out, axis=0)
|
||||
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
||||
scores = []
|
||||
for i in range(0, len(preds), split_size):
|
||||
part = preds[i : i + split_size]
|
||||
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
||||
kl = np.mean(np.sum(kl, 1))
|
||||
scores.append(np.exp(kl))
|
||||
return float(np.mean(scores))
|
||||
|
||||
def compute_prec_recall(
|
||||
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
||||
) -> Tuple[float, float]:
|
||||
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
||||
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
||||
pr = self.manifold_estimator.evaluate_pr(
|
||||
activations_ref, radii_1, activations_sample, radii_2
|
||||
)
|
||||
return (float(pr[0][0]), float(pr[1][0]))
|
||||
|
||||
|
||||
class ManifoldEstimator:
|
||||
"""
|
||||
A helper for comparing manifolds of feature vectors.
|
||||
|
||||
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session,
|
||||
row_batch_size=10000,
|
||||
col_batch_size=10000,
|
||||
nhood_sizes=(3,),
|
||||
clamp_to_percentile=None,
|
||||
eps=1e-5,
|
||||
):
|
||||
"""
|
||||
Estimate the manifold of given feature vectors.
|
||||
|
||||
:param session: the TensorFlow session.
|
||||
:param row_batch_size: row batch size to compute pairwise distances
|
||||
(parameter to trade-off between memory usage and performance).
|
||||
:param col_batch_size: column batch size to compute pairwise distances.
|
||||
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
||||
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
||||
the given percentile.
|
||||
:param eps: small number for numerical stability.
|
||||
"""
|
||||
self.distance_block = DistanceBlock(session)
|
||||
self.row_batch_size = row_batch_size
|
||||
self.col_batch_size = col_batch_size
|
||||
self.nhood_sizes = nhood_sizes
|
||||
self.num_nhoods = len(nhood_sizes)
|
||||
self.clamp_to_percentile = clamp_to_percentile
|
||||
self.eps = eps
|
||||
|
||||
def warmup(self):
|
||||
feats, radii = (
|
||||
np.zeros([1, 2048], dtype=np.float32),
|
||||
np.zeros([1, 1], dtype=np.float32),
|
||||
)
|
||||
self.evaluate_pr(feats, radii, feats, radii)
|
||||
|
||||
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
||||
num_images = len(features)
|
||||
|
||||
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
||||
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
||||
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
||||
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
||||
|
||||
for begin1 in range(0, num_images, self.row_batch_size):
|
||||
end1 = min(begin1 + self.row_batch_size, num_images)
|
||||
row_batch = features[begin1:end1]
|
||||
|
||||
for begin2 in range(0, num_images, self.col_batch_size):
|
||||
end2 = min(begin2 + self.col_batch_size, num_images)
|
||||
col_batch = features[begin2:end2]
|
||||
|
||||
# Compute distances between batches.
|
||||
distance_batch[
|
||||
0 : end1 - begin1, begin2:end2
|
||||
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
||||
|
||||
# Find the k-nearest neighbor from the current batch.
|
||||
radii[begin1:end1, :] = np.concatenate(
|
||||
[
|
||||
x[:, self.nhood_sizes]
|
||||
for x in _numpy_partition(
|
||||
distance_batch[0 : end1 - begin1, :], seq, axis=1
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
if self.clamp_to_percentile is not None:
|
||||
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
||||
radii[radii > max_distances] = 0
|
||||
return radii
|
||||
|
||||
def evaluate(
|
||||
self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray
|
||||
):
|
||||
"""
|
||||
Evaluate if new feature vectors are at the manifold.
|
||||
"""
|
||||
num_eval_images = eval_features.shape[0]
|
||||
num_ref_images = radii.shape[0]
|
||||
distance_batch = np.zeros(
|
||||
[self.row_batch_size, num_ref_images], dtype=np.float32
|
||||
)
|
||||
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
||||
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
||||
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
||||
|
||||
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
||||
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
||||
feature_batch = eval_features[begin1:end1]
|
||||
|
||||
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
||||
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
||||
ref_batch = features[begin2:end2]
|
||||
|
||||
distance_batch[
|
||||
0 : end1 - begin1, begin2:end2
|
||||
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
||||
|
||||
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
||||
# If a feature vector is inside a hypersphere of some reference sample, then
|
||||
# the new sample lies at the estimated manifold.
|
||||
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
||||
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
||||
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(
|
||||
np.int32
|
||||
)
|
||||
|
||||
max_realism_score[begin1:end1] = np.max(
|
||||
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
||||
)
|
||||
nearest_indices[begin1:end1] = np.argmin(
|
||||
distance_batch[0 : end1 - begin1, :], axis=1
|
||||
)
|
||||
|
||||
return {
|
||||
"fraction": float(np.mean(batch_predictions)),
|
||||
"batch_predictions": batch_predictions,
|
||||
"max_realisim_score": max_realism_score,
|
||||
"nearest_indices": nearest_indices,
|
||||
}
|
||||
|
||||
def evaluate_pr(
|
||||
self,
|
||||
features_1: np.ndarray,
|
||||
radii_1: np.ndarray,
|
||||
features_2: np.ndarray,
|
||||
radii_2: np.ndarray,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Evaluate precision and recall efficiently.
|
||||
|
||||
:param features_1: [N1 x D] feature vectors for reference batch.
|
||||
:param radii_1: [N1 x K1] radii for reference vectors.
|
||||
:param features_2: [N2 x D] feature vectors for the other batch.
|
||||
:param radii_2: [N x K2] radii for other vectors.
|
||||
:return: a tuple of arrays for (precision, recall):
|
||||
- precision: an np.ndarray of length K1
|
||||
- recall: an np.ndarray of length K2
|
||||
"""
|
||||
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
|
||||
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
|
||||
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
||||
end_1 = begin_1 + self.row_batch_size
|
||||
batch_1 = features_1[begin_1:end_1]
|
||||
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
||||
end_2 = begin_2 + self.col_batch_size
|
||||
batch_2 = features_2[begin_2:end_2]
|
||||
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
||||
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
||||
)
|
||||
features_1_status[begin_1:end_1] |= batch_1_in
|
||||
features_2_status[begin_2:end_2] |= batch_2_in
|
||||
return (
|
||||
np.mean(features_2_status.astype(np.float64), axis=0),
|
||||
np.mean(features_1_status.astype(np.float64), axis=0),
|
||||
)
|
||||
|
||||
|
||||
class DistanceBlock:
|
||||
"""
|
||||
Calculate pairwise distances between vectors.
|
||||
|
||||
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
||||
"""
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
# Initialize TF graph to calculate pairwise distances.
|
||||
with session.graph.as_default():
|
||||
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
distance_block_16 = _batch_pairwise_distances(
|
||||
tf.cast(self._features_batch1, tf.float16),
|
||||
tf.cast(self._features_batch2, tf.float16),
|
||||
)
|
||||
self.distance_block = tf.cond(
|
||||
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
||||
lambda: tf.cast(distance_block_16, tf.float32),
|
||||
lambda: _batch_pairwise_distances(
|
||||
self._features_batch1, self._features_batch2
|
||||
),
|
||||
)
|
||||
|
||||
# Extra logic for less thans.
|
||||
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
||||
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
||||
self._batch_2_in = tf.math.reduce_any(
|
||||
dist32 <= self._radii1[:, None], axis=0
|
||||
)
|
||||
|
||||
def pairwise_distances(self, U, V):
|
||||
"""
|
||||
Evaluate pairwise distances between two batches of feature vectors.
|
||||
"""
|
||||
return self.session.run(
|
||||
self.distance_block,
|
||||
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
||||
)
|
||||
|
||||
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
||||
return self.session.run(
|
||||
[self._batch_1_in, self._batch_2_in],
|
||||
feed_dict={
|
||||
self._features_batch1: batch_1,
|
||||
self._features_batch2: batch_2,
|
||||
self._radii1: radii_1,
|
||||
self._radii2: radii_2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _batch_pairwise_distances(U, V):
|
||||
"""
|
||||
Compute pairwise distances between two batches of feature vectors.
|
||||
"""
|
||||
with tf.variable_scope("pairwise_dist_block"):
|
||||
# Squared norms of each row in U and V.
|
||||
norm_u = tf.reduce_sum(tf.square(U), 1)
|
||||
norm_v = tf.reduce_sum(tf.square(V), 1)
|
||||
|
||||
# norm_u as a column and norm_v as a row vectors.
|
||||
norm_u = tf.reshape(norm_u, [-1, 1])
|
||||
norm_v = tf.reshape(norm_v, [1, -1])
|
||||
|
||||
# Pairwise squared Euclidean distances.
|
||||
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
||||
|
||||
return D
|
||||
|
||||
|
||||
class NpzArrayReader(ABC):
|
||||
@abstractmethod
|
||||
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remaining(self) -> int:
|
||||
pass
|
||||
|
||||
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
||||
def gen_fn():
|
||||
while True:
|
||||
batch = self.read_batch(batch_size)
|
||||
if batch is None:
|
||||
break
|
||||
yield batch
|
||||
|
||||
rem = self.remaining()
|
||||
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
||||
return BatchIterator(gen_fn, num_batches)
|
||||
|
||||
|
||||
class BatchIterator:
|
||||
def __init__(self, gen_fn, length):
|
||||
self.gen_fn = gen_fn
|
||||
self.length = length
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __iter__(self):
|
||||
return self.gen_fn()
|
||||
|
||||
|
||||
class StreamingNpzArrayReader(NpzArrayReader):
|
||||
def __init__(self, arr_f, shape, dtype):
|
||||
self.arr_f = arr_f
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.idx = 0
|
||||
|
||||
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
||||
if self.idx >= self.shape[0]:
|
||||
return None
|
||||
|
||||
bs = min(batch_size, self.shape[0] - self.idx)
|
||||
self.idx += bs
|
||||
|
||||
if self.dtype.itemsize == 0:
|
||||
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
||||
|
||||
read_count = bs * np.prod(self.shape[1:])
|
||||
read_size = int(read_count * self.dtype.itemsize)
|
||||
data = _read_bytes(self.arr_f, read_size, "array data")
|
||||
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
||||
|
||||
def remaining(self) -> int:
|
||||
return max(0, self.shape[0] - self.idx)
|
||||
|
||||
|
||||
class MemoryNpzArrayReader(NpzArrayReader):
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
self.idx = 0
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, arr_name: str):
|
||||
with open(path, "rb") as f:
|
||||
arr = np.load(f)[arr_name]
|
||||
return cls(arr)
|
||||
|
||||
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
||||
if self.idx >= self.arr.shape[0]:
|
||||
return None
|
||||
|
||||
res = self.arr[self.idx : self.idx + batch_size]
|
||||
self.idx += batch_size
|
||||
return res
|
||||
|
||||
def remaining(self) -> int:
|
||||
return max(0, self.arr.shape[0] - self.idx)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
||||
with _open_npy_file(path, arr_name) as arr_f:
|
||||
version = np.lib.format.read_magic(arr_f)
|
||||
if version == (1, 0):
|
||||
header = np.lib.format.read_array_header_1_0(arr_f)
|
||||
elif version == (2, 0):
|
||||
header = np.lib.format.read_array_header_2_0(arr_f)
|
||||
else:
|
||||
yield MemoryNpzArrayReader.load(path, arr_name)
|
||||
return
|
||||
shape, fortran, dtype = header
|
||||
if fortran or dtype.hasobject:
|
||||
yield MemoryNpzArrayReader.load(path, arr_name)
|
||||
else:
|
||||
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
||||
|
||||
|
||||
def _read_bytes(fp, size, error_template="ran out of data"):
|
||||
"""
|
||||
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
||||
|
||||
Read from file-like object until size bytes are read.
|
||||
Raises ValueError if not EOF is encountered before size bytes are read.
|
||||
Non-blocking objects only supported if they derive from io objects.
|
||||
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
||||
requested.
|
||||
"""
|
||||
data = bytes()
|
||||
while True:
|
||||
# io files (default in python3) return None or raise on
|
||||
# would-block, python2 file will truncate, probably nothing can be
|
||||
# done about that. note that regular files can't be non-blocking
|
||||
try:
|
||||
r = fp.read(size - len(data))
|
||||
data += r
|
||||
if len(r) == 0 or len(data) == size:
|
||||
break
|
||||
except io.BlockingIOError:
|
||||
pass
|
||||
if len(data) != size:
|
||||
msg = "EOF: reading %s, expected %d bytes got %d"
|
||||
raise ValueError(msg % (error_template, size, len(data)))
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _open_npy_file(path: str, arr_name: str):
|
||||
with open(path, "rb") as f:
|
||||
with zipfile.ZipFile(f, "r") as zip_f:
|
||||
if f"{arr_name}.npy" not in zip_f.namelist():
|
||||
raise ValueError(f"missing {arr_name} in npz file")
|
||||
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
||||
yield arr_f
|
||||
|
||||
|
||||
def _download_inception_model():
|
||||
if os.path.exists(INCEPTION_V3_PATH):
|
||||
return
|
||||
print("downloading InceptionV3 model...")
|
||||
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
||||
with open(tmp_path, "wb") as f:
|
||||
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
||||
f.write(chunk)
|
||||
os.rename(tmp_path, INCEPTION_V3_PATH)
|
||||
|
||||
|
||||
def _create_feature_graph(input_batch):
|
||||
_download_inception_model()
|
||||
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
||||
with open(INCEPTION_V3_PATH, "rb") as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
pool3, spatial = tf.import_graph_def(
|
||||
graph_def,
|
||||
input_map={f"ExpandDims:0": input_batch},
|
||||
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
||||
name=prefix,
|
||||
)
|
||||
_update_shapes(pool3)
|
||||
spatial = spatial[..., :7]
|
||||
return pool3, spatial
|
||||
|
||||
|
||||
def _create_softmax_graph(input_batch):
|
||||
_download_inception_model()
|
||||
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
||||
with open(INCEPTION_V3_PATH, "rb") as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
(matmul,) = tf.import_graph_def(
|
||||
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
||||
)
|
||||
w = matmul.inputs[1]
|
||||
logits = tf.matmul(input_batch, w)
|
||||
return tf.nn.softmax(logits)
|
||||
|
||||
|
||||
def _update_shapes(pool3):
|
||||
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
||||
ops = pool3.graph.get_operations()
|
||||
for op in ops:
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims is not None: # pylint: disable=protected-access
|
||||
# shape = [s.value for s in shape] TF 1.x
|
||||
shape = [s for s in shape] # TF 2.x
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
||||
return pool3
|
||||
|
||||
|
||||
def _numpy_partition(arr, kth, **kwargs):
|
||||
num_workers = min(cpu_count(), len(arr))
|
||||
chunk_size = len(arr) // num_workers
|
||||
extra = len(arr) % num_workers
|
||||
|
||||
start_idx = 0
|
||||
batches = []
|
||||
for i in range(num_workers):
|
||||
size = chunk_size + (1 if i < extra else 0)
|
||||
batches.append(arr[start_idx : start_idx + size])
|
||||
start_idx += size
|
||||
|
||||
with ThreadPool(num_workers) as pool:
|
||||
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(REQUIREMENTS)
|
||||
main()
|
606
extern/ldm_zero123/modules/evaluate/evaluate_perceptualsim.py
vendored
Executable file
606
extern/ldm_zero123/modules/evaluate/evaluate_perceptualsim.py
vendored
Executable file
@ -0,0 +1,606 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torchvision import models
|
||||
from tqdm import tqdm
|
||||
|
||||
from extern.ldm_zero123.modules.evaluate.ssim import ssim
|
||||
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
|
||||
|
||||
def normalize_tensor(in_feat, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1)).view(
|
||||
in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]
|
||||
)
|
||||
return in_feat / (norm_factor.expand_as(in_feat) + eps)
|
||||
|
||||
|
||||
def cos_sim(in0, in1):
|
||||
in0_norm = normalize_tensor(in0)
|
||||
in1_norm = normalize_tensor(in1)
|
||||
N = in0.size()[0]
|
||||
X = in0.size()[2]
|
||||
Y = in0.size()[3]
|
||||
|
||||
return torch.mean(
|
||||
torch.mean(torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2).view(
|
||||
N, 1, 1, Y
|
||||
),
|
||||
dim=3,
|
||||
).view(N)
|
||||
|
||||
|
||||
class squeezenet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(squeezenet, self).__init__()
|
||||
pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.slice6 = torch.nn.Sequential()
|
||||
self.slice7 = torch.nn.Sequential()
|
||||
self.N_slices = 7
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), pretrained_features[x])
|
||||
for x in range(2, 5):
|
||||
self.slice2.add_module(str(x), pretrained_features[x])
|
||||
for x in range(5, 8):
|
||||
self.slice3.add_module(str(x), pretrained_features[x])
|
||||
for x in range(8, 10):
|
||||
self.slice4.add_module(str(x), pretrained_features[x])
|
||||
for x in range(10, 11):
|
||||
self.slice5.add_module(str(x), pretrained_features[x])
|
||||
for x in range(11, 12):
|
||||
self.slice6.add_module(str(x), pretrained_features[x])
|
||||
for x in range(12, 13):
|
||||
self.slice7.add_module(str(x), pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5 = h
|
||||
h = self.slice6(h)
|
||||
h_relu6 = h
|
||||
h = self.slice7(h)
|
||||
h_relu7 = h
|
||||
vgg_outputs = namedtuple(
|
||||
"SqueezeOutputs",
|
||||
["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
|
||||
)
|
||||
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class alexnet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(alexnet, self).__init__()
|
||||
alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(2, 5):
|
||||
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(5, 8):
|
||||
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(8, 10):
|
||||
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(10, 12):
|
||||
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5 = h
|
||||
alexnet_outputs = namedtuple(
|
||||
"AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
|
||||
)
|
||||
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class vgg16(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(vgg16, self).__init__()
|
||||
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(4):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(4, 9):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(9, 16):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(16, 23):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(23, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1_2 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2_2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3_3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4_3 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5_3 = h
|
||||
vgg_outputs = namedtuple(
|
||||
"VggOutputs",
|
||||
["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"],
|
||||
)
|
||||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class resnet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
||||
super(resnet, self).__init__()
|
||||
if num == 18:
|
||||
self.net = models.resnet18(pretrained=pretrained)
|
||||
elif num == 34:
|
||||
self.net = models.resnet34(pretrained=pretrained)
|
||||
elif num == 50:
|
||||
self.net = models.resnet50(pretrained=pretrained)
|
||||
elif num == 101:
|
||||
self.net = models.resnet101(pretrained=pretrained)
|
||||
elif num == 152:
|
||||
self.net = models.resnet152(pretrained=pretrained)
|
||||
self.N_slices = 5
|
||||
|
||||
self.conv1 = self.net.conv1
|
||||
self.bn1 = self.net.bn1
|
||||
self.relu = self.net.relu
|
||||
self.maxpool = self.net.maxpool
|
||||
self.layer1 = self.net.layer1
|
||||
self.layer2 = self.net.layer2
|
||||
self.layer3 = self.net.layer3
|
||||
self.layer4 = self.net.layer4
|
||||
|
||||
def forward(self, X):
|
||||
h = self.conv1(X)
|
||||
h = self.bn1(h)
|
||||
h = self.relu(h)
|
||||
h_relu1 = h
|
||||
h = self.maxpool(h)
|
||||
h = self.layer1(h)
|
||||
h_conv2 = h
|
||||
h = self.layer2(h)
|
||||
h_conv3 = h
|
||||
h = self.layer3(h)
|
||||
h_conv4 = h
|
||||
h = self.layer4(h)
|
||||
h_conv5 = h
|
||||
|
||||
outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"])
|
||||
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# Off-the-shelf deep network
|
||||
class PNet(torch.nn.Module):
|
||||
"""Pre-trained network with all channels equally weighted by default"""
|
||||
|
||||
def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True):
|
||||
super(PNet, self).__init__()
|
||||
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
self.pnet_type = pnet_type
|
||||
self.pnet_rand = pnet_rand
|
||||
|
||||
self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)
|
||||
self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)
|
||||
|
||||
if self.pnet_type in ["vgg", "vgg16"]:
|
||||
self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False)
|
||||
elif self.pnet_type == "alex":
|
||||
self.net = alexnet(pretrained=not self.pnet_rand, requires_grad=False)
|
||||
elif self.pnet_type[:-2] == "resnet":
|
||||
self.net = resnet(
|
||||
pretrained=not self.pnet_rand,
|
||||
requires_grad=False,
|
||||
num=int(self.pnet_type[-2:]),
|
||||
)
|
||||
elif self.pnet_type == "squeeze":
|
||||
self.net = squeezenet(pretrained=not self.pnet_rand, requires_grad=False)
|
||||
|
||||
self.L = self.net.N_slices
|
||||
|
||||
if use_gpu:
|
||||
self.net.cuda()
|
||||
self.shift = self.shift.cuda()
|
||||
self.scale = self.scale.cuda()
|
||||
|
||||
def forward(self, in0, in1, retPerLayer=False):
|
||||
in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
|
||||
in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
|
||||
|
||||
outs0 = self.net.forward(in0_sc)
|
||||
outs1 = self.net.forward(in1_sc)
|
||||
|
||||
if retPerLayer:
|
||||
all_scores = []
|
||||
for kk, out0 in enumerate(outs0):
|
||||
cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk])
|
||||
if kk == 0:
|
||||
val = 1.0 * cur_score
|
||||
else:
|
||||
val = val + cur_score
|
||||
if retPerLayer:
|
||||
all_scores += [cur_score]
|
||||
|
||||
if retPerLayer:
|
||||
return (val, all_scores)
|
||||
else:
|
||||
return val
|
||||
|
||||
|
||||
# The SSIM metric
|
||||
def ssim_metric(img1, img2, mask=None):
|
||||
return ssim(img1, img2, mask=mask, size_average=False)
|
||||
|
||||
|
||||
# The PSNR metric
|
||||
def psnr(img1, img2, mask=None, reshape=False):
|
||||
b = img1.size(0)
|
||||
if not (mask is None):
|
||||
b = img1.size(0)
|
||||
mse_err = (img1 - img2).pow(2) * mask
|
||||
if reshape:
|
||||
mse_err = mse_err.reshape(b, -1).sum(dim=1) / (
|
||||
3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1)
|
||||
)
|
||||
else:
|
||||
mse_err = mse_err.view(b, -1).sum(dim=1) / (
|
||||
3 * mask.view(b, -1).sum(dim=1).clamp(min=1)
|
||||
)
|
||||
else:
|
||||
if reshape:
|
||||
mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1)
|
||||
else:
|
||||
mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1)
|
||||
|
||||
psnr = 10 * (1 / mse_err).log10()
|
||||
return psnr
|
||||
|
||||
|
||||
# The perceptual similarity metric
|
||||
def perceptual_sim(img1, img2, vgg16):
|
||||
# First extract features
|
||||
dist = vgg16(img1 * 2 - 1, img2 * 2 - 1)
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def load_img(img_name, size=None):
|
||||
try:
|
||||
img = Image.open(img_name)
|
||||
|
||||
if type(size) == int:
|
||||
img = img.resize((size, size))
|
||||
elif size is not None:
|
||||
img = img.resize((size[1], size[0]))
|
||||
|
||||
img = transform(img).cuda()
|
||||
img = img.unsqueeze(0)
|
||||
except Exception as e:
|
||||
print("Failed at loading %s " % img_name)
|
||||
print(e)
|
||||
img = torch.zeros(1, 3, 256, 256).cuda()
|
||||
raise
|
||||
return img
|
||||
|
||||
|
||||
def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other):
|
||||
# Load VGG16 for feature similarity
|
||||
vgg16 = PNet().to("cuda")
|
||||
vgg16.eval()
|
||||
vgg16.cuda()
|
||||
|
||||
values_percsim = []
|
||||
values_ssim = []
|
||||
values_psnr = []
|
||||
folders = os.listdir(folder)
|
||||
for i, f in tqdm(enumerate(sorted(folders))):
|
||||
pred_imgs = glob.glob(folder + f + "/" + pred_img)
|
||||
tgt_imgs = glob.glob(folder + f + "/" + tgt_img)
|
||||
assert len(tgt_imgs) == 1
|
||||
|
||||
perc_sim = 10000
|
||||
ssim_sim = -10
|
||||
psnr_sim = -10
|
||||
for p_img in pred_imgs:
|
||||
t_img = load_img(tgt_imgs[0])
|
||||
p_img = load_img(p_img, size=t_img.shape[2:])
|
||||
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
|
||||
perc_sim = min(perc_sim, t_perc_sim)
|
||||
|
||||
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
|
||||
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
|
||||
|
||||
values_percsim += [perc_sim]
|
||||
values_ssim += [ssim_sim]
|
||||
values_psnr += [psnr_sim]
|
||||
|
||||
if take_every_other:
|
||||
n_valuespercsim = []
|
||||
n_valuesssim = []
|
||||
n_valuespsnr = []
|
||||
for i in range(0, len(values_percsim) // 2):
|
||||
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
|
||||
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
|
||||
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
|
||||
|
||||
values_percsim = n_valuespercsim
|
||||
values_ssim = n_valuesssim
|
||||
values_psnr = n_valuespsnr
|
||||
|
||||
avg_percsim = np.mean(np.array(values_percsim))
|
||||
std_percsim = np.std(np.array(values_percsim))
|
||||
|
||||
avg_psnr = np.mean(np.array(values_psnr))
|
||||
std_psnr = np.std(np.array(values_psnr))
|
||||
|
||||
avg_ssim = np.mean(np.array(values_ssim))
|
||||
std_ssim = np.std(np.array(values_ssim))
|
||||
|
||||
return {
|
||||
"Perceptual similarity": (avg_percsim, std_percsim),
|
||||
"PSNR": (avg_psnr, std_psnr),
|
||||
"SSIM": (avg_ssim, std_ssim),
|
||||
}
|
||||
|
||||
|
||||
def compute_perceptual_similarity_from_list(
|
||||
pred_imgs_list, tgt_imgs_list, take_every_other, simple_format=True
|
||||
):
|
||||
# Load VGG16 for feature similarity
|
||||
vgg16 = PNet().to("cuda")
|
||||
vgg16.eval()
|
||||
vgg16.cuda()
|
||||
|
||||
values_percsim = []
|
||||
values_ssim = []
|
||||
values_psnr = []
|
||||
equal_count = 0
|
||||
ambig_count = 0
|
||||
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
|
||||
pred_imgs = pred_imgs_list[i]
|
||||
tgt_imgs = [tgt_img]
|
||||
assert len(tgt_imgs) == 1
|
||||
|
||||
if type(pred_imgs) != list:
|
||||
pred_imgs = [pred_imgs]
|
||||
|
||||
perc_sim = 10000
|
||||
ssim_sim = -10
|
||||
psnr_sim = -10
|
||||
assert len(pred_imgs) > 0
|
||||
for p_img in pred_imgs:
|
||||
t_img = load_img(tgt_imgs[0])
|
||||
p_img = load_img(p_img, size=t_img.shape[2:])
|
||||
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
|
||||
perc_sim = min(perc_sim, t_perc_sim)
|
||||
|
||||
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
|
||||
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
|
||||
|
||||
values_percsim += [perc_sim]
|
||||
values_ssim += [ssim_sim]
|
||||
if psnr_sim != np.float("inf"):
|
||||
values_psnr += [psnr_sim]
|
||||
else:
|
||||
if torch.allclose(p_img, t_img):
|
||||
equal_count += 1
|
||||
print("{} equal src and wrp images.".format(equal_count))
|
||||
else:
|
||||
ambig_count += 1
|
||||
print("{} ambiguous src and wrp images.".format(ambig_count))
|
||||
|
||||
if take_every_other:
|
||||
n_valuespercsim = []
|
||||
n_valuesssim = []
|
||||
n_valuespsnr = []
|
||||
for i in range(0, len(values_percsim) // 2):
|
||||
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
|
||||
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
|
||||
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
|
||||
|
||||
values_percsim = n_valuespercsim
|
||||
values_ssim = n_valuesssim
|
||||
values_psnr = n_valuespsnr
|
||||
|
||||
avg_percsim = np.mean(np.array(values_percsim))
|
||||
std_percsim = np.std(np.array(values_percsim))
|
||||
|
||||
avg_psnr = np.mean(np.array(values_psnr))
|
||||
std_psnr = np.std(np.array(values_psnr))
|
||||
|
||||
avg_ssim = np.mean(np.array(values_ssim))
|
||||
std_ssim = np.std(np.array(values_ssim))
|
||||
|
||||
if simple_format:
|
||||
# just to make yaml formatting readable
|
||||
return {
|
||||
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
|
||||
"PSNR": [float(avg_psnr), float(std_psnr)],
|
||||
"SSIM": [float(avg_ssim), float(std_ssim)],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"Perceptual similarity": (avg_percsim, std_percsim),
|
||||
"PSNR": (avg_psnr, std_psnr),
|
||||
"SSIM": (avg_ssim, std_ssim),
|
||||
}
|
||||
|
||||
|
||||
def compute_perceptual_similarity_from_list_topk(
|
||||
pred_imgs_list, tgt_imgs_list, take_every_other, resize=False
|
||||
):
|
||||
# Load VGG16 for feature similarity
|
||||
vgg16 = PNet().to("cuda")
|
||||
vgg16.eval()
|
||||
vgg16.cuda()
|
||||
|
||||
values_percsim = []
|
||||
values_ssim = []
|
||||
values_psnr = []
|
||||
individual_percsim = []
|
||||
individual_ssim = []
|
||||
individual_psnr = []
|
||||
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
|
||||
pred_imgs = pred_imgs_list[i]
|
||||
tgt_imgs = [tgt_img]
|
||||
assert len(tgt_imgs) == 1
|
||||
|
||||
if type(pred_imgs) != list:
|
||||
assert False
|
||||
pred_imgs = [pred_imgs]
|
||||
|
||||
perc_sim = 10000
|
||||
ssim_sim = -10
|
||||
psnr_sim = -10
|
||||
sample_percsim = list()
|
||||
sample_ssim = list()
|
||||
sample_psnr = list()
|
||||
for p_img in pred_imgs:
|
||||
if resize:
|
||||
t_img = load_img(tgt_imgs[0], size=(256, 256))
|
||||
else:
|
||||
t_img = load_img(tgt_imgs[0])
|
||||
p_img = load_img(p_img, size=t_img.shape[2:])
|
||||
|
||||
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
|
||||
sample_percsim.append(t_perc_sim)
|
||||
perc_sim = min(perc_sim, t_perc_sim)
|
||||
|
||||
t_ssim = ssim_metric(p_img, t_img).item()
|
||||
sample_ssim.append(t_ssim)
|
||||
ssim_sim = max(ssim_sim, t_ssim)
|
||||
|
||||
t_psnr = psnr(p_img, t_img).item()
|
||||
sample_psnr.append(t_psnr)
|
||||
psnr_sim = max(psnr_sim, t_psnr)
|
||||
|
||||
values_percsim += [perc_sim]
|
||||
values_ssim += [ssim_sim]
|
||||
values_psnr += [psnr_sim]
|
||||
individual_percsim.append(sample_percsim)
|
||||
individual_ssim.append(sample_ssim)
|
||||
individual_psnr.append(sample_psnr)
|
||||
|
||||
if take_every_other:
|
||||
assert False, "Do this later, after specifying topk to get proper results"
|
||||
n_valuespercsim = []
|
||||
n_valuesssim = []
|
||||
n_valuespsnr = []
|
||||
for i in range(0, len(values_percsim) // 2):
|
||||
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
|
||||
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
|
||||
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
|
||||
|
||||
values_percsim = n_valuespercsim
|
||||
values_ssim = n_valuesssim
|
||||
values_psnr = n_valuespsnr
|
||||
|
||||
avg_percsim = np.mean(np.array(values_percsim))
|
||||
std_percsim = np.std(np.array(values_percsim))
|
||||
|
||||
avg_psnr = np.mean(np.array(values_psnr))
|
||||
std_psnr = np.std(np.array(values_psnr))
|
||||
|
||||
avg_ssim = np.mean(np.array(values_ssim))
|
||||
std_ssim = np.std(np.array(values_ssim))
|
||||
|
||||
individual_percsim = np.array(individual_percsim)
|
||||
individual_psnr = np.array(individual_psnr)
|
||||
individual_ssim = np.array(individual_ssim)
|
||||
|
||||
return {
|
||||
"avg_of_best": {
|
||||
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
|
||||
"PSNR": [float(avg_psnr), float(std_psnr)],
|
||||
"SSIM": [float(avg_ssim), float(std_ssim)],
|
||||
},
|
||||
"individual": {
|
||||
"PSIM": individual_percsim,
|
||||
"PSNR": individual_psnr,
|
||||
"SSIM": individual_ssim,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument("--folder", type=str, default="")
|
||||
args.add_argument("--pred_image", type=str, default="")
|
||||
args.add_argument("--target_image", type=str, default="")
|
||||
args.add_argument("--take_every_other", action="store_true", default=False)
|
||||
args.add_argument("--output_file", type=str, default="")
|
||||
|
||||
opts = args.parse_args()
|
||||
|
||||
folder = opts.folder
|
||||
pred_img = opts.pred_image
|
||||
tgt_img = opts.target_image
|
||||
|
||||
results = compute_perceptual_similarity(
|
||||
folder, pred_img, tgt_img, opts.take_every_other
|
||||
)
|
||||
|
||||
f = open(opts.output_file, "w")
|
||||
for key in results:
|
||||
print("%s for %s: \n" % (key, opts.folder))
|
||||
print("\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]))
|
||||
|
||||
f.write("%s for %s: \n" % (key, opts.folder))
|
||||
f.write("\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]))
|
||||
|
||||
f.close()
|
147
extern/ldm_zero123/modules/evaluate/frechet_video_distance.py
vendored
Executable file
147
extern/ldm_zero123/modules/evaluate/frechet_video_distance.py
vendored
Executable file
@ -0,0 +1,147 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python2, python3
|
||||
"""Minimal Reference implementation for the Frechet Video Distance (FVD).
|
||||
|
||||
FVD is a metric for the quality of video generation models. It is inspired by
|
||||
the FID (Frechet Inception Distance) used for images, but uses a different
|
||||
embedding to be better suitable for videos.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import six
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_gan as tfgan
|
||||
import tensorflow_hub as hub
|
||||
|
||||
|
||||
def preprocess(videos, target_resolution):
|
||||
"""Runs some preprocessing on the videos for I3D model.
|
||||
|
||||
Args:
|
||||
videos: <T>[batch_size, num_frames, height, width, depth] The videos to be
|
||||
preprocessed. We don't care about the specific dtype of the videos, it can
|
||||
be anything that tf.image.resize_bilinear accepts. Values are expected to
|
||||
be in the range 0-255.
|
||||
target_resolution: (width, height): target video resolution
|
||||
|
||||
Returns:
|
||||
videos: <float32>[batch_size, num_frames, height, width, depth]
|
||||
"""
|
||||
videos_shape = list(videos.shape)
|
||||
all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
|
||||
resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
|
||||
target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
|
||||
output_videos = tf.reshape(resized_videos, target_shape)
|
||||
scaled_videos = 2.0 * tf.cast(output_videos, tf.float32) / 255.0 - 1
|
||||
return scaled_videos
|
||||
|
||||
|
||||
def _is_in_graph(tensor_name):
|
||||
"""Checks whether a given tensor does exists in the graph."""
|
||||
try:
|
||||
tf.get_default_graph().get_tensor_by_name(tensor_name)
|
||||
except KeyError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def create_id3_embedding(videos, warmup=False, batch_size=16):
|
||||
"""Embeds the given videos using the Inflated 3D Convolution ne twork.
|
||||
|
||||
Downloads the graph of the I3D from tf.hub and adds it to the graph on the
|
||||
first call.
|
||||
|
||||
Args:
|
||||
videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3].
|
||||
Expected range is [-1, 1].
|
||||
|
||||
Returns:
|
||||
embedding: <float32>[batch_size, embedding_size]. embedding_size depends
|
||||
on the model used.
|
||||
|
||||
Raises:
|
||||
ValueError: when a provided embedding_layer is not supported.
|
||||
"""
|
||||
|
||||
# batch_size = 16
|
||||
module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
|
||||
|
||||
# Making sure that we import the graph separately for
|
||||
# each different input video tensor.
|
||||
module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(videos.name).replace(
|
||||
":", "_"
|
||||
)
|
||||
|
||||
assert_ops = [
|
||||
tf.Assert(
|
||||
tf.reduce_max(videos) <= 1.001, ["max value in frame is > 1", videos]
|
||||
),
|
||||
tf.Assert(
|
||||
tf.reduce_min(videos) >= -1.001, ["min value in frame is < -1", videos]
|
||||
),
|
||||
tf.assert_equal(
|
||||
tf.shape(videos)[0],
|
||||
batch_size,
|
||||
["invalid frame batch size: ", tf.shape(videos)],
|
||||
summarize=6,
|
||||
),
|
||||
]
|
||||
with tf.control_dependencies(assert_ops):
|
||||
videos = tf.identity(videos)
|
||||
|
||||
module_scope = "%s_apply_default/" % module_name
|
||||
|
||||
# To check whether the module has already been loaded into the graph, we look
|
||||
# for a given tensor name. If this tensor name exists, we assume the function
|
||||
# has been called before and the graph was imported. Otherwise we import it.
|
||||
# Note: in theory, the tensor could exist, but have wrong shapes.
|
||||
# This will happen if create_id3_embedding is called with a frames_placehoder
|
||||
# of wrong size/batch size, because even though that will throw a tf.Assert
|
||||
# on graph-execution time, it will insert the tensor (with wrong shape) into
|
||||
# the graph. This is why we need the following assert.
|
||||
if warmup:
|
||||
video_batch_size = int(videos.shape[0])
|
||||
assert video_batch_size in [
|
||||
batch_size,
|
||||
-1,
|
||||
None,
|
||||
], f"Invalid batch size {video_batch_size}"
|
||||
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
|
||||
if not _is_in_graph(tensor_name):
|
||||
i3d_model = hub.Module(module_spec, name=module_name)
|
||||
i3d_model(videos)
|
||||
|
||||
# gets the kinetics-i3d-400-logits layer
|
||||
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
|
||||
tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
|
||||
return tensor
|
||||
|
||||
|
||||
def calculate_fvd(real_activations, generated_activations):
|
||||
"""Returns a list of ops that compute metrics as funcs of activations.
|
||||
|
||||
Args:
|
||||
real_activations: <float32>[num_samples, embedding_size]
|
||||
generated_activations: <float32>[num_samples, embedding_size]
|
||||
|
||||
Returns:
|
||||
A scalar that contains the requested FVD.
|
||||
"""
|
||||
return tfgan.eval.frechet_classifier_distance_from_activations(
|
||||
real_activations, generated_activations
|
||||
)
|
118
extern/ldm_zero123/modules/evaluate/ssim.py
vendored
Executable file
118
extern/ldm_zero123/modules/evaluate/ssim.py
vendored
Executable file
@ -0,0 +1,118 @@
|
||||
# MIT Licence
|
||||
|
||||
# Methods to predict the SSIM, taken from
|
||||
# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
|
||||
|
||||
from math import exp
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor(
|
||||
[
|
||||
exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
|
||||
for x in range(window_size)
|
||||
]
|
||||
)
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(
|
||||
_2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
||||
)
|
||||
return window
|
||||
|
||||
|
||||
def _ssim(img1, img2, window, window_size, channel, mask=None, size_average=True):
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = (
|
||||
F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
)
|
||||
sigma2_sq = (
|
||||
F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
)
|
||||
sigma12 = (
|
||||
F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
|
||||
- mu1_mu2
|
||||
)
|
||||
|
||||
C1 = (0.01) ** 2
|
||||
C2 = (0.03) ** 2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
|
||||
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
|
||||
)
|
||||
|
||||
if not (mask is None):
|
||||
b = mask.size(0)
|
||||
ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
|
||||
ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(dim=1).clamp(
|
||||
min=1
|
||||
)
|
||||
return ssim_map
|
||||
|
||||
import pdb
|
||||
|
||||
pdb.set_trace
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
else:
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True):
|
||||
super(SSIM, self).__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
|
||||
def forward(self, img1, img2, mask=None):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if channel == self.channel and self.window.data.type() == img1.data.type():
|
||||
window = self.window
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
return _ssim(
|
||||
img1,
|
||||
img2,
|
||||
window,
|
||||
self.window_size,
|
||||
channel,
|
||||
mask,
|
||||
self.size_average,
|
||||
)
|
||||
|
||||
|
||||
def ssim(img1, img2, window_size=11, mask=None, size_average=True):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
return _ssim(img1, img2, window, window_size, channel, mask, size_average)
|
331
extern/ldm_zero123/modules/evaluate/torch_frechet_video_distance.py
vendored
Executable file
331
extern/ldm_zero123/modules/evaluate/torch_frechet_video_distance.py
vendored
Executable file
@ -0,0 +1,331 @@
|
||||
# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!
|
||||
import glob
|
||||
import hashlib
|
||||
import html
|
||||
import io
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
import urllib
|
||||
import urllib.request
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import scipy.linalg
|
||||
import torch
|
||||
from torchvision.io import read_video
|
||||
from tqdm import tqdm
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
from einops import rearrange
|
||||
from nitro.util import isvideo
|
||||
|
||||
|
||||
def compute_frechet_distance(mu_sample, sigma_sample, mu_ref, sigma_ref) -> float:
|
||||
print("Calculate frechet distance...")
|
||||
m = np.square(mu_sample - mu_ref).sum()
|
||||
s, _ = scipy.linalg.sqrtm(
|
||||
np.dot(sigma_sample, sigma_ref), disp=False
|
||||
) # pylint: disable=no-member
|
||||
fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2))
|
||||
|
||||
return float(fid)
|
||||
|
||||
|
||||
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
mu = feats.mean(axis=0) # [d]
|
||||
sigma = np.cov(feats, rowvar=False) # [d, d]
|
||||
|
||||
return mu, sigma
|
||||
|
||||
|
||||
def open_url(
|
||||
url: str,
|
||||
num_attempts: int = 10,
|
||||
verbose: bool = True,
|
||||
return_filename: bool = False,
|
||||
) -> Any:
|
||||
"""Download the given URL and return a binary-mode file object to access the data."""
|
||||
assert num_attempts >= 1
|
||||
|
||||
# Doesn't look like an URL scheme so interpret it as a local filename.
|
||||
if not re.match("^[a-z]+://", url):
|
||||
return url if return_filename else open(url, "rb")
|
||||
|
||||
# Handle file URLs. This code handles unusual file:// patterns that
|
||||
# arise on Windows:
|
||||
#
|
||||
# file:///c:/foo.txt
|
||||
#
|
||||
# which would translate to a local '/c:/foo.txt' filename that's
|
||||
# invalid. Drop the forward slash for such pathnames.
|
||||
#
|
||||
# If you touch this code path, you should test it on both Linux and
|
||||
# Windows.
|
||||
#
|
||||
# Some internet resources suggest using urllib.request.url2pathname() but
|
||||
# but that converts forward slashes to backslashes and this causes
|
||||
# its own set of problems.
|
||||
if url.startswith("file://"):
|
||||
filename = urllib.parse.urlparse(url).path
|
||||
if re.match(r"^/[a-zA-Z]:", filename):
|
||||
filename = filename[1:]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
||||
|
||||
# Download.
|
||||
url_name = None
|
||||
url_data = None
|
||||
with requests.Session() as session:
|
||||
if verbose:
|
||||
print("Downloading %s ..." % url, end="", flush=True)
|
||||
for attempts_left in reversed(range(num_attempts)):
|
||||
try:
|
||||
with session.get(url) as res:
|
||||
res.raise_for_status()
|
||||
if len(res.content) == 0:
|
||||
raise IOError("No data received")
|
||||
|
||||
if len(res.content) < 8192:
|
||||
content_str = res.content.decode("utf-8")
|
||||
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
||||
links = [
|
||||
html.unescape(link)
|
||||
for link in content_str.split('"')
|
||||
if "export=download" in link
|
||||
]
|
||||
if len(links) == 1:
|
||||
url = requests.compat.urljoin(url, links[0])
|
||||
raise IOError("Google Drive virus checker nag")
|
||||
if "Google Drive - Quota exceeded" in content_str:
|
||||
raise IOError(
|
||||
"Google Drive download quota exceeded -- please try again later"
|
||||
)
|
||||
|
||||
match = re.search(
|
||||
r'filename="([^"]*)"',
|
||||
res.headers.get("Content-Disposition", ""),
|
||||
)
|
||||
url_name = match[1] if match else url
|
||||
url_data = res.content
|
||||
if verbose:
|
||||
print(" done")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except:
|
||||
if not attempts_left:
|
||||
if verbose:
|
||||
print(" failed")
|
||||
raise
|
||||
if verbose:
|
||||
print(".", end="", flush=True)
|
||||
|
||||
# Return data as file object.
|
||||
assert not return_filename
|
||||
return io.BytesIO(url_data)
|
||||
|
||||
|
||||
def load_video(ip):
|
||||
vid, *_ = read_video(ip)
|
||||
vid = rearrange(vid, "t h w c -> t c h w").to(torch.uint8)
|
||||
return vid
|
||||
|
||||
|
||||
def get_data_from_str(input_str, nprc=None):
|
||||
assert os.path.isdir(
|
||||
input_str
|
||||
), f'Specified input folder "{input_str}" is not a directory'
|
||||
vid_filelist = glob.glob(os.path.join(input_str, "*.mp4"))
|
||||
print(f"Found {len(vid_filelist)} videos in dir {input_str}")
|
||||
|
||||
if nprc is None:
|
||||
try:
|
||||
nprc = mp.cpu_count()
|
||||
except NotImplementedError:
|
||||
print(
|
||||
"WARNING: cpu_count() not avlailable, using only 1 cpu for video loading"
|
||||
)
|
||||
nprc = 1
|
||||
|
||||
pool = mp.Pool(processes=nprc)
|
||||
|
||||
vids = []
|
||||
for v in tqdm(
|
||||
pool.imap_unordered(load_video, vid_filelist),
|
||||
total=len(vid_filelist),
|
||||
desc="Loading videos...",
|
||||
):
|
||||
vids.append(v)
|
||||
|
||||
vids = torch.stack(vids, dim=0).float()
|
||||
|
||||
return vids
|
||||
|
||||
|
||||
def get_stats(stats):
|
||||
assert os.path.isfile(stats) and stats.endswith(
|
||||
".npz"
|
||||
), f"no stats found under {stats}"
|
||||
|
||||
print(f"Using precomputed statistics under {stats}")
|
||||
stats = np.load(stats)
|
||||
stats = {key: stats[key] for key in stats.files}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_fvd(
|
||||
ref_input, sample_input, bs=32, ref_stats=None, sample_stats=None, nprc_load=None
|
||||
):
|
||||
calc_stats = ref_stats is None or sample_stats is None
|
||||
|
||||
if calc_stats:
|
||||
only_ref = sample_stats is not None
|
||||
only_sample = ref_stats is not None
|
||||
|
||||
if isinstance(ref_input, str) and not only_sample:
|
||||
ref_input = get_data_from_str(ref_input, nprc_load)
|
||||
|
||||
if isinstance(sample_input, str) and not only_ref:
|
||||
sample_input = get_data_from_str(sample_input, nprc_load)
|
||||
|
||||
stats = compute_statistics(
|
||||
sample_input,
|
||||
ref_input,
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
bs=bs,
|
||||
only_ref=only_ref,
|
||||
only_sample=only_sample,
|
||||
)
|
||||
|
||||
if only_ref:
|
||||
stats.update(get_stats(sample_stats))
|
||||
elif only_sample:
|
||||
stats.update(get_stats(ref_stats))
|
||||
|
||||
else:
|
||||
stats = get_stats(sample_stats)
|
||||
stats.update(get_stats(ref_stats))
|
||||
|
||||
fvd = compute_frechet_distance(**stats)
|
||||
|
||||
return {
|
||||
"FVD": fvd,
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_statistics(
|
||||
videos_fake,
|
||||
videos_real,
|
||||
device: str = "cuda",
|
||||
bs=32,
|
||||
only_ref=False,
|
||||
only_sample=False,
|
||||
) -> Dict:
|
||||
detector_url = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1"
|
||||
detector_kwargs = dict(
|
||||
rescale=True, resize=True, return_features=True
|
||||
) # Return raw features before the softmax layer.
|
||||
|
||||
with open_url(detector_url, verbose=False) as f:
|
||||
detector = torch.jit.load(f).eval().to(device)
|
||||
|
||||
assert not (
|
||||
only_sample and only_ref
|
||||
), "only_ref and only_sample arguments are mutually exclusive"
|
||||
|
||||
ref_embed, sample_embed = [], []
|
||||
|
||||
info = f"Computing I3D activations for FVD score with batch size {bs}"
|
||||
|
||||
if only_ref:
|
||||
if not isvideo(videos_real):
|
||||
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
|
||||
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
|
||||
print(videos_real.shape)
|
||||
|
||||
if videos_real.shape[0] % bs == 0:
|
||||
n_secs = videos_real.shape[0] // bs
|
||||
else:
|
||||
n_secs = videos_real.shape[0] // bs + 1
|
||||
|
||||
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
|
||||
|
||||
for ref_v in tqdm(videos_real, total=len(videos_real), desc=info):
|
||||
feats_ref = (
|
||||
detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
|
||||
)
|
||||
ref_embed.append(feats_ref)
|
||||
|
||||
elif only_sample:
|
||||
if not isvideo(videos_fake):
|
||||
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
|
||||
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
|
||||
print(videos_fake.shape)
|
||||
|
||||
if videos_fake.shape[0] % bs == 0:
|
||||
n_secs = videos_fake.shape[0] // bs
|
||||
else:
|
||||
n_secs = videos_fake.shape[0] // bs + 1
|
||||
|
||||
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
|
||||
|
||||
for sample_v in tqdm(videos_fake, total=len(videos_real), desc=info):
|
||||
feats_sample = (
|
||||
detector(sample_v.to(device).contiguous(), **detector_kwargs)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
sample_embed.append(feats_sample)
|
||||
|
||||
else:
|
||||
if not isvideo(videos_real):
|
||||
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
|
||||
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
|
||||
|
||||
if not isvideo(videos_fake):
|
||||
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
|
||||
|
||||
if videos_fake.shape[0] % bs == 0:
|
||||
n_secs = videos_fake.shape[0] // bs
|
||||
else:
|
||||
n_secs = videos_fake.shape[0] // bs + 1
|
||||
|
||||
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
|
||||
videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0)
|
||||
|
||||
for ref_v, sample_v in tqdm(
|
||||
zip(videos_real, videos_fake), total=len(videos_fake), desc=info
|
||||
):
|
||||
# print(ref_v.shape)
|
||||
# ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
|
||||
# sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
|
||||
|
||||
feats_sample = (
|
||||
detector(sample_v.to(device).contiguous(), **detector_kwargs)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
feats_ref = (
|
||||
detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
|
||||
)
|
||||
sample_embed.append(feats_sample)
|
||||
ref_embed.append(feats_ref)
|
||||
|
||||
out = dict()
|
||||
if len(sample_embed) > 0:
|
||||
sample_embed = np.concatenate(sample_embed, axis=0)
|
||||
mu_sample, sigma_sample = compute_stats(sample_embed)
|
||||
out.update({"mu_sample": mu_sample, "sigma_sample": sigma_sample})
|
||||
|
||||
if len(ref_embed) > 0:
|
||||
ref_embed = np.concatenate(ref_embed, axis=0)
|
||||
mu_ref, sigma_ref = compute_stats(ref_embed)
|
||||
out.update({"mu_ref": mu_ref, "sigma_ref": sigma_ref})
|
||||
|
||||
return out
|
6
extern/ldm_zero123/modules/image_degradation/__init__.py
vendored
Executable file
6
extern/ldm_zero123/modules/image_degradation/__init__.py
vendored
Executable file
@ -0,0 +1,6 @@
|
||||
from extern.ldm_zero123.modules.image_degradation.bsrgan import (
|
||||
degradation_bsrgan_variant as degradation_fn_bsr,
|
||||
)
|
||||
from extern.ldm_zero123.modules.image_degradation.bsrgan_light import (
|
||||
degradation_bsrgan_variant as degradation_fn_bsr_light,
|
||||
)
|
809
extern/ldm_zero123/modules/image_degradation/bsrgan.py
vendored
Executable file
809
extern/ldm_zero123/modules/image_degradation/bsrgan.py
vendored
Executable file
@ -0,0 +1,809 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import albumentations
|
||||
import cv2
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
import torch
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
|
||||
import extern.ldm_zero123.modules.image_degradation.utils_image as util
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
"""
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
"""
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[: w - w % sf, : h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
"""generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(
|
||||
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
|
||||
np.array([1.0, 0.0]),
|
||||
)
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
"""
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
"""
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(
|
||||
k_size=np.array([15, 15]),
|
||||
scale_factor=np.array([4, 4]),
|
||||
min_var=0.6,
|
||||
max_var=10.0,
|
||||
noise_level=0,
|
||||
):
|
||||
""" "
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
"""
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
"""
|
||||
if filter_type == "gaussian":
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == "laplacian":
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
"""
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
"""
|
||||
x = util.imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
"""blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
"""
|
||||
x = ndimage.filters.convolve(
|
||||
x, np.expand_dims(k, axis=2), mode="wrap"
|
||||
) # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
"""bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
"""
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
"""blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
"""
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype("float32")
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(
|
||||
ksize=2 * random.randint(2, 11) + 3,
|
||||
theta=random.random() * np.pi,
|
||||
l1=l1,
|
||||
l2=l2,
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random())
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
# noise_level = random.randint(noise_level1, noise_level2)
|
||||
# rnum = np.random.rand()
|
||||
# if rnum > 0.6: # add color Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
# else: # add noise
|
||||
# L = noise_level2 / 255.
|
||||
# D = np.diag(np.random.rand(3))
|
||||
# U = orth(np.random.rand(3, 3))
|
||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
# img = np.clip(img, 0.0, 1.0)
|
||||
# return img
|
||||
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
||||
np.float32
|
||||
)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(
|
||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
||||
).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal(
|
||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
||||
).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
||||
np.float32
|
||||
)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(
|
||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
||||
).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal(
|
||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
||||
).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
|
||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
||||
noise_gray = (
|
||||
np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
)
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(30, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode(
|
||||
".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
|
||||
)
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[
|
||||
rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :
|
||||
]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
img = util.imresize_np(img, 1 / 2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
||||
shuffle_order[idx2],
|
||||
shuffle_order[idx1],
|
||||
)
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.filters.convolve(
|
||||
img, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
||||
)
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / sf * a), int(1 / sf * b)),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
# todo no isp_model?
|
||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
image = util.imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
||||
shuffle_order[idx2],
|
||||
shuffle_order[idx1],
|
||||
)
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.filters.convolve(
|
||||
image, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
||||
)
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / sf * a), int(1 / sf * b)),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
|
||||
# elif i == 6:
|
||||
# # add processed camera sensor noise
|
||||
# if random.random() < isp_prob and isp_model is not None:
|
||||
# with torch.no_grad():
|
||||
# img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = util.single2uint(image)
|
||||
example = {"image": image}
|
||||
return example
|
||||
|
||||
|
||||
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
|
||||
def degradation_bsrgan_plus(
|
||||
img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None
|
||||
):
|
||||
"""
|
||||
This is an extended degradation model by combining
|
||||
the degradation models of BSRGAN and Real-ESRGAN
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
use_shuffle: the degradation shuffle
|
||||
use_sharp: sharpening the img
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
||||
|
||||
if use_sharp:
|
||||
img = add_sharpening(img)
|
||||
hq = img.copy()
|
||||
|
||||
if random.random() < shuffle_prob:
|
||||
shuffle_order = random.sample(range(13), 13)
|
||||
else:
|
||||
shuffle_order = list(range(13))
|
||||
# local shuffle for noise, JPEG is always the last one
|
||||
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
||||
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
||||
|
||||
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 1:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 2:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 3:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 4:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 5:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
elif i == 6:
|
||||
img = add_JPEG_noise(img)
|
||||
elif i == 7:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 8:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 9:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 10:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 11:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 12:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
else:
|
||||
print("check the shuffle!")
|
||||
|
||||
# resize to desired size
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("hey")
|
||||
img = util.imread_uint("utils/test.png", 3)
|
||||
print(img)
|
||||
img = util.uint2single(img)
|
||||
print(img)
|
||||
img = img[:448, :448]
|
||||
h = img.shape[0] // 4
|
||||
print("resizing to", h)
|
||||
sf = 4
|
||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
||||
for i in range(20):
|
||||
print(i)
|
||||
img_lq = deg_fn(img)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(
|
||||
max_size=h, interpolation=cv2.INTER_CUBIC
|
||||
)(image=img)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(
|
||||
util.single2uint(img_lq),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
lq_bicubic_nearest = cv2.resize(
|
||||
util.single2uint(img_lq_bicubic),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
img_concat = np.concatenate(
|
||||
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
|
||||
)
|
||||
util.imsave(img_concat, str(i) + ".png")
|
720
extern/ldm_zero123/modules/image_degradation/bsrgan_light.py
vendored
Executable file
720
extern/ldm_zero123/modules/image_degradation/bsrgan_light.py
vendored
Executable file
@ -0,0 +1,720 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import albumentations
|
||||
import cv2
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
import torch
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
|
||||
import extern.ldm_zero123.modules.image_degradation.utils_image as util
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
"""
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
"""
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[: w - w % sf, : h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
"""generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(
|
||||
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
|
||||
np.array([1.0, 0.0]),
|
||||
)
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
"""
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
"""
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(
|
||||
k_size=np.array([15, 15]),
|
||||
scale_factor=np.array([4, 4]),
|
||||
min_var=0.6,
|
||||
max_var=10.0,
|
||||
noise_level=0,
|
||||
):
|
||||
""" "
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
"""
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
"""
|
||||
if filter_type == "gaussian":
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == "laplacian":
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
"""
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
"""
|
||||
x = util.imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
"""blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
"""
|
||||
x = ndimage.convolve(
|
||||
x, np.expand_dims(k, axis=2), mode="wrap"
|
||||
) # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
"""bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
"""
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
"""blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
"""
|
||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype("float32")
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
|
||||
wd2 = wd2 / 4
|
||||
wd = wd / 4
|
||||
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(
|
||||
ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random())
|
||||
img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
# noise_level = random.randint(noise_level1, noise_level2)
|
||||
# rnum = np.random.rand()
|
||||
# if rnum > 0.6: # add color Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
# else: # add noise
|
||||
# L = noise_level2 / 255.
|
||||
# D = np.diag(np.random.rand(3))
|
||||
# U = orth(np.random.rand(3, 3))
|
||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
# img = np.clip(img, 0.0, 1.0)
|
||||
# return img
|
||||
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
||||
np.float32
|
||||
)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(
|
||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
||||
).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal(
|
||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
||||
).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
||||
np.float32
|
||||
)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(
|
||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
||||
).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal(
|
||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
||||
).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
|
||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
||||
noise_gray = (
|
||||
np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
)
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(80, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode(
|
||||
".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
|
||||
)
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[
|
||||
rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :
|
||||
]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
img = util.imresize_np(img, 1 / 2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
||||
shuffle_order[idx2],
|
||||
shuffle_order[idx1],
|
||||
)
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.convolve(
|
||||
img, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
||||
)
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / sf * a), int(1 / sf * b)),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
# todo no isp_model?
|
||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
image = util.imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
||||
shuffle_order[idx2],
|
||||
shuffle_order[idx1],
|
||||
)
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
# elif i == 1:
|
||||
# image = add_blur(image, sf=sf)
|
||||
|
||||
if i == 0:
|
||||
pass
|
||||
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.8:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.convolve(
|
||||
image, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
||||
)
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / sf * a), int(1 / sf * b)),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
#
|
||||
# elif i == 6:
|
||||
# # add processed camera sensor noise
|
||||
# if random.random() < isp_prob and isp_model is not None:
|
||||
# with torch.no_grad():
|
||||
# img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = util.single2uint(image)
|
||||
example = {"image": image}
|
||||
return example
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("hey")
|
||||
img = util.imread_uint("utils/test.png", 3)
|
||||
img = img[:448, :448]
|
||||
h = img.shape[0] // 4
|
||||
print("resizing to", h)
|
||||
sf = 4
|
||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
||||
for i in range(20):
|
||||
print(i)
|
||||
img_hq = img
|
||||
img_lq = deg_fn(img)["image"]
|
||||
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(
|
||||
max_size=h, interpolation=cv2.INTER_CUBIC
|
||||
)(image=img_hq)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(
|
||||
util.single2uint(img_lq),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
lq_bicubic_nearest = cv2.resize(
|
||||
util.single2uint(img_lq_bicubic),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
img_concat = np.concatenate(
|
||||
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
|
||||
)
|
||||
util.imsave(img_concat, str(i) + ".png")
|
BIN
extern/ldm_zero123/modules/image_degradation/utils/test.png
vendored
Executable file
BIN
extern/ldm_zero123/modules/image_degradation/utils/test.png
vendored
Executable file
Binary file not shown.
After Width: | Height: | Size: 431 KiB |
988
extern/ldm_zero123/modules/image_degradation/utils_image.py
vendored
Executable file
988
extern/ldm_zero123/modules/image_degradation/utils_image.py
vendored
Executable file
@ -0,0 +1,988 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Kai Zhang (github: https://github.com/cszn)
|
||||
# 03/Mar/2019
|
||||
# --------------------------------------------
|
||||
# https://github.com/twhui/SRGAN-pyTorch
|
||||
# https://github.com/xinntao/BasicSR
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
".jpg",
|
||||
".JPG",
|
||||
".jpeg",
|
||||
".JPEG",
|
||||
".png",
|
||||
".PNG",
|
||||
".ppm",
|
||||
".PPM",
|
||||
".bmp",
|
||||
".BMP",
|
||||
".tif",
|
||||
]
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
|
||||
|
||||
def imshow(x, title=None, cbar=False, figsize=None):
|
||||
plt.figure(figsize=figsize)
|
||||
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
|
||||
if title:
|
||||
plt.title(title)
|
||||
if cbar:
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
|
||||
def surf(Z, cmap="rainbow", figsize=None):
|
||||
plt.figure(figsize=figsize)
|
||||
ax3 = plt.axes(projection="3d")
|
||||
|
||||
w, h = Z.shape[:2]
|
||||
xx = np.arange(0, w, 1)
|
||||
yy = np.arange(0, h, 1)
|
||||
X, Y = np.meshgrid(xx, yy)
|
||||
ax3.plot_surface(X, Y, Z, cmap=cmap)
|
||||
# ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
|
||||
plt.show()
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# get image pathes
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def get_image_paths(dataroot):
|
||||
paths = None # return None if dataroot is None
|
||||
if dataroot is not None:
|
||||
paths = sorted(_get_paths_from_images(dataroot))
|
||||
return paths
|
||||
|
||||
|
||||
def _get_paths_from_images(path):
|
||||
assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
|
||||
images = []
|
||||
for dirpath, _, fnames in sorted(os.walk(path)):
|
||||
for fname in sorted(fnames):
|
||||
if is_image_file(fname):
|
||||
img_path = os.path.join(dirpath, fname)
|
||||
images.append(img_path)
|
||||
assert images, "{:s} has no valid image file".format(path)
|
||||
return images
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# split large images into small images
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
|
||||
w, h = img.shape[:2]
|
||||
patches = []
|
||||
if w > p_max and h > p_max:
|
||||
w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
|
||||
h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
|
||||
w1.append(w - p_size)
|
||||
h1.append(h - p_size)
|
||||
# print(w1)
|
||||
# print(h1)
|
||||
for i in w1:
|
||||
for j in h1:
|
||||
patches.append(img[i : i + p_size, j : j + p_size, :])
|
||||
else:
|
||||
patches.append(img)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
def imssave(imgs, img_path):
|
||||
"""
|
||||
imgs: list, N images of size WxHxC
|
||||
"""
|
||||
img_name, ext = os.path.splitext(os.path.basename(img_path))
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
new_path = os.path.join(
|
||||
os.path.dirname(img_path), img_name + str("_s{:04d}".format(i)) + ".png"
|
||||
)
|
||||
cv2.imwrite(new_path, img)
|
||||
|
||||
|
||||
def split_imageset(
|
||||
original_dataroot,
|
||||
taget_dataroot,
|
||||
n_channels=3,
|
||||
p_size=800,
|
||||
p_overlap=96,
|
||||
p_max=1000,
|
||||
):
|
||||
"""
|
||||
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
|
||||
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
|
||||
will be splitted.
|
||||
Args:
|
||||
original_dataroot:
|
||||
taget_dataroot:
|
||||
p_size: size of small images
|
||||
p_overlap: patch size in training is a good choice
|
||||
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
|
||||
"""
|
||||
paths = get_image_paths(original_dataroot)
|
||||
for img_path in paths:
|
||||
# img_name, ext = os.path.splitext(os.path.basename(img_path))
|
||||
img = imread_uint(img_path, n_channels=n_channels)
|
||||
patches = patches_from_image(img, p_size, p_overlap, p_max)
|
||||
imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
|
||||
# if original_dataroot == taget_dataroot:
|
||||
# del img_path
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# makedir
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, str):
|
||||
mkdir(paths)
|
||||
else:
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
|
||||
|
||||
def mkdir_and_rename(path):
|
||||
if os.path.exists(path):
|
||||
new_name = path + "_archived_" + get_timestamp()
|
||||
print("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||
os.rename(path, new_name)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# read image from path
|
||||
# opencv is fast, but read BGR numpy image
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# get uint8 image of size HxWxn_channles (RGB)
|
||||
# --------------------------------------------
|
||||
def imread_uint(path, n_channels=3):
|
||||
# input: path
|
||||
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
||||
if n_channels == 1:
|
||||
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
||||
img = np.expand_dims(img, axis=2) # HxWx1
|
||||
elif n_channels == 3:
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
||||
else:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
||||
return img
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# matlab's imwrite
|
||||
# --------------------------------------------
|
||||
def imsave(img, img_path):
|
||||
img = np.squeeze(img)
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
|
||||
def imwrite(img, img_path):
|
||||
img = np.squeeze(img)
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# get single image of size HxWxn_channles (BGR)
|
||||
# --------------------------------------------
|
||||
def read_img(path):
|
||||
# read image by cv2
|
||||
# return: Numpy float32, HWC, BGR, [0,1]
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
||||
img = img.astype(np.float32) / 255.0
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
# some images have 4 channels
|
||||
if img.shape[2] > 3:
|
||||
img = img[:, :, :3]
|
||||
return img
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# image format conversion
|
||||
# --------------------------------------------
|
||||
# numpy(single) <---> numpy(unit)
|
||||
# numpy(single) <---> tensor
|
||||
# numpy(unit) <---> tensor
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(single) [0, 1] <---> numpy(unit)
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
def uint2single(img):
|
||||
return np.float32(img / 255.0)
|
||||
|
||||
|
||||
def single2uint(img):
|
||||
return np.uint8((img.clip(0, 1) * 255.0).round())
|
||||
|
||||
|
||||
def uint162single(img):
|
||||
return np.float32(img / 65535.0)
|
||||
|
||||
|
||||
def single2uint16(img):
|
||||
return np.uint16((img.clip(0, 1) * 65535.0).round())
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(unit) (HxWxC or HxW) <---> tensor
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
# convert uint to 4-dimensional torch tensor
|
||||
def uint2tensor4(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return (
|
||||
torch.from_numpy(np.ascontiguousarray(img))
|
||||
.permute(2, 0, 1)
|
||||
.float()
|
||||
.div(255.0)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
|
||||
|
||||
# convert uint to 3-dimensional torch tensor
|
||||
def uint2tensor3(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return (
|
||||
torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0)
|
||||
)
|
||||
|
||||
|
||||
# convert 2/3/4-dimensional torch tensor to uint
|
||||
def tensor2uint(img):
|
||||
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
return np.uint8((img * 255.0).round())
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(single) (HxWxC) <---> tensor
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
# convert single (HxWxC) to 3-dimensional torch tensor
|
||||
def single2tensor3(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
||||
|
||||
|
||||
# convert single (HxWxC) to 4-dimensional torch tensor
|
||||
def single2tensor4(img):
|
||||
return (
|
||||
torch.from_numpy(np.ascontiguousarray(img))
|
||||
.permute(2, 0, 1)
|
||||
.float()
|
||||
.unsqueeze(0)
|
||||
)
|
||||
|
||||
|
||||
# convert torch tensor to single
|
||||
def tensor2single(img):
|
||||
img = img.data.squeeze().float().cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# convert torch tensor to single
|
||||
def tensor2single3(img):
|
||||
img = img.data.squeeze().float().cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
elif img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return img
|
||||
|
||||
|
||||
def single2tensor5(img):
|
||||
return (
|
||||
torch.from_numpy(np.ascontiguousarray(img))
|
||||
.permute(2, 0, 1, 3)
|
||||
.float()
|
||||
.unsqueeze(0)
|
||||
)
|
||||
|
||||
|
||||
def single32tensor5(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
def single42tensor4(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
|
||||
|
||||
|
||||
# from skimage.io import imread, imsave
|
||||
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
||||
"""
|
||||
Converts a torch Tensor into an image Numpy array of BGR channel order
|
||||
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
||||
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
||||
"""
|
||||
tensor = (
|
||||
tensor.squeeze().float().cpu().clamp_(*min_max)
|
||||
) # squeeze first, then clamp
|
||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
||||
n_dim = tensor.dim()
|
||||
if n_dim == 4:
|
||||
n_img = len(tensor)
|
||||
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||
elif n_dim == 3:
|
||||
img_np = tensor.numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||
elif n_dim == 2:
|
||||
img_np = tensor.numpy()
|
||||
else:
|
||||
raise TypeError(
|
||||
"Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(
|
||||
n_dim
|
||||
)
|
||||
)
|
||||
if out_type == np.uint8:
|
||||
img_np = (img_np * 255.0).round()
|
||||
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
||||
return img_np.astype(out_type)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Augmentation, flipe and/or rotate
|
||||
# --------------------------------------------
|
||||
# The following two are enough.
|
||||
# (1) augmet_img: numpy image of WxHxC or WxH
|
||||
# (2) augment_img_tensor4: tensor image 1xCxWxH
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def augment_img(img, mode=0):
|
||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return np.flipud(np.rot90(img))
|
||||
elif mode == 2:
|
||||
return np.flipud(img)
|
||||
elif mode == 3:
|
||||
return np.rot90(img, k=3)
|
||||
elif mode == 4:
|
||||
return np.flipud(np.rot90(img, k=2))
|
||||
elif mode == 5:
|
||||
return np.rot90(img)
|
||||
elif mode == 6:
|
||||
return np.rot90(img, k=2)
|
||||
elif mode == 7:
|
||||
return np.flipud(np.rot90(img, k=3))
|
||||
|
||||
|
||||
def augment_img_tensor4(img, mode=0):
|
||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return img.rot90(1, [2, 3]).flip([2])
|
||||
elif mode == 2:
|
||||
return img.flip([2])
|
||||
elif mode == 3:
|
||||
return img.rot90(3, [2, 3])
|
||||
elif mode == 4:
|
||||
return img.rot90(2, [2, 3]).flip([2])
|
||||
elif mode == 5:
|
||||
return img.rot90(1, [2, 3])
|
||||
elif mode == 6:
|
||||
return img.rot90(2, [2, 3])
|
||||
elif mode == 7:
|
||||
return img.rot90(3, [2, 3]).flip([2])
|
||||
|
||||
|
||||
def augment_img_tensor(img, mode=0):
|
||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
||||
img_size = img.size()
|
||||
img_np = img.data.cpu().numpy()
|
||||
if len(img_size) == 3:
|
||||
img_np = np.transpose(img_np, (1, 2, 0))
|
||||
elif len(img_size) == 4:
|
||||
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
||||
img_np = augment_img(img_np, mode=mode)
|
||||
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
||||
if len(img_size) == 3:
|
||||
img_tensor = img_tensor.permute(2, 0, 1)
|
||||
elif len(img_size) == 4:
|
||||
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
||||
|
||||
return img_tensor.type_as(img)
|
||||
|
||||
|
||||
def augment_img_np3(img, mode=0):
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return img.transpose(1, 0, 2)
|
||||
elif mode == 2:
|
||||
return img[::-1, :, :]
|
||||
elif mode == 3:
|
||||
img = img[::-1, :, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
elif mode == 4:
|
||||
return img[:, ::-1, :]
|
||||
elif mode == 5:
|
||||
img = img[:, ::-1, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
elif mode == 6:
|
||||
img = img[:, ::-1, :]
|
||||
img = img[::-1, :, :]
|
||||
return img
|
||||
elif mode == 7:
|
||||
img = img[:, ::-1, :]
|
||||
img = img[::-1, :, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
|
||||
def augment_imgs(img_list, hflip=True, rot=True):
|
||||
# horizontal flip OR rotate
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rot and random.random() < 0.5
|
||||
rot90 = rot and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
img = img[::-1, :, :]
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
return [_augment(img) for img in img_list]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# modcrop and shave
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def modcrop(img_in, scale):
|
||||
# img_in: Numpy, HWC or HW
|
||||
img = np.copy(img_in)
|
||||
if img.ndim == 2:
|
||||
H, W = img.shape
|
||||
H_r, W_r = H % scale, W % scale
|
||||
img = img[: H - H_r, : W - W_r]
|
||||
elif img.ndim == 3:
|
||||
H, W, C = img.shape
|
||||
H_r, W_r = H % scale, W % scale
|
||||
img = img[: H - H_r, : W - W_r, :]
|
||||
else:
|
||||
raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim))
|
||||
return img
|
||||
|
||||
|
||||
def shave(img_in, border=0):
|
||||
# img_in: Numpy, HWC or HW
|
||||
img = np.copy(img_in)
|
||||
h, w = img.shape[:2]
|
||||
img = img[border : h - border, border : w - border]
|
||||
return img
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# image processing process on numpy image
|
||||
# channel_convert(in_c, tar_type, img_list):
|
||||
# rgb2ycbcr(img, only_y=True):
|
||||
# bgr2ycbcr(img, only_y=True):
|
||||
# ycbcr2rgb(img):
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def rgb2ycbcr(img, only_y=True):
|
||||
"""same as matlab rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
"""
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.0
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(
|
||||
img,
|
||||
[
|
||||
[65.481, -37.797, 112.0],
|
||||
[128.553, -74.203, -93.786],
|
||||
[24.966, 112.0, -18.214],
|
||||
],
|
||||
) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.0
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
"""same as matlab ycbcr2rgb
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
"""
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.0
|
||||
# convert
|
||||
rlt = np.matmul(
|
||||
img,
|
||||
[
|
||||
[0.00456621, 0.00456621, 0.00456621],
|
||||
[0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0],
|
||||
],
|
||||
) * 255.0 + [-222.921, 135.576, -276.836]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.0
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def bgr2ycbcr(img, only_y=True):
|
||||
"""bgr version of rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
"""
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.0
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(
|
||||
img,
|
||||
[
|
||||
[24.966, 112.0, -18.214],
|
||||
[128.553, -74.203, -93.786],
|
||||
[65.481, -37.797, 112.0],
|
||||
],
|
||||
) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.0
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def channel_convert(in_c, tar_type, img_list):
|
||||
# conversion among BGR, gray and y
|
||||
if in_c == 3 and tar_type == "gray": # BGR to gray
|
||||
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in gray_list]
|
||||
elif in_c == 3 and tar_type == "y": # BGR to y
|
||||
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in y_list]
|
||||
elif in_c == 1 and tar_type == "RGB": # gray/y to BGR
|
||||
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
||||
else:
|
||||
return img_list
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# metric, PSNR and SSIM
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# PSNR
|
||||
# --------------------------------------------
|
||||
def calculate_psnr(img1, img2, border=0):
|
||||
# img1 and img2 have range [0, 255]
|
||||
# img1 = img1.squeeze()
|
||||
# img2 = img2.squeeze()
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError("Input images must have the same dimensions.")
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border : h - border, border : w - border]
|
||||
img2 = img2[border : h - border, border : w - border]
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
mse = np.mean((img1 - img2) ** 2)
|
||||
if mse == 0:
|
||||
return float("inf")
|
||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# SSIM
|
||||
# --------------------------------------------
|
||||
def calculate_ssim(img1, img2, border=0):
|
||||
"""calculate SSIM
|
||||
the same outputs as MATLAB's
|
||||
img1, img2: [0, 255]
|
||||
"""
|
||||
# img1 = img1.squeeze()
|
||||
# img2 = img2.squeeze()
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError("Input images must have the same dimensions.")
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border : h - border, border : w - border]
|
||||
img2 = img2[border : h - border, border : w - border]
|
||||
|
||||
if img1.ndim == 2:
|
||||
return ssim(img1, img2)
|
||||
elif img1.ndim == 3:
|
||||
if img1.shape[2] == 3:
|
||||
ssims = []
|
||||
for i in range(3):
|
||||
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
|
||||
return np.array(ssims).mean()
|
||||
elif img1.shape[2] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError("Wrong input image dimensions.")
|
||||
|
||||
|
||||
def ssim(img1, img2):
|
||||
C1 = (0.01 * 255) ** 2
|
||||
C2 = (0.03 * 255) ** 2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
|
||||
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
|
||||
)
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# matlab's bicubic imresize (numpy and torch) [0, 1]
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
# matlab 'imresize' function, now only support 'bicubic'
|
||||
def cubic(x):
|
||||
absx = torch.abs(x)
|
||||
absx2 = absx**2
|
||||
absx3 = absx**3
|
||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
|
||||
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
|
||||
) * (((absx > 1) * (absx <= 2)).type_as(absx))
|
||||
|
||||
|
||||
def calculate_weights_indices(
|
||||
in_length, out_length, scale, kernel, kernel_width, antialiasing
|
||||
):
|
||||
if (scale < 1) and (antialiasing):
|
||||
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
||||
kernel_width = kernel_width / scale
|
||||
|
||||
# Output-space coordinates
|
||||
x = torch.linspace(1, out_length, out_length)
|
||||
|
||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
||||
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
||||
# space maps to 1.5 in input space.
|
||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
# What is the left-most pixel that can be involved in the computation?
|
||||
left = torch.floor(u - kernel_width / 2)
|
||||
|
||||
# What is the maximum number of pixels that can be involved in the
|
||||
# computation? Note: it's OK to use an extra pixel here; if the
|
||||
# corresponding weights are all zero, it will be eliminated at the end
|
||||
# of this function.
|
||||
P = math.ceil(kernel_width) + 2
|
||||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
|
||||
0, P - 1, P
|
||||
).view(1, P).expand(out_length, P)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
||||
# apply cubic kernel
|
||||
if (scale < 1) and (antialiasing):
|
||||
weights = scale * cubic(distance_to_center * scale)
|
||||
else:
|
||||
weights = cubic(distance_to_center)
|
||||
# Normalize the weights matrix so that each row sums to 1.
|
||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
||||
weights = weights / weights_sum.expand(out_length, P)
|
||||
|
||||
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 1, P - 2)
|
||||
weights = weights.narrow(1, 1, P - 2)
|
||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 0, P - 2)
|
||||
weights = weights.narrow(1, 0, P - 2)
|
||||
weights = weights.contiguous()
|
||||
indices = indices.contiguous()
|
||||
sym_len_s = -indices.min() + 1
|
||||
sym_len_e = indices.max() - in_length
|
||||
indices = indices + sym_len_s - 1
|
||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# imresize for tensor image [0, 1]
|
||||
# --------------------------------------------
|
||||
def imresize(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: pytorch tensor, CHW or HW [0,1]
|
||||
# output: CHW or HW [0,1] w/o round
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(0)
|
||||
in_C, in_H, in_W = img.size()
|
||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
||||
kernel_width = 4
|
||||
kernel = "cubic"
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing
|
||||
)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing
|
||||
)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
||||
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:, :sym_len_Hs, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[:, -sym_len_He:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[j, i, :] = (
|
||||
img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||
)
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
||||
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :, :sym_len_Ws]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, :, -sym_len_We:]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
return out_2
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# imresize for numpy image [0, 1]
|
||||
# --------------------------------------------
|
||||
def imresize_np(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: Numpy, HWC or HW [0,1]
|
||||
# output: HWC or HW [0,1] w/o round
|
||||
img = torch.from_numpy(img)
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(2)
|
||||
|
||||
in_H, in_W, in_C = img.size()
|
||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
||||
kernel_width = 4
|
||||
kernel = "cubic"
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing
|
||||
)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing
|
||||
)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
||||
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:sym_len_Hs, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[-sym_len_He:, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[i, :, j] = (
|
||||
img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
||||
)
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
||||
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :sym_len_Ws, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, -sym_len_We:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
|
||||
return out_2.numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("---")
|
||||
# img = imread_uint('test.bmp', 3)
|
||||
# img = uint2single(img)
|
||||
# img_bicubic = imresize_np(img, 1/4)
|
1
extern/ldm_zero123/modules/losses/__init__.py
vendored
Executable file
1
extern/ldm_zero123/modules/losses/__init__.py
vendored
Executable file
@ -0,0 +1 @@
|
||||
from extern.ldm_zero123.modules.losses.contperceptual import LPIPSWithDiscriminator
|
153
extern/ldm_zero123/modules/losses/contperceptual.py
vendored
Executable file
153
extern/ldm_zero123/modules/losses/contperceptual.py
vendored
Executable file
@ -0,0 +1,153 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
||||
|
||||
|
||||
class LPIPSWithDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
disc_start,
|
||||
logvar_init=0.0,
|
||||
kl_weight=1.0,
|
||||
pixelloss_weight=1.0,
|
||||
disc_num_layers=3,
|
||||
disc_in_channels=3,
|
||||
disc_factor=1.0,
|
||||
disc_weight=1.0,
|
||||
perceptual_weight=1.0,
|
||||
use_actnorm=False,
|
||||
disc_conditional=False,
|
||||
disc_loss="hinge",
|
||||
):
|
||||
super().__init__()
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
self.kl_weight = kl_weight
|
||||
self.pixel_weight = pixelloss_weight
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
# output log variance
|
||||
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
||||
|
||||
self.discriminator = NLayerDiscriminator(
|
||||
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
|
||||
).apply(weights_init)
|
||||
self.discriminator_iter_start = disc_start
|
||||
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
||||
self.disc_factor = disc_factor
|
||||
self.discriminator_weight = disc_weight
|
||||
self.disc_conditional = disc_conditional
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
nll_grads = torch.autograd.grad(
|
||||
nll_loss, self.last_layer[0], retain_graph=True
|
||||
)[0]
|
||||
g_grads = torch.autograd.grad(
|
||||
g_loss, self.last_layer[0], retain_graph=True
|
||||
)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs,
|
||||
reconstructions,
|
||||
posteriors,
|
||||
optimizer_idx,
|
||||
global_step,
|
||||
last_layer=None,
|
||||
cond=None,
|
||||
split="train",
|
||||
weights=None,
|
||||
):
|
||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0:
|
||||
p_loss = self.perceptual_loss(
|
||||
inputs.contiguous(), reconstructions.contiguous()
|
||||
)
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
|
||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if weights is not None:
|
||||
weighted_nll_loss = weights * nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
kl_loss = posteriors.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
|
||||
# now the GAN part
|
||||
if optimizer_idx == 0:
|
||||
# generator update
|
||||
if cond is None:
|
||||
assert not self.disc_conditional
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
else:
|
||||
assert self.disc_conditional
|
||||
logits_fake = self.discriminator(
|
||||
torch.cat((reconstructions.contiguous(), cond), dim=1)
|
||||
)
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
|
||||
if self.disc_factor > 0.0:
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(
|
||||
nll_loss, g_loss, last_layer=last_layer
|
||||
)
|
||||
except RuntimeError:
|
||||
assert not self.training
|
||||
d_weight = torch.tensor(0.0)
|
||||
else:
|
||||
d_weight = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(
|
||||
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
||||
)
|
||||
loss = (
|
||||
weighted_nll_loss
|
||||
+ self.kl_weight * kl_loss
|
||||
+ d_weight * disc_factor * g_loss
|
||||
)
|
||||
|
||||
log = {
|
||||
"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||
"{}/logvar".format(split): self.logvar.detach(),
|
||||
"{}/kl_loss".format(split): kl_loss.detach().mean(),
|
||||
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
||||
"{}/d_weight".format(split): d_weight.detach(),
|
||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
||||
}
|
||||
return loss, log
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# second pass for discriminator update
|
||||
if cond is None:
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
else:
|
||||
logits_real = self.discriminator(
|
||||
torch.cat((inputs.contiguous().detach(), cond), dim=1)
|
||||
)
|
||||
logits_fake = self.discriminator(
|
||||
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
|
||||
)
|
||||
|
||||
disc_factor = adopt_weight(
|
||||
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
||||
)
|
||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||
|
||||
log = {
|
||||
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
||||
"{}/logits_fake".format(split): logits_fake.detach().mean(),
|
||||
}
|
||||
return d_loss, log
|
218
extern/ldm_zero123/modules/losses/vqperceptual.py
vendored
Executable file
218
extern/ldm_zero123/modules/losses/vqperceptual.py
vendored
Executable file
@ -0,0 +1,218 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
||||
from taming.modules.losses.lpips import LPIPS
|
||||
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||
from torch import nn
|
||||
|
||||
|
||||
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
|
||||
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
|
||||
loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
|
||||
loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
|
||||
loss_real = (weights * loss_real).sum() / weights.sum()
|
||||
loss_fake = (weights * loss_fake).sum() / weights.sum()
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
return d_loss
|
||||
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
def measure_perplexity(predicted_indices, n_embed):
|
||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
|
||||
avg_probs = encodings.mean(0)
|
||||
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
||||
cluster_use = torch.sum(avg_probs > 0)
|
||||
return perplexity, cluster_use
|
||||
|
||||
|
||||
def l1(x, y):
|
||||
return torch.abs(x - y)
|
||||
|
||||
|
||||
def l2(x, y):
|
||||
return torch.pow((x - y), 2)
|
||||
|
||||
|
||||
class VQLPIPSWithDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
disc_start,
|
||||
codebook_weight=1.0,
|
||||
pixelloss_weight=1.0,
|
||||
disc_num_layers=3,
|
||||
disc_in_channels=3,
|
||||
disc_factor=1.0,
|
||||
disc_weight=1.0,
|
||||
perceptual_weight=1.0,
|
||||
use_actnorm=False,
|
||||
disc_conditional=False,
|
||||
disc_ndf=64,
|
||||
disc_loss="hinge",
|
||||
n_classes=None,
|
||||
perceptual_loss="lpips",
|
||||
pixel_loss="l1",
|
||||
):
|
||||
super().__init__()
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
assert perceptual_loss in ["lpips", "clips", "dists"]
|
||||
assert pixel_loss in ["l1", "l2"]
|
||||
self.codebook_weight = codebook_weight
|
||||
self.pixel_weight = pixelloss_weight
|
||||
if perceptual_loss == "lpips":
|
||||
print(f"{self.__class__.__name__}: Running with LPIPS.")
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
else:
|
||||
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
|
||||
self.perceptual_weight = perceptual_weight
|
||||
|
||||
if pixel_loss == "l1":
|
||||
self.pixel_loss = l1
|
||||
else:
|
||||
self.pixel_loss = l2
|
||||
|
||||
self.discriminator = NLayerDiscriminator(
|
||||
input_nc=disc_in_channels,
|
||||
n_layers=disc_num_layers,
|
||||
use_actnorm=use_actnorm,
|
||||
ndf=disc_ndf,
|
||||
).apply(weights_init)
|
||||
self.discriminator_iter_start = disc_start
|
||||
if disc_loss == "hinge":
|
||||
self.disc_loss = hinge_d_loss
|
||||
elif disc_loss == "vanilla":
|
||||
self.disc_loss = vanilla_d_loss
|
||||
else:
|
||||
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
||||
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
||||
self.disc_factor = disc_factor
|
||||
self.discriminator_weight = disc_weight
|
||||
self.disc_conditional = disc_conditional
|
||||
self.n_classes = n_classes
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
nll_grads = torch.autograd.grad(
|
||||
nll_loss, self.last_layer[0], retain_graph=True
|
||||
)[0]
|
||||
g_grads = torch.autograd.grad(
|
||||
g_loss, self.last_layer[0], retain_graph=True
|
||||
)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
codebook_loss,
|
||||
inputs,
|
||||
reconstructions,
|
||||
optimizer_idx,
|
||||
global_step,
|
||||
last_layer=None,
|
||||
cond=None,
|
||||
split="train",
|
||||
predicted_indices=None,
|
||||
):
|
||||
if not exists(codebook_loss):
|
||||
codebook_loss = torch.tensor([0.0]).to(inputs.device)
|
||||
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0:
|
||||
p_loss = self.perceptual_loss(
|
||||
inputs.contiguous(), reconstructions.contiguous()
|
||||
)
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
else:
|
||||
p_loss = torch.tensor([0.0])
|
||||
|
||||
nll_loss = rec_loss
|
||||
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
nll_loss = torch.mean(nll_loss)
|
||||
|
||||
# now the GAN part
|
||||
if optimizer_idx == 0:
|
||||
# generator update
|
||||
if cond is None:
|
||||
assert not self.disc_conditional
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
else:
|
||||
assert self.disc_conditional
|
||||
logits_fake = self.discriminator(
|
||||
torch.cat((reconstructions.contiguous(), cond), dim=1)
|
||||
)
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(
|
||||
nll_loss, g_loss, last_layer=last_layer
|
||||
)
|
||||
except RuntimeError:
|
||||
assert not self.training
|
||||
d_weight = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(
|
||||
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
||||
)
|
||||
loss = (
|
||||
nll_loss
|
||||
+ d_weight * disc_factor * g_loss
|
||||
+ self.codebook_weight * codebook_loss.mean()
|
||||
)
|
||||
|
||||
log = {
|
||||
"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
||||
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
||||
"{}/p_loss".format(split): p_loss.detach().mean(),
|
||||
"{}/d_weight".format(split): d_weight.detach(),
|
||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
||||
}
|
||||
if predicted_indices is not None:
|
||||
assert self.n_classes is not None
|
||||
with torch.no_grad():
|
||||
perplexity, cluster_usage = measure_perplexity(
|
||||
predicted_indices, self.n_classes
|
||||
)
|
||||
log[f"{split}/perplexity"] = perplexity
|
||||
log[f"{split}/cluster_usage"] = cluster_usage
|
||||
return loss, log
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# second pass for discriminator update
|
||||
if cond is None:
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
else:
|
||||
logits_real = self.discriminator(
|
||||
torch.cat((inputs.contiguous().detach(), cond), dim=1)
|
||||
)
|
||||
logits_fake = self.discriminator(
|
||||
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
|
||||
)
|
||||
|
||||
disc_factor = adopt_weight(
|
||||
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
||||
)
|
||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||
|
||||
log = {
|
||||
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
||||
"{}/logits_fake".format(split): logits_fake.detach().mean(),
|
||||
}
|
||||
return d_loss, log
|
705
extern/ldm_zero123/modules/x_transformer.py
vendored
Executable file
705
extern/ldm_zero123/modules/x_transformer.py
vendored
Executable file
@ -0,0 +1,705 @@
|
||||
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, reduce, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
|
||||
|
||||
LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"])
|
||||
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
self.init_()
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.emb.weight, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
n = torch.arange(x.shape[1], device=x.device)
|
||||
return self.emb(n)[None, :, :]
|
||||
|
||||
|
||||
class FixedPositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
def forward(self, x, seq_dim=1, offset=0):
|
||||
t = (
|
||||
torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
|
||||
+ offset
|
||||
)
|
||||
sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
|
||||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
||||
return emb[None, :, :]
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def always(val):
|
||||
def inner(*args, **kwargs):
|
||||
return val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def not_equals(val):
|
||||
def inner(x):
|
||||
return x != val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def equals(val):
|
||||
def inner(x):
|
||||
return x == val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def max_neg_value(tensor):
|
||||
return -torch.finfo(tensor.dtype).max
|
||||
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(
|
||||
partial(string_begins_with, prefix), d
|
||||
)
|
||||
kwargs_without_prefix = dict(
|
||||
map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
|
||||
)
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
# classes
|
||||
class Scale(nn.Module):
|
||||
def __init__(self, value, fn):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.value, *rest)
|
||||
|
||||
|
||||
class Rezero(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.g = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.g, *rest)
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim**-0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-8):
|
||||
super().__init__()
|
||||
self.scale = dim**-0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def forward(self, x, residual):
|
||||
return x + residual
|
||||
|
||||
|
||||
class GRUGating(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gru = nn.GRUCell(dim, dim)
|
||||
|
||||
def forward(self, x, residual):
|
||||
gated_output = self.gru(
|
||||
rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")
|
||||
)
|
||||
|
||||
return gated_output.reshape_as(x)
|
||||
|
||||
|
||||
# feedforward
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# attention.
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=DEFAULT_DIM_HEAD,
|
||||
heads=8,
|
||||
causal=False,
|
||||
mask=None,
|
||||
talking_heads=False,
|
||||
sparse_topk=None,
|
||||
use_entmax15=False,
|
||||
num_mem_kv=0,
|
||||
dropout=0.0,
|
||||
on_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_entmax15:
|
||||
raise NotImplementedError(
|
||||
"Check out entmax activation instead of softmax activation!"
|
||||
)
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.causal = causal
|
||||
self.mask = mask
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# talking heads
|
||||
self.talking_heads = talking_heads
|
||||
if talking_heads:
|
||||
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
# explicit topk sparse attention
|
||||
self.sparse_topk = sparse_topk
|
||||
|
||||
# entmax
|
||||
# self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
||||
self.attn_fn = F.softmax
|
||||
|
||||
# add memory key / values
|
||||
self.num_mem_kv = num_mem_kv
|
||||
if num_mem_kv > 0:
|
||||
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
|
||||
# attention on attention
|
||||
self.attn_on_attn = on_attn
|
||||
self.to_out = (
|
||||
nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
|
||||
if on_attn
|
||||
else nn.Linear(inner_dim, dim)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
rel_pos=None,
|
||||
sinusoidal_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None,
|
||||
):
|
||||
b, n, _, h, talking_heads, device = (
|
||||
*x.shape,
|
||||
self.heads,
|
||||
self.talking_heads,
|
||||
x.device,
|
||||
)
|
||||
kv_input = default(context, x)
|
||||
|
||||
q_input = x
|
||||
k_input = kv_input
|
||||
v_input = kv_input
|
||||
|
||||
if exists(mem):
|
||||
k_input = torch.cat((mem, k_input), dim=-2)
|
||||
v_input = torch.cat((mem, v_input), dim=-2)
|
||||
|
||||
if exists(sinusoidal_emb):
|
||||
# in shortformer, the query would start at a position offset depending on the past cached memory
|
||||
offset = k_input.shape[-2] - q_input.shape[-2]
|
||||
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
|
||||
k_input = k_input + sinusoidal_emb(k_input)
|
||||
|
||||
q = self.to_q(q_input)
|
||||
k = self.to_k(k_input)
|
||||
v = self.to_v(v_input)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
input_mask = None
|
||||
if any(map(exists, (mask, context_mask))):
|
||||
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
||||
k_mask = q_mask if not exists(context) else context_mask
|
||||
k_mask = default(
|
||||
k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()
|
||||
)
|
||||
q_mask = rearrange(q_mask, "b i -> b () i ()")
|
||||
k_mask = rearrange(k_mask, "b j -> b () () j")
|
||||
input_mask = q_mask * k_mask
|
||||
|
||||
if self.num_mem_kv > 0:
|
||||
mem_k, mem_v = map(
|
||||
lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)
|
||||
)
|
||||
k = torch.cat((mem_k, k), dim=-2)
|
||||
v = torch.cat((mem_v, v), dim=-2)
|
||||
if exists(input_mask):
|
||||
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
||||
|
||||
dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
|
||||
mask_value = max_neg_value(dots)
|
||||
|
||||
if exists(prev_attn):
|
||||
dots = dots + prev_attn
|
||||
|
||||
pre_softmax_attn = dots
|
||||
|
||||
if talking_heads:
|
||||
dots = einsum(
|
||||
"b h i j, h k -> b k i j", dots, self.pre_softmax_proj
|
||||
).contiguous()
|
||||
|
||||
if exists(rel_pos):
|
||||
dots = rel_pos(dots)
|
||||
|
||||
if exists(input_mask):
|
||||
dots.masked_fill_(~input_mask, mask_value)
|
||||
del input_mask
|
||||
|
||||
if self.causal:
|
||||
i, j = dots.shape[-2:]
|
||||
r = torch.arange(i, device=device)
|
||||
mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
|
||||
mask = F.pad(mask, (j - i, 0), value=False)
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
||||
top, _ = dots.topk(self.sparse_topk, dim=-1)
|
||||
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
|
||||
mask = dots < vk
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = self.attn_fn(dots, dim=-1)
|
||||
post_softmax_attn = attn
|
||||
|
||||
attn = self.dropout(attn)
|
||||
|
||||
if talking_heads:
|
||||
attn = einsum(
|
||||
"b h i j, h k -> b k i j", attn, self.post_softmax_proj
|
||||
).contiguous()
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
intermediates = Intermediates(
|
||||
pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn
|
||||
)
|
||||
|
||||
return self.to_out(out), intermediates
|
||||
|
||||
|
||||
class AttentionLayers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
pre_norm=True,
|
||||
gate_residual=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
|
||||
attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
|
||||
|
||||
dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.has_pos_emb = position_infused_attn
|
||||
self.pia_pos_emb = (
|
||||
FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||
)
|
||||
self.rotary_pos_emb = always(None)
|
||||
|
||||
assert (
|
||||
rel_pos_num_buckets <= rel_pos_max_distance
|
||||
), "number of relative position buckets must be less than the relative position max distance"
|
||||
self.rel_pos = None
|
||||
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
self.residual_attn = residual_attn
|
||||
self.cross_residual_attn = cross_residual_attn
|
||||
|
||||
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
||||
norm_class = RMSNorm if use_rmsnorm else norm_class
|
||||
norm_fn = partial(norm_class, dim)
|
||||
|
||||
norm_fn = nn.Identity if use_rezero else norm_fn
|
||||
branch_fn = Rezero if use_rezero else None
|
||||
|
||||
if cross_attend and not only_cross:
|
||||
default_block = ("a", "c", "f")
|
||||
elif cross_attend and only_cross:
|
||||
default_block = ("c", "f")
|
||||
else:
|
||||
default_block = ("a", "f")
|
||||
|
||||
if macaron:
|
||||
default_block = ("f",) + default_block
|
||||
|
||||
if exists(custom_layers):
|
||||
layer_types = custom_layers
|
||||
elif exists(par_ratio):
|
||||
par_depth = depth * len(default_block)
|
||||
assert 1 < par_ratio <= par_depth, "par ratio out of range"
|
||||
default_block = tuple(filter(not_equals("f"), default_block))
|
||||
par_attn = par_depth // par_ratio
|
||||
depth_cut = (
|
||||
par_depth * 2 // 3
|
||||
) # 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
||||
assert (
|
||||
len(default_block) <= par_width
|
||||
), "default block is too large for par_ratio"
|
||||
par_block = default_block + ("f",) * (par_width - len(default_block))
|
||||
par_head = par_block * par_attn
|
||||
layer_types = par_head + ("f",) * (par_depth - len(par_head))
|
||||
elif exists(sandwich_coef):
|
||||
assert (
|
||||
sandwich_coef > 0 and sandwich_coef <= depth
|
||||
), "sandwich coefficient should be less than the depth"
|
||||
layer_types = (
|
||||
("a",) * sandwich_coef
|
||||
+ default_block * (depth - sandwich_coef)
|
||||
+ ("f",) * sandwich_coef
|
||||
)
|
||||
else:
|
||||
layer_types = default_block * depth
|
||||
|
||||
self.layer_types = layer_types
|
||||
self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
|
||||
|
||||
for layer_type in self.layer_types:
|
||||
if layer_type == "a":
|
||||
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
||||
elif layer_type == "c":
|
||||
layer = Attention(dim, heads=heads, **attn_kwargs)
|
||||
elif layer_type == "f":
|
||||
layer = FeedForward(dim, **ff_kwargs)
|
||||
layer = layer if not macaron else Scale(0.5, layer)
|
||||
else:
|
||||
raise Exception(f"invalid layer type {layer_type}")
|
||||
|
||||
if isinstance(layer, Attention) and exists(branch_fn):
|
||||
layer = branch_fn(layer)
|
||||
|
||||
if gate_residual:
|
||||
residual_fn = GRUGating(dim)
|
||||
else:
|
||||
residual_fn = Residual()
|
||||
|
||||
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False,
|
||||
):
|
||||
hiddens = []
|
||||
intermediates = []
|
||||
prev_attn = None
|
||||
prev_cross_attn = None
|
||||
|
||||
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
||||
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
|
||||
zip(self.layer_types, self.layers)
|
||||
):
|
||||
is_last = ind == (len(self.layers) - 1)
|
||||
|
||||
if layer_type == "a":
|
||||
hiddens.append(x)
|
||||
layer_mem = mems.pop(0)
|
||||
|
||||
residual = x
|
||||
|
||||
if self.pre_norm:
|
||||
x = norm(x)
|
||||
|
||||
if layer_type == "a":
|
||||
out, inter = block(
|
||||
x,
|
||||
mask=mask,
|
||||
sinusoidal_emb=self.pia_pos_emb,
|
||||
rel_pos=self.rel_pos,
|
||||
prev_attn=prev_attn,
|
||||
mem=layer_mem,
|
||||
)
|
||||
elif layer_type == "c":
|
||||
out, inter = block(
|
||||
x,
|
||||
context=context,
|
||||
mask=mask,
|
||||
context_mask=context_mask,
|
||||
prev_attn=prev_cross_attn,
|
||||
)
|
||||
elif layer_type == "f":
|
||||
out = block(x)
|
||||
|
||||
x = residual_fn(out, residual)
|
||||
|
||||
if layer_type in ("a", "c"):
|
||||
intermediates.append(inter)
|
||||
|
||||
if layer_type == "a" and self.residual_attn:
|
||||
prev_attn = inter.pre_softmax_attn
|
||||
elif layer_type == "c" and self.cross_residual_attn:
|
||||
prev_cross_attn = inter.pre_softmax_attn
|
||||
|
||||
if not self.pre_norm and not is_last:
|
||||
x = norm(x)
|
||||
|
||||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens=hiddens, attn_intermediates=intermediates
|
||||
)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(AttentionLayers):
|
||||
def __init__(self, **kwargs):
|
||||
assert "causal" not in kwargs, "cannot set causality on encoder"
|
||||
super().__init__(causal=False, **kwargs)
|
||||
|
||||
|
||||
class TransformerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_tokens,
|
||||
max_seq_len,
|
||||
attn_layers,
|
||||
emb_dim=None,
|
||||
max_mem_len=0.0,
|
||||
emb_dropout=0.0,
|
||||
num_memory_tokens=None,
|
||||
tie_embedding=False,
|
||||
use_pos_emb=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(
|
||||
attn_layers, AttentionLayers
|
||||
), "attention layers must be one of Encoder or Decoder"
|
||||
|
||||
dim = attn_layers.dim
|
||||
emb_dim = default(emb_dim, dim)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_mem_len = max_mem_len
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
||||
self.pos_emb = (
|
||||
AbsolutePositionalEmbedding(emb_dim, max_seq_len)
|
||||
if (use_pos_emb and not attn_layers.has_pos_emb)
|
||||
else always(0)
|
||||
)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
self.attn_layers = attn_layers
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.init_()
|
||||
|
||||
self.to_logits = (
|
||||
nn.Linear(dim, num_tokens)
|
||||
if not tie_embedding
|
||||
else lambda t: t @ self.token_emb.weight.t()
|
||||
)
|
||||
|
||||
# memory tokens (like [cls]) from Memory Transformers paper
|
||||
num_memory_tokens = default(num_memory_tokens, 0)
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||
|
||||
# let funnel encoder know number of memory tokens, if specified
|
||||
if hasattr(attn_layers, "num_memory_tokens"):
|
||||
attn_layers.num_memory_tokens = num_memory_tokens
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.token_emb.weight, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embeddings=False,
|
||||
mask=None,
|
||||
return_mems=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
**kwargs,
|
||||
):
|
||||
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
||||
x = self.token_emb(x)
|
||||
x += self.pos_emb(x)
|
||||
x = self.emb_dropout(x)
|
||||
|
||||
x = self.project_emb(x)
|
||||
|
||||
if num_mem > 0:
|
||||
mem = repeat(self.memory_tokens, "n d -> b n d", b=b)
|
||||
x = torch.cat((mem, x), dim=1)
|
||||
|
||||
# auto-handle masking after appending memory tokens
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (num_mem, 0), value=True)
|
||||
|
||||
x, intermediates = self.attn_layers(
|
||||
x, mask=mask, mems=mems, return_hiddens=True, **kwargs
|
||||
)
|
||||
x = self.norm(x)
|
||||
|
||||
mem, x = x[:, :num_mem], x[:, num_mem:]
|
||||
|
||||
out = self.to_logits(x) if not return_embeddings else x
|
||||
|
||||
if return_mems:
|
||||
hiddens = intermediates.hiddens
|
||||
new_mems = (
|
||||
list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens)))
|
||||
if exists(mems)
|
||||
else hiddens
|
||||
)
|
||||
new_mems = list(
|
||||
map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)
|
||||
)
|
||||
return out, new_mems
|
||||
|
||||
if return_attn:
|
||||
attn_maps = list(
|
||||
map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
|
||||
)
|
||||
return out, attn_maps
|
||||
|
||||
return out
|
144
extern/ldm_zero123/thirdp/psp/helpers.py
vendored
Executable file
144
extern/ldm_zero123/thirdp/psp/helpers.py
vendored
Executable file
@ -0,0 +1,144 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch.nn import (
|
||||
AdaptiveAvgPool2d,
|
||||
BatchNorm2d,
|
||||
Conv2d,
|
||||
MaxPool2d,
|
||||
Module,
|
||||
PReLU,
|
||||
ReLU,
|
||||
Sequential,
|
||||
Sigmoid,
|
||||
)
|
||||
|
||||
"""
|
||||
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
def l2_norm(input, axis=1):
|
||||
norm = torch.norm(input, 2, axis, True)
|
||||
output = torch.div(input, norm)
|
||||
return output
|
||||
|
||||
|
||||
class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])):
|
||||
"""A named tuple describing a ResNet block."""
|
||||
|
||||
|
||||
def get_block(in_channel, depth, num_units, stride=2):
|
||||
return [Bottleneck(in_channel, depth, stride)] + [
|
||||
Bottleneck(depth, depth, 1) for i in range(num_units - 1)
|
||||
]
|
||||
|
||||
|
||||
def get_blocks(num_layers):
|
||||
if num_layers == 50:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=4),
|
||||
get_block(in_channel=128, depth=256, num_units=14),
|
||||
get_block(in_channel=256, depth=512, num_units=3),
|
||||
]
|
||||
elif num_layers == 100:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=13),
|
||||
get_block(in_channel=128, depth=256, num_units=30),
|
||||
get_block(in_channel=256, depth=512, num_units=3),
|
||||
]
|
||||
elif num_layers == 152:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=8),
|
||||
get_block(in_channel=128, depth=256, num_units=36),
|
||||
get_block(in_channel=256, depth=512, num_units=3),
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid number of layers: {}. Must be one of [50, 100, 152]".format(
|
||||
num_layers
|
||||
)
|
||||
)
|
||||
return blocks
|
||||
|
||||
|
||||
class SEModule(Module):
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2d(1)
|
||||
self.fc1 = Conv2d(
|
||||
channels, channels // reduction, kernel_size=1, padding=0, bias=False
|
||||
)
|
||||
self.relu = ReLU(inplace=True)
|
||||
self.fc2 = Conv2d(
|
||||
channels // reduction, channels, kernel_size=1, padding=0, bias=False
|
||||
)
|
||||
self.sigmoid = Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
|
||||
|
||||
class bottleneck_IR(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth),
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
|
||||
|
||||
class bottleneck_IR_SE(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR_SE, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth),
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth),
|
||||
SEModule(depth, 16),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
26
extern/ldm_zero123/thirdp/psp/id_loss.py
vendored
Executable file
26
extern/ldm_zero123/thirdp/psp/id_loss.py
vendored
Executable file
@ -0,0 +1,26 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from extern.ldm_zero123.thirdp.psp.model_irse import Backbone
|
||||
|
||||
|
||||
class IDFeatures(nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super(IDFeatures, self).__init__()
|
||||
print("Loading ResNet ArcFace")
|
||||
self.facenet = Backbone(
|
||||
input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se"
|
||||
)
|
||||
self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
||||
self.facenet.eval()
|
||||
|
||||
def forward(self, x, crop=False):
|
||||
# Not sure of the image range here
|
||||
if crop:
|
||||
x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
|
||||
x = x[:, :, 35:223, 32:220]
|
||||
x = self.face_pool(x)
|
||||
x_feats = self.facenet(x)
|
||||
return x_feats
|
118
extern/ldm_zero123/thirdp/psp/model_irse.py
vendored
Executable file
118
extern/ldm_zero123/thirdp/psp/model_irse.py
vendored
Executable file
@ -0,0 +1,118 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
|
||||
from torch.nn import (
|
||||
BatchNorm1d,
|
||||
BatchNorm2d,
|
||||
Conv2d,
|
||||
Dropout,
|
||||
Linear,
|
||||
Module,
|
||||
PReLU,
|
||||
Sequential,
|
||||
)
|
||||
|
||||
from extern.ldm_zero123.thirdp.psp.helpers import (
|
||||
Flatten,
|
||||
bottleneck_IR,
|
||||
bottleneck_IR_SE,
|
||||
get_blocks,
|
||||
l2_norm,
|
||||
)
|
||||
|
||||
"""
|
||||
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
|
||||
class Backbone(Module):
|
||||
def __init__(self, input_size, num_layers, mode="ir", drop_ratio=0.4, affine=True):
|
||||
super(Backbone, self).__init__()
|
||||
assert input_size in [112, 224], "input_size should be 112 or 224"
|
||||
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
||||
assert mode in ["ir", "ir_se"], "mode should be ir or ir_se"
|
||||
blocks = get_blocks(num_layers)
|
||||
if mode == "ir":
|
||||
unit_module = bottleneck_IR
|
||||
elif mode == "ir_se":
|
||||
unit_module = bottleneck_IR_SE
|
||||
self.input_layer = Sequential(
|
||||
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)
|
||||
)
|
||||
if input_size == 112:
|
||||
self.output_layer = Sequential(
|
||||
BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 7 * 7, 512),
|
||||
BatchNorm1d(512, affine=affine),
|
||||
)
|
||||
else:
|
||||
self.output_layer = Sequential(
|
||||
BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 14 * 14, 512),
|
||||
BatchNorm1d(512, affine=affine),
|
||||
)
|
||||
|
||||
modules = []
|
||||
for block in blocks:
|
||||
for bottleneck in block:
|
||||
modules.append(
|
||||
unit_module(
|
||||
bottleneck.in_channel, bottleneck.depth, bottleneck.stride
|
||||
)
|
||||
)
|
||||
self.body = Sequential(*modules)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return l2_norm(x)
|
||||
|
||||
|
||||
def IR_50(input_size):
|
||||
"""Constructs a ir-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode="ir", drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_101(input_size):
|
||||
"""Constructs a ir-101 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=100, mode="ir", drop_ratio=0.4, affine=False
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def IR_152(input_size):
|
||||
"""Constructs a ir-152 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=152, mode="ir", drop_ratio=0.4, affine=False
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_50(input_size):
|
||||
"""Constructs a ir_se-50 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=50, mode="ir_se", drop_ratio=0.4, affine=False
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_101(input_size):
|
||||
"""Constructs a ir_se-101 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=100, mode="ir_se", drop_ratio=0.4, affine=False
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_152(input_size):
|
||||
"""Constructs a ir_se-152 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=152, mode="ir_se", drop_ratio=0.4, affine=False
|
||||
)
|
||||
return model
|
249
extern/ldm_zero123/util.py
vendored
Executable file
249
extern/ldm_zero123/util.py
vendored
Executable file
@ -0,0 +1,249 @@
|
||||
import importlib
|
||||
import os
|
||||
import time
|
||||
from inspect import isfunction
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from torch import optim
|
||||
|
||||
|
||||
def pil_rectangle_crop(im):
|
||||
width, height = im.size # Get dimensions
|
||||
|
||||
if width <= height:
|
||||
left = 0
|
||||
right = width
|
||||
top = (height - width) / 2
|
||||
bottom = (height + width) / 2
|
||||
else:
|
||||
top = 0
|
||||
bottom = height
|
||||
left = (width - height) / 2
|
||||
bottom = (width + height) / 2
|
||||
|
||||
# Crop the center of the image
|
||||
im = im.crop((left, top, right, bottom))
|
||||
return im
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(
|
||||
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
|
||||
)
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == "__is_first_stage__":
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1.0e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1.0e-8, # TODO: check hyperparameters before using
|
||||
weight_decay=1.0e-2,
|
||||
amsgrad=False,
|
||||
ema_decay=0.9999, # ema decay to match previous code
|
||||
ema_power=1.0,
|
||||
param_names=(),
|
||||
):
|
||||
"""AdamW that saves EMA versions of the parameters."""
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= ema_decay <= 1.0:
|
||||
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
ema_decay=ema_decay,
|
||||
ema_power=ema_power,
|
||||
param_names=param_names,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault("amsgrad", False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
state_sums = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group["amsgrad"]
|
||||
beta1, beta2 = group["betas"]
|
||||
ema_decay = group["ema_decay"]
|
||||
ema_power = group["ema_power"]
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError("AdamW does not support sparse gradients")
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state["max_exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
# Exponential moving average of parameter values
|
||||
state["param_exp_avg"] = p.detach().float().clone()
|
||||
|
||||
exp_avgs.append(state["exp_avg"])
|
||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||
ema_params_with_grad.append(state["param_exp_avg"])
|
||||
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
|
||||
|
||||
# update the steps for each param group update
|
||||
state["step"] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state["step"])
|
||||
|
||||
optim._functional.adamw(
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group["lr"],
|
||||
weight_decay=group["weight_decay"],
|
||||
eps=group["eps"],
|
||||
maximize=False,
|
||||
)
|
||||
|
||||
cur_ema_decay = min(ema_decay, 1 - state["step"] ** -ema_power)
|
||||
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(
|
||||
param.float(), alpha=1 - cur_ema_decay
|
||||
)
|
||||
|
||||
return loss
|
666
extern/zero123.py
vendored
Normal file
666
extern/zero123.py
vendored
Normal file
@ -0,0 +1,666 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import deprecate, is_accelerate_available, logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CLIPCameraProjection(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Projection layer for CLIP embedding and camera embedding.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed`
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
self.input_dim = self.embedding_dim + self.additional_embeddings
|
||||
self.output_dim = self.embedding_dim
|
||||
|
||||
self.proj = torch.nn.Linear(self.input_dim, self.output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embedding: torch.FloatTensor,
|
||||
):
|
||||
"""
|
||||
The [`PriorTransformer`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`):
|
||||
The currently input embeddings.
|
||||
|
||||
Returns:
|
||||
The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`).
|
||||
"""
|
||||
proj_embedding = self.proj(embedding)
|
||||
return proj_embedding
|
||||
|
||||
|
||||
class Zero123Pipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline to generate variations from an input image using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
# TODO: feature_extractor is required to encode images (if they are in PIL format),
|
||||
# we should give a descriptive message if the pipeline doesn't have one.
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
clip_camera_projection: CLIPCameraProjection,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(
|
||||
unet.config, "_diffusers_version"
|
||||
) and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse(
|
||||
"0.9.0.dev0"
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate(
|
||||
"sample_size<64", "1.0.0", deprecation_message, standard_warn=False
|
||||
)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
image_encoder=image_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
clip_camera_projection=clip_camera_projection,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [
|
||||
self.unet,
|
||||
self.image_encoder,
|
||||
self.vae,
|
||||
self.safety_checker,
|
||||
]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_image(
|
||||
self,
|
||||
image,
|
||||
elevation,
|
||||
azimuth,
|
||||
distance,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
clip_image_embeddings=None,
|
||||
image_camera_embeddings=None,
|
||||
):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if image_camera_embeddings is None:
|
||||
if image is None:
|
||||
assert clip_image_embeddings is not None
|
||||
image_embeddings = clip_image_embeddings.to(device=device, dtype=dtype)
|
||||
else:
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(
|
||||
images=image, return_tensors="pt"
|
||||
).pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeddings = self.image_encoder(image).image_embeds
|
||||
image_embeddings = image_embeddings.unsqueeze(1)
|
||||
|
||||
bs_embed, seq_len, _ = image_embeddings.shape
|
||||
|
||||
if isinstance(elevation, float):
|
||||
elevation = torch.as_tensor(
|
||||
[elevation] * bs_embed, dtype=dtype, device=device
|
||||
)
|
||||
if isinstance(azimuth, float):
|
||||
azimuth = torch.as_tensor(
|
||||
[azimuth] * bs_embed, dtype=dtype, device=device
|
||||
)
|
||||
if isinstance(distance, float):
|
||||
distance = torch.as_tensor(
|
||||
[distance] * bs_embed, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
camera_embeddings = torch.stack(
|
||||
[
|
||||
torch.deg2rad(elevation),
|
||||
torch.sin(torch.deg2rad(azimuth)),
|
||||
torch.cos(torch.deg2rad(azimuth)),
|
||||
distance,
|
||||
],
|
||||
dim=-1,
|
||||
)[:, None, :]
|
||||
|
||||
image_embeddings = torch.cat([image_embeddings, camera_embeddings], dim=-1)
|
||||
|
||||
# project (image, camera) embeddings to the same dimension as clip embeddings
|
||||
image_embeddings = self.clip_camera_projection(image_embeddings)
|
||||
else:
|
||||
image_embeddings = image_camera_embeddings.to(device=device, dtype=dtype)
|
||||
bs_embed, seq_len, _ = image_embeddings.shape
|
||||
|
||||
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
||||
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
image_embeddings = image_embeddings.view(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_prompt_embeds = torch.zeros_like(image_embeddings)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
|
||||
|
||||
return image_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(
|
||||
image, output_type="pil"
|
||||
)
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(
|
||||
feature_extractor_input, return_tensors="pt"
|
||||
).to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
warnings.warn(
|
||||
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor instead",
|
||||
FutureWarning,
|
||||
)
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(self, image, height, width, callback_steps):
|
||||
# TODO: check image size or adjust image size to (height, width)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
||||
)
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None
|
||||
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(
|
||||
shape, generator=generator, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_latent_model_input(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
image: Optional[
|
||||
Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]
|
||||
],
|
||||
num_images_per_prompt: int,
|
||||
do_classifier_free_guidance: bool,
|
||||
image_latents: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image_pt = TF.to_tensor(image).unsqueeze(0).to(latents)
|
||||
elif isinstance(image, list):
|
||||
image_pt = torch.stack([TF.to_tensor(img) for img in image], dim=0).to(
|
||||
latents
|
||||
)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image_pt = image
|
||||
else:
|
||||
image_pt = None
|
||||
|
||||
if image_pt is None:
|
||||
assert image_latents is not None
|
||||
image_pt = image_latents.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
else:
|
||||
image_pt = image_pt * 2.0 - 1.0 # scale to [-1, 1]
|
||||
# FIXME: encoded latents should be multiplied with self.vae.config.scaling_factor
|
||||
# but zero123 was not trained this way
|
||||
image_pt = self.vae.encode(image_pt).latent_dist.mode()
|
||||
image_pt = image_pt.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if do_classifier_free_guidance:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.cat([latents, latents], dim=0),
|
||||
torch.cat([torch.zeros_like(image_pt), image_pt], dim=0),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
latent_model_input = torch.cat([latents, image_pt], dim=1)
|
||||
|
||||
return latent_model_input
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: Optional[
|
||||
Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]
|
||||
] = None,
|
||||
elevation: Optional[Union[float, torch.FloatTensor]] = None,
|
||||
azimuth: Optional[Union[float, torch.FloatTensor]] = None,
|
||||
distance: Optional[Union[float, torch.FloatTensor]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 3.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
clip_image_embeddings: Optional[torch.FloatTensor] = None,
|
||||
image_camera_embeddings: Optional[torch.FloatTensor] = None,
|
||||
image_latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
||||
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
|
||||
configuration of
|
||||
[this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
|
||||
`CLIPImageProcessor`
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
# TODO: check input elevation, azimuth, and distance
|
||||
# TODO: check image, clip_image_embeddings, image_latents
|
||||
self.check_inputs(image, height, width, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
batch_size = 1
|
||||
elif isinstance(image, list):
|
||||
batch_size = len(image)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
batch_size = image.shape[0]
|
||||
else:
|
||||
assert image_latents is not None
|
||||
assert (
|
||||
clip_image_embeddings is not None or image_camera_embeddings is not None
|
||||
)
|
||||
batch_size = image_latents.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input image
|
||||
if isinstance(image, PIL.Image.Image) or isinstance(image, list):
|
||||
pil_image = image
|
||||
elif isinstance(image, torch.Tensor):
|
||||
pil_image = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
|
||||
else:
|
||||
pil_image = None
|
||||
image_embeddings = self._encode_image(
|
||||
pil_image,
|
||||
elevation,
|
||||
azimuth,
|
||||
distance,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
clip_image_embeddings,
|
||||
image_camera_embeddings,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
# num_channels_latents = self.unet.config.in_channels
|
||||
num_channels_latents = 4 # FIXME: hard-coded
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
image_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = self._get_latent_model_input(
|
||||
latents,
|
||||
image,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
image_latents,
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=image_embeddings,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(
|
||||
latents / self.vae.config.scaling_factor, return_dict=False
|
||||
)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(
|
||||
image, device, image_embeddings.dtype
|
||||
)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(
|
||||
image, output_type=output_type, do_denormalize=do_denormalize
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(
|
||||
images=image, nsfw_content_detected=has_nsfw_concept
|
||||
)
|
450
gradio_app.py
Normal file
450
gradio_app.py
Normal file
@ -0,0 +1,450 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import psutil
|
||||
import trimesh
|
||||
|
||||
|
||||
def tail(f, window=20):
|
||||
# Returns the last `window` lines of file `f`.
|
||||
if window == 0:
|
||||
return []
|
||||
|
||||
BUFSIZ = 1024
|
||||
f.seek(0, 2)
|
||||
remaining_bytes = f.tell()
|
||||
size = window + 1
|
||||
block = -1
|
||||
data = []
|
||||
|
||||
while size > 0 and remaining_bytes > 0:
|
||||
if remaining_bytes - BUFSIZ > 0:
|
||||
# Seek back one whole BUFSIZ
|
||||
f.seek(block * BUFSIZ, 2)
|
||||
# read BUFFER
|
||||
bunch = f.read(BUFSIZ)
|
||||
else:
|
||||
# file too small, start from beginning
|
||||
f.seek(0, 0)
|
||||
# only read what was not read
|
||||
bunch = f.read(remaining_bytes)
|
||||
|
||||
bunch = bunch.decode("utf-8")
|
||||
data.insert(0, bunch)
|
||||
size -= bunch.count("\n")
|
||||
remaining_bytes -= BUFSIZ
|
||||
block -= 1
|
||||
|
||||
return "\n".join("".join(data).splitlines()[-window:])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentStatus:
|
||||
pid: Optional[int] = None
|
||||
progress: str = ""
|
||||
log: str = ""
|
||||
output_image: Optional[str] = None
|
||||
output_video: Optional[str] = None
|
||||
output_mesh: Optional[str] = None
|
||||
|
||||
def tolist(self):
|
||||
return [
|
||||
self.pid,
|
||||
self.progress,
|
||||
self.log,
|
||||
self.output_image,
|
||||
self.output_video,
|
||||
self.output_mesh,
|
||||
]
|
||||
|
||||
|
||||
EXP_ROOT_DIR = "outputs-gradio"
|
||||
DEFAULT_PROMPT = "a delicious hamburger"
|
||||
model_config = [
|
||||
("DreamFusion (DeepFloyd-IF)", "configs/gradio/dreamfusion-if.yaml"),
|
||||
("DreamFusion (Stable Diffusion)", "configs/gradio/dreamfusion-sd.yaml"),
|
||||
("TextMesh (DeepFloyd-IF)", "configs/gradio/textmesh-if.yaml"),
|
||||
("Fantasia3D (Stable Diffusion, Geometry Only)", "configs/gradio/fantasia3d.yaml"),
|
||||
("SJC (Stable Diffusion)", "configs/gradio/sjc.yaml"),
|
||||
("Latent-NeRF (Stable Diffusion)", "configs/gradio/latentnerf.yaml"),
|
||||
]
|
||||
model_choices = [m[0] for m in model_config]
|
||||
model_name_to_config = {m[0]: m[1] for m in model_config}
|
||||
|
||||
|
||||
def load_model_config(model_name):
|
||||
return open(model_name_to_config[model_name]).read()
|
||||
|
||||
|
||||
def load_model_config_attrs(model_name):
|
||||
config_str = load_model_config(model_name)
|
||||
from threestudio.utils.config import load_config
|
||||
|
||||
cfg = load_config(
|
||||
config_str,
|
||||
cli_args=[
|
||||
"name=dummy",
|
||||
"tag=dummy",
|
||||
"use_timestamp=false",
|
||||
f"exp_root_dir={EXP_ROOT_DIR}",
|
||||
"system.prompt_processor.prompt=placeholder",
|
||||
],
|
||||
from_string=True,
|
||||
)
|
||||
return {
|
||||
"source": config_str,
|
||||
"guidance_scale": cfg.system.guidance.guidance_scale,
|
||||
"max_steps": cfg.trainer.max_steps,
|
||||
}
|
||||
|
||||
|
||||
def on_model_selector_change(model_name):
|
||||
cfg = load_model_config_attrs(model_name)
|
||||
return [cfg["source"], cfg["guidance_scale"]]
|
||||
|
||||
|
||||
def get_current_status(process, trial_dir, alive_path):
|
||||
status = ExperimentStatus()
|
||||
|
||||
status.pid = process.pid
|
||||
|
||||
# write the current timestamp to the alive file
|
||||
# the watcher will know the last active time of this process from this timestamp
|
||||
if os.path.exists(os.path.dirname(alive_path)):
|
||||
alive_fp = open(alive_path, "w")
|
||||
alive_fp.seek(0)
|
||||
alive_fp.write(str(time.time()))
|
||||
alive_fp.flush()
|
||||
|
||||
log_path = os.path.join(trial_dir, "logs")
|
||||
progress_path = os.path.join(trial_dir, "progress")
|
||||
save_path = os.path.join(trial_dir, "save")
|
||||
|
||||
# read current progress from the progress file
|
||||
# the progress file is created by GradioCallback
|
||||
if os.path.exists(progress_path):
|
||||
status.progress = open(progress_path).read()
|
||||
else:
|
||||
status.progress = "Setting up everything ..."
|
||||
|
||||
# read the last 10 lines of the log file
|
||||
if os.path.exists(log_path):
|
||||
status.log = tail(open(log_path, "rb"), window=10)
|
||||
else:
|
||||
status.log = ""
|
||||
|
||||
# get the validation image and testing video if they exist
|
||||
if os.path.exists(save_path):
|
||||
images = glob.glob(os.path.join(save_path, "*.png"))
|
||||
steps = [
|
||||
int(re.match(r"it(\d+)-0\.png", os.path.basename(f)).group(1))
|
||||
for f in images
|
||||
]
|
||||
images = sorted(list(zip(images, steps)), key=lambda x: x[1])
|
||||
if len(images) > 0:
|
||||
status.output_image = images[-1][0]
|
||||
|
||||
videos = glob.glob(os.path.join(save_path, "*.mp4"))
|
||||
steps = [
|
||||
int(re.match(r"it(\d+)-test\.mp4", os.path.basename(f)).group(1))
|
||||
for f in videos
|
||||
]
|
||||
videos = sorted(list(zip(videos, steps)), key=lambda x: x[1])
|
||||
if len(videos) > 0:
|
||||
status.output_video = videos[-1][0]
|
||||
|
||||
export_dirs = glob.glob(os.path.join(save_path, "*export"))
|
||||
steps = [
|
||||
int(re.match(r"it(\d+)-export", os.path.basename(f)).group(1))
|
||||
for f in export_dirs
|
||||
]
|
||||
export_dirs = sorted(list(zip(export_dirs, steps)), key=lambda x: x[1])
|
||||
if len(export_dirs) > 0:
|
||||
obj = glob.glob(os.path.join(export_dirs[-1][0], "*.obj"))
|
||||
if len(obj) > 0:
|
||||
# FIXME
|
||||
# seems the gr.Model3D cannot load our manually saved obj file
|
||||
# here we load the obj and save it to a temporary file using trimesh
|
||||
mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
|
||||
trimesh.load(obj[0]).export(mesh_path.name)
|
||||
status.output_mesh = mesh_path.name
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def run(
|
||||
model_name: str,
|
||||
config: str,
|
||||
prompt: str,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
max_steps: int,
|
||||
):
|
||||
# update status every 1 second
|
||||
status_update_interval = 1
|
||||
|
||||
# save the config to a temporary file
|
||||
config_file = tempfile.NamedTemporaryFile()
|
||||
|
||||
with open(config_file.name, "w") as f:
|
||||
f.write(config)
|
||||
|
||||
# manually assign the output directory, name and tag so that we know the trial directory
|
||||
name = os.path.basename(model_name_to_config[model_name]).split(".")[0]
|
||||
tag = datetime.now().strftime("@%Y%m%d-%H%M%S")
|
||||
trial_dir = os.path.join(EXP_ROOT_DIR, name, tag)
|
||||
alive_path = os.path.join(trial_dir, "alive")
|
||||
|
||||
# spawn the training process
|
||||
process = subprocess.Popen(
|
||||
f"python launch.py --config {config_file.name} --train --gpu 0 --gradio trainer.enable_progress_bar=false".split()
|
||||
+ [
|
||||
f'name="{name}"',
|
||||
f'tag="{tag}"',
|
||||
f"exp_root_dir={EXP_ROOT_DIR}",
|
||||
"use_timestamp=false",
|
||||
f'system.prompt_processor.prompt="{prompt}"',
|
||||
f"system.guidance.guidance_scale={guidance_scale}",
|
||||
f"seed={seed}",
|
||||
f"trainer.max_steps={max_steps}",
|
||||
]
|
||||
)
|
||||
|
||||
# spawn the watcher process
|
||||
watch_process = subprocess.Popen(
|
||||
"python gradio_app.py watch".split()
|
||||
+ ["--pid", f"{process.pid}", "--trial-dir", f"{trial_dir}"]
|
||||
)
|
||||
|
||||
# update status (progress, log, image, video) every status_update_interval senconds
|
||||
# button status: Run -> Stop
|
||||
while process.poll() is None:
|
||||
time.sleep(status_update_interval)
|
||||
yield get_current_status(process, trial_dir, alive_path).tolist() + [
|
||||
gr.update(visible=False),
|
||||
gr.update(value="Stop", variant="stop", visible=True),
|
||||
]
|
||||
|
||||
# wait for the processes to finish
|
||||
process.wait()
|
||||
watch_process.wait()
|
||||
|
||||
# update status one last time
|
||||
# button status: Stop / Reset -> Run
|
||||
status = get_current_status(process, trial_dir, alive_path)
|
||||
status.progress = "Finished."
|
||||
yield status.tolist() + [
|
||||
gr.update(value="Run", variant="primary", visible=True),
|
||||
gr.update(visible=False),
|
||||
]
|
||||
|
||||
|
||||
def stop_run(pid):
|
||||
# kill the process
|
||||
print(f"Trying to kill process {pid} ...")
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except:
|
||||
print(f"Exception when killing process {pid}.")
|
||||
# button status: Stop -> Reset
|
||||
return [
|
||||
gr.update(value="Reset", variant="secondary", visible=True),
|
||||
gr.update(visible=False),
|
||||
]
|
||||
|
||||
|
||||
def launch(port, listen=False):
|
||||
with gr.Blocks(title="threestudio - Web Demo") as demo:
|
||||
with gr.Row():
|
||||
pid = gr.State()
|
||||
with gr.Column(scale=1):
|
||||
header = gr.Markdown(
|
||||
"""
|
||||
# threestudio
|
||||
|
||||
- Select a model from the dropdown menu.
|
||||
- Input a text prompt.
|
||||
- Hit Run!
|
||||
"""
|
||||
)
|
||||
|
||||
# model selection dropdown
|
||||
model_selector = gr.Dropdown(
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
label="Select a model",
|
||||
)
|
||||
|
||||
# prompt input
|
||||
prompt_input = gr.Textbox(value=DEFAULT_PROMPT, label="Input prompt")
|
||||
|
||||
# guidance scale slider
|
||||
guidance_scale_input = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=100.0,
|
||||
value=load_model_config_attrs(model_selector.value)[
|
||||
"guidance_scale"
|
||||
],
|
||||
step=0.5,
|
||||
label="Guidance scale",
|
||||
)
|
||||
|
||||
# seed slider
|
||||
seed_input = gr.Slider(
|
||||
minimum=0, maximum=2147483647, value=0, step=1, label="Seed"
|
||||
)
|
||||
|
||||
max_steps_input = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=5000,
|
||||
value=5000,
|
||||
step=1,
|
||||
label="Number of training steps",
|
||||
)
|
||||
|
||||
# full config viewer
|
||||
with gr.Accordion("See full configurations", open=False):
|
||||
config_editor = gr.Code(
|
||||
value=load_model_config(model_selector.value),
|
||||
language="yaml",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
# load config on model selection change
|
||||
model_selector.change(
|
||||
fn=on_model_selector_change,
|
||||
inputs=model_selector,
|
||||
outputs=[config_editor, guidance_scale_input],
|
||||
)
|
||||
|
||||
run_btn = gr.Button(value="Run", variant="primary")
|
||||
stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
|
||||
|
||||
# generation status
|
||||
status = gr.Textbox(
|
||||
value="Hit the Run button to start.",
|
||||
label="Status",
|
||||
lines=1,
|
||||
max_lines=1,
|
||||
)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
with gr.Accordion("See terminal logs", open=False):
|
||||
# logs
|
||||
logs = gr.Textbox(label="Logs", lines=10)
|
||||
|
||||
# validation image display
|
||||
output_image = gr.Image(value=None, label="Image")
|
||||
|
||||
# testing video display
|
||||
output_video = gr.Video(value=None, label="Video")
|
||||
|
||||
# export mesh display
|
||||
output_mesh = gr.Model3D(value=None, label="3D Mesh")
|
||||
|
||||
run_event = run_btn.click(
|
||||
fn=run,
|
||||
inputs=[
|
||||
model_selector,
|
||||
config_editor,
|
||||
prompt_input,
|
||||
guidance_scale_input,
|
||||
seed_input,
|
||||
max_steps_input,
|
||||
],
|
||||
outputs=[
|
||||
pid,
|
||||
status,
|
||||
logs,
|
||||
output_image,
|
||||
output_video,
|
||||
output_mesh,
|
||||
run_btn,
|
||||
stop_btn,
|
||||
],
|
||||
)
|
||||
stop_btn.click(
|
||||
fn=stop_run, inputs=[pid], outputs=[run_btn, stop_btn], cancels=[run_event]
|
||||
)
|
||||
|
||||
launch_args = {"server_port": port}
|
||||
if listen:
|
||||
launch_args["server_name"] = "0.0.0.0"
|
||||
demo.queue().launch(**launch_args)
|
||||
|
||||
|
||||
def watch(
|
||||
pid: int, trial_dir: str, alive_timeout: int, wait_timeout: int, check_interval: int
|
||||
) -> None:
|
||||
print(f"Spawn watcher for process {pid}")
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
exit(1)
|
||||
|
||||
alive_path = os.path.join(trial_dir, "alive")
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(wait_timeout)
|
||||
|
||||
def loop_find_progress_file():
|
||||
while True:
|
||||
if not os.path.exists(alive_path):
|
||||
time.sleep(check_interval)
|
||||
else:
|
||||
signal.alarm(0)
|
||||
return
|
||||
|
||||
def loop_check_alive():
|
||||
while True:
|
||||
if not psutil.pid_exists(pid):
|
||||
print(f"Process {pid} not exists, watcher exits.")
|
||||
exit(0)
|
||||
alive_timestamp = float(open(alive_path).read())
|
||||
if time.time() - alive_timestamp > alive_timeout:
|
||||
print(f"Alive timeout for process {pid}, killed.")
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except:
|
||||
print(f"Exception when killing process {pid}.")
|
||||
exit(0)
|
||||
time.sleep(check_interval)
|
||||
|
||||
# loop until alive file is found, or alive_timeout is reached
|
||||
loop_find_progress_file()
|
||||
# kill the process if it is not accessed for alive_timeout seconds
|
||||
loop_check_alive()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("operation", type=str, choices=["launch", "watch"])
|
||||
args, extra = parser.parse_known_args()
|
||||
if args.operation == "launch":
|
||||
parser.add_argument("--listen", action="store_true")
|
||||
parser.add_argument("--port", type=int, default=7860)
|
||||
args = parser.parse_args()
|
||||
launch(args.port, listen=args.listen)
|
||||
if args.operation == "watch":
|
||||
parser.add_argument("--pid", type=int)
|
||||
parser.add_argument("--trial-dir", type=str)
|
||||
parser.add_argument("--alive-timeout", type=int, default=10)
|
||||
parser.add_argument("--wait-timeout", type=int, default=10)
|
||||
parser.add_argument("--check-interval", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
watch(
|
||||
args.pid,
|
||||
args.trial_dir,
|
||||
alive_timeout=args.alive_timeout,
|
||||
wait_timeout=args.wait_timeout,
|
||||
check_interval=args.check_interval,
|
||||
)
|
252
launch.py
Normal file
252
launch.py
Normal file
@ -0,0 +1,252 @@
|
||||
import argparse
|
||||
import contextlib
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
class ColoredFilter(logging.Filter):
|
||||
"""
|
||||
A logging filter to add color to certain log levels.
|
||||
"""
|
||||
|
||||
RESET = "\033[0m"
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
BLUE = "\033[34m"
|
||||
MAGENTA = "\033[35m"
|
||||
CYAN = "\033[36m"
|
||||
|
||||
COLORS = {
|
||||
"WARNING": YELLOW,
|
||||
"INFO": GREEN,
|
||||
"DEBUG": BLUE,
|
||||
"CRITICAL": MAGENTA,
|
||||
"ERROR": RED,
|
||||
}
|
||||
|
||||
RESET = "\x1b[0m"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def filter(self, record):
|
||||
if record.levelname in self.COLORS:
|
||||
color_start = self.COLORS[record.levelname]
|
||||
record.levelname = f"{color_start}[{record.levelname}]"
|
||||
record.msg = f"{record.msg}{self.RESET}"
|
||||
return True
|
||||
|
||||
|
||||
def main(args, extras) -> None:
|
||||
# set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else []
|
||||
selected_gpus = [0]
|
||||
|
||||
# Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified.
|
||||
# As far as Pytorch Lightning is concerned, we always use all available GPUs
|
||||
# (possibly filtered by CUDA_VISIBLE_DEVICES).
|
||||
devices = -1
|
||||
if len(env_gpus) > 0:
|
||||
# CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script.
|
||||
n_gpus = len(env_gpus)
|
||||
else:
|
||||
selected_gpus = list(args.gpu.split(","))
|
||||
n_gpus = len(selected_gpus)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
if args.typecheck:
|
||||
from jaxtyping import install_import_hook
|
||||
|
||||
install_import_hook("threestudio", "typeguard.typechecked")
|
||||
|
||||
import threestudio
|
||||
from threestudio.systems.base import BaseSystem
|
||||
from threestudio.utils.callbacks import (
|
||||
CodeSnapshotCallback,
|
||||
ConfigSnapshotCallback,
|
||||
CustomProgressBar,
|
||||
ProgressCallback,
|
||||
)
|
||||
from threestudio.utils.config import ExperimentConfig, load_config
|
||||
from threestudio.utils.misc import get_rank
|
||||
from threestudio.utils.typing import Optional
|
||||
|
||||
logger = logging.getLogger("pytorch_lightning")
|
||||
if args.verbose:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
for handler in logger.handlers:
|
||||
if handler.stream == sys.stderr: # type: ignore
|
||||
if not args.gradio:
|
||||
handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
|
||||
handler.addFilter(ColoredFilter())
|
||||
else:
|
||||
handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
|
||||
|
||||
# parse YAML config to OmegaConf
|
||||
cfg: ExperimentConfig
|
||||
cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus)
|
||||
|
||||
if len(cfg.custom_import) > 0:
|
||||
print(cfg.custom_import)
|
||||
for extension in cfg.custom_import:
|
||||
importlib.import_module(extension)
|
||||
# set a different seed for each device
|
||||
pl.seed_everything(cfg.seed + get_rank(), workers=True)
|
||||
|
||||
dm = threestudio.find(cfg.data_type)(cfg.data)
|
||||
|
||||
# Auto check resume files during training
|
||||
if args.train and cfg.resume is None:
|
||||
import glob
|
||||
resume_file_list = glob.glob(f"{cfg.trial_dir}/ckpts/*")
|
||||
if len(resume_file_list) != 0:
|
||||
print(sorted(resume_file_list))
|
||||
cfg.resume = sorted(resume_file_list)[-1]
|
||||
print(f"Find resume file: {cfg.resume}")
|
||||
|
||||
system: BaseSystem = threestudio.find(cfg.system_type)(
|
||||
cfg.system, resumed=cfg.resume is not None
|
||||
)
|
||||
system.set_save_dir(os.path.join(cfg.trial_dir, "save"))
|
||||
|
||||
if args.gradio:
|
||||
fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs"))
|
||||
fh.setLevel(logging.INFO)
|
||||
if args.verbose:
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
|
||||
logger.addHandler(fh)
|
||||
|
||||
callbacks = []
|
||||
if args.train:
|
||||
callbacks += [
|
||||
ModelCheckpoint(
|
||||
dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint
|
||||
),
|
||||
LearningRateMonitor(logging_interval="step"),
|
||||
CodeSnapshotCallback(
|
||||
os.path.join(cfg.trial_dir, "code"), use_version=False
|
||||
),
|
||||
ConfigSnapshotCallback(
|
||||
args.config,
|
||||
cfg,
|
||||
os.path.join(cfg.trial_dir, "configs"),
|
||||
use_version=False,
|
||||
),
|
||||
]
|
||||
if args.gradio:
|
||||
callbacks += [
|
||||
ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress"))
|
||||
]
|
||||
else:
|
||||
callbacks += [CustomProgressBar(refresh_rate=1)]
|
||||
|
||||
def write_to_text(file, lines):
|
||||
with open(file, "w") as f:
|
||||
for line in lines:
|
||||
f.write(line + "\n")
|
||||
|
||||
loggers = []
|
||||
if args.train:
|
||||
# make tensorboard logging dir to suppress warning
|
||||
rank_zero_only(
|
||||
lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True)
|
||||
)()
|
||||
loggers += [
|
||||
TensorBoardLogger(cfg.trial_dir, name="tb_logs"),
|
||||
CSVLogger(cfg.trial_dir, name="csv_logs"),
|
||||
] + system.get_loggers()
|
||||
rank_zero_only(
|
||||
lambda: write_to_text(
|
||||
os.path.join(cfg.trial_dir, "cmd.txt"),
|
||||
["python " + " ".join(sys.argv), str(args)],
|
||||
)
|
||||
)()
|
||||
|
||||
trainer = Trainer(
|
||||
callbacks=callbacks,
|
||||
logger=loggers,
|
||||
inference_mode=False,
|
||||
accelerator="gpu",
|
||||
devices=devices,
|
||||
**cfg.trainer,
|
||||
)
|
||||
|
||||
def set_system_status(system: BaseSystem, ckpt_path: Optional[str]):
|
||||
if ckpt_path is None:
|
||||
return
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
system.set_resume_status(ckpt["epoch"], ckpt["global_step"])
|
||||
|
||||
if args.train:
|
||||
trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume)
|
||||
trainer.test(system, datamodule=dm)
|
||||
if args.gradio:
|
||||
# also export assets if in gradio mode
|
||||
trainer.predict(system, datamodule=dm)
|
||||
elif args.validate:
|
||||
# manually set epoch and global_step as they cannot be automatically resumed
|
||||
set_system_status(system, cfg.resume)
|
||||
trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume)
|
||||
elif args.test:
|
||||
# manually set epoch and global_step as they cannot be automatically resumed
|
||||
set_system_status(system, cfg.resume)
|
||||
trainer.test(system, datamodule=dm, ckpt_path=cfg.resume)
|
||||
elif args.export:
|
||||
set_system_status(system, cfg.resume)
|
||||
trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", required=True, help="path to config file")
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
default="0",
|
||||
help="GPU(s) to be used. 0 means use the 1st available GPU. "
|
||||
"1,2 means use the 2nd and 3rd available GPU. "
|
||||
"If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, "
|
||||
"this argument is ignored and all available GPUs are always used.",
|
||||
)
|
||||
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument("--train", action="store_true")
|
||||
group.add_argument("--validate", action="store_true")
|
||||
group.add_argument("--test", action="store_true")
|
||||
group.add_argument("--export", action="store_true")
|
||||
|
||||
parser.add_argument(
|
||||
"--gradio", action="store_true", help="if true, run in gradio mode"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="if true, set logging level to DEBUG"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--typecheck",
|
||||
action="store_true",
|
||||
help="whether to enable dynamic type checking",
|
||||
)
|
||||
|
||||
args, extras = parser.parse_known_args()
|
||||
|
||||
if args.gradio:
|
||||
# FIXME: no effect, stdout is not captured
|
||||
with contextlib.redirect_stdout(sys.stderr):
|
||||
main(args, extras)
|
||||
else:
|
||||
main(args, extras)
|
567
metric_utils.py
Normal file
567
metric_utils.py
Normal file
@ -0,0 +1,567 @@
|
||||
# * evaluate use laion/CLIP-ViT-H-14-laion2B-s32B-b79K
|
||||
# best open source clip so far: laion/CLIP-ViT-bigG-14-laion2B-39B-b160k
|
||||
# code adapted from NeuralLift-360
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
import matplotlib.pyplot as plt
|
||||
# import clip
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
|
||||
from torchvision import transforms
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
from PIL import Image
|
||||
# import torchvision.transforms as transforms
|
||||
import glob
|
||||
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
|
||||
import lpips
|
||||
from os.path import join as osp
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import contextual_loss as cl
|
||||
|
||||
criterion = cl.ContextualLoss(use_vgg=True, vgg_layer='relu5_4')
|
||||
|
||||
class CLIP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
device,
|
||||
clip_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
|
||||
size=224): #'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.device = f"cuda:{device}"
|
||||
|
||||
clip_name = clip_name
|
||||
|
||||
self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
||||
clip_name)
|
||||
self.clip_model = CLIPModel.from_pretrained(clip_name).to(self.device)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
'openai/clip-vit-base-patch32')
|
||||
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=self.feature_extractor.image_mean,
|
||||
std=self.feature_extractor.image_std)
|
||||
|
||||
self.resize = transforms.Resize(224)
|
||||
self.to_tensor = transforms.ToTensor()
|
||||
|
||||
# image augmentation
|
||||
self.aug = T.Compose([
|
||||
T.Resize((224, 224)),
|
||||
T.Normalize((0.48145466, 0.4578275, 0.40821073),
|
||||
(0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
# * recommend to use this function for evaluation
|
||||
@torch.no_grad()
|
||||
def score_gt(self, ref_img_path, novel_views):
|
||||
# assert len(novel_views) == 100
|
||||
clip_scores = []
|
||||
for novel in novel_views:
|
||||
clip_scores.append(self.score_from_path(ref_img_path, [novel]))
|
||||
return np.mean(clip_scores)
|
||||
|
||||
# * recommend to use this function for evaluation
|
||||
# def score_gt(self, ref_paths, novel_paths):
|
||||
# clip_scores = []
|
||||
# for img1_path, img2_path in zip(ref_paths, novel_paths):
|
||||
# clip_scores.append(self.score_from_path(img1_path, img2_path))
|
||||
|
||||
# return np.mean(clip_scores)
|
||||
|
||||
def similarity(self, image1_features: torch.Tensor,
|
||||
image2_features: torch.Tensor) -> float:
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
y = image1_features.T.view(image1_features.T.shape[1],
|
||||
image1_features.T.shape[0])
|
||||
similarity = torch.matmul(y, image2_features.T)
|
||||
# print(similarity)
|
||||
return similarity[0][0].item()
|
||||
|
||||
def get_img_embeds(self, img):
|
||||
if img.shape[0] == 4:
|
||||
img = img[:3, :, :]
|
||||
|
||||
img = self.aug(img).to(self.device)
|
||||
img = img.unsqueeze(0) # b,c,h,w
|
||||
|
||||
# plt.imshow(img.cpu().squeeze(0).permute(1, 2, 0).numpy())
|
||||
# plt.show()
|
||||
# print(img)
|
||||
|
||||
image_z = self.clip_model.get_image_features(img)
|
||||
image_z = image_z / image_z.norm(dim=-1,
|
||||
keepdim=True) # normalize features
|
||||
return image_z
|
||||
|
||||
def score_from_feature(self, img1, img2):
|
||||
img1_feature, img2_feature = self.get_img_embeds(
|
||||
img1), self.get_img_embeds(img2)
|
||||
# for debug
|
||||
return self.similarity(img1_feature, img2_feature)
|
||||
|
||||
def read_img_list(self, img_list):
|
||||
size = self.size
|
||||
images = []
|
||||
# white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
|
||||
|
||||
for img_path in img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
# print(img_path)
|
||||
if img.shape[2] == 4: # Handle BGRA images
|
||||
alpha = img[:, :, 3] # Extract alpha channel
|
||||
img = cv2.cvtColor(img,cv2.COLOR_BGRA2RGB) # Convert BGRA to BGR
|
||||
img[np.where(alpha == 0)] = [
|
||||
255, 255, 255
|
||||
] # Set transparent pixels to white
|
||||
else: # Handle other image formats like JPG and PNG
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
||||
|
||||
# plt.imshow(img)
|
||||
# plt.show()
|
||||
|
||||
images.append(img)
|
||||
|
||||
images = np.stack(images, axis=0)
|
||||
# images[np.where(images == 0)] = 255 # Set black pixels to white
|
||||
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
|
||||
# images = images.astype(np.float32)
|
||||
|
||||
return images
|
||||
|
||||
def score_from_path(self, img1_path, img2_path):
|
||||
img1, img2 = self.read_img_list(img1_path), self.read_img_list(img2_path)
|
||||
img1 = np.squeeze(img1)
|
||||
img2 = np.squeeze(img2)
|
||||
# plt.imshow(img1)
|
||||
# plt.show()
|
||||
# plt.imshow(img2)
|
||||
# plt.show()
|
||||
|
||||
img1, img2 = self.to_tensor(img1), self.to_tensor(img2)
|
||||
# print("img1 to tensor ",img1)
|
||||
return self.score_from_feature(img1, img2)
|
||||
|
||||
|
||||
def numpy_to_torch(images):
|
||||
images = images * 2.0 - 1.0
|
||||
images = torch.from_numpy(images.transpose((0, 3, 1, 2))).float()
|
||||
return images.cuda()
|
||||
|
||||
|
||||
class LPIPSMeter:
|
||||
|
||||
def __init__(self,
|
||||
net='alex',
|
||||
device=None,
|
||||
size=224): # or we can use 'alex', 'vgg' as network
|
||||
self.size = size
|
||||
self.net = net
|
||||
self.results = []
|
||||
self.device = device if device is not None else torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.fn = lpips.LPIPS(net=net).eval().to(self.device)
|
||||
|
||||
def measure(self):
|
||||
return np.mean(self.results)
|
||||
|
||||
def report(self):
|
||||
return f'LPIPS ({self.net}) = {self.measure():.6f}'
|
||||
|
||||
def read_img_list(self, img_list):
|
||||
size = self.size
|
||||
images = []
|
||||
white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
|
||||
|
||||
for img_path in img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if img.shape[2] == 4: # Handle BGRA images
|
||||
alpha = img[:, :, 3] # Extract alpha channel
|
||||
img = cv2.cvtColor(img,
|
||||
cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR
|
||||
|
||||
img = cv2.cvtColor(img,
|
||||
cv2.COLOR_BGR2RGB) # Convert BGR to RGB
|
||||
img[np.where(alpha == 0)] = [
|
||||
255, 255, 255
|
||||
] # Set transparent pixels to white
|
||||
else: # Handle other image formats like JPG and PNG
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
||||
images.append(img)
|
||||
|
||||
images = np.stack(images, axis=0)
|
||||
# images[np.where(images == 0)] = 255 # Set black pixels to white
|
||||
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
|
||||
images = images.astype(np.float32) / 255.0
|
||||
|
||||
return images
|
||||
|
||||
# * recommend to use this function for evaluation
|
||||
@torch.no_grad()
|
||||
def score_gt(self, ref_paths, novel_paths):
|
||||
self.results = []
|
||||
for path0, path1 in zip(ref_paths, novel_paths):
|
||||
# Load images
|
||||
# img0 = lpips.im2tensor(lpips.load_image(path0)).cuda() # RGB image from [-1,1]
|
||||
# img1 = lpips.im2tensor(lpips.load_image(path1)).cuda()
|
||||
img0, img1 = self.read_img_list([path0]), self.read_img_list(
|
||||
[path1])
|
||||
img0, img1 = numpy_to_torch(img0), numpy_to_torch(img1)
|
||||
# print(img0.shape,img1.shape)
|
||||
img0 = F.interpolate(img0,
|
||||
size=(self.size, self.size),
|
||||
mode='area')
|
||||
img1 = F.interpolate(img1,
|
||||
size=(self.size, self.size),
|
||||
mode='area')
|
||||
|
||||
# for debug vis
|
||||
# plt.imshow(img0.cpu().squeeze(0).permute(1, 2, 0).numpy())
|
||||
# plt.show()
|
||||
# plt.imshow(img1.cpu().squeeze(0).permute(1, 2, 0).numpy())
|
||||
# plt.show()
|
||||
# equivalent to cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA
|
||||
|
||||
# print(img0.shape,img1.shape)
|
||||
|
||||
self.results.append(self.fn.forward(img0, img1).cpu().numpy())
|
||||
|
||||
return self.measure()
|
||||
|
||||
class CXMeter:
|
||||
|
||||
def __init__(self,
|
||||
net='vgg',
|
||||
device=None,
|
||||
size=512): # or we can use 'alex', 'vgg' as network
|
||||
self.size = size
|
||||
self.net = net
|
||||
self.results = []
|
||||
self.device = device if device is not None else torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.fn = lpips.LPIPS(net=net).eval().to(self.device)
|
||||
|
||||
def measure(self):
|
||||
return np.mean(self.results)
|
||||
|
||||
def report(self):
|
||||
return f'LPIPS ({self.net}) = {self.measure():.6f}'
|
||||
|
||||
def read_img_list(self, img_list):
|
||||
size = self.size
|
||||
images = []
|
||||
white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
|
||||
|
||||
for img_path in img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if img.shape[2] == 4: # Handle BGRA images
|
||||
alpha = img[:, :, 3] # Extract alpha channel
|
||||
img = cv2.cvtColor(img,
|
||||
cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR
|
||||
|
||||
img = cv2.cvtColor(img,
|
||||
cv2.COLOR_BGR2RGB) # Convert BGR to RGB
|
||||
img[np.where(alpha == 0)] = [
|
||||
255, 255, 255
|
||||
] # Set transparent pixels to white
|
||||
else: # Handle other image formats like JPG and PNG
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
||||
images.append(img)
|
||||
|
||||
images = np.stack(images, axis=0)
|
||||
# images[np.where(images == 0)] = 255 # Set black pixels to white
|
||||
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
|
||||
images = images.astype(np.float32) / 255.0
|
||||
|
||||
return images
|
||||
|
||||
# * recommend to use this function for evaluation
|
||||
@torch.no_grad()
|
||||
def score_gt(self, ref_paths, novel_paths):
|
||||
self.results = []
|
||||
path0 = ref_paths[0]
|
||||
print('calculating CX loss')
|
||||
for path1 in tqdm(novel_paths):
|
||||
# Load images
|
||||
img0, img1 = self.read_img_list([path0]), self.read_img_list(
|
||||
[path1])
|
||||
img0, img1 = numpy_to_torch(img0), numpy_to_torch(img1)
|
||||
img0, img1 = img0 * 0.5 + 0.5, img1 * 0.5 + 0.5
|
||||
img0 = F.interpolate(img0,
|
||||
size=(self.size, self.size),
|
||||
mode='area')
|
||||
img1 = F.interpolate(img1,
|
||||
size=(self.size, self.size),
|
||||
mode='area')
|
||||
loss = criterion(img0.cpu(), img1.cpu())
|
||||
self.results.append(loss.cpu().numpy())
|
||||
|
||||
return self.measure()
|
||||
|
||||
class PSNRMeter:
|
||||
|
||||
def __init__(self, size=800):
|
||||
self.results = []
|
||||
self.size = size
|
||||
|
||||
def read_img_list(self, img_list):
|
||||
size = self.size
|
||||
images = []
|
||||
white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
|
||||
for img_path in img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if img.shape[2] == 4: # Handle BGRA images
|
||||
alpha = img[:, :, 3] # Extract alpha channel
|
||||
img = cv2.cvtColor(img,
|
||||
cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR
|
||||
|
||||
img = cv2.cvtColor(img,
|
||||
cv2.COLOR_BGR2RGB) # Convert BGR to RGB
|
||||
img[np.where(alpha == 0)] = [
|
||||
255, 255, 255
|
||||
] # Set transparent pixels to white
|
||||
else: # Handle other image formats like JPG and PNG
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
||||
images.append(img)
|
||||
|
||||
images = np.stack(images, axis=0)
|
||||
# images[np.where(images == 0)] = 255 # Set black pixels to white
|
||||
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
|
||||
images = images.astype(np.float32) / 255.0
|
||||
# print(images.shape)
|
||||
return images
|
||||
|
||||
def update(self, preds, truths):
|
||||
# print(preds.shape)
|
||||
|
||||
psnr_values = []
|
||||
# For each pair of images in the batches
|
||||
for img1, img2 in zip(preds, truths):
|
||||
# Compute the PSNR and add it to the list
|
||||
# print(img1.shape,img2.shape)
|
||||
|
||||
# for debug
|
||||
# plt.imshow(img1)
|
||||
# plt.show()
|
||||
# plt.imshow(img2)
|
||||
# plt.show()
|
||||
|
||||
psnr = compare_psnr(
|
||||
img1, img2,
|
||||
data_range=1.0) # assuming your images are scaled to [0,1]
|
||||
# print(f"temp psnr {psnr}")
|
||||
psnr_values.append(psnr)
|
||||
|
||||
# Convert the list of PSNR values to a numpy array
|
||||
self.results = psnr_values
|
||||
|
||||
def measure(self):
|
||||
return np.mean(self.results)
|
||||
|
||||
def report(self):
|
||||
return f'PSNR = {self.measure():.6f}'
|
||||
|
||||
# * recommend to use this function for evaluation
|
||||
def score_gt(self, ref_paths, novel_paths):
|
||||
self.results = []
|
||||
# [B, N, 3] or [B, H, W, 3], range[0, 1]
|
||||
preds = self.read_img_list(ref_paths)
|
||||
print('novel_paths', novel_paths)
|
||||
truths = self.read_img_list(novel_paths)
|
||||
self.update(preds, truths)
|
||||
return self.measure()
|
||||
|
||||
# all_inputs = 'data'
|
||||
# nerf_dataset = os.listdir(osp(all_inputs, 'nerf4'))
|
||||
# realfusion_dataset = os.listdir(osp(all_inputs, 'realfusion15'))
|
||||
# meta_examples = {
|
||||
# 'nerf4': nerf_dataset,
|
||||
# 'realfusion15': realfusion_dataset,
|
||||
# }
|
||||
# all_datasets = meta_examples.keys()
|
||||
|
||||
# organization 1
|
||||
def deprecated_score_from_method_for_dataset(my_scorer,
|
||||
method,
|
||||
dataset,
|
||||
input,
|
||||
output,
|
||||
score_type='clip',
|
||||
): # psnr, lpips
|
||||
# print("\n\n\n")
|
||||
# print(f"______{method}___{dataset}___{score_type}_________")
|
||||
scores = {}
|
||||
final_res = 0
|
||||
examples = meta_examples[dataset]
|
||||
for i in range(len(examples)):
|
||||
|
||||
# compare entire folder for clip
|
||||
if score_type == 'clip':
|
||||
novel_view = osp(pred_path, examples[i], 'colors')
|
||||
# compare first image for other metrics
|
||||
else:
|
||||
if method == '3d_fuse': method = '3d_fuse_0'
|
||||
novel_view = list(
|
||||
glob.glob(
|
||||
osp(pred_path, examples[i], 'colors',
|
||||
'step_0000*')))[0]
|
||||
|
||||
score_i = my_scorer.score_gt(
|
||||
[], [novel_view])
|
||||
scores[examples[i]] = score_i
|
||||
final_res += score_i
|
||||
# print(scores, " Avg : ", final_res / len(examples))
|
||||
# print("``````````````````````")
|
||||
return scores
|
||||
|
||||
# results organization 2
|
||||
def score_from_method_for_dataset(my_scorer,
|
||||
input_path,
|
||||
pred_path,
|
||||
score_type='clip',
|
||||
rgb_name='lambertian',
|
||||
result_folder='results/images',
|
||||
first_str='*0000*'
|
||||
): # psnr, lpips
|
||||
scores = {}
|
||||
final_res = 0
|
||||
examples = os.listdir(input_path)
|
||||
for i in range(len(examples)):
|
||||
# ref path
|
||||
ref_path = osp(input_path, examples[i], 'rgba.png')
|
||||
# compare entire folder for clip
|
||||
print(pred_path,'*'+examples[i]+'*', result_folder, f'*{rgb_name}*')
|
||||
exit(0)
|
||||
if score_type == 'clip':
|
||||
novel_view = glob.glob(osp(pred_path,'*'+examples[i]+'*', result_folder, f'*{rgb_name}*'))
|
||||
print(f'[INOF] {score_type} loss for example {examples[i]} between 1 GT and {len(novel_view)} predictions')
|
||||
# compare first image for other metrics
|
||||
else:
|
||||
novel_view = glob.glob(osp(pred_path, '*'+examples[i]+'*/', result_folder, f'{first_str}{rgb_name}*'))
|
||||
print(f'[INOF] {score_type} loss for example {examples[i]} between {ref_path} and {novel_view}')
|
||||
# breakpoint()
|
||||
score_i = my_scorer.score_gt([ref_path], novel_view)
|
||||
scores[examples[i]] = score_i
|
||||
final_res += score_i
|
||||
avg_score = final_res / len(examples)
|
||||
scores['average'] = avg_score
|
||||
return scores
|
||||
|
||||
|
||||
# results organization 2
|
||||
def score_from_my_method_for_dataset(my_scorer,
|
||||
input_path, dataset,
|
||||
score_type='clip'
|
||||
): # psnr, lpips
|
||||
scores = {}
|
||||
final_res = 0
|
||||
input_path = osp(input_path, dataset)
|
||||
ref_path = glob.glob(osp(input_path, "*_rgba.png"))
|
||||
novel_view = [osp(input_path, '%d.png' % i) for i in range(120)]
|
||||
# print(ref_path)
|
||||
# print(novel_view)
|
||||
for i in tqdm(range(120)):
|
||||
if os.path.exists(osp(input_path, '%d_color.png' % i)):
|
||||
continue
|
||||
img = cv2.imread(novel_view[i])
|
||||
H = img.shape[0]
|
||||
img = img[:, :H]
|
||||
cv2.imwrite(osp(input_path, '%d_color.png' % i), img)
|
||||
if score_type == 'clip' or score_type == 'cx':
|
||||
novel_view = [osp(input_path, '%d_color.png' % i) for i in range(120)]
|
||||
else:
|
||||
novel_view = [osp(input_path, '%d_color.png' % i) for i in range(1)]
|
||||
print(novel_view)
|
||||
scores['%s_average' % dataset] = my_scorer.score_gt(ref_path, novel_view)
|
||||
return scores
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Script to accept three string arguments")
|
||||
parser.add_argument("--input_path",
|
||||
default=None,
|
||||
help="Specify the input path")
|
||||
parser.add_argument("--pred_pattern",
|
||||
default="out/magic123*",
|
||||
help="Specify the pattern of predition paths")
|
||||
parser.add_argument("--results_folder",
|
||||
default="results/images",
|
||||
help="where are the results under each pred_path")
|
||||
parser.add_argument("--rgb_name",
|
||||
default="lambertian",
|
||||
help="the postfix of the image")
|
||||
parser.add_argument("--first_str",
|
||||
default="*0000*",
|
||||
help="the str to indicate the first view")
|
||||
parser.add_argument("--datasets",
|
||||
default=None,
|
||||
nargs='*',
|
||||
help="Specify the output path")
|
||||
parser.add_argument("--device",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Specify the GPU device to be used")
|
||||
parser.add_argument("--save_dir", type=str, default='all_metrics/results')
|
||||
args = parser.parse_args()
|
||||
|
||||
clip_scorer = CLIP(args.device)
|
||||
lpips_scorer = LPIPSMeter()
|
||||
psnr_scorer = PSNRMeter()
|
||||
CX_scorer = CXMeter()
|
||||
# criterion = criterion.to(args.device)
|
||||
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
|
||||
for dataset in os.listdir(args.input_path):
|
||||
print(dataset)
|
||||
results_dict = {}
|
||||
results_dict['clip'] = score_from_my_method_for_dataset(
|
||||
clip_scorer, args.input_path, dataset, 'clip')
|
||||
|
||||
results_dict['psnr'] = score_from_my_method_for_dataset(
|
||||
psnr_scorer, args.input_path, dataset, 'psnr')
|
||||
|
||||
results_dict['lpips'] = score_from_my_method_for_dataset(
|
||||
lpips_scorer, args.input_path, dataset, 'lpips')
|
||||
|
||||
results_dict['CX'] = score_from_my_method_for_dataset(
|
||||
CX_scorer, args.input_path, dataset, 'cx')
|
||||
|
||||
df = pd.DataFrame(results_dict)
|
||||
print(df)
|
||||
df.to_csv(f"{args.save_dir}/result.csv")
|
||||
|
||||
|
||||
# for dataset in args.datasets:
|
||||
# input_path = osp(args.input_path, dataset)
|
||||
|
||||
# # assume the pred_path is organized as: pred_path/methods/dataset
|
||||
# pred_pattern = osp(args.pred_pattern, dataset)
|
||||
# pred_paths = glob.glob(pred_pattern)
|
||||
# print(f"[INFO] Following the pattern {pred_pattern}, find {len(pred_paths)} pred_paths: \n", pred_paths)
|
||||
# if len(pred_paths) == 0:
|
||||
# raise IOError
|
||||
# for pred_path in pred_paths:
|
||||
# if not os.path.exists(pred_path):
|
||||
# print(f'[WARN] prediction does not exit for {pred_path}')
|
||||
# else:
|
||||
# print(f'[INFO] evaluate {pred_path}')
|
||||
|
237
preprocess_image.py
Normal file
237
preprocess_image.py
Normal file
@ -0,0 +1,237 @@
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import argparse
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import glob
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
|
||||
class BackgroundRemoval():
|
||||
def __init__(self, device='cuda'):
|
||||
|
||||
from carvekit.api.high import HiInterface
|
||||
self.interface = HiInterface(
|
||||
object_type="object", # Can be "object" or "hairs-like".
|
||||
batch_size_seg=5,
|
||||
batch_size_matting=1,
|
||||
device=device,
|
||||
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
||||
matting_mask_size=2048,
|
||||
trimap_prob_threshold=231,
|
||||
trimap_dilation=30,
|
||||
trimap_erosion_iters=5,
|
||||
fp16=True,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image):
|
||||
# image: [H, W, 3] array in [0, 255].
|
||||
image = Image.fromarray(image)
|
||||
|
||||
image = self.interface([image])[0]
|
||||
image = np.array(image)
|
||||
|
||||
return image
|
||||
|
||||
class BLIP2():
|
||||
def __init__(self, device='cuda'):
|
||||
self.device = device
|
||||
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
||||
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image):
|
||||
image = Image.fromarray(image)
|
||||
inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16)
|
||||
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=20)
|
||||
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
|
||||
return generated_text
|
||||
|
||||
|
||||
class DPT():
|
||||
def __init__(self, task='depth', device='cuda'):
|
||||
|
||||
self.task = task
|
||||
self.device = device
|
||||
|
||||
from threestudio.utils.dpt import DPTDepthModel
|
||||
|
||||
if task == 'depth':
|
||||
path = 'load/omnidata/omnidata_dpt_depth_v2.ckpt'
|
||||
self.model = DPTDepthModel(backbone='vitb_rn50_384')
|
||||
self.aug = transforms.Compose([
|
||||
transforms.Resize((384, 384)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=0.5, std=0.5)
|
||||
])
|
||||
|
||||
else: # normal
|
||||
path = 'load/omnidata/omnidata_dpt_normal_v2.ckpt'
|
||||
self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3)
|
||||
self.aug = transforms.Compose([
|
||||
transforms.Resize((384, 384)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
# load model
|
||||
checkpoint = torch.load(path, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = {}
|
||||
for k, v in checkpoint['state_dict'].items():
|
||||
state_dict[k[6:]] = v
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.model.eval().to(device)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image):
|
||||
# image: np.ndarray, uint8, [H, W, 3]
|
||||
H, W = image.shape[:2]
|
||||
image = Image.fromarray(image)
|
||||
|
||||
image = self.aug(image).unsqueeze(0).to(self.device)
|
||||
|
||||
if self.task == 'depth':
|
||||
depth = self.model(image).clamp(0, 1)
|
||||
depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False)
|
||||
depth = depth.squeeze(1).cpu().numpy()
|
||||
return depth
|
||||
else:
|
||||
normal = self.model(image).clamp(0, 1)
|
||||
normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False)
|
||||
normal = normal.cpu().numpy()
|
||||
return normal
|
||||
|
||||
def preprocess_single_image(img_path, args):
|
||||
out_dir = os.path.dirname(img_path)
|
||||
out_rgba = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_rgba.png')
|
||||
out_depth = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_depth.png')
|
||||
out_normal = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_normal.png')
|
||||
out_caption = os.path.join(out_dir, os.path.basename(img_path).split('.')[0] + '_caption.txt')
|
||||
|
||||
# load image
|
||||
print(f'[INFO] loading image {img_path}...')
|
||||
|
||||
# check the exisiting files
|
||||
if os.path.isfile(out_rgba) and os.path.isfile(out_depth) and os.path.isfile(out_normal):
|
||||
print(f"{img_path} has already been here!")
|
||||
return
|
||||
print(img_path)
|
||||
image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
carved_image = None
|
||||
# debug
|
||||
if image.shape[-1] == 4:
|
||||
if args.do_rm_bg_force:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||
else:
|
||||
carved_image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||
|
||||
else:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if args.do_seg:
|
||||
if carved_image is None:
|
||||
# carve background
|
||||
print(f'[INFO] background removal...')
|
||||
carved_image = BackgroundRemoval()(image) # [H, W, 4]
|
||||
mask = carved_image[..., -1] > 0
|
||||
|
||||
# predict depth
|
||||
print(f'[INFO] depth estimation...')
|
||||
dpt_depth_model = DPT(task='depth')
|
||||
depth = dpt_depth_model(image)[0]
|
||||
depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9)
|
||||
depth[~mask] = 0
|
||||
depth = (depth * 255).astype(np.uint8)
|
||||
del dpt_depth_model
|
||||
|
||||
# predict normal
|
||||
print(f'[INFO] normal estimation...')
|
||||
dpt_normal_model = DPT(task='normal')
|
||||
normal = dpt_normal_model(image)[0]
|
||||
normal = (normal * 255).astype(np.uint8).transpose(1, 2, 0)
|
||||
normal[~mask] = 0
|
||||
del dpt_normal_model
|
||||
|
||||
opt.recenter=False
|
||||
# recenter
|
||||
if opt.recenter:
|
||||
print(f'[INFO] recenter...')
|
||||
final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8)
|
||||
final_depth = np.zeros((opt.size, opt.size), dtype=np.uint8)
|
||||
final_normal = np.zeros((opt.size, opt.size, 3), dtype=np.uint8)
|
||||
|
||||
coords = np.nonzero(mask)
|
||||
x_min, x_max = coords[0].min(), coords[0].max()
|
||||
y_min, y_max = coords[1].min(), coords[1].max()
|
||||
h = x_max - x_min
|
||||
w = y_max - y_min
|
||||
desired_size = int(opt.size * (1 - opt.border_ratio))
|
||||
scale = desired_size / max(h, w)
|
||||
h2 = int(h * scale)
|
||||
w2 = int(w * scale)
|
||||
x2_min = (opt.size - h2) // 2
|
||||
x2_max = x2_min + h2
|
||||
y2_min = (opt.size - w2) // 2
|
||||
y2_max = y2_min + w2
|
||||
final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
|
||||
final_depth[x2_min:x2_max, y2_min:y2_max] = cv2.resize(depth[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
|
||||
final_normal[x2_min:x2_max, y2_min:y2_max] = cv2.resize(normal[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
|
||||
|
||||
else:
|
||||
final_rgba = carved_image
|
||||
final_depth = depth
|
||||
final_normal = normal
|
||||
|
||||
# write output
|
||||
cv2.imwrite(out_rgba, cv2.cvtColor(final_rgba, cv2.COLOR_RGBA2BGRA))
|
||||
cv2.imwrite(out_depth, final_depth)
|
||||
cv2.imwrite(out_normal, final_normal)
|
||||
|
||||
if opt.do_caption:
|
||||
# predict caption (it's too slow... use your brain instead)
|
||||
print(f'[INFO] captioning...')
|
||||
blip2 = BLIP2()
|
||||
caption = blip2(image)
|
||||
with open(out_caption, 'w') as f:
|
||||
f.write(caption)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)")
|
||||
parser.add_argument('--size', default=1024, type=int, help="output resolution")
|
||||
parser.add_argument('--border_ratio', default=0.1, type=float, help="output border ratio")
|
||||
parser.add_argument('--recenter', type=bool, default=False, help="recenter, potentially not helpful for multiview zero123")
|
||||
parser.add_argument('--dont_recenter', dest='recenter', action='store_false')
|
||||
parser.add_argument('--do_caption', type=bool, default=False, help="do text captioning")
|
||||
parser.add_argument('--do_seg', type=bool, default=True)
|
||||
parser.add_argument('--do_rm_bg_force', type=bool, default=False)
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
if os.path.isdir(opt.path):
|
||||
img_list = sorted(os.path.join(root, fname) for root, _dirs, files in os.walk(opt.path) for fname in files)
|
||||
img_list = [img for img in img_list if not img.endswith("rgba.png") and not img.endswith("depth.png") and not img.endswith("normal.png")]
|
||||
img_list = [img for img in img_list if img.endswith(".png")]
|
||||
for img in img_list:
|
||||
# try:
|
||||
preprocess_single_image(img, opt)
|
||||
# except:
|
||||
# with open("preprocess_images_invalid.txt", "a") as f:
|
||||
# print(img, file=f)
|
||||
else: # single image file
|
||||
preprocess_single_image(opt.path, opt)
|
35
requirements.txt
Normal file
35
requirements.txt
Normal file
@ -0,0 +1,35 @@
|
||||
lightning==2.0.0
|
||||
omegaconf==2.3.0
|
||||
jaxtyping
|
||||
typeguard
|
||||
diffusers<=0.23.0
|
||||
transformers
|
||||
accelerate
|
||||
opencv-python
|
||||
tensorboard
|
||||
matplotlib
|
||||
imageio>=2.28.0
|
||||
imageio[ffmpeg]
|
||||
libigl
|
||||
xatlas
|
||||
trimesh[easy]
|
||||
networkx
|
||||
pysdf
|
||||
PyMCubes
|
||||
wandb
|
||||
gradio
|
||||
|
||||
# deepfloyd
|
||||
xformers
|
||||
bitsandbytes
|
||||
sentencepiece
|
||||
safetensors
|
||||
huggingface_hub
|
||||
|
||||
# for zero123
|
||||
einops
|
||||
kornia
|
||||
taming-transformers-rom1504
|
||||
|
||||
#controlnet
|
||||
controlnet_aux
|
36
threestudio/__init__.py
Normal file
36
threestudio/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
__modules__ = {}
|
||||
|
||||
|
||||
def register(name):
|
||||
def decorator(cls):
|
||||
__modules__[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def find(name):
|
||||
return __modules__[name]
|
||||
|
||||
|
||||
### grammar sugar for logging utilities ###
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("pytorch_lightning")
|
||||
|
||||
from pytorch_lightning.utilities.rank_zero import (
|
||||
rank_zero_debug,
|
||||
rank_zero_info,
|
||||
rank_zero_only,
|
||||
)
|
||||
|
||||
debug = rank_zero_debug
|
||||
info = rank_zero_info
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def warn(*args, **kwargs):
|
||||
logger.warn(*args, **kwargs)
|
||||
|
||||
|
||||
from . import data, models, systems
|
1
threestudio/data/__init__.py
Normal file
1
threestudio/data/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import image, uncond
|
351
threestudio/data/image.py
Normal file
351
threestudio/data/image.py
Normal file
@ -0,0 +1,351 @@
|
||||
import bisect
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
import threestudio
|
||||
from threestudio import register
|
||||
from threestudio.data.uncond import (
|
||||
RandomCameraDataModuleConfig,
|
||||
RandomCameraDataset,
|
||||
RandomCameraIterableDataset,
|
||||
)
|
||||
from threestudio.utils.base import Updateable
|
||||
from threestudio.utils.config import parse_structured
|
||||
from threestudio.utils.misc import get_rank
|
||||
from threestudio.utils.ops import (
|
||||
get_mvp_matrix,
|
||||
get_projection_matrix,
|
||||
get_ray_directions,
|
||||
get_rays,
|
||||
)
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleImageDataModuleConfig:
|
||||
# height and width should be Union[int, List[int]]
|
||||
# but OmegaConf does not support Union of containers
|
||||
height: Any = 96
|
||||
width: Any = 96
|
||||
resolution_milestones: List[int] = field(default_factory=lambda: [])
|
||||
default_elevation_deg: float = 0.0
|
||||
default_azimuth_deg: float = -180.0
|
||||
default_camera_distance: float = 1.2
|
||||
default_fovy_deg: float = 60.0
|
||||
image_path: str = ""
|
||||
use_random_camera: bool = True
|
||||
random_camera: dict = field(default_factory=dict)
|
||||
rays_noise_scale: float = 2e-3
|
||||
batch_size: int = 1
|
||||
requires_depth: bool = False
|
||||
requires_normal: bool = False
|
||||
rays_d_normalize: bool = True
|
||||
use_mixed_camera_config: bool = False
|
||||
|
||||
|
||||
class SingleImageDataBase:
|
||||
def setup(self, cfg, split):
|
||||
self.split = split
|
||||
self.rank = get_rank()
|
||||
self.cfg: SingleImageDataModuleConfig = cfg
|
||||
|
||||
if self.cfg.use_random_camera:
|
||||
random_camera_cfg = parse_structured(
|
||||
RandomCameraDataModuleConfig, self.cfg.get("random_camera", {})
|
||||
)
|
||||
# FIXME:
|
||||
if self.cfg.use_mixed_camera_config:
|
||||
if self.rank % 2 == 0:
|
||||
random_camera_cfg.camera_distance_range=[self.cfg.default_camera_distance, self.cfg.default_camera_distance]
|
||||
random_camera_cfg.fovy_range=[self.cfg.default_fovy_deg, self.cfg.default_fovy_deg]
|
||||
self.fixed_camera_intrinsic = True
|
||||
else:
|
||||
self.fixed_camera_intrinsic = False
|
||||
if split == "train":
|
||||
self.random_pose_generator = RandomCameraIterableDataset(
|
||||
random_camera_cfg
|
||||
)
|
||||
else:
|
||||
self.random_pose_generator = RandomCameraDataset(
|
||||
random_camera_cfg, split
|
||||
)
|
||||
|
||||
elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg])
|
||||
azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg])
|
||||
camera_distance = torch.FloatTensor([self.cfg.default_camera_distance])
|
||||
|
||||
elevation = elevation_deg * math.pi / 180
|
||||
azimuth = azimuth_deg * math.pi / 180
|
||||
camera_position: Float[Tensor, "1 3"] = torch.stack(
|
||||
[
|
||||
camera_distance * torch.cos(elevation) * torch.cos(azimuth),
|
||||
camera_distance * torch.cos(elevation) * torch.sin(azimuth),
|
||||
camera_distance * torch.sin(elevation),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position)
|
||||
up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None]
|
||||
|
||||
light_position: Float[Tensor, "1 3"] = camera_position
|
||||
lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1)
|
||||
right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
|
||||
up = F.normalize(torch.cross(right, lookat), dim=-1)
|
||||
self.c2w: Float[Tensor, "1 3 4"] = torch.cat(
|
||||
[torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]],
|
||||
dim=-1,
|
||||
)
|
||||
self.c2w4x4: Float[Tensor, "B 4 4"] = torch.cat(
|
||||
[self.c2w, torch.zeros_like(self.c2w[:, :1])], dim=1
|
||||
)
|
||||
self.c2w4x4[:, 3, 3] = 1.0
|
||||
|
||||
self.camera_position = camera_position
|
||||
self.light_position = light_position
|
||||
self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg
|
||||
self.camera_distance = camera_distance
|
||||
self.fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg]))
|
||||
|
||||
self.heights: List[int] = (
|
||||
[self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height
|
||||
)
|
||||
self.widths: List[int] = (
|
||||
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width
|
||||
)
|
||||
assert len(self.heights) == len(self.widths)
|
||||
self.resolution_milestones: List[int]
|
||||
if len(self.heights) == 1 and len(self.widths) == 1:
|
||||
if len(self.cfg.resolution_milestones) > 0:
|
||||
threestudio.warn(
|
||||
"Ignoring resolution_milestones since height and width are not changing"
|
||||
)
|
||||
self.resolution_milestones = [-1]
|
||||
else:
|
||||
assert len(self.heights) == len(self.cfg.resolution_milestones) + 1
|
||||
self.resolution_milestones = [-1] + self.cfg.resolution_milestones
|
||||
|
||||
self.directions_unit_focals = [
|
||||
get_ray_directions(H=height, W=width, focal=1.0)
|
||||
for (height, width) in zip(self.heights, self.widths)
|
||||
]
|
||||
self.focal_lengths = [
|
||||
0.5 * height / torch.tan(0.5 * self.fovy) for height in self.heights
|
||||
]
|
||||
|
||||
self.height: int = self.heights[0]
|
||||
self.width: int = self.widths[0]
|
||||
self.directions_unit_focal = self.directions_unit_focals[0]
|
||||
self.focal_length = self.focal_lengths[0]
|
||||
self.set_rays()
|
||||
self.load_images()
|
||||
self.prev_height = self.height
|
||||
|
||||
def set_rays(self):
|
||||
# get directions by dividing directions_unit_focal by focal length
|
||||
directions: Float[Tensor, "1 H W 3"] = self.directions_unit_focal[None]
|
||||
directions[:, :, :, :2] = directions[:, :, :, :2] / self.focal_length
|
||||
|
||||
rays_o, rays_d = get_rays(
|
||||
directions,
|
||||
self.c2w,
|
||||
keepdim=True,
|
||||
noise_scale=self.cfg.rays_noise_scale,
|
||||
normalize=self.cfg.rays_d_normalize,
|
||||
)
|
||||
|
||||
proj_mtx: Float[Tensor, "4 4"] = get_projection_matrix(
|
||||
self.fovy, self.width / self.height, 0.01, 100.0
|
||||
) # FIXME: hard-coded near and far
|
||||
mvp_mtx: Float[Tensor, "4 4"] = get_mvp_matrix(self.c2w, proj_mtx)
|
||||
|
||||
self.rays_o, self.rays_d = rays_o, rays_d
|
||||
self.mvp_mtx = mvp_mtx
|
||||
|
||||
def load_images(self):
|
||||
# load image
|
||||
assert os.path.exists(
|
||||
self.cfg.image_path
|
||||
), f"Could not find image {self.cfg.image_path}!"
|
||||
rgba = cv2.cvtColor(
|
||||
cv2.imread(self.cfg.image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA
|
||||
)
|
||||
rgba = (
|
||||
cv2.resize(
|
||||
rgba, (self.width, self.height), interpolation=cv2.INTER_AREA
|
||||
).astype(np.float32)
|
||||
/ 255.0
|
||||
)
|
||||
rgb = rgba[..., :3]
|
||||
self.rgb: Float[Tensor, "1 H W 3"] = (
|
||||
torch.from_numpy(rgb).unsqueeze(0).contiguous().to(self.rank)
|
||||
)
|
||||
self.mask: Float[Tensor, "1 H W 1"] = (
|
||||
torch.from_numpy(rgba[..., 3:] > 0.5).unsqueeze(0).to(self.rank)
|
||||
)
|
||||
print(
|
||||
f"[INFO] single image dataset: load image {self.cfg.image_path} {self.rgb.shape}"
|
||||
)
|
||||
|
||||
# load depth
|
||||
if self.cfg.requires_depth:
|
||||
depth_path = self.cfg.image_path.replace("_rgba.png", "_depth.png")
|
||||
assert os.path.exists(depth_path)
|
||||
depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
|
||||
depth = cv2.resize(
|
||||
depth, (self.width, self.height), interpolation=cv2.INTER_AREA
|
||||
)
|
||||
self.depth: Float[Tensor, "1 H W 1"] = (
|
||||
torch.from_numpy(depth.astype(np.float32) / 255.0)
|
||||
.unsqueeze(0)
|
||||
.to(self.rank)
|
||||
)
|
||||
print(
|
||||
f"[INFO] single image dataset: load depth {depth_path} {self.depth.shape}"
|
||||
)
|
||||
else:
|
||||
self.depth = None
|
||||
|
||||
# load normal
|
||||
if self.cfg.requires_normal:
|
||||
normal_path = self.cfg.image_path.replace("_rgba.png", "_normal.png")
|
||||
assert os.path.exists(normal_path)
|
||||
normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED)
|
||||
normal = cv2.resize(
|
||||
normal, (self.width, self.height), interpolation=cv2.INTER_AREA
|
||||
)
|
||||
self.normal: Float[Tensor, "1 H W 3"] = (
|
||||
torch.from_numpy(normal.astype(np.float32) / 255.0)
|
||||
.unsqueeze(0)
|
||||
.to(self.rank)
|
||||
)
|
||||
print(
|
||||
f"[INFO] single image dataset: load normal {normal_path} {self.normal.shape}"
|
||||
)
|
||||
else:
|
||||
self.normal = None
|
||||
|
||||
def get_all_images(self):
|
||||
return self.rgb
|
||||
|
||||
def update_step_(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1
|
||||
self.height = self.heights[size_ind]
|
||||
if self.height == self.prev_height:
|
||||
return
|
||||
|
||||
self.prev_height = self.height
|
||||
self.width = self.widths[size_ind]
|
||||
self.directions_unit_focal = self.directions_unit_focals[size_ind]
|
||||
self.focal_length = self.focal_lengths[size_ind]
|
||||
threestudio.debug(f"Training height: {self.height}, width: {self.width}")
|
||||
self.set_rays()
|
||||
self.load_images()
|
||||
|
||||
|
||||
class SingleImageIterableDataset(IterableDataset, SingleImageDataBase, Updateable):
|
||||
def __init__(self, cfg: Any, split: str) -> None:
|
||||
super().__init__()
|
||||
self.setup(cfg, split)
|
||||
|
||||
def collate(self, batch) -> Dict[str, Any]:
|
||||
batch = {
|
||||
"rays_o": self.rays_o,
|
||||
"rays_d": self.rays_d,
|
||||
"mvp_mtx": self.mvp_mtx,
|
||||
"camera_positions": self.camera_position,
|
||||
"light_positions": self.light_position,
|
||||
"elevation": self.elevation_deg,
|
||||
"azimuth": self.azimuth_deg,
|
||||
"camera_distances": self.camera_distance,
|
||||
"rgb": self.rgb,
|
||||
"ref_depth": self.depth,
|
||||
"ref_normal": self.normal,
|
||||
"mask": self.mask,
|
||||
"height": self.cfg.height,
|
||||
"width": self.cfg.width,
|
||||
"c2w": self.c2w,
|
||||
"c2w4x4": self.c2w4x4,
|
||||
}
|
||||
if self.cfg.use_random_camera:
|
||||
batch["random_camera"] = self.random_pose_generator.collate(None)
|
||||
|
||||
return batch
|
||||
|
||||
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
self.update_step_(epoch, global_step, on_load_weights)
|
||||
self.random_pose_generator.update_step(epoch, global_step, on_load_weights)
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
yield {}
|
||||
|
||||
|
||||
class SingleImageDataset(Dataset, SingleImageDataBase):
|
||||
def __init__(self, cfg: Any, split: str) -> None:
|
||||
super().__init__()
|
||||
self.setup(cfg, split)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.random_pose_generator)
|
||||
|
||||
def __getitem__(self, index):
|
||||
batch = self.random_pose_generator[index]
|
||||
batch.update(
|
||||
{
|
||||
"height": self.random_pose_generator.cfg.eval_height,
|
||||
"width": self.random_pose_generator.cfg.eval_width,
|
||||
"mvp_mtx_ref": self.mvp_mtx[0],
|
||||
"c2w_ref": self.c2w4x4,
|
||||
}
|
||||
)
|
||||
return batch
|
||||
|
||||
|
||||
@register("single-image-datamodule")
|
||||
class SingleImageDataModule(pl.LightningDataModule):
|
||||
cfg: SingleImageDataModuleConfig
|
||||
|
||||
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
|
||||
super().__init__()
|
||||
self.cfg = parse_structured(SingleImageDataModuleConfig, cfg)
|
||||
|
||||
def setup(self, stage=None) -> None:
|
||||
if stage in [None, "fit"]:
|
||||
self.train_dataset = SingleImageIterableDataset(self.cfg, "train")
|
||||
if stage in [None, "fit", "validate"]:
|
||||
self.val_dataset = SingleImageDataset(self.cfg, "val")
|
||||
if stage in [None, "test", "predict"]:
|
||||
self.test_dataset = SingleImageDataset(self.cfg, "test")
|
||||
|
||||
def prepare_data(self):
|
||||
pass
|
||||
|
||||
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
|
||||
return DataLoader(
|
||||
dataset, num_workers=0, batch_size=batch_size, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(
|
||||
self.train_dataset,
|
||||
batch_size=self.cfg.batch_size,
|
||||
collate_fn=self.train_dataset.collate,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(self.val_dataset, batch_size=1)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(self.test_dataset, batch_size=1)
|
||||
|
||||
def predict_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(self.test_dataset, batch_size=1)
|
351
threestudio/data/images.py
Normal file
351
threestudio/data/images.py
Normal file
@ -0,0 +1,351 @@
|
||||
import bisect
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
import threestudio
|
||||
from threestudio import register
|
||||
from threestudio.data.uncond import (
|
||||
RandomCameraDataModuleConfig,
|
||||
RandomCameraDataset,
|
||||
RandomCameraIterableDataset,
|
||||
)
|
||||
from threestudio.utils.base import Updateable
|
||||
from threestudio.utils.config import parse_structured
|
||||
from threestudio.utils.misc import get_rank
|
||||
from threestudio.utils.ops import (
|
||||
get_mvp_matrix,
|
||||
get_projection_matrix,
|
||||
get_ray_directions,
|
||||
get_rays,
|
||||
)
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleImageDataModuleConfig:
|
||||
# height and width should be Union[int, List[int]]
|
||||
# but OmegaConf does not support Union of containers
|
||||
height: Any = 96
|
||||
width: Any = 96
|
||||
resolution_milestones: List[int] = field(default_factory=lambda: [])
|
||||
default_elevation_deg: float = 0.0
|
||||
default_azimuth_deg: float = -180.0
|
||||
default_camera_distance: float = 1.2
|
||||
default_fovy_deg: float = 60.0
|
||||
image_path: str = ""
|
||||
use_random_camera: bool = True
|
||||
random_camera: dict = field(default_factory=dict)
|
||||
rays_noise_scale: float = 2e-3
|
||||
batch_size: int = 1
|
||||
requires_depth: bool = False
|
||||
requires_normal: bool = False
|
||||
rays_d_normalize: bool = True
|
||||
use_mixed_camera_config: bool = False
|
||||
|
||||
|
||||
class SingleImageDataBase:
|
||||
def setup(self, cfg, split):
|
||||
self.split = split
|
||||
self.rank = get_rank()
|
||||
self.cfg: SingleImageDataModuleConfig = cfg
|
||||
|
||||
if self.cfg.use_random_camera:
|
||||
random_camera_cfg = parse_structured(
|
||||
RandomCameraDataModuleConfig, self.cfg.get("random_camera", {})
|
||||
)
|
||||
# FIXME:
|
||||
if self.cfg.use_mixed_camera_config:
|
||||
if self.rank % 2 == 0:
|
||||
random_camera_cfg.camera_distance_range=[self.cfg.default_camera_distance, self.cfg.default_camera_distance]
|
||||
random_camera_cfg.fovy_range=[self.cfg.default_fovy_deg, self.cfg.default_fovy_deg]
|
||||
self.fixed_camera_intrinsic = True
|
||||
else:
|
||||
self.fixed_camera_intrinsic = False
|
||||
if split == "train":
|
||||
self.random_pose_generator = RandomCameraIterableDataset(
|
||||
random_camera_cfg
|
||||
)
|
||||
else:
|
||||
self.random_pose_generator = RandomCameraDataset(
|
||||
random_camera_cfg, split
|
||||
)
|
||||
|
||||
elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg])
|
||||
azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg])
|
||||
camera_distance = torch.FloatTensor([self.cfg.default_camera_distance])
|
||||
|
||||
elevation = elevation_deg * math.pi / 180
|
||||
azimuth = azimuth_deg * math.pi / 180
|
||||
camera_position: Float[Tensor, "1 3"] = torch.stack(
|
||||
[
|
||||
camera_distance * torch.cos(elevation) * torch.cos(azimuth),
|
||||
camera_distance * torch.cos(elevation) * torch.sin(azimuth),
|
||||
camera_distance * torch.sin(elevation),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position)
|
||||
up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None]
|
||||
|
||||
light_position: Float[Tensor, "1 3"] = camera_position
|
||||
lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1)
|
||||
right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
|
||||
up = F.normalize(torch.cross(right, lookat), dim=-1)
|
||||
self.c2w: Float[Tensor, "1 3 4"] = torch.cat(
|
||||
[torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]],
|
||||
dim=-1,
|
||||
)
|
||||
self.c2w4x4: Float[Tensor, "B 4 4"] = torch.cat(
|
||||
[self.c2w, torch.zeros_like(self.c2w[:, :1])], dim=1
|
||||
)
|
||||
self.c2w4x4[:, 3, 3] = 1.0
|
||||
|
||||
self.camera_position = camera_position
|
||||
self.light_position = light_position
|
||||
self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg
|
||||
self.camera_distance = camera_distance
|
||||
self.fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg]))
|
||||
|
||||
self.heights: List[int] = (
|
||||
[self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height
|
||||
)
|
||||
self.widths: List[int] = (
|
||||
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width
|
||||
)
|
||||
assert len(self.heights) == len(self.widths)
|
||||
self.resolution_milestones: List[int]
|
||||
if len(self.heights) == 1 and len(self.widths) == 1:
|
||||
if len(self.cfg.resolution_milestones) > 0:
|
||||
threestudio.warn(
|
||||
"Ignoring resolution_milestones since height and width are not changing"
|
||||
)
|
||||
self.resolution_milestones = [-1]
|
||||
else:
|
||||
assert len(self.heights) == len(self.cfg.resolution_milestones) + 1
|
||||
self.resolution_milestones = [-1] + self.cfg.resolution_milestones
|
||||
|
||||
self.directions_unit_focals = [
|
||||
get_ray_directions(H=height, W=width, focal=1.0)
|
||||
for (height, width) in zip(self.heights, self.widths)
|
||||
]
|
||||
self.focal_lengths = [
|
||||
0.5 * height / torch.tan(0.5 * self.fovy) for height in self.heights
|
||||
]
|
||||
|
||||
self.height: int = self.heights[0]
|
||||
self.width: int = self.widths[0]
|
||||
self.directions_unit_focal = self.directions_unit_focals[0]
|
||||
self.focal_length = self.focal_lengths[0]
|
||||
self.set_rays()
|
||||
self.load_images()
|
||||
self.prev_height = self.height
|
||||
|
||||
def set_rays(self):
|
||||
# get directions by dividing directions_unit_focal by focal length
|
||||
directions: Float[Tensor, "1 H W 3"] = self.directions_unit_focal[None]
|
||||
directions[:, :, :, :2] = directions[:, :, :, :2] / self.focal_length
|
||||
|
||||
rays_o, rays_d = get_rays(
|
||||
directions,
|
||||
self.c2w,
|
||||
keepdim=True,
|
||||
noise_scale=self.cfg.rays_noise_scale,
|
||||
normalize=self.cfg.rays_d_normalize,
|
||||
)
|
||||
|
||||
proj_mtx: Float[Tensor, "4 4"] = get_projection_matrix(
|
||||
self.fovy, self.width / self.height, 0.01, 100.0
|
||||
) # FIXME: hard-coded near and far
|
||||
mvp_mtx: Float[Tensor, "4 4"] = get_mvp_matrix(self.c2w, proj_mtx)
|
||||
|
||||
self.rays_o, self.rays_d = rays_o, rays_d
|
||||
self.mvp_mtx = mvp_mtx
|
||||
|
||||
def load_images(self):
|
||||
# load image
|
||||
assert os.path.exists(
|
||||
self.cfg.image_path
|
||||
), f"Could not find image {self.cfg.image_path}!"
|
||||
rgba = cv2.cvtColor(
|
||||
cv2.imread(self.cfg.image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA
|
||||
)
|
||||
rgba = (
|
||||
cv2.resize(
|
||||
rgba, (self.width, self.height), interpolation=cv2.INTER_AREA
|
||||
).astype(np.float32)
|
||||
/ 255.0
|
||||
)
|
||||
rgb = rgba[..., :3]
|
||||
self.rgb: Float[Tensor, "1 H W 3"] = (
|
||||
torch.from_numpy(rgb).unsqueeze(0).contiguous().to(self.rank)
|
||||
)
|
||||
self.mask: Float[Tensor, "1 H W 1"] = (
|
||||
torch.from_numpy(rgba[..., 3:] > 0.5).unsqueeze(0).to(self.rank)
|
||||
)
|
||||
print(
|
||||
f"[INFO] single image dataset: load image {self.cfg.image_path} {self.rgb.shape}"
|
||||
)
|
||||
|
||||
# load depth
|
||||
if self.cfg.requires_depth:
|
||||
depth_path = self.cfg.image_path.replace("_rgba.png", "_depth.png")
|
||||
assert os.path.exists(depth_path)
|
||||
depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
|
||||
depth = cv2.resize(
|
||||
depth, (self.width, self.height), interpolation=cv2.INTER_AREA
|
||||
)
|
||||
self.depth: Float[Tensor, "1 H W 1"] = (
|
||||
torch.from_numpy(depth.astype(np.float32) / 255.0)
|
||||
.unsqueeze(0)
|
||||
.to(self.rank)
|
||||
)
|
||||
print(
|
||||
f"[INFO] single image dataset: load depth {depth_path} {self.depth.shape}"
|
||||
)
|
||||
else:
|
||||
self.depth = None
|
||||
|
||||
# load normal
|
||||
if self.cfg.requires_normal:
|
||||
normal_path = self.cfg.image_path.replace("_rgba.png", "_normal.png")
|
||||
assert os.path.exists(normal_path)
|
||||
normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED)
|
||||
normal = cv2.resize(
|
||||
normal, (self.width, self.height), interpolation=cv2.INTER_AREA
|
||||
)
|
||||
self.normal: Float[Tensor, "1 H W 3"] = (
|
||||
torch.from_numpy(normal.astype(np.float32) / 255.0)
|
||||
.unsqueeze(0)
|
||||
.to(self.rank)
|
||||
)
|
||||
print(
|
||||
f"[INFO] single image dataset: load normal {normal_path} {self.normal.shape}"
|
||||
)
|
||||
else:
|
||||
self.normal = None
|
||||
|
||||
def get_all_images(self):
|
||||
return self.rgb
|
||||
|
||||
def update_step_(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1
|
||||
self.height = self.heights[size_ind]
|
||||
if self.height == self.prev_height:
|
||||
return
|
||||
|
||||
self.prev_height = self.height
|
||||
self.width = self.widths[size_ind]
|
||||
self.directions_unit_focal = self.directions_unit_focals[size_ind]
|
||||
self.focal_length = self.focal_lengths[size_ind]
|
||||
threestudio.debug(f"Training height: {self.height}, width: {self.width}")
|
||||
self.set_rays()
|
||||
self.load_images()
|
||||
|
||||
|
||||
class SingleImageIterableDataset(IterableDataset, SingleImageDataBase, Updateable):
|
||||
def __init__(self, cfg: Any, split: str) -> None:
|
||||
super().__init__()
|
||||
self.setup(cfg, split)
|
||||
|
||||
def collate(self, batch) -> Dict[str, Any]:
|
||||
batch = {
|
||||
"rays_o": self.rays_o,
|
||||
"rays_d": self.rays_d,
|
||||
"mvp_mtx": self.mvp_mtx,
|
||||
"camera_positions": self.camera_position,
|
||||
"light_positions": self.light_position,
|
||||
"elevation": self.elevation_deg,
|
||||
"azimuth": self.azimuth_deg,
|
||||
"camera_distances": self.camera_distance,
|
||||
"rgb": self.rgb,
|
||||
"ref_depth": self.depth,
|
||||
"ref_normal": self.normal,
|
||||
"mask": self.mask,
|
||||
"height": self.cfg.height,
|
||||
"width": self.cfg.width,
|
||||
"c2w": self.c2w,
|
||||
"c2w4x4": self.c2w4x4,
|
||||
}
|
||||
if self.cfg.use_random_camera:
|
||||
batch["random_camera"] = self.random_pose_generator.collate(None)
|
||||
|
||||
return batch
|
||||
|
||||
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
self.update_step_(epoch, global_step, on_load_weights)
|
||||
self.random_pose_generator.update_step(epoch, global_step, on_load_weights)
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
yield {}
|
||||
|
||||
|
||||
class SingleImageDataset(Dataset, SingleImageDataBase):
|
||||
def __init__(self, cfg: Any, split: str) -> None:
|
||||
super().__init__()
|
||||
self.setup(cfg, split)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.random_pose_generator)
|
||||
|
||||
def __getitem__(self, index):
|
||||
batch = self.random_pose_generator[index]
|
||||
batch.update(
|
||||
{
|
||||
"height": self.random_pose_generator.cfg.eval_height,
|
||||
"width": self.random_pose_generator.cfg.eval_width,
|
||||
"mvp_mtx_ref": self.mvp_mtx[0],
|
||||
"c2w_ref": self.c2w4x4,
|
||||
}
|
||||
)
|
||||
return batch
|
||||
|
||||
|
||||
@register("single-image-datamodule")
|
||||
class SingleImageDataModule(pl.LightningDataModule):
|
||||
cfg: SingleImageDataModuleConfig
|
||||
|
||||
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
|
||||
super().__init__()
|
||||
self.cfg = parse_structured(SingleImageDataModuleConfig, cfg)
|
||||
|
||||
def setup(self, stage=None) -> None:
|
||||
if stage in [None, "fit"]:
|
||||
self.train_dataset = SingleImageIterableDataset(self.cfg, "train")
|
||||
if stage in [None, "fit", "validate"]:
|
||||
self.val_dataset = SingleImageDataset(self.cfg, "val")
|
||||
if stage in [None, "test", "predict"]:
|
||||
self.test_dataset = SingleImageDataset(self.cfg, "test")
|
||||
|
||||
def prepare_data(self):
|
||||
pass
|
||||
|
||||
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
|
||||
return DataLoader(
|
||||
dataset, num_workers=0, batch_size=batch_size, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(
|
||||
self.train_dataset,
|
||||
batch_size=self.cfg.batch_size,
|
||||
collate_fn=self.train_dataset.collate,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(self.val_dataset, batch_size=1)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(self.test_dataset, batch_size=1)
|
||||
|
||||
def predict_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(self.test_dataset, batch_size=1)
|
518
threestudio/data/uncond.py
Normal file
518
threestudio/data/uncond.py
Normal file
@ -0,0 +1,518 @@
|
||||
import bisect
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
import threestudio
|
||||
from threestudio import register
|
||||
from threestudio.utils.base import Updateable
|
||||
from threestudio.utils.config import parse_structured
|
||||
from threestudio.utils.misc import get_device
|
||||
from threestudio.utils.ops import (
|
||||
get_full_projection_matrix,
|
||||
get_mvp_matrix,
|
||||
get_projection_matrix,
|
||||
get_ray_directions,
|
||||
get_rays,
|
||||
)
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class RandomCameraDataModuleConfig:
|
||||
# height, width, and batch_size should be Union[int, List[int]]
|
||||
# but OmegaConf does not support Union of containers
|
||||
height: Any = 64
|
||||
width: Any = 64
|
||||
batch_size: Any = 1
|
||||
resolution_milestones: List[int] = field(default_factory=lambda: [])
|
||||
eval_height: int = 512
|
||||
eval_width: int = 512
|
||||
eval_batch_size: int = 1
|
||||
n_val_views: int = 1
|
||||
n_test_views: int = 120
|
||||
elevation_range: Tuple[float, float] = (-10, 90)
|
||||
azimuth_range: Tuple[float, float] = (-180, 180)
|
||||
camera_distance_range: Tuple[float, float] = (1, 1.5)
|
||||
fovy_range: Tuple[float, float] = (
|
||||
40,
|
||||
70,
|
||||
) # in degrees, in vertical direction (along height)
|
||||
camera_perturb: float = 0.1
|
||||
center_perturb: float = 0.2
|
||||
up_perturb: float = 0.02
|
||||
light_position_perturb: float = 1.0
|
||||
light_distance_range: Tuple[float, float] = (0.8, 1.5)
|
||||
eval_elevation_deg: float = 15.0
|
||||
eval_camera_distance: float = 1.5
|
||||
eval_fovy_deg: float = 70.0
|
||||
light_sample_strategy: str = "dreamfusion"
|
||||
batch_uniform_azimuth: bool = True
|
||||
progressive_until: int = 0 # progressive ranges for elevation, azimuth, r, fovy
|
||||
|
||||
rays_d_normalize: bool = True
|
||||
|
||||
|
||||
class RandomCameraIterableDataset(IterableDataset, Updateable):
|
||||
def __init__(self, cfg: Any) -> None:
|
||||
super().__init__()
|
||||
self.cfg: RandomCameraDataModuleConfig = cfg
|
||||
self.heights: List[int] = (
|
||||
[self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height
|
||||
)
|
||||
self.widths: List[int] = (
|
||||
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width
|
||||
)
|
||||
self.batch_sizes: List[int] = (
|
||||
[self.cfg.batch_size]
|
||||
if isinstance(self.cfg.batch_size, int)
|
||||
else self.cfg.batch_size
|
||||
)
|
||||
assert len(self.heights) == len(self.widths) == len(self.batch_sizes)
|
||||
self.resolution_milestones: List[int]
|
||||
if (
|
||||
len(self.heights) == 1
|
||||
and len(self.widths) == 1
|
||||
and len(self.batch_sizes) == 1
|
||||
):
|
||||
if len(self.cfg.resolution_milestones) > 0:
|
||||
threestudio.warn(
|
||||
"Ignoring resolution_milestones since height and width are not changing"
|
||||
)
|
||||
self.resolution_milestones = [-1]
|
||||
else:
|
||||
assert len(self.heights) == len(self.cfg.resolution_milestones) + 1
|
||||
self.resolution_milestones = [-1] + self.cfg.resolution_milestones
|
||||
|
||||
self.directions_unit_focals = [
|
||||
get_ray_directions(H=height, W=width, focal=1.0)
|
||||
for (height, width) in zip(self.heights, self.widths)
|
||||
]
|
||||
self.height: int = self.heights[0]
|
||||
self.width: int = self.widths[0]
|
||||
self.batch_size: int = self.batch_sizes[0]
|
||||
self.directions_unit_focal = self.directions_unit_focals[0]
|
||||
self.elevation_range = self.cfg.elevation_range
|
||||
self.azimuth_range = self.cfg.azimuth_range
|
||||
self.camera_distance_range = self.cfg.camera_distance_range
|
||||
self.fovy_range = self.cfg.fovy_range
|
||||
|
||||
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1
|
||||
self.height = self.heights[size_ind]
|
||||
self.width = self.widths[size_ind]
|
||||
self.batch_size = self.batch_sizes[size_ind]
|
||||
self.directions_unit_focal = self.directions_unit_focals[size_ind]
|
||||
threestudio.debug(
|
||||
f"Training height: {self.height}, width: {self.width}, batch_size: {self.batch_size}"
|
||||
)
|
||||
# progressive view
|
||||
self.progressive_view(global_step)
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
yield {}
|
||||
|
||||
def progressive_view(self, global_step):
|
||||
r = min(1.0, global_step / (self.cfg.progressive_until + 1))
|
||||
self.elevation_range = [
|
||||
(1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[0],
|
||||
(1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[1],
|
||||
]
|
||||
self.azimuth_range = [
|
||||
(1 - r) * 0.0 + r * self.cfg.azimuth_range[0],
|
||||
(1 - r) * 0.0 + r * self.cfg.azimuth_range[1],
|
||||
]
|
||||
# self.camera_distance_range = [
|
||||
# (1 - r) * self.cfg.eval_camera_distance
|
||||
# + r * self.cfg.camera_distance_range[0],
|
||||
# (1 - r) * self.cfg.eval_camera_distance
|
||||
# + r * self.cfg.camera_distance_range[1],
|
||||
# ]
|
||||
# self.fovy_range = [
|
||||
# (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[0],
|
||||
# (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[1],
|
||||
# ]
|
||||
|
||||
def collate(self, batch) -> Dict[str, Any]:
|
||||
# sample elevation angles
|
||||
elevation_deg: Float[Tensor, "B"]
|
||||
elevation: Float[Tensor, "B"]
|
||||
if random.random() < 0.5:
|
||||
# sample elevation angles uniformly with a probability 0.5 (biased towards poles)
|
||||
elevation_deg = (
|
||||
torch.rand(self.batch_size)
|
||||
* (self.elevation_range[1] - self.elevation_range[0])
|
||||
+ self.elevation_range[0]
|
||||
)
|
||||
elevation = elevation_deg * math.pi / 180
|
||||
else:
|
||||
# otherwise sample uniformly on sphere
|
||||
elevation_range_percent = [
|
||||
self.elevation_range[0] / 180.0 * math.pi,
|
||||
self.elevation_range[1] / 180.0 * math.pi,
|
||||
]
|
||||
# inverse transform sampling
|
||||
elevation = torch.asin(
|
||||
(
|
||||
torch.rand(self.batch_size)
|
||||
* (
|
||||
math.sin(elevation_range_percent[1])
|
||||
- math.sin(elevation_range_percent[0])
|
||||
)
|
||||
+ math.sin(elevation_range_percent[0])
|
||||
)
|
||||
)
|
||||
elevation_deg = elevation / math.pi * 180.0
|
||||
|
||||
# sample azimuth angles from a uniform distribution bounded by azimuth_range
|
||||
azimuth_deg: Float[Tensor, "B"]
|
||||
if self.cfg.batch_uniform_azimuth:
|
||||
# ensures sampled azimuth angles in a batch cover the whole range
|
||||
azimuth_deg = (
|
||||
torch.rand(self.batch_size) + torch.arange(self.batch_size)
|
||||
) / self.batch_size * (
|
||||
self.azimuth_range[1] - self.azimuth_range[0]
|
||||
) + self.azimuth_range[
|
||||
0
|
||||
]
|
||||
else:
|
||||
# simple random sampling
|
||||
azimuth_deg = (
|
||||
torch.rand(self.batch_size)
|
||||
* (self.azimuth_range[1] - self.azimuth_range[0])
|
||||
+ self.azimuth_range[0]
|
||||
)
|
||||
azimuth = azimuth_deg * math.pi / 180
|
||||
|
||||
# sample distances from a uniform distribution bounded by distance_range
|
||||
camera_distances: Float[Tensor, "B"] = (
|
||||
torch.rand(self.batch_size)
|
||||
* (self.camera_distance_range[1] - self.camera_distance_range[0])
|
||||
+ self.camera_distance_range[0]
|
||||
)
|
||||
|
||||
# convert spherical coordinates to cartesian coordinates
|
||||
# right hand coordinate system, x back, y right, z up
|
||||
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
|
||||
camera_positions: Float[Tensor, "B 3"] = torch.stack(
|
||||
[
|
||||
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
||||
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
||||
camera_distances * torch.sin(elevation),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# default scene center at origin
|
||||
center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
|
||||
# default camera up direction as +z
|
||||
up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
|
||||
None, :
|
||||
].repeat(self.batch_size, 1)
|
||||
|
||||
# sample camera perturbations from a uniform distribution [-camera_perturb, camera_perturb]
|
||||
camera_perturb: Float[Tensor, "B 3"] = (
|
||||
torch.rand(self.batch_size, 3) * 2 * self.cfg.camera_perturb
|
||||
- self.cfg.camera_perturb
|
||||
)
|
||||
camera_positions = camera_positions + camera_perturb
|
||||
# sample center perturbations from a normal distribution with mean 0 and std center_perturb
|
||||
center_perturb: Float[Tensor, "B 3"] = (
|
||||
torch.randn(self.batch_size, 3) * self.cfg.center_perturb
|
||||
)
|
||||
center = center + center_perturb
|
||||
# sample up perturbations from a normal distribution with mean 0 and std up_perturb
|
||||
up_perturb: Float[Tensor, "B 3"] = (
|
||||
torch.randn(self.batch_size, 3) * self.cfg.up_perturb
|
||||
)
|
||||
up = up + up_perturb
|
||||
|
||||
# sample fovs from a uniform distribution bounded by fov_range
|
||||
fovy_deg: Float[Tensor, "B"] = (
|
||||
torch.rand(self.batch_size) * (self.fovy_range[1] - self.fovy_range[0])
|
||||
+ self.fovy_range[0]
|
||||
)
|
||||
fovy = fovy_deg * math.pi / 180
|
||||
|
||||
# sample light distance from a uniform distribution bounded by light_distance_range
|
||||
light_distances: Float[Tensor, "B"] = (
|
||||
torch.rand(self.batch_size)
|
||||
* (self.cfg.light_distance_range[1] - self.cfg.light_distance_range[0])
|
||||
+ self.cfg.light_distance_range[0]
|
||||
)
|
||||
|
||||
if self.cfg.light_sample_strategy == "dreamfusion":
|
||||
# sample light direction from a normal distribution with mean camera_position and std light_position_perturb
|
||||
light_direction: Float[Tensor, "B 3"] = F.normalize(
|
||||
camera_positions
|
||||
+ torch.randn(self.batch_size, 3) * self.cfg.light_position_perturb,
|
||||
dim=-1,
|
||||
)
|
||||
# get light position by scaling light direction by light distance
|
||||
light_positions: Float[Tensor, "B 3"] = (
|
||||
light_direction * light_distances[:, None]
|
||||
)
|
||||
elif self.cfg.light_sample_strategy == "magic3d":
|
||||
# sample light direction within restricted angle range (pi/3)
|
||||
local_z = F.normalize(camera_positions, dim=-1)
|
||||
local_x = F.normalize(
|
||||
torch.stack(
|
||||
[local_z[:, 1], -local_z[:, 0], torch.zeros_like(local_z[:, 0])],
|
||||
dim=-1,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
local_y = F.normalize(torch.cross(local_z, local_x, dim=-1), dim=-1)
|
||||
rot = torch.stack([local_x, local_y, local_z], dim=-1)
|
||||
light_azimuth = (
|
||||
torch.rand(self.batch_size) * math.pi * 2 - math.pi
|
||||
) # [-pi, pi]
|
||||
light_elevation = (
|
||||
torch.rand(self.batch_size) * math.pi / 3 + math.pi / 6
|
||||
) # [pi/6, pi/2]
|
||||
light_positions_local = torch.stack(
|
||||
[
|
||||
light_distances
|
||||
* torch.cos(light_elevation)
|
||||
* torch.cos(light_azimuth),
|
||||
light_distances
|
||||
* torch.cos(light_elevation)
|
||||
* torch.sin(light_azimuth),
|
||||
light_distances * torch.sin(light_elevation),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
light_positions = (rot @ light_positions_local[:, :, None])[:, :, 0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown light sample strategy: {self.cfg.light_sample_strategy}"
|
||||
)
|
||||
|
||||
lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
|
||||
right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
|
||||
up = F.normalize(torch.cross(right, lookat), dim=-1)
|
||||
c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
|
||||
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
||||
dim=-1,
|
||||
)
|
||||
c2w: Float[Tensor, "B 4 4"] = torch.cat(
|
||||
[c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
|
||||
)
|
||||
c2w[:, 3, 3] = 1.0
|
||||
|
||||
# get directions by dividing directions_unit_focal by focal length
|
||||
focal_length: Float[Tensor, "B"] = 0.5 * self.height / torch.tan(0.5 * fovy)
|
||||
directions: Float[Tensor, "B H W 3"] = self.directions_unit_focal[
|
||||
None, :, :, :
|
||||
].repeat(self.batch_size, 1, 1, 1)
|
||||
directions[:, :, :, :2] = (
|
||||
directions[:, :, :, :2] / focal_length[:, None, None, None]
|
||||
)
|
||||
|
||||
# Importance note: the returned rays_d MUST be normalized!
|
||||
rays_o, rays_d = get_rays(
|
||||
directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize
|
||||
)
|
||||
|
||||
self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix(
|
||||
fovy, self.width / self.height, 0.1, 1000.0
|
||||
) # FIXME: hard-coded near and far
|
||||
mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx)
|
||||
self.fovy = fovy
|
||||
|
||||
return {
|
||||
"rays_o": rays_o,
|
||||
"rays_d": rays_d,
|
||||
"mvp_mtx": mvp_mtx,
|
||||
"camera_positions": camera_positions,
|
||||
"c2w": c2w,
|
||||
"light_positions": light_positions,
|
||||
"elevation": elevation_deg,
|
||||
"azimuth": azimuth_deg,
|
||||
"camera_distances": camera_distances,
|
||||
"height": self.height,
|
||||
"width": self.width,
|
||||
"fovy": self.fovy,
|
||||
"proj_mtx": self.proj_mtx,
|
||||
}
|
||||
|
||||
|
||||
class RandomCameraDataset(Dataset):
|
||||
def __init__(self, cfg: Any, split: str) -> None:
|
||||
super().__init__()
|
||||
self.cfg: RandomCameraDataModuleConfig = cfg
|
||||
self.split = split
|
||||
|
||||
if split == "val":
|
||||
self.n_views = self.cfg.n_val_views
|
||||
else:
|
||||
self.n_views = self.cfg.n_test_views
|
||||
|
||||
azimuth_deg: Float[Tensor, "B"]
|
||||
if self.split == "val":
|
||||
# make sure the first and last view are not the same
|
||||
azimuth_deg = torch.linspace(0, 360.0, self.n_views + 1)[: self.n_views]
|
||||
else:
|
||||
azimuth_deg = torch.linspace(0, 360.0, self.n_views)
|
||||
elevation_deg: Float[Tensor, "B"] = torch.full_like(
|
||||
azimuth_deg, self.cfg.eval_elevation_deg
|
||||
)
|
||||
camera_distances: Float[Tensor, "B"] = torch.full_like(
|
||||
elevation_deg, self.cfg.eval_camera_distance
|
||||
)
|
||||
|
||||
elevation = elevation_deg * math.pi / 180
|
||||
azimuth = azimuth_deg * math.pi / 180
|
||||
|
||||
# convert spherical coordinates to cartesian coordinates
|
||||
# right hand coordinate system, x back, y right, z up
|
||||
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
|
||||
camera_positions: Float[Tensor, "B 3"] = torch.stack(
|
||||
[
|
||||
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
||||
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
||||
camera_distances * torch.sin(elevation),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# default scene center at origin
|
||||
center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
|
||||
# default camera up direction as +z
|
||||
up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
|
||||
None, :
|
||||
].repeat(self.cfg.eval_batch_size, 1)
|
||||
|
||||
fovy_deg: Float[Tensor, "B"] = torch.full_like(
|
||||
elevation_deg, self.cfg.eval_fovy_deg
|
||||
)
|
||||
fovy = fovy_deg * math.pi / 180
|
||||
light_positions: Float[Tensor, "B 3"] = camera_positions
|
||||
|
||||
lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
|
||||
right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
|
||||
up = F.normalize(torch.cross(right, lookat), dim=-1)
|
||||
c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
|
||||
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
||||
dim=-1,
|
||||
)
|
||||
c2w: Float[Tensor, "B 4 4"] = torch.cat(
|
||||
[c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
|
||||
)
|
||||
c2w[:, 3, 3] = 1.0
|
||||
|
||||
# get directions by dividing directions_unit_focal by focal length
|
||||
focal_length: Float[Tensor, "B"] = (
|
||||
0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy)
|
||||
)
|
||||
directions_unit_focal = get_ray_directions(
|
||||
H=self.cfg.eval_height, W=self.cfg.eval_width, focal=1.0
|
||||
)
|
||||
directions: Float[Tensor, "B H W 3"] = directions_unit_focal[
|
||||
None, :, :, :
|
||||
].repeat(self.n_views, 1, 1, 1)
|
||||
directions[:, :, :, :2] = (
|
||||
directions[:, :, :, :2] / focal_length[:, None, None, None]
|
||||
)
|
||||
|
||||
rays_o, rays_d = get_rays(
|
||||
directions, c2w, keepdim=True, normalize=self.cfg.rays_d_normalize
|
||||
)
|
||||
self.proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix(
|
||||
fovy, self.cfg.eval_width / self.cfg.eval_height, 0.1, 1000.0
|
||||
) # FIXME: hard-coded near and far
|
||||
mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, self.proj_mtx)
|
||||
|
||||
self.rays_o, self.rays_d = rays_o, rays_d
|
||||
self.mvp_mtx = mvp_mtx
|
||||
self.c2w = c2w
|
||||
self.camera_positions = camera_positions
|
||||
self.light_positions = light_positions
|
||||
self.elevation, self.azimuth = elevation, azimuth
|
||||
self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg
|
||||
self.camera_distances = camera_distances
|
||||
self.fovy = fovy
|
||||
|
||||
def __len__(self):
|
||||
return self.n_views
|
||||
|
||||
def __getitem__(self, index):
|
||||
return {
|
||||
"index": index,
|
||||
"rays_o": self.rays_o[index],
|
||||
"rays_d": self.rays_d[index],
|
||||
"mvp_mtx": self.mvp_mtx[index],
|
||||
"c2w": self.c2w[index],
|
||||
"camera_positions": self.camera_positions[index],
|
||||
"light_positions": self.light_positions[index],
|
||||
"elevation": self.elevation_deg[index],
|
||||
"azimuth": self.azimuth_deg[index],
|
||||
"camera_distances": self.camera_distances[index],
|
||||
"height": self.cfg.eval_height,
|
||||
"width": self.cfg.eval_width,
|
||||
"fovy": self.fovy[index],
|
||||
"proj_mtx": self.proj_mtx[index],
|
||||
}
|
||||
|
||||
def collate(self, batch):
|
||||
batch = torch.utils.data.default_collate(batch)
|
||||
batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width})
|
||||
return batch
|
||||
|
||||
|
||||
@register("random-camera-datamodule")
|
||||
class RandomCameraDataModule(pl.LightningDataModule):
|
||||
cfg: RandomCameraDataModuleConfig
|
||||
|
||||
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
|
||||
super().__init__()
|
||||
self.cfg = parse_structured(RandomCameraDataModuleConfig, cfg)
|
||||
|
||||
def setup(self, stage=None) -> None:
|
||||
if stage in [None, "fit"]:
|
||||
self.train_dataset = RandomCameraIterableDataset(self.cfg)
|
||||
if stage in [None, "fit", "validate"]:
|
||||
self.val_dataset = RandomCameraDataset(self.cfg, "val")
|
||||
if stage in [None, "test", "predict"]:
|
||||
self.test_dataset = RandomCameraDataset(self.cfg, "test")
|
||||
|
||||
def prepare_data(self):
|
||||
pass
|
||||
|
||||
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
|
||||
return DataLoader(
|
||||
dataset,
|
||||
# very important to disable multi-processing if you want to change self attributes at runtime!
|
||||
# (for example setting self.width and self.height in update_step)
|
||||
num_workers=0, # type: ignore
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(
|
||||
self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(
|
||||
self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate
|
||||
)
|
||||
# return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(
|
||||
self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate
|
||||
)
|
||||
|
||||
def predict_dataloader(self) -> DataLoader:
|
||||
return self.general_loader(
|
||||
self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate
|
||||
)
|
9
threestudio/models/__init__.py
Normal file
9
threestudio/models/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
from . import (
|
||||
background,
|
||||
exporters,
|
||||
geometry,
|
||||
guidance,
|
||||
materials,
|
||||
prompt_processors,
|
||||
renderers,
|
||||
)
|
6
threestudio/models/background/__init__.py
Normal file
6
threestudio/models/background/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from . import (
|
||||
base,
|
||||
neural_environment_map_background,
|
||||
solid_color_background,
|
||||
textured_background,
|
||||
)
|
24
threestudio/models/background/base.py
Normal file
24
threestudio/models/background/base.py
Normal file
@ -0,0 +1,24 @@
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import threestudio
|
||||
from threestudio.utils.base import BaseModule
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
class BaseBackground(BaseModule):
|
||||
@dataclass
|
||||
class Config(BaseModule.Config):
|
||||
pass
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
|
||||
raise NotImplementedError
|
@ -0,0 +1,71 @@
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.background.base import BaseBackground
|
||||
from threestudio.models.networks import get_encoding, get_mlp
|
||||
from threestudio.utils.ops import get_activation
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("neural-environment-map-background")
|
||||
class NeuralEnvironmentMapBackground(BaseBackground):
|
||||
@dataclass
|
||||
class Config(BaseBackground.Config):
|
||||
n_output_dims: int = 3
|
||||
color_activation: str = "sigmoid"
|
||||
dir_encoding_config: dict = field(
|
||||
default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3}
|
||||
)
|
||||
mlp_network_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"otype": "VanillaMLP",
|
||||
"activation": "ReLU",
|
||||
"n_neurons": 16,
|
||||
"n_hidden_layers": 2,
|
||||
}
|
||||
)
|
||||
random_aug: bool = False
|
||||
random_aug_prob: float = 0.5
|
||||
eval_color: Optional[Tuple[float, float, float]] = None
|
||||
|
||||
# multi-view diffusion
|
||||
share_aug_bg: bool = False
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
self.encoding = get_encoding(3, self.cfg.dir_encoding_config)
|
||||
self.network = get_mlp(
|
||||
self.encoding.n_output_dims,
|
||||
self.cfg.n_output_dims,
|
||||
self.cfg.mlp_network_config,
|
||||
)
|
||||
|
||||
def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
|
||||
if not self.training and self.cfg.eval_color is not None:
|
||||
return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(
|
||||
dirs
|
||||
) * torch.as_tensor(self.cfg.eval_color).to(dirs)
|
||||
# viewdirs must be normalized before passing to this function
|
||||
dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1)
|
||||
dirs_embd = self.encoding(dirs.view(-1, 3))
|
||||
color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims)
|
||||
color = get_activation(self.cfg.color_activation)(color)
|
||||
if (
|
||||
self.training
|
||||
and self.cfg.random_aug
|
||||
and random.random() < self.cfg.random_aug_prob
|
||||
):
|
||||
# use random background color with probability random_aug_prob
|
||||
n_color = 1 if self.cfg.share_aug_bg else dirs.shape[0]
|
||||
color = color * 0 + ( # prevent checking for unused parameters in DDP
|
||||
torch.rand(n_color, 1, 1, self.cfg.n_output_dims)
|
||||
.to(dirs)
|
||||
.expand(*dirs.shape[:-1], -1)
|
||||
)
|
||||
return color
|
51
threestudio/models/background/solid_color_background.py
Normal file
51
threestudio/models/background/solid_color_background.py
Normal file
@ -0,0 +1,51 @@
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.background.base import BaseBackground
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("solid-color-background")
|
||||
class SolidColorBackground(BaseBackground):
|
||||
@dataclass
|
||||
class Config(BaseBackground.Config):
|
||||
n_output_dims: int = 3
|
||||
color: Tuple = (1.0, 1.0, 1.0)
|
||||
learned: bool = False
|
||||
random_aug: bool = False
|
||||
random_aug_prob: float = 0.5
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
self.env_color: Float[Tensor, "Nc"]
|
||||
if self.cfg.learned:
|
||||
self.env_color = nn.Parameter(
|
||||
torch.as_tensor(self.cfg.color, dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"env_color", torch.as_tensor(self.cfg.color, dtype=torch.float32)
|
||||
)
|
||||
|
||||
def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
|
||||
color = torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(
|
||||
dirs
|
||||
) * self.env_color.to(dirs)
|
||||
if (
|
||||
self.training
|
||||
and self.cfg.random_aug
|
||||
and random.random() < self.cfg.random_aug_prob
|
||||
):
|
||||
# use random background color with probability random_aug_prob
|
||||
color = color * 0 + ( # prevent checking for unused parameters in DDP
|
||||
torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims)
|
||||
.to(dirs)
|
||||
.expand(*dirs.shape[:-1], -1)
|
||||
)
|
||||
return color
|
54
threestudio/models/background/textured_background.py
Normal file
54
threestudio/models/background/textured_background.py
Normal file
@ -0,0 +1,54 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.background.base import BaseBackground
|
||||
from threestudio.utils.ops import get_activation
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("textured-background")
|
||||
class TexturedBackground(BaseBackground):
|
||||
@dataclass
|
||||
class Config(BaseBackground.Config):
|
||||
n_output_dims: int = 3
|
||||
height: int = 64
|
||||
width: int = 64
|
||||
color_activation: str = "sigmoid"
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
self.texture = nn.Parameter(
|
||||
torch.randn((1, self.cfg.n_output_dims, self.cfg.height, self.cfg.width))
|
||||
)
|
||||
|
||||
def spherical_xyz_to_uv(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 2"]:
|
||||
x, y, z = dirs[..., 0], dirs[..., 1], dirs[..., 2]
|
||||
xy = (x**2 + y**2) ** 0.5
|
||||
u = torch.atan2(xy, z) / torch.pi
|
||||
v = torch.atan2(y, x) / (torch.pi * 2) + 0.5
|
||||
uv = torch.stack([u, v], -1)
|
||||
return uv
|
||||
|
||||
def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B Nc"]:
|
||||
dirs_shape = dirs.shape[:-1]
|
||||
uv = self.spherical_xyz_to_uv(dirs.reshape(-1, dirs.shape[-1]))
|
||||
uv = 2 * uv - 1 # rescale to [-1, 1] for grid_sample
|
||||
uv = uv.reshape(1, -1, 1, 2)
|
||||
color = (
|
||||
F.grid_sample(
|
||||
self.texture,
|
||||
uv,
|
||||
mode="bilinear",
|
||||
padding_mode="reflection",
|
||||
align_corners=False,
|
||||
)
|
||||
.reshape(self.cfg.n_output_dims, -1)
|
||||
.T.reshape(*dirs_shape, self.cfg.n_output_dims)
|
||||
)
|
||||
color = get_activation(self.cfg.color_activation)(color)
|
||||
return color
|
118
threestudio/models/estimators.py
Normal file
118
threestudio/models/estimators.py
Normal file
@ -0,0 +1,118 @@
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal
|
||||
|
||||
import torch
|
||||
from nerfacc.data_specs import RayIntervals
|
||||
from nerfacc.estimators.base import AbstractEstimator
|
||||
from nerfacc.pdf import importance_sampling, searchsorted
|
||||
from nerfacc.volrend import render_transmittance_from_density
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ImportanceEstimator(AbstractEstimator):
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@torch.no_grad()
|
||||
def sampling(
|
||||
self,
|
||||
prop_sigma_fns: List[Callable],
|
||||
prop_samples: List[int],
|
||||
num_samples: int,
|
||||
# rendering options
|
||||
n_rays: int,
|
||||
near_plane: float,
|
||||
far_plane: float,
|
||||
sampling_type: Literal["uniform", "lindisp"] = "uniform",
|
||||
# training options
|
||||
stratified: bool = False,
|
||||
requires_grad: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Sampling with CDFs from proposal networks.
|
||||
|
||||
Args:
|
||||
prop_sigma_fns: Proposal network evaluate functions. It should be a list
|
||||
of functions that take in samples {t_starts (n_rays, n_samples),
|
||||
t_ends (n_rays, n_samples)} and returns the post-activation densities
|
||||
(n_rays, n_samples).
|
||||
prop_samples: Number of samples to draw from each proposal network. Should
|
||||
be the same length as `prop_sigma_fns`.
|
||||
num_samples: Number of samples to draw in the end.
|
||||
n_rays: Number of rays.
|
||||
near_plane: Near plane.
|
||||
far_plane: Far plane.
|
||||
sampling_type: Sampling type. Either "uniform" or "lindisp". Default to
|
||||
"lindisp".
|
||||
stratified: Whether to use stratified sampling. Default to `False`.
|
||||
|
||||
Returns:
|
||||
A tuple of {Tensor, Tensor}:
|
||||
|
||||
- **t_starts**: The starts of the samples. Shape (n_rays, num_samples).
|
||||
- **t_ends**: The ends of the samples. Shape (n_rays, num_samples).
|
||||
|
||||
"""
|
||||
assert len(prop_sigma_fns) == len(prop_samples), (
|
||||
"The number of proposal networks and the number of samples "
|
||||
"should be the same."
|
||||
)
|
||||
cdfs = torch.cat(
|
||||
[
|
||||
torch.zeros((n_rays, 1), device=self.device),
|
||||
torch.ones((n_rays, 1), device=self.device),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
intervals = RayIntervals(vals=cdfs)
|
||||
|
||||
for level_fn, level_samples in zip(prop_sigma_fns, prop_samples):
|
||||
intervals, _ = importance_sampling(
|
||||
intervals, cdfs, level_samples, stratified
|
||||
)
|
||||
t_vals = _transform_stot(
|
||||
sampling_type, intervals.vals, near_plane, far_plane
|
||||
)
|
||||
t_starts = t_vals[..., :-1]
|
||||
t_ends = t_vals[..., 1:]
|
||||
|
||||
with torch.set_grad_enabled(requires_grad):
|
||||
sigmas = level_fn(t_starts, t_ends)
|
||||
assert sigmas.shape == t_starts.shape
|
||||
trans, _ = render_transmittance_from_density(t_starts, t_ends, sigmas)
|
||||
cdfs = 1.0 - torch.cat([trans, torch.zeros_like(trans[:, :1])], dim=-1)
|
||||
|
||||
intervals, _ = importance_sampling(intervals, cdfs, num_samples, stratified)
|
||||
t_vals_fine = _transform_stot(
|
||||
sampling_type, intervals.vals, near_plane, far_plane
|
||||
)
|
||||
|
||||
t_vals = torch.cat([t_vals, t_vals_fine], dim=-1)
|
||||
t_vals, _ = torch.sort(t_vals, dim=-1)
|
||||
|
||||
t_starts_ = t_vals[..., :-1]
|
||||
t_ends_ = t_vals[..., 1:]
|
||||
|
||||
return t_starts_, t_ends_
|
||||
|
||||
|
||||
def _transform_stot(
|
||||
transform_type: Literal["uniform", "lindisp"],
|
||||
s_vals: torch.Tensor,
|
||||
t_min: torch.Tensor,
|
||||
t_max: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if transform_type == "uniform":
|
||||
_contract_fn, _icontract_fn = lambda x: x, lambda x: x
|
||||
elif transform_type == "lindisp":
|
||||
_contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x
|
||||
else:
|
||||
raise ValueError(f"Unknown transform_type: {transform_type}")
|
||||
s_min, s_max = _contract_fn(t_min), _contract_fn(t_max)
|
||||
icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min)
|
||||
return icontract_fn(s_vals)
|
1
threestudio/models/exporters/__init__.py
Normal file
1
threestudio/models/exporters/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import base, mesh_exporter
|
59
threestudio/models/exporters/base.py
Normal file
59
threestudio/models/exporters/base.py
Normal file
@ -0,0 +1,59 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.background.base import BaseBackground
|
||||
from threestudio.models.geometry.base import BaseImplicitGeometry
|
||||
from threestudio.models.materials.base import BaseMaterial
|
||||
from threestudio.utils.base import BaseObject
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExporterOutput:
|
||||
save_name: str
|
||||
save_type: str
|
||||
params: Dict[str, Any]
|
||||
|
||||
|
||||
class Exporter(BaseObject):
|
||||
@dataclass
|
||||
class Config(BaseObject.Config):
|
||||
save_video: bool = False
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(
|
||||
self,
|
||||
geometry: BaseImplicitGeometry,
|
||||
material: BaseMaterial,
|
||||
background: BaseBackground,
|
||||
) -> None:
|
||||
@dataclass
|
||||
class SubModules:
|
||||
geometry: BaseImplicitGeometry
|
||||
material: BaseMaterial
|
||||
background: BaseBackground
|
||||
|
||||
self.sub_modules = SubModules(geometry, material, background)
|
||||
|
||||
@property
|
||||
def geometry(self) -> BaseImplicitGeometry:
|
||||
return self.sub_modules.geometry
|
||||
|
||||
@property
|
||||
def material(self) -> BaseMaterial:
|
||||
return self.sub_modules.material
|
||||
|
||||
@property
|
||||
def background(self) -> BaseBackground:
|
||||
return self.sub_modules.background
|
||||
|
||||
def __call__(self, *args, **kwargs) -> List[ExporterOutput]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@threestudio.register("dummy-exporter")
|
||||
class DummyExporter(Exporter):
|
||||
def __call__(self, *args, **kwargs) -> List[ExporterOutput]:
|
||||
# DummyExporter does not export anything
|
||||
return []
|
175
threestudio/models/exporters/mesh_exporter.py
Normal file
175
threestudio/models/exporters/mesh_exporter.py
Normal file
@ -0,0 +1,175 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.background.base import BaseBackground
|
||||
from threestudio.models.exporters.base import Exporter, ExporterOutput
|
||||
from threestudio.models.geometry.base import BaseImplicitGeometry
|
||||
from threestudio.models.materials.base import BaseMaterial
|
||||
from threestudio.models.mesh import Mesh
|
||||
from threestudio.utils.rasterize import NVDiffRasterizerContext
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("mesh-exporter")
|
||||
class MeshExporter(Exporter):
|
||||
@dataclass
|
||||
class Config(Exporter.Config):
|
||||
fmt: str = "obj-mtl" # in ['obj-mtl', 'obj'], TODO: fbx
|
||||
save_name: str = "model"
|
||||
save_normal: bool = False
|
||||
save_uv: bool = True
|
||||
save_texture: bool = True
|
||||
texture_size: int = 1024
|
||||
texture_format: str = "jpg"
|
||||
xatlas_chart_options: dict = field(default_factory=dict)
|
||||
xatlas_pack_options: dict = field(default_factory=dict)
|
||||
context_type: str = "gl"
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(
|
||||
self,
|
||||
geometry: BaseImplicitGeometry,
|
||||
material: BaseMaterial,
|
||||
background: BaseBackground,
|
||||
) -> None:
|
||||
super().configure(geometry, material, background)
|
||||
self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device)
|
||||
|
||||
def __call__(self) -> List[ExporterOutput]:
|
||||
mesh: Mesh = self.geometry.isosurface()
|
||||
|
||||
if self.cfg.fmt == "obj-mtl":
|
||||
return self.export_obj_with_mtl(mesh)
|
||||
elif self.cfg.fmt == "obj":
|
||||
return self.export_obj(mesh)
|
||||
else:
|
||||
raise ValueError(f"Unsupported mesh export format: {self.cfg.fmt}")
|
||||
|
||||
def export_obj_with_mtl(self, mesh: Mesh) -> List[ExporterOutput]:
|
||||
params = {
|
||||
"mesh": mesh,
|
||||
"save_mat": True,
|
||||
"save_normal": self.cfg.save_normal,
|
||||
"save_uv": self.cfg.save_uv,
|
||||
"save_vertex_color": False,
|
||||
"map_Kd": None, # Base Color
|
||||
"map_Ks": None, # Specular
|
||||
"map_Bump": None, # Normal
|
||||
# ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
|
||||
"map_Pm": None, # Metallic
|
||||
"map_Pr": None, # Roughness
|
||||
"map_format": self.cfg.texture_format,
|
||||
}
|
||||
|
||||
if self.cfg.save_uv:
|
||||
mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options)
|
||||
|
||||
if self.cfg.save_texture:
|
||||
threestudio.info("Exporting textures ...")
|
||||
assert self.cfg.save_uv, "save_uv must be True when save_texture is True"
|
||||
# clip space transform
|
||||
uv_clip = mesh.v_tex * 2.0 - 1.0
|
||||
# pad to four component coordinate
|
||||
uv_clip4 = torch.cat(
|
||||
(
|
||||
uv_clip,
|
||||
torch.zeros_like(uv_clip[..., 0:1]),
|
||||
torch.ones_like(uv_clip[..., 0:1]),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
# rasterize
|
||||
rast, _ = self.ctx.rasterize_one(
|
||||
uv_clip4, mesh.t_tex_idx, (self.cfg.texture_size, self.cfg.texture_size)
|
||||
)
|
||||
|
||||
hole_mask = ~(rast[:, :, 3] > 0)
|
||||
|
||||
def uv_padding(image):
|
||||
uv_padding_size = self.cfg.xatlas_pack_options.get("padding", 2)
|
||||
inpaint_image = (
|
||||
cv2.inpaint(
|
||||
(image.detach().cpu().numpy() * 255).astype(np.uint8),
|
||||
(hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
|
||||
uv_padding_size,
|
||||
cv2.INPAINT_TELEA,
|
||||
)
|
||||
/ 255.0
|
||||
)
|
||||
return torch.from_numpy(inpaint_image).to(image)
|
||||
|
||||
# Interpolate world space position
|
||||
gb_pos, _ = self.ctx.interpolate_one(
|
||||
mesh.v_pos, rast[None, ...], mesh.t_pos_idx
|
||||
)
|
||||
gb_pos = gb_pos[0]
|
||||
|
||||
# Sample out textures from MLP
|
||||
geo_out = self.geometry.export(points=gb_pos)
|
||||
mat_out = self.material.export(points=gb_pos, **geo_out)
|
||||
|
||||
threestudio.info(
|
||||
"Perform UV padding on texture maps to avoid seams, may take a while ..."
|
||||
)
|
||||
|
||||
if "albedo" in mat_out:
|
||||
params["map_Kd"] = uv_padding(mat_out["albedo"])
|
||||
else:
|
||||
threestudio.warn(
|
||||
"save_texture is True but no albedo texture found, using default white texture"
|
||||
)
|
||||
if "metallic" in mat_out:
|
||||
params["map_Pm"] = uv_padding(mat_out["metallic"])
|
||||
if "roughness" in mat_out:
|
||||
params["map_Pr"] = uv_padding(mat_out["roughness"])
|
||||
if "bump" in mat_out:
|
||||
params["map_Bump"] = uv_padding(mat_out["bump"])
|
||||
# TODO: map_Ks
|
||||
return [
|
||||
ExporterOutput(
|
||||
save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params
|
||||
)
|
||||
]
|
||||
|
||||
def export_obj(self, mesh: Mesh) -> List[ExporterOutput]:
|
||||
params = {
|
||||
"mesh": mesh,
|
||||
"save_mat": False,
|
||||
"save_normal": self.cfg.save_normal,
|
||||
"save_uv": self.cfg.save_uv,
|
||||
"save_vertex_color": False,
|
||||
"map_Kd": None, # Base Color
|
||||
"map_Ks": None, # Specular
|
||||
"map_Bump": None, # Normal
|
||||
# ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
|
||||
"map_Pm": None, # Metallic
|
||||
"map_Pr": None, # Roughness
|
||||
"map_format": self.cfg.texture_format,
|
||||
}
|
||||
|
||||
if self.cfg.save_uv:
|
||||
mesh.unwrap_uv(self.cfg.xatlas_chart_options, self.cfg.xatlas_pack_options)
|
||||
|
||||
if self.cfg.save_texture:
|
||||
threestudio.info("Exporting textures ...")
|
||||
geo_out = self.geometry.export(points=mesh.v_pos)
|
||||
mat_out = self.material.export(points=mesh.v_pos, **geo_out)
|
||||
|
||||
if "albedo" in mat_out:
|
||||
mesh.set_vertex_color(mat_out["albedo"])
|
||||
params["save_vertex_color"] = True
|
||||
else:
|
||||
threestudio.warn(
|
||||
"save_texture is True but no albedo texture found, not saving vertex color"
|
||||
)
|
||||
|
||||
return [
|
||||
ExporterOutput(
|
||||
save_name=f"{self.cfg.save_name}.obj", save_type="obj", params=params
|
||||
)
|
||||
]
|
8
threestudio/models/geometry/__init__.py
Normal file
8
threestudio/models/geometry/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from . import (
|
||||
base,
|
||||
custom_mesh,
|
||||
implicit_sdf,
|
||||
implicit_volume,
|
||||
tetrahedra_sdf_grid,
|
||||
volume_grid,
|
||||
)
|
209
threestudio/models/geometry/base.py
Normal file
209
threestudio/models/geometry/base.py
Normal file
@ -0,0 +1,209 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.isosurface import (
|
||||
IsosurfaceHelper,
|
||||
MarchingCubeCPUHelper,
|
||||
MarchingTetrahedraHelper,
|
||||
)
|
||||
from threestudio.models.mesh import Mesh
|
||||
from threestudio.utils.base import BaseModule
|
||||
from threestudio.utils.ops import chunk_batch, scale_tensor
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
def contract_to_unisphere(
|
||||
x: Float[Tensor, "... 3"], bbox: Float[Tensor, "2 3"], unbounded: bool = False
|
||||
) -> Float[Tensor, "... 3"]:
|
||||
if unbounded:
|
||||
x = scale_tensor(x, bbox, (0, 1))
|
||||
x = x * 2 - 1 # aabb is at [-1, 1]
|
||||
mag = x.norm(dim=-1, keepdim=True)
|
||||
mask = mag.squeeze(-1) > 1
|
||||
x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
|
||||
x = x / 4 + 0.5 # [-inf, inf] is at [0, 1]
|
||||
else:
|
||||
x = scale_tensor(x, bbox, (0, 1))
|
||||
return x
|
||||
|
||||
|
||||
class BaseGeometry(BaseModule):
|
||||
@dataclass
|
||||
class Config(BaseModule.Config):
|
||||
pass
|
||||
|
||||
cfg: Config
|
||||
|
||||
@staticmethod
|
||||
def create_from(
|
||||
other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs
|
||||
) -> "BaseGeometry":
|
||||
raise TypeError(
|
||||
f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}"
|
||||
)
|
||||
|
||||
def export(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
class BaseImplicitGeometry(BaseGeometry):
|
||||
@dataclass
|
||||
class Config(BaseGeometry.Config):
|
||||
radius: float = 1.0
|
||||
isosurface: bool = True
|
||||
isosurface_method: str = "mt"
|
||||
isosurface_resolution: int = 128
|
||||
isosurface_threshold: Union[float, str] = 0.0
|
||||
isosurface_chunk: int = 0
|
||||
isosurface_coarse_to_fine: bool = True
|
||||
isosurface_deformable_grid: bool = False
|
||||
isosurface_remove_outliers: bool = True
|
||||
isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bbox: Float[Tensor, "2 3"]
|
||||
self.register_buffer(
|
||||
"bbox",
|
||||
torch.as_tensor(
|
||||
[
|
||||
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
|
||||
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
self.isosurface_helper: Optional[IsosurfaceHelper] = None
|
||||
self.unbounded: bool = False
|
||||
|
||||
def _initilize_isosurface_helper(self):
|
||||
if self.cfg.isosurface and self.isosurface_helper is None:
|
||||
if self.cfg.isosurface_method == "mc-cpu":
|
||||
self.isosurface_helper = MarchingCubeCPUHelper(
|
||||
self.cfg.isosurface_resolution
|
||||
).to(self.device)
|
||||
elif self.cfg.isosurface_method == "mt":
|
||||
self.isosurface_helper = MarchingTetrahedraHelper(
|
||||
self.cfg.isosurface_resolution,
|
||||
f"load/tets/{self.cfg.isosurface_resolution}_tets.npz",
|
||||
).to(self.device)
|
||||
else:
|
||||
raise AttributeError(
|
||||
"Unknown isosurface method {self.cfg.isosurface_method}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
|
||||
) -> Dict[str, Float[Tensor, "..."]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_field(
|
||||
self, points: Float[Tensor, "*N Di"]
|
||||
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
|
||||
# return the value of the implicit field, could be density / signed distance
|
||||
# also return a deformation field if the grid vertices can be optimized
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_level(
|
||||
self, field: Float[Tensor, "*N 1"], threshold: float
|
||||
) -> Float[Tensor, "*N 1"]:
|
||||
# return the value of the implicit field, where the zero level set represents the surface
|
||||
raise NotImplementedError
|
||||
|
||||
def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh:
|
||||
def batch_func(x):
|
||||
# scale to bbox as the input vertices are in [0, 1]
|
||||
field, deformation = self.forward_field(
|
||||
scale_tensor(
|
||||
x.to(bbox.device), self.isosurface_helper.points_range, bbox
|
||||
),
|
||||
)
|
||||
field = field.to(
|
||||
x.device
|
||||
) # move to the same device as the input (could be CPU)
|
||||
if deformation is not None:
|
||||
deformation = deformation.to(x.device)
|
||||
return field, deformation
|
||||
|
||||
assert self.isosurface_helper is not None
|
||||
|
||||
field, deformation = chunk_batch(
|
||||
batch_func,
|
||||
self.cfg.isosurface_chunk,
|
||||
self.isosurface_helper.grid_vertices,
|
||||
)
|
||||
|
||||
threshold: float
|
||||
|
||||
if isinstance(self.cfg.isosurface_threshold, float):
|
||||
threshold = self.cfg.isosurface_threshold
|
||||
elif self.cfg.isosurface_threshold == "auto":
|
||||
eps = 1.0e-5
|
||||
threshold = field[field > eps].mean().item()
|
||||
threestudio.info(
|
||||
f"Automatically determined isosurface threshold: {threshold}"
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}"
|
||||
)
|
||||
|
||||
level = self.forward_level(field, threshold)
|
||||
mesh: Mesh = self.isosurface_helper(level, deformation=deformation)
|
||||
mesh.v_pos = scale_tensor(
|
||||
mesh.v_pos, self.isosurface_helper.points_range, bbox
|
||||
) # scale to bbox as the grid vertices are in [0, 1]
|
||||
mesh.add_extra("bbox", bbox)
|
||||
|
||||
if self.cfg.isosurface_remove_outliers:
|
||||
# remove outliers components with small number of faces
|
||||
# only enabled when the mesh is not differentiable
|
||||
mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold)
|
||||
|
||||
return mesh
|
||||
|
||||
def isosurface(self) -> Mesh:
|
||||
if not self.cfg.isosurface:
|
||||
raise NotImplementedError(
|
||||
"Isosurface is not enabled in the current configuration"
|
||||
)
|
||||
self._initilize_isosurface_helper()
|
||||
if self.cfg.isosurface_coarse_to_fine:
|
||||
threestudio.debug("First run isosurface to get a tight bounding box ...")
|
||||
with torch.no_grad():
|
||||
mesh_coarse = self._isosurface(self.bbox)
|
||||
vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0)
|
||||
vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0])
|
||||
vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1])
|
||||
threestudio.debug("Run isosurface again with the tight bounding box ...")
|
||||
mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True)
|
||||
else:
|
||||
mesh = self._isosurface(self.bbox)
|
||||
return mesh
|
||||
|
||||
|
||||
class BaseExplicitGeometry(BaseGeometry):
|
||||
@dataclass
|
||||
class Config(BaseGeometry.Config):
|
||||
radius: float = 1.0
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bbox: Float[Tensor, "2 3"]
|
||||
self.register_buffer(
|
||||
"bbox",
|
||||
torch.as_tensor(
|
||||
[
|
||||
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
|
||||
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
178
threestudio/models/geometry/custom_mesh.py
Normal file
178
threestudio/models/geometry/custom_mesh.py
Normal file
@ -0,0 +1,178 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.geometry.base import (
|
||||
BaseExplicitGeometry,
|
||||
BaseGeometry,
|
||||
contract_to_unisphere,
|
||||
)
|
||||
from threestudio.models.mesh import Mesh
|
||||
from threestudio.models.networks import get_encoding, get_mlp
|
||||
from threestudio.utils.ops import scale_tensor
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("custom-mesh")
|
||||
class CustomMesh(BaseExplicitGeometry):
|
||||
@dataclass
|
||||
class Config(BaseExplicitGeometry.Config):
|
||||
n_input_dims: int = 3
|
||||
n_feature_dims: int = 3
|
||||
pos_encoding_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"otype": "HashGrid",
|
||||
"n_levels": 16,
|
||||
"n_features_per_level": 2,
|
||||
"log2_hashmap_size": 19,
|
||||
"base_resolution": 16,
|
||||
"per_level_scale": 1.447269237440378,
|
||||
}
|
||||
)
|
||||
mlp_network_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"otype": "VanillaMLP",
|
||||
"activation": "ReLU",
|
||||
"output_activation": "none",
|
||||
"n_neurons": 64,
|
||||
"n_hidden_layers": 1,
|
||||
}
|
||||
)
|
||||
shape_init: str = ""
|
||||
shape_init_params: Optional[Any] = None
|
||||
shape_init_mesh_up: str = "+z"
|
||||
shape_init_mesh_front: str = "+x"
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
super().configure()
|
||||
|
||||
self.encoding = get_encoding(
|
||||
self.cfg.n_input_dims, self.cfg.pos_encoding_config
|
||||
)
|
||||
self.feature_network = get_mlp(
|
||||
self.encoding.n_output_dims,
|
||||
self.cfg.n_feature_dims,
|
||||
self.cfg.mlp_network_config,
|
||||
)
|
||||
|
||||
# Initialize custom mesh
|
||||
if self.cfg.shape_init.startswith("mesh:"):
|
||||
assert isinstance(self.cfg.shape_init_params, float)
|
||||
mesh_path = self.cfg.shape_init[5:]
|
||||
if not os.path.exists(mesh_path):
|
||||
raise ValueError(f"Mesh file {mesh_path} does not exist.")
|
||||
|
||||
import trimesh
|
||||
|
||||
scene = trimesh.load(mesh_path)
|
||||
if isinstance(scene, trimesh.Trimesh):
|
||||
mesh = scene
|
||||
elif isinstance(scene, trimesh.scene.Scene):
|
||||
mesh = trimesh.Trimesh()
|
||||
for obj in scene.geometry.values():
|
||||
mesh = trimesh.util.concatenate([mesh, obj])
|
||||
else:
|
||||
raise ValueError(f"Unknown mesh type at {mesh_path}.")
|
||||
|
||||
# move to center
|
||||
centroid = mesh.vertices.mean(0)
|
||||
mesh.vertices = mesh.vertices - centroid
|
||||
|
||||
# align to up-z and front-x
|
||||
dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
|
||||
dir2vec = {
|
||||
"+x": np.array([1, 0, 0]),
|
||||
"+y": np.array([0, 1, 0]),
|
||||
"+z": np.array([0, 0, 1]),
|
||||
"-x": np.array([-1, 0, 0]),
|
||||
"-y": np.array([0, -1, 0]),
|
||||
"-z": np.array([0, 0, -1]),
|
||||
}
|
||||
if (
|
||||
self.cfg.shape_init_mesh_up not in dirs
|
||||
or self.cfg.shape_init_mesh_front not in dirs
|
||||
):
|
||||
raise ValueError(
|
||||
f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
|
||||
)
|
||||
if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]:
|
||||
raise ValueError(
|
||||
"shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
|
||||
)
|
||||
z_, x_ = (
|
||||
dir2vec[self.cfg.shape_init_mesh_up],
|
||||
dir2vec[self.cfg.shape_init_mesh_front],
|
||||
)
|
||||
y_ = np.cross(z_, x_)
|
||||
std2mesh = np.stack([x_, y_, z_], axis=0).T
|
||||
mesh2std = np.linalg.inv(std2mesh)
|
||||
|
||||
# scaling
|
||||
scale = np.abs(mesh.vertices).max()
|
||||
mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params
|
||||
mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
|
||||
|
||||
v_pos = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device)
|
||||
t_pos_idx = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device)
|
||||
self.mesh = Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx)
|
||||
self.register_buffer(
|
||||
"v_buffer",
|
||||
v_pos,
|
||||
)
|
||||
self.register_buffer(
|
||||
"t_buffer",
|
||||
t_pos_idx,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown shape initialization type: {self.cfg.shape_init}"
|
||||
)
|
||||
print(self.mesh.v_pos.device)
|
||||
|
||||
def isosurface(self) -> Mesh:
|
||||
if hasattr(self, "mesh"):
|
||||
return self.mesh
|
||||
elif hasattr(self, "v_buffer"):
|
||||
self.mesh = Mesh(v_pos=self.v_buffer, t_pos_idx=self.t_buffer)
|
||||
return self.mesh
|
||||
else:
|
||||
raise ValueError(f"custom mesh is not initialized")
|
||||
|
||||
def forward(
|
||||
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
|
||||
) -> Dict[str, Float[Tensor, "..."]]:
|
||||
assert (
|
||||
output_normal == False
|
||||
), f"Normal output is not supported for {self.__class__.__name__}"
|
||||
points_unscaled = points # points in the original scale
|
||||
points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1)
|
||||
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
|
||||
features = self.feature_network(enc).view(
|
||||
*points.shape[:-1], self.cfg.n_feature_dims
|
||||
)
|
||||
return {"features": features}
|
||||
|
||||
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
|
||||
out: Dict[str, Any] = {}
|
||||
if self.cfg.n_feature_dims == 0:
|
||||
return out
|
||||
points_unscaled = points
|
||||
points = contract_to_unisphere(points_unscaled, self.bbox)
|
||||
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
||||
features = self.feature_network(enc).view(
|
||||
*points.shape[:-1], self.cfg.n_feature_dims
|
||||
)
|
||||
out.update(
|
||||
{
|
||||
"features": features,
|
||||
}
|
||||
)
|
||||
return out
|
413
threestudio/models/geometry/implicit_sdf.py
Normal file
413
threestudio/models/geometry/implicit_sdf.py
Normal file
@ -0,0 +1,413 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere
|
||||
from threestudio.models.mesh import Mesh
|
||||
from threestudio.models.networks import get_encoding, get_mlp
|
||||
from threestudio.utils.misc import broadcast, get_rank
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("implicit-sdf")
|
||||
class ImplicitSDF(BaseImplicitGeometry):
|
||||
@dataclass
|
||||
class Config(BaseImplicitGeometry.Config):
|
||||
n_input_dims: int = 3
|
||||
n_feature_dims: int = 3
|
||||
pos_encoding_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"otype": "HashGrid",
|
||||
"n_levels": 16,
|
||||
"n_features_per_level": 2,
|
||||
"log2_hashmap_size": 19,
|
||||
"base_resolution": 16,
|
||||
"per_level_scale": 1.447269237440378,
|
||||
}
|
||||
)
|
||||
mlp_network_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"otype": "VanillaMLP",
|
||||
"activation": "ReLU",
|
||||
"output_activation": "none",
|
||||
"n_neurons": 64,
|
||||
"n_hidden_layers": 1,
|
||||
}
|
||||
)
|
||||
normal_type: Optional[
|
||||
str
|
||||
] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian']
|
||||
finite_difference_normal_eps: Union[
|
||||
float, str
|
||||
] = 0.01 # in [float, "progressive"]
|
||||
shape_init: Optional[str] = None
|
||||
shape_init_params: Optional[Any] = None
|
||||
shape_init_mesh_up: str = "+z"
|
||||
shape_init_mesh_front: str = "+x"
|
||||
force_shape_init: bool = False
|
||||
sdf_bias: Union[float, str] = 0.0
|
||||
sdf_bias_params: Optional[Any] = None
|
||||
|
||||
# no need to removal outlier for SDF
|
||||
isosurface_remove_outliers: bool = False
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
super().configure()
|
||||
self.encoding = get_encoding(
|
||||
self.cfg.n_input_dims, self.cfg.pos_encoding_config
|
||||
)
|
||||
self.sdf_network = get_mlp(
|
||||
self.encoding.n_output_dims, 1, self.cfg.mlp_network_config
|
||||
)
|
||||
|
||||
if self.cfg.n_feature_dims > 0:
|
||||
self.feature_network = get_mlp(
|
||||
self.encoding.n_output_dims,
|
||||
self.cfg.n_feature_dims,
|
||||
self.cfg.mlp_network_config,
|
||||
)
|
||||
|
||||
if self.cfg.normal_type == "pred":
|
||||
self.normal_network = get_mlp(
|
||||
self.encoding.n_output_dims, 3, self.cfg.mlp_network_config
|
||||
)
|
||||
if self.cfg.isosurface_deformable_grid:
|
||||
assert (
|
||||
self.cfg.isosurface_method == "mt"
|
||||
), "isosurface_deformable_grid only works with mt"
|
||||
self.deformation_network = get_mlp(
|
||||
self.encoding.n_output_dims, 3, self.cfg.mlp_network_config
|
||||
)
|
||||
|
||||
self.finite_difference_normal_eps: Optional[float] = None
|
||||
|
||||
def initialize_shape(self) -> None:
|
||||
if self.cfg.shape_init is None and not self.cfg.force_shape_init:
|
||||
return
|
||||
|
||||
# do not initialize shape if weights are provided
|
||||
if self.cfg.weights is not None and not self.cfg.force_shape_init:
|
||||
return
|
||||
|
||||
if self.cfg.sdf_bias != 0.0:
|
||||
threestudio.warn(
|
||||
"shape_init and sdf_bias are both specified, which may lead to unexpected results."
|
||||
)
|
||||
|
||||
get_gt_sdf: Callable[[Float[Tensor, "N 3"]], Float[Tensor, "N 1"]]
|
||||
assert isinstance(self.cfg.shape_init, str)
|
||||
if self.cfg.shape_init == "ellipsoid":
|
||||
assert (
|
||||
isinstance(self.cfg.shape_init_params, Sized)
|
||||
and len(self.cfg.shape_init_params) == 3
|
||||
)
|
||||
size = torch.as_tensor(self.cfg.shape_init_params).to(self.device)
|
||||
|
||||
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
||||
return ((points_rand / size) ** 2).sum(
|
||||
dim=-1, keepdim=True
|
||||
).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
|
||||
|
||||
get_gt_sdf = func
|
||||
elif self.cfg.shape_init == "sphere":
|
||||
assert isinstance(self.cfg.shape_init_params, float)
|
||||
radius = self.cfg.shape_init_params
|
||||
|
||||
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
||||
return (points_rand**2).sum(dim=-1, keepdim=True).sqrt() - radius
|
||||
|
||||
get_gt_sdf = func
|
||||
elif self.cfg.shape_init.startswith("mesh:"):
|
||||
assert isinstance(self.cfg.shape_init_params, float)
|
||||
mesh_path = self.cfg.shape_init[5:]
|
||||
if not os.path.exists(mesh_path):
|
||||
raise ValueError(f"Mesh file {mesh_path} does not exist.")
|
||||
|
||||
import trimesh
|
||||
|
||||
scene = trimesh.load(mesh_path)
|
||||
if isinstance(scene, trimesh.Trimesh):
|
||||
mesh = scene
|
||||
elif isinstance(scene, trimesh.scene.Scene):
|
||||
mesh = trimesh.Trimesh()
|
||||
for obj in scene.geometry.values():
|
||||
mesh = trimesh.util.concatenate([mesh, obj])
|
||||
else:
|
||||
raise ValueError(f"Unknown mesh type at {mesh_path}.")
|
||||
|
||||
# move to center
|
||||
centroid = mesh.vertices.mean(0)
|
||||
mesh.vertices = mesh.vertices - centroid
|
||||
|
||||
# align to up-z and front-x
|
||||
dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
|
||||
dir2vec = {
|
||||
"+x": np.array([1, 0, 0]),
|
||||
"+y": np.array([0, 1, 0]),
|
||||
"+z": np.array([0, 0, 1]),
|
||||
"-x": np.array([-1, 0, 0]),
|
||||
"-y": np.array([0, -1, 0]),
|
||||
"-z": np.array([0, 0, -1]),
|
||||
}
|
||||
if (
|
||||
self.cfg.shape_init_mesh_up not in dirs
|
||||
or self.cfg.shape_init_mesh_front not in dirs
|
||||
):
|
||||
raise ValueError(
|
||||
f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
|
||||
)
|
||||
if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]:
|
||||
raise ValueError(
|
||||
"shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
|
||||
)
|
||||
z_, x_ = (
|
||||
dir2vec[self.cfg.shape_init_mesh_up],
|
||||
dir2vec[self.cfg.shape_init_mesh_front],
|
||||
)
|
||||
y_ = np.cross(z_, x_)
|
||||
std2mesh = np.stack([x_, y_, z_], axis=0).T
|
||||
mesh2std = np.linalg.inv(std2mesh)
|
||||
|
||||
# scaling
|
||||
scale = np.abs(mesh.vertices).max()
|
||||
mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params
|
||||
mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
|
||||
|
||||
from pysdf import SDF
|
||||
|
||||
sdf = SDF(mesh.vertices, mesh.faces)
|
||||
|
||||
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
||||
# add a negative signed here
|
||||
# as in pysdf the inside of the shape has positive signed distance
|
||||
return torch.from_numpy(-sdf(points_rand.cpu().numpy())).to(
|
||||
points_rand
|
||||
)[..., None]
|
||||
|
||||
get_gt_sdf = func
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown shape initialization type: {self.cfg.shape_init}"
|
||||
)
|
||||
|
||||
# Initialize SDF to a given shape when no weights are provided or force_shape_init is True
|
||||
optim = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
from tqdm import tqdm
|
||||
|
||||
for _ in tqdm(
|
||||
range(1000),
|
||||
desc=f"Initializing SDF to a(n) {self.cfg.shape_init}:",
|
||||
disable=get_rank() != 0,
|
||||
):
|
||||
points_rand = (
|
||||
torch.rand((10000, 3), dtype=torch.float32).to(self.device) * 2.0 - 1.0
|
||||
)
|
||||
sdf_gt = get_gt_sdf(points_rand)
|
||||
sdf_pred = self.forward_sdf(points_rand)
|
||||
loss = F.mse_loss(sdf_pred, sdf_gt)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
# explicit broadcast to ensure param consistency across ranks
|
||||
for param in self.parameters():
|
||||
broadcast(param, src=0)
|
||||
|
||||
def get_shifted_sdf(
|
||||
self, points: Float[Tensor, "*N Di"], sdf: Float[Tensor, "*N 1"]
|
||||
) -> Float[Tensor, "*N 1"]:
|
||||
sdf_bias: Union[float, Float[Tensor, "*N 1"]]
|
||||
if self.cfg.sdf_bias == "ellipsoid":
|
||||
assert (
|
||||
isinstance(self.cfg.sdf_bias_params, Sized)
|
||||
and len(self.cfg.sdf_bias_params) == 3
|
||||
)
|
||||
size = torch.as_tensor(self.cfg.sdf_bias_params).to(points)
|
||||
sdf_bias = ((points / size) ** 2).sum(
|
||||
dim=-1, keepdim=True
|
||||
).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
|
||||
elif self.cfg.sdf_bias == "sphere":
|
||||
assert isinstance(self.cfg.sdf_bias_params, float)
|
||||
radius = self.cfg.sdf_bias_params
|
||||
sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius
|
||||
elif isinstance(self.cfg.sdf_bias, float):
|
||||
sdf_bias = self.cfg.sdf_bias
|
||||
else:
|
||||
raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}")
|
||||
return sdf + sdf_bias
|
||||
|
||||
def forward(
|
||||
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
|
||||
) -> Dict[str, Float[Tensor, "..."]]:
|
||||
grad_enabled = torch.is_grad_enabled()
|
||||
|
||||
if output_normal and self.cfg.normal_type == "analytic":
|
||||
torch.set_grad_enabled(True)
|
||||
points.requires_grad_(True)
|
||||
|
||||
points_unscaled = points # points in the original scale
|
||||
points = contract_to_unisphere(
|
||||
points, self.bbox, self.unbounded
|
||||
) # points normalized to (0, 1)
|
||||
|
||||
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
|
||||
sdf = self.sdf_network(enc).view(*points.shape[:-1], 1)
|
||||
sdf = self.get_shifted_sdf(points_unscaled, sdf)
|
||||
output = {"sdf": sdf}
|
||||
|
||||
if self.cfg.n_feature_dims > 0:
|
||||
features = self.feature_network(enc).view(
|
||||
*points.shape[:-1], self.cfg.n_feature_dims
|
||||
)
|
||||
output.update({"features": features})
|
||||
|
||||
if output_normal:
|
||||
if (
|
||||
self.cfg.normal_type == "finite_difference"
|
||||
or self.cfg.normal_type == "finite_difference_laplacian"
|
||||
):
|
||||
assert self.finite_difference_normal_eps is not None
|
||||
eps: float = self.finite_difference_normal_eps
|
||||
if self.cfg.normal_type == "finite_difference_laplacian":
|
||||
offsets: Float[Tensor, "6 3"] = torch.as_tensor(
|
||||
[
|
||||
[eps, 0.0, 0.0],
|
||||
[-eps, 0.0, 0.0],
|
||||
[0.0, eps, 0.0],
|
||||
[0.0, -eps, 0.0],
|
||||
[0.0, 0.0, eps],
|
||||
[0.0, 0.0, -eps],
|
||||
]
|
||||
).to(points_unscaled)
|
||||
points_offset: Float[Tensor, "... 6 3"] = (
|
||||
points_unscaled[..., None, :] + offsets
|
||||
).clamp(-self.cfg.radius, self.cfg.radius)
|
||||
sdf_offset: Float[Tensor, "... 6 1"] = self.forward_sdf(
|
||||
points_offset
|
||||
)
|
||||
sdf_grad = (
|
||||
0.5
|
||||
* (sdf_offset[..., 0::2, 0] - sdf_offset[..., 1::2, 0])
|
||||
/ eps
|
||||
)
|
||||
else:
|
||||
offsets: Float[Tensor, "3 3"] = torch.as_tensor(
|
||||
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
|
||||
).to(points_unscaled)
|
||||
points_offset: Float[Tensor, "... 3 3"] = (
|
||||
points_unscaled[..., None, :] + offsets
|
||||
).clamp(-self.cfg.radius, self.cfg.radius)
|
||||
sdf_offset: Float[Tensor, "... 3 1"] = self.forward_sdf(
|
||||
points_offset
|
||||
)
|
||||
sdf_grad = (sdf_offset[..., 0::1, 0] - sdf) / eps
|
||||
normal = F.normalize(sdf_grad, dim=-1)
|
||||
elif self.cfg.normal_type == "pred":
|
||||
normal = self.normal_network(enc).view(*points.shape[:-1], 3)
|
||||
normal = F.normalize(normal, dim=-1)
|
||||
sdf_grad = normal
|
||||
elif self.cfg.normal_type == "analytic":
|
||||
sdf_grad = -torch.autograd.grad(
|
||||
sdf,
|
||||
points_unscaled,
|
||||
grad_outputs=torch.ones_like(sdf),
|
||||
create_graph=True,
|
||||
)[0]
|
||||
normal = F.normalize(sdf_grad, dim=-1)
|
||||
if not grad_enabled:
|
||||
sdf_grad = sdf_grad.detach()
|
||||
normal = normal.detach()
|
||||
else:
|
||||
raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
|
||||
output.update(
|
||||
{"normal": normal, "shading_normal": normal, "sdf_grad": sdf_grad}
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_sdf(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
|
||||
points_unscaled = points
|
||||
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
||||
|
||||
sdf = self.sdf_network(
|
||||
self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
||||
).reshape(*points.shape[:-1], 1)
|
||||
sdf = self.get_shifted_sdf(points_unscaled, sdf)
|
||||
return sdf
|
||||
|
||||
def forward_field(
|
||||
self, points: Float[Tensor, "*N Di"]
|
||||
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
|
||||
points_unscaled = points
|
||||
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
||||
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
||||
sdf = self.sdf_network(enc).reshape(*points.shape[:-1], 1)
|
||||
sdf = self.get_shifted_sdf(points_unscaled, sdf)
|
||||
deformation: Optional[Float[Tensor, "*N 3"]] = None
|
||||
if self.cfg.isosurface_deformable_grid:
|
||||
deformation = self.deformation_network(enc).reshape(*points.shape[:-1], 3)
|
||||
return sdf, deformation
|
||||
|
||||
def forward_level(
|
||||
self, field: Float[Tensor, "*N 1"], threshold: float
|
||||
) -> Float[Tensor, "*N 1"]:
|
||||
return field - threshold
|
||||
|
||||
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
|
||||
out: Dict[str, Any] = {}
|
||||
if self.cfg.n_feature_dims == 0:
|
||||
return out
|
||||
points_unscaled = points
|
||||
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
||||
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
||||
features = self.feature_network(enc).view(
|
||||
*points.shape[:-1], self.cfg.n_feature_dims
|
||||
)
|
||||
out.update(
|
||||
{
|
||||
"features": features,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
if (
|
||||
self.cfg.normal_type == "finite_difference"
|
||||
or self.cfg.normal_type == "finite_difference_laplacian"
|
||||
):
|
||||
if isinstance(self.cfg.finite_difference_normal_eps, float):
|
||||
self.finite_difference_normal_eps = (
|
||||
self.cfg.finite_difference_normal_eps
|
||||
)
|
||||
elif self.cfg.finite_difference_normal_eps == "progressive":
|
||||
# progressive finite difference eps from Neuralangelo
|
||||
# https://arxiv.org/abs/2306.03092
|
||||
hg_conf: Any = self.cfg.pos_encoding_config
|
||||
assert (
|
||||
hg_conf.otype == "ProgressiveBandHashGrid"
|
||||
), "finite_difference_normal_eps=progressive only works with ProgressiveBandHashGrid"
|
||||
current_level = min(
|
||||
hg_conf.start_level
|
||||
+ max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps,
|
||||
hg_conf.n_levels,
|
||||
)
|
||||
grid_res = hg_conf.base_resolution * hg_conf.per_level_scale ** (
|
||||
current_level - 1
|
||||
)
|
||||
grid_size = 2 * self.cfg.radius / grid_res
|
||||
if grid_size != self.finite_difference_normal_eps:
|
||||
threestudio.info(
|
||||
f"Update finite_difference_normal_eps to {grid_size}"
|
||||
)
|
||||
self.finite_difference_normal_eps = grid_size
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown finite_difference_normal_eps={self.cfg.finite_difference_normal_eps}"
|
||||
)
|
325
threestudio/models/geometry/implicit_volume.py
Normal file
325
threestudio/models/geometry/implicit_volume.py
Normal file
@ -0,0 +1,325 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.geometry.base import (
|
||||
BaseGeometry,
|
||||
BaseImplicitGeometry,
|
||||
contract_to_unisphere,
|
||||
)
|
||||
from threestudio.models.networks import get_encoding, get_mlp
|
||||
from threestudio.utils.ops import get_activation
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("implicit-volume")
|
||||
class ImplicitVolume(BaseImplicitGeometry):
|
||||
@dataclass
|
||||
class Config(BaseImplicitGeometry.Config):
|
||||
n_input_dims: int = 3
|
||||
n_feature_dims: int = 3
|
||||
density_activation: Optional[str] = "softplus"
|
||||
density_bias: Union[float, str] = "blob_magic3d"
|
||||
density_blob_scale: float = 10.0
|
||||
density_blob_std: float = 0.5
|
||||
pos_encoding_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"otype": "HashGrid",
|
||||
"n_levels": 16,
|
||||
"n_features_per_level": 2,
|
||||
"log2_hashmap_size": 19,
|
||||
"base_resolution": 16,
|
||||
"per_level_scale": 1.447269237440378,
|
||||
}
|
||||
)
|
||||
mlp_network_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"otype": "VanillaMLP",
|
||||
"activation": "ReLU",
|
||||
"output_activation": "none",
|
||||
"n_neurons": 64,
|
||||
"n_hidden_layers": 1,
|
||||
}
|
||||
)
|
||||
normal_type: Optional[
|
||||
str
|
||||
] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian']
|
||||
finite_difference_normal_eps: Union[
|
||||
float, str
|
||||
] = 0.01 # in [float, "progressive"]
|
||||
|
||||
# automatically determine the threshold
|
||||
isosurface_threshold: Union[float, str] = 25.0
|
||||
|
||||
# 4D Gaussian Annealing
|
||||
anneal_density_blob_std_config: Optional[dict] = None
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
super().configure()
|
||||
self.encoding = get_encoding(
|
||||
self.cfg.n_input_dims, self.cfg.pos_encoding_config
|
||||
)
|
||||
self.density_network = get_mlp(
|
||||
self.encoding.n_output_dims, 1, self.cfg.mlp_network_config
|
||||
)
|
||||
if self.cfg.n_feature_dims > 0:
|
||||
self.feature_network = get_mlp(
|
||||
self.encoding.n_output_dims,
|
||||
self.cfg.n_feature_dims,
|
||||
self.cfg.mlp_network_config,
|
||||
)
|
||||
if self.cfg.normal_type == "pred":
|
||||
self.normal_network = get_mlp(
|
||||
self.encoding.n_output_dims, 3, self.cfg.mlp_network_config
|
||||
)
|
||||
|
||||
self.finite_difference_normal_eps: Optional[float] = None
|
||||
|
||||
def get_activated_density(
|
||||
self, points: Float[Tensor, "*N Di"], density: Float[Tensor, "*N 1"]
|
||||
) -> Tuple[Float[Tensor, "*N 1"], Float[Tensor, "*N 1"]]:
|
||||
density_bias: Union[float, Float[Tensor, "*N 1"]]
|
||||
if self.cfg.density_bias == "blob_dreamfusion":
|
||||
# pre-activation density bias
|
||||
density_bias = (
|
||||
self.cfg.density_blob_scale
|
||||
* torch.exp(
|
||||
-0.5 * (points**2).sum(dim=-1) / self.cfg.density_blob_std**2
|
||||
)[..., None]
|
||||
)
|
||||
elif self.cfg.density_bias == "blob_magic3d":
|
||||
# pre-activation density bias
|
||||
density_bias = (
|
||||
self.cfg.density_blob_scale
|
||||
* (
|
||||
1
|
||||
- torch.sqrt((points**2).sum(dim=-1)) / self.cfg.density_blob_std
|
||||
)[..., None]
|
||||
)
|
||||
elif isinstance(self.cfg.density_bias, float):
|
||||
density_bias = self.cfg.density_bias
|
||||
else:
|
||||
raise ValueError(f"Unknown density bias {self.cfg.density_bias}")
|
||||
raw_density: Float[Tensor, "*N 1"] = density + density_bias
|
||||
density = get_activation(self.cfg.density_activation)(raw_density)
|
||||
return raw_density, density
|
||||
|
||||
def forward(
|
||||
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
|
||||
) -> Dict[str, Float[Tensor, "..."]]:
|
||||
grad_enabled = torch.is_grad_enabled()
|
||||
|
||||
if output_normal and self.cfg.normal_type == "analytic":
|
||||
torch.set_grad_enabled(True)
|
||||
points.requires_grad_(True)
|
||||
|
||||
points_unscaled = points # points in the original scale
|
||||
points = contract_to_unisphere(
|
||||
points, self.bbox, self.unbounded
|
||||
) # points normalized to (0, 1)
|
||||
|
||||
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
|
||||
density = self.density_network(enc).view(*points.shape[:-1], 1)
|
||||
raw_density, density = self.get_activated_density(points_unscaled, density)
|
||||
|
||||
output = {
|
||||
"density": density,
|
||||
}
|
||||
|
||||
if self.cfg.n_feature_dims > 0:
|
||||
features = self.feature_network(enc).view(
|
||||
*points.shape[:-1], self.cfg.n_feature_dims
|
||||
)
|
||||
output.update({"features": features})
|
||||
|
||||
if output_normal:
|
||||
if (
|
||||
self.cfg.normal_type == "finite_difference"
|
||||
or self.cfg.normal_type == "finite_difference_laplacian"
|
||||
):
|
||||
# TODO: use raw density
|
||||
assert self.finite_difference_normal_eps is not None
|
||||
eps: float = self.finite_difference_normal_eps
|
||||
if self.cfg.normal_type == "finite_difference_laplacian":
|
||||
offsets: Float[Tensor, "6 3"] = torch.as_tensor(
|
||||
[
|
||||
[eps, 0.0, 0.0],
|
||||
[-eps, 0.0, 0.0],
|
||||
[0.0, eps, 0.0],
|
||||
[0.0, -eps, 0.0],
|
||||
[0.0, 0.0, eps],
|
||||
[0.0, 0.0, -eps],
|
||||
]
|
||||
).to(points_unscaled)
|
||||
points_offset: Float[Tensor, "... 6 3"] = (
|
||||
points_unscaled[..., None, :] + offsets
|
||||
).clamp(-self.cfg.radius, self.cfg.radius)
|
||||
density_offset: Float[Tensor, "... 6 1"] = self.forward_density(
|
||||
points_offset
|
||||
)
|
||||
normal = (
|
||||
-0.5
|
||||
* (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0])
|
||||
/ eps
|
||||
)
|
||||
else:
|
||||
offsets: Float[Tensor, "3 3"] = torch.as_tensor(
|
||||
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
|
||||
).to(points_unscaled)
|
||||
points_offset: Float[Tensor, "... 3 3"] = (
|
||||
points_unscaled[..., None, :] + offsets
|
||||
).clamp(-self.cfg.radius, self.cfg.radius)
|
||||
density_offset: Float[Tensor, "... 3 1"] = self.forward_density(
|
||||
points_offset
|
||||
)
|
||||
normal = -(density_offset[..., 0::1, 0] - density) / eps
|
||||
normal = F.normalize(normal, dim=-1)
|
||||
elif self.cfg.normal_type == "pred":
|
||||
normal = self.normal_network(enc).view(*points.shape[:-1], 3)
|
||||
normal = F.normalize(normal, dim=-1)
|
||||
elif self.cfg.normal_type == "analytic":
|
||||
normal = -torch.autograd.grad(
|
||||
density,
|
||||
points_unscaled,
|
||||
grad_outputs=torch.ones_like(density),
|
||||
create_graph=True,
|
||||
)[0]
|
||||
normal = F.normalize(normal, dim=-1)
|
||||
if not grad_enabled:
|
||||
normal = normal.detach()
|
||||
else:
|
||||
raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
|
||||
output.update({"normal": normal, "shading_normal": normal})
|
||||
|
||||
torch.set_grad_enabled(grad_enabled)
|
||||
return output
|
||||
|
||||
def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
|
||||
points_unscaled = points
|
||||
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
||||
|
||||
density = self.density_network(
|
||||
self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
||||
).reshape(*points.shape[:-1], 1)
|
||||
|
||||
_, density = self.get_activated_density(points_unscaled, density)
|
||||
return density
|
||||
|
||||
def forward_field(
|
||||
self, points: Float[Tensor, "*N Di"]
|
||||
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
|
||||
if self.cfg.isosurface_deformable_grid:
|
||||
threestudio.warn(
|
||||
f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring."
|
||||
)
|
||||
density = self.forward_density(points)
|
||||
return density, None
|
||||
|
||||
def forward_level(
|
||||
self, field: Float[Tensor, "*N 1"], threshold: float
|
||||
) -> Float[Tensor, "*N 1"]:
|
||||
return -(field - threshold)
|
||||
|
||||
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
|
||||
out: Dict[str, Any] = {}
|
||||
if self.cfg.n_feature_dims == 0:
|
||||
return out
|
||||
points_unscaled = points
|
||||
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
||||
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
||||
features = self.feature_network(enc).view(
|
||||
*points.shape[:-1], self.cfg.n_feature_dims
|
||||
)
|
||||
out.update(
|
||||
{
|
||||
"features": features,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def create_from(
|
||||
other: BaseGeometry,
|
||||
cfg: Optional[Union[dict, DictConfig]] = None,
|
||||
copy_net: bool = True,
|
||||
**kwargs,
|
||||
) -> "ImplicitVolume":
|
||||
if isinstance(other, ImplicitVolume):
|
||||
instance = ImplicitVolume(cfg, **kwargs)
|
||||
instance.encoding.load_state_dict(other.encoding.state_dict())
|
||||
instance.density_network.load_state_dict(other.density_network.state_dict())
|
||||
if copy_net:
|
||||
if (
|
||||
instance.cfg.n_feature_dims > 0
|
||||
and other.cfg.n_feature_dims == instance.cfg.n_feature_dims
|
||||
):
|
||||
instance.feature_network.load_state_dict(
|
||||
other.feature_network.state_dict()
|
||||
)
|
||||
if (
|
||||
instance.cfg.normal_type == "pred"
|
||||
and other.cfg.normal_type == "pred"
|
||||
):
|
||||
instance.normal_network.load_state_dict(
|
||||
other.normal_network.state_dict()
|
||||
)
|
||||
return instance
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Cannot create {ImplicitVolume.__name__} from {other.__class__.__name__}"
|
||||
)
|
||||
|
||||
# FIXME: use progressive normal eps
|
||||
def update_step(
|
||||
self, epoch: int, global_step: int, on_load_weights: bool = False
|
||||
) -> None:
|
||||
if self.cfg.anneal_density_blob_std_config is not None:
|
||||
min_step = self.cfg.anneal_density_blob_std_config.min_anneal_step
|
||||
max_step = self.cfg.anneal_density_blob_std_config.max_anneal_step
|
||||
if global_step >= min_step and global_step <= max_step:
|
||||
end_val = self.cfg.anneal_density_blob_std_config.end_val
|
||||
start_val = self.cfg.anneal_density_blob_std_config.start_val
|
||||
self.density_blob_std = start_val + (global_step - min_step) * (
|
||||
end_val - start_val
|
||||
) / (max_step - min_step)
|
||||
|
||||
if (
|
||||
self.cfg.normal_type == "finite_difference"
|
||||
or self.cfg.normal_type == "finite_difference_laplacian"
|
||||
):
|
||||
if isinstance(self.cfg.finite_difference_normal_eps, float):
|
||||
self.finite_difference_normal_eps = (
|
||||
self.cfg.finite_difference_normal_eps
|
||||
)
|
||||
elif self.cfg.finite_difference_normal_eps == "progressive":
|
||||
# progressive finite difference eps from Neuralangelo
|
||||
# https://arxiv.org/abs/2306.03092
|
||||
hg_conf: Any = self.cfg.pos_encoding_config
|
||||
assert (
|
||||
hg_conf.otype == "ProgressiveBandHashGrid"
|
||||
), "finite_difference_normal_eps=progressive only works with ProgressiveBandHashGrid"
|
||||
current_level = min(
|
||||
hg_conf.start_level
|
||||
+ max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps,
|
||||
hg_conf.n_levels,
|
||||
)
|
||||
grid_res = hg_conf.base_resolution * hg_conf.per_level_scale ** (
|
||||
current_level - 1
|
||||
)
|
||||
grid_size = 2 * self.cfg.radius / grid_res
|
||||
if grid_size != self.finite_difference_normal_eps:
|
||||
threestudio.info(
|
||||
f"Update finite_difference_normal_eps to {grid_size}"
|
||||
)
|
||||
self.finite_difference_normal_eps = grid_size
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown finite_difference_normal_eps={self.cfg.finite_difference_normal_eps}"
|
||||
)
|
369
threestudio/models/geometry/tetrahedra_sdf_grid.py
Normal file
369
threestudio/models/geometry/tetrahedra_sdf_grid.py
Normal file
@ -0,0 +1,369 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.geometry.base import (
|
||||
BaseExplicitGeometry,
|
||||
BaseGeometry,
|
||||
contract_to_unisphere,
|
||||
)
|
||||
from threestudio.models.geometry.implicit_sdf import ImplicitSDF
|
||||
from threestudio.models.geometry.implicit_volume import ImplicitVolume
|
||||
from threestudio.models.isosurface import MarchingTetrahedraHelper
|
||||
from threestudio.models.mesh import Mesh
|
||||
from threestudio.models.networks import get_encoding, get_mlp
|
||||
from threestudio.utils.misc import broadcast
|
||||
from threestudio.utils.ops import scale_tensor
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("tetrahedra-sdf-grid")
|
||||
class TetrahedraSDFGrid(BaseExplicitGeometry):
|
||||
@dataclass
|
||||
class Config(BaseExplicitGeometry.Config):
|
||||
isosurface_resolution: int = 128
|
||||
isosurface_deformable_grid: bool = True
|
||||
isosurface_remove_outliers: bool = False
|
||||
isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01
|
||||
|
||||
n_input_dims: int = 3
|
||||
n_feature_dims: int = 3
|
||||
pos_encoding_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"otype": "HashGrid",
|
||||
"n_levels": 16,
|
||||
"n_features_per_level": 2,
|
||||
"log2_hashmap_size": 19,
|
||||
"base_resolution": 16,
|
||||
"per_level_scale": 1.447269237440378,
|
||||
}
|
||||
)
|
||||
mlp_network_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"otype": "VanillaMLP",
|
||||
"activation": "ReLU",
|
||||
"output_activation": "none",
|
||||
"n_neurons": 64,
|
||||
"n_hidden_layers": 1,
|
||||
}
|
||||
)
|
||||
shape_init: Optional[str] = None
|
||||
shape_init_params: Optional[Any] = None
|
||||
shape_init_mesh_up: str = "+z"
|
||||
shape_init_mesh_front: str = "+x"
|
||||
force_shape_init: bool = False
|
||||
geometry_only: bool = False
|
||||
fix_geometry: bool = False
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
super().configure()
|
||||
|
||||
# this should be saved to state_dict, register as buffer
|
||||
self.isosurface_bbox: Float[Tensor, "2 3"]
|
||||
self.register_buffer("isosurface_bbox", self.bbox.clone())
|
||||
|
||||
self.isosurface_helper = MarchingTetrahedraHelper(
|
||||
self.cfg.isosurface_resolution,
|
||||
f"load/tets/{self.cfg.isosurface_resolution}_tets.npz",
|
||||
)
|
||||
|
||||
self.sdf: Float[Tensor, "Nv 1"]
|
||||
self.deformation: Optional[Float[Tensor, "Nv 3"]]
|
||||
|
||||
if not self.cfg.fix_geometry:
|
||||
self.register_parameter(
|
||||
"sdf",
|
||||
nn.Parameter(
|
||||
torch.zeros(
|
||||
(self.isosurface_helper.grid_vertices.shape[0], 1),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
),
|
||||
)
|
||||
if self.cfg.isosurface_deformable_grid:
|
||||
self.register_parameter(
|
||||
"deformation",
|
||||
nn.Parameter(
|
||||
torch.zeros_like(self.isosurface_helper.grid_vertices)
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.deformation = None
|
||||
else:
|
||||
self.register_buffer(
|
||||
"sdf",
|
||||
torch.zeros(
|
||||
(self.isosurface_helper.grid_vertices.shape[0], 1),
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
if self.cfg.isosurface_deformable_grid:
|
||||
self.register_buffer(
|
||||
"deformation",
|
||||
torch.zeros_like(self.isosurface_helper.grid_vertices),
|
||||
)
|
||||
else:
|
||||
self.deformation = None
|
||||
|
||||
if not self.cfg.geometry_only:
|
||||
self.encoding = get_encoding(
|
||||
self.cfg.n_input_dims, self.cfg.pos_encoding_config
|
||||
)
|
||||
self.feature_network = get_mlp(
|
||||
self.encoding.n_output_dims,
|
||||
self.cfg.n_feature_dims,
|
||||
self.cfg.mlp_network_config,
|
||||
)
|
||||
|
||||
self.mesh: Optional[Mesh] = None
|
||||
|
||||
def initialize_shape(self) -> None:
|
||||
if self.cfg.shape_init is None and not self.cfg.force_shape_init:
|
||||
return
|
||||
|
||||
# do not initialize shape if weights are provided
|
||||
if self.cfg.weights is not None and not self.cfg.force_shape_init:
|
||||
return
|
||||
|
||||
get_gt_sdf: Callable[[Float[Tensor, "N 3"]], Float[Tensor, "N 1"]]
|
||||
assert isinstance(self.cfg.shape_init, str)
|
||||
if self.cfg.shape_init == "ellipsoid":
|
||||
assert (
|
||||
isinstance(self.cfg.shape_init_params, Sized)
|
||||
and len(self.cfg.shape_init_params) == 3
|
||||
)
|
||||
size = torch.as_tensor(self.cfg.shape_init_params).to(self.device)
|
||||
|
||||
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
||||
return ((points_rand / size) ** 2).sum(
|
||||
dim=-1, keepdim=True
|
||||
).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
|
||||
|
||||
get_gt_sdf = func
|
||||
elif self.cfg.shape_init == "sphere":
|
||||
assert isinstance(self.cfg.shape_init_params, float)
|
||||
radius = self.cfg.shape_init_params
|
||||
|
||||
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
||||
return (points_rand**2).sum(dim=-1, keepdim=True).sqrt() - radius
|
||||
|
||||
get_gt_sdf = func
|
||||
elif self.cfg.shape_init.startswith("mesh:"):
|
||||
assert isinstance(self.cfg.shape_init_params, float)
|
||||
mesh_path = self.cfg.shape_init[5:]
|
||||
if not os.path.exists(mesh_path):
|
||||
raise ValueError(f"Mesh file {mesh_path} does not exist.")
|
||||
|
||||
import trimesh
|
||||
|
||||
mesh = trimesh.load(mesh_path)
|
||||
|
||||
# move to center
|
||||
centroid = mesh.vertices.mean(0)
|
||||
mesh.vertices = mesh.vertices - centroid
|
||||
|
||||
# align to up-z and front-x
|
||||
dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
|
||||
dir2vec = {
|
||||
"+x": np.array([1, 0, 0]),
|
||||
"+y": np.array([0, 1, 0]),
|
||||
"+z": np.array([0, 0, 1]),
|
||||
"-x": np.array([-1, 0, 0]),
|
||||
"-y": np.array([0, -1, 0]),
|
||||
"-z": np.array([0, 0, -1]),
|
||||
}
|
||||
if (
|
||||
self.cfg.shape_init_mesh_up not in dirs
|
||||
or self.cfg.shape_init_mesh_front not in dirs
|
||||
):
|
||||
raise ValueError(
|
||||
f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
|
||||
)
|
||||
if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]:
|
||||
raise ValueError(
|
||||
"shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
|
||||
)
|
||||
z_, x_ = (
|
||||
dir2vec[self.cfg.shape_init_mesh_up],
|
||||
dir2vec[self.cfg.shape_init_mesh_front],
|
||||
)
|
||||
y_ = np.cross(z_, x_)
|
||||
std2mesh = np.stack([x_, y_, z_], axis=0).T
|
||||
mesh2std = np.linalg.inv(std2mesh)
|
||||
|
||||
# scaling
|
||||
scale = np.abs(mesh.vertices).max()
|
||||
mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params
|
||||
mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
|
||||
|
||||
from pysdf import SDF
|
||||
|
||||
sdf = SDF(mesh.vertices, mesh.faces)
|
||||
|
||||
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
||||
# add a negative signed here
|
||||
# as in pysdf the inside of the shape has positive signed distance
|
||||
return torch.from_numpy(-sdf(points_rand.cpu().numpy())).to(
|
||||
points_rand
|
||||
)[..., None]
|
||||
|
||||
get_gt_sdf = func
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown shape initialization type: {self.cfg.shape_init}"
|
||||
)
|
||||
|
||||
sdf_gt = get_gt_sdf(
|
||||
scale_tensor(
|
||||
self.isosurface_helper.grid_vertices,
|
||||
self.isosurface_helper.points_range,
|
||||
self.isosurface_bbox,
|
||||
)
|
||||
)
|
||||
self.sdf.data = sdf_gt
|
||||
|
||||
# explicit broadcast to ensure param consistency across ranks
|
||||
for param in self.parameters():
|
||||
broadcast(param, src=0)
|
||||
|
||||
def isosurface(self) -> Mesh:
|
||||
# return cached mesh if fix_geometry is True to save computation
|
||||
if self.cfg.fix_geometry and self.mesh is not None:
|
||||
return self.mesh
|
||||
mesh = self.isosurface_helper(self.sdf, self.deformation)
|
||||
mesh.v_pos = scale_tensor(
|
||||
mesh.v_pos, self.isosurface_helper.points_range, self.isosurface_bbox
|
||||
)
|
||||
if self.cfg.isosurface_remove_outliers:
|
||||
mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold)
|
||||
self.mesh = mesh
|
||||
return mesh
|
||||
|
||||
def forward(
|
||||
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
|
||||
) -> Dict[str, Float[Tensor, "..."]]:
|
||||
if self.cfg.geometry_only:
|
||||
return {}
|
||||
assert (
|
||||
output_normal == False
|
||||
), f"Normal output is not supported for {self.__class__.__name__}"
|
||||
points_unscaled = points # points in the original scale
|
||||
points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1)
|
||||
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
|
||||
features = self.feature_network(enc).view(
|
||||
*points.shape[:-1], self.cfg.n_feature_dims
|
||||
)
|
||||
return {"features": features}
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def create_from(
|
||||
other: BaseGeometry,
|
||||
cfg: Optional[Union[dict, DictConfig]] = None,
|
||||
copy_net: bool = True,
|
||||
**kwargs,
|
||||
) -> "TetrahedraSDFGrid":
|
||||
if isinstance(other, TetrahedraSDFGrid):
|
||||
instance = TetrahedraSDFGrid(cfg, **kwargs)
|
||||
assert instance.cfg.isosurface_resolution == other.cfg.isosurface_resolution
|
||||
instance.isosurface_bbox = other.isosurface_bbox.clone()
|
||||
instance.sdf.data = other.sdf.data.clone()
|
||||
if (
|
||||
instance.cfg.isosurface_deformable_grid
|
||||
and other.cfg.isosurface_deformable_grid
|
||||
):
|
||||
assert (
|
||||
instance.deformation is not None and other.deformation is not None
|
||||
)
|
||||
instance.deformation.data = other.deformation.data.clone()
|
||||
if (
|
||||
not instance.cfg.geometry_only
|
||||
and not other.cfg.geometry_only
|
||||
and copy_net
|
||||
):
|
||||
instance.encoding.load_state_dict(other.encoding.state_dict())
|
||||
instance.feature_network.load_state_dict(
|
||||
other.feature_network.state_dict()
|
||||
)
|
||||
return instance
|
||||
elif isinstance(other, ImplicitVolume):
|
||||
instance = TetrahedraSDFGrid(cfg, **kwargs)
|
||||
if other.cfg.isosurface_method != "mt":
|
||||
other.cfg.isosurface_method = "mt"
|
||||
threestudio.warn(
|
||||
f"Override isosurface_method of the source geometry to 'mt'"
|
||||
)
|
||||
if other.cfg.isosurface_resolution != instance.cfg.isosurface_resolution:
|
||||
other.cfg.isosurface_resolution = instance.cfg.isosurface_resolution
|
||||
threestudio.warn(
|
||||
f"Override isosurface_resolution of the source geometry to {instance.cfg.isosurface_resolution}"
|
||||
)
|
||||
mesh = other.isosurface()
|
||||
instance.isosurface_bbox = mesh.extras["bbox"]
|
||||
instance.sdf.data = (
|
||||
mesh.extras["grid_level"].to(instance.sdf.data).clamp(-1, 1)
|
||||
)
|
||||
if not instance.cfg.geometry_only and copy_net:
|
||||
instance.encoding.load_state_dict(other.encoding.state_dict())
|
||||
instance.feature_network.load_state_dict(
|
||||
other.feature_network.state_dict()
|
||||
)
|
||||
return instance
|
||||
elif isinstance(other, ImplicitSDF):
|
||||
instance = TetrahedraSDFGrid(cfg, **kwargs)
|
||||
if other.cfg.isosurface_method != "mt":
|
||||
other.cfg.isosurface_method = "mt"
|
||||
threestudio.warn(
|
||||
f"Override isosurface_method of the source geometry to 'mt'"
|
||||
)
|
||||
if other.cfg.isosurface_resolution != instance.cfg.isosurface_resolution:
|
||||
other.cfg.isosurface_resolution = instance.cfg.isosurface_resolution
|
||||
threestudio.warn(
|
||||
f"Override isosurface_resolution of the source geometry to {instance.cfg.isosurface_resolution}"
|
||||
)
|
||||
mesh = other.isosurface()
|
||||
instance.isosurface_bbox = mesh.extras["bbox"]
|
||||
instance.sdf.data = mesh.extras["grid_level"].to(instance.sdf.data)
|
||||
if (
|
||||
instance.cfg.isosurface_deformable_grid
|
||||
and other.cfg.isosurface_deformable_grid
|
||||
):
|
||||
assert instance.deformation is not None
|
||||
instance.deformation.data = mesh.extras["grid_deformation"].to(
|
||||
instance.deformation.data
|
||||
)
|
||||
if not instance.cfg.geometry_only and copy_net:
|
||||
instance.encoding.load_state_dict(other.encoding.state_dict())
|
||||
instance.feature_network.load_state_dict(
|
||||
other.feature_network.state_dict()
|
||||
)
|
||||
return instance
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Cannot create {TetrahedraSDFGrid.__name__} from {other.__class__.__name__}"
|
||||
)
|
||||
|
||||
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
|
||||
out: Dict[str, Any] = {}
|
||||
if self.cfg.geometry_only or self.cfg.n_feature_dims == 0:
|
||||
return out
|
||||
points_unscaled = points
|
||||
points = contract_to_unisphere(points_unscaled, self.bbox)
|
||||
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
||||
features = self.feature_network(enc).view(
|
||||
*points.shape[:-1], self.cfg.n_feature_dims
|
||||
)
|
||||
out.update(
|
||||
{
|
||||
"features": features,
|
||||
}
|
||||
)
|
||||
return out
|
190
threestudio/models/geometry/volume_grid.py
Normal file
190
threestudio/models/geometry/volume_grid.py
Normal file
@ -0,0 +1,190 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere
|
||||
from threestudio.utils.ops import get_activation
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("volume-grid")
|
||||
class VolumeGrid(BaseImplicitGeometry):
|
||||
@dataclass
|
||||
class Config(BaseImplicitGeometry.Config):
|
||||
grid_size: Tuple[int, int, int] = field(default_factory=lambda: (100, 100, 100))
|
||||
n_feature_dims: int = 3
|
||||
density_activation: Optional[str] = "softplus"
|
||||
density_bias: Union[float, str] = "blob"
|
||||
density_blob_scale: float = 5.0
|
||||
density_blob_std: float = 0.5
|
||||
normal_type: Optional[
|
||||
str
|
||||
] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian']
|
||||
|
||||
# automatically determine the threshold
|
||||
isosurface_threshold: Union[float, str] = "auto"
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
super().configure()
|
||||
self.grid_size = self.cfg.grid_size
|
||||
|
||||
self.grid = nn.Parameter(
|
||||
torch.zeros(1, self.cfg.n_feature_dims + 1, *self.grid_size)
|
||||
)
|
||||
if self.cfg.density_bias == "blob":
|
||||
self.register_buffer("density_scale", torch.tensor(0.0))
|
||||
else:
|
||||
self.density_scale = nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
if self.cfg.normal_type == "pred":
|
||||
self.normal_grid = nn.Parameter(torch.zeros(1, 3, *self.grid_size))
|
||||
|
||||
def get_density_bias(self, points: Float[Tensor, "*N Di"]):
|
||||
if self.cfg.density_bias == "blob":
|
||||
# density_bias: Float[Tensor, "*N 1"] = self.cfg.density_blob_scale * torch.exp(-0.5 * (points ** 2).sum(dim=-1) / self.cfg.density_blob_std ** 2)[...,None]
|
||||
density_bias: Float[Tensor, "*N 1"] = (
|
||||
self.cfg.density_blob_scale
|
||||
* (
|
||||
1
|
||||
- torch.sqrt((points.detach() ** 2).sum(dim=-1))
|
||||
/ self.cfg.density_blob_std
|
||||
)[..., None]
|
||||
)
|
||||
return density_bias
|
||||
elif isinstance(self.cfg.density_bias, float):
|
||||
return self.cfg.density_bias
|
||||
else:
|
||||
raise AttributeError(f"Unknown density bias {self.cfg.density_bias}")
|
||||
|
||||
def get_trilinear_feature(
|
||||
self, points: Float[Tensor, "*N Di"], grid: Float[Tensor, "1 Df G1 G2 G3"]
|
||||
) -> Float[Tensor, "*N Df"]:
|
||||
points_shape = points.shape[:-1]
|
||||
df = grid.shape[1]
|
||||
di = points.shape[-1]
|
||||
out = F.grid_sample(
|
||||
grid, points.view(1, 1, 1, -1, di), align_corners=False, mode="bilinear"
|
||||
)
|
||||
out = out.reshape(df, -1).T.reshape(*points_shape, df)
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
|
||||
) -> Dict[str, Float[Tensor, "..."]]:
|
||||
points_unscaled = points # points in the original scale
|
||||
points = contract_to_unisphere(
|
||||
points, self.bbox, self.unbounded
|
||||
) # points normalized to (0, 1)
|
||||
points = points * 2 - 1 # convert to [-1, 1] for grid sample
|
||||
|
||||
out = self.get_trilinear_feature(points, self.grid)
|
||||
density, features = out[..., 0:1], out[..., 1:]
|
||||
density = density * torch.exp(self.density_scale) # exp scaling in DreamFusion
|
||||
|
||||
# breakpoint()
|
||||
density = get_activation(self.cfg.density_activation)(
|
||||
density + self.get_density_bias(points_unscaled)
|
||||
)
|
||||
|
||||
output = {
|
||||
"density": density,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
if output_normal:
|
||||
if (
|
||||
self.cfg.normal_type == "finite_difference"
|
||||
or self.cfg.normal_type == "finite_difference_laplacian"
|
||||
):
|
||||
eps = 1.0e-3
|
||||
if self.cfg.normal_type == "finite_difference_laplacian":
|
||||
offsets: Float[Tensor, "6 3"] = torch.as_tensor(
|
||||
[
|
||||
[eps, 0.0, 0.0],
|
||||
[-eps, 0.0, 0.0],
|
||||
[0.0, eps, 0.0],
|
||||
[0.0, -eps, 0.0],
|
||||
[0.0, 0.0, eps],
|
||||
[0.0, 0.0, -eps],
|
||||
]
|
||||
).to(points_unscaled)
|
||||
points_offset: Float[Tensor, "... 6 3"] = (
|
||||
points_unscaled[..., None, :] + offsets
|
||||
).clamp(-self.cfg.radius, self.cfg.radius)
|
||||
density_offset: Float[Tensor, "... 6 1"] = self.forward_density(
|
||||
points_offset
|
||||
)
|
||||
normal = (
|
||||
-0.5
|
||||
* (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0])
|
||||
/ eps
|
||||
)
|
||||
else:
|
||||
offsets: Float[Tensor, "3 3"] = torch.as_tensor(
|
||||
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
|
||||
).to(points_unscaled)
|
||||
points_offset: Float[Tensor, "... 3 3"] = (
|
||||
points_unscaled[..., None, :] + offsets
|
||||
).clamp(-self.cfg.radius, self.cfg.radius)
|
||||
density_offset: Float[Tensor, "... 3 1"] = self.forward_density(
|
||||
points_offset
|
||||
)
|
||||
normal = -(density_offset[..., 0::1, 0] - density) / eps
|
||||
normal = F.normalize(normal, dim=-1)
|
||||
elif self.cfg.normal_type == "pred":
|
||||
normal = self.get_trilinear_feature(points, self.normal_grid)
|
||||
normal = F.normalize(normal, dim=-1)
|
||||
else:
|
||||
raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
|
||||
output.update({"normal": normal, "shading_normal": normal})
|
||||
return output
|
||||
|
||||
def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
|
||||
points_unscaled = points
|
||||
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
||||
points = points * 2 - 1 # convert to [-1, 1] for grid sample
|
||||
|
||||
out = self.get_trilinear_feature(points, self.grid)
|
||||
density = out[..., 0:1]
|
||||
density = density * torch.exp(self.density_scale)
|
||||
|
||||
density = get_activation(self.cfg.density_activation)(
|
||||
density + self.get_density_bias(points_unscaled)
|
||||
)
|
||||
return density
|
||||
|
||||
def forward_field(
|
||||
self, points: Float[Tensor, "*N Di"]
|
||||
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
|
||||
if self.cfg.isosurface_deformable_grid:
|
||||
threestudio.warn(
|
||||
f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring."
|
||||
)
|
||||
density = self.forward_density(points)
|
||||
return density, None
|
||||
|
||||
def forward_level(
|
||||
self, field: Float[Tensor, "*N 1"], threshold: float
|
||||
) -> Float[Tensor, "*N 1"]:
|
||||
return -(field - threshold)
|
||||
|
||||
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
|
||||
out: Dict[str, Any] = {}
|
||||
if self.cfg.n_feature_dims == 0:
|
||||
return out
|
||||
points_unscaled = points
|
||||
points = contract_to_unisphere(points, self.bbox, self.unbounded)
|
||||
points = points * 2 - 1 # convert to [-1, 1] for grid sample
|
||||
features = self.get_trilinear_feature(points, self.grid)[..., 1:]
|
||||
out.update(
|
||||
{
|
||||
"features": features,
|
||||
}
|
||||
)
|
||||
return out
|
13
threestudio/models/guidance/__init__.py
Normal file
13
threestudio/models/guidance/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from . import (
|
||||
controlnet_guidance,
|
||||
controlnet_reg_guidance,
|
||||
deep_floyd_guidance,
|
||||
stable_diffusion_guidance,
|
||||
stable_diffusion_unified_guidance,
|
||||
stable_diffusion_vsd_guidance,
|
||||
stable_diffusion_bsd_guidance,
|
||||
stable_zero123_guidance,
|
||||
zero123_guidance,
|
||||
zero123_unified_guidance,
|
||||
clip_guidance,
|
||||
)
|
84
threestudio/models/guidance/clip_guidance.py
Normal file
84
threestudio/models/guidance/clip_guidance.py
Normal file
@ -0,0 +1,84 @@
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
import clip
|
||||
|
||||
import threestudio
|
||||
from threestudio.utils.base import BaseObject
|
||||
from threestudio.models.prompt_processors.base import PromptProcessorOutput
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("clip-guidance")
|
||||
class CLIPGuidance(BaseObject):
|
||||
@dataclass
|
||||
class Config(BaseObject.Config):
|
||||
cache_dir: Optional[str] = None
|
||||
pretrained_model_name_or_path: str = "ViT-B/16"
|
||||
view_dependent_prompting: bool = True
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
threestudio.info(f"Loading CLIP ...")
|
||||
self.clip_model, self.clip_preprocess = clip.load(
|
||||
self.cfg.pretrained_model_name_or_path,
|
||||
device=self.device,
|
||||
jit=False,
|
||||
download_root=self.cfg.cache_dir
|
||||
)
|
||||
|
||||
self.aug = T.Compose([
|
||||
T.Resize((224, 224)),
|
||||
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
threestudio.info(f"Loaded CLIP!")
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def get_embedding(self, input_value, is_text=True):
|
||||
if is_text:
|
||||
value = clip.tokenize(input_value).to(self.device)
|
||||
z = self.clip_model.encode_text(value)
|
||||
else:
|
||||
input_value = self.aug(input_value)
|
||||
z = self.clip_model.encode_image(input_value)
|
||||
|
||||
return z / z.norm(dim=-1, keepdim=True)
|
||||
|
||||
def get_loss(self, image_z, clip_z, loss_type='similarity_score', use_mean=True):
|
||||
if loss_type == 'similarity_score':
|
||||
loss = -((image_z * clip_z).sum(-1))
|
||||
elif loss_type == 'spherical_dist':
|
||||
image_z, clip_z = F.normalize(image_z, dim=-1), F.normalize(clip_z, dim=-1)
|
||||
loss = ((image_z - clip_z).norm(dim=-1).div(2).arcsin().pow(2).mul(2))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return loss.mean() if use_mean else loss
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
pred_rgb: Float[Tensor, "B H W C"],
|
||||
gt_rgb: Float[Tensor, "B H W C"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
elevation: Float[Tensor, "B"],
|
||||
azimuth: Float[Tensor, "B"],
|
||||
camera_distances: Float[Tensor, "B"],
|
||||
embedding_type: str = 'both',
|
||||
loss_type: Optional[str] = 'similarity_score',
|
||||
**kwargs,
|
||||
):
|
||||
clip_text_loss, clip_img_loss = 0, 0
|
||||
|
||||
if embedding_type in ('both', 'text'):
|
||||
text_embeddings = prompt_utils.get_text_embeddings(
|
||||
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
|
||||
).chunk(2)[0]
|
||||
clip_text_loss = self.get_loss(self.get_embedding(pred_rgb, is_text=False), text_embeddings, loss_type=loss_type)
|
||||
|
||||
if embedding_type in ('both', 'img'):
|
||||
clip_img_loss = self.get_loss(self.get_embedding(pred_rgb, is_text=False), self.get_embedding(gt_rgb, is_text=False), loss_type=loss_type)
|
||||
|
||||
return clip_text_loss + clip_img_loss
|
517
threestudio/models/guidance/controlnet_guidance.py
Normal file
517
threestudio/models/guidance/controlnet_guidance.py
Normal file
@ -0,0 +1,517 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from controlnet_aux import CannyDetector, NormalBaeDetector
|
||||
from diffusers import ControlNetModel, DDIMScheduler, StableDiffusionControlNetPipeline
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from tqdm import tqdm
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.prompt_processors.base import PromptProcessorOutput
|
||||
from threestudio.utils.base import BaseObject
|
||||
from threestudio.utils.misc import C, parse_version
|
||||
from threestudio.utils.perceptual import PerceptualLoss
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("stable-diffusion-controlnet-guidance")
|
||||
class ControlNetGuidance(BaseObject):
|
||||
@dataclass
|
||||
class Config(BaseObject.Config):
|
||||
cache_dir: Optional[str] = None
|
||||
pretrained_model_name_or_path: str = "SG161222/Realistic_Vision_V2.0"
|
||||
ddim_scheduler_name_or_path: str = "runwayml/stable-diffusion-v1-5"
|
||||
control_type: str = "normal" # normal/canny
|
||||
|
||||
enable_memory_efficient_attention: bool = False
|
||||
enable_sequential_cpu_offload: bool = False
|
||||
enable_attention_slicing: bool = False
|
||||
enable_channels_last_format: bool = False
|
||||
guidance_scale: float = 7.5
|
||||
condition_scale: float = 1.5
|
||||
grad_clip: Optional[Any] = None
|
||||
half_precision_weights: bool = True
|
||||
|
||||
fixed_size: int = -1
|
||||
|
||||
min_step_percent: float = 0.02
|
||||
max_step_percent: float = 0.98
|
||||
|
||||
diffusion_steps: int = 20
|
||||
|
||||
use_sds: bool = False
|
||||
|
||||
use_du: bool = False
|
||||
per_du_step: int = 10
|
||||
start_du_step: int = 1000
|
||||
cache_du: bool = False
|
||||
|
||||
# Canny threshold
|
||||
canny_lower_bound: int = 50
|
||||
canny_upper_bound: int = 100
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
threestudio.info(f"Loading ControlNet ...")
|
||||
|
||||
controlnet_name_or_path: str
|
||||
if self.cfg.control_type in ("normal", "input_normal"):
|
||||
controlnet_name_or_path = "lllyasviel/control_v11p_sd15_normalbae"
|
||||
elif self.cfg.control_type == "canny":
|
||||
controlnet_name_or_path = "lllyasviel/control_v11p_sd15_canny"
|
||||
|
||||
self.weights_dtype = (
|
||||
torch.float16 if self.cfg.half_precision_weights else torch.float32
|
||||
)
|
||||
|
||||
pipe_kwargs = {
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"requires_safety_checker": False,
|
||||
"torch_dtype": self.weights_dtype,
|
||||
"cache_dir": self.cfg.cache_dir,
|
||||
}
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
controlnet_name_or_path,
|
||||
torch_dtype=self.weights_dtype,
|
||||
cache_dir=self.cfg.cache_dir,
|
||||
)
|
||||
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
self.cfg.pretrained_model_name_or_path, controlnet=controlnet, **pipe_kwargs
|
||||
).to(self.device)
|
||||
self.scheduler = DDIMScheduler.from_pretrained(
|
||||
self.cfg.ddim_scheduler_name_or_path,
|
||||
subfolder="scheduler",
|
||||
torch_dtype=self.weights_dtype,
|
||||
cache_dir=self.cfg.cache_dir,
|
||||
)
|
||||
self.scheduler.set_timesteps(self.cfg.diffusion_steps)
|
||||
|
||||
if self.cfg.enable_memory_efficient_attention:
|
||||
if parse_version(torch.__version__) >= parse_version("2"):
|
||||
threestudio.info(
|
||||
"PyTorch2.0 uses memory efficient attention by default."
|
||||
)
|
||||
elif not is_xformers_available():
|
||||
threestudio.warn(
|
||||
"xformers is not available, memory efficient attention is not enabled."
|
||||
)
|
||||
else:
|
||||
self.pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if self.cfg.enable_sequential_cpu_offload:
|
||||
self.pipe.enable_sequential_cpu_offload()
|
||||
|
||||
if self.cfg.enable_attention_slicing:
|
||||
self.pipe.enable_attention_slicing(1)
|
||||
|
||||
if self.cfg.enable_channels_last_format:
|
||||
self.pipe.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
# Create model
|
||||
self.vae = self.pipe.vae.eval()
|
||||
self.unet = self.pipe.unet.eval()
|
||||
self.controlnet = self.pipe.controlnet.eval()
|
||||
|
||||
if self.cfg.control_type == "normal":
|
||||
self.preprocessor = NormalBaeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators"
|
||||
)
|
||||
self.preprocessor.model.to(self.device)
|
||||
elif self.cfg.control_type == "canny":
|
||||
self.preprocessor = CannyDetector()
|
||||
|
||||
for p in self.vae.parameters():
|
||||
p.requires_grad_(False)
|
||||
for p in self.unet.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
self.set_min_max_steps() # set to default value
|
||||
|
||||
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
|
||||
self.device
|
||||
)
|
||||
|
||||
self.grad_clip_val: Optional[float] = None
|
||||
|
||||
if self.cfg.use_du:
|
||||
if self.cfg.cache_du:
|
||||
self.edit_frames = {}
|
||||
self.perceptual_loss = PerceptualLoss().eval().to(self.device)
|
||||
|
||||
threestudio.info(f"Loaded ControlNet!")
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
|
||||
self.min_step = int(self.num_train_timesteps * min_step_percent)
|
||||
self.max_step = int(self.num_train_timesteps * max_step_percent)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward_controlnet(
|
||||
self,
|
||||
latents: Float[Tensor, "..."],
|
||||
t: Float[Tensor, "..."],
|
||||
image_cond: Float[Tensor, "..."],
|
||||
condition_scale: float,
|
||||
encoder_hidden_states: Float[Tensor, "..."],
|
||||
) -> Float[Tensor, "..."]:
|
||||
return self.controlnet(
|
||||
latents.to(self.weights_dtype),
|
||||
t.to(self.weights_dtype),
|
||||
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
|
||||
controlnet_cond=image_cond.to(self.weights_dtype),
|
||||
conditioning_scale=condition_scale,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward_control_unet(
|
||||
self,
|
||||
latents: Float[Tensor, "..."],
|
||||
t: Float[Tensor, "..."],
|
||||
encoder_hidden_states: Float[Tensor, "..."],
|
||||
cross_attention_kwargs,
|
||||
down_block_additional_residuals,
|
||||
mid_block_additional_residual,
|
||||
) -> Float[Tensor, "..."]:
|
||||
input_dtype = latents.dtype
|
||||
return self.unet(
|
||||
latents.to(self.weights_dtype),
|
||||
t.to(self.weights_dtype),
|
||||
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
).sample.to(input_dtype)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def encode_images(
|
||||
self, imgs: Float[Tensor, "B 3 H W"]
|
||||
) -> Float[Tensor, "B 4 DH DW"]:
|
||||
input_dtype = imgs.dtype
|
||||
imgs = imgs * 2.0 - 1.0
|
||||
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
|
||||
latents = posterior.sample() * self.vae.config.scaling_factor
|
||||
return latents.to(input_dtype)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def encode_cond_images(
|
||||
self, imgs: Float[Tensor, "B 3 H W"]
|
||||
) -> Float[Tensor, "B 4 DH DW"]:
|
||||
input_dtype = imgs.dtype
|
||||
imgs = imgs * 2.0 - 1.0
|
||||
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
|
||||
latents = posterior.mode()
|
||||
uncond_image_latents = torch.zeros_like(latents)
|
||||
latents = torch.cat([latents, latents, uncond_image_latents], dim=0)
|
||||
return latents.to(input_dtype)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def decode_latents(
|
||||
self, latents: Float[Tensor, "B 4 DH DW"]
|
||||
) -> Float[Tensor, "B 3 H W"]:
|
||||
input_dtype = latents.dtype
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents.to(self.weights_dtype)).sample
|
||||
image = (image * 0.5 + 0.5).clamp(0, 1)
|
||||
return image.to(input_dtype)
|
||||
|
||||
def edit_latents(
|
||||
self,
|
||||
text_embeddings: Float[Tensor, "BB 77 768"],
|
||||
latents: Float[Tensor, "B 4 DH DW"],
|
||||
image_cond: Float[Tensor, "B 3 H W"],
|
||||
t: Int[Tensor, "B"],
|
||||
mask = None
|
||||
) -> Float[Tensor, "B 4 DH DW"]:
|
||||
self.scheduler.config.num_train_timesteps = t.item()
|
||||
self.scheduler.set_timesteps(self.cfg.diffusion_steps)
|
||||
if mask is not None:
|
||||
mask = F.interpolate(mask, (latents.shape[-2], latents.shape[-1]), mode='bilinear')
|
||||
with torch.no_grad():
|
||||
# add noise
|
||||
noise = torch.randn_like(latents)
|
||||
latents = self.scheduler.add_noise(latents, noise, t) # type: ignore
|
||||
|
||||
# sections of code used from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
|
||||
threestudio.debug("Start editing...")
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
# predict the noise residual with unet, NO grad!
|
||||
with torch.no_grad():
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
(
|
||||
down_block_res_samples,
|
||||
mid_block_res_sample,
|
||||
) = self.forward_controlnet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
image_cond=image_cond,
|
||||
condition_scale=self.cfg.condition_scale,
|
||||
)
|
||||
|
||||
noise_pred = self.forward_control_unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
cross_attention_kwargs=None,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
)
|
||||
# perform classifier-free guidance
|
||||
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
if mask is not None:
|
||||
noise_pred = mask * noise_pred + (1 - mask) * noise
|
||||
# get previous sample, continue loop
|
||||
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
||||
threestudio.debug("Editing finished.")
|
||||
return latents
|
||||
|
||||
def prepare_image_cond(self, cond_rgb: Float[Tensor, "B H W C"]):
|
||||
if self.cfg.control_type == "normal":
|
||||
cond_rgb = (
|
||||
(cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy()
|
||||
)
|
||||
detected_map = self.preprocessor(cond_rgb)
|
||||
control = (
|
||||
torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0
|
||||
)
|
||||
control = control.unsqueeze(0)
|
||||
control = control.permute(0, 3, 1, 2)
|
||||
elif self.cfg.control_type == "canny":
|
||||
cond_rgb = (
|
||||
(cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy()
|
||||
)
|
||||
blurred_img = cv2.blur(cond_rgb, ksize=(5, 5))
|
||||
detected_map = self.preprocessor(
|
||||
blurred_img, self.cfg.canny_lower_bound, self.cfg.canny_upper_bound
|
||||
)
|
||||
control = (
|
||||
torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0
|
||||
)
|
||||
# control = control.unsqueeze(-1).repeat(1, 1, 3)
|
||||
control = control.unsqueeze(0)
|
||||
control = control.permute(0, 3, 1, 2)
|
||||
elif self.cfg.control_type == "input_normal":
|
||||
cond_rgb[..., 0] = (
|
||||
1 - cond_rgb[..., 0]
|
||||
) # Flip the sign on the x-axis to match bae system
|
||||
control = cond_rgb.permute(0, 3, 1, 2)
|
||||
else:
|
||||
raise ValueError(f"Unknown control type: {self.cfg.control_type}")
|
||||
|
||||
return control
|
||||
|
||||
def compute_grad_sds(
|
||||
self,
|
||||
text_embeddings: Float[Tensor, "BB 77 768"],
|
||||
latents: Float[Tensor, "B 4 DH DW"],
|
||||
image_cond: Float[Tensor, "B 3 H W"],
|
||||
t: Int[Tensor, "B"],
|
||||
):
|
||||
with torch.no_grad():
|
||||
# add noise
|
||||
noise = torch.randn_like(latents) # TODO: use torch generator
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents_noisy] * 2)
|
||||
down_block_res_samples, mid_block_res_sample = self.forward_controlnet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
image_cond=image_cond,
|
||||
condition_scale=self.cfg.condition_scale,
|
||||
)
|
||||
|
||||
noise_pred = self.forward_control_unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
cross_attention_kwargs=None,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
)
|
||||
|
||||
# perform classifier-free guidance
|
||||
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
|
||||
grad = w * (noise_pred - noise)
|
||||
return grad
|
||||
|
||||
def compute_grad_du(
|
||||
self,
|
||||
latents: Float[Tensor, "B 4 H W"],
|
||||
rgb_BCHW_HW8: Float[Tensor, "B 3 RH RW"],
|
||||
cond_feature: Float[Tensor, "B 3 RH RW"],
|
||||
cond_rgb: Float[Tensor, "B H W 3"],
|
||||
text_embeddings: Float[Tensor, "BB 77 768"],
|
||||
mask = None,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size, _, RH, RW = cond_feature.shape
|
||||
assert batch_size == 1
|
||||
|
||||
origin_gt_rgb = F.interpolate(
|
||||
cond_rgb.permute(0, 3, 1, 2), (RH, RW), mode="bilinear"
|
||||
).permute(0, 2, 3, 1)
|
||||
need_diffusion = (
|
||||
self.global_step % self.cfg.per_du_step == 0
|
||||
and self.global_step > self.cfg.start_du_step
|
||||
)
|
||||
if self.cfg.cache_du:
|
||||
if torch.is_tensor(kwargs["index"]):
|
||||
batch_index = kwargs["index"].item()
|
||||
else:
|
||||
batch_index = kwargs["index"]
|
||||
if (
|
||||
not (batch_index in self.edit_frames)
|
||||
) and self.global_step > self.cfg.start_du_step:
|
||||
need_diffusion = True
|
||||
need_loss = self.cfg.cache_du or need_diffusion
|
||||
guidance_out = {}
|
||||
|
||||
if need_diffusion:
|
||||
t = torch.randint(
|
||||
self.min_step,
|
||||
self.max_step,
|
||||
[1],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
print("t:", t)
|
||||
edit_latents = self.edit_latents(text_embeddings, latents, cond_feature, t, mask)
|
||||
edit_images = self.decode_latents(edit_latents)
|
||||
edit_images = F.interpolate(
|
||||
edit_images, (RH, RW), mode="bilinear"
|
||||
).permute(0, 2, 3, 1)
|
||||
self.edit_images = edit_images
|
||||
if self.cfg.cache_du:
|
||||
self.edit_frames[batch_index] = edit_images.detach().cpu()
|
||||
|
||||
if need_loss:
|
||||
if self.cfg.cache_du:
|
||||
if batch_index in self.edit_frames:
|
||||
gt_rgb = self.edit_frames[batch_index].to(cond_feature.device)
|
||||
else:
|
||||
gt_rgb = origin_gt_rgb
|
||||
else:
|
||||
gt_rgb = edit_images
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8)
|
||||
cv2.imwrite(".threestudio_cache/test.jpg", temp[:, :, ::-1])
|
||||
|
||||
guidance_out.update(
|
||||
{
|
||||
"loss_l1": torch.nn.functional.l1_loss(
|
||||
rgb_BCHW_HW8, gt_rgb.permute(0, 3, 1, 2), reduction="sum"
|
||||
),
|
||||
"loss_p": self.perceptual_loss(
|
||||
rgb_BCHW_HW8.contiguous(),
|
||||
gt_rgb.permute(0, 3, 1, 2).contiguous(),
|
||||
).sum(),
|
||||
}
|
||||
)
|
||||
|
||||
return guidance_out
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
rgb: Float[Tensor, "B H W C"],
|
||||
cond_rgb: Float[Tensor, "B H W C"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
mask = None,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size, H, W, _ = rgb.shape
|
||||
assert batch_size == 1
|
||||
assert rgb.shape[:-1] == cond_rgb.shape[:-1]
|
||||
|
||||
rgb_BCHW = rgb.permute(0, 3, 1, 2)
|
||||
if mask is not None: mask = mask.permute(0, 3, 1, 2)
|
||||
latents: Float[Tensor, "B 4 DH DW"]
|
||||
if self.cfg.fixed_size > 0:
|
||||
RH, RW = self.cfg.fixed_size, self.cfg.fixed_size
|
||||
else:
|
||||
RH, RW = H // 8 * 8, W // 8 * 8
|
||||
rgb_BCHW_HW8 = F.interpolate(
|
||||
rgb_BCHW, (RH, RW), mode="bilinear", align_corners=False
|
||||
)
|
||||
latents = self.encode_images(rgb_BCHW_HW8)
|
||||
|
||||
image_cond = self.prepare_image_cond(cond_rgb)
|
||||
image_cond = F.interpolate(
|
||||
image_cond, (RH, RW), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
temp = torch.zeros(1).to(rgb.device)
|
||||
azimuth = kwargs.get("azimuth", temp)
|
||||
camera_distance = kwargs.get("camera_distance", temp)
|
||||
view_dependent_prompt = kwargs.get("view_dependent_prompt", False)
|
||||
text_embeddings = prompt_utils.get_text_embeddings(temp, azimuth, camera_distance, view_dependent_prompt) # FIXME: change to view-conditioned prompt
|
||||
|
||||
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
||||
t = torch.randint(
|
||||
self.min_step,
|
||||
self.max_step + 1,
|
||||
[batch_size],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
|
||||
guidance_out = {}
|
||||
if self.cfg.use_sds:
|
||||
grad = self.compute_grad_sds(text_embeddings, latents, image_cond, t)
|
||||
grad = torch.nan_to_num(grad)
|
||||
if self.grad_clip_val is not None:
|
||||
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
|
||||
target = (latents - grad).detach()
|
||||
loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
|
||||
guidance_out.update(
|
||||
{
|
||||
"loss_sds": loss_sds,
|
||||
"grad_norm": grad.norm(),
|
||||
"min_step": self.min_step,
|
||||
"max_step": self.max_step,
|
||||
}
|
||||
)
|
||||
|
||||
if self.cfg.use_du:
|
||||
grad = self.compute_grad_du(
|
||||
latents, rgb_BCHW_HW8, image_cond, cond_rgb, text_embeddings, mask, **kwargs
|
||||
)
|
||||
guidance_out.update(grad)
|
||||
|
||||
return guidance_out
|
||||
|
||||
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
# clip grad for stable training as demonstrated in
|
||||
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
|
||||
# http://arxiv.org/abs/2303.15413
|
||||
if self.cfg.grad_clip is not None:
|
||||
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
|
||||
|
||||
self.set_min_max_steps(
|
||||
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
|
||||
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
|
||||
)
|
||||
|
||||
self.global_step = global_step
|
454
threestudio/models/guidance/controlnet_reg_guidance.py
Normal file
454
threestudio/models/guidance/controlnet_reg_guidance.py
Normal file
@ -0,0 +1,454 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from controlnet_aux import CannyDetector, NormalBaeDetector
|
||||
from diffusers import ControlNetModel, DDIMScheduler, StableDiffusionControlNetPipeline, DPMSolverMultistepScheduler
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from tqdm import tqdm
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.prompt_processors.base import PromptProcessorOutput
|
||||
from threestudio.utils.base import BaseObject
|
||||
from threestudio.utils.misc import C, parse_version
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("stable-diffusion-controlnet-reg-guidance")
|
||||
class ControlNetGuidance(BaseObject):
|
||||
@dataclass
|
||||
class Config(BaseObject.Config):
|
||||
cache_dir: Optional[str] = None
|
||||
local_files_only: Optional[bool] = False
|
||||
pretrained_model_name_or_path: str = "SG161222/Realistic_Vision_V2.0"
|
||||
ddim_scheduler_name_or_path: str = "runwayml/stable-diffusion-v1-5"
|
||||
control_type: str = "normal" # normal/canny
|
||||
|
||||
enable_memory_efficient_attention: bool = False
|
||||
enable_sequential_cpu_offload: bool = False
|
||||
enable_attention_slicing: bool = False
|
||||
enable_channels_last_format: bool = False
|
||||
guidance_scale: float = 7.5
|
||||
condition_scale: float = 1.5
|
||||
grad_clip: Optional[Any] = None
|
||||
half_precision_weights: bool = True
|
||||
|
||||
min_step_percent: float = 0.02
|
||||
max_step_percent: float = 0.98
|
||||
|
||||
diffusion_steps: int = 20
|
||||
|
||||
use_sds: bool = False
|
||||
|
||||
# Canny threshold
|
||||
canny_lower_bound: int = 50
|
||||
canny_upper_bound: int = 100
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
threestudio.info(f"Loading ControlNet ...")
|
||||
|
||||
self.weights_dtype = torch.float16 if self.cfg.half_precision_weights else torch.float32
|
||||
|
||||
self.preprocessor, controlnet_name_or_path = self.get_preprocessor_and_controlnet()
|
||||
|
||||
pipe_kwargs = self.configure_pipeline()
|
||||
|
||||
self.load_models(pipe_kwargs, controlnet_name_or_path)
|
||||
|
||||
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
self.scheduler.set_timesteps(self.cfg.diffusion_steps)
|
||||
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
|
||||
self.scheduler = self.pipe.scheduler
|
||||
|
||||
self.check_memory_efficiency_conditions()
|
||||
|
||||
self.set_min_max_steps()
|
||||
self.alphas = self.scheduler.alphas_cumprod.to(self.device)
|
||||
self.grad_clip_val = None
|
||||
|
||||
threestudio.info(f"Loaded ControlNet!")
|
||||
|
||||
def get_preprocessor_and_controlnet(self):
|
||||
if self.cfg.control_type in ("normal", "input_normal"):
|
||||
if self.cfg.pretrained_model_name_or_path == "SG161222/Realistic_Vision_V2.0":
|
||||
controlnet_name_or_path = "lllyasviel/control_v11p_sd15_normalbae"
|
||||
else:
|
||||
controlnet_name_or_path = "thibaud/controlnet-sd21-normalbae-diffusers"
|
||||
preprocessor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators", cache_dir=self.cfg.cache_dir)
|
||||
preprocessor.model.to(self.device)
|
||||
elif self.cfg.control_type == "canny" or self.cfg.control_type == "canny2":
|
||||
controlnet_name_or_path = self.get_canny_controlnet()
|
||||
preprocessor = CannyDetector()
|
||||
else:
|
||||
raise ValueError(f"Unknown control type: {self.cfg.control_type}")
|
||||
return preprocessor, controlnet_name_or_path
|
||||
|
||||
def get_canny_controlnet(self):
|
||||
if self.cfg.control_type == "canny":
|
||||
return "lllyasviel/control_v11p_sd15_canny"
|
||||
elif self.cfg.control_type == "canny2":
|
||||
return "thepowefuldeez/sd21-controlnet-canny"
|
||||
|
||||
def configure_pipeline(self):
|
||||
return {
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"requires_safety_checker": False,
|
||||
"torch_dtype": self.weights_dtype,
|
||||
"cache_dir": self.cfg.cache_dir,
|
||||
"local_files_only": self.cfg.local_files_only
|
||||
}
|
||||
|
||||
def load_models(self, pipe_kwargs, controlnet_name_or_path):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
controlnet_name_or_path,
|
||||
torch_dtype=self.weights_dtype,
|
||||
cache_dir=self.cfg.cache_dir,
|
||||
local_files_only=self.cfg.local_files_only
|
||||
)
|
||||
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
self.cfg.pretrained_model_name_or_path, controlnet=controlnet, **pipe_kwargs
|
||||
).to(self.device)
|
||||
|
||||
self.scheduler = DDIMScheduler.from_pretrained(
|
||||
self.cfg.ddim_scheduler_name_or_path,
|
||||
subfolder="scheduler",
|
||||
torch_dtype=self.weights_dtype,
|
||||
cache_dir=self.cfg.cache_dir,
|
||||
local_files_only=self.cfg.local_files_only
|
||||
)
|
||||
|
||||
self.vae = self.pipe.vae.eval()
|
||||
self.unet = self.pipe.unet.eval()
|
||||
self.controlnet = self.pipe.controlnet.eval()
|
||||
|
||||
def check_memory_efficiency_conditions(self):
|
||||
if self.cfg.enable_memory_efficient_attention:
|
||||
self.memory_efficiency_status()
|
||||
if self.cfg.enable_sequential_cpu_offload:
|
||||
self.pipe.enable_sequential_cpu_offload()
|
||||
if self.cfg.enable_attention_slicing:
|
||||
self.pipe.enable_attention_slicing(1)
|
||||
if self.cfg.enable_channels_last_format:
|
||||
self.pipe.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
def memory_efficiency_status(self):
|
||||
if parse_version(torch.__version__) >= parse_version("2"):
|
||||
threestudio.info("PyTorch2.0 uses memory efficient attention by default.")
|
||||
elif not is_xformers_available():
|
||||
threestudio.warn("xformers is not available, memory efficient attention is not enabled.")
|
||||
else:
|
||||
self.pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
|
||||
self.min_step = int(self.num_train_timesteps * min_step_percent)
|
||||
self.max_step = int(self.num_train_timesteps * max_step_percent)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward_controlnet(
|
||||
self,
|
||||
latents: Float[Tensor, "..."],
|
||||
t: Float[Tensor, "..."],
|
||||
image_cond: Float[Tensor, "..."],
|
||||
condition_scale: float,
|
||||
encoder_hidden_states: Float[Tensor, "..."],
|
||||
) -> Float[Tensor, "..."]:
|
||||
return self.controlnet(
|
||||
latents.to(self.weights_dtype),
|
||||
t.to(self.weights_dtype),
|
||||
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
|
||||
controlnet_cond=image_cond.to(self.weights_dtype),
|
||||
conditioning_scale=condition_scale,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward_control_unet(
|
||||
self,
|
||||
latents: Float[Tensor, "..."],
|
||||
t: Float[Tensor, "..."],
|
||||
encoder_hidden_states: Float[Tensor, "..."],
|
||||
cross_attention_kwargs,
|
||||
down_block_additional_residuals,
|
||||
mid_block_additional_residual,
|
||||
) -> Float[Tensor, "..."]:
|
||||
input_dtype = latents.dtype
|
||||
return self.unet(
|
||||
latents.to(self.weights_dtype),
|
||||
t.to(self.weights_dtype),
|
||||
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
).sample.to(input_dtype)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def encode_images(
|
||||
self, imgs: Float[Tensor, "B 3 512 512"]
|
||||
) -> Float[Tensor, "B 4 64 64"]:
|
||||
input_dtype = imgs.dtype
|
||||
imgs = imgs * 2.0 - 1.0
|
||||
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
|
||||
latents = posterior.sample() * self.vae.config.scaling_factor
|
||||
return latents.to(input_dtype)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def encode_cond_images(
|
||||
self, imgs: Float[Tensor, "B 3 512 512"]
|
||||
) -> Float[Tensor, "B 4 64 64"]:
|
||||
input_dtype = imgs.dtype
|
||||
imgs = imgs * 2.0 - 1.0
|
||||
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
|
||||
latents = posterior.mode()
|
||||
uncond_image_latents = torch.zeros_like(latents)
|
||||
latents = torch.cat([latents, latents, uncond_image_latents], dim=0)
|
||||
return latents.to(input_dtype)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: Float[Tensor, "B 4 H W"],
|
||||
latent_height: int = 64,
|
||||
latent_width: int = 64,
|
||||
) -> Float[Tensor, "B 3 512 512"]:
|
||||
input_dtype = latents.dtype
|
||||
latents = F.interpolate(
|
||||
latents, (latent_height, latent_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents.to(self.weights_dtype)).sample
|
||||
image = (image * 0.5 + 0.5).clamp(0, 1)
|
||||
return image.to(input_dtype)
|
||||
|
||||
def edit_latents(
|
||||
self,
|
||||
text_embeddings: Float[Tensor, "BB 77 768"],
|
||||
latents: Float[Tensor, "B 4 64 64"],
|
||||
image_cond: Float[Tensor, "B 3 512 512"],
|
||||
t: Int[Tensor, "B"],
|
||||
mask=None
|
||||
) -> Float[Tensor, "B 4 64 64"]:
|
||||
batch_size = t.shape[0]
|
||||
self.scheduler.set_timesteps(num_inference_steps=self.cfg.diffusion_steps)
|
||||
init_timestep = max(1, min(int(self.cfg.diffusion_steps * t[0].item() / self.num_train_timesteps), self.cfg.diffusion_steps))
|
||||
t_start = max(self.cfg.diffusion_steps - init_timestep, 0)
|
||||
latent_timestep = self.scheduler.timesteps[t_start : t_start + 1].repeat(batch_size)
|
||||
B, _, DH, DW = latents.shape
|
||||
origin_latents = latents.clone()
|
||||
if mask is not None:
|
||||
mask = F.interpolate(mask, (DH, DW), mode="bilinear", antialias=True)
|
||||
|
||||
with torch.no_grad():
|
||||
# sections of code used from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
|
||||
noise = torch.randn_like(latents)
|
||||
latents = self.scheduler.add_noise(latents, noise, latent_timestep) # type: ignore
|
||||
threestudio.debug("Start editing...")
|
||||
for i, step in enumerate(range(t_start, self.cfg.diffusion_steps)):
|
||||
timestep = self.scheduler.timesteps[step]
|
||||
# predict the noise residual with unet, NO grad!
|
||||
with torch.no_grad():
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
(
|
||||
down_block_res_samples,
|
||||
mid_block_res_sample,
|
||||
) = self.forward_controlnet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
image_cond=image_cond,
|
||||
condition_scale=self.cfg.condition_scale,
|
||||
)
|
||||
|
||||
noise_pred = self.forward_control_unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
cross_attention_kwargs=None,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
)
|
||||
# perform classifier-free guidance
|
||||
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
noise_pred = noise_pred * mask + (1-mask) * noise
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
threestudio.debug("Editing finished.")
|
||||
return latents
|
||||
|
||||
def prepare_image_cond(self, cond_rgb: Float[Tensor, "B H W C"]):
|
||||
if self.cfg.control_type == "normal":
|
||||
cond_rgb = (
|
||||
(cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy()
|
||||
)
|
||||
detected_map = self.preprocessor(cond_rgb)
|
||||
control = (
|
||||
torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0
|
||||
)
|
||||
control = control.unsqueeze(0)
|
||||
control = control.permute(0, 3, 1, 2)
|
||||
elif self.cfg.control_type == "canny" or self.cfg.control_type == "canny2":
|
||||
cond_rgb = (
|
||||
(cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy()
|
||||
)
|
||||
blurred_img = cv2.blur(cond_rgb, ksize=(5, 5))
|
||||
detected_map = self.preprocessor(
|
||||
blurred_img, self.cfg.canny_lower_bound, self.cfg.canny_upper_bound
|
||||
)
|
||||
control = (
|
||||
torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0
|
||||
)
|
||||
control = control.unsqueeze(-1).repeat(1, 1, 3)
|
||||
control = control.unsqueeze(0)
|
||||
control = control.permute(0, 3, 1, 2)
|
||||
elif self.cfg.control_type == "input_normal":
|
||||
cond_rgb[..., 0] = (
|
||||
1 - cond_rgb[..., 0]
|
||||
) # Flip the sign on the x-axis to match bae system
|
||||
control = cond_rgb.permute(0, 3, 1, 2)
|
||||
else:
|
||||
raise ValueError(f"Unknown control type: {self.cfg.control_type}")
|
||||
|
||||
return F.interpolate(control, (512, 512), mode="bilinear", align_corners=False)
|
||||
|
||||
def compute_grad_sds(
|
||||
self,
|
||||
text_embeddings: Float[Tensor, "BB 77 768"],
|
||||
latents: Float[Tensor, "B 4 64 64"],
|
||||
image_cond: Float[Tensor, "B 3 512 512"],
|
||||
t: Int[Tensor, "B"],
|
||||
):
|
||||
with torch.no_grad():
|
||||
# add noise
|
||||
noise = torch.randn_like(latents) # TODO: use torch generator
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents_noisy] * 2)
|
||||
down_block_res_samples, mid_block_res_sample = self.forward_controlnet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
image_cond=image_cond,
|
||||
condition_scale=self.cfg.condition_scale,
|
||||
)
|
||||
|
||||
noise_pred = self.forward_control_unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
cross_attention_kwargs=None,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
)
|
||||
|
||||
# perform classifier-free guidance
|
||||
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
|
||||
grad = w * (noise_pred - noise)
|
||||
return grad
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
rgb: Float[Tensor, "B H W C"],
|
||||
cond_rgb: Float[Tensor, "B H W C"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
mask: Float[Tensor, "B H W C"],
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
batch_size, H, W, _ = rgb.shape
|
||||
|
||||
rgb_BCHW = rgb.permute(0, 3, 1, 2)
|
||||
latents: Float[Tensor, "B 4 64 64"]
|
||||
rgb_BCHW_512 = F.interpolate(
|
||||
rgb_BCHW, (512, 512), mode="bilinear", align_corners=False
|
||||
)
|
||||
latents = self.encode_images(rgb_BCHW_512)
|
||||
|
||||
image_cond = self.prepare_image_cond(cond_rgb)
|
||||
|
||||
temp = torch.zeros(1).to(rgb.device)
|
||||
text_embeddings = prompt_utils.get_text_embeddings(temp, temp, temp, False)
|
||||
|
||||
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
||||
t = torch.randint(
|
||||
self.min_step,
|
||||
self.max_step + 1,
|
||||
[batch_size],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.cfg.use_sds:
|
||||
grad = self.compute_grad_sds(text_embeddings, latents, image_cond, t)
|
||||
grad = torch.nan_to_num(grad)
|
||||
if self.grad_clip_val is not None:
|
||||
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
|
||||
target = (latents - grad).detach()
|
||||
loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
|
||||
return {
|
||||
"loss_sds": loss_sds,
|
||||
"grad_norm": grad.norm(),
|
||||
"min_step": self.min_step,
|
||||
"max_step": self.max_step,
|
||||
}
|
||||
else:
|
||||
|
||||
if mask is not None: mask = mask.permute(0, 3, 1, 2)
|
||||
edit_latents = self.edit_latents(text_embeddings, latents, image_cond, t, mask)
|
||||
edit_images = self.decode_latents(edit_latents)
|
||||
edit_images = F.interpolate(edit_images, (H, W), mode="bilinear")
|
||||
|
||||
return {"edit_images": edit_images.permute(0, 2, 3, 1),
|
||||
"edit_latents": edit_latents}
|
||||
|
||||
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
# clip grad for stable training as demonstrated in
|
||||
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
|
||||
# http://arxiv.org/abs/2303.15413
|
||||
if self.cfg.grad_clip is not None:
|
||||
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
|
||||
|
||||
self.set_min_max_steps(
|
||||
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
|
||||
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from threestudio.utils.config import ExperimentConfig, load_config
|
||||
from threestudio.utils.typing import Optional
|
||||
|
||||
cfg = load_config("configs/experimental/controlnet-normal.yaml")
|
||||
guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance)
|
||||
prompt_processor = threestudio.find(cfg.system.prompt_processor_type)(
|
||||
cfg.system.prompt_processor
|
||||
)
|
||||
|
||||
rgb_image = cv2.imread("assets/face.jpg")[:, :, ::-1].copy() / 255
|
||||
rgb_image = cv2.resize(rgb_image, (512, 512))
|
||||
rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device)
|
||||
prompt_utils = prompt_processor()
|
||||
guidance_out = guidance(rgb_image, rgb_image, prompt_utils)
|
||||
edit_image = (
|
||||
(guidance_out["edit_images"][0].detach().cpu().clip(0, 1).numpy() * 255)
|
||||
.astype(np.uint8)[:, :, ::-1]
|
||||
.copy()
|
||||
)
|
||||
os.makedirs(".threestudio_cache", exist_ok=True)
|
||||
cv2.imwrite(".threestudio_cache/edit_image.jpg", edit_image)
|
582
threestudio/models/guidance/deep_floyd_guidance.py
Normal file
582
threestudio/models/guidance/deep_floyd_guidance.py
Normal file
@ -0,0 +1,582 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers import IFPipeline, DDPMScheduler
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from tqdm import tqdm
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.prompt_processors.base import PromptProcessorOutput
|
||||
from threestudio.utils.base import BaseObject
|
||||
from threestudio.utils.misc import C, parse_version
|
||||
from threestudio.utils.ops import perpendicular_component
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("deep-floyd-guidance")
|
||||
class DeepFloydGuidance(BaseObject):
|
||||
@dataclass
|
||||
class Config(BaseObject.Config):
|
||||
cache_dir: Optional[str] = None
|
||||
local_files_only: Optional[bool] = False
|
||||
pretrained_model_name_or_path: str = "DeepFloyd/IF-I-XL-v1.0"
|
||||
# FIXME: xformers error
|
||||
enable_memory_efficient_attention: bool = False
|
||||
enable_sequential_cpu_offload: bool = False
|
||||
enable_attention_slicing: bool = False
|
||||
enable_channels_last_format: bool = True
|
||||
guidance_scale: float = 20.0
|
||||
grad_clip: Optional[
|
||||
Any
|
||||
] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000])
|
||||
time_prior: Optional[Any] = None # [w1,w2,s1,s2]
|
||||
half_precision_weights: bool = True
|
||||
|
||||
min_step_percent: float = 0.02
|
||||
max_step_percent: float = 0.98
|
||||
|
||||
weighting_strategy: str = "sds"
|
||||
|
||||
view_dependent_prompting: bool = True
|
||||
|
||||
"""Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items."""
|
||||
max_items_eval: int = 4
|
||||
|
||||
lora_weights_path: Optional[str] = None
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
threestudio.info(f"Loading Deep Floyd ...")
|
||||
|
||||
self.weights_dtype = (
|
||||
torch.float16 if self.cfg.half_precision_weights else torch.float32
|
||||
)
|
||||
|
||||
# Create model
|
||||
self.pipe = IFPipeline.from_pretrained(
|
||||
self.cfg.pretrained_model_name_or_path,
|
||||
text_encoder=None,
|
||||
safety_checker=None,
|
||||
watermarker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
variant="fp16" if self.cfg.half_precision_weights else None,
|
||||
torch_dtype=self.weights_dtype,
|
||||
cache_dir=self.cfg.cache_dir,
|
||||
local_files_only=self.cfg.local_files_only
|
||||
).to(self.device)
|
||||
|
||||
# Load lora weights
|
||||
if self.cfg.lora_weights_path is not None:
|
||||
self.pipe.load_lora_weights(self.cfg.lora_weights_path)
|
||||
self.pipe.scheduler = self.pipe.scheduler.__class__.from_config(self.pipe.scheduler.config, variance_type="fixed_small")
|
||||
|
||||
if self.cfg.enable_memory_efficient_attention:
|
||||
if parse_version(torch.__version__) >= parse_version("2"):
|
||||
threestudio.info(
|
||||
"PyTorch2.0 uses memory efficient attention by default."
|
||||
)
|
||||
elif not is_xformers_available():
|
||||
threestudio.warn(
|
||||
"xformers is not available, memory efficient attention is not enabled."
|
||||
)
|
||||
else:
|
||||
threestudio.warn(
|
||||
f"Use DeepFloyd with xformers may raise error, see https://github.com/deep-floyd/IF/issues/52 to track this problem."
|
||||
)
|
||||
self.pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if self.cfg.enable_sequential_cpu_offload:
|
||||
self.pipe.enable_sequential_cpu_offload()
|
||||
|
||||
if self.cfg.enable_attention_slicing:
|
||||
self.pipe.enable_attention_slicing(1)
|
||||
|
||||
if self.cfg.enable_channels_last_format:
|
||||
self.pipe.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
self.unet = self.pipe.unet.eval()
|
||||
|
||||
for p in self.unet.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.scheduler = self.pipe.scheduler
|
||||
|
||||
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
self.set_min_max_steps() # set to default value
|
||||
if self.cfg.time_prior is not None:
|
||||
m1, m2, s1, s2 = self.cfg.time_prior
|
||||
weights = torch.cat(
|
||||
(
|
||||
torch.exp(
|
||||
-((torch.arange(self.num_train_timesteps, m1, -1) - m1) ** 2)
|
||||
/ (2 * s1**2)
|
||||
),
|
||||
torch.ones(m1 - m2 + 1),
|
||||
torch.exp(
|
||||
-((torch.arange(m2 - 1, 0, -1) - m2) ** 2) / (2 * s2**2)
|
||||
),
|
||||
)
|
||||
)
|
||||
weights = weights / torch.sum(weights)
|
||||
self.time_prior_acc_weights = torch.cumsum(weights, dim=0)
|
||||
|
||||
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
|
||||
self.device
|
||||
)
|
||||
|
||||
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(
|
||||
self.device
|
||||
)
|
||||
|
||||
self.grad_clip_val: Optional[float] = None
|
||||
|
||||
threestudio.info(f"Loaded Deep Floyd!")
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
|
||||
self.min_step = int(self.num_train_timesteps * min_step_percent)
|
||||
self.max_step = int(self.num_train_timesteps * max_step_percent)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward_unet(
|
||||
self,
|
||||
latents: Float[Tensor, "..."],
|
||||
t: Float[Tensor, "..."],
|
||||
encoder_hidden_states: Float[Tensor, "..."],
|
||||
) -> Float[Tensor, "..."]:
|
||||
input_dtype = latents.dtype
|
||||
return self.unet(
|
||||
latents.to(self.weights_dtype),
|
||||
t.to(self.weights_dtype),
|
||||
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
|
||||
).sample.to(input_dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
rgb: Float[Tensor, "B H W C"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
elevation: Float[Tensor, "B"],
|
||||
azimuth: Float[Tensor, "B"],
|
||||
camera_distances: Float[Tensor, "B"],
|
||||
current_step_ratio=None,
|
||||
mask: Float[Tensor, "B H W 1"] = None,
|
||||
rgb_as_latents=False,
|
||||
guidance_eval=False,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size = rgb.shape[0]
|
||||
|
||||
rgb_BCHW = rgb.permute(0, 3, 1, 2)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.permute(0, 3, 1, 2)
|
||||
mask = F.interpolate(
|
||||
mask, (64, 64), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
assert rgb_as_latents == False, f"No latent space in {self.__class__.__name__}"
|
||||
rgb_BCHW = rgb_BCHW * 2.0 - 1.0 # scale to [-1, 1] to match the diffusion range
|
||||
latents = F.interpolate(
|
||||
rgb_BCHW, (64, 64), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
if self.cfg.time_prior is not None:
|
||||
time_index = torch.where(
|
||||
(self.time_prior_acc_weights - current_step_ratio) > 0
|
||||
)[0][0]
|
||||
if time_index == 0 or torch.abs(
|
||||
self.time_prior_acc_weights[time_index] - current_step_ratio
|
||||
) < torch.abs(
|
||||
self.time_prior_acc_weights[time_index - 1] - current_step_ratio
|
||||
):
|
||||
t = self.num_train_timesteps - time_index
|
||||
else:
|
||||
t = self.num_train_timesteps - time_index + 1
|
||||
t = torch.clip(t, self.min_step, self.max_step + 1)
|
||||
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
|
||||
|
||||
else:
|
||||
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
||||
t = torch.randint(
|
||||
self.min_step,
|
||||
self.max_step + 1,
|
||||
[batch_size],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if prompt_utils.use_perp_neg:
|
||||
(
|
||||
text_embeddings,
|
||||
neg_guidance_weights,
|
||||
) = prompt_utils.get_text_embeddings_perp_neg(
|
||||
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
|
||||
)
|
||||
with torch.no_grad():
|
||||
noise = torch.randn_like(latents)
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
if mask is not None:
|
||||
latents_noisy = (1 - mask) * latents + mask * latents_noisy
|
||||
latent_model_input = torch.cat([latents_noisy] * 4, dim=0)
|
||||
noise_pred = self.forward_unet(
|
||||
latent_model_input,
|
||||
torch.cat([t] * 4),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
) # (4B, 6, 64, 64)
|
||||
|
||||
noise_pred_text, _ = noise_pred[:batch_size].split(3, dim=1)
|
||||
noise_pred_uncond, _ = noise_pred[batch_size : batch_size * 2].split(
|
||||
3, dim=1
|
||||
)
|
||||
noise_pred_neg, _ = noise_pred[batch_size * 2 :].split(3, dim=1)
|
||||
|
||||
e_pos = noise_pred_text - noise_pred_uncond
|
||||
accum_grad = 0
|
||||
n_negative_prompts = neg_guidance_weights.shape[-1]
|
||||
for i in range(n_negative_prompts):
|
||||
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
|
||||
accum_grad += neg_guidance_weights[:, i].view(
|
||||
-1, 1, 1, 1
|
||||
) * perpendicular_component(e_i_neg, e_pos)
|
||||
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
e_pos + accum_grad
|
||||
)
|
||||
else:
|
||||
neg_guidance_weights = None
|
||||
text_embeddings = prompt_utils.get_text_embeddings(
|
||||
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
|
||||
)
|
||||
# predict the noise residual with unet, NO grad!
|
||||
with torch.no_grad():
|
||||
# add noise
|
||||
noise = torch.randn_like(latents) # TODO: use torch generator
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
if mask is not None:
|
||||
latents_noisy = (1 - mask) * latents + mask * latents_noisy
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
|
||||
noise_pred = self.forward_unet(
|
||||
latent_model_input,
|
||||
torch.cat([t] * 2),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
) # (2B, 6, 64, 64)
|
||||
|
||||
# perform guidance (high scale from paper!)
|
||||
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(3, dim=1)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(3, dim=1)
|
||||
noise_pred = noise_pred_text + self.cfg.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
"""
|
||||
# thresholding, experimental
|
||||
if self.cfg.thresholding:
|
||||
assert batch_size == 1
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
noise_pred = custom_ddpm_step(self.scheduler,
|
||||
noise_pred, int(t.item()), latents_noisy, **self.pipe.prepare_extra_step_kwargs(None, 0.0)
|
||||
)
|
||||
"""
|
||||
|
||||
if self.cfg.weighting_strategy == "sds":
|
||||
# w(t), sigma_t^2
|
||||
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
|
||||
elif self.cfg.weighting_strategy == "uniform":
|
||||
w = 1
|
||||
elif self.cfg.weighting_strategy == "fantasia3d":
|
||||
w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown weighting strategy: {self.cfg.weighting_strategy}"
|
||||
)
|
||||
|
||||
grad = w * (noise_pred - noise)
|
||||
grad = torch.nan_to_num(grad)
|
||||
# clip grad for stable training?
|
||||
if self.grad_clip_val is not None:
|
||||
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
|
||||
|
||||
# loss = SpecifyGradient.apply(latents, grad)
|
||||
# SpecifyGradient is not straghtforward, use a reparameterization trick instead
|
||||
target = (latents - grad).detach()
|
||||
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
|
||||
loss_sd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
|
||||
|
||||
guidance_out = {
|
||||
"loss_sd": loss_sd,
|
||||
"grad_norm": grad.norm(),
|
||||
"min_step": self.min_step,
|
||||
"max_step": self.max_step,
|
||||
}
|
||||
|
||||
# # FIXME: Visualize inpainting results
|
||||
# self.scheduler.set_timesteps(20)
|
||||
# latents = latents_noisy
|
||||
# for t in tqdm(self.scheduler.timesteps):
|
||||
# # pred noise
|
||||
# noise_pred = self.get_noise_pred(
|
||||
# latents, t, text_embeddings, prompt_utils.use_perp_neg, None
|
||||
# )
|
||||
# # get prev latent
|
||||
# prev_latents = latents
|
||||
# latents = self.scheduler.step(noise_pred, t, latents)["prev_sample"]
|
||||
# if mask is not None:
|
||||
# latents = (1 - mask) * prev_latents + mask * latents
|
||||
|
||||
# denoised_img = (latents / 2 + 0.5).permute(0, 2, 3, 1)
|
||||
# guidance_out.update(
|
||||
# {"denoised_img": denoised_img}
|
||||
# )
|
||||
|
||||
if guidance_eval:
|
||||
guidance_eval_utils = {
|
||||
"use_perp_neg": prompt_utils.use_perp_neg,
|
||||
"neg_guidance_weights": neg_guidance_weights,
|
||||
"text_embeddings": text_embeddings,
|
||||
"t_orig": t,
|
||||
"latents_noisy": latents_noisy,
|
||||
"noise_pred": torch.cat([noise_pred, predicted_variance], dim=1),
|
||||
}
|
||||
guidance_eval_out = self.guidance_eval(**guidance_eval_utils)
|
||||
texts = []
|
||||
for n, e, a, c in zip(
|
||||
guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances
|
||||
):
|
||||
texts.append(
|
||||
f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}"
|
||||
)
|
||||
guidance_eval_out.update({"texts": texts})
|
||||
guidance_out.update({"eval": guidance_eval_out})
|
||||
|
||||
return guidance_out
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
@torch.no_grad()
|
||||
def get_noise_pred(
|
||||
self,
|
||||
latents_noisy,
|
||||
t,
|
||||
text_embeddings,
|
||||
use_perp_neg=False,
|
||||
neg_guidance_weights=None,
|
||||
):
|
||||
batch_size = latents_noisy.shape[0]
|
||||
if use_perp_neg:
|
||||
latent_model_input = torch.cat([latents_noisy] * 4, dim=0)
|
||||
noise_pred = self.forward_unet(
|
||||
latent_model_input,
|
||||
torch.cat([t.reshape(1)] * 4).to(self.device),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
) # (4B, 6, 64, 64)
|
||||
|
||||
noise_pred_text, _ = noise_pred[:batch_size].split(3, dim=1)
|
||||
noise_pred_uncond, _ = noise_pred[batch_size : batch_size * 2].split(
|
||||
3, dim=1
|
||||
)
|
||||
noise_pred_neg, _ = noise_pred[batch_size * 2 :].split(3, dim=1)
|
||||
|
||||
e_pos = noise_pred_text - noise_pred_uncond
|
||||
accum_grad = 0
|
||||
n_negative_prompts = neg_guidance_weights.shape[-1]
|
||||
for i in range(n_negative_prompts):
|
||||
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
|
||||
accum_grad += neg_guidance_weights[:, i].view(
|
||||
-1, 1, 1, 1
|
||||
) * perpendicular_component(e_i_neg, e_pos)
|
||||
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
e_pos + accum_grad
|
||||
)
|
||||
else:
|
||||
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
|
||||
noise_pred = self.forward_unet(
|
||||
latent_model_input,
|
||||
torch.cat([t.reshape(1)] * 2).to(self.device),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
) # (2B, 6, 64, 64)
|
||||
|
||||
# perform guidance (high scale from paper!)
|
||||
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(3, dim=1)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(3, dim=1)
|
||||
noise_pred = noise_pred_text + self.cfg.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
return torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
@torch.no_grad()
|
||||
def guidance_eval(
|
||||
self,
|
||||
t_orig,
|
||||
text_embeddings,
|
||||
latents_noisy,
|
||||
noise_pred,
|
||||
use_perp_neg=False,
|
||||
neg_guidance_weights=None,
|
||||
):
|
||||
# use only 50 timesteps, and find nearest of those to t
|
||||
self.scheduler.set_timesteps(50)
|
||||
self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device)
|
||||
bs = (
|
||||
min(self.cfg.max_items_eval, latents_noisy.shape[0])
|
||||
if self.cfg.max_items_eval > 0
|
||||
else latents_noisy.shape[0]
|
||||
) # batch size
|
||||
large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[
|
||||
:bs
|
||||
].unsqueeze(
|
||||
-1
|
||||
) # sized [bs,50] > [bs,1]
|
||||
idxs = torch.min(large_enough_idxs, dim=1)[1]
|
||||
t = self.scheduler.timesteps_gpu[idxs]
|
||||
|
||||
fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy())
|
||||
imgs_noisy = (latents_noisy[:bs] / 2 + 0.5).permute(0, 2, 3, 1)
|
||||
|
||||
# get prev latent
|
||||
latents_1step = []
|
||||
pred_1orig = []
|
||||
for b in range(bs):
|
||||
step_output = self.scheduler.step(
|
||||
noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1]
|
||||
)
|
||||
latents_1step.append(step_output["prev_sample"])
|
||||
pred_1orig.append(step_output["pred_original_sample"])
|
||||
latents_1step = torch.cat(latents_1step)
|
||||
pred_1orig = torch.cat(pred_1orig)
|
||||
imgs_1step = (latents_1step / 2 + 0.5).permute(0, 2, 3, 1)
|
||||
imgs_1orig = (pred_1orig / 2 + 0.5).permute(0, 2, 3, 1)
|
||||
|
||||
latents_final = []
|
||||
for b, i in enumerate(idxs):
|
||||
latents = latents_1step[b : b + 1]
|
||||
text_emb = (
|
||||
text_embeddings[
|
||||
[b, b + len(idxs), b + 2 * len(idxs), b + 3 * len(idxs)], ...
|
||||
]
|
||||
if use_perp_neg
|
||||
else text_embeddings[[b, b + len(idxs)], ...]
|
||||
)
|
||||
neg_guid = neg_guidance_weights[b : b + 1] if use_perp_neg else None
|
||||
for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False):
|
||||
# pred noise
|
||||
noise_pred = self.get_noise_pred(
|
||||
latents, t, text_emb, use_perp_neg, neg_guid
|
||||
)
|
||||
# get prev latent
|
||||
latents = self.scheduler.step(noise_pred, t, latents)["prev_sample"]
|
||||
latents_final.append(latents)
|
||||
|
||||
latents_final = torch.cat(latents_final)
|
||||
imgs_final = (latents_final / 2 + 0.5).permute(0, 2, 3, 1)
|
||||
|
||||
return {
|
||||
"bs": bs,
|
||||
"noise_levels": fracs,
|
||||
"imgs_noisy": imgs_noisy,
|
||||
"imgs_1step": imgs_1step,
|
||||
"imgs_1orig": imgs_1orig,
|
||||
"imgs_final": imgs_final,
|
||||
}
|
||||
|
||||
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
# clip grad for stable training as demonstrated in
|
||||
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
|
||||
# http://arxiv.org/abs/2303.15413
|
||||
if self.cfg.grad_clip is not None:
|
||||
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
|
||||
|
||||
self.set_min_max_steps(
|
||||
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
|
||||
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
# used by thresholding, experimental
|
||||
def custom_ddpm_step(ddpm, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True):
|
||||
self = ddpm
|
||||
t = timestep
|
||||
|
||||
prev_t = self.previous_timestep(t)
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
predicted_variance = None
|
||||
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[t].item()
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_t].item() if prev_t >= 0 else 1.0
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
||||
current_beta_t = 1 - current_alpha_t
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
||||
" `v_prediction` for the DDPMScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip or threshold "predicted x_0"
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
elif self.config.clip_sample:
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
noise_thresholded = (sample - (alpha_prod_t ** 0.5) * pred_original_sample) / (beta_prod_t ** 0.5)
|
||||
return noise_thresholded
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from threestudio.utils.config import load_config
|
||||
import pytorch_lightning as pl
|
||||
import numpy as np
|
||||
import os
|
||||
import cv2
|
||||
cfg = load_config("configs/debugging/deepfloyd.yaml")
|
||||
guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance)
|
||||
prompt_processor = threestudio.find(cfg.system.prompt_processor_type)(cfg.system.prompt_processor)
|
||||
prompt_utils = prompt_processor()
|
||||
temp = torch.zeros(1).to(guidance.device)
|
||||
# rgb_image = guidance.sample(prompt_utils, temp, temp, temp, seed=cfg.seed)
|
||||
# rgb_image = (rgb_image[0].detach().cpu().clip(0, 1).numpy()*255).astype(np.uint8)[:, :, ::-1].copy()
|
||||
# os.makedirs('.threestudio_cache', exist_ok=True)
|
||||
# cv2.imwrite('.threestudio_cache/diffusion_image.jpg', rgb_image)
|
||||
|
||||
### inpaint
|
||||
rgb_image = cv2.imread("assets/test.jpg")[:, :, ::-1].copy() / 255
|
||||
mask_image = cv2.imread("assets/mask.png")[:, :, :1].copy() / 255
|
||||
rgb_image = cv2.resize(rgb_image, (512, 512))
|
||||
mask_image = cv2.resize(mask_image, (512, 512)).reshape(512, 512, 1)
|
||||
rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device)
|
||||
mask_image = torch.FloatTensor(mask_image).unsqueeze(0).to(guidance.device)
|
||||
|
||||
guidance_out = guidance(rgb_image, prompt_utils, temp, temp, temp, mask=mask_image)
|
||||
edit_image = (
|
||||
(guidance_out["denoised_img"][0].detach().cpu().clip(0, 1).numpy() * 255)
|
||||
.astype(np.uint8)[:, :, ::-1]
|
||||
.copy()
|
||||
)
|
||||
os.makedirs(".threestudio_cache", exist_ok=True)
|
||||
cv2.imwrite(".threestudio_cache/edit_image.jpg", edit_image)
|
1134
threestudio/models/guidance/stable_diffusion_bsd_guidance.py
Normal file
1134
threestudio/models/guidance/stable_diffusion_bsd_guidance.py
Normal file
File diff suppressed because it is too large
Load Diff
632
threestudio/models/guidance/stable_diffusion_guidance.py
Normal file
632
threestudio/models/guidance/stable_diffusion_guidance.py
Normal file
@ -0,0 +1,632 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, StableDiffusionPipeline
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from tqdm import tqdm
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.prompt_processors.base import PromptProcessorOutput
|
||||
from threestudio.utils.base import BaseObject
|
||||
from threestudio.utils.misc import C, cleanup, parse_version
|
||||
from threestudio.utils.ops import perpendicular_component
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("stable-diffusion-guidance")
|
||||
class StableDiffusionGuidance(BaseObject):
|
||||
@dataclass
|
||||
class Config(BaseObject.Config):
|
||||
cache_dir: Optional[str] = None
|
||||
local_files_only: Optional[bool] = False
|
||||
pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
|
||||
enable_memory_efficient_attention: bool = False
|
||||
enable_sequential_cpu_offload: bool = False
|
||||
enable_attention_slicing: bool = False
|
||||
enable_channels_last_format: bool = False
|
||||
guidance_scale: float = 100.0
|
||||
grad_clip: Optional[
|
||||
Any
|
||||
] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000])
|
||||
time_prior: Optional[Any] = None # [w1,w2,s1,s2]
|
||||
half_precision_weights: bool = True
|
||||
|
||||
min_step_percent: float = 0.02
|
||||
max_step_percent: float = 0.98
|
||||
max_step_percent_annealed: float = 0.5
|
||||
anneal_start_step: Optional[int] = None
|
||||
|
||||
use_sjc: bool = False
|
||||
var_red: bool = True
|
||||
weighting_strategy: str = "sds"
|
||||
|
||||
token_merging: bool = False
|
||||
token_merging_params: Optional[dict] = field(default_factory=dict)
|
||||
|
||||
view_dependent_prompting: bool = True
|
||||
|
||||
"""Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items."""
|
||||
max_items_eval: int = 4
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
threestudio.info(f"Loading Stable Diffusion ...")
|
||||
|
||||
self.weights_dtype = (
|
||||
torch.float16 if self.cfg.half_precision_weights else torch.float32
|
||||
)
|
||||
|
||||
pipe_kwargs = {
|
||||
"tokenizer": None,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"requires_safety_checker": False,
|
||||
"torch_dtype": self.weights_dtype,
|
||||
"cache_dir": self.cfg.cache_dir,
|
||||
"local_files_only": self.cfg.local_files_only
|
||||
}
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(
|
||||
self.cfg.pretrained_model_name_or_path,
|
||||
**pipe_kwargs,
|
||||
).to(self.device)
|
||||
|
||||
if self.cfg.enable_memory_efficient_attention:
|
||||
if parse_version(torch.__version__) >= parse_version("2"):
|
||||
threestudio.info(
|
||||
"PyTorch2.0 uses memory efficient attention by default."
|
||||
)
|
||||
elif not is_xformers_available():
|
||||
threestudio.warn(
|
||||
"xformers is not available, memory efficient attention is not enabled."
|
||||
)
|
||||
else:
|
||||
self.pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if self.cfg.enable_sequential_cpu_offload:
|
||||
self.pipe.enable_sequential_cpu_offload()
|
||||
|
||||
if self.cfg.enable_attention_slicing:
|
||||
self.pipe.enable_attention_slicing(1)
|
||||
|
||||
if self.cfg.enable_channels_last_format:
|
||||
self.pipe.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
del self.pipe.text_encoder
|
||||
cleanup()
|
||||
|
||||
# Create model
|
||||
self.vae = self.pipe.vae.eval()
|
||||
self.unet = self.pipe.unet.eval()
|
||||
|
||||
for p in self.vae.parameters():
|
||||
p.requires_grad_(False)
|
||||
for p in self.unet.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
if self.cfg.token_merging:
|
||||
import tomesd
|
||||
|
||||
tomesd.apply_patch(self.unet, **self.cfg.token_merging_params)
|
||||
|
||||
if self.cfg.use_sjc:
|
||||
# score jacobian chaining use DDPM
|
||||
self.scheduler = DDPMScheduler.from_pretrained(
|
||||
self.cfg.pretrained_model_name_or_path,
|
||||
subfolder="scheduler",
|
||||
torch_dtype=self.weights_dtype,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.0120,
|
||||
beta_schedule="scaled_linear",
|
||||
cache_dir=self.cfg.cache_dir,
|
||||
)
|
||||
else:
|
||||
self.scheduler = DDIMScheduler.from_pretrained(
|
||||
self.cfg.pretrained_model_name_or_path,
|
||||
subfolder="scheduler",
|
||||
torch_dtype=self.weights_dtype,
|
||||
cache_dir=self.cfg.cache_dir,
|
||||
local_files_only=self.cfg.local_files_only,
|
||||
)
|
||||
|
||||
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
self.set_min_max_steps() # set to default value
|
||||
if self.cfg.time_prior is not None:
|
||||
m1, m2, s1, s2 = self.cfg.time_prior
|
||||
weights = torch.cat(
|
||||
(
|
||||
torch.exp(
|
||||
-((torch.arange(self.num_train_timesteps, m1, -1) - m1) ** 2)
|
||||
/ (2 * s1**2)
|
||||
),
|
||||
torch.ones(m1 - m2 + 1),
|
||||
torch.exp(
|
||||
-((torch.arange(m2 - 1, 0, -1) - m2) ** 2) / (2 * s2**2)
|
||||
),
|
||||
)
|
||||
)
|
||||
weights = weights / torch.sum(weights)
|
||||
self.time_prior_acc_weights = torch.cumsum(weights, dim=0)
|
||||
|
||||
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
|
||||
self.device
|
||||
)
|
||||
if self.cfg.use_sjc:
|
||||
# score jacobian chaining need mu
|
||||
self.us: Float[Tensor, "..."] = torch.sqrt((1 - self.alphas) / self.alphas)
|
||||
|
||||
self.grad_clip_val: Optional[float] = None
|
||||
|
||||
threestudio.info(f"Loaded Stable Diffusion!")
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
|
||||
self.min_step = int(self.num_train_timesteps * min_step_percent)
|
||||
self.max_step = int(self.num_train_timesteps * max_step_percent)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward_unet(
|
||||
self,
|
||||
latents: Float[Tensor, "..."],
|
||||
t: Float[Tensor, "..."],
|
||||
encoder_hidden_states: Float[Tensor, "..."],
|
||||
) -> Float[Tensor, "..."]:
|
||||
input_dtype = latents.dtype
|
||||
return self.unet(
|
||||
latents.to(self.weights_dtype),
|
||||
t.to(self.weights_dtype),
|
||||
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
|
||||
).sample.to(input_dtype)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def encode_images(
|
||||
self, imgs: Float[Tensor, "B 3 512 512"]
|
||||
) -> Float[Tensor, "B 4 64 64"]:
|
||||
input_dtype = imgs.dtype
|
||||
imgs = imgs * 2.0 - 1.0
|
||||
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
|
||||
latents = posterior.sample() * self.vae.config.scaling_factor
|
||||
return latents.to(input_dtype)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: Float[Tensor, "B 4 H W"],
|
||||
latent_height: int = 64,
|
||||
latent_width: int = 64,
|
||||
) -> Float[Tensor, "B 3 512 512"]:
|
||||
input_dtype = latents.dtype
|
||||
latents = F.interpolate(
|
||||
latents, (latent_height, latent_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents.to(self.weights_dtype)).sample
|
||||
image = (image * 0.5 + 0.5).clamp(0, 1)
|
||||
return image.to(input_dtype)
|
||||
|
||||
def compute_grad_sds(
|
||||
self,
|
||||
latents: Float[Tensor, "B 4 64 64"],
|
||||
t: Int[Tensor, "B"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
elevation: Float[Tensor, "B"],
|
||||
azimuth: Float[Tensor, "B"],
|
||||
camera_distances: Float[Tensor, "B"],
|
||||
):
|
||||
batch_size = elevation.shape[0]
|
||||
|
||||
if prompt_utils.use_perp_neg:
|
||||
(
|
||||
text_embeddings,
|
||||
neg_guidance_weights,
|
||||
) = prompt_utils.get_text_embeddings_perp_neg(
|
||||
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
|
||||
)
|
||||
with torch.no_grad():
|
||||
noise = torch.randn_like(latents)
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
latent_model_input = torch.cat([latents_noisy] * 4, dim=0)
|
||||
noise_pred = self.forward_unet(
|
||||
latent_model_input,
|
||||
torch.cat([t] * 4),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
) # (4B, 3, 64, 64)
|
||||
|
||||
noise_pred_text = noise_pred[:batch_size]
|
||||
noise_pred_uncond = noise_pred[batch_size : batch_size * 2]
|
||||
noise_pred_neg = noise_pred[batch_size * 2 :]
|
||||
|
||||
e_pos = noise_pred_text - noise_pred_uncond
|
||||
accum_grad = 0
|
||||
n_negative_prompts = neg_guidance_weights.shape[-1]
|
||||
for i in range(n_negative_prompts):
|
||||
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
|
||||
accum_grad += neg_guidance_weights[:, i].view(
|
||||
-1, 1, 1, 1
|
||||
) * perpendicular_component(e_i_neg, e_pos)
|
||||
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
e_pos + accum_grad
|
||||
)
|
||||
else:
|
||||
neg_guidance_weights = None
|
||||
text_embeddings = prompt_utils.get_text_embeddings(
|
||||
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
|
||||
)
|
||||
# predict the noise residual with unet, NO grad!
|
||||
with torch.no_grad():
|
||||
# add noise
|
||||
noise = torch.randn_like(latents) # TODO: use torch generator
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
|
||||
noise_pred = self.forward_unet(
|
||||
latent_model_input,
|
||||
torch.cat([t] * 2),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
)
|
||||
|
||||
# perform guidance (high scale from paper!)
|
||||
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + self.cfg.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
if self.cfg.weighting_strategy == "sds":
|
||||
# w(t), sigma_t^2
|
||||
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
|
||||
elif self.cfg.weighting_strategy == "uniform":
|
||||
w = 1
|
||||
elif self.cfg.weighting_strategy == "fantasia3d":
|
||||
w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown weighting strategy: {self.cfg.weighting_strategy}"
|
||||
)
|
||||
|
||||
grad = w * (noise_pred - noise)
|
||||
|
||||
guidance_eval_utils = {
|
||||
"use_perp_neg": prompt_utils.use_perp_neg,
|
||||
"neg_guidance_weights": neg_guidance_weights,
|
||||
"text_embeddings": text_embeddings,
|
||||
"t_orig": t,
|
||||
"latents_noisy": latents_noisy,
|
||||
"noise_pred": noise_pred,
|
||||
}
|
||||
|
||||
return grad, guidance_eval_utils
|
||||
|
||||
def compute_grad_sjc(
|
||||
self,
|
||||
latents: Float[Tensor, "B 4 64 64"],
|
||||
t: Int[Tensor, "B"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
elevation: Float[Tensor, "B"],
|
||||
azimuth: Float[Tensor, "B"],
|
||||
camera_distances: Float[Tensor, "B"],
|
||||
):
|
||||
batch_size = elevation.shape[0]
|
||||
|
||||
sigma = self.us[t]
|
||||
sigma = sigma.view(-1, 1, 1, 1)
|
||||
|
||||
if prompt_utils.use_perp_neg:
|
||||
(
|
||||
text_embeddings,
|
||||
neg_guidance_weights,
|
||||
) = prompt_utils.get_text_embeddings_perp_neg(
|
||||
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
|
||||
)
|
||||
with torch.no_grad():
|
||||
noise = torch.randn_like(latents)
|
||||
y = latents
|
||||
zs = y + sigma * noise
|
||||
scaled_zs = zs / torch.sqrt(1 + sigma**2)
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([scaled_zs] * 4, dim=0)
|
||||
noise_pred = self.forward_unet(
|
||||
latent_model_input,
|
||||
torch.cat([t] * 4),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
) # (4B, 3, 64, 64)
|
||||
|
||||
noise_pred_text = noise_pred[:batch_size]
|
||||
noise_pred_uncond = noise_pred[batch_size : batch_size * 2]
|
||||
noise_pred_neg = noise_pred[batch_size * 2 :]
|
||||
|
||||
e_pos = noise_pred_text - noise_pred_uncond
|
||||
accum_grad = 0
|
||||
n_negative_prompts = neg_guidance_weights.shape[-1]
|
||||
for i in range(n_negative_prompts):
|
||||
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
|
||||
accum_grad += neg_guidance_weights[:, i].view(
|
||||
-1, 1, 1, 1
|
||||
) * perpendicular_component(e_i_neg, e_pos)
|
||||
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
e_pos + accum_grad
|
||||
)
|
||||
else:
|
||||
neg_guidance_weights = None
|
||||
text_embeddings = prompt_utils.get_text_embeddings(
|
||||
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
|
||||
)
|
||||
# predict the noise residual with unet, NO grad!
|
||||
with torch.no_grad():
|
||||
# add noise
|
||||
noise = torch.randn_like(latents) # TODO: use torch generator
|
||||
y = latents
|
||||
|
||||
zs = y + sigma * noise
|
||||
scaled_zs = zs / torch.sqrt(1 + sigma**2)
|
||||
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([scaled_zs] * 2, dim=0)
|
||||
noise_pred = self.forward_unet(
|
||||
latent_model_input,
|
||||
torch.cat([t] * 2),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
)
|
||||
|
||||
# perform guidance (high scale from paper!)
|
||||
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + self.cfg.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
Ds = zs - sigma * noise_pred
|
||||
|
||||
if self.cfg.var_red:
|
||||
grad = -(Ds - y) / sigma
|
||||
else:
|
||||
grad = -(Ds - zs) / sigma
|
||||
|
||||
guidance_eval_utils = {
|
||||
"use_perp_neg": prompt_utils.use_perp_neg,
|
||||
"neg_guidance_weights": neg_guidance_weights,
|
||||
"text_embeddings": text_embeddings,
|
||||
"t_orig": t,
|
||||
"latents_noisy": scaled_zs,
|
||||
"noise_pred": noise_pred,
|
||||
}
|
||||
|
||||
return grad, guidance_eval_utils
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
rgb: Float[Tensor, "B H W C"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
elevation: Float[Tensor, "B"],
|
||||
azimuth: Float[Tensor, "B"],
|
||||
camera_distances: Float[Tensor, "B"],
|
||||
rgb_as_latents=False,
|
||||
guidance_eval=False,
|
||||
current_step_ratio=None,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size = rgb.shape[0]
|
||||
|
||||
rgb_BCHW = rgb.permute(0, 3, 1, 2)
|
||||
latents: Float[Tensor, "B 4 64 64"]
|
||||
if rgb_as_latents:
|
||||
latents = F.interpolate(
|
||||
rgb_BCHW, (64, 64), mode="bilinear", align_corners=False
|
||||
)
|
||||
else:
|
||||
rgb_BCHW_512 = F.interpolate(
|
||||
rgb_BCHW, (512, 512), mode="bilinear", align_corners=False
|
||||
)
|
||||
# encode image into latents with vae
|
||||
latents = self.encode_images(rgb_BCHW_512)
|
||||
|
||||
if self.cfg.time_prior is not None:
|
||||
time_index = torch.where(
|
||||
(self.time_prior_acc_weights - current_step_ratio) > 0
|
||||
)[0][0]
|
||||
if time_index == 0 or torch.abs(
|
||||
self.time_prior_acc_weights[time_index] - current_step_ratio
|
||||
) < torch.abs(
|
||||
self.time_prior_acc_weights[time_index - 1] - current_step_ratio
|
||||
):
|
||||
t = self.num_train_timesteps - time_index
|
||||
else:
|
||||
t = self.num_train_timesteps - time_index + 1
|
||||
t = torch.clip(t, self.min_step, self.max_step + 1)
|
||||
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
|
||||
|
||||
else:
|
||||
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
||||
t = torch.randint(
|
||||
self.min_step,
|
||||
self.max_step + 1,
|
||||
[batch_size],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.cfg.use_sjc:
|
||||
grad, guidance_eval_utils = self.compute_grad_sjc(
|
||||
latents, t, prompt_utils, elevation, azimuth, camera_distances
|
||||
)
|
||||
else:
|
||||
grad, guidance_eval_utils = self.compute_grad_sds(
|
||||
latents, t, prompt_utils, elevation, azimuth, camera_distances
|
||||
)
|
||||
|
||||
grad = torch.nan_to_num(grad)
|
||||
# clip grad for stable training?
|
||||
if self.grad_clip_val is not None:
|
||||
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
|
||||
|
||||
# loss = SpecifyGradient.apply(latents, grad)
|
||||
# SpecifyGradient is not straghtforward, use a reparameterization trick instead
|
||||
target = (latents - grad).detach()
|
||||
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
|
||||
loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
|
||||
|
||||
guidance_out = {
|
||||
"loss_sd": loss_sds,
|
||||
"grad_norm": grad.norm(),
|
||||
"min_step": self.min_step,
|
||||
"max_step": self.max_step,
|
||||
}
|
||||
|
||||
if guidance_eval:
|
||||
guidance_eval_out = self.guidance_eval(**guidance_eval_utils)
|
||||
texts = []
|
||||
for n, e, a, c in zip(
|
||||
guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances
|
||||
):
|
||||
texts.append(
|
||||
f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}"
|
||||
)
|
||||
guidance_eval_out.update({"texts": texts})
|
||||
guidance_out.update({"eval": guidance_eval_out})
|
||||
|
||||
return guidance_out
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
@torch.no_grad()
|
||||
def get_noise_pred(
|
||||
self,
|
||||
latents_noisy,
|
||||
t,
|
||||
text_embeddings,
|
||||
use_perp_neg=False,
|
||||
neg_guidance_weights=None,
|
||||
):
|
||||
batch_size = latents_noisy.shape[0]
|
||||
|
||||
if use_perp_neg:
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents_noisy] * 4, dim=0)
|
||||
noise_pred = self.forward_unet(
|
||||
latent_model_input,
|
||||
torch.cat([t.reshape(1)] * 4).to(self.device),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
) # (4B, 3, 64, 64)
|
||||
|
||||
noise_pred_text = noise_pred[:batch_size]
|
||||
noise_pred_uncond = noise_pred[batch_size : batch_size * 2]
|
||||
noise_pred_neg = noise_pred[batch_size * 2 :]
|
||||
|
||||
e_pos = noise_pred_text - noise_pred_uncond
|
||||
accum_grad = 0
|
||||
n_negative_prompts = neg_guidance_weights.shape[-1]
|
||||
for i in range(n_negative_prompts):
|
||||
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
|
||||
accum_grad += neg_guidance_weights[:, i].view(
|
||||
-1, 1, 1, 1
|
||||
) * perpendicular_component(e_i_neg, e_pos)
|
||||
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
e_pos + accum_grad
|
||||
)
|
||||
else:
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
|
||||
noise_pred = self.forward_unet(
|
||||
latent_model_input,
|
||||
torch.cat([t.reshape(1)] * 2).to(self.device),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
)
|
||||
# perform guidance (high scale from paper!)
|
||||
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + self.cfg.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
return noise_pred
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
@torch.no_grad()
|
||||
def guidance_eval(
|
||||
self,
|
||||
t_orig,
|
||||
text_embeddings,
|
||||
latents_noisy,
|
||||
noise_pred,
|
||||
use_perp_neg=False,
|
||||
neg_guidance_weights=None,
|
||||
):
|
||||
# use only 50 timesteps, and find nearest of those to t
|
||||
self.scheduler.set_timesteps(50)
|
||||
self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device)
|
||||
bs = (
|
||||
min(self.cfg.max_items_eval, latents_noisy.shape[0])
|
||||
if self.cfg.max_items_eval > 0
|
||||
else latents_noisy.shape[0]
|
||||
) # batch size
|
||||
large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[
|
||||
:bs
|
||||
].unsqueeze(
|
||||
-1
|
||||
) # sized [bs,50] > [bs,1]
|
||||
idxs = torch.min(large_enough_idxs, dim=1)[1]
|
||||
t = self.scheduler.timesteps_gpu[idxs]
|
||||
|
||||
fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy())
|
||||
imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1)
|
||||
|
||||
# get prev latent
|
||||
latents_1step = []
|
||||
pred_1orig = []
|
||||
for b in range(bs):
|
||||
step_output = self.scheduler.step(
|
||||
noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1
|
||||
)
|
||||
latents_1step.append(step_output["prev_sample"])
|
||||
pred_1orig.append(step_output["pred_original_sample"])
|
||||
latents_1step = torch.cat(latents_1step)
|
||||
pred_1orig = torch.cat(pred_1orig)
|
||||
imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1)
|
||||
imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1)
|
||||
|
||||
latents_final = []
|
||||
for b, i in enumerate(idxs):
|
||||
latents = latents_1step[b : b + 1]
|
||||
text_emb = (
|
||||
text_embeddings[
|
||||
[b, b + len(idxs), b + 2 * len(idxs), b + 3 * len(idxs)], ...
|
||||
]
|
||||
if use_perp_neg
|
||||
else text_embeddings[[b, b + len(idxs)], ...]
|
||||
)
|
||||
neg_guid = neg_guidance_weights[b : b + 1] if use_perp_neg else None
|
||||
for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False):
|
||||
# pred noise
|
||||
noise_pred = self.get_noise_pred(
|
||||
latents, t, text_emb, use_perp_neg, neg_guid
|
||||
)
|
||||
# get prev latent
|
||||
latents = self.scheduler.step(noise_pred, t, latents, eta=1)[
|
||||
"prev_sample"
|
||||
]
|
||||
latents_final.append(latents)
|
||||
|
||||
latents_final = torch.cat(latents_final)
|
||||
imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1)
|
||||
|
||||
return {
|
||||
"bs": bs,
|
||||
"noise_levels": fracs,
|
||||
"imgs_noisy": imgs_noisy,
|
||||
"imgs_1step": imgs_1step,
|
||||
"imgs_1orig": imgs_1orig,
|
||||
"imgs_final": imgs_final,
|
||||
}
|
||||
|
||||
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
# clip grad for stable training as demonstrated in
|
||||
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
|
||||
# http://arxiv.org/abs/2303.15413
|
||||
if self.cfg.grad_clip is not None:
|
||||
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
|
||||
|
||||
self.set_min_max_steps(
|
||||
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
|
||||
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
|
||||
)
|
729
threestudio/models/guidance/stable_diffusion_unified_guidance.py
Normal file
729
threestudio/models/guidance/stable_diffusion_unified_guidance.py
Normal file
@ -0,0 +1,729 @@
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.models.embeddings import TimestepEmbedding
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from tqdm import tqdm
|
||||
|
||||
import threestudio
|
||||
from threestudio.models.networks import ToDTypeWrapper
|
||||
from threestudio.models.prompt_processors.base import PromptProcessorOutput
|
||||
from threestudio.utils.base import BaseModule
|
||||
from threestudio.utils.misc import C, cleanup, enable_gradient, parse_version
|
||||
from threestudio.utils.ops import perpendicular_component
|
||||
from threestudio.utils.typing import *
|
||||
|
||||
|
||||
@threestudio.register("stable-diffusion-unified-guidance")
|
||||
class StableDiffusionUnifiedGuidance(BaseModule):
|
||||
@dataclass
|
||||
class Config(BaseModule.Config):
|
||||
cache_dir: Optional[str] = None
|
||||
local_files_only: Optional[bool] = False
|
||||
|
||||
# guidance type, in ["sds", "vsd"]
|
||||
guidance_type: str = "sds"
|
||||
|
||||
pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
|
||||
guidance_scale: float = 100.0
|
||||
weighting_strategy: str = "dreamfusion"
|
||||
view_dependent_prompting: bool = True
|
||||
|
||||
min_step_percent: Any = 0.02
|
||||
max_step_percent: Any = 0.98
|
||||
grad_clip: Optional[Any] = None
|
||||
|
||||
return_rgb_1step_orig: bool = False
|
||||
return_rgb_multistep_orig: bool = False
|
||||
n_rgb_multistep_orig_steps: int = 4
|
||||
|
||||
# TODO
|
||||
# controlnet
|
||||
controlnet_model_name_or_path: Optional[str] = None
|
||||
preprocessor: Optional[str] = None
|
||||
control_scale: float = 1.0
|
||||
|
||||
# TODO
|
||||
# lora
|
||||
lora_model_name_or_path: Optional[str] = None
|
||||
|
||||
# efficiency-related configurations
|
||||
half_precision_weights: bool = True
|
||||
enable_memory_efficient_attention: bool = False
|
||||
enable_sequential_cpu_offload: bool = False
|
||||
enable_attention_slicing: bool = False
|
||||
enable_channels_last_format: bool = False
|
||||
token_merging: bool = False
|
||||
token_merging_params: Optional[dict] = field(default_factory=dict)
|
||||
|
||||
# VSD configurations, only used when guidance_type is "vsd"
|
||||
vsd_phi_model_name_or_path: Optional[str] = None
|
||||
vsd_guidance_scale_phi: float = 1.0
|
||||
vsd_use_lora: bool = True
|
||||
vsd_lora_cfg_training: bool = False
|
||||
vsd_lora_n_timestamp_samples: int = 1
|
||||
vsd_use_camera_condition: bool = True
|
||||
# camera condition type, in ["extrinsics", "mvp", "spherical"]
|
||||
vsd_camera_condition_type: Optional[str] = "extrinsics"
|
||||
|
||||
cfg: Config
|
||||
|
||||
def configure(self) -> None:
|
||||
self.min_step: Optional[int] = None
|
||||
self.max_step: Optional[int] = None
|
||||
self.grad_clip_val: Optional[float] = None
|
||||
|
||||
@dataclass
|
||||
class NonTrainableModules:
|
||||
pipe: StableDiffusionPipeline
|
||||
pipe_phi: Optional[StableDiffusionPipeline] = None
|
||||
controlnet: Optional[ControlNetModel] = None
|
||||
|
||||
self.weights_dtype = (
|
||||
torch.float16 if self.cfg.half_precision_weights else torch.float32
|
||||
)
|
||||
|
||||
threestudio.info(f"Loading Stable Diffusion ...")
|
||||
|
||||
pipe_kwargs = {
|
||||
"tokenizer": None,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"requires_safety_checker": False,
|
||||
"torch_dtype": self.weights_dtype,
|
||||
"cache_dir": self.cfg.cache_dir,
|
||||
"local_files_only": self.cfg.local_files_only,
|
||||
}
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
self.cfg.pretrained_model_name_or_path,
|
||||
**pipe_kwargs,
|
||||
).to(self.device)
|
||||
self.prepare_pipe(pipe)
|
||||
self.configure_pipe_token_merging(pipe)
|
||||
|
||||
# phi network for VSD
|
||||
# introduce two trainable modules:
|
||||
# - self.camera_embedding
|
||||
# - self.lora_layers
|
||||
pipe_phi = None
|
||||
|
||||
# if the phi network shares the same unet with the pretrain network
|
||||
# we need to pass additional cross attention kwargs to the unet
|
||||
self.vsd_share_model = (
|
||||
self.cfg.guidance_type == "vsd"
|
||||
and self.cfg.vsd_phi_model_name_or_path is None
|
||||
)
|
||||
if self.cfg.guidance_type == "vsd":
|
||||
if self.cfg.vsd_phi_model_name_or_path is None:
|
||||
pipe_phi = pipe
|
||||
else:
|
||||
pipe_phi = StableDiffusionPipeline.from_pretrained(
|
||||
self.cfg.vsd_phi_model_name_or_path,
|
||||
**pipe_kwargs,
|
||||
).to(self.device)
|
||||
self.prepare_pipe(pipe_phi)
|
||||
self.configure_pipe_token_merging(pipe_phi)
|
||||
|
||||
# set up camera embedding
|
||||
if self.cfg.vsd_use_camera_condition:
|
||||
if self.cfg.vsd_camera_condition_type in ["extrinsics", "mvp"]:
|
||||
self.camera_embedding_dim = 16
|
||||
elif self.cfg.vsd_camera_condition_type == "spherical":
|
||||
self.camera_embedding_dim = 4
|
||||
else:
|
||||
raise ValueError("Invalid camera condition type!")
|
||||
|
||||
# FIXME: hard-coded output dim
|
||||
self.camera_embedding = ToDTypeWrapper(
|
||||
TimestepEmbedding(self.camera_embedding_dim, 1280),
|
||||
self.weights_dtype,
|
||||
).to(self.device)
|
||||
pipe_phi.unet.class_embedding = self.camera_embedding
|
||||
|
||||
if self.cfg.vsd_use_lora:
|
||||
# set up LoRA layers
|
||||
lora_attn_procs = {}
|
||||
for name in pipe_phi.unet.attn_processors.keys():
|
||||
cross_attention_dim = (
|
||||
None
|
||||
if name.endswith("attn1.processor")
|
||||
else pipe_phi.unet.config.cross_attention_dim
|
||||
)
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = pipe_phi.unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(
|
||||
reversed(pipe_phi.unet.config.block_out_channels)
|
||||
)[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = pipe_phi.unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
|
||||
pipe_phi.unet.set_attn_processor(lora_attn_procs)
|
||||
|
||||
self.lora_layers = AttnProcsLayers(pipe_phi.unet.attn_processors).to(
|
||||
self.device
|
||||
)
|
||||
self.lora_layers._load_state_dict_pre_hooks.clear()
|
||||
self.lora_layers._state_dict_hooks.clear()
|
||||
|
||||
threestudio.info(f"Loaded Stable Diffusion!")
|
||||
|
||||
# controlnet
|
||||
controlnet = None
|
||||
if self.cfg.controlnet_model_name_or_path is not None:
|
||||
threestudio.info(f"Loading ControlNet ...")
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
self.cfg.controlnet_model_name_or_path,
|
||||
torch_dtype=self.weights_dtype,
|
||||
).to(self.device)
|
||||
controlnet.eval()
|
||||
enable_gradient(controlnet, enabled=False)
|
||||
|
||||
threestudio.info(f"Loaded ControlNet!")
|
||||
|
||||
self.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
|
||||
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
|
||||
# q(z_t|x) = N(alpha_t x, sigma_t^2 I)
|
||||
# in DDPM, alpha_t = sqrt(alphas_cumprod_t), sigma_t^2 = 1 - alphas_cumprod_t
|
||||
self.alphas_cumprod: Float[Tensor, "T"] = self.scheduler.alphas_cumprod.to(
|
||||
self.device
|
||||
)
|
||||
self.alphas: Float[Tensor, "T"] = self.alphas_cumprod**0.5
|
||||
self.sigmas: Float[Tensor, "T"] = (1 - self.alphas_cumprod) ** 0.5
|
||||
# log SNR
|
||||
self.lambdas: Float[Tensor, "T"] = self.sigmas / self.alphas
|
||||
|
||||
self._non_trainable_modules = NonTrainableModules(
|
||||
pipe=pipe,
|
||||
pipe_phi=pipe_phi,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
@property
|
||||
def pipe(self) -> StableDiffusionPipeline:
|
||||
return self._non_trainable_modules.pipe
|
||||
|
||||
@property
|
||||
def pipe_phi(self) -> StableDiffusionPipeline:
|
||||
if self._non_trainable_modules.pipe_phi is None:
|
||||
raise RuntimeError("phi model is not available.")
|
||||
return self._non_trainable_modules.pipe_phi
|
||||
|
||||
@property
|
||||
def controlnet(self) -> ControlNetModel:
|
||||
if self._non_trainable_modules.controlnet is None:
|
||||
raise RuntimeError("ControlNet model is not available.")
|
||||
return self._non_trainable_modules.controlnet
|
||||
|
||||
def prepare_pipe(self, pipe: StableDiffusionPipeline):
|
||||
if self.cfg.enable_memory_efficient_attention:
|
||||
if parse_version(torch.__version__) >= parse_version("2"):
|
||||
threestudio.info(
|
||||
"PyTorch2.0 uses memory efficient attention by default."
|
||||
)
|
||||
elif not is_xformers_available():
|
||||
threestudio.warn(
|
||||
"xformers is not available, memory efficient attention is not enabled."
|
||||
)
|
||||
else:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if self.cfg.enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
if self.cfg.enable_attention_slicing:
|
||||
pipe.enable_attention_slicing(1)
|
||||
|
||||
if self.cfg.enable_channels_last_format:
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
# FIXME: pipe.__call__ requires text_encoder.dtype
|
||||
# pipe.text_encoder.to("meta")
|
||||
cleanup()
|
||||
|
||||
pipe.vae.eval()
|
||||
pipe.unet.eval()
|
||||
|
||||
enable_gradient(pipe.vae, enabled=False)
|
||||
enable_gradient(pipe.unet, enabled=False)
|
||||
|
||||
# disable progress bar
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
def configure_pipe_token_merging(self, pipe: StableDiffusionPipeline):
|
||||
if self.cfg.token_merging:
|
||||
import tomesd
|
||||
|
||||
tomesd.apply_patch(pipe.unet, **self.cfg.token_merging_params)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def forward_unet(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
latents: Float[Tensor, "..."],
|
||||
t: Int[Tensor, "..."],
|
||||
encoder_hidden_states: Float[Tensor, "..."],
|
||||
class_labels: Optional[Float[Tensor, "..."]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
down_block_additional_residuals: Optional[Float[Tensor, "..."]] = None,
|
||||
mid_block_additional_residual: Optional[Float[Tensor, "..."]] = None,
|
||||
velocity_to_epsilon: bool = False,
|
||||
) -> Float[Tensor, "..."]:
|
||||
input_dtype = latents.dtype
|
||||
pred = unet(
|
||||
latents.to(unet.dtype),
|
||||
t.to(unet.dtype),
|
||||
encoder_hidden_states=encoder_hidden_states.to(unet.dtype),
|
||||
class_labels=class_labels,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
).sample
|
||||
if velocity_to_epsilon:
|
||||
pred = latents * self.sigmas[t].view(-1, 1, 1, 1) + pred * self.alphas[
|
||||
t
|
||||
].view(-1, 1, 1, 1)
|
||||
return pred.to(input_dtype)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def vae_encode(
|
||||
self, vae: AutoencoderKL, imgs: Float[Tensor, "B 3 H W"], mode=False
|
||||
) -> Float[Tensor, "B 4 Hl Wl"]:
|
||||
# expect input in [-1, 1]
|
||||
input_dtype = imgs.dtype
|
||||
posterior = vae.encode(imgs.to(vae.dtype)).latent_dist
|
||||
if mode:
|
||||
latents = posterior.mode()
|
||||
else:
|
||||
latents = posterior.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
return latents.to(input_dtype)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def vae_decode(
|
||||
self, vae: AutoencoderKL, latents: Float[Tensor, "B 4 Hl Wl"]
|
||||
) -> Float[Tensor, "B 3 H W"]:
|
||||
# output in [0, 1]
|
||||
input_dtype = latents.dtype
|
||||
latents = 1 / vae.config.scaling_factor * latents
|
||||
image = vae.decode(latents.to(vae.dtype)).sample
|
||||
image = (image * 0.5 + 0.5).clamp(0, 1)
|
||||
return image.to(input_dtype)
|
||||
|
||||
@contextmanager
|
||||
def disable_unet_class_embedding(self, unet: UNet2DConditionModel):
|
||||
class_embedding = unet.class_embedding
|
||||
try:
|
||||
unet.class_embedding = None
|
||||
yield unet
|
||||
finally:
|
||||
unet.class_embedding = class_embedding
|
||||
|
||||
@contextmanager
|
||||
def set_scheduler(
|
||||
self, pipe: StableDiffusionPipeline, scheduler_class: Any, **kwargs
|
||||
):
|
||||
scheduler_orig = pipe.scheduler
|
||||
pipe.scheduler = scheduler_class.from_config(scheduler_orig.config, **kwargs)
|
||||
yield pipe
|
||||
pipe.scheduler = scheduler_orig
|
||||
|
||||
def get_eps_pretrain(
|
||||
self,
|
||||
latents_noisy: Float[Tensor, "B 4 Hl Wl"],
|
||||
t: Int[Tensor, "B"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
elevation: Float[Tensor, "B"],
|
||||
azimuth: Float[Tensor, "B"],
|
||||
camera_distances: Float[Tensor, "B"],
|
||||
) -> Float[Tensor, "B 4 Hl Wl"]:
|
||||
batch_size = latents_noisy.shape[0]
|
||||
|
||||
if prompt_utils.use_perp_neg:
|
||||
(
|
||||
text_embeddings,
|
||||
neg_guidance_weights,
|
||||
) = prompt_utils.get_text_embeddings_perp_neg(
|
||||
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
|
||||
)
|
||||
with torch.no_grad():
|
||||
with self.disable_unet_class_embedding(self.pipe.unet) as unet:
|
||||
noise_pred = self.forward_unet(
|
||||
unet,
|
||||
torch.cat([latents_noisy] * 4, dim=0),
|
||||
torch.cat([t] * 4, dim=0),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
cross_attention_kwargs={"scale": 0.0}
|
||||
if self.vsd_share_model
|
||||
else None,
|
||||
velocity_to_epsilon=self.pipe.scheduler.config.prediction_type
|
||||
== "v_prediction",
|
||||
) # (4B, 3, Hl, Wl)
|
||||
|
||||
noise_pred_text = noise_pred[:batch_size]
|
||||
noise_pred_uncond = noise_pred[batch_size : batch_size * 2]
|
||||
noise_pred_neg = noise_pred[batch_size * 2 :]
|
||||
|
||||
e_pos = noise_pred_text - noise_pred_uncond
|
||||
accum_grad = 0
|
||||
n_negative_prompts = neg_guidance_weights.shape[-1]
|
||||
for i in range(n_negative_prompts):
|
||||
e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond
|
||||
accum_grad += neg_guidance_weights[:, i].view(
|
||||
-1, 1, 1, 1
|
||||
) * perpendicular_component(e_i_neg, e_pos)
|
||||
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
e_pos + accum_grad
|
||||
)
|
||||
else:
|
||||
text_embeddings = prompt_utils.get_text_embeddings(
|
||||
elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting
|
||||
)
|
||||
with torch.no_grad():
|
||||
with self.disable_unet_class_embedding(self.pipe.unet) as unet:
|
||||
noise_pred = self.forward_unet(
|
||||
unet,
|
||||
torch.cat([latents_noisy] * 2, dim=0),
|
||||
torch.cat([t] * 2, dim=0),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
cross_attention_kwargs={"scale": 0.0}
|
||||
if self.vsd_share_model
|
||||
else None,
|
||||
velocity_to_epsilon=self.pipe.scheduler.config.prediction_type
|
||||
== "v_prediction",
|
||||
)
|
||||
|
||||
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
return noise_pred
|
||||
|
||||
def get_eps_phi(
|
||||
self,
|
||||
latents_noisy: Float[Tensor, "B 4 Hl Wl"],
|
||||
t: Int[Tensor, "B"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
elevation: Float[Tensor, "B"],
|
||||
azimuth: Float[Tensor, "B"],
|
||||
camera_distances: Float[Tensor, "B"],
|
||||
camera_condition: Float[Tensor, "B ..."],
|
||||
) -> Float[Tensor, "B 4 Hl Wl"]:
|
||||
batch_size = latents_noisy.shape[0]
|
||||
|
||||
# not using view-dependent prompting in LoRA
|
||||
text_embeddings, _ = prompt_utils.get_text_embeddings(
|
||||
elevation, azimuth, camera_distances, view_dependent_prompting=False
|
||||
).chunk(2)
|
||||
with torch.no_grad():
|
||||
noise_pred = self.forward_unet(
|
||||
self.pipe_phi.unet,
|
||||
torch.cat([latents_noisy] * 2, dim=0),
|
||||
torch.cat([t] * 2, dim=0),
|
||||
encoder_hidden_states=torch.cat([text_embeddings] * 2, dim=0),
|
||||
class_labels=torch.cat(
|
||||
[
|
||||
camera_condition.view(batch_size, -1),
|
||||
torch.zeros_like(camera_condition.view(batch_size, -1)),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
if self.cfg.vsd_use_camera_condition
|
||||
else None,
|
||||
cross_attention_kwargs={"scale": 1.0},
|
||||
velocity_to_epsilon=self.pipe_phi.scheduler.config.prediction_type
|
||||
== "v_prediction",
|
||||
)
|
||||
|
||||
noise_pred_camera, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg.vsd_guidance_scale_phi * (
|
||||
noise_pred_camera - noise_pred_uncond
|
||||
)
|
||||
|
||||
return noise_pred
|
||||
|
||||
def train_phi(
|
||||
self,
|
||||
latents: Float[Tensor, "B 4 Hl Wl"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
elevation: Float[Tensor, "B"],
|
||||
azimuth: Float[Tensor, "B"],
|
||||
camera_distances: Float[Tensor, "B"],
|
||||
camera_condition: Float[Tensor, "B ..."],
|
||||
):
|
||||
B = latents.shape[0]
|
||||
latents = latents.detach().repeat(
|
||||
self.cfg.vsd_lora_n_timestamp_samples, 1, 1, 1
|
||||
)
|
||||
|
||||
num_train_timesteps = self.pipe_phi.scheduler.config.num_train_timesteps
|
||||
t = torch.randint(
|
||||
int(num_train_timesteps * 0.0),
|
||||
int(num_train_timesteps * 1.0),
|
||||
[B * self.cfg.vsd_lora_n_timestamp_samples],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
latents_noisy = self.pipe_phi.scheduler.add_noise(latents, noise, t)
|
||||
if self.pipe_phi.scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif self.pipe_phi.scheduler.prediction_type == "v_prediction":
|
||||
target = self.pipe_phi.scheduler.get_velocity(latents, noise, t)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown prediction type {self.pipe_phi.scheduler.prediction_type}"
|
||||
)
|
||||
|
||||
# not using view-dependent prompting in LoRA
|
||||
text_embeddings, _ = prompt_utils.get_text_embeddings(
|
||||
elevation, azimuth, camera_distances, view_dependent_prompting=False
|
||||
).chunk(2)
|
||||
|
||||
if (
|
||||
self.cfg.vsd_use_camera_condition
|
||||
and self.cfg.vsd_lora_cfg_training
|
||||
and random.random() < 0.1
|
||||
):
|
||||
camera_condition = torch.zeros_like(camera_condition)
|
||||
|
||||
noise_pred = self.forward_unet(
|
||||
self.pipe_phi.unet,
|
||||
latents_noisy,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings.repeat(
|
||||
self.cfg.vsd_lora_n_timestamp_samples, 1, 1
|
||||
),
|
||||
class_labels=camera_condition.view(B, -1).repeat(
|
||||
self.cfg.vsd_lora_n_timestamp_samples, 1
|
||||
)
|
||||
if self.cfg.vsd_use_camera_condition
|
||||
else None,
|
||||
cross_attention_kwargs={"scale": 1.0},
|
||||
)
|
||||
return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
rgb: Float[Tensor, "B H W C"],
|
||||
prompt_utils: PromptProcessorOutput,
|
||||
elevation: Float[Tensor, "B"],
|
||||
azimuth: Float[Tensor, "B"],
|
||||
camera_distances: Float[Tensor, "B"],
|
||||
mvp_mtx: Float[Tensor, "B 4 4"],
|
||||
c2w: Float[Tensor, "B 4 4"],
|
||||
rgb_as_latents=False,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size = rgb.shape[0]
|
||||
|
||||
rgb_BCHW = rgb.permute(0, 3, 1, 2)
|
||||
latents: Float[Tensor, "B 4 Hl Wl"]
|
||||
if rgb_as_latents:
|
||||
# treat input rgb as latents
|
||||
# input rgb should be in range [-1, 1]
|
||||
latents = F.interpolate(
|
||||
rgb_BCHW, (64, 64), mode="bilinear", align_corners=False
|
||||
)
|
||||
else:
|
||||
# treat input rgb as rgb
|
||||
# input rgb should be in range [0, 1]
|
||||
rgb_BCHW = F.interpolate(
|
||||
rgb_BCHW, (512, 512), mode="bilinear", align_corners=False
|
||||
)
|
||||
# encode image into latents with vae
|
||||
latents = self.vae_encode(self.pipe.vae, rgb_BCHW * 2.0 - 1.0)
|
||||
|
||||
# sample timestep
|
||||
# use the same timestep for each batch
|
||||
assert self.min_step is not None and self.max_step is not None
|
||||
t = torch.randint(
|
||||
self.min_step,
|
||||
self.max_step + 1,
|
||||
[1],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
).repeat(batch_size)
|
||||
|
||||
# sample noise
|
||||
noise = torch.randn_like(latents)
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
|
||||
eps_pretrain = self.get_eps_pretrain(
|
||||
latents_noisy, t, prompt_utils, elevation, azimuth, camera_distances
|
||||
)
|
||||
|
||||
latents_1step_orig = (
|
||||
1
|
||||
/ self.alphas[t].view(-1, 1, 1, 1)
|
||||
* (latents_noisy - self.sigmas[t].view(-1, 1, 1, 1) * eps_pretrain)
|
||||
).detach()
|
||||
|
||||
if self.cfg.guidance_type == "sds":
|
||||
eps_phi = noise
|
||||
elif self.cfg.guidance_type == "vsd":
|
||||
if self.cfg.vsd_camera_condition_type == "extrinsics":
|
||||
camera_condition = c2w
|
||||
elif self.cfg.vsd_camera_condition_type == "mvp":
|
||||
camera_condition = mvp_mtx
|
||||
elif self.cfg.vsd_camera_condition_type == "spherical":
|
||||
camera_condition = torch.stack(
|
||||
[
|
||||
torch.deg2rad(elevation),
|
||||
torch.sin(torch.deg2rad(azimuth)),
|
||||
torch.cos(torch.deg2rad(azimuth)),
|
||||
camera_distances,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown camera_condition_type {self.cfg.vsd_camera_condition_type}"
|
||||
)
|
||||
eps_phi = self.get_eps_phi(
|
||||
latents_noisy,
|
||||
t,
|
||||
prompt_utils,
|
||||
elevation,
|
||||
azimuth,
|
||||
camera_distances,
|
||||
camera_condition,
|
||||
)
|
||||
|
||||
loss_train_phi = self.train_phi(
|
||||
latents,
|
||||
prompt_utils,
|
||||
elevation,
|
||||
azimuth,
|
||||
camera_distances,
|
||||
camera_condition,
|
||||
)
|
||||
|
||||
if self.cfg.weighting_strategy == "dreamfusion":
|
||||
w = (1.0 - self.alphas[t]).view(-1, 1, 1, 1)
|
||||
elif self.cfg.weighting_strategy == "uniform":
|
||||
w = 1.0
|
||||
elif self.cfg.weighting_strategy == "fantasia3d":
|
||||
w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown weighting strategy: {self.cfg.weighting_strategy}"
|
||||
)
|
||||
|
||||
grad = w * (eps_pretrain - eps_phi)
|
||||
|
||||
if self.grad_clip_val is not None:
|
||||
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
|
||||
|
||||
# reparameterization trick:
|
||||
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
|
||||
target = (latents - grad).detach()
|
||||
loss_sd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
|
||||
|
||||
guidance_out = {
|
||||
"loss_sd": loss_sd,
|
||||
"grad_norm": grad.norm(),
|
||||
"timesteps": t,
|
||||
"min_step": self.min_step,
|
||||
"max_step": self.max_step,
|
||||
"latents": latents,
|
||||
"latents_1step_orig": latents_1step_orig,
|
||||
"rgb": rgb_BCHW.permute(0, 2, 3, 1),
|
||||
"weights": w,
|
||||
"lambdas": self.lambdas[t],
|
||||
}
|
||||
|
||||
if self.cfg.return_rgb_1step_orig:
|
||||
with torch.no_grad():
|
||||
rgb_1step_orig = self.vae_decode(
|
||||
self.pipe.vae, latents_1step_orig
|
||||
).permute(0, 2, 3, 1)
|
||||
guidance_out.update({"rgb_1step_orig": rgb_1step_orig})
|
||||
|
||||
if self.cfg.return_rgb_multistep_orig:
|
||||
with self.set_scheduler(
|
||||
self.pipe,
|
||||
DPMSolverSinglestepScheduler,
|
||||
solver_order=1,
|
||||
num_train_timesteps=int(t[0]),
|
||||
) as pipe:
|
||||
text_embeddings = prompt_utils.get_text_embeddings(
|
||||
elevation,
|
||||
azimuth,
|
||||
camera_distances,
|
||||
self.cfg.view_dependent_prompting,
|
||||
)
|
||||
text_embeddings_cond, text_embeddings_uncond = text_embeddings.chunk(2)
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
latents_multistep_orig = pipe(
|
||||
num_inference_steps=self.cfg.n_rgb_multistep_orig_steps,
|
||||
guidance_scale=self.cfg.guidance_scale,
|
||||
eta=1.0,
|
||||
latents=latents_noisy.to(pipe.unet.dtype),
|
||||
prompt_embeds=text_embeddings_cond.to(pipe.unet.dtype),
|
||||
negative_prompt_embeds=text_embeddings_uncond.to(
|
||||
pipe.unet.dtype
|
||||
),
|
||||
cross_attention_kwargs={"scale": 0.0}
|
||||
if self.vsd_share_model
|
||||
else None,
|
||||
output_type="latent",
|
||||
).images.to(latents.dtype)
|
||||
with torch.no_grad():
|
||||
rgb_multistep_orig = self.vae_decode(
|
||||
self.pipe.vae, latents_multistep_orig
|
||||
)
|
||||
guidance_out.update(
|
||||
{
|
||||
"latents_multistep_orig": latents_multistep_orig,
|
||||
"rgb_multistep_orig": rgb_multistep_orig.permute(0, 2, 3, 1),
|
||||
}
|
||||
)
|
||||
|
||||
if self.cfg.guidance_type == "vsd":
|
||||
guidance_out.update(
|
||||
{
|
||||
"loss_train_phi": loss_train_phi,
|
||||
}
|
||||
)
|
||||
|
||||
return guidance_out
|
||||
|
||||
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||||
# clip grad for stable training as demonstrated in
|
||||
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
|
||||
# http://arxiv.org/abs/2303.15413
|
||||
if self.cfg.grad_clip is not None:
|
||||
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
|
||||
|
||||
self.min_step = int(
|
||||
self.num_train_timesteps * C(self.cfg.min_step_percent, epoch, global_step)
|
||||
)
|
||||
self.max_step = int(
|
||||
self.num_train_timesteps * C(self.cfg.max_step_percent, epoch, global_step)
|
||||
)
|
1003
threestudio/models/guidance/stable_diffusion_vsd_guidance.py
Normal file
1003
threestudio/models/guidance/stable_diffusion_vsd_guidance.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user