structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.bedrock import BedrockPreparedRequest
|
||||
from litellm.types.rerank import RerankRequest
|
||||
from litellm.types.utils import RerankResponse
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
from .transformation import BedrockRerankConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
else:
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class BedrockRerankHandler(BaseAWSLLM):
|
||||
async def arerank(
|
||||
self,
|
||||
prepared_request: BedrockPreparedRequest,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
if client is None:
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
||||
try:
|
||||
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response.json())
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
optional_params: dict,
|
||||
logging_obj: LitellmLogging,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
_is_async: Optional[bool] = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
)
|
||||
data = BedrockRerankConfig()._transform_request(request_data)
|
||||
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
data=cast(dict, data),
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepared_request["endpoint_url"],
|
||||
"headers": prepared_request["prepped"].headers,
|
||||
},
|
||||
)
|
||||
|
||||
if _is_async:
|
||||
return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
logging_obj.post_call(
|
||||
original_response=response.text,
|
||||
api_key="",
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response_json)
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
) -> BedrockPreparedRequest:
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params, model
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
proxy_endpoint_url = proxy_endpoint_url.replace(
|
||||
"bedrock-runtime", "bedrock-agent-runtime"
|
||||
)
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/rerank"
|
||||
sigv4 = SigV4Auth(
|
||||
boto3_credentials_info.credentials,
|
||||
"bedrock",
|
||||
boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
# Make POST Request
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=proxy_endpoint_url, data=body, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
return BedrockPreparedRequest(
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
prepped=prepped,
|
||||
body=body,
|
||||
data=data,
|
||||
)
|
||||
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Translates from Cohere's `/v1/rerank` input format to Bedrock's `/rerank` input format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
BedrockRerankBedrockRerankingConfiguration,
|
||||
BedrockRerankConfiguration,
|
||||
BedrockRerankInlineDocumentSource,
|
||||
BedrockRerankModelConfiguration,
|
||||
BedrockRerankQuery,
|
||||
BedrockRerankRequest,
|
||||
BedrockRerankSource,
|
||||
BedrockRerankTextDocument,
|
||||
BedrockRerankTextQuery,
|
||||
)
|
||||
from litellm.types.rerank import (
|
||||
RerankBilledUnits,
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
RerankResponseMeta,
|
||||
RerankResponseResult,
|
||||
RerankTokens,
|
||||
)
|
||||
|
||||
|
||||
class BedrockRerankConfig:
|
||||
def _transform_sources(
|
||||
self, documents: List[Union[str, dict]]
|
||||
) -> List[BedrockRerankSource]:
|
||||
"""
|
||||
Transform the sources from RerankRequest format to Bedrock format.
|
||||
"""
|
||||
_sources = []
|
||||
for document in documents:
|
||||
if isinstance(document, str):
|
||||
_sources.append(
|
||||
BedrockRerankSource(
|
||||
inlineDocumentSource=BedrockRerankInlineDocumentSource(
|
||||
textDocument=BedrockRerankTextDocument(text=document),
|
||||
type="TEXT",
|
||||
),
|
||||
type="INLINE",
|
||||
)
|
||||
)
|
||||
else:
|
||||
_sources.append(
|
||||
BedrockRerankSource(
|
||||
inlineDocumentSource=BedrockRerankInlineDocumentSource(
|
||||
jsonDocument=document, type="JSON"
|
||||
),
|
||||
type="INLINE",
|
||||
)
|
||||
)
|
||||
return _sources
|
||||
|
||||
def _transform_request(self, request_data: RerankRequest) -> BedrockRerankRequest:
|
||||
"""
|
||||
Transform the request from RerankRequest format to Bedrock format.
|
||||
"""
|
||||
_sources = self._transform_sources(request_data.documents)
|
||||
|
||||
return BedrockRerankRequest(
|
||||
queries=[
|
||||
BedrockRerankQuery(
|
||||
textQuery=BedrockRerankTextQuery(text=request_data.query),
|
||||
type="TEXT",
|
||||
)
|
||||
],
|
||||
rerankingConfiguration=BedrockRerankConfiguration(
|
||||
bedrockRerankingConfiguration=BedrockRerankBedrockRerankingConfiguration(
|
||||
modelConfiguration=BedrockRerankModelConfiguration(
|
||||
modelArn=request_data.model
|
||||
),
|
||||
numberOfResults=request_data.top_n or len(request_data.documents),
|
||||
),
|
||||
type="BEDROCK_RERANKING_MODEL",
|
||||
),
|
||||
sources=_sources,
|
||||
)
|
||||
|
||||
def _transform_response(self, response: dict) -> RerankResponse:
|
||||
"""
|
||||
Transform the response from Bedrock into the RerankResponse format.
|
||||
|
||||
example input:
|
||||
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
|
||||
"""
|
||||
_billed_units = RerankBilledUnits(
|
||||
**response.get("usage", {"search_units": 1})
|
||||
) # by default 1 search unit
|
||||
_tokens = RerankTokens(**response.get("usage", {}))
|
||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||
|
||||
_results: Optional[List[RerankResponseResult]] = None
|
||||
|
||||
bedrock_results = response.get("results")
|
||||
if bedrock_results:
|
||||
_results = [
|
||||
RerankResponseResult(
|
||||
index=result.get("index"),
|
||||
relevance_score=result.get("relevanceScore"),
|
||||
)
|
||||
for result in bedrock_results
|
||||
]
|
||||
|
||||
if _results is None:
|
||||
raise ValueError(f"No results found in the response={response}")
|
||||
|
||||
return RerankResponse(
|
||||
id=response.get("id") or str(uuid.uuid4()),
|
||||
results=_results,
|
||||
meta=rerank_meta,
|
||||
) # Return response
|
||||
Reference in New Issue
Block a user