103 lines
2.5 KiB
Python
103 lines
2.5 KiB
Python
# Copyright 2025 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""The registry class for model."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
import logging
|
|
import re
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from .base_llm import BaseLlm
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
_llm_registry_dict: dict[str, type[BaseLlm]] = {}
|
|
"""Registry for LLMs.
|
|
|
|
Key is the regex that matches the model name.
|
|
Value is the class that implements the model.
|
|
"""
|
|
|
|
|
|
class LLMRegistry:
|
|
"""Registry for LLMs."""
|
|
|
|
@staticmethod
|
|
def new_llm(model: str) -> BaseLlm:
|
|
"""Creates a new LLM instance.
|
|
|
|
Args:
|
|
model: The model name.
|
|
|
|
Returns:
|
|
The LLM instance.
|
|
"""
|
|
|
|
return LLMRegistry.resolve(model)(model=model)
|
|
|
|
@staticmethod
|
|
def _register(model_name_regex: str, llm_cls: type[BaseLlm]):
|
|
"""Registers a new LLM class.
|
|
|
|
Args:
|
|
model_name_regex: The regex that matches the model name.
|
|
llm_cls: The class that implements the model.
|
|
"""
|
|
|
|
if model_name_regex in _llm_registry_dict:
|
|
logger.info(
|
|
'Updating LLM class for %s from %s to %s',
|
|
model_name_regex,
|
|
_llm_registry_dict[model_name_regex],
|
|
llm_cls,
|
|
)
|
|
|
|
_llm_registry_dict[model_name_regex] = llm_cls
|
|
|
|
@staticmethod
|
|
def register(llm_cls: type[BaseLlm]):
|
|
"""Registers a new LLM class.
|
|
|
|
Args:
|
|
llm_cls: The class that implements the model.
|
|
"""
|
|
|
|
for regex in llm_cls.supported_models():
|
|
LLMRegistry._register(regex, llm_cls)
|
|
|
|
@staticmethod
|
|
@lru_cache(maxsize=32)
|
|
def resolve(model: str) -> type[BaseLlm]:
|
|
"""Resolves the model to a BaseLlm subclass.
|
|
|
|
Args:
|
|
model: The model name.
|
|
|
|
Returns:
|
|
The BaseLlm subclass.
|
|
Raises:
|
|
ValueError: If the model is not found.
|
|
"""
|
|
|
|
for regex, llm_class in _llm_registry_dict.items():
|
|
if re.compile(regex).fullmatch(model):
|
|
return llm_class
|
|
|
|
raise ValueError(f'Model {model} not found.')
|