diff --git a/demo/fastapi_app.py b/demo/fastapi_app.py index ffe11bb..9368466 100644 --- a/demo/fastapi_app.py +++ b/demo/fastapi_app.py @@ -12,9 +12,9 @@ app = FastAPI() # Device and dtype configuration def get_device_and_dtype(): if torch.cuda.is_available(): - return 'cuda', torch.bfloat16 + return 'cuda', torch.float32 elif torch.backends.mps.is_available(): - return 'mps', torch.float16 + return 'mps', torch.float32 return 'cpu', torch.float32 device, dtype = get_device_and_dtype() @@ -35,11 +35,17 @@ tokenizer = vl_chat_processor.tokenizer @torch.inference_mode() def multimodal_understanding(image_data, question, seed, top_p, temperature): - torch.cuda.empty_cache() if device == 'cuda' else None + # Clear CUDA cache if using CUDA + if device == 'cuda': + torch.cuda.empty_cache() + + # set seed torch.manual_seed(seed) np.random.seed(seed) if device == 'cuda': torch.cuda.manual_seed(seed) + elif device == 'mps': + torch.mps.manual_seed(seed) conversation = [ { @@ -94,36 +100,39 @@ def generate(input_ids, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): - torch.cuda.empty_cache() if device == 'cuda' else None - tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device) - for i in range(parallel_size * 2): - tokens[i, :] = input_ids - if i % 2 != 0: - tokens[i, 1:-1] = vl_chat_processor.pad_id - inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) - generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device) + try: + torch.cuda.empty_cache() if device == 'cuda' else None + tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device) + for i in range(parallel_size * 2): + tokens[i, :] = input_ids + if i % 2 != 0: + tokens[i, 1:-1] = vl_chat_processor.pad_id + inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) + generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device) - pkv = None - for i in range(image_token_num_per_image): - outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) - pkv = outputs.past_key_values - hidden_states = outputs.last_hidden_state - logits = vl_gpt.gen_head(hidden_states[:, -1, :]) - logit_cond = logits[0::2, :] - logit_uncond = logits[1::2, :] - logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) - probs = torch.softmax(logits / temperature, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - generated_tokens[:, i] = next_token.squeeze(dim=-1) - next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) - img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) - inputs_embeds = img_embeds.unsqueeze(dim=1) - patches = vl_gpt.gen_vision_model.decode_code( - generated_tokens.to(dtype=torch.int), - shape=[parallel_size, 8, width // patch_size, height // patch_size] - ) + pkv = None + for i in range(image_token_num_per_image): + outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) + pkv = outputs.past_key_values + hidden_states = outputs.last_hidden_state + logits = vl_gpt.gen_head(hidden_states[:, -1, :]) + logit_cond = logits[0::2, :] + logit_uncond = logits[1::2, :] + logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) + probs = torch.softmax(logits / temperature, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + generated_tokens[:, i] = next_token.squeeze(dim=-1) + next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) + img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) + inputs_embeds = img_embeds.unsqueeze(dim=1) + patches = vl_gpt.gen_vision_model.decode_code( + generated_tokens.to(dtype=torch.int), + shape=[parallel_size, 8, width // patch_size, height // patch_size] + ) - return generated_tokens.to(dtype=torch.int), patches + return generated_tokens.to(dtype=torch.int), patches + except Exception as e: + raise Exception(f"Error in generate function: {str(e)}") def unpack(dec, width, height, parallel_size=5): @@ -138,29 +147,37 @@ def unpack(dec, width, height, parallel_size=5): @torch.inference_mode() def generate_image(prompt, seed, guidance): - torch.cuda.empty_cache() if device == 'cuda' else None - seed = seed if seed is not None else 12345 - torch.manual_seed(seed) - if device == 'cuda': - torch.cuda.manual_seed(seed) - np.random.seed(seed) - width = 384 - height = 384 - parallel_size = 5 - - with torch.no_grad(): - messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}] - text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( - conversations=messages, - sft_format=vl_chat_processor.sft_format, - system_prompt='' - ) - text = text + vl_chat_processor.image_start_tag - input_ids = torch.LongTensor(tokenizer.encode(text)) - _, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size) - images = unpack(patches, width // 16 * 16, height // 16 * 16) + try: + # Clear CUDA cache if using CUDA + if device == 'cuda': + torch.cuda.empty_cache() + # Set the seed for reproducible results + if seed is not None: + torch.manual_seed(seed) + if device == 'cuda': + torch.cuda.manual_seed(seed) + elif device == 'mps': + torch.mps.manual_seed(seed) + np.random.seed(seed) + width = 384 + height = 384 + parallel_size = 5 + + with torch.no_grad(): + messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}] + text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( + conversations=messages, + sft_format=vl_chat_processor.sft_format, + system_prompt='' + ) + text = text + vl_chat_processor.image_start_tag + input_ids = torch.LongTensor(tokenizer.encode(text)) + _, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size) + images = unpack(patches, width // 16 * 16, height // 16 * 16) - return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)] + return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)] + except Exception as e: + raise Exception(f"Error in generate_image function: {str(e)}") @app.post("/generate_images/")