support inference by vllm
This commit is contained in:
parent
eb1737ae95
commit
6177c2686b
50
deployment/vllm/ReadMe.md
Normal file
50
deployment/vllm/ReadMe.md
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
<h1 align="center">
|
||||||
|
🚀 Dolphin vLLM Demo
|
||||||
|
</h1>
|
||||||
|
|
||||||
|
## ✅ Introduction
|
||||||
|
The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json),
|
||||||
|
its architectures field is specified as "VisionEncoderDecoderModel". vLLM does not natively support this architecture.
|
||||||
|
To enable vLLM deployment of the Dolphin model, we implemented two vllm plugins: [vllm-dolphin](https://github.com/hanyd2010/vllm-dolphin)[](https://pypi.org/project/vllm-dolphin/) and [vllm-mbart](https://github.com/hanyd2010/vllm-mbart)[](https://pypi.org/project/vllm-mbart/).
|
||||||
|
We also provide Dolphin vllm demos for both offline inference and online deployment.
|
||||||
|
|
||||||
|
## 🛠️ Installation
|
||||||
|
|
||||||
|
```
|
||||||
|
# Install vllm
|
||||||
|
pip install vllm>=0.9.0
|
||||||
|
|
||||||
|
# Install vllm-dolphin
|
||||||
|
pip install vllm-dolphin==0.1
|
||||||
|
```
|
||||||
|
|
||||||
|
## ⚡ Offline Inference
|
||||||
|
```
|
||||||
|
# predict elements reading order
|
||||||
|
python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
|
||||||
|
|
||||||
|
# recognize text/latex
|
||||||
|
python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
|
||||||
|
python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
|
||||||
|
|
||||||
|
# recognize table
|
||||||
|
python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## ⚡ Online Inference
|
||||||
|
```
|
||||||
|
# 1. Start Api Server
|
||||||
|
python deployment/vllm/api_server.py --model="ByteDance/Dolphin" --hf-overrides "{\"architectures\": [\"DolphinForConditionalGeneration\"]}"
|
||||||
|
|
||||||
|
# 2. Predict
|
||||||
|
# predict elements reading order
|
||||||
|
python deployment/vllm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
|
||||||
|
|
||||||
|
# recognize text/latex
|
||||||
|
python deployment/vllm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
|
||||||
|
python deployment/vllm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
|
||||||
|
|
||||||
|
# recognize table
|
||||||
|
python deployment/vllm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
|
||||||
|
```
|
104
deployment/vllm/api_client.py
Normal file
104
deployment/vllm/api_client.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Example Python client for `vllm.entrypoints.api_server`
|
||||||
|
Start the demo server:
|
||||||
|
python -m vllm.entrypoints.api_server --model <model_name>
|
||||||
|
|
||||||
|
NOTE: The API server is used only for demonstration and simple performance
|
||||||
|
benchmarks. It is not intended for production use.
|
||||||
|
For production use, we recommend `vllm serve` and the OpenAI client API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from argparse import Namespace
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def clear_line(n: int = 1) -> None:
|
||||||
|
LINE_UP = "\033[1A"
|
||||||
|
LINE_CLEAR = "\x1b[2K"
|
||||||
|
for _ in range(n):
|
||||||
|
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image_base64(image_path: str) -> str:
|
||||||
|
"""Encode local image to base64 format."""
|
||||||
|
|
||||||
|
with open(image_path, "rb") as f:
|
||||||
|
image_data = f.read()
|
||||||
|
result = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def post_http_request(
|
||||||
|
prompt: str, image_path: str, api_url: str, stream: bool = False
|
||||||
|
) -> requests.Response:
|
||||||
|
headers = {"User-Agent": "Test Client"}
|
||||||
|
pload = {
|
||||||
|
"encoder_prompt": "",
|
||||||
|
"decoder_prompt": prompt,
|
||||||
|
"image_base64": encode_image_base64(image_path),
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
|
||||||
|
for chunk in response.iter_lines(
|
||||||
|
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
|
||||||
|
):
|
||||||
|
if chunk:
|
||||||
|
data = json.loads(chunk.decode("utf-8"))
|
||||||
|
output = data["text"]
|
||||||
|
yield output
|
||||||
|
|
||||||
|
|
||||||
|
def get_response(response: requests.Response) -> list[str]:
|
||||||
|
data = json.loads(response.content)
|
||||||
|
output = data["text"]
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
|
||||||
|
parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
|
||||||
|
parser.add_argument("--stream", action="store_true")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: Namespace):
|
||||||
|
prompt = args.prompt
|
||||||
|
image_path = args.image_path
|
||||||
|
api_url = f"http://{args.host}:{args.port}/generate"
|
||||||
|
stream = args.stream
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt!r}\n", flush=True)
|
||||||
|
response = post_http_request(prompt, image_path, api_url, stream)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
num_printed_lines = 0
|
||||||
|
for h in get_streaming_response(response):
|
||||||
|
clear_line(num_printed_lines)
|
||||||
|
num_printed_lines = 0
|
||||||
|
for i, line in enumerate(h):
|
||||||
|
num_printed_lines += 1
|
||||||
|
print(f"Response {i}: {line!r}", flush=True)
|
||||||
|
else:
|
||||||
|
output = get_response(response)
|
||||||
|
print(f"Response: {output[0]!r}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
215
deployment/vllm/api_server.py
Normal file
215
deployment/vllm/api_server.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
NOTE: This API server is used only for demonstrating usage of AsyncEngine
|
||||||
|
and simple performance benchmarks. It is not intended for production use.
|
||||||
|
For production use, we recommend using our OpenAI compatible server.
|
||||||
|
We are also not going to accept PRs modifying this file, please
|
||||||
|
change `vllm/entrypoints/openai/api_server.py` instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import io
|
||||||
|
import ssl
|
||||||
|
from argparse import Namespace
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.entrypoints.launcher import serve_http
|
||||||
|
from vllm.entrypoints.utils import with_cancellation
|
||||||
|
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
|
||||||
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
|
logger = init_logger("api_server")
|
||||||
|
|
||||||
|
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||||
|
app = FastAPI()
|
||||||
|
engine = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> Response:
|
||||||
|
"""Health check."""
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate")
|
||||||
|
async def generate(request: Request) -> Response:
|
||||||
|
"""Generate completion for the request.
|
||||||
|
|
||||||
|
The request should be a JSON object with the following fields:
|
||||||
|
- prompt: the prompt to use for the generation.
|
||||||
|
- stream: whether to stream the results or not.
|
||||||
|
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||||
|
"""
|
||||||
|
request_dict = await request.json()
|
||||||
|
return await _generate(request_dict, raw_request=request)
|
||||||
|
|
||||||
|
|
||||||
|
async def decode_image(image_base64: str) -> Image.Image:
|
||||||
|
image_data = base64.b64decode(image_base64)
|
||||||
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
async def custom_process_prompt(encoder_prompt: str, decoder_prompt: str,
|
||||||
|
image_base64: str) -> ExplicitEncoderDecoderPrompt:
|
||||||
|
assert engine is not None
|
||||||
|
tokenizer = engine.engine.get_tokenizer_group().tokenizer
|
||||||
|
image = await decode_image(image_base64)
|
||||||
|
|
||||||
|
if encoder_prompt == "":
|
||||||
|
encoder_prompt = "0" * 783 # For Dolphin
|
||||||
|
|
||||||
|
if decoder_prompt == "":
|
||||||
|
decoder_prompt_ids = tokenizer.bos_token_id
|
||||||
|
else:
|
||||||
|
decoder_prompt = f"<s>{decoder_prompt.strip()} <Answer/>"
|
||||||
|
decoder_prompt_ids = tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
|
||||||
|
encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}),
|
||||||
|
decoder_prompt=TokensPrompt(prompt_token_ids=decoder_prompt_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
return enc_dec_prompt
|
||||||
|
|
||||||
|
|
||||||
|
@with_cancellation
|
||||||
|
async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
||||||
|
encoder_prompt = request_dict.pop("encoder_prompt", "")
|
||||||
|
decoder_prompt = request_dict.pop("decoder_prompt", "")
|
||||||
|
image_base64 = request_dict.pop("image_base64", "")
|
||||||
|
stream = request_dict.pop("stream", False)
|
||||||
|
sampling_params = SamplingParams(**request_dict)
|
||||||
|
request_id = random_uuid()
|
||||||
|
|
||||||
|
assert engine is not None
|
||||||
|
|
||||||
|
enc_dec_prompt = await custom_process_prompt(encoder_prompt, decoder_prompt, image_base64)
|
||||||
|
results_generator = engine.generate(enc_dec_prompt, sampling_params, request_id)
|
||||||
|
|
||||||
|
# Streaming case
|
||||||
|
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||||
|
async for request_output in results_generator:
|
||||||
|
prompt = request_output.prompt
|
||||||
|
assert prompt is not None
|
||||||
|
text_outputs = [
|
||||||
|
prompt + output.text for output in request_output.outputs
|
||||||
|
]
|
||||||
|
ret = {"text": text_outputs}
|
||||||
|
yield (json.dumps(ret) + "\n").encode("utf-8")
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return StreamingResponse(stream_results())
|
||||||
|
|
||||||
|
# Non-streaming case
|
||||||
|
final_output = None
|
||||||
|
try:
|
||||||
|
async for request_output in results_generator:
|
||||||
|
final_output = request_output
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return Response(status_code=499)
|
||||||
|
|
||||||
|
assert final_output is not None
|
||||||
|
prompt = final_output.prompt
|
||||||
|
assert prompt is not None
|
||||||
|
text_outputs = [prompt + output.text.strip() for output in final_output.outputs]
|
||||||
|
ret = {"text": text_outputs}
|
||||||
|
return JSONResponse(ret)
|
||||||
|
|
||||||
|
|
||||||
|
def build_app(args: Namespace) -> FastAPI:
|
||||||
|
global app
|
||||||
|
|
||||||
|
app.root_path = args.root_path
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def init_app(
|
||||||
|
args: Namespace,
|
||||||
|
llm_engine: Optional[AsyncLLMEngine] = None,
|
||||||
|
) -> FastAPI:
|
||||||
|
app = build_app(args)
|
||||||
|
|
||||||
|
global engine
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
engine = (llm_engine
|
||||||
|
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
|
||||||
|
engine_args, usage_context=UsageContext.API_SERVER))
|
||||||
|
app.state.engine_client = engine
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def run_server(args: Namespace,
|
||||||
|
llm_engine: Optional[AsyncLLMEngine] = None,
|
||||||
|
**uvicorn_kwargs: Any) -> None:
|
||||||
|
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||||
|
logger.info("args: %s", args)
|
||||||
|
|
||||||
|
set_ulimit()
|
||||||
|
|
||||||
|
app = await init_app(args, llm_engine)
|
||||||
|
assert engine is not None
|
||||||
|
|
||||||
|
shutdown_task = await serve_http(
|
||||||
|
app,
|
||||||
|
sock=None,
|
||||||
|
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
log_level=args.log_level,
|
||||||
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||||
|
ssl_keyfile=args.ssl_keyfile,
|
||||||
|
ssl_certfile=args.ssl_certfile,
|
||||||
|
ssl_ca_certs=args.ssl_ca_certs,
|
||||||
|
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||||
|
**uvicorn_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
await shutdown_task
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default=None)
|
||||||
|
parser.add_argument("--port", type=parser.check_port, default=8000)
|
||||||
|
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||||
|
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||||
|
parser.add_argument("--ssl-ca-certs",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The CA certificates file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-ssl-refresh",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Refresh SSL Context when SSL certificate files change")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ssl-cert-reqs",
|
||||||
|
type=int,
|
||||||
|
default=int(ssl.CERT_NONE),
|
||||||
|
help="Whether client certificate is required (see stdlib ssl module's)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--root-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||||
|
parser.add_argument("--log-level", type=str, default="debug")
|
||||||
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(run_server(args))
|
91
deployment/vllm/demo_vllm.py
Normal file
91
deployment/vllm/demo_vllm.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import vllm_dolphin # vllm_dolphin plugin
|
||||||
|
import argparse
|
||||||
|
from argparse import Namespace
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
def offline_inference(model_id: str, prompt: str, image_path: str, max_tokens: int = 2048):
|
||||||
|
dtype = "float16" if torch.cuda.is_available() else "float32"
|
||||||
|
# Create an encoder/decoder model instance
|
||||||
|
llm = LLM(
|
||||||
|
model=model_id,
|
||||||
|
dtype=dtype,
|
||||||
|
enforce_eager=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_num_seqs=8,
|
||||||
|
hf_overrides={"architectures": ["DolphinForConditionalGeneration"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=0,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
prompt_logprobs=None,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# process prompt
|
||||||
|
tokenizer = llm.llm_engine.get_tokenizer_group().tokenizer
|
||||||
|
|
||||||
|
# The Dolphin model does not require an Encoder Prompt. To ensure vllm correctly allocates KV Cache,
|
||||||
|
# it is necessary to simulate an Encoder Prompt.
|
||||||
|
encoder_prompt = "0" * 783
|
||||||
|
decoder_prompt = f"<s>{prompt.strip()} <Answer/>"
|
||||||
|
|
||||||
|
image = Image.open(image_path)
|
||||||
|
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
|
||||||
|
encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}),
|
||||||
|
decoder_prompt=TokensPrompt(
|
||||||
|
prompt_token_ids=tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate output tokens from the prompts. The output is a list of
|
||||||
|
# RequestOutput objects that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(enc_dec_prompt, sampling_params)
|
||||||
|
|
||||||
|
print("------" * 8)
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
decoder_prompt_tokens = tokenizer.batch_decode(output.prompt_token_ids, skip_special_tokens=True)
|
||||||
|
decoder_prompt = "".join(decoder_prompt_tokens)
|
||||||
|
generated_text = output.outputs[0].text.strip()
|
||||||
|
print(f"Decoder prompt: {decoder_prompt!r}, "
|
||||||
|
f"\nGenerated text: {generated_text!r}")
|
||||||
|
|
||||||
|
print("------" * 8)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model", type=str, default="ByteDance/Dolphin")
|
||||||
|
parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
|
||||||
|
parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: Namespace):
|
||||||
|
model = args.model
|
||||||
|
prompt = args.prompt
|
||||||
|
image_path = args.image_path
|
||||||
|
|
||||||
|
offline_inference(model, prompt, image_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
Loading…
Reference in New Issue
Block a user