import os import random import shutil from dataclasses import dataclass, field import torch import torch.nn.functional as F from PIL import Image, ImageDraw from torchmetrics import PearsonCorrCoef import threestudio from import BaseLift3DSystem from threestudio.utils.ops import binary_cross_entropy, dot from threestudio.utils.typing import * @threestudio.register("zero123-system") class Zero123(BaseLift3DSystem): @dataclass class Config(BaseLift3DSystem.Config): freq: dict = field(default_factory=dict) refinement: bool = False ambient_ratio_min: float = 0.5 cfg: Config def configure(self): # create geometry, material, background, renderer super().configure() def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: render_out = self.renderer(**batch) return { **render_out, } def on_fit_start(self) -> None: super().on_fit_start() # no prompt processor self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance) # visualize all training images all_images = self.trainer.datamodule.train_dataloader().dataset.get_all_images() self.save_image_grid( "all_training_images.png", [ {"type": "rgb", "img": image, "kwargs": {"data_format": "HWC"}} for image in all_images ], name="on_fit_start", step=self.true_global_step, ) self.pearson = PearsonCorrCoef().to(self.device) def training_substep(self, batch, batch_idx, guidance: str): """ Args: guidance: one of "ref" (reference image supervision), "zero123" """ if guidance == "ref": # bg_color = torch.rand_like(batch['rays_o']) ambient_ratio = 1.0 shading = "diffuse" batch["shading"] = shading elif guidance == "zero123": batch = batch["random_camera"] ambient_ratio = ( self.cfg.ambient_ratio_min + (1 - self.cfg.ambient_ratio_min) * random.random() ) batch["bg_color"] = None batch["ambient_ratio"] = ambient_ratio out = self(batch) loss_prefix = f"loss_{guidance}_" loss_terms = {} def set_loss(name, value): loss_terms[f"{loss_prefix}{name}"] = value guidance_eval = ( guidance == "zero123" and self.cfg.freq.guidance_eval > 0 and self.true_global_step % self.cfg.freq.guidance_eval == 0 ) if guidance == "ref": gt_mask = batch["mask"] gt_rgb = batch["rgb"] # color loss gt_rgb = gt_rgb * gt_mask.float() + out["comp_rgb_bg"] * ( 1 - gt_mask.float() ) set_loss("rgb", F.mse_loss(gt_rgb, out["comp_rgb"])) # mask loss set_loss("mask", F.mse_loss(gt_mask.float(), out["opacity"])) # depth loss if self.C(self.cfg.loss.lambda_depth) > 0: valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)].unsqueeze(1) valid_pred_depth = out["depth"][gt_mask].unsqueeze(1) with torch.no_grad(): A = [valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1 ) # [B, 2] X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1] valid_gt_depth = A @ X # [B, 1] set_loss("depth", F.mse_loss(valid_gt_depth, valid_pred_depth)) # relative depth loss if self.C(self.cfg.loss.lambda_depth_rel) > 0: valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)] # [B,] valid_pred_depth = out["depth"][gt_mask] # [B,] set_loss( "depth_rel", 1 - self.pearson(valid_pred_depth, valid_gt_depth) ) # normal loss if self.C(self.cfg.loss.lambda_normal) > 0: valid_gt_normal = ( 1 - 2 * batch["ref_normal"][gt_mask.squeeze(-1)] ) # [B, 3] valid_pred_normal = ( 2 * out["comp_normal"][gt_mask.squeeze(-1)] - 1 ) # [B, 3] set_loss( "normal", 1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean(), ) elif guidance == "zero123": # zero123 guidance_out = self.guidance( out["comp_rgb"], **batch, rgb_as_latents=False, guidance_eval=guidance_eval, ) # claforte: TODO: rename the loss_terms keys set_loss("sds", guidance_out["loss_sds"]) if self.C(self.cfg.loss.lambda_normal_smooth) > 0: if "comp_normal" not in out: raise ValueError( "comp_normal is required for 2D normal smooth loss, no comp_normal is found in the output." ) normal = out["comp_normal"] set_loss( "normal_smooth", (normal[:, 1:, :, :] - normal[:, :-1, :, :]).square().mean() + (normal[:, :, 1:, :] - normal[:, :, :-1, :]).square().mean(), ) if self.C(self.cfg.loss.lambda_3d_normal_smooth) > 0: if "normal" not in out: raise ValueError( "Normal is required for normal smooth loss, no normal is found in the output." ) if "normal_perturb" not in out: raise ValueError( "normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." ) normals = out["normal"] normals_perturb = out["normal_perturb"] set_loss("3d_normal_smooth", (normals - normals_perturb).abs().mean()) if not self.cfg.refinement: if self.C(self.cfg.loss.lambda_orient) > 0: if "normal" not in out: raise ValueError( "Normal is required for orientation loss, no normal is found in the output." ) set_loss( "orient", ( out["weights"].detach() * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2 ).sum() / (out["opacity"] > 0).sum(), ) if guidance != "ref" and self.C(self.cfg.loss.lambda_sparsity) > 0: set_loss("sparsity", (out["opacity"] ** 2 + 0.01).sqrt().mean()) if self.C(self.cfg.loss.lambda_opaque) > 0: opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) set_loss( "opaque", binary_cross_entropy(opacity_clamped, opacity_clamped) ) else: if self.C(self.cfg.loss.lambda_normal_consistency) > 0: set_loss("normal_consistency", out["mesh"].normal_consistency()) if self.C(self.cfg.loss.lambda_laplacian_smoothness) > 0: set_loss("laplacian_smoothness", out["mesh"].laplacian()) loss = 0.0 for name, value in loss_terms.items(): self.log(f"train/{name}", value) if name.startswith(loss_prefix): loss_weighted = value * self.C( self.cfg.loss[name.replace(loss_prefix, "lambda_")] ) self.log(f"train/{name}_w", loss_weighted) loss += loss_weighted for name, value in self.cfg.loss.items(): self.log(f"train_params/{name}", self.C(value)) self.log(f"train/loss_{guidance}", loss) if guidance_eval: self.guidance_evaluation_save( out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]], guidance_out["eval"], ) return {"loss": loss} def training_step(self, batch, batch_idx): if self.cfg.freq.get("ref_or_zero123", "accumulate") == "accumulate": do_ref = True do_zero123 = True elif self.cfg.freq.get("ref_or_zero123", "accumulate") == "alternate": do_ref = ( self.true_global_step < self.cfg.freq.ref_only_steps or self.true_global_step % self.cfg.freq.n_ref == 0 ) do_zero123 = not do_ref total_loss = 0.0 if do_zero123: out = self.training_substep(batch, batch_idx, guidance="zero123") total_loss += out["loss"] if do_ref: out = self.training_substep(batch, batch_idx, guidance="ref") total_loss += out["loss"] self.log("train/loss", total_loss, prog_bar=True) # sch = self.lr_schedulers() # sch.step() return {"loss": total_loss} def validation_step(self, batch, batch_idx): out = self(batch) self.save_image_grid( f"it{self.true_global_step}-val/{batch['index'][0]}.png", ( [ { "type": "rgb", "img": batch["rgb"][0], "kwargs": {"data_format": "HWC"}, } ] if "rgb" in batch else [] ) + [ { "type": "rgb", "img": out["comp_rgb"][0], "kwargs": {"data_format": "HWC"}, }, ] + ( [ { "type": "rgb", "img": out["comp_normal"][0], "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, } ] if "comp_normal" in out else [] ) + ( [ { "type": "grayscale", "img": out["depth"][0], "kwargs": {}, } ] if "depth" in out else [] ) + [ { "type": "grayscale", "img": out["opacity"][0, :, :, 0], "kwargs": {"cmap": None, "data_range": (0, 1)}, }, ], # claforte: TODO: don't hardcode the frame numbers to record... read them from cfg instead. name=f"validation_step_batchidx_{batch_idx}" if batch_idx in [0, 7, 15, 23, 29] else None, step=self.true_global_step, ) def on_validation_epoch_end(self): filestem = f"it{self.true_global_step}-val" self.save_img_sequence( filestem, filestem, "(\d+)\.png", save_format="mp4", fps=30, name="validation_epoch_end", step=self.true_global_step, ) shutil.rmtree( os.path.join(self.get_save_dir(), f"it{self.true_global_step}-val") ) def test_step(self, batch, batch_idx): out = self(batch) self.save_image_grid( f"it{self.true_global_step}-test/{batch['index'][0]}.png", ( [ { "type": "rgb", "img": batch["rgb"][0], "kwargs": {"data_format": "HWC"}, } ] if "rgb" in batch else [] ) + [ { "type": "rgb", "img": out["comp_rgb"][0], "kwargs": {"data_format": "HWC"}, }, ] + ( [ { "type": "rgb", "img": out["comp_normal"][0], "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, } ] if "comp_normal" in out else [] ) + ( [ { "type": "grayscale", "img": out["depth"][0], "kwargs": {}, } ] if "depth" in out else [] ) + [ { "type": "grayscale", "img": out["opacity"][0, :, :, 0], "kwargs": {"cmap": None, "data_range": (0, 1)}, }, ], name="test_step", step=self.true_global_step, ) def on_test_epoch_end(self): self.save_img_sequence( f"it{self.true_global_step}-test", f"it{self.true_global_step}-test", "(\d+)\.png", save_format="mp4", fps=30, name="test", step=self.true_global_step, ) shutil.rmtree( os.path.join(self.get_save_dir(), f"it{self.true_global_step}-test") )