diff --git a/.cursorrules b/.cursorrules index 35727f6d..d037a929 100644 --- a/.cursorrules +++ b/.cursorrules @@ -17,9 +17,16 @@ ``` src/ ├── api/ -│ ├── routes.py # API routes definition -│ ├── auth_routes.py # Authentication routes (login, registration, etc.) -│ └── admin_routes.py # Protected admin routes +│ ├── __init__.py # Package initialization +│ ├── admin_routes.py # Admin routes for management interface +│ ├── agent_routes.py # Routes to manage agents +│ ├── auth_routes.py # Authentication routes (login, registration) +│ ├── chat_routes.py # Routes for chat interactions with agents +│ ├── client_routes.py # Routes to manage clients +│ ├── contact_routes.py # Routes to manage contacts +│ ├── mcp_server_routes.py # Routes to manage MCP servers +│ ├── session_routes.py # Routes to manage chat sessions +│ └── tool_routes.py # Routes to manage tools for agents ├── config/ │ ├── database.py # Database configuration │ └── settings.py # General settings @@ -30,18 +37,21 @@ src/ │ └── models.py # SQLAlchemy models ├── schemas/ │ ├── schemas.py # Main Pydantic schemas +│ ├── chat.py # Chat schemas │ ├── user.py # User and authentication schemas │ └── audit.py # Audit logs schemas ├── services/ │ ├── agent_service.py # Business logic for agents +│ ├── agent_runner.py # Agent execution logic +│ ├── auth_service.py # JWT authentication logic +│ ├── audit_service.py # Audit logs logic │ ├── client_service.py # Business logic for clients │ ├── contact_service.py # Business logic for contacts -│ ├── mcp_server_service.py # Business logic for MCP servers -│ ├── tool_service.py # Business logic for tools -│ ├── user_service.py # User and authentication logic -│ ├── auth_service.py # JWT authentication logic │ ├── email_service.py # Email sending service -│ └── audit_service.py # Audit logs logic +│ ├── mcp_server_service.py # Business logic for MCP servers +│ ├── session_service.py # Business logic for chat sessions +│ ├── tool_service.py # Business logic for tools +│ └── user_service.py # User management logic ├── templates/ │ ├── emails/ │ │ ├── base_email.html # Base template with common structure and styles @@ -49,7 +59,21 @@ src/ │ │ ├── password_reset.html # Password reset template │ │ ├── welcome_email.html # Welcome email after verification │ │ └── account_locked.html # Security alert for locked accounts +├── tests/ +│ ├── __init__.py # Package initialization +│ ├── api/ +│ │ ├── __init__.py # Package initialization +│ │ ├── test_auth_routes.py # Test for authentication routes +│ │ └── test_root.py # Test for root endpoint +│ ├── models/ +│ │ ├── __init__.py # Package initialization +│ │ ├── test_models.py # Test for models +│ ├── services/ +│ │ ├── __init__.py # Package initialization +│ │ ├── test_auth_service.py # Test for authentication service +│ │ └── test_user_service.py # Test for user service └── utils/ + ├── logger.py # Logger configuration └── security.py # Security utilities (JWT, hash) ``` @@ -63,6 +87,15 @@ src/ - Code examples in documentation must be in English - Commit messages must be in English +### Project Configuration +- Dependencies managed in `pyproject.toml` using modern Python packaging standards +- Development dependencies specified as optional dependencies in `pyproject.toml` +- Single source of truth for project metadata in `pyproject.toml` +- Build system configured to use setuptools +- Pytest configuration in `pyproject.toml` under `[tool.pytest.ini_options]` +- Code formatting with Black configured in `pyproject.toml` +- Linting with Flake8 configured in `.flake8` + ### Schemas (Pydantic) - Use `BaseModel` as base for all schemas - Define fields with explicit types @@ -136,6 +169,28 @@ src/ - Indentation with 4 spaces - Maximum of 79 characters per line +## Commit Rules +- Use Conventional Commits format for all commit messages +- Format: `(): ` +- Types: + - `feat`: A new feature + - `fix`: A bug fix + - `docs`: Documentation changes + - `style`: Changes that do not affect code meaning (formatting, etc.) + - `refactor`: Code changes that neither fix a bug nor add a feature + - `perf`: Performance improvements + - `test`: Adding or modifying tests + - `chore`: Changes to build process or auxiliary tools +- Scope is optional and should be the module or component affected +- Description should be concise, in the imperative mood, and not capitalized +- Use body for more detailed explanations if needed +- Reference issues in the footer with `Fixes #123` or `Relates to #123` +- Examples: + - `feat(auth): add password reset functionality` + - `fix(api): correct validation error in client registration` + - `docs: update API documentation for new endpoints` + - `refactor(services): improve error handling in authentication` + ## Best Practices - Always validate input data - Implement appropriate logging @@ -163,6 +218,7 @@ src/ ## Useful Commands - `make run`: Start the server +- `make run-prod`: Start the server in production mode - `make alembic-revision message="description"`: Create new migration - `make alembic-upgrade`: Apply pending migrations - `make alembic-downgrade`: Revert last migration @@ -170,3 +226,8 @@ src/ - `make alembic-reset`: Reset database to initial state - `make alembic-upgrade-cascade`: Force upgrade removing dependencies - `make clear-cache`: Clean project cache +- `make seed-all`: Run all database seeders +- `make lint`: Run linting checks with flake8 +- `make format`: Format code with black +- `make install`: Install project for development +- `make install-dev`: Install project with development dependencies diff --git a/.dockerignore b/.dockerignore index bd636738..3bd4b3f1 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,47 @@ +# Environment and IDE +.venv +venv +.env +.idea +.vscode +__pycache__ +*.pyc +*.pyo +*.pyd +.Python +.pytest_cache +.coverage +htmlcov/ +.tox/ + +# Version control +.git +.github +.gitignore + +# Logs and temp files +logs +*.log +tmp +.DS_Store + +# Docker +.dockerignore +Dockerfile* +docker-compose* + +# Documentation +README.md +LICENSE +docs/ + +# Development tools +tests/ +.flake8 +pyproject.toml +requirements-dev.txt +Makefile + # Ambiente virtual venv/ __pycache__/ diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..ba84055c --- /dev/null +++ b/.flake8 @@ -0,0 +1,8 @@ +[flake8] +max-line-length = 88 +exclude = .git,__pycache__,venv,alembic/versions/* +ignore = E203, W503, E501 +per-file-ignores = + __init__.py: F401 + src/models/models.py: E712 + alembic/*: E711,E712,F401 \ No newline at end of file diff --git a/.gitignore b/.gitignore index b927ec57..54401fdc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,10 @@ -# Byte-compiled / optimized / DLL files +# Python __pycache__/ *.py[cod] *$py.class - -# C extensions *.so - -# Distribution / packaging .Python +env/ build/ develop-eggs/ dist/ @@ -19,19 +16,10 @@ lib64/ parts/ sdist/ var/ -wheels/ *.egg-info/ .installed.cfg *.egg -# PyInstaller -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - # Unit test / coverage reports htmlcov/ .tox/ @@ -42,14 +30,57 @@ nosetests.xml coverage.xml *.cover .hypothesis/ +.pytest_cache/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDE +.idea/ +.vscode/ +*.sublime-project +*.sublime-workspace +.DS_Store + +# Logs +logs/ +*.log + +# Database +*.db +*.sqlite +*.sqlite3 +backup/ + +# Local +local_settings.py +local.py + +# Docker +.docker/ + +# Alembic versions +# alembic/versions/ + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt # Translations *.mo *.pot # Django stuff: -*.log -local_settings.py db.sqlite3 # Flask stuff: @@ -77,17 +108,6 @@ celerybeat-schedule # SageMath parsed files *.sage.py -# Environments -.env -.venv -env/ -venv/ -.venv/ -.env/ -ENV/ -env.bak/ -venv.bak/ - # Spyder project settings .spyderproject .spyproject @@ -102,11 +122,8 @@ venv.bak/ .mypy_cache/ # IDE -.idea/ -.vscode/ *.swp *.swo # OS -.DS_Store Thumbs.db \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 007c51a7..f56e756e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11-slim +FROM python:3.10-slim # Define o diretório de trabalho WORKDIR /app @@ -15,19 +15,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* -# Copia os arquivos de requisitos -COPY requirements.txt . - -# Instala as dependências -RUN pip install --no-cache-dir -r requirements.txt - -# Copia o código-fonte +# Copy project files COPY . . +# Install dependencies +RUN pip install --no-cache-dir -e . + # Configuração para produção ENV PORT=8000 \ HOST=0.0.0.0 \ DEBUG=false +# Expose port +EXPOSE 8000 + # Define o comando de inicialização CMD alembic upgrade head && uvicorn src.main:app --host $HOST --port $PORT \ No newline at end of file diff --git a/Makefile b/Makefile index aec63a95..cd211cb2 100644 --- a/Makefile +++ b/Makefile @@ -1,42 +1,42 @@ -.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs +.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs lint format install install-dev venv -# Comandos do Alembic +# Alembic commands init: alembic init alembics -# make alembic-revision message="descrição da migração" +# make alembic-revision message="migration description" alembic-revision: alembic revision --autogenerate -m "$(message)" -# Comando para atualizar o banco de dados +# Command to update database to latest version alembic-upgrade: alembic upgrade head -# Comando para voltar uma versão +# Command to downgrade one version alembic-downgrade: alembic downgrade -1 -# Comando para rodar o servidor +# Command to run the server run: uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 --reload-dir src -# Comando para limpar o cache em todas as pastas do projeto pastas pycache +# Command to run the server in production mode +run-prod: + uvicorn src.main:app --host 0.0.0.0 --port 8000 --workers 4 + +# Command to clean cache in all project folders clear-cache: rm -rf ~/.cache/uv/environments-v2/* && find . -type d -name "__pycache__" -exec rm -r {} + -# Comando para criar uma nova migração +# Command to create a new migration and apply it alembic-migrate: alembic revision --autogenerate -m "$(message)" && alembic upgrade head -# Comando para resetar o banco de dados +# Command to reset the database alembic-reset: alembic downgrade base && alembic upgrade head - -# Comando para forçar upgrade com CASCADE -alembic-upgrade-cascade: - psql -U postgres -d a2a_saas -c "DROP TABLE IF EXISTS events CASCADE; DROP TABLE IF EXISTS sessions CASCADE; DROP TABLE IF EXISTS user_states CASCADE; DROP TABLE IF EXISTS app_states CASCADE;" && alembic upgrade head - -# Comandos para executar seeders + +# Commands to run seeders seed-admin: python -m scripts.seeders.admin_seeder @@ -58,7 +58,7 @@ seed-contacts: seed-all: python -m scripts.run_seeders -# Comandos Docker +# Docker commands docker-build: docker-compose build @@ -72,4 +72,21 @@ docker-logs: docker-compose logs -f docker-seed: - docker-compose exec api python -m scripts.run_seeders \ No newline at end of file + docker-compose exec api python -m scripts.run_seeders + +# Testing, linting and formatting commands +lint: + flake8 src/ tests/ + +format: + black src/ tests/ + +# Virtual environment and installation commands +venv: + python -m venv venv + +install: + pip install -e . + +install-dev: + pip install -e ".[dev]" \ No newline at end of file diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..83263505 --- /dev/null +++ b/conftest.py @@ -0,0 +1,53 @@ +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +from src.config.database import Base, get_db +from src.main import app + +# Use in-memory SQLite for tests +SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:" + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +@pytest.fixture(scope="function") +def db_session(): + """Creates a fresh database session for each test.""" + Base.metadata.create_all(bind=engine) # Create tables + + connection = engine.connect() + transaction = connection.begin() + session = TestingSessionLocal(bind=connection) + + # Use our test database instead of the standard one + def override_get_db(): + try: + yield session + session.commit() + finally: + session.close() + + app.dependency_overrides[get_db] = override_get_db + + yield session # The test will run here + + # Teardown + transaction.rollback() + connection.close() + Base.metadata.drop_all(bind=engine) + app.dependency_overrides.clear() + + +@pytest.fixture(scope="function") +def client(db_session): + """Creates a FastAPI TestClient with database session fixture.""" + with TestClient(app) as test_client: + yield test_client \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..f5a3a8e0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,89 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "evo-ai" +version = "1.0.0" +description = "API for executing AI agents" +readme = "README.md" +authors = [ + {name = "EvoAI Team", email = "admin@evoai.com"} +] +requires-python = ">=3.10" +license = {text = "Proprietary"} +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "License :: Other/Proprietary License", + "Operating System :: OS Independent", +] + +# Main dependencies +dependencies = [ + "fastapi==0.115.12", + "uvicorn==0.34.2", + "pydantic==2.11.3", + "sqlalchemy==2.0.40", + "psycopg2-binary==2.9.10", + "google-cloud-aiplatform==1.90.0", + "python-dotenv==1.1.0", + "google-adk==0.3.0", + "litellm==1.67.4.post1", + "python-multipart==0.0.20", + "alembic==1.15.2", + "asyncpg==0.30.0", + "python-jose==3.4.0", + "passlib==1.7.4", + "sendgrid==6.11.0", + "pydantic-settings==2.9.1", + "fastapi_utils==0.8.0", + "bcrypt==4.3.0", + "jinja2==3.1.6", +] + +[project.optional-dependencies] +dev = [ + "black==25.1.0", + "flake8==7.2.0", + "pytest==8.3.5", + "pytest-cov==6.1.1", + "httpx==0.28.1", + "pytest-asyncio==0.26.0", +] + +[tool.setuptools] +packages = ["src"] + +[tool.black] +line-length = 88 +target-version = ["py310"] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + | venv + | alembic/versions +)/ +''' + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = "test_*.py" +python_functions = "test_*" +python_classes = "Test*" +filterwarnings = [ + "ignore::DeprecationWarning", +] + +[tool.coverage.run] +source = ["src"] +omit = ["tests/*", "alembic/*"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 3ad02cac..00000000 --- a/requirements.txt +++ /dev/null @@ -1,22 +0,0 @@ -fastapi -uvicorn -pydantic -sqlalchemy -psycopg2 -psycopg2-binary -google-cloud-aiplatform -python-dotenv -google-adk -litellm -python-multipart -alembic -asyncpg -# Novas dependências para autenticação -python-jose[cryptography] -passlib[bcrypt] -sendgrid -pydantic[email] -pydantic-settings -fastapi_utils -bcrypt -jinja2 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..cf008a5d --- /dev/null +++ b/setup.py @@ -0,0 +1,6 @@ +"""Setup script for the package.""" + +from setuptools import setup + +if __name__ == "__main__": + setup() \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index 87228db0..134e123e 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,3 +1,3 @@ """ -Pacote principal da aplicação -""" \ No newline at end of file +Main package of the application +""" diff --git a/src/api/admin_routes.py b/src/api/admin_routes.py index d6818421..a7270450 100644 --- a/src/api/admin_routes.py +++ b/src/api/admin_routes.py @@ -1,14 +1,17 @@ +from typing import List from fastapi import APIRouter, Depends, HTTPException, status, Request from sqlalchemy.orm import Session -from typing import List, Optional -from datetime import datetime import uuid from src.config.database import get_db from src.core.jwt_middleware import get_jwt_token, verify_admin from src.schemas.audit import AuditLogResponse, AuditLogFilter from src.services.audit_service import get_audit_logs, create_audit_log -from src.services.user_service import get_admin_users, create_admin_user, deactivate_user +from src.services.user_service import ( + get_admin_users, + create_admin_user, + deactivate_user, +) from src.schemas.user import UserResponse, AdminUserCreate router = APIRouter( @@ -18,6 +21,7 @@ router = APIRouter( responses={403: {"description": "Permission denied"}}, ) + # Audit routes @router.get("/audit-logs", response_model=List[AuditLogResponse]) async def read_audit_logs( @@ -27,12 +31,12 @@ async def read_audit_logs( ): """ Get audit logs with optional filters - + Args: filters: Filters for log search db: Database session payload: JWT token payload - + Returns: List[AuditLogResponse]: List of audit logs """ @@ -45,9 +49,10 @@ async def read_audit_logs( resource_type=filters.resource_type, resource_id=filters.resource_id, start_date=filters.start_date, - end_date=filters.end_date + end_date=filters.end_date, ) + # Admin routes @router.get("/users", response_model=List[UserResponse]) async def read_admin_users( @@ -58,18 +63,19 @@ async def read_admin_users( ): """ List admin users - + Args: skip: Number of records to skip limit: Maximum number of records to return db: Database session payload: JWT token payload - + Returns: List[UserResponse]: List of admin users """ return get_admin_users(db, skip, limit) + @router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def create_new_admin_user( user_data: AdminUserCreate, @@ -79,16 +85,16 @@ async def create_new_admin_user( ): """ Create a new admin user - + Args: user_data: User data to be created request: FastAPI Request object db: Database session payload: JWT token payload - + Returns: UserResponse: Created user data - + Raises: HTTPException: If there is an error in creation """ @@ -97,17 +103,14 @@ async def create_new_admin_user( if not user_id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Unable to identify the logged in user" + detail="Unable to identify the logged in user", ) - + # Create admin user user, message = create_admin_user(db, user_data) if not user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=message - ) - + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message) + # Register action in audit log create_audit_log( db, @@ -116,11 +119,12 @@ async def create_new_admin_user( resource_type="admin_user", resource_id=str(user.id), details={"email": user.email}, - request=request + request=request, ) - + return user + @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def deactivate_admin_user( user_id: uuid.UUID, @@ -130,13 +134,13 @@ async def deactivate_admin_user( ): """ Deactivate an admin user (does not delete, only deactivates) - + Args: user_id: ID of the user to be deactivated request: FastAPI Request object db: Database session payload: JWT token payload - + Raises: HTTPException: If there is an error in deactivation """ @@ -145,24 +149,21 @@ async def deactivate_admin_user( if not current_user_id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Unable to identify the logged in user" + detail="Unable to identify the logged in user", ) - + # Do not allow deactivating yourself if str(user_id) == current_user_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Unable to deactivate your own user" + detail="Unable to deactivate your own user", ) - + # Deactivate user success, message = deactivate_user(db, user_id) if not success: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=message - ) - + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message) + # Register action in audit log create_audit_log( db, @@ -171,5 +172,5 @@ async def deactivate_admin_user( resource_type="admin_user", resource_id=str(user_id), details=None, - request=request - ) \ No newline at end of file + request=request, + ) diff --git a/src/api/agent_routes.py b/src/api/agent_routes.py index 91efcb74..25845819 100644 --- a/src/api/agent_routes.py +++ b/src/api/agent_routes.py @@ -7,10 +7,6 @@ from src.core.jwt_middleware import ( get_jwt_token, verify_user_client, ) -from src.core.jwt_middleware import ( - get_jwt_token, - verify_user_client, -) from src.schemas.schemas import ( Agent, AgentCreate, diff --git a/src/api/auth_routes.py b/src/api/auth_routes.py index b5a3b8ce..06615f86 100644 --- a/src/api/auth_routes.py +++ b/src/api/auth_routes.py @@ -162,9 +162,7 @@ async def login_for_access_token(form_data: UserLogin, db: Session = Depends(get """ user = authenticate_user(db, form_data.email, form_data.password) if not user: - logger.warning( - f"Login attempt with invalid credentials: {form_data.email}" - ) + logger.warning(f"Login attempt with invalid credentials: {form_data.email}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password", diff --git a/src/api/chat_routes.py b/src/api/chat_routes.py index e977c275..2b88e871 100644 --- a/src/api/chat_routes.py +++ b/src/api/chat_routes.py @@ -11,7 +11,11 @@ from src.services import ( from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse from src.services.agent_runner import run_agent from src.core.exceptions import AgentNotFoundError -from src.main import session_service, artifacts_service, memory_service +from src.services.service_providers import ( + session_service, + artifacts_service, + memory_service, +) from datetime import datetime import logging @@ -71,4 +75,4 @@ async def chat( except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) - ) \ No newline at end of file + ) diff --git a/src/api/session_routes.py b/src/api/session_routes.py index d68be678..170e4fd9 100644 --- a/src/api/session_routes.py +++ b/src/api/session_routes.py @@ -19,7 +19,7 @@ from src.services.session_service import ( get_sessions_by_agent, get_sessions_by_client, ) -from src.main import session_service +from src.services.service_providers import session_service import logging logger = logging.getLogger(__name__) @@ -30,6 +30,7 @@ router = APIRouter( responses={404: {"description": "Not found"}}, ) + # Session Routes @router.get("/client/{client_id}", response_model=List[Adk_Session]) async def get_client_sessions( diff --git a/src/config/database.py b/src/config/database.py index 96efda8d..2a77cb0e 100644 --- a/src/config/database.py +++ b/src/config/database.py @@ -10,9 +10,10 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() + def get_db(): db = SessionLocal() try: yield db finally: - db.close() \ No newline at end of file + db.close() diff --git a/src/core/exceptions.py b/src/core/exceptions.py index 9d553e57..165d618c 100644 --- a/src/core/exceptions.py +++ b/src/core/exceptions.py @@ -1,55 +1,66 @@ from fastapi import HTTPException from typing import Optional, Dict, Any + class BaseAPIException(HTTPException): """Base class for API exceptions""" + def __init__( self, status_code: int, message: str, error_code: str, - details: Optional[Dict[str, Any]] = None + details: Optional[Dict[str, Any]] = None, ): - super().__init__(status_code=status_code, detail={ - "error": message, - "error_code": error_code, - "details": details or {} - }) + super().__init__( + status_code=status_code, + detail={ + "error": message, + "error_code": error_code, + "details": details or {}, + }, + ) + class AgentNotFoundError(BaseAPIException): """Exception when the agent is not found""" + def __init__(self, agent_id: str): super().__init__( status_code=404, message=f"Agent with ID {agent_id} not found", - error_code="AGENT_NOT_FOUND" + error_code="AGENT_NOT_FOUND", ) + class InvalidParameterError(BaseAPIException): """Exception for invalid parameters""" + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): super().__init__( status_code=400, message=message, error_code="INVALID_PARAMETER", - details=details + details=details, ) + class InvalidRequestError(BaseAPIException): """Exception for invalid requests""" + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): super().__init__( status_code=400, message=message, error_code="INVALID_REQUEST", - details=details + details=details, ) + class InternalServerError(BaseAPIException): """Exception for server errors""" + def __init__(self, message: str = "Server error"): super().__init__( - status_code=500, - message=message, - error_code="INTERNAL_SERVER_ERROR" - ) \ No newline at end of file + status_code=500, message=message, error_code="INTERNAL_SERVER_ERROR" + ) diff --git a/src/core/jwt_middleware.py b/src/core/jwt_middleware.py index 977e5a45..918e36ba 100644 --- a/src/core/jwt_middleware.py +++ b/src/core/jwt_middleware.py @@ -13,16 +13,17 @@ logger = logging.getLogger(__name__) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") + async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict: """ Extracts and validates the JWT token - + Args: token: Token JWT - + Returns: dict: Token payload data - + Raises: HTTPException: If the token is invalid """ @@ -31,86 +32,90 @@ async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict: detail="Invalid credentials", headers={"WWW-Authenticate": "Bearer"}, ) - + try: payload = jwt.decode( - token, - settings.JWT_SECRET_KEY, - algorithms=[settings.JWT_ALGORITHM] + token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) - + email: str = payload.get("sub") if email is None: logger.warning("Token without email (sub)") raise credentials_exception - + exp = payload.get("exp") if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow(): logger.warning(f"Token expired for {email}") raise credentials_exception - + return payload - + except JWTError as e: logger.error(f"Error decoding JWT token: {str(e)}") raise credentials_exception + async def verify_user_client( payload: dict = Depends(get_jwt_token), db: Session = Depends(get_db), - required_client_id: UUID = None + required_client_id: UUID = None, ) -> bool: """ Verifies if the user is associated with the specified client - + Args: payload: Token JWT payload db: Database session required_client_id: Client ID to be verified - + Returns: bool: True se verificado com sucesso - + Raises: HTTPException: If the user does not have permission """ # Administrators have access to all clients if payload.get("is_admin", False): return True - + # Para não-admins, verificar se o client_id corresponde user_client_id = payload.get("client_id") if not user_client_id: - logger.warning(f"Non-admin user without client_id in token: {payload.get('sub')}") + logger.warning( + f"Non-admin user without client_id in token: {payload.get('sub')}" + ) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="User not associated with a client" + detail="User not associated with a client", ) - + # If no client_id is specified to verify, any client is valid if not required_client_id: return True - + # Verify if the user's client_id corresponds to the required_client_id if str(user_client_id) != str(required_client_id): - logger.warning(f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}") + logger.warning( + f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}" + ) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Access denied to access resources of this client" + detail="Access denied to access resources of this client", ) - + return True + async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool: """ Verifies if the user is an administrator - + Args: payload: Token JWT payload - + Returns: bool: True if the user is an administrator - + Raises: HTTPException: If the user is not an administrator """ @@ -118,26 +123,29 @@ async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool: logger.warning(f"Access denied to admin: User {payload.get('sub')}") raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Access denied. Restricted to administrators." + detail="Access denied. Restricted to administrators.", ) - + return True -def get_current_user_client_id(payload: dict = Depends(get_jwt_token)) -> Optional[UUID]: + +def get_current_user_client_id( + payload: dict = Depends(get_jwt_token), +) -> Optional[UUID]: """ Gets the ID of the client associated with the current user - + Args: payload: Token JWT payload - + Returns: Optional[UUID]: Client ID or None if the user is an administrator """ if payload.get("is_admin", False): return None - + client_id = payload.get("client_id") if client_id: return UUID(client_id) - - return None \ No newline at end of file + + return None diff --git a/src/main.py b/src/main.py index 3f191594..0645f172 100644 --- a/src/main.py +++ b/src/main.py @@ -1,35 +1,30 @@ import os import sys from pathlib import Path - -# Add the root directory to PYTHONPATH -root_dir = Path(__file__).parent.parent -sys.path.append(str(root_dir)) - from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from src.config.database import engine, Base from src.config.settings import settings from src.utils.logger import setup_logger -from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService -from google.adk.sessions import DatabaseSessionService -from google.adk.memory import InMemoryMemoryService -# Initialize service instances -session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING) -artifacts_service = InMemoryArtifactService() -memory_service = InMemoryMemoryService() +# Necessary for other modules +from src.services.service_providers import session_service # noqa: F401 +from src.services.service_providers import artifacts_service # noqa: F401 +from src.services.service_providers import memory_service # noqa: F401 -# Import routers after service initialization to avoid circular imports -from src.api.auth_routes import router as auth_router -from src.api.admin_routes import router as admin_router -from src.api.chat_routes import router as chat_router -from src.api.session_routes import router as session_router -from src.api.agent_routes import router as agent_router -from src.api.contact_routes import router as contact_router -from src.api.mcp_server_routes import router as mcp_server_router -from src.api.tool_routes import router as tool_router -from src.api.client_routes import router as client_router +import src.api.auth_routes +import src.api.admin_routes +import src.api.chat_routes +import src.api.session_routes +import src.api.agent_routes +import src.api.contact_routes +import src.api.mcp_server_routes +import src.api.tool_routes +import src.api.client_routes + +# Add the root directory to PYTHONPATH +root_dir = Path(__file__).parent.parent +sys.path.append(str(root_dir)) # Configure logger logger = setup_logger(__name__) @@ -52,8 +47,7 @@ app.add_middleware( # PostgreSQL configuration POSTGRES_CONNECTION_STRING = os.getenv( - "POSTGRES_CONNECTION_STRING", - "postgresql://postgres:root@localhost:5432/evo_ai" + "POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai" ) # Create database tables @@ -61,6 +55,17 @@ Base.metadata.create_all(bind=engine) API_PREFIX = "/api/v1" +# Define router references +auth_router = src.api.auth_routes.router +admin_router = src.api.admin_routes.router +chat_router = src.api.chat_routes.router +session_router = src.api.session_routes.router +agent_router = src.api.agent_routes.router +contact_router = src.api.contact_routes.router +mcp_server_router = src.api.mcp_server_routes.router +tool_router = src.api.tool_routes.router +client_router = src.api.client_routes.router + # Include routes app.include_router(auth_router, prefix=API_PREFIX) app.include_router(admin_router, prefix=API_PREFIX) @@ -79,5 +84,5 @@ def read_root(): "message": "Welcome to Evo AI API", "documentation": "/docs", "version": settings.API_VERSION, - "auth": "To access the API, use JWT authentication via '/api/v1/auth/login'" + "auth": "To access the API, use JWT authentication via '/api/v1/auth/login'", } diff --git a/src/models/models.py b/src/models/models.py index 45a786af..dc6a1364 100644 --- a/src/models/models.py +++ b/src/models/models.py @@ -1,9 +1,20 @@ -from sqlalchemy import Column, String, UUID, DateTime, ForeignKey, JSON, Text, BigInteger, CheckConstraint, Boolean +from sqlalchemy import ( + Column, + String, + UUID, + DateTime, + ForeignKey, + JSON, + Text, + CheckConstraint, + Boolean, +) from sqlalchemy.sql import func from sqlalchemy.orm import relationship, backref from src.config.database import Base import uuid + class Client(Base): __tablename__ = "clients" @@ -13,13 +24,16 @@ class Client(Base): created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + class User(Base): __tablename__ = "users" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) email = Column(String, unique=True, index=True, nullable=False) password_hash = Column(String, nullable=False) - client_id = Column(UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True) + client_id = Column( + UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True + ) is_active = Column(Boolean, default=False) is_admin = Column(Boolean, default=False) email_verified = Column(Boolean, default=False) @@ -29,9 +43,12 @@ class User(Base): password_reset_expiry = Column(DateTime(timezone=True), nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) - + # Relationship with Client (One-to-One, optional for administrators) - client = relationship("Client", backref=backref("user", uselist=False, cascade="all, delete-orphan")) + client = relationship( + "Client", backref=backref("user", uselist=False, cascade="all, delete-orphan") + ) + class Contact(Base): __tablename__ = "contacts" @@ -44,6 +61,7 @@ class Contact(Base): created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + class Agent(Base): __tablename__ = "agents" @@ -60,21 +78,30 @@ class Agent(Base): updated_at = Column(DateTime(timezone=True), onupdate=func.now()) __table_args__ = ( - CheckConstraint("type IN ('llm', 'sequential', 'parallel', 'loop')", name='check_agent_type'), + CheckConstraint( + "type IN ('llm', 'sequential', 'parallel', 'loop')", name="check_agent_type" + ), ) def to_dict(self): """Converts the object to a dictionary, converting UUIDs to strings""" result = {} for key, value in self.__dict__.items(): - if key.startswith('_'): + if key.startswith("_"): continue if isinstance(value, uuid.UUID): result[key] = str(value) elif isinstance(value, dict): result[key] = self._convert_dict(value) elif isinstance(value, list): - result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value] + result[key] = [ + ( + self._convert_dict(item) + if isinstance(item, dict) + else str(item) if isinstance(item, uuid.UUID) else item + ) + for item in value + ] else: result[key] = value return result @@ -88,11 +115,19 @@ class Agent(Base): elif isinstance(value, dict): result[key] = self._convert_dict(value) elif isinstance(value, list): - result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value] + result[key] = [ + ( + self._convert_dict(item) + if isinstance(item, dict) + else str(item) if isinstance(item, uuid.UUID) else item + ) + for item in value + ] else: result[key] = value return result + class MCPServer(Base): __tablename__ = "mcp_servers" @@ -105,11 +140,14 @@ class MCPServer(Base): type = Column(String, nullable=False, default="official") created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) - + __table_args__ = ( - CheckConstraint("type IN ('official', 'community')", name='check_mcp_server_type'), + CheckConstraint( + "type IN ('official', 'community')", name="check_mcp_server_type" + ), ) + class Tool(Base): __tablename__ = "tools" @@ -121,11 +159,12 @@ class Tool(Base): created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + class Session(Base): __tablename__ = "sessions" # The directive below makes Alembic ignore this table in migrations - __table_args__ = {'extend_existing': True, 'info': {'skip_autogenerate': True}} - + __table_args__ = {"extend_existing": True, "info": {"skip_autogenerate": True}} + id = Column(String, primary_key=True) app_name = Column(String) user_id = Column(String) @@ -133,11 +172,14 @@ class Session(Base): create_time = Column(DateTime(timezone=True)) update_time = Column(DateTime(timezone=True)) + class AuditLog(Base): __tablename__ = "audit_logs" - + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True) + user_id = Column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) action = Column(String, nullable=False) resource_type = Column(String, nullable=False) resource_id = Column(String, nullable=True) @@ -145,6 +187,6 @@ class AuditLog(Base): ip_address = Column(String, nullable=True) user_agent = Column(String, nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now()) - + # Relationship with User - user = relationship("User", backref="audit_logs") \ No newline at end of file + user = relationship("User", backref="audit_logs") diff --git a/src/schemas/agent_config.py b/src/schemas/agent_config.py index 8a7c36d8..360e34c1 100644 --- a/src/schemas/agent_config.py +++ b/src/schemas/agent_config.py @@ -1,26 +1,38 @@ -from typing import List, Optional, Dict, Any, Union +from typing import List, Optional, Dict, Union from pydantic import BaseModel, Field from uuid import UUID + class ToolConfig(BaseModel): """Configuration of a tool""" + id: UUID - envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the tool") + envs: Dict[str, str] = Field( + default_factory=dict, description="Environment variables of the tool" + ) class Config: from_attributes = True + class MCPServerConfig(BaseModel): """Configuration of an MCP server""" + id: UUID - envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the server") - tools: List[str] = Field(default_factory=list, description="List of tools of the server") + envs: Dict[str, str] = Field( + default_factory=dict, description="Environment variables of the server" + ) + tools: List[str] = Field( + default_factory=list, description="List of tools of the server" + ) class Config: from_attributes = True + class HTTPToolParameter(BaseModel): """Parameter of an HTTP tool""" + type: str required: bool description: str @@ -28,8 +40,10 @@ class HTTPToolParameter(BaseModel): class Config: from_attributes = True + class HTTPToolParameters(BaseModel): """Parameters of an HTTP tool""" + path_params: Optional[Dict[str, str]] = None query_params: Optional[Dict[str, Union[str, List[str]]]] = None body_params: Optional[Dict[str, HTTPToolParameter]] = None @@ -37,8 +51,10 @@ class HTTPToolParameters(BaseModel): class Config: from_attributes = True + class HTTPToolErrorHandling(BaseModel): """Configuration of error handling""" + timeout: int retry_count: int fallback_response: Dict[str, str] @@ -46,8 +62,10 @@ class HTTPToolErrorHandling(BaseModel): class Config: from_attributes = True + class HTTPTool(BaseModel): """Configuration of an HTTP tool""" + name: str method: str values: Dict[str, str] @@ -60,42 +78,72 @@ class HTTPTool(BaseModel): class Config: from_attributes = True + class CustomTools(BaseModel): """Configuration of custom tools""" - http_tools: List[HTTPTool] = Field(default_factory=list, description="List of HTTP tools") + + http_tools: List[HTTPTool] = Field( + default_factory=list, description="List of HTTP tools" + ) class Config: from_attributes = True + class LLMConfig(BaseModel): """Configuration for LLM agents""" - tools: Optional[List[ToolConfig]] = Field(default=None, description="List of available tools") - custom_tools: Optional[CustomTools] = Field(default=None, description="Custom tools") - mcp_servers: Optional[List[MCPServerConfig]] = Field(default=None, description="List of MCP servers") - sub_agents: Optional[List[UUID]] = Field(default=None, description="List of IDs of sub-agents") + + tools: Optional[List[ToolConfig]] = Field( + default=None, description="List of available tools" + ) + custom_tools: Optional[CustomTools] = Field( + default=None, description="Custom tools" + ) + mcp_servers: Optional[List[MCPServerConfig]] = Field( + default=None, description="List of MCP servers" + ) + sub_agents: Optional[List[UUID]] = Field( + default=None, description="List of IDs of sub-agents" + ) class Config: from_attributes = True + class SequentialConfig(BaseModel): """Configuration for sequential agents""" - sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents in execution order") + + sub_agents: List[UUID] = Field( + ..., description="List of IDs of sub-agents in execution order" + ) class Config: from_attributes = True + class ParallelConfig(BaseModel): """Configuration for parallel agents""" - sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for parallel execution") + + sub_agents: List[UUID] = Field( + ..., description="List of IDs of sub-agents for parallel execution" + ) class Config: from_attributes = True + class LoopConfig(BaseModel): """Configuration for loop agents""" - sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for loop execution") - max_iterations: Optional[int] = Field(default=None, description="Maximum number of iterations") - condition: Optional[str] = Field(default=None, description="Condition to stop the loop") + + sub_agents: List[UUID] = Field( + ..., description="List of IDs of sub-agents for loop execution" + ) + max_iterations: Optional[int] = Field( + default=None, description="Maximum number of iterations" + ) + condition: Optional[str] = Field( + default=None, description="Condition to stop the loop" + ) class Config: - from_attributes = True \ No newline at end of file + from_attributes = True diff --git a/src/schemas/audit.py b/src/schemas/audit.py index 60129145..2e6de13c 100644 --- a/src/schemas/audit.py +++ b/src/schemas/audit.py @@ -3,30 +3,38 @@ from typing import Optional, Dict, Any from datetime import datetime from uuid import UUID + class AuditLogBase(BaseModel): """Base schema for audit log""" + action: str resource_type: str resource_id: Optional[str] = None details: Optional[Dict[str, Any]] = None + class AuditLogCreate(AuditLogBase): """Schema for creating audit log""" + pass + class AuditLogResponse(AuditLogBase): """Schema for audit log response""" + id: UUID user_id: Optional[UUID] = None ip_address: Optional[str] = None user_agent: Optional[str] = None created_at: datetime - + class Config: from_attributes = True + class AuditLogFilter(BaseModel): """Schema for audit log search filters""" + user_id: Optional[UUID] = None action: Optional[str] = None resource_type: Optional[str] = None @@ -34,4 +42,4 @@ class AuditLogFilter(BaseModel): start_date: Optional[datetime] = None end_date: Optional[datetime] = None skip: Optional[int] = Field(0, ge=0) - limit: Optional[int] = Field(100, ge=1, le=1000) \ No newline at end of file + limit: Optional[int] = Field(100, ge=1, le=1000) diff --git a/src/schemas/chat.py b/src/schemas/chat.py index 8a84fa40..80f68a10 100644 --- a/src/schemas/chat.py +++ b/src/schemas/chat.py @@ -1,21 +1,33 @@ from pydantic import BaseModel, Field from typing import Dict, Any, Optional + class ChatRequest(BaseModel): """Schema for chat requests""" - agent_id: str = Field(..., description="ID of the agent that will process the message") - contact_id: str = Field(..., description="ID of the contact that will process the message") + + agent_id: str = Field( + ..., description="ID of the agent that will process the message" + ) + contact_id: str = Field( + ..., description="ID of the contact that will process the message" + ) message: str = Field(..., description="User message") + class ChatResponse(BaseModel): """Schema for chat responses""" + response: str = Field(..., description="Agent response") status: str = Field(..., description="Operation status") error: Optional[str] = Field(None, description="Error message, if there is one") timestamp: str = Field(..., description="Timestamp of the response") + class ErrorResponse(BaseModel): """Schema for error responses""" + error: str = Field(..., description="Error message") status_code: int = Field(..., description="HTTP status code of the error") - details: Optional[Dict[str, Any]] = Field(None, description="Additional error details") \ No newline at end of file + details: Optional[Dict[str, Any]] = Field( + None, description="Additional error details" + ) diff --git a/src/schemas/schemas.py b/src/schemas/schemas.py index 007cedc8..d894336b 100644 --- a/src/schemas/schemas.py +++ b/src/schemas/schemas.py @@ -4,15 +4,18 @@ from datetime import datetime from uuid import UUID import uuid import re -from .agent_config import LLMConfig, SequentialConfig, ParallelConfig, LoopConfig +from src.schemas.agent_config import LLMConfig + class ClientBase(BaseModel): name: str email: Optional[EmailStr] = None + class ClientCreate(ClientBase): pass + class Client(ClientBase): id: UUID created_at: datetime @@ -20,14 +23,17 @@ class Client(ClientBase): class Config: from_attributes = True + class ContactBase(BaseModel): ext_id: Optional[str] = None name: Optional[str] = None meta: Optional[Dict[str, Any]] = Field(default_factory=dict) + class ContactCreate(ContactBase): client_id: UUID + class Contact(ContactBase): id: UUID client_id: UUID @@ -35,67 +41,80 @@ class Contact(ContactBase): class Config: from_attributes = True + class AgentBase(BaseModel): name: str = Field(..., description="Agent name (no spaces or special characters)") description: Optional[str] = Field(None, description="Agent description") type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)") - model: Optional[str] = Field(None, description="Agent model (required only for llm type)") - api_key: Optional[str] = Field(None, description="Agent API Key (required only for llm type)") + model: Optional[str] = Field( + None, description="Agent model (required only for llm type)" + ) + api_key: Optional[str] = Field( + None, description="Agent API Key (required only for llm type)" + ) instruction: Optional[str] = None - config: Union[LLMConfig, Dict[str, Any]] = Field(..., description="Agent configuration based on type") + config: Union[LLMConfig, Dict[str, Any]] = Field( + ..., description="Agent configuration based on type" + ) - @validator('name') + @validator("name") def validate_name(cls, v): - if not re.match(r'^[a-zA-Z0-9_-]+$', v): - raise ValueError('Agent name cannot contain spaces or special characters') + if not re.match(r"^[a-zA-Z0-9_-]+$", v): + raise ValueError("Agent name cannot contain spaces or special characters") return v - @validator('type') + @validator("type") def validate_type(cls, v): - if v not in ['llm', 'sequential', 'parallel', 'loop']: - raise ValueError('Invalid agent type. Must be: llm, sequential, parallel or loop') + if v not in ["llm", "sequential", "parallel", "loop"]: + raise ValueError( + "Invalid agent type. Must be: llm, sequential, parallel or loop" + ) return v - @validator('model') + @validator("model") def validate_model(cls, v, values): - if 'type' in values and values['type'] == 'llm' and not v: - raise ValueError('Model is required for llm type agents') + if "type" in values and values["type"] == "llm" and not v: + raise ValueError("Model is required for llm type agents") return v - @validator('api_key') + @validator("api_key") def validate_api_key(cls, v, values): - if 'type' in values and values['type'] == 'llm' and not v: - raise ValueError('API Key is required for llm type agents') + if "type" in values and values["type"] == "llm" and not v: + raise ValueError("API Key is required for llm type agents") return v - @validator('config') + @validator("config") def validate_config(cls, v, values): - if 'type' not in values: + if "type" not in values: return v - - if values['type'] == 'llm': + + if values["type"] == "llm": if isinstance(v, dict): try: # Convert the dictionary to LLMConfig v = LLMConfig(**v) except Exception as e: - raise ValueError(f'Invalid LLM configuration for agent: {str(e)}') + raise ValueError(f"Invalid LLM configuration for agent: {str(e)}") elif not isinstance(v, LLMConfig): - raise ValueError('Invalid LLM configuration for agent') - elif values['type'] in ['sequential', 'parallel', 'loop']: + raise ValueError("Invalid LLM configuration for agent") + elif values["type"] in ["sequential", "parallel", "loop"]: if not isinstance(v, dict): raise ValueError(f'Invalid configuration for agent {values["type"]}') - if 'sub_agents' not in v: + if "sub_agents" not in v: raise ValueError(f'Agent {values["type"]} must have sub_agents') - if not isinstance(v['sub_agents'], list): - raise ValueError('sub_agents must be a list') - if not v['sub_agents']: - raise ValueError(f'Agent {values["type"]} must have at least one sub-agent') + if not isinstance(v["sub_agents"], list): + raise ValueError("sub_agents must be a list") + if not v["sub_agents"]: + raise ValueError( + f'Agent {values["type"]} must have at least one sub-agent' + ) return v + class AgentCreate(AgentBase): client_id: UUID + class Agent(AgentBase): id: UUID client_id: UUID @@ -105,6 +124,7 @@ class Agent(AgentBase): class Config: from_attributes = True + class MCPServerBase(BaseModel): name: str description: Optional[str] = None @@ -113,9 +133,11 @@ class MCPServerBase(BaseModel): tools: List[str] = Field(default_factory=list) type: str = Field(default="official") + class MCPServerCreate(MCPServerBase): pass + class MCPServer(MCPServerBase): id: uuid.UUID created_at: datetime @@ -124,19 +146,22 @@ class MCPServer(MCPServerBase): class Config: from_attributes = True + class ToolBase(BaseModel): name: str description: Optional[str] = None config_json: Dict[str, Any] = Field(default_factory=dict) environments: Dict[str, Any] = Field(default_factory=dict) + class ToolCreate(ToolBase): pass + class Tool(ToolBase): id: uuid.UUID created_at: datetime updated_at: Optional[datetime] = None class Config: - from_attributes = True \ No newline at end of file + from_attributes = True diff --git a/src/schemas/user.py b/src/schemas/user.py index 5a07c5b6..949ff19a 100644 --- a/src/schemas/user.py +++ b/src/schemas/user.py @@ -1,23 +1,28 @@ -from pydantic import BaseModel, EmailStr, Field +from pydantic import BaseModel, EmailStr from typing import Optional from datetime import datetime from uuid import UUID + class UserBase(BaseModel): email: EmailStr + class UserCreate(UserBase): password: str name: str # For client creation + class AdminUserCreate(UserBase): password: str name: str + class UserLogin(BaseModel): email: EmailStr password: str + class UserResponse(UserBase): id: UUID client_id: Optional[UUID] = None @@ -25,26 +30,31 @@ class UserResponse(UserBase): email_verified: bool is_admin: bool created_at: datetime - + class Config: from_attributes = True + class TokenResponse(BaseModel): access_token: str token_type: str - + + class TokenData(BaseModel): sub: str # user email exp: datetime is_admin: bool client_id: Optional[UUID] = None - + + class PasswordReset(BaseModel): token: str new_password: str - + + class ForgotPassword(BaseModel): email: EmailStr + class MessageResponse(BaseModel): - message: str \ No newline at end of file + message: str diff --git a/src/services/__init__.py b/src/services/__init__.py index 3e050588..255943f4 100644 --- a/src/services/__init__.py +++ b/src/services/__init__.py @@ -1 +1 @@ -from .agent_runner import run_agent \ No newline at end of file +from .agent_runner import run_agent diff --git a/src/services/agent_builder.py b/src/services/agent_builder.py index d2dae0d4..9f447153 100644 --- a/src/services/agent_builder.py +++ b/src/services/agent_builder.py @@ -13,11 +13,10 @@ from google.adk.agents.callback_context import CallbackContext from google.adk.models import LlmResponse, LlmRequest from google.adk.tools import load_memory -from typing import Optional -import logging -import os import requests +import os from datetime import datetime + logger = setup_logger(__name__) @@ -83,7 +82,7 @@ def before_model_callback( llm_request.config.system_instruction = modified_text logger.debug( - f"📝 System instruction updated with search results and history" + "📝 System instruction updated with search results and history" ) else: logger.warning("⚠️ No results found in the search") @@ -180,11 +179,13 @@ class AgentBuilder: mcp_tools = [] mcp_exit_stack = None if agent.config.get("mcp_servers"): - mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(agent.config, self.db) + mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools( + agent.config, self.db + ) # Combine all tools all_tools = custom_tools + mcp_tools - + now = datetime.now() current_datetime = now.strftime("%d/%m/%Y %H:%M") current_day_of_week = now.strftime("%A") @@ -201,10 +202,13 @@ class AgentBuilder: # Check if load_memory is enabled # before_model_callback_func = None - if agent.config.get("load_memory") == True: + if agent.config.get("load_memory"): all_tools.append(load_memory) # before_model_callback_func = before_model_callback - formatted_prompt = formatted_prompt + "\n\nALWAYS use the load_memory tool to retrieve knowledge for your context\n\n" + formatted_prompt = ( + formatted_prompt + + "\n\nALWAYS use the load_memory tool to retrieve knowledge for your context\n\n" + ) return ( LlmAgent( diff --git a/src/services/agent_runner.py b/src/services/agent_runner.py index f81867b8..8df8cbad 100644 --- a/src/services/agent_runner.py +++ b/src/services/agent_runner.py @@ -22,9 +22,7 @@ async def run_agent( db: Session, ): try: - logger.info( - f"Starting execution of agent {agent_id} for contact {contact_id}" - ) + logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}") logger.info(f"Received message: {message}") get_root_agent = get_agent(db, agent_id) @@ -77,15 +75,15 @@ async def run_agent( if event.is_final_response() and event.content and event.content.parts: final_response_text = event.content.parts[0].text logger.info(f"Final response received: {final_response_text}") - + completed_session = session_service.get_session( app_name=agent_id, user_id=contact_id, session_id=session_id, ) - + memory_service.add_session_to_memory(completed_session) - + finally: # Ensure the exit_stack is closed correctly if exit_stack: diff --git a/src/services/agent_service.py b/src/services/agent_service.py index 9e48642c..31793edb 100644 --- a/src/services/agent_service.py +++ b/src/services/agent_service.py @@ -216,9 +216,7 @@ async def update_agent( return agent except Exception as e: db.rollback() - raise HTTPException( - status_code=500, detail=f"Error updating agent: {str(e)}" - ) + raise HTTPException(status_code=500, detail=f"Error updating agent: {str(e)}") def delete_agent(db: Session, agent_id: uuid.UUID) -> bool: diff --git a/src/services/audit_service.py b/src/services/audit_service.py index ecec8910..2ceeb371 100644 --- a/src/services/audit_service.py +++ b/src/services/audit_service.py @@ -1,15 +1,15 @@ from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError -from src.models.models import AuditLog, User +from src.models.models import AuditLog from datetime import datetime from fastapi import Request from typing import Optional, Dict, Any, List import uuid import logging -import json logger = logging.getLogger(__name__) + def create_audit_log( db: Session, user_id: Optional[uuid.UUID], @@ -17,11 +17,11 @@ def create_audit_log( resource_type: str, resource_id: Optional[str] = None, details: Optional[Dict[str, Any]] = None, - request: Optional[Request] = None + request: Optional[Request] = None, ) -> Optional[AuditLog]: """ Create a new audit log - + Args: db: Database session user_id: User ID that performed the action (or None if anonymous) @@ -30,25 +30,25 @@ def create_audit_log( resource_id: Resource ID (optional) details: Additional details of the action (optional) request: FastAPI Request object (optional, to get IP and User-Agent) - + Returns: Optional[AuditLog]: Created audit log or None in case of error """ try: ip_address = None user_agent = None - + if request: - ip_address = request.client.host if hasattr(request, 'client') else None + ip_address = request.client.host if hasattr(request, "client") else None user_agent = request.headers.get("user-agent") - + # Convert details to serializable format if details: # Convert UUIDs to strings for key, value in details.items(): if isinstance(value, uuid.UUID): details[key] = str(value) - + audit_log = AuditLog( user_id=user_id, action=action, @@ -56,20 +56,20 @@ def create_audit_log( resource_id=str(resource_id) if resource_id else None, details=details, ip_address=ip_address, - user_agent=user_agent + user_agent=user_agent, ) - + db.add(audit_log) db.commit() db.refresh(audit_log) - + logger.info( - f"Audit log created: {action} in {resource_type}" + - (f" (ID: {resource_id})" if resource_id else "") + f"Audit log created: {action} in {resource_type}" + + (f" (ID: {resource_id})" if resource_id else "") ) - + return audit_log - + except SQLAlchemyError as e: db.rollback() logger.error(f"Error creating audit log: {str(e)}") @@ -78,6 +78,7 @@ def create_audit_log( logger.error(f"Unexpected error creating audit log: {str(e)}") return None + def get_audit_logs( db: Session, skip: int = 0, @@ -87,11 +88,11 @@ def get_audit_logs( resource_type: Optional[str] = None, resource_id: Optional[str] = None, start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None + end_date: Optional[datetime] = None, ) -> List[AuditLog]: """ Get audit logs with optional filters - + Args: db: Database session skip: Number of records to skip @@ -102,35 +103,35 @@ def get_audit_logs( resource_id: Filter by resource ID start_date: Start date end_date: End date - + Returns: List[AuditLog]: List of audit logs """ query = db.query(AuditLog) - + # Apply filters, if provided if user_id: query = query.filter(AuditLog.user_id == user_id) - + if action: query = query.filter(AuditLog.action == action) - + if resource_type: query = query.filter(AuditLog.resource_type == resource_type) - + if resource_id: query = query.filter(AuditLog.resource_id == resource_id) - + if start_date: query = query.filter(AuditLog.created_at >= start_date) - + if end_date: query = query.filter(AuditLog.created_at <= end_date) - + # Order by creation date (most recent first) query = query.order_by(AuditLog.created_at.desc()) - + # Apply pagination query = query.offset(skip).limit(limit) - - return query.all() \ No newline at end of file + + return query.all() diff --git a/src/services/auth_service.py b/src/services/auth_service.py index b1800e3a..63f3b376 100644 --- a/src/services/auth_service.py +++ b/src/services/auth_service.py @@ -16,17 +16,20 @@ logger = logging.getLogger(__name__) # Define OAuth2 authentication scheme with password flow oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") -async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User: + +async def get_current_user( + token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) +) -> User: """ Get the current user from the JWT token - + Args: token: JWT token db: Database session - + Returns: User: Current user - + Raises: HTTPException: If the token is invalid or the user is not found """ @@ -35,103 +38,108 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De detail="Invalid credentials", headers={"WWW-Authenticate": "Bearer"}, ) - + try: # Decode the token payload = jwt.decode( - token, - settings.JWT_SECRET_KEY, - algorithms=[settings.JWT_ALGORITHM] + token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) - + # Extract token data email: str = payload.get("sub") if email is None: logger.warning("Token without email (sub)") raise credentials_exception - + # Check if the token has expired exp = payload.get("exp") if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow(): logger.warning(f"Token expired for {email}") raise credentials_exception - + # Create TokenData object token_data = TokenData( sub=email, exp=datetime.fromtimestamp(exp), is_admin=payload.get("is_admin", False), - client_id=payload.get("client_id") + client_id=payload.get("client_id"), ) - + except JWTError as e: logger.error(f"Error decoding JWT token: {str(e)}") raise credentials_exception - + # Search for user in the database user = get_user_by_email(db, email=token_data.sub) if user is None: logger.warning(f"User not found for email: {token_data.sub}") raise credentials_exception - + if not user.is_active: logger.warning(f"Attempt to access inactive user: {user.email}") raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Inactive user" + status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" ) - + return user -async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User: + +async def get_current_active_user( + current_user: User = Depends(get_current_user), +) -> User: """ Check if the current user is active - + Args: current_user: Current user - + Returns: User: Current user if active - + Raises: HTTPException: If the user is not active """ if not current_user.is_active: logger.warning(f"Attempt to access inactive user: {current_user.email}") raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Inactive user" + status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" ) return current_user -async def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User: + +async def get_current_admin_user( + current_user: User = Depends(get_current_user), +) -> User: """ Check if the current user is an administrator - + Args: current_user: Current user - + Returns: User: Current user if administrator - + Raises: HTTPException: If the user is not an administrator """ if not current_user.is_admin: - logger.warning(f"Attempt to access admin by non-admin user: {current_user.email}") + logger.warning( + f"Attempt to access admin by non-admin user: {current_user.email}" + ) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Access denied. Restricted to administrators." + detail="Access denied. Restricted to administrators.", ) return current_user + def create_access_token(user: User) -> str: """ Create a JWT access token for the user - + Args: user: User for which to create the token - + Returns: str: JWT token """ @@ -140,10 +148,10 @@ def create_access_token(user: User) -> str: "sub": user.email, "is_admin": user.is_admin, } - + # Include client_id only if not administrator and client_id is set if not user.is_admin and user.client_id: token_data["client_id"] = str(user.client_id) - + # Create token - return create_jwt_token(token_data) \ No newline at end of file + return create_jwt_token(token_data) diff --git a/src/services/client_service.py b/src/services/client_service.py index 609b00fe..0045db6f 100644 --- a/src/services/client_service.py +++ b/src/services/client_service.py @@ -11,6 +11,7 @@ import logging logger = logging.getLogger(__name__) + def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]: """Search for a client by ID""" try: @@ -23,9 +24,10 @@ def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]: logger.error(f"Error searching for client {client_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error searching for client" + detail="Error searching for client", ) + def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]: """Search for all clients with pagination""" try: @@ -34,9 +36,10 @@ def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]: logger.error(f"Error searching for clients: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error searching for clients" + detail="Error searching for clients", ) + def create_client(db: Session, client: ClientCreate) -> Client: """Create a new client""" try: @@ -51,19 +54,22 @@ def create_client(db: Session, client: ClientCreate) -> Client: logger.error(f"Error creating client: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error creating client" + detail="Error creating client", ) -def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Optional[Client]: + +def update_client( + db: Session, client_id: uuid.UUID, client: ClientCreate +) -> Optional[Client]: """Updates an existing client""" try: db_client = get_client(db, client_id) if not db_client: return None - + for key, value in client.model_dump().items(): setattr(db_client, key, value) - + db.commit() db.refresh(db_client) logger.info(f"Client updated successfully: {client_id}") @@ -73,16 +79,17 @@ def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Op logger.error(f"Error updating client {client_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error updating client" + detail="Error updating client", ) + def delete_client(db: Session, client_id: uuid.UUID) -> bool: """Removes a client""" try: db_client = get_client(db, client_id) if not db_client: return False - + db.delete(db_client) db.commit() logger.info(f"Client removed successfully: {client_id}") @@ -92,18 +99,21 @@ def delete_client(db: Session, client_id: uuid.UUID) -> bool: logger.error(f"Error removing client {client_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error removing client" + detail="Error removing client", ) -def create_client_with_user(db: Session, client_data: ClientCreate, user_data: UserCreate) -> Tuple[Optional[Client], str]: + +def create_client_with_user( + db: Session, client_data: ClientCreate, user_data: UserCreate +) -> Tuple[Optional[Client], str]: """ Creates a new client with an associated user - + Args: db: Database session client_data: Client data to be created user_data: User data to be created - + Returns: Tuple[Optional[Client], str]: Tuple with the created client (or None in case of error) and status message """ @@ -112,27 +122,27 @@ def create_client_with_user(db: Session, client_data: ClientCreate, user_data: U client = Client(**client_data.model_dump()) db.add(client) db.flush() # Get client ID without committing the transaction - + # Use client ID to create the associated user user, message = create_user(db, user_data, is_admin=False, client_id=client.id) - + if not user: # If there was an error creating the user, rollback db.rollback() logger.error(f"Error creating user for client: {message}") return None, f"Error creating user: {message}" - + # If everything went well, commit the transaction db.commit() logger.info(f"Client and user created successfully: {client.id}") return client, "Client and user created successfully" - + except SQLAlchemyError as e: db.rollback() logger.error(f"Error creating client with user: {str(e)}") return None, f"Error creating client with user: {str(e)}" - + except Exception as e: db.rollback() logger.error(f"Unexpected error creating client with user: {str(e)}") - return None, f"Unexpected error: {str(e)}" \ No newline at end of file + return None, f"Unexpected error: {str(e)}" diff --git a/src/services/contact_service.py b/src/services/contact_service.py index 7378a81c..61defd76 100644 --- a/src/services/contact_service.py +++ b/src/services/contact_service.py @@ -9,6 +9,7 @@ import logging logger = logging.getLogger(__name__) + def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]: """Search for a contact by ID""" try: @@ -21,20 +22,30 @@ def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]: logger.error(f"Error searching for contact {contact_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error searching for contact" + detail="Error searching for contact", ) -def get_contacts_by_client(db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100) -> List[Contact]: + +def get_contacts_by_client( + db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100 +) -> List[Contact]: """Search for contacts of a client with pagination""" try: - return db.query(Contact).filter(Contact.client_id == client_id).offset(skip).limit(limit).all() + return ( + db.query(Contact) + .filter(Contact.client_id == client_id) + .offset(skip) + .limit(limit) + .all() + ) except SQLAlchemyError as e: logger.error(f"Error searching for contacts of client {client_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error searching for contacts" + detail="Error searching for contacts", ) + def create_contact(db: Session, contact: ContactCreate) -> Contact: """Create a new contact""" try: @@ -49,19 +60,22 @@ def create_contact(db: Session, contact: ContactCreate) -> Contact: logger.error(f"Error creating contact: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error creating contact" + detail="Error creating contact", ) -def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -> Optional[Contact]: + +def update_contact( + db: Session, contact_id: uuid.UUID, contact: ContactCreate +) -> Optional[Contact]: """Update an existing contact""" try: db_contact = get_contact(db, contact_id) if not db_contact: return None - + for key, value in contact.model_dump().items(): setattr(db_contact, key, value) - + db.commit() db.refresh(db_contact) logger.info(f"Contact updated successfully: {contact_id}") @@ -71,16 +85,17 @@ def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) - logger.error(f"Error updating contact {contact_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error updating contact" + detail="Error updating contact", ) + def delete_contact(db: Session, contact_id: uuid.UUID) -> bool: """Remove a contact""" try: db_contact = get_contact(db, contact_id) if not db_contact: return False - + db.delete(db_contact) db.commit() logger.info(f"Contact removed successfully: {contact_id}") @@ -90,5 +105,5 @@ def delete_contact(db: Session, contact_id: uuid.UUID) -> bool: logger.error(f"Error removing contact {contact_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error removing contact" - ) \ No newline at end of file + detail="Error removing contact", + ) diff --git a/src/services/custom_tools.py b/src/services/custom_tools.py index 5b11cdb2..2e2dc605 100644 --- a/src/services/custom_tools.py +++ b/src/services/custom_tools.py @@ -6,6 +6,7 @@ from src.utils.logger import setup_logger logger = setup_logger(__name__) + class CustomToolBuilder: def __init__(self): self.tools = [] @@ -53,7 +54,9 @@ class CustomToolBuilder: # Adds default values to query params if they are not present for param, value in values.items(): - if param not in query_params and param not in parameters.get("path_params", {}): + if param not in query_params and param not in parameters.get( + "path_params", {} + ): query_params[param] = value # Processa body parameters @@ -64,7 +67,11 @@ class CustomToolBuilder: # Adds default values to body if they are not present for param, value in values.items(): - if param not in body_data and param not in query_params and param not in parameters.get("path_params", {}): + if ( + param not in body_data + and param not in query_params + and param not in parameters.get("path_params", {}) + ): body_data[param] = value # Makes the HTTP request @@ -74,7 +81,7 @@ class CustomToolBuilder: headers=processed_headers, params=query_params, json=body_data if body_data else None, - timeout=error_handling.get("timeout", 30) + timeout=error_handling.get("timeout", 30), ) if response.status_code >= 400: @@ -87,30 +94,34 @@ class CustomToolBuilder: except Exception as e: logger.error(f"Error executing tool {name}: {str(e)}") - return json.dumps(error_handling.get("fallback_response", { - "error": "tool_execution_error", - "message": str(e) - })) + return json.dumps( + error_handling.get( + "fallback_response", + {"error": "tool_execution_error", "message": str(e)}, + ) + ) # Adds dynamic docstring based on the configuration param_docs = [] - + # Adds path parameters for param, value in parameters.get("path_params", {}).items(): param_docs.append(f"{param}: {value}") - + # Adds query parameters for param, value in parameters.get("query_params", {}).items(): if isinstance(value, list): param_docs.append(f"{param}: List[{', '.join(value)}]") else: param_docs.append(f"{param}: {value}") - + # Adds body parameters for param, param_config in parameters.get("body_params", {}).items(): required = "Required" if param_config.get("required", False) else "Optional" - param_docs.append(f"{param} ({param_config['type']}, {required}): {param_config['description']}") - + param_docs.append( + f"{param} ({param_config['type']}, {required}): {param_config['description']}" + ) + # Adds default values if values: param_docs.append("\nDefault values:") @@ -119,10 +130,10 @@ class CustomToolBuilder: http_tool.__doc__ = f""" {description} - + Parameters: {chr(10).join(param_docs)} - + Returns: String containing the response in JSON format """ @@ -140,4 +151,4 @@ class CustomToolBuilder: for http_tool_config in tools_config.get("http_tools", []): self.tools.append(self._create_http_tool(http_tool_config)) - return self.tools \ No newline at end of file + return self.tools diff --git a/src/services/email_service.py b/src/services/email_service.py index af703e03..6544908f 100644 --- a/src/services/email_service.py +++ b/src/services/email_service.py @@ -16,17 +16,18 @@ os.makedirs(templates_dir, exist_ok=True) # Configure Jinja2 with the templates directory env = Environment( loader=FileSystemLoader(templates_dir), - autoescape=select_autoescape(['html', 'xml']) + autoescape=select_autoescape(["html", "xml"]), ) + def _render_template(template_name: str, context: dict) -> str: """ Render a template with the provided data - + Args: template_name: Template file name context: Data to render in the template - + Returns: str: Rendered HTML """ @@ -37,14 +38,15 @@ def _render_template(template_name: str, context: dict) -> str: logger.error(f"Error rendering template '{template_name}': {str(e)}") return f"

Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}

" + def send_verification_email(email: str, token: str) -> bool: """ Send a verification email to the user - + Args: email: Recipient's email token: Email verification token - + Returns: bool: True if the email was sent successfully, False otherwise """ @@ -53,39 +55,47 @@ def send_verification_email(email: str, token: str) -> bool: from_email = Email(settings.EMAIL_FROM) to_email = To(email) subject = "Email Verification - Evo AI" - + verification_link = f"{settings.APP_URL}/auth/verify-email/{token}" - - html_content = _render_template('verification_email', { - 'verification_link': verification_link, - 'user_name': email.split('@')[0], # Use part of the email as temporary name - 'current_year': datetime.now().year - }) - + + html_content = _render_template( + "verification_email", + { + "verification_link": verification_link, + "user_name": email.split("@")[ + 0 + ], # Use part of the email as temporary name + "current_year": datetime.now().year, + }, + ) + content = Content("text/html", html_content) - + mail = Mail(from_email, to_email, subject, content) response = sg.client.mail.send.post(request_body=mail.get()) - + if response.status_code >= 200 and response.status_code < 300: logger.info(f"Verification email sent to {email}") return True else: - logger.error(f"Failed to send verification email to {email}. Status: {response.status_code}") + logger.error( + f"Failed to send verification email to {email}. Status: {response.status_code}" + ) return False - + except Exception as e: logger.error(f"Error sending verification email to {email}: {str(e)}") return False + def send_password_reset_email(email: str, token: str) -> bool: """ Send a password reset email to the user - + Args: email: Recipient's email token: Password reset token - + Returns: bool: True if the email was sent successfully, False otherwise """ @@ -94,39 +104,47 @@ def send_password_reset_email(email: str, token: str) -> bool: from_email = Email(settings.EMAIL_FROM) to_email = To(email) subject = "Password Reset - Evo AI" - + reset_link = f"{settings.APP_URL}/reset-password?token={token}" - - html_content = _render_template('password_reset', { - 'reset_link': reset_link, - 'user_name': email.split('@')[0], # Use part of the email as temporary name - 'current_year': datetime.now().year - }) - + + html_content = _render_template( + "password_reset", + { + "reset_link": reset_link, + "user_name": email.split("@")[ + 0 + ], # Use part of the email as temporary name + "current_year": datetime.now().year, + }, + ) + content = Content("text/html", html_content) - + mail = Mail(from_email, to_email, subject, content) response = sg.client.mail.send.post(request_body=mail.get()) - + if response.status_code >= 200 and response.status_code < 300: logger.info(f"Password reset email sent to {email}") return True else: - logger.error(f"Failed to send password reset email to {email}. Status: {response.status_code}") + logger.error( + f"Failed to send password reset email to {email}. Status: {response.status_code}" + ) return False - + except Exception as e: logger.error(f"Error sending password reset email to {email}: {str(e)}") return False + def send_welcome_email(email: str, user_name: str = None) -> bool: """ Send a welcome email to the user after verification - + Args: email: Recipient's email user_name: User's name (optional) - + Returns: bool: True if the email was sent successfully, False otherwise """ @@ -135,41 +153,49 @@ def send_welcome_email(email: str, user_name: str = None) -> bool: from_email = Email(settings.EMAIL_FROM) to_email = To(email) subject = "Welcome to Evo AI" - + dashboard_link = f"{settings.APP_URL}/dashboard" - - html_content = _render_template('welcome_email', { - 'dashboard_link': dashboard_link, - 'user_name': user_name or email.split('@')[0], - 'current_year': datetime.now().year - }) - + + html_content = _render_template( + "welcome_email", + { + "dashboard_link": dashboard_link, + "user_name": user_name or email.split("@")[0], + "current_year": datetime.now().year, + }, + ) + content = Content("text/html", html_content) - + mail = Mail(from_email, to_email, subject, content) response = sg.client.mail.send.post(request_body=mail.get()) - + if response.status_code >= 200 and response.status_code < 300: logger.info(f"Welcome email sent to {email}") return True else: - logger.error(f"Failed to send welcome email to {email}. Status: {response.status_code}") + logger.error( + f"Failed to send welcome email to {email}. Status: {response.status_code}" + ) return False - + except Exception as e: logger.error(f"Error sending welcome email to {email}: {str(e)}") return False -def send_account_locked_email(email: str, reset_token: str, failed_attempts: int, time_period: str) -> bool: + +def send_account_locked_email( + email: str, reset_token: str, failed_attempts: int, time_period: str +) -> bool: """ Send an email informing that the account has been locked after login attempts - + Args: email: Recipient's email reset_token: Token to reset the password failed_attempts: Number of failed attempts time_period: Time period of the attempts - + Returns: bool: True if the email was sent successfully, False otherwise """ @@ -178,29 +204,34 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int from_email = Email(settings.EMAIL_FROM) to_email = To(email) subject = "Security Alert - Account Locked" - + reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}" - - html_content = _render_template('account_locked', { - 'reset_link': reset_link, - 'user_name': email.split('@')[0], - 'failed_attempts': failed_attempts, - 'time_period': time_period, - 'current_year': datetime.now().year - }) - + + html_content = _render_template( + "account_locked", + { + "reset_link": reset_link, + "user_name": email.split("@")[0], + "failed_attempts": failed_attempts, + "time_period": time_period, + "current_year": datetime.now().year, + }, + ) + content = Content("text/html", html_content) - + mail = Mail(from_email, to_email, subject, content) response = sg.client.mail.send.post(request_body=mail.get()) - + if response.status_code >= 200 and response.status_code < 300: logger.info(f"Account locked email sent to {email}") return True else: - logger.error(f"Failed to send account locked email to {email}. Status: {response.status_code}") + logger.error( + f"Failed to send account locked email to {email}. Status: {response.status_code}" + ) return False - + except Exception as e: logger.error(f"Error sending account locked email to {email}: {str(e)}") - return False \ No newline at end of file + return False diff --git a/src/services/mcp_server_service.py b/src/services/mcp_server_service.py index 2f89abaa..cc9f54bf 100644 --- a/src/services/mcp_server_service.py +++ b/src/services/mcp_server_service.py @@ -9,6 +9,7 @@ import logging logger = logging.getLogger(__name__) + def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]: """Search for an MCP server by ID""" try: @@ -21,9 +22,10 @@ def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]: logger.error(f"Error searching for MCP server {server_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error searching for MCP server" + detail="Error searching for MCP server", ) + def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]: """Search for all MCP servers with pagination""" try: @@ -32,9 +34,10 @@ def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPSer logger.error(f"Error searching for MCP servers: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error searching for MCP servers" + detail="Error searching for MCP servers", ) + def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer: """Create a new MCP server""" try: @@ -49,19 +52,22 @@ def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer: logger.error(f"Error creating MCP server: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error creating MCP server" + detail="Error creating MCP server", ) -def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate) -> Optional[MCPServer]: + +def update_mcp_server( + db: Session, server_id: uuid.UUID, server: MCPServerCreate +) -> Optional[MCPServer]: """Update an existing MCP server""" try: db_server = get_mcp_server(db, server_id) if not db_server: return None - + for key, value in server.model_dump().items(): setattr(db_server, key, value) - + db.commit() db.refresh(db_server) logger.info(f"MCP server updated successfully: {server_id}") @@ -71,16 +77,17 @@ def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate logger.error(f"Error updating MCP server {server_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error updating MCP server" + detail="Error updating MCP server", ) + def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool: """Remove an MCP server""" try: db_server = get_mcp_server(db, server_id) if not db_server: return False - + db.delete(db_server) db.commit() logger.info(f"MCP server removed successfully: {server_id}") @@ -90,5 +97,5 @@ def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool: logger.error(f"Error removing MCP server {server_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error removing MCP server" - ) \ No newline at end of file + detail="Error removing MCP server", + ) diff --git a/src/services/mcp_service.py b/src/services/mcp_service.py index c65a68e7..eaccfd2e 100644 --- a/src/services/mcp_service.py +++ b/src/services/mcp_service.py @@ -12,26 +12,28 @@ from sqlalchemy.orm import Session logger = setup_logger(__name__) + class MCPService: def __init__(self): self.tools = [] self.exit_stack = AsyncExitStack() - async def _connect_to_mcp_server(self, server_config: Dict[str, Any]) -> Tuple[List[Any], Optional[AsyncExitStack]]: + async def _connect_to_mcp_server( + self, server_config: Dict[str, Any] + ) -> Tuple[List[Any], Optional[AsyncExitStack]]: """Connect to a specific MCP server and return its tools.""" try: # Determines the type of server (local or remote) if "url" in server_config: # Remote server (SSE) connection_params = SseServerParams( - url=server_config["url"], - headers=server_config.get("headers", {}) + url=server_config["url"], headers=server_config.get("headers", {}) ) else: # Local server (Stdio) command = server_config.get("command", "npx") args = server_config.get("args", []) - + # Adds environment variables if specified env = server_config.get("env", {}) if env: @@ -39,9 +41,7 @@ class MCPService: os.environ[key] = value connection_params = StdioServerParameters( - command=command, - args=args, - env=env + command=command, args=args, env=env ) tools, exit_stack = await MCPToolset.from_server( @@ -73,8 +73,10 @@ class MCPService: logger.warning(f"Removed {removed_count} incompatible tools.") return filtered_tools - - def _filter_tools_by_agent(self, tools: List[Any], agent_tools: List[str]) -> List[Any]: + + def _filter_tools_by_agent( + self, tools: List[Any], agent_tools: List[str] + ) -> List[Any]: """Filters tools compatible with the agent.""" filtered_tools = [] for tool in tools: @@ -83,7 +85,9 @@ class MCPService: filtered_tools.append(tool) return filtered_tools - async def build_tools(self, mcp_config: Dict[str, Any], db: Session) -> Tuple[List[Any], AsyncExitStack]: + async def build_tools( + self, mcp_config: Dict[str, Any], db: Session + ) -> Tuple[List[Any], AsyncExitStack]: """Builds a list of tools from multiple MCP servers.""" self.tools = [] self.exit_stack = AsyncExitStack() @@ -92,23 +96,25 @@ class MCPService: for server in mcp_config.get("mcp_servers", []): try: # Search for the MCP server in the database - mcp_server = get_mcp_server(db, server['id']) + mcp_server = get_mcp_server(db, server["id"]) if not mcp_server: logger.warning(f"Servidor MCP não encontrado: {server['id']}") continue # Prepares the server configuration server_config = mcp_server.config_json.copy() - + # Replaces the environment variables in the config_json - if 'env' in server_config: - for key, value in server_config['env'].items(): - if value.startswith('env@@'): - env_key = value.replace('env@@', '') - if env_key in server.get('envs', {}): - server_config['env'][key] = server['envs'][env_key] + if "env" in server_config: + for key, value in server_config["env"].items(): + if value.startswith("env@@"): + env_key = value.replace("env@@", "") + if env_key in server.get("envs", {}): + server_config["env"][key] = server["envs"][env_key] else: - logger.warning(f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}") + logger.warning( + f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}" + ) continue logger.info(f"Connecting to MCP server: {mcp_server.name}") @@ -117,22 +123,30 @@ class MCPService: if tools and exit_stack: # Filters incompatible tools filtered_tools = self._filter_incompatible_tools(tools) - + # Filters tools compatible with the agent - agent_tools = server.get('tools', []) - filtered_tools = self._filter_tools_by_agent(filtered_tools, agent_tools) + agent_tools = server.get("tools", []) + filtered_tools = self._filter_tools_by_agent( + filtered_tools, agent_tools + ) self.tools.extend(filtered_tools) - + # Registers the exit_stack with the AsyncExitStack await self.exit_stack.enter_async_context(exit_stack) - logger.info(f"Connected successfully. Added {len(filtered_tools)} tools.") + logger.info( + f"Connected successfully. Added {len(filtered_tools)} tools." + ) else: - logger.warning(f"Failed to connect or no tools available for {mcp_server.name}") + logger.warning( + f"Failed to connect or no tools available for {mcp_server.name}" + ) except Exception as e: logger.error(f"Error connecting to MCP server {server['id']}: {e}") continue - logger.info(f"MCP Toolset created successfully. Total of {len(self.tools)} tools.") + logger.info( + f"MCP Toolset created successfully. Total of {len(self.tools)} tools." + ) - return self.tools, self.exit_stack \ No newline at end of file + return self.tools, self.exit_stack diff --git a/src/services/service_providers.py b/src/services/service_providers.py new file mode 100644 index 00000000..56cd33c3 --- /dev/null +++ b/src/services/service_providers.py @@ -0,0 +1,9 @@ +from src.config.settings import settings +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.sessions import DatabaseSessionService +from google.adk.memory import InMemoryMemoryService + +# Initialize service instances +session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING) +artifacts_service = InMemoryArtifactService() +memory_service = InMemoryMemoryService() diff --git a/src/services/session_service.py b/src/services/session_service.py index f4095677..cb14bfef 100644 --- a/src/services/session_service.py +++ b/src/services/session_service.py @@ -66,7 +66,7 @@ def get_session_by_id( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid session ID. Expected format: app_name_user_id", ) - + parts = session_id.split("_", 1) if len(parts) != 2: logger.error(f"Invalid session ID format: {session_id}") @@ -74,22 +74,22 @@ def get_session_by_id( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid session ID format. Expected format: app_name_user_id", ) - + user_id, app_name = parts - + session = session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id, ) - + if session is None: logger.error(f"Session not found: {session_id}") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Session not found: {session_id}", ) - + return session except Exception as e: logger.error(f"Error searching for session {session_id}: {str(e)}") @@ -106,7 +106,7 @@ def delete_session(session_service: DatabaseSessionService, session_id: str) -> try: session = get_session_by_id(session_service, session_id) # If we get here, the session exists (get_session_by_id already validates) - + session_service.delete_session( app_name=session.app_name, user_id=session.user_id, @@ -131,10 +131,10 @@ def get_session_events( try: session = get_session_by_id(session_service, session_id) # If we get here, the session exists (get_session_by_id already validates) - - if not hasattr(session, 'events') or session.events is None: + + if not hasattr(session, "events") or session.events is None: return [] - + return session.events except HTTPException: # Passes HTTP exceptions from get_session_by_id diff --git a/src/services/tool_service.py b/src/services/tool_service.py index 00af61c6..d4e3ef3f 100644 --- a/src/services/tool_service.py +++ b/src/services/tool_service.py @@ -9,6 +9,7 @@ import logging logger = logging.getLogger(__name__) + def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]: """Search for a tool by ID""" try: @@ -21,9 +22,10 @@ def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]: logger.error(f"Error searching for tool {tool_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error searching for tool" + detail="Error searching for tool", ) + def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]: """Search for all tools with pagination""" try: @@ -32,9 +34,10 @@ def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]: logger.error(f"Error searching for tools: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error searching for tools" + detail="Error searching for tools", ) + def create_tool(db: Session, tool: ToolCreate) -> Tool: """Creates a new tool""" try: @@ -49,19 +52,20 @@ def create_tool(db: Session, tool: ToolCreate) -> Tool: logger.error(f"Error creating tool: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error creating tool" + detail="Error creating tool", ) + def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]: """Updates an existing tool""" try: db_tool = get_tool(db, tool_id) if not db_tool: return None - + for key, value in tool.model_dump().items(): setattr(db_tool, key, value) - + db.commit() db.refresh(db_tool) logger.info(f"Tool updated successfully: {tool_id}") @@ -71,16 +75,17 @@ def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[T logger.error(f"Error updating tool {tool_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error updating tool" + detail="Error updating tool", ) + def delete_tool(db: Session, tool_id: uuid.UUID) -> bool: """Remove a tool""" try: db_tool = get_tool(db, tool_id) if not db_tool: return False - + db.delete(db_tool) db.commit() logger.info(f"Tool removed successfully: {tool_id}") @@ -90,5 +95,5 @@ def delete_tool(db: Session, tool_id: uuid.UUID) -> bool: logger.error(f"Error removing tool {tool_id}: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error removing tool" - ) \ No newline at end of file + detail="Error removing tool", + ) diff --git a/src/services/user_service.py b/src/services/user_service.py index c87ead6f..1d622eef 100644 --- a/src/services/user_service.py +++ b/src/services/user_service.py @@ -3,7 +3,10 @@ from sqlalchemy.exc import SQLAlchemyError from src.models.models import User, Client from src.schemas.user import UserCreate from src.utils.security import get_password_hash, verify_password, generate_token -from src.services.email_service import send_verification_email, send_password_reset_email +from src.services.email_service import ( + send_verification_email, + send_password_reset_email, +) from datetime import datetime, timedelta import uuid import logging @@ -11,16 +14,22 @@ from typing import Optional, Tuple logger = logging.getLogger(__name__) -def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, client_id: Optional[uuid.UUID] = None) -> Tuple[Optional[User], str]: + +def create_user( + db: Session, + user_data: UserCreate, + is_admin: bool = False, + client_id: Optional[uuid.UUID] = None, +) -> Tuple[Optional[User], str]: """ - Creates a new user in the system - + Creates a new user in the system + Args: db: Database session user_data: User data to be created is_admin: If the user is an administrator client_id: Associated client ID (optional, a new one will be created if not provided) - + Returns: Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message """ @@ -28,17 +37,19 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie # Check if email already exists db_user = db.query(User).filter(User.email == user_data.email).first() if db_user: - logger.warning(f"Attempt to register with existing email: {user_data.email}") + logger.warning( + f"Attempt to register with existing email: {user_data.email}" + ) return None, "Email already registered" - + # Create verification token verification_token = generate_token() token_expiry = datetime.utcnow() + timedelta(hours=24) - + # Start transaction user = None local_client_id = client_id - + try: # If not admin and no client_id, create an associated client if not is_admin and local_client_id is None: @@ -46,7 +57,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie db.add(client) db.flush() # Get the client ID local_client_id = client.id - + # Create user user = User( email=user_data.email, @@ -56,52 +67,56 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie is_active=False, # Inactive until email is verified email_verified=False, verification_token=verification_token, - verification_token_expiry=token_expiry + verification_token_expiry=token_expiry, ) db.add(user) db.commit() - + # Send verification email email_sent = send_verification_email(user.email, verification_token) if not email_sent: logger.error(f"Failed to send verification email to {user.email}") # We don't do rollback here, we just log the error - + logger.info(f"User created successfully: {user.email}") - return user, "User created successfully. Check your email to activate your account." - + return ( + user, + "User created successfully. Check your email to activate your account.", + ) + except SQLAlchemyError as e: db.rollback() logger.error(f"Error creating user: {str(e)}") return None, f"Error creating user: {str(e)}" - + except Exception as e: logger.error(f"Unexpected error creating user: {str(e)}") return None, f"Unexpected error: {str(e)}" + def verify_email(db: Session, token: str) -> Tuple[bool, str]: """ Verify the user's email using the provided token - + Args: db: Database session token: Verification token - + Returns: Tuple[bool, str]: Tuple with verification status and message """ try: # Search for user by token user = db.query(User).filter(User.verification_token == token).first() - + if not user: logger.warning(f"Attempt to verify with invalid token: {token}") return False, "Invalid verification token" - + # Check if the token has expired now = datetime.utcnow() expiry = user.verification_token_expiry - + # Ensure both dates are of the same type (aware or naive) if expiry.tzinfo is not None and now.tzinfo is None: # If expiry has timezone and now doesn't, convert now to have timezone @@ -109,180 +124,201 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]: elif now.tzinfo is not None and expiry.tzinfo is None: # If now has timezone and expiry doesn't, convert expiry to have timezone expiry = expiry.replace(tzinfo=now.tzinfo) - + if expiry < now: - logger.warning(f"Attempt to verify with expired token for user: {user.email}") + logger.warning( + f"Attempt to verify with expired token for user: {user.email}" + ) return False, "Verification token expired" - + # Update user user.email_verified = True user.is_active = True user.verification_token = None user.verification_token_expiry = None - + db.commit() logger.info(f"Email verified successfully for user: {user.email}") return True, "Email verified successfully. Your account is active." - + except SQLAlchemyError as e: db.rollback() logger.error(f"Error verifying email: {str(e)}") return False, f"Error verifying email: {str(e)}" - + except Exception as e: logger.error(f"Unexpected error verifying email: {str(e)}") return False, f"Unexpected error: {str(e)}" + def resend_verification(db: Session, email: str) -> Tuple[bool, str]: """ Resend the verification email - + Args: db: Database session email: User email - + Returns: Tuple[bool, str]: Tuple with operation status and message """ try: # Search for user by email user = db.query(User).filter(User.email == email).first() - + if not user: - logger.warning(f"Attempt to resend verification email for non-existent email: {email}") + logger.warning( + f"Attempt to resend verification email for non-existent email: {email}" + ) return False, "Email not found" - + if user.email_verified: - logger.info(f"Attempt to resend verification email for already verified email: {email}") + logger.info( + f"Attempt to resend verification email for already verified email: {email}" + ) return False, "Email already verified" - + # Generate new token verification_token = generate_token() token_expiry = datetime.utcnow() + timedelta(hours=24) - + # Update user user.verification_token = verification_token user.verification_token_expiry = token_expiry - + db.commit() - + # Send email email_sent = send_verification_email(user.email, verification_token) if not email_sent: logger.error(f"Failed to resend verification email to {user.email}") return False, "Failed to send verification email" - + logger.info(f"Verification email resent successfully to: {user.email}") return True, "Verification email resent. Check your inbox." - + except SQLAlchemyError as e: db.rollback() logger.error(f"Error resending verification: {str(e)}") return False, f"Error resending verification: {str(e)}" - + except Exception as e: logger.error(f"Unexpected error resending verification: {str(e)}") return False, f"Unexpected error: {str(e)}" + def forgot_password(db: Session, email: str) -> Tuple[bool, str]: """ Initiates the password recovery process - + Args: db: Database session email: User email - + Returns: Tuple[bool, str]: Tuple with operation status and message """ try: # Search for user by email user = db.query(User).filter(User.email == email).first() - + if not user: # For security, we don't inform if the email exists or not logger.info(f"Attempt to recover password for non-existent email: {email}") - return True, "If the email is registered, you will receive instructions to reset your password." - + return ( + True, + "If the email is registered, you will receive instructions to reset your password.", + ) + # Generate reset token reset_token = generate_token() token_expiry = datetime.utcnow() + timedelta(hours=1) # Token valid for 1 hour - + # Update user user.password_reset_token = reset_token user.password_reset_expiry = token_expiry - + db.commit() - + # Send email email_sent = send_password_reset_email(user.email, reset_token) if not email_sent: logger.error(f"Failed to send password reset email to {user.email}") return False, "Failed to send password reset email" - + logger.info(f"Password reset email sent successfully to: {user.email}") - return True, "If the email is registered, you will receive instructions to reset your password." - + return ( + True, + "If the email is registered, you will receive instructions to reset your password.", + ) + except SQLAlchemyError as e: db.rollback() logger.error(f"Error processing password recovery: {str(e)}") return False, f"Error processing password recovery: {str(e)}" - + except Exception as e: logger.error(f"Unexpected error processing password recovery: {str(e)}") return False, f"Unexpected error: {str(e)}" + def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]: """ Resets the user's password using the provided token - + Args: db: Database session token: Password reset token new_password: New password - + Returns: Tuple[bool, str]: Tuple with operation status and message """ try: # Search for user by token user = db.query(User).filter(User.password_reset_token == token).first() - + if not user: logger.warning(f"Attempt to reset password with invalid token: {token}") return False, "Invalid password reset token" - + # Check if the token has expired if user.password_reset_expiry < datetime.utcnow(): - logger.warning(f"Attempt to reset password with expired token for user: {user.email}") + logger.warning( + f"Attempt to reset password with expired token for user: {user.email}" + ) return False, "Password reset token expired" - + # Update password user.password_hash = get_password_hash(new_password) user.password_reset_token = None user.password_reset_expiry = None - + db.commit() logger.info(f"Password reset successfully for user: {user.email}") - return True, "Password reset successfully. You can now login with your new password." - + return ( + True, + "Password reset successfully. You can now login with your new password.", + ) + except SQLAlchemyError as e: db.rollback() logger.error(f"Error resetting password: {str(e)}") return False, f"Error resetting password: {str(e)}" - + except Exception as e: logger.error(f"Unexpected error resetting password: {str(e)}") return False, f"Unexpected error: {str(e)}" + def get_user_by_email(db: Session, email: str) -> Optional[User]: """ Searches for a user by email - + Args: db: Database session email: User email - + Returns: Optional[User]: User found or None """ @@ -292,15 +328,16 @@ def get_user_by_email(db: Session, email: str) -> Optional[User]: logger.error(f"Error searching for user by email: {str(e)}") return None + def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: """ Authenticates a user with email and password - + Args: db: Database session email: User email password: User password - + Returns: Optional[User]: Authenticated user or None """ @@ -313,75 +350,78 @@ def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: return None return user + def get_admin_users(db: Session, skip: int = 0, limit: int = 100): """ Lists the admin users - + Args: db: Database session skip: Number of records to skip limit: Maximum number of records to return - + Returns: List[User]: List of admin users """ try: - users = db.query(User).filter(User.is_admin == True).offset(skip).limit(limit).all() + users = db.query(User).filter(User.is_admin).offset(skip).limit(limit).all() logger.info(f"List of admins: {len(users)} found") return users - + except SQLAlchemyError as e: logger.error(f"Error listing admins: {str(e)}") return [] - + except Exception as e: logger.error(f"Unexpected error listing admins: {str(e)}") return [] + def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]: """ Creates a new admin user - + Args: db: Database session user_data: User data to be created - + Returns: Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message """ return create_user(db, user_data, is_admin=True) + def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]: """ Deactivates a user (does not delete, only marks as inactive) - + Args: db: Database session user_id: ID of the user to be deactivated - + Returns: Tuple[bool, str]: Tuple with operation status and message """ try: # Search for user by ID user = db.query(User).filter(User.id == user_id).first() - + if not user: logger.warning(f"Attempt to deactivate non-existent user: {user_id}") return False, "User not found" - + # Deactivate user user.is_active = False - + db.commit() logger.info(f"User deactivated successfully: {user.email}") return True, "User deactivated successfully" - + except SQLAlchemyError as e: db.rollback() logger.error(f"Error deactivating user: {str(e)}") return False, f"Error deactivating user: {str(e)}" - + except Exception as e: logger.error(f"Unexpected error deactivating user: {str(e)}") - return False, f"Unexpected error: {str(e)}" \ No newline at end of file + return False, f"Unexpected error: {str(e)}" diff --git a/src/utils/logger.py b/src/utils/logger.py index 879e47a4..fac7b36f 100644 --- a/src/utils/logger.py +++ b/src/utils/logger.py @@ -3,23 +3,26 @@ import os import sys from src.config.settings import settings + class CustomFormatter(logging.Formatter): """Custom formatter for logs""" - + grey = "\x1b[38;20m" yellow = "\x1b[33;20m" red = "\x1b[31;20m" bold_red = "\x1b[31;1m" reset = "\x1b[0m" - - format_template = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)" + + format_template = ( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)" + ) FORMATS = { logging.DEBUG: grey + format_template + reset, logging.INFO: grey + format_template + reset, logging.WARNING: yellow + format_template + reset, logging.ERROR: red + format_template + reset, - logging.CRITICAL: bold_red + format_template + reset + logging.CRITICAL: bold_red + format_template + reset, } def format(self, record): @@ -27,33 +30,34 @@ class CustomFormatter(logging.Formatter): formatter = logging.Formatter(log_fmt) return formatter.format(record) + def setup_logger(name: str) -> logging.Logger: """ Configures a custom logger - + Args: name: Logger name - + Returns: logging.Logger: Logger configurado """ logger = logging.getLogger(name) - + # Remove existing handlers to avoid duplication if logger.handlers: logger.handlers.clear() - + # Configure the logger level based on the environment variable or configuration log_level = getattr(logging, os.getenv("LOG_LEVEL", settings.LOG_LEVEL).upper()) logger.setLevel(log_level) - + # Console handler console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(CustomFormatter()) console_handler.setLevel(log_level) logger.addHandler(console_handler) - + # Prevent logs from being propagated to the root logger logger.propagate = False - - return logger \ No newline at end of file + + return logger diff --git a/src/utils/security.py b/src/utils/security.py index 1f77016f..1d804ed4 100644 --- a/src/utils/security.py +++ b/src/utils/security.py @@ -11,41 +11,44 @@ from dataclasses import dataclass logger = logging.getLogger(__name__) # Fix bcrypt error with passlib -if not hasattr(bcrypt, '__about__'): +if not hasattr(bcrypt, "__about__"): + @dataclass class BcryptAbout: __version__: str = getattr(bcrypt, "__version__") - + setattr(bcrypt, "__about__", BcryptAbout()) # Context for password hashing using bcrypt pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + def get_password_hash(password: str) -> str: """Creates a password hash""" return pwd_context.hash(password) + def verify_password(plain_password: str, hashed_password: str) -> bool: """Verifies if the provided password matches the stored hash""" return pwd_context.verify(plain_password, hashed_password) + def create_jwt_token(data: dict, expires_delta: timedelta = None) -> str: """Creates a JWT token""" to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: - expire = datetime.utcnow() + timedelta( - minutes=settings.JWT_EXPIRATION_TIME - ) + expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRATION_TIME) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode( to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM ) return encoded_jwt + def generate_token(length: int = 32) -> str: """Generates a secure token for email verification or password reset""" alphabet = string.ascii_letters + string.digits - token = ''.join(secrets.choice(alphabet) for _ in range(length)) - return token \ No newline at end of file + token = "".join(secrets.choice(alphabet) for _ in range(length)) + return token diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..2c854ab4 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Package initialization for tests diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 00000000..42b18b86 --- /dev/null +++ b/tests/api/__init__.py @@ -0,0 +1 @@ +# API tests package diff --git a/tests/api/test_root.py b/tests/api/test_root.py new file mode 100644 index 00000000..95d42b0f --- /dev/null +++ b/tests/api/test_root.py @@ -0,0 +1,11 @@ +def test_read_root(client): + """ + Test that the root endpoint returns the correct response. + """ + response = client.get("/") + assert response.status_code == 200 + data = response.json() + assert "message" in data + assert "documentation" in data + assert "version" in data + assert "auth" in data diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 00000000..20068376 --- /dev/null +++ b/tests/services/__init__.py @@ -0,0 +1 @@ +# Services tests package diff --git a/tests/services/test_auth_service.py b/tests/services/test_auth_service.py new file mode 100644 index 00000000..5015c31b --- /dev/null +++ b/tests/services/test_auth_service.py @@ -0,0 +1,27 @@ +from src.services.auth_service import create_access_token +from src.models.models import User +import uuid + + +def test_create_access_token(): + """ + Test that an access token is created with the correct data. + """ + # Create a mock user + user = User( + id=uuid.uuid4(), + email="test@example.com", + hashed_password="hashed_password", + is_active=True, + is_admin=False, + name="Test User", + client_id=uuid.uuid4(), + ) + + # Create token + token = create_access_token(user) + + # Simple validation + assert token is not None + assert isinstance(token, str) + assert len(token) > 0