Dolphin/deployment/tensorrt_llm/api_server.py
2025-06-30 19:41:03 +08:00

112 lines
3.7 KiB
Python

# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/apps/fastapi_server.py
#!/usr/bin/env python
import asyncio
import base64
import io
import logging
import signal
from http import HTTPStatus
from PIL import Image
from typing import Optional
import click
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response
from tensorrt_llm.executor import CppExecutorError, RequestError
from dolphin_runner import DolphinRunner, InferenceConfig
TIMEOUT_KEEP_ALIVE = 5 # seconds.
async def decode_image(image_base64: str) -> Image.Image:
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data))
return image
class LlmServer:
def __init__(self, runner: DolphinRunner):
self.runner = runner
self.app = FastAPI()
self.register_routes()
def register_routes(self):
self.app.add_api_route("/health", self.health, methods=["GET"])
self.app.add_api_route("/generate", self.generate, methods=["POST"])
async def health(self) -> Response:
return Response(status_code=200)
async def generate(self, request: Request) -> Response:
""" Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- image_base64: the image to use for the generation.
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt", "")
logging.info(f"request prompt: {prompt}")
image_base64 = request_dict.pop("image_base64", "")
image = await decode_image(image_base64)
try:
output_texts = self.runner.run([prompt], [image], 4024)
output_texts = [texts[0] for texts in output_texts]
return JSONResponse({"text": output_texts[0]})
except RequestError as e:
return JSONResponse(content=str(e),
status_code=HTTPStatus.BAD_REQUEST)
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
async def __call__(self, host, port):
config = uvicorn.Config(self.app,
host=host,
port=port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
await uvicorn.Server(config).serve()
@click.command()
@click.option("--hf_model_dir", type=str, required=True)
@click.option("--visual_engine_dir", type=str, required=True)
@click.option("--llm_engine_dir", type=str, required=True)
@click.option("--max_batch_size", type=int, default=16)
@click.option("--max_new_tokens", type=int, default=4024)
@click.option("--host", type=str, default=None)
@click.option("--port", type=int, default=8000)
def entrypoint(hf_model_dir: str,
visual_engine_dir: str,
llm_engine_dir: str,
max_batch_size: int,
max_new_tokens: int,
host: Optional[str] = None,
port: int = 8000):
host = host or "0.0.0.0"
port = port or 8000
logging.info(f"Starting server at {host}:{port}")
config = InferenceConfig(
max_new_tokens=max_new_tokens,
batch_size=max_batch_size,
log_level="info",
hf_model_dir=hf_model_dir,
visual_engine_dir=visual_engine_dir,
llm_engine_dir=llm_engine_dir,
)
dolphin_runner = DolphinRunner(config)
server = LlmServer(runner=dolphin_runner)
asyncio.run(server(host, port))
if __name__ == "__main__":
entrypoint()