mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 14:18:57 -05:00
Change
Changes: init_distributed function: Extracted the distributed setup logic into a separate function. sample function: Modified it to use torch.multinomial instead of an exponentiation-based approach for sampling. Argument Validation: Replaced the assert with a more user-friendly validation in main to ensure that at least one of the parameters (input-file or interactive) is provided. Interactive Code Refactoring: The user interaction logic was kept, but the init_distributed function is now called separately at the beginning of main.
This commit is contained in:
parent
b5d872ead0
commit
89882a94f6
@ -24,7 +24,27 @@ def sample(logits, temperature: float = 1.0):
|
|||||||
"""
|
"""
|
||||||
logits = logits / max(temperature, 1e-5)
|
logits = logits / max(temperature, 1e-5)
|
||||||
probs = torch.softmax(logits, dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
return torch.multinomial(probs, 1) # Usando uma distribuição de probabilidade
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed(rank: int, world_size: int, local_rank: int):
|
||||||
|
"""
|
||||||
|
Initialize the distributed process group and set device configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rank (int): The rank of the current process.
|
||||||
|
world_size (int): Total number of processes.
|
||||||
|
local_rank (int): The local rank for multi-GPU configurations.
|
||||||
|
"""
|
||||||
|
if world_size > 1:
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
global print
|
||||||
|
if rank != 0:
|
||||||
|
print = lambda *_, **__: None
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
torch.set_num_threads(8)
|
||||||
|
torch.manual_seed(965)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -100,20 +120,17 @@ def main(
|
|||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
rank = int(os.getenv("RANK", "0"))
|
rank = int(os.getenv("RANK", "0"))
|
||||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
if world_size > 1:
|
|
||||||
dist.init_process_group("nccl")
|
# Initialize distributed configuration
|
||||||
global print
|
init_distributed(rank, world_size, local_rank)
|
||||||
if rank != 0:
|
|
||||||
print = lambda *_, **__: None
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
|
||||||
torch.set_num_threads(8)
|
|
||||||
torch.manual_seed(965)
|
|
||||||
with open(config) as f:
|
with open(config) as f:
|
||||||
args = ModelArgs(**json.load(f))
|
args = ModelArgs(**json.load(f))
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
model = Transformer(args)
|
model = Transformer(args)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
|
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
|
||||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
||||||
@ -181,5 +198,10 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--max-new-tokens", type=int, default=200)
|
parser.add_argument("--max-new-tokens", type=int, default=200)
|
||||||
parser.add_argument("--temperature", type=float, default=0.2)
|
parser.add_argument("--temperature", type=float, default=0.2)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert args.input_file or args.interactive
|
|
||||||
|
# Validate input
|
||||||
|
if not (args.input_file or args.interactive):
|
||||||
|
print("Erro: É necessário especificar --input-file ou ativar --interactive.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
||||||
|
Loading…
Reference in New Issue
Block a user