Added optional GPU Memory Logging

This commit is contained in:
nikola 2025-01-29 16:50:22 +00:00 committed by Krish Gera
parent b5d872ead0
commit a0a75d0692

View File

@ -10,6 +10,17 @@ from safetensors.torch import load_model
from model import Transformer, ModelArgs from model import Transformer, ModelArgs
def log_gpu_memory(prefix: str = ""):
"""
Logs the current allocated and reserved GPU memory in MB.
Prints only on rank=0 due to the project's existing print override.
Args:
prefix (str, optional): A label for the log, e.g. 'After loading:'.
"""
allocated = torch.cuda.memory_allocated() / (1024 ** 2)
reserved = torch.cuda.memory_reserved() / (1024 ** 2)
print(f"{prefix} GPU memory allocated: {allocated:.2f} MB | reserved: {reserved:.2f} MB")
def sample(logits, temperature: float = 1.0): def sample(logits, temperature: float = 1.0):
""" """
@ -117,6 +128,9 @@ def main(
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
if args.log_gpu_memory:
log_gpu_memory("After loading model:")
if interactive: if interactive:
messages = [] messages = []
@ -138,7 +152,11 @@ def main(
continue continue
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
if args.log_gpu_memory:
log_gpu_memory("Before generation (interactive):")
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature) completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
if args.log_gpu_memory:
log_gpu_memory("After generation (interactive):")
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion) print(completion)
messages.append({"role": "assistant", "content": completion}) messages.append({"role": "assistant", "content": completion})
@ -147,7 +165,11 @@ def main(
prompts = [line.strip() for line in f.readlines()] prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size assert len(prompts) <= args.max_batch_size
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
if args.log_gpu_memory:
log_gpu_memory("Before batch generation:")
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
if args.log_gpu_memory:
log_gpu_memory("After batch generation:")
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions): for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt) print("Prompt:", prompt)
@ -180,6 +202,8 @@ if __name__ == "__main__":
parser.add_argument("--interactive", action="store_true") parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--log-gpu-memory", action="store_true",
help="Log GPU memory usage at key points (model load, before/after generation).")
args = parser.parse_args() args = parser.parse_args()
assert args.input_file or args.interactive assert args.input_file or args.interactive
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)