112 lines
3.7 KiB
Python
112 lines
3.7 KiB
Python
# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/apps/fastapi_server.py
|
|
|
|
#!/usr/bin/env python
|
|
import asyncio
|
|
import base64
|
|
import io
|
|
import logging
|
|
import signal
|
|
from http import HTTPStatus
|
|
from PIL import Image
|
|
from typing import Optional
|
|
|
|
import click
|
|
import uvicorn
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse, Response
|
|
|
|
from tensorrt_llm.executor import CppExecutorError, RequestError
|
|
from dolphin_runner import DolphinRunner, InferenceConfig
|
|
|
|
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
|
|
|
|
|
async def decode_image(image_base64: str) -> Image.Image:
|
|
image_data = base64.b64decode(image_base64)
|
|
image = Image.open(io.BytesIO(image_data))
|
|
return image
|
|
|
|
|
|
class LlmServer:
|
|
def __init__(self, runner: DolphinRunner):
|
|
self.runner = runner
|
|
self.app = FastAPI()
|
|
self.register_routes()
|
|
|
|
def register_routes(self):
|
|
self.app.add_api_route("/health", self.health, methods=["GET"])
|
|
self.app.add_api_route("/generate", self.generate, methods=["POST"])
|
|
|
|
async def health(self) -> Response:
|
|
return Response(status_code=200)
|
|
|
|
async def generate(self, request: Request) -> Response:
|
|
""" Generate completion for the request.
|
|
|
|
The request should be a JSON object with the following fields:
|
|
- prompt: the prompt to use for the generation.
|
|
- image_base64: the image to use for the generation.
|
|
"""
|
|
request_dict = await request.json()
|
|
|
|
prompt = request_dict.pop("prompt", "")
|
|
logging.info(f"request prompt: {prompt}")
|
|
image_base64 = request_dict.pop("image_base64", "")
|
|
image = await decode_image(image_base64)
|
|
|
|
try:
|
|
output_texts = self.runner.run([prompt], [image], 4024)
|
|
output_texts = [texts[0] for texts in output_texts]
|
|
return JSONResponse({"text": output_texts[0]})
|
|
except RequestError as e:
|
|
return JSONResponse(content=str(e),
|
|
status_code=HTTPStatus.BAD_REQUEST)
|
|
except CppExecutorError:
|
|
# If internal executor error is raised, shutdown the server
|
|
signal.raise_signal(signal.SIGINT)
|
|
|
|
async def __call__(self, host, port):
|
|
config = uvicorn.Config(self.app,
|
|
host=host,
|
|
port=port,
|
|
log_level="info",
|
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
|
await uvicorn.Server(config).serve()
|
|
|
|
|
|
@click.command()
|
|
@click.option("--hf_model_dir", type=str, required=True)
|
|
@click.option("--visual_engine_dir", type=str, required=True)
|
|
@click.option("--llm_engine_dir", type=str, required=True)
|
|
@click.option("--max_batch_size", type=int, default=16)
|
|
@click.option("--max_new_tokens", type=int, default=4024)
|
|
@click.option("--host", type=str, default=None)
|
|
@click.option("--port", type=int, default=8000)
|
|
def entrypoint(hf_model_dir: str,
|
|
visual_engine_dir: str,
|
|
llm_engine_dir: str,
|
|
max_batch_size: int,
|
|
max_new_tokens: int,
|
|
host: Optional[str] = None,
|
|
port: int = 8000):
|
|
host = host or "0.0.0.0"
|
|
port = port or 8000
|
|
logging.info(f"Starting server at {host}:{port}")
|
|
|
|
config = InferenceConfig(
|
|
max_new_tokens=max_new_tokens,
|
|
batch_size=max_batch_size,
|
|
log_level="info",
|
|
hf_model_dir=hf_model_dir,
|
|
visual_engine_dir=visual_engine_dir,
|
|
llm_engine_dir=llm_engine_dir,
|
|
)
|
|
|
|
dolphin_runner = DolphinRunner(config)
|
|
server = LlmServer(runner=dolphin_runner)
|
|
|
|
asyncio.run(server(host, port))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
entrypoint() |