diff --git a/deployment/vllm/ReadMe.md b/deployment/vllm/ReadMe.md new file mode 100644 index 0000000..74effd5 --- /dev/null +++ b/deployment/vllm/ReadMe.md @@ -0,0 +1,50 @@ +

+🚀 Dolphin vLLM Demo +

+ +## ✅ Introduction +The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json), +its architectures field is specified as "VisionEncoderDecoderModel". vLLM does not natively support this architecture. +To enable vLLM deployment of the Dolphin model, we implemented two vllm plugins: [vllm-dolphin](https://github.com/hanyd2010/vllm-dolphin)[![PyPI version](https://img.shields.io/pypi/v/vllm-dolphin)](https://pypi.org/project/vllm-dolphin/) and [vllm-mbart](https://github.com/hanyd2010/vllm-mbart)[![PyPI version](https://img.shields.io/pypi/v/vllm-mbart)](https://pypi.org/project/vllm-mbart/). +We also provide Dolphin vllm demos for both offline inference and online deployment. + +## 🛠️ Installation + +``` +# Install vllm +pip install vllm>=0.9.0 + +# Install vllm-dolphin +pip install vllm-dolphin==0.1 +``` + +## ⚡ Offline Inference +``` +# predict elements reading order +python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document." + +# recognize text/latex +python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image." +python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image." + +# recognize table +python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image." +``` + + +## ⚡ Online Inference +``` +# 1. Start Api Server +python deployment/vllm/api_server.py --model="ByteDance/Dolphin" --hf-overrides "{\"architectures\": [\"DolphinForConditionalGeneration\"]}" + +# 2. Predict +# predict elements reading order +python deployment/vllm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document." + +# recognize text/latex +python deployment/vllm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image." +python deployment/vllm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image." + +# recognize table +python deployment/vllm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image." +``` \ No newline at end of file diff --git a/deployment/vllm/api_client.py b/deployment/vllm/api_client.py new file mode 100644 index 0000000..387dcb8 --- /dev/null +++ b/deployment/vllm/api_client.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Example Python client for `vllm.entrypoints.api_server` +Start the demo server: + python -m vllm.entrypoints.api_server --model + +NOTE: The API server is used only for demonstration and simple performance +benchmarks. It is not intended for production use. +For production use, we recommend `vllm serve` and the OpenAI client API. +""" + +import argparse +import base64 +import json +from argparse import Namespace +from collections.abc import Iterable + +import requests + + +def clear_line(n: int = 1) -> None: + LINE_UP = "\033[1A" + LINE_CLEAR = "\x1b[2K" + for _ in range(n): + print(LINE_UP, end=LINE_CLEAR, flush=True) + + +def encode_image_base64(image_path: str) -> str: + """Encode local image to base64 format.""" + + with open(image_path, "rb") as f: + image_data = f.read() + result = base64.b64encode(image_data).decode("utf-8") + + return result + + +def post_http_request( + prompt: str, image_path: str, api_url: str, stream: bool = False +) -> requests.Response: + headers = {"User-Agent": "Test Client"} + pload = { + "encoder_prompt": "", + "decoder_prompt": prompt, + "image_base64": encode_image_base64(image_path), + "temperature": 0.0, + "max_tokens": 2048, + "stream": stream, + } + response = requests.post(api_url, headers=headers, json=pload, stream=stream) + return response + + +def get_streaming_response(response: requests.Response) -> Iterable[list[str]]: + for chunk in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\n" + ): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"] + yield output + + +def get_response(response: requests.Response) -> list[str]: + data = json.loads(response.content) + output = data["text"] + return output + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.") + parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg") + parser.add_argument("--stream", action="store_true") + return parser.parse_args() + + +def main(args: Namespace): + prompt = args.prompt + image_path = args.image_path + api_url = f"http://{args.host}:{args.port}/generate" + stream = args.stream + + print(f"Prompt: {prompt!r}\n", flush=True) + response = post_http_request(prompt, image_path, api_url, stream) + + if stream: + num_printed_lines = 0 + for h in get_streaming_response(response): + clear_line(num_printed_lines) + num_printed_lines = 0 + for i, line in enumerate(h): + num_printed_lines += 1 + print(f"Response {i}: {line!r}", flush=True) + else: + output = get_response(response) + print(f"Response: {output[0]!r}", flush=True) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/deployment/vllm/api_server.py b/deployment/vllm/api_server.py new file mode 100644 index 0000000..8a297f1 --- /dev/null +++ b/deployment/vllm/api_server.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +NOTE: This API server is used only for demonstrating usage of AsyncEngine +and simple performance benchmarks. It is not intended for production use. +For production use, we recommend using our OpenAI compatible server. +We are also not going to accept PRs modifying this file, please +change `vllm/entrypoints/openai/api_server.py` instead. +""" + +import asyncio +import base64 +import json +import io +import ssl +from argparse import Namespace +from collections.abc import AsyncGenerator +from PIL import Image +from typing import Any, Optional + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.utils import with_cancellation +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger("api_server") + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + return await _generate(request_dict, raw_request=request) + + +async def decode_image(image_base64: str) -> Image.Image: + image_data = base64.b64decode(image_base64) + image = Image.open(io.BytesIO(image_data)) + return image + + +async def custom_process_prompt(encoder_prompt: str, decoder_prompt: str, + image_base64: str) -> ExplicitEncoderDecoderPrompt: + assert engine is not None + tokenizer = engine.engine.get_tokenizer_group().tokenizer + image = await decode_image(image_base64) + + if encoder_prompt == "": + encoder_prompt = "0" * 783 # For Dolphin + + if decoder_prompt == "": + decoder_prompt_ids = tokenizer.bos_token_id + else: + decoder_prompt = f"{decoder_prompt.strip()} " + decoder_prompt_ids = tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"] + + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), + decoder_prompt=TokensPrompt(prompt_token_ids=decoder_prompt_ids), + ) + + return enc_dec_prompt + + +@with_cancellation +async def _generate(request_dict: dict, raw_request: Request) -> Response: + encoder_prompt = request_dict.pop("encoder_prompt", "") + decoder_prompt = request_dict.pop("decoder_prompt", "") + image_base64 = request_dict.pop("image_base64", "") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + assert engine is not None + + enc_dec_prompt = await custom_process_prompt(encoder_prompt, decoder_prompt, image_base64) + results_generator = engine.generate(enc_dec_prompt, sampling_params, request_id) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + assert prompt is not None + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\n").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) + + assert final_output is not None + prompt = final_output.prompt + assert prompt is not None + text_outputs = [prompt + output.text.strip() for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +def build_app(args: Namespace) -> FastAPI: + global app + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) + app.state.engine_client = engine + return app + + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + set_ulimit() + + app = await init_app(args, llm_engine) + assert engine is not None + + shutdown_task = await serve_http( + app, + sock=None, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + await shutdown_task + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=parser.check_port, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument("--ssl-ca-certs", + type=str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument("--log-level", type=str, default="debug") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + asyncio.run(run_server(args)) diff --git a/deployment/vllm/demo_vllm.py b/deployment/vllm/demo_vllm.py new file mode 100644 index 0000000..32be1fb --- /dev/null +++ b/deployment/vllm/demo_vllm.py @@ -0,0 +1,91 @@ +""" +Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +SPDX-License-Identifier: MIT +""" + +import vllm_dolphin # vllm_dolphin plugin +import argparse +from argparse import Namespace +from PIL import Image + +from vllm import LLM, SamplingParams +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt + +import torch +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def offline_inference(model_id: str, prompt: str, image_path: str, max_tokens: int = 2048): + dtype = "float16" if torch.cuda.is_available() else "float32" + # Create an encoder/decoder model instance + llm = LLM( + model=model_id, + dtype=dtype, + enforce_eager=True, + trust_remote_code=True, + max_num_seqs=8, + hf_overrides={"architectures": ["DolphinForConditionalGeneration"]}, + ) + + # Create a sampling params object. + sampling_params = SamplingParams( + temperature=0.0, + logprobs=0, + max_tokens=max_tokens, + prompt_logprobs=None, + skip_special_tokens=False, + ) + + # process prompt + tokenizer = llm.llm_engine.get_tokenizer_group().tokenizer + + # The Dolphin model does not require an Encoder Prompt. To ensure vllm correctly allocates KV Cache, + # it is necessary to simulate an Encoder Prompt. + encoder_prompt = "0" * 783 + decoder_prompt = f"{prompt.strip()} " + + image = Image.open(image_path) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), + decoder_prompt=TokensPrompt( + prompt_token_ids=tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"] + ), + ) + + # Generate output tokens from the prompts. The output is a list of + # RequestOutput objects that contain the prompt, generated text, and other information. + outputs = llm.generate(enc_dec_prompt, sampling_params) + + print("------" * 8) + # Print the outputs. + for output in outputs: + decoder_prompt_tokens = tokenizer.batch_decode(output.prompt_token_ids, skip_special_tokens=True) + decoder_prompt = "".join(decoder_prompt_tokens) + generated_text = output.outputs[0].text.strip() + print(f"Decoder prompt: {decoder_prompt!r}, " + f"\nGenerated text: {generated_text!r}") + + print("------" * 8) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="ByteDance/Dolphin") + parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg") + parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.") + return parser.parse_args() + + +def main(args: Namespace): + model = args.model + prompt = args.prompt + image_path = args.image_path + + offline_inference(model, prompt, image_path) + + +if __name__ == "__main__": + args = parse_args() + main(args)