from attrdict import AttrDict import torch import torch.nn as nn from typing import Union, Tuple class MlpProjector(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg if cfg.projector_type == "identity": modules = nn.Identity() elif cfg.projector_type == "linear": modules = nn.Linear(cfg.input_dim, cfg.n_embed) elif cfg.projector_type == "mlp_gelu": mlp_depth = cfg.get("depth", 1) modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) modules = nn.Sequential(*modules) elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": mlp_depth = cfg.get("depth", 1) self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) modules = [] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) modules = nn.Sequential(*modules) else: raise ValueError(f"Unknown projector type: {cfg.projector_type}") self.layers = modules def forward(self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]): """ Args: x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); otherwise it is the feature from the single vision encoder. Returns: x (torch.Tensor): [b, s, c] """ if isinstance(x_or_tuple, tuple): # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": high_x, low_x = x_or_tuple high_x = self.high_up_proj(high_x) low_x = self.low_up_proj(low_x) x = torch.concat([high_x, low_x], dim=-1) else: x = x_or_tuple return self.layers(x) if __name__ == "__main__": cfg = AttrDict( input_dim=1024, n_embed=2048, depth=2, projector_type="low_high_hybrid_split_mlp_gelu" ) inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) m = MlpProjector(cfg) out = m(inputs) print(out.shape)