diff --git a/inference/kernel.py b/inference/kernel.py index ba18dca..f46d384 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -1,9 +1,21 @@ +import sys +import platform +if platform.system() == "Windows": + sys.exit("Triton is not supported on Windows. Please use Linux or a Linux-based Docker container.") + from typing import Tuple -import torch -import triton -import triton.language as tl -from triton import Config +try: + import torch +except ImportError: + sys.exit("PyTorch is required for this project. Please install torch >=2.1.0.") + +try: + import triton + import triton.language as tl + from triton import Config +except ImportError: + sys.exit("Triton is required for this project. Please install it on a supported Linux system with CUDA.") @triton.jit