253 lines
6.6 KiB
Python
253 lines
6.6 KiB
Python
# Copyright 2024 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
"""Pagers for the GenAI List APIs."""
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
import copy
|
|
from typing import Any, AsyncIterator,Awaitable, Callable, Generic, Iterator, Literal, TypeVar
|
|
|
|
T = TypeVar('T')
|
|
|
|
PagedItem = Literal[
|
|
'batch_jobs', 'models', 'tuning_jobs', 'files', 'cached_contents'
|
|
]
|
|
|
|
|
|
class _BasePager(Generic[T]):
|
|
"""Base pager class for iterating through paginated results."""
|
|
|
|
def _init_page(
|
|
self,
|
|
name: PagedItem,
|
|
request: Callable[..., Any],
|
|
response: Any,
|
|
config: Any,
|
|
) -> None:
|
|
self._name = name
|
|
self._request = request
|
|
|
|
self._page = getattr(response, self._name) or []
|
|
self._idx = 0
|
|
|
|
if not config:
|
|
request_config = {}
|
|
elif isinstance(config, dict):
|
|
request_config = copy.deepcopy(config)
|
|
else:
|
|
request_config = dict(config)
|
|
request_config['page_token'] = getattr(response, 'next_page_token')
|
|
self._config = request_config
|
|
|
|
self._page_size: int = request_config.get('page_size', len(self._page))
|
|
|
|
def __init__(
|
|
self,
|
|
name: PagedItem,
|
|
request: Callable[..., Any],
|
|
response: Any,
|
|
config: Any,
|
|
):
|
|
self._init_page(name, request, response, config)
|
|
|
|
@property
|
|
def page(self) -> list[T]:
|
|
"""Returns a subset of the entire list of items.
|
|
|
|
For the number of items returned, see `pageSize()`.
|
|
|
|
Usage:
|
|
|
|
.. code-block:: python
|
|
|
|
batch_jobs_pager = client.batches.list(config={'page_size': 5})
|
|
print(f"first page: {batch_jobs_pager.page}")
|
|
# first page: [BatchJob(name='projects/./locations/./batchPredictionJobs/1
|
|
"""
|
|
|
|
return self._page
|
|
|
|
@property
|
|
def name(self) -> PagedItem:
|
|
"""Returns the type of paged item (for example, ``batch_jobs``).
|
|
|
|
Usage:
|
|
|
|
.. code-block:: python
|
|
|
|
batch_jobs_pager = client.batches.list(config={'page_size': 5})
|
|
print(f"name: {batch_jobs_pager.name}")
|
|
# name: batch_jobs
|
|
"""
|
|
|
|
return self._name
|
|
|
|
@property
|
|
def page_size(self) -> int:
|
|
"""Returns the maximum number of items fetched by the pager at one time.
|
|
|
|
Usage:
|
|
|
|
.. code-block:: python
|
|
|
|
batch_jobs_pager = client.batches.list(config={'page_size': 5})
|
|
print(f"page_size: {batch_jobs_pager.page_size}")
|
|
# page_size: 5
|
|
"""
|
|
|
|
return self._page_size
|
|
|
|
@property
|
|
def config(self) -> dict[str, Any]:
|
|
"""Returns the configuration when making the API request for the next page.
|
|
|
|
A configuration is a set of optional parameters and arguments that can be
|
|
used to customize the API request. For example, the ``page_token`` parameter
|
|
contains the token to request the next page.
|
|
|
|
Usage:
|
|
|
|
.. code-block:: python
|
|
|
|
batch_jobs_pager = client.batches.list(config={'page_size': 5})
|
|
print(f"config: {batch_jobs_pager.config}")
|
|
# config: {'page_size': 5, 'page_token': 'AMEw9yO5jnsGnZJLHSKDFHJJu'}
|
|
"""
|
|
|
|
return self._config
|
|
|
|
def __len__(self) -> int:
|
|
"""Returns the total number of items in the current page."""
|
|
return len(self.page)
|
|
|
|
def __getitem__(self, index: int) -> T:
|
|
"""Returns the item at the given index."""
|
|
return self.page[index]
|
|
|
|
def _init_next_page(self, response: Any) -> None:
|
|
"""Initializes the next page from the response.
|
|
|
|
This is an internal method that should be called by subclasses after
|
|
fetching the next page.
|
|
|
|
Args:
|
|
response: The response object from the API request.
|
|
"""
|
|
self._init_page(self.name, self._request, response, self.config)
|
|
|
|
|
|
class Pager(_BasePager[T]):
|
|
"""Pager class for iterating through paginated results."""
|
|
|
|
def __next__(self) -> T:
|
|
"""Returns the next item."""
|
|
if self._idx >= len(self):
|
|
try:
|
|
self.next_page()
|
|
except IndexError:
|
|
raise StopIteration
|
|
|
|
item = self.page[self._idx]
|
|
self._idx += 1
|
|
return item
|
|
|
|
def __iter__(self) -> Iterator[T]:
|
|
"""Returns an iterator over the items."""
|
|
self._idx = 0
|
|
return self
|
|
|
|
def next_page(self) -> list[T]:
|
|
"""Fetches the next page of items. This makes a new API request.
|
|
|
|
Usage:
|
|
|
|
.. code-block:: python
|
|
|
|
batch_jobs_pager = client.batches.list(config={'page_size': 5})
|
|
print(f"current page: {batch_jobs_pager.page}")
|
|
batch_jobs_pager.next_page()
|
|
print(f"next page: {batch_jobs_pager.page}")
|
|
# current page: [BatchJob(name='projects/.../batchPredictionJobs/1
|
|
# next page: [BatchJob(name='projects/.../batchPredictionJobs/6
|
|
"""
|
|
|
|
if not self.config.get('page_token'):
|
|
raise IndexError('No more pages to fetch.')
|
|
|
|
response = self._request(config=self.config)
|
|
self._init_next_page(response)
|
|
return self.page
|
|
|
|
|
|
class AsyncPager(_BasePager[T]):
|
|
"""AsyncPager class for iterating through paginated results."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: PagedItem,
|
|
request: Callable[..., Awaitable[Any]],
|
|
response: Any,
|
|
config: Any,
|
|
):
|
|
super().__init__(name, request, response, config)
|
|
|
|
def __aiter__(self) -> AsyncIterator[T]:
|
|
"""Returns an async iterator over the items."""
|
|
self._idx = 0
|
|
return self
|
|
|
|
async def __anext__(self) -> T:
|
|
"""Returns the next item asynchronously."""
|
|
if self._idx >= len(self):
|
|
try:
|
|
await self.next_page()
|
|
except IndexError:
|
|
raise StopAsyncIteration
|
|
|
|
item = self.page[self._idx]
|
|
self._idx += 1
|
|
return item
|
|
|
|
async def next_page(self) -> list[T]:
|
|
"""Fetches the next page of items asynchronously.
|
|
|
|
This makes a new API request.
|
|
|
|
Returns:
|
|
The next page of items.
|
|
|
|
Raises:
|
|
IndexError: No more pages to fetch.
|
|
|
|
Usage:
|
|
|
|
.. code-block:: python
|
|
|
|
batch_jobs_pager = await client.aio.batches.list(config={'page_size': 5})
|
|
print(f"current page: {batch_jobs_pager.page}")
|
|
await batch_jobs_pager.next_page()
|
|
print(f"next page: {batch_jobs_pager.page}")
|
|
# current page: [BatchJob(name='projects/.../batchPredictionJobs/1
|
|
# next page: [BatchJob(name='projects/.../batchPredictionJobs/6
|
|
"""
|
|
|
|
if not self.config.get('page_token'):
|
|
raise IndexError('No more pages to fetch.')
|
|
|
|
response = await self._request(config=self.config)
|
|
self._init_next_page(response)
|
|
return self.page
|