From a0a75d0692d3d524823f0130de505b3159e75e4b Mon Sep 17 00:00:00 2001 From: nikola Date: Wed, 29 Jan 2025 16:50:22 +0000 Subject: [PATCH] Added optional GPU Memory Logging --- inference/generate.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..d7f16c6 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -10,6 +10,17 @@ from safetensors.torch import load_model from model import Transformer, ModelArgs +def log_gpu_memory(prefix: str = ""): + """ + Logs the current allocated and reserved GPU memory in MB. + Prints only on rank=0 due to the project's existing print override. + + Args: + prefix (str, optional): A label for the log, e.g. 'After loading:'. + """ + allocated = torch.cuda.memory_allocated() / (1024 ** 2) + reserved = torch.cuda.memory_reserved() / (1024 ** 2) + print(f"{prefix} GPU memory allocated: {allocated:.2f} MB | reserved: {reserved:.2f} MB") def sample(logits, temperature: float = 1.0): """ @@ -117,6 +128,9 @@ def main( tokenizer = AutoTokenizer.from_pretrained(ckpt_path) tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) + if args.log_gpu_memory: + log_gpu_memory("After loading model:") + if interactive: messages = [] @@ -138,7 +152,11 @@ def main( continue messages.append({"role": "user", "content": prompt}) prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + if args.log_gpu_memory: + log_gpu_memory("Before generation (interactive):") completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature) + if args.log_gpu_memory: + log_gpu_memory("After generation (interactive):") completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) print(completion) messages.append({"role": "assistant", "content": completion}) @@ -147,7 +165,11 @@ def main( prompts = [line.strip() for line in f.readlines()] assert len(prompts) <= args.max_batch_size prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] + if args.log_gpu_memory: + log_gpu_memory("Before batch generation:") completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) + if args.log_gpu_memory: + log_gpu_memory("After batch generation:") completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) for prompt, completion in zip(prompts, completions): print("Prompt:", prompt) @@ -180,6 +202,8 @@ if __name__ == "__main__": parser.add_argument("--interactive", action="store_true") parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--log-gpu-memory", action="store_true", + help="Log GPU memory usage at key points (model load, before/after generation).") args = parser.parse_args() assert args.input_file or args.interactive main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)