From 53764900610cd83541e7490f92c8d86040fff7da Mon Sep 17 00:00:00 2001 From: "che.ender" Date: Sat, 8 Feb 2025 15:17:29 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E5=8F=AF?= =?UTF-8?q?=E4=BB=A5=E7=9B=B4=E6=8E=A5=E4=BD=BF=E7=94=A8pro=207b=E7=9A=84a?= =?UTF-8?q?pi=E6=9C=8D=E5=8A=A1=E5=99=A8,=E5=AE=A2=E6=88=B7=E7=AB=AF?= =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E7=9B=B4=E6=8E=A5=E4=BD=BF=E7=94=A8=20fastap?= =?UTF-8?q?i=5Fclient.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加了可以直接使用pro 7b的api服务器,客户端可以直接使用 fastapi_client.py --- demo/fastapi_pro.py | 186 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 demo/fastapi_pro.py diff --git a/demo/fastapi_pro.py b/demo/fastapi_pro.py new file mode 100644 index 0000000..b327dc4 --- /dev/null +++ b/demo/fastapi_pro.py @@ -0,0 +1,186 @@ +from fastapi import FastAPI, File, Form, UploadFile, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse +import torch +from transformers import AutoConfig, AutoModelForCausalLM +from janus.models import MultiModalityCausalLM, VLChatProcessor +from PIL import Image +import numpy as np +import io + +app = FastAPI() + + + +# Load model and processor +model_path = "deepseek-ai/Janus-Pro-7B" + +config = AutoConfig.from_pretrained(model_path) +language_config = config.language_config +language_config._attn_implementation = 'eager' +vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, + language_config=language_config, + trust_remote_code=True) +if torch.cuda.is_available(): + vl_gpt = vl_gpt.to(torch.bfloat16).cuda() +else: + vl_gpt = vl_gpt.to(torch.float16) + +vl_chat_processor = VLChatProcessor.from_pretrained(model_path) +tokenizer = vl_chat_processor.tokenizer +cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' + + +@torch.inference_mode() +def multimodal_understanding(image_data, question, seed, top_p, temperature): + torch.cuda.empty_cache() + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed(seed) + + conversation = [ + { + "role": "<|User|>", + "content": f"\n{question}", + "images": [image_data], + }, + {"role": "<|Assistant|>", "content": ""}, + ] + + # load images and prepare for inputs + pil_images = [Image.open(io.BytesIO(image_data))] + prepare_inputs = vl_chat_processor( + conversations=conversation, images=pil_images, force_batchify=True + ).to(vl_gpt.device) + + # # run image encoder to get the image embeddings + inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + + # # run the model to get the response + outputs = vl_gpt.language_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=prepare_inputs.attention_mask, + pad_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=512, + do_sample=False, + use_cache=True, + ) + + answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) + return answer + + +@app.post("/understand_image_and_question/") +async def understand_image_and_question( + file: UploadFile = File(...), + question: str = Form(...), + seed: int = Form(42), + top_p: float = Form(0.95), + temperature: float = Form(0.1) +): + image_data = await file.read() + response = multimodal_understanding(image_data, question, seed, top_p, temperature) + return JSONResponse({"response": response}) + + +def generate(input_ids, + width, + height, + temperature: float = 1, + parallel_size: int = 5, + cfg_weight: float = 5, + image_token_num_per_image: int = 576, + patch_size: int = 16): + torch.cuda.empty_cache() + tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) + for i in range(parallel_size * 2): + tokens[i, :] = input_ids + if i % 2 != 0: + tokens[i, 1:-1] = vl_chat_processor.pad_id + inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) + generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) + + pkv = None + for i in range(image_token_num_per_image): + outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) + pkv = outputs.past_key_values + hidden_states = outputs.last_hidden_state + logits = vl_gpt.gen_head(hidden_states[:, -1, :]) + logit_cond = logits[0::2, :] + logit_uncond = logits[1::2, :] + logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) + probs = torch.softmax(logits / temperature, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + generated_tokens[:, i] = next_token.squeeze(dim=-1) + next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) + img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) + inputs_embeds = img_embeds.unsqueeze(dim=1) + patches = vl_gpt.gen_vision_model.decode_code( + generated_tokens.to(dtype=torch.int), + shape=[parallel_size, 8, width // patch_size, height // patch_size] + ) + + return generated_tokens.to(dtype=torch.int), patches + + +def unpack(dec, width, height, parallel_size=5): + dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) + dec = np.clip((dec + 1) / 2 * 255, 0, 255) + + visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) + visual_img[:, :, :] = dec + + return visual_img + + +@torch.inference_mode() +def generate_image(prompt, seed, guidance): + torch.cuda.empty_cache() + seed = seed if seed is not None else 12345 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + width = 384 + height = 384 + parallel_size = 5 + + with torch.no_grad(): + messages = [{'role': '<|User|>', 'content': prompt},{'role': '<|Assistant|>', 'content': ''}] + text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( + conversations=messages, + sft_format=vl_chat_processor.sft_format, + system_prompt='' + ) + text = text + vl_chat_processor.image_start_tag + input_ids = torch.LongTensor(tokenizer.encode(text)) + _, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size) + images = unpack(patches, width // 16 * 16, height // 16 * 16) + + return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)] + + +@app.post("/generate_images/") +async def generate_images( + prompt: str = Form(...), + seed: int = Form(None), + guidance: float = Form(5.0), +): + try: + images = generate_image(prompt, seed, guidance) + def image_stream(): + for img in images: + buf = io.BytesIO() + img.save(buf, format='PNG') + buf.seek(0) + yield buf.read() + + return StreamingResponse(image_stream(), media_type="multipart/related") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}") + + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) From c74d76954f487eec67c5c3fb8322de9ae062239c Mon Sep 17 00:00:00 2001 From: "furanger@sina.com" Date: Tue, 25 Feb 2025 17:40:40 +0800 Subject: [PATCH 2/4] Fine tune the Janus-Pro-7B model using ms swift, And add a proxy service interface, Integrate Swift with fastapi_cient py The client can directly use fastapi_cient Py doesn't need to be changed --- demo/README.md | 42 ++++++++++++++++++++++++++ demo/fastapi_swift.py | 68 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 demo/README.md create mode 100644 demo/fastapi_swift.py diff --git a/demo/README.md b/demo/README.md new file mode 100644 index 0000000..aabd635 --- /dev/null +++ b/demo/README.md @@ -0,0 +1,42 @@ + +## 0 . Fine tuning restrictions +Fine tuning supports training for image understanding but not image generation + +## 1. install ms-swift +use ms-swift Fine tune the Janus-Pro-7B model, +First, install ms-swift +---------------------------------------------- +pip install git+https://github.com/modelscope/ms-swift.git +cd ms-swift +pip install -e . +-------------------------------------------------------- +## 2. Datasets +The dataset format is +{"messages": [{"role": "user", "content": "Does the construction worker in this picture comply with the safety regulations for high-altitude operations?"}, {"role": "assistant", "content": "In the high-altitude work area, people entering the construction site must wear safety helmets, and high-altitude workers should wear safety belts. The other end of the safety belt must be hung higher than the human body, which is called high hanging and low use. The high-altitude workers in the picture did not wear safety belts, which does not meet the safety standards for high-altitude operations."}], "images": ["root/train/train_images/wpd-36.jpg"]} + +## 3. Fine tuning +lora Fine tuning +swift sft --model_type deepseek_janus_pro --model --dataset --target_modules all-linear + +full Fine tuning +swift sft --model_type deepseek_janus_pro --model --dataset --train_type full + +## 4. swift model Service +swift deploy --ckpt_dir + + +## 5. swift model Proxy Service +fastapi_swift.py + +## 6. Client API fastapi_client. py +Submit questions and receive responses to the swift model Proxy Service using fastapi_client. py + + +## Other1. +The parameters(seed、top_p、temperature ) are no longer useful, but in order to maintain interface reuse, they are retained + +## Other2. +If the Swift model needs to change the directory, the configuration file needs to be changed +Adapterconfig.json Modify 'base_madel_name_or_path' +Args.json modifies 'model' +Specify the Janus-Pro-7B directory for classical gravity diff --git a/demo/fastapi_swift.py b/demo/fastapi_swift.py new file mode 100644 index 0000000..798bad8 --- /dev/null +++ b/demo/fastapi_swift.py @@ -0,0 +1,68 @@ +from fastapi import FastAPI, File, Form, UploadFile, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse +from PIL import Image +import numpy as np +import io +import hashlib +import traceback +import json +import requests + + +app = FastAPI() +understand_image_and_question_url = "http://localhost:8000/v1/chat/completions" + + +@app.post("/understand_image_and_question/") +async def understand_image_and_question( + file: UploadFile = File(...), + question: str = Form(...), + seed: int = Form(42), + top_p: float = Form(0.95), + temperature: float = Form(0.1) +): + # images file max size 8mb + maxfilesize = 8 * 1024 * 1024 + image_data = await file.read(maxfilesize) + try: + # Upload file directory + imagedirectory = "./uploads/" + # Need to match version with Swift service + JanusVersion = "Janus-Pro-7B" + #JanusVersion = "Janus-Pro-1B" + + file = Image.open(io.BytesIO(image_data)) + hash_obj = hashlib.md5() + hash_obj.update(image_data) + file_hash = hash_obj.hexdigest() + filename = imagedirectory + file_hash + ".png" + file.save(filename, format='PNG') + file.close() + + outjson = {"model": JanusVersion, + "messages": [{"role": "user", + "content": "" + question} ], + "images": [filename]} + + outjson = json.dumps(outjson,ensure_ascii=False) + response = requests.post(understand_image_and_question_url, data=outjson, stream=False) + response_data = response.json() + return response_data + + except Exception as e: + print("-----------------------------------------------") + error_type = type(e).__name__ + error_msg = str(e) + print(error_type) + print(error_msg) + traceback.print_exc() + print("-----------------------------------------------") + + return "images file bad" + + + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) From 5892ea3872fb551ddfa072c87d55e1b46d37f1e0 Mon Sep 17 00:00:00 2001 From: "furanger@sina.com" Date: Wed, 26 Feb 2025 12:48:19 +0800 Subject: [PATCH 3/4] =?UTF-8?q?Revert=20"=E5=A2=9E=E5=8A=A0=E4=BA=86?= =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E7=9B=B4=E6=8E=A5=E4=BD=BF=E7=94=A8pro=207b?= =?UTF-8?q?=E7=9A=84api=E6=9C=8D=E5=8A=A1=E5=99=A8,=E5=AE=A2=E6=88=B7?= =?UTF-8?q?=E7=AB=AF=E5=8F=AF=E4=BB=A5=E7=9B=B4=E6=8E=A5=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=20fastapi=5Fclient.py"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 53764900610cd83541e7490f92c8d86040fff7da. --- demo/fastapi_pro.py | 186 -------------------------------------------- 1 file changed, 186 deletions(-) delete mode 100644 demo/fastapi_pro.py diff --git a/demo/fastapi_pro.py b/demo/fastapi_pro.py deleted file mode 100644 index b327dc4..0000000 --- a/demo/fastapi_pro.py +++ /dev/null @@ -1,186 +0,0 @@ -from fastapi import FastAPI, File, Form, UploadFile, HTTPException -from fastapi.responses import JSONResponse, StreamingResponse -import torch -from transformers import AutoConfig, AutoModelForCausalLM -from janus.models import MultiModalityCausalLM, VLChatProcessor -from PIL import Image -import numpy as np -import io - -app = FastAPI() - - - -# Load model and processor -model_path = "deepseek-ai/Janus-Pro-7B" - -config = AutoConfig.from_pretrained(model_path) -language_config = config.language_config -language_config._attn_implementation = 'eager' -vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, - language_config=language_config, - trust_remote_code=True) -if torch.cuda.is_available(): - vl_gpt = vl_gpt.to(torch.bfloat16).cuda() -else: - vl_gpt = vl_gpt.to(torch.float16) - -vl_chat_processor = VLChatProcessor.from_pretrained(model_path) -tokenizer = vl_chat_processor.tokenizer -cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' - - -@torch.inference_mode() -def multimodal_understanding(image_data, question, seed, top_p, temperature): - torch.cuda.empty_cache() - torch.manual_seed(seed) - np.random.seed(seed) - torch.cuda.manual_seed(seed) - - conversation = [ - { - "role": "<|User|>", - "content": f"\n{question}", - "images": [image_data], - }, - {"role": "<|Assistant|>", "content": ""}, - ] - - # load images and prepare for inputs - pil_images = [Image.open(io.BytesIO(image_data))] - prepare_inputs = vl_chat_processor( - conversations=conversation, images=pil_images, force_batchify=True - ).to(vl_gpt.device) - - # # run image encoder to get the image embeddings - inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) - - # # run the model to get the response - outputs = vl_gpt.language_model.generate( - inputs_embeds=inputs_embeds, - attention_mask=prepare_inputs.attention_mask, - pad_token_id=tokenizer.eos_token_id, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - max_new_tokens=512, - do_sample=False, - use_cache=True, - ) - - answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) - return answer - - -@app.post("/understand_image_and_question/") -async def understand_image_and_question( - file: UploadFile = File(...), - question: str = Form(...), - seed: int = Form(42), - top_p: float = Form(0.95), - temperature: float = Form(0.1) -): - image_data = await file.read() - response = multimodal_understanding(image_data, question, seed, top_p, temperature) - return JSONResponse({"response": response}) - - -def generate(input_ids, - width, - height, - temperature: float = 1, - parallel_size: int = 5, - cfg_weight: float = 5, - image_token_num_per_image: int = 576, - patch_size: int = 16): - torch.cuda.empty_cache() - tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) - for i in range(parallel_size * 2): - tokens[i, :] = input_ids - if i % 2 != 0: - tokens[i, 1:-1] = vl_chat_processor.pad_id - inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) - generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) - - pkv = None - for i in range(image_token_num_per_image): - outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) - pkv = outputs.past_key_values - hidden_states = outputs.last_hidden_state - logits = vl_gpt.gen_head(hidden_states[:, -1, :]) - logit_cond = logits[0::2, :] - logit_uncond = logits[1::2, :] - logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) - probs = torch.softmax(logits / temperature, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - generated_tokens[:, i] = next_token.squeeze(dim=-1) - next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) - img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) - inputs_embeds = img_embeds.unsqueeze(dim=1) - patches = vl_gpt.gen_vision_model.decode_code( - generated_tokens.to(dtype=torch.int), - shape=[parallel_size, 8, width // patch_size, height // patch_size] - ) - - return generated_tokens.to(dtype=torch.int), patches - - -def unpack(dec, width, height, parallel_size=5): - dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) - dec = np.clip((dec + 1) / 2 * 255, 0, 255) - - visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) - visual_img[:, :, :] = dec - - return visual_img - - -@torch.inference_mode() -def generate_image(prompt, seed, guidance): - torch.cuda.empty_cache() - seed = seed if seed is not None else 12345 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - np.random.seed(seed) - width = 384 - height = 384 - parallel_size = 5 - - with torch.no_grad(): - messages = [{'role': '<|User|>', 'content': prompt},{'role': '<|Assistant|>', 'content': ''}] - text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( - conversations=messages, - sft_format=vl_chat_processor.sft_format, - system_prompt='' - ) - text = text + vl_chat_processor.image_start_tag - input_ids = torch.LongTensor(tokenizer.encode(text)) - _, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size) - images = unpack(patches, width // 16 * 16, height // 16 * 16) - - return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)] - - -@app.post("/generate_images/") -async def generate_images( - prompt: str = Form(...), - seed: int = Form(None), - guidance: float = Form(5.0), -): - try: - images = generate_image(prompt, seed, guidance) - def image_stream(): - for img in images: - buf = io.BytesIO() - img.save(buf, format='PNG') - buf.seek(0) - yield buf.read() - - return StreamingResponse(image_stream(), media_type="multipart/related") - except Exception as e: - raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}") - - - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) From ecc599e43e636ca2993ced9d1f96fa46d819e8a7 Mon Sep 17 00:00:00 2001 From: "che.ender" Date: Wed, 26 Feb 2025 14:50:48 +0800 Subject: [PATCH 4/4] Update README.md add swift export --- demo/README.md | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/demo/README.md b/demo/README.md index aabd635..248f3b7 100644 --- a/demo/README.md +++ b/demo/README.md @@ -21,21 +21,32 @@ swift sft --model_type deepseek_janus_pro --model --d full Fine tuning swift sft --model_type deepseek_janus_pro --model --dataset --train_type full -## 4. swift model Service -swift deploy --ckpt_dir -## 5. swift model Proxy Service +## 4. swift model export +Export can merge two previously dispersed models into one model system +swift export --ckpt_dir + +## 5. swift model Service +swift deploy --ckpt_dir + + +## 6. swift model Proxy Service +Create an empty uploads directory fastapi_swift.py -## 6. Client API fastapi_client. py -Submit questions and receive responses to the swift model Proxy Service using fastapi_client. py +## 7. Client API fastapi_client. py +Submit questions and receive responses to the swift model Proxy Service +using fastapi_client. py -## Other1. -The parameters(seed、top_p、temperature ) are no longer useful, but in order to maintain interface reuse, they are retained +## Other1. +fastap_client.py parameters(seed、top_p、temperature ) are no longer useful, +but in order to maintain interface reuse, +they are retained -## Other2. +## Other2. +If no export is performed, then: If the Swift model needs to change the directory, the configuration file needs to be changed Adapterconfig.json Modify 'base_madel_name_or_path' Args.json modifies 'model'