structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
# Vertex AI Batch Prediction Jobs
Implementation to call VertexAI Batch endpoints in OpenAI Batch API spec
Vertex Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini

View File

@@ -0,0 +1,215 @@
import json
from typing import Any, Coroutine, Dict, Optional, Union
import httpx
import litellm
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.llms.openai import CreateBatchRequest
from litellm.types.llms.vertex_ai import (
VERTEX_CREDENTIALS_TYPES,
VertexAIBatchPredictionJob,
)
from litellm.types.utils import LiteLLMBatch
from .transformation import VertexAIBatchTransformation
class VertexAIBatchPrediction(VertexLLM):
def __init__(self, gcs_bucket_name: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gcs_bucket_name = gcs_bucket_name
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_base: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
sync_handler = _get_httpx_client()
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
default_api_base = self.create_vertex_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
)
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=None,
auth_header=None,
url=default_api_base,
)
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
vertex_batch_request: VertexAIBatchPredictionJob = VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
request=create_batch_data
)
if _is_async is True:
return self._async_create_batch(
vertex_batch_request=vertex_batch_request,
api_base=api_base,
headers=headers,
)
response = sync_handler.post(
url=api_base,
headers=headers,
data=json.dumps(vertex_batch_request),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response
async def _async_create_batch(
self,
vertex_batch_request: VertexAIBatchPredictionJob,
api_base: str,
headers: Dict[str, str],
) -> LiteLLMBatch:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
)
response = await client.post(
url=api_base,
headers=headers,
data=json.dumps(vertex_batch_request),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response
def create_vertex_url(
self,
vertex_location: str,
vertex_project: str,
) -> str:
"""Return the base url for the vertex garden models"""
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
def retrieve_batch(
self,
_is_async: bool,
batch_id: str,
api_base: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
sync_handler = _get_httpx_client()
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
default_api_base = self.create_vertex_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
)
# Append batch_id to the URL
default_api_base = f"{default_api_base}/{batch_id}"
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=None,
auth_header=None,
url=default_api_base,
)
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
if _is_async is True:
return self._async_retrieve_batch(
api_base=api_base,
headers=headers,
)
response = sync_handler.get(
url=api_base,
headers=headers,
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response
async def _async_retrieve_batch(
self,
api_base: str,
headers: Dict[str, str],
) -> LiteLLMBatch:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
)
response = await client.get(
url=api_base,
headers=headers,
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response

View File

@@ -0,0 +1,193 @@
import uuid
from typing import Dict
from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest
from litellm.types.llms.vertex_ai import *
from litellm.types.utils import LiteLLMBatch
class VertexAIBatchTransformation:
"""
Transforms OpenAI Batch requests to Vertex AI Batch requests
API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
"""
@classmethod
def transform_openai_batch_request_to_vertex_ai_batch_request(
cls,
request: CreateBatchRequest,
) -> VertexAIBatchPredictionJob:
"""
Transforms OpenAI Batch requests to Vertex AI Batch requests
"""
request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
input_file_id = request.get("input_file_id")
if input_file_id is None:
raise ValueError("input_file_id is required, but not provided")
input_config: InputConfig = InputConfig(
gcsSource=GcsSource(uris=input_file_id), instancesFormat="jsonl"
)
model: str = cls._get_model_from_gcs_file(input_file_id)
output_config: OutputConfig = OutputConfig(
predictionsFormat="jsonl",
gcsDestination=GcsDestination(
outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
),
)
return VertexAIBatchPredictionJob(
inputConfig=input_config,
outputConfig=output_config,
model=model,
displayName=request_display_name,
)
@classmethod
def transform_vertex_ai_batch_response_to_openai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> LiteLLMBatch:
return LiteLLMBatch(
id=cls._get_batch_id_from_vertex_ai_batch_response(response),
completion_window="24hrs",
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=response.get("createTime", "")
),
endpoint="",
input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
response
),
object="batch",
status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
response
),
)
@classmethod
def _get_batch_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the batch id from the Vertex AI Batch response safely
vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
returns: `3814889423749775360`
"""
_name = response.get("name", "")
if not _name:
return ""
# Split by '/' and get the last part if it exists
parts = _name.split("/")
return parts[-1] if parts else _name
@classmethod
def _get_input_file_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the input file id from the Vertex AI Batch response
"""
input_file_id: str = ""
input_config = response.get("inputConfig")
if input_config is None:
return input_file_id
gcs_source = input_config.get("gcsSource")
if gcs_source is None:
return input_file_id
uris = gcs_source.get("uris", "")
if len(uris) == 0:
return input_file_id
return uris[0]
@classmethod
def _get_output_file_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the output file id from the Vertex AI Batch response
"""
output_file_id: str = ""
output_config = response.get("outputConfig")
if output_config is None:
return output_file_id
gcs_destination = output_config.get("gcsDestination")
if gcs_destination is None:
return output_file_id
output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
return output_uri_prefix
@classmethod
def _get_batch_job_status_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> BatchJobStatus:
"""
Gets the batch job status from the Vertex AI Batch response
ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
"""
state_mapping: Dict[str, BatchJobStatus] = {
"JOB_STATE_UNSPECIFIED": "failed",
"JOB_STATE_QUEUED": "validating",
"JOB_STATE_PENDING": "validating",
"JOB_STATE_RUNNING": "in_progress",
"JOB_STATE_SUCCEEDED": "completed",
"JOB_STATE_FAILED": "failed",
"JOB_STATE_CANCELLING": "cancelling",
"JOB_STATE_CANCELLED": "cancelled",
"JOB_STATE_PAUSED": "in_progress",
"JOB_STATE_EXPIRED": "expired",
"JOB_STATE_UPDATING": "in_progress",
"JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
}
vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
return state_mapping[vertex_state]
@classmethod
def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
"""
Gets the gcs uri prefix from the input file id
Example:
input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
returns: "gs://litellm-testing-bucket"
input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
returns: "gs://litellm-testing-bucket/batches"
"""
# Split the path and remove the filename
path_parts = input_file_id.rsplit("/", 1)
return path_parts[0]
@classmethod
def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
"""
Extracts the model from the gcs file uri
When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
Why?
- Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
returns: "publishers/google/models/gemini-1.5-flash-001"
"""
from urllib.parse import unquote
decoded_uri = unquote(gcs_file_uri)
model_path = decoded_uri.split("publishers/")[1]
parts = model_path.split("/")
model = f"publishers/{'/'.join(parts[:3])}"
return model