support inference by vllm

This commit is contained in:
yingdong.han 2025-06-27 15:01:22 +08:00
parent eb1737ae95
commit 6177c2686b
4 changed files with 460 additions and 0 deletions

50
deployment/vllm/ReadMe.md Normal file
View File

@ -0,0 +1,50 @@
<h1 align="center">
🚀 Dolphin vLLM Demo
</h1>
## ✅ Introduction
The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json),
its architectures field is specified as "VisionEncoderDecoderModel". vLLM does not natively support this architecture.
To enable vLLM deployment of the Dolphin model, we implemented two vllm plugins: [vllm-dolphin](https://github.com/hanyd2010/vllm-dolphin)[![PyPI version](https://img.shields.io/pypi/v/vllm-dolphin)](https://pypi.org/project/vllm-dolphin/) and [vllm-mbart](https://github.com/hanyd2010/vllm-mbart)[![PyPI version](https://img.shields.io/pypi/v/vllm-mbart)](https://pypi.org/project/vllm-mbart/).
We also provide Dolphin vllm demos for both offline inference and online deployment.
## 🛠️ Installation
```
# Install vllm
pip install vllm>=0.9.0
# Install vllm-dolphin
pip install vllm-dolphin==0.1
```
## ⚡ Offline Inference
```
# predict elements reading order
python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
# recognize text/latex
python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
# recognize table
python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
```
## ⚡ Online Inference
```
# 1. Start Api Server
python deployment/vllm/api_server.py --model="ByteDance/Dolphin" --hf-overrides "{\"architectures\": [\"DolphinForConditionalGeneration\"]}"
# 2. Predict
# predict elements reading order
python deployment/vllm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
# recognize text/latex
python deployment/vllm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
python deployment/vllm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
# recognize table
python deployment/vllm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
```

View File

@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for `vllm.entrypoints.api_server`
Start the demo server:
python -m vllm.entrypoints.api_server --model <model_name>
NOTE: The API server is used only for demonstration and simple performance
benchmarks. It is not intended for production use.
For production use, we recommend `vllm serve` and the OpenAI client API.
"""
import argparse
import base64
import json
from argparse import Namespace
from collections.abc import Iterable
import requests
def clear_line(n: int = 1) -> None:
LINE_UP = "\033[1A"
LINE_CLEAR = "\x1b[2K"
for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)
def encode_image_base64(image_path: str) -> str:
"""Encode local image to base64 format."""
with open(image_path, "rb") as f:
image_data = f.read()
result = base64.b64encode(image_data).decode("utf-8")
return result
def post_http_request(
prompt: str, image_path: str, api_url: str, stream: bool = False
) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"encoder_prompt": "",
"decoder_prompt": prompt,
"image_base64": encode_image_base64(image_path),
"temperature": 0.0,
"max_tokens": 2048,
"stream": stream,
}
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
return response
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
def get_response(response: requests.Response) -> list[str]:
data = json.loads(response.content)
output = data["text"]
return output
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
parser.add_argument("--stream", action="store_true")
return parser.parse_args()
def main(args: Namespace):
prompt = args.prompt
image_path = args.image_path
api_url = f"http://{args.host}:{args.port}/generate"
stream = args.stream
print(f"Prompt: {prompt!r}\n", flush=True)
response = post_http_request(prompt, image_path, api_url, stream)
if stream:
num_printed_lines = 0
for h in get_streaming_response(response):
clear_line(num_printed_lines)
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Response {i}: {line!r}", flush=True)
else:
output = get_response(response)
print(f"Response: {output[0]!r}", flush=True)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import asyncio
import base64
import json
import io
import ssl
from argparse import Namespace
from collections.abc import AsyncGenerator
from PIL import Image
from typing import Any, Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("api_server")
TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
engine = None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
return await _generate(request_dict, raw_request=request)
async def decode_image(image_base64: str) -> Image.Image:
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data))
return image
async def custom_process_prompt(encoder_prompt: str, decoder_prompt: str,
image_base64: str) -> ExplicitEncoderDecoderPrompt:
assert engine is not None
tokenizer = engine.engine.get_tokenizer_group().tokenizer
image = await decode_image(image_base64)
if encoder_prompt == "":
encoder_prompt = "0" * 783 # For Dolphin
if decoder_prompt == "":
decoder_prompt_ids = tokenizer.bos_token_id
else:
decoder_prompt = f"<s>{decoder_prompt.strip()} <Answer/>"
decoder_prompt_ids = tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"]
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}),
decoder_prompt=TokensPrompt(prompt_token_ids=decoder_prompt_ids),
)
return enc_dec_prompt
@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
encoder_prompt = request_dict.pop("encoder_prompt", "")
decoder_prompt = request_dict.pop("decoder_prompt", "")
image_base64 = request_dict.pop("image_base64", "")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
assert engine is not None
enc_dec_prompt = await custom_process_prompt(encoder_prompt, decoder_prompt, image_base64)
results_generator = engine.generate(enc_dec_prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\n").encode("utf-8")
if stream:
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
try:
async for request_output in results_generator:
final_output = request_output
except asyncio.CancelledError:
return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text.strip() for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
def build_app(args: Namespace) -> FastAPI:
global app
app.root_path = args.root_path
return app
async def init_app(
args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
) -> FastAPI:
app = build_app(args)
global engine
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER))
app.state.engine_client = engine
return app
async def run_server(args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs: Any) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
set_ulimit()
app = await init_app(args, llm_engine)
assert engine is not None
shutdown_task = await serve_http(
app,
sock=None,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
await shutdown_task
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=parser.check_port, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--enable-ssl-refresh",
action="store_true",
default=False,
help="Refresh SSL Context when SSL certificate files change")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
asyncio.run(run_server(args))

View File

@ -0,0 +1,91 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import vllm_dolphin # vllm_dolphin plugin
import argparse
from argparse import Namespace
from PIL import Image
from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def offline_inference(model_id: str, prompt: str, image_path: str, max_tokens: int = 2048):
dtype = "float16" if torch.cuda.is_available() else "float32"
# Create an encoder/decoder model instance
llm = LLM(
model=model_id,
dtype=dtype,
enforce_eager=True,
trust_remote_code=True,
max_num_seqs=8,
hf_overrides={"architectures": ["DolphinForConditionalGeneration"]},
)
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0.0,
logprobs=0,
max_tokens=max_tokens,
prompt_logprobs=None,
skip_special_tokens=False,
)
# process prompt
tokenizer = llm.llm_engine.get_tokenizer_group().tokenizer
# The Dolphin model does not require an Encoder Prompt. To ensure vllm correctly allocates KV Cache,
# it is necessary to simulate an Encoder Prompt.
encoder_prompt = "0" * 783
decoder_prompt = f"<s>{prompt.strip()} <Answer/>"
image = Image.open(image_path)
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}),
decoder_prompt=TokensPrompt(
prompt_token_ids=tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"]
),
)
# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated text, and other information.
outputs = llm.generate(enc_dec_prompt, sampling_params)
print("------" * 8)
# Print the outputs.
for output in outputs:
decoder_prompt_tokens = tokenizer.batch_decode(output.prompt_token_ids, skip_special_tokens=True)
decoder_prompt = "".join(decoder_prompt_tokens)
generated_text = output.outputs[0].text.strip()
print(f"Decoder prompt: {decoder_prompt!r}, "
f"\nGenerated text: {generated_text!r}")
print("------" * 8)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="ByteDance/Dolphin")
parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
return parser.parse_args()
def main(args: Namespace):
model = args.model
prompt = args.prompt
image_path = args.image_path
offline_inference(model, prompt, image_path)
if __name__ == "__main__":
args = parse_args()
main(args)