Merge pull request #8 from eltociear/patch-1

chore: update siglip_vit.py
This commit is contained in:
Wu Chengyue 2024-10-19 17:44:59 +08:00 committed by GitHub
commit 0214867df2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -92,7 +92,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within