diff --git a/janus/models/siglip_vit.py b/janus/models/siglip_vit.py index a93707b..ba426d6 100644 --- a/janus/models/siglip_vit.py +++ b/janus/models/siglip_vit.py @@ -92,7 +92,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # type: (torch.Tensor, float, float, float, float) -> torch.Tensor r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first - convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype. + convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype. Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within