""" Copyright (c) 2025 Bytedance Ltd. and/or its affiliates SPDX-License-Identifier: MIT """ import vllm_dolphin # vllm_dolphin plugin import argparse from argparse import Namespace from PIL import Image from vllm import LLM, SamplingParams from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt import torch import os os.environ["TOKENIZERS_PARALLELISM"] = "false" def offline_inference(model_id: str, prompt: str, image_path: str, max_tokens: int = 2048): dtype = "float16" if torch.cuda.is_available() else "float32" # Create an encoder/decoder model instance llm = LLM( model=model_id, dtype=dtype, enforce_eager=True, trust_remote_code=True, max_num_seqs=8, hf_overrides={"architectures": ["DolphinForConditionalGeneration"]}, ) # Create a sampling params object. sampling_params = SamplingParams( temperature=0.0, logprobs=0, max_tokens=max_tokens, prompt_logprobs=None, skip_special_tokens=False, ) # process prompt tokenizer = llm.llm_engine.get_tokenizer_group().tokenizer # The Dolphin model does not require an Encoder Prompt. To ensure vllm correctly allocates KV Cache, # it is necessary to simulate an Encoder Prompt. encoder_prompt = "0" * 783 decoder_prompt = f"{prompt.strip()} " image = Image.open(image_path) enc_dec_prompt = ExplicitEncoderDecoderPrompt( encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), decoder_prompt=TokensPrompt( prompt_token_ids=tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"] ), ) # Generate output tokens from the prompts. The output is a list of # RequestOutput objects that contain the prompt, generated text, and other information. outputs = llm.generate(enc_dec_prompt, sampling_params) print("------" * 8) # Print the outputs. for output in outputs: decoder_prompt_tokens = tokenizer.batch_decode(output.prompt_token_ids, skip_special_tokens=True) decoder_prompt = "".join(decoder_prompt_tokens) generated_text = output.outputs[0].text.strip() print(f"Decoder prompt: {decoder_prompt!r}, " f"\nGenerated text: {generated_text!r}") print("------" * 8) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="ByteDance/Dolphin") parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg") parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.") return parser.parse_args() def main(args: Namespace): model = args.model prompt = args.prompt image_path = args.image_path offline_inference(model, prompt, image_path) if __name__ == "__main__": args = parse_args() main(args)