from typing import Tuple, Union, List, Dict, Optional, Literal import torch import torch.nn as nn import torchvision.transforms from einops import rearrange from deepseek_vlm.models.siglip_vit import create_siglip_vit from deepseek_vlm.models.sam import create_sam_vit class CLIPVisionTower(nn.Module): def __init__(self, model_name: str = "siglip_large_patch16_384", image_size: Union[Tuple[int, int], int] = 336, select_feature: str = "patch", select_layer: int = -2, select_layers: list = None, ckpt_path: str = "", pixel_mean: Optional[List[float]] = None, pixel_std: Optional[List[float]] = None, **kwargs): super().__init__() self.model_name = model_name self.select_feature = select_feature self.select_layer = select_layer self.select_layers = select_layers vision_tower_params = { "model_name": model_name, "image_size": image_size, "ckpt_path": ckpt_path, "select_layer": select_layer } vision_tower_params.update(kwargs) self.vision_tower, self.forward_kwargs = self.build_vision_tower(vision_tower_params) if pixel_mean is not None and pixel_std is not None: image_norm = torchvision.transforms.Normalize(mean=pixel_mean, std=pixel_std) else: image_norm = None self.image_norm = image_norm def build_vision_tower(self, vision_tower_params): if self.model_name.startswith("siglip"): self.select_feature = "same" vision_tower = create_siglip_vit(**vision_tower_params) forward_kwargs = dict() elif self.model_name.startswith("sam"): vision_tower = create_sam_vit(**vision_tower_params) forward_kwargs = dict() else: # huggingface from transformers import CLIPVisionModel vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) forward_kwargs = dict(output_hidden_states=True) return vision_tower, forward_kwargs def feature_select(self, image_forward_outs): if isinstance(image_forward_outs, torch.Tensor): # the output has been the self.select_layer"s features image_features = image_forward_outs else: image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_feature == "patch": # if the output has cls_token image_features = image_features[:, 1:] elif self.select_feature == "cls_patch": image_features = image_features elif self.select_feature == "same": image_features = image_features else: raise ValueError(f"Unexpected select feature: {self.select_feature}") return image_features def forward(self, images): """ Args: images (torch.Tensor): [b, 3, H, W] Returns: image_features (torch.Tensor): [b, n_patch, d] """ if self.image_norm is not None: images = self.image_norm(images) image_forward_outs = self.vision_tower(images, **self.forward_kwargs) image_features = self.feature_select(image_forward_outs) return image_features class HybridVisionTower(nn.Module): def __init__(self, high_res_cfg: Dict, low_res_cfg: Dict, freeze_high: bool = False, freeze_low: bool = False, concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple", **ignore_kwargs): super().__init__() self.vision_tower_high = CLIPVisionTower(**high_res_cfg) self.vision_tower_low = CLIPVisionTower(**low_res_cfg) self.low_res_size = low_res_cfg["image_size"] self.concat_type = concat_type self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024)) self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024)) if freeze_high: for p_name, p in self.vision_tower_high.named_parameters(): p.requires_grad = False self.vision_tower_high = self.vision_tower_high.eval() else: # train donwsamples and neck for p_name, p in self.vision_tower_high.named_parameters(): if "downsamples" in p_name or "neck" in p_name: p.requires_grad = True else: p.requires_grad = False if freeze_low: for p in self.vision_tower_low.parameters(): p.requires_grad = False self.vision_tower_low = self.vision_tower_low.eval() self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True) def forward(self, images: torch.Tensor): """ Args: images (torch.Tensor): [bs, 3, H, W] Returns: res (torch.Tensor): [bs, t, c] """ # [bs, c, h, w] high_images = images # [bs, c, h_low, w_low] low_images = self.resize(images) # separately run two vision towers # run high_res vision tower high_res = self.vision_tower_high(high_images) # [bs, c, h, w] -> [bs, h*w, c] high_res = rearrange(high_res, "b c h w -> b (h w) c") # run low_res vision tower low_res = self.vision_tower_low(low_images) if self.concat_type == "feature": images_features = torch.cat([high_res, low_res], dim=-1) elif self.concat_type == "sequence": images_features = torch.cat([high_res, low_res], dim=1) elif self.concat_type == "add": images_features = high_res + low_res elif self.concat_type == "tuple": images_features = (high_res, low_res) else: raise ValueError(f"Currently only support `feature`, `sequence`, `add` and `tuple` concat type.") return images_features if __name__ == "__main__": image_size = 1024 x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda() high_res_cfg = dict( model_name="sam_b_downsample", select_feature="same", image_size=image_size, pixel_mean=(0.48145466, 0.4578275, 0.40821073), pixel_std=(0.26862954, 0.26130258, 0.27577711), select_layer=-1, ckpt_path="" ) low_res_cfg = dict( model_name="siglip_large_patch16_384", select_feature="same", image_size=384, pixel_mean=(0.5, 0.5, 0.5), pixel_std=(0.5, 0.5, 0.5), select_layer=-1, ckpt_path="" ) net = HybridVisionTower( high_res_cfg=high_res_cfg, low_res_cfg=low_res_cfg, freeze_high=True, freeze_low=True, concat_type="tuple" ).bfloat16().cuda() high_x, low_x = net(x) print(x.shape, high_x.shape, low_x.shape)