mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 14:18:57 -05:00
23 lines
709 B
Python
23 lines
709 B
Python
|
import torch
|
||
|
from fix_moe_symbolic_shapes import RobustMoE
|
||
|
|
||
|
def test_moe():
|
||
|
# Test with both default behavior and compiled version
|
||
|
model = RobustMoE(num_experts=4, d_model=256)
|
||
|
x = torch.randn(32, 256) # batch_size=32, d_model=256
|
||
|
|
||
|
# Test 1: Regular forward pass
|
||
|
print("Testing regular forward pass...")
|
||
|
output = model(x)
|
||
|
print(f"Output shape: {output.shape}")
|
||
|
|
||
|
# Test 2: Compiled version
|
||
|
print("\nTesting compiled version...")
|
||
|
compiled_model = torch.compile(model)
|
||
|
compiled_output = compiled_model(x)
|
||
|
print(f"Compiled output shape: {compiled_output.shape}")
|
||
|
|
||
|
print("\nAll tests passed successfully!")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test_moe()
|