chore: update project structure and add testing framework

This commit is contained in:
Davidson Gomes 2025-04-28 20:41:10 -03:00
parent 7af234ef48
commit e7e030dfd5
49 changed files with 1261 additions and 619 deletions

View File

@ -17,9 +17,16 @@
``` ```
src/ src/
├── api/ ├── api/
│ ├── routes.py # API routes definition │ ├── __init__.py # Package initialization
│ ├── auth_routes.py # Authentication routes (login, registration, etc.) │ ├── admin_routes.py # Admin routes for management interface
│ └── admin_routes.py # Protected admin routes │ ├── agent_routes.py # Routes to manage agents
│ ├── auth_routes.py # Authentication routes (login, registration)
│ ├── chat_routes.py # Routes for chat interactions with agents
│ ├── client_routes.py # Routes to manage clients
│ ├── contact_routes.py # Routes to manage contacts
│ ├── mcp_server_routes.py # Routes to manage MCP servers
│ ├── session_routes.py # Routes to manage chat sessions
│ └── tool_routes.py # Routes to manage tools for agents
├── config/ ├── config/
│ ├── database.py # Database configuration │ ├── database.py # Database configuration
│ └── settings.py # General settings │ └── settings.py # General settings
@ -30,18 +37,21 @@ src/
│ └── models.py # SQLAlchemy models │ └── models.py # SQLAlchemy models
├── schemas/ ├── schemas/
│ ├── schemas.py # Main Pydantic schemas │ ├── schemas.py # Main Pydantic schemas
│ ├── chat.py # Chat schemas
│ ├── user.py # User and authentication schemas │ ├── user.py # User and authentication schemas
│ └── audit.py # Audit logs schemas │ └── audit.py # Audit logs schemas
├── services/ ├── services/
│ ├── agent_service.py # Business logic for agents │ ├── agent_service.py # Business logic for agents
│ ├── agent_runner.py # Agent execution logic
│ ├── auth_service.py # JWT authentication logic
│ ├── audit_service.py # Audit logs logic
│ ├── client_service.py # Business logic for clients │ ├── client_service.py # Business logic for clients
│ ├── contact_service.py # Business logic for contacts │ ├── contact_service.py # Business logic for contacts
│ ├── mcp_server_service.py # Business logic for MCP servers
│ ├── tool_service.py # Business logic for tools
│ ├── user_service.py # User and authentication logic
│ ├── auth_service.py # JWT authentication logic
│ ├── email_service.py # Email sending service │ ├── email_service.py # Email sending service
│ └── audit_service.py # Audit logs logic │ ├── mcp_server_service.py # Business logic for MCP servers
│ ├── session_service.py # Business logic for chat sessions
│ ├── tool_service.py # Business logic for tools
│ └── user_service.py # User management logic
├── templates/ ├── templates/
│ ├── emails/ │ ├── emails/
│ │ ├── base_email.html # Base template with common structure and styles │ │ ├── base_email.html # Base template with common structure and styles
@ -49,7 +59,21 @@ src/
│ │ ├── password_reset.html # Password reset template │ │ ├── password_reset.html # Password reset template
│ │ ├── welcome_email.html # Welcome email after verification │ │ ├── welcome_email.html # Welcome email after verification
│ │ └── account_locked.html # Security alert for locked accounts │ │ └── account_locked.html # Security alert for locked accounts
├── tests/
│ ├── __init__.py # Package initialization
│ ├── api/
│ │ ├── __init__.py # Package initialization
│ │ ├── test_auth_routes.py # Test for authentication routes
│ │ └── test_root.py # Test for root endpoint
│ ├── models/
│ │ ├── __init__.py # Package initialization
│ │ ├── test_models.py # Test for models
│ ├── services/
│ │ ├── __init__.py # Package initialization
│ │ ├── test_auth_service.py # Test for authentication service
│ │ └── test_user_service.py # Test for user service
└── utils/ └── utils/
├── logger.py # Logger configuration
└── security.py # Security utilities (JWT, hash) └── security.py # Security utilities (JWT, hash)
``` ```
@ -63,6 +87,15 @@ src/
- Code examples in documentation must be in English - Code examples in documentation must be in English
- Commit messages must be in English - Commit messages must be in English
### Project Configuration
- Dependencies managed in `pyproject.toml` using modern Python packaging standards
- Development dependencies specified as optional dependencies in `pyproject.toml`
- Single source of truth for project metadata in `pyproject.toml`
- Build system configured to use setuptools
- Pytest configuration in `pyproject.toml` under `[tool.pytest.ini_options]`
- Code formatting with Black configured in `pyproject.toml`
- Linting with Flake8 configured in `.flake8`
### Schemas (Pydantic) ### Schemas (Pydantic)
- Use `BaseModel` as base for all schemas - Use `BaseModel` as base for all schemas
- Define fields with explicit types - Define fields with explicit types
@ -136,6 +169,28 @@ src/
- Indentation with 4 spaces - Indentation with 4 spaces
- Maximum of 79 characters per line - Maximum of 79 characters per line
## Commit Rules
- Use Conventional Commits format for all commit messages
- Format: `<type>(<scope>): <description>`
- Types:
- `feat`: A new feature
- `fix`: A bug fix
- `docs`: Documentation changes
- `style`: Changes that do not affect code meaning (formatting, etc.)
- `refactor`: Code changes that neither fix a bug nor add a feature
- `perf`: Performance improvements
- `test`: Adding or modifying tests
- `chore`: Changes to build process or auxiliary tools
- Scope is optional and should be the module or component affected
- Description should be concise, in the imperative mood, and not capitalized
- Use body for more detailed explanations if needed
- Reference issues in the footer with `Fixes #123` or `Relates to #123`
- Examples:
- `feat(auth): add password reset functionality`
- `fix(api): correct validation error in client registration`
- `docs: update API documentation for new endpoints`
- `refactor(services): improve error handling in authentication`
## Best Practices ## Best Practices
- Always validate input data - Always validate input data
- Implement appropriate logging - Implement appropriate logging
@ -163,6 +218,7 @@ src/
## Useful Commands ## Useful Commands
- `make run`: Start the server - `make run`: Start the server
- `make run-prod`: Start the server in production mode
- `make alembic-revision message="description"`: Create new migration - `make alembic-revision message="description"`: Create new migration
- `make alembic-upgrade`: Apply pending migrations - `make alembic-upgrade`: Apply pending migrations
- `make alembic-downgrade`: Revert last migration - `make alembic-downgrade`: Revert last migration
@ -170,3 +226,8 @@ src/
- `make alembic-reset`: Reset database to initial state - `make alembic-reset`: Reset database to initial state
- `make alembic-upgrade-cascade`: Force upgrade removing dependencies - `make alembic-upgrade-cascade`: Force upgrade removing dependencies
- `make clear-cache`: Clean project cache - `make clear-cache`: Clean project cache
- `make seed-all`: Run all database seeders
- `make lint`: Run linting checks with flake8
- `make format`: Format code with black
- `make install`: Install project for development
- `make install-dev`: Install project with development dependencies

View File

@ -1,3 +1,47 @@
# Environment and IDE
.venv
venv
.env
.idea
.vscode
__pycache__
*.pyc
*.pyo
*.pyd
.Python
.pytest_cache
.coverage
htmlcov/
.tox/
# Version control
.git
.github
.gitignore
# Logs and temp files
logs
*.log
tmp
.DS_Store
# Docker
.dockerignore
Dockerfile*
docker-compose*
# Documentation
README.md
LICENSE
docs/
# Development tools
tests/
.flake8
pyproject.toml
requirements-dev.txt
Makefile
# Ambiente virtual # Ambiente virtual
venv/ venv/
__pycache__/ __pycache__/

8
.flake8 Normal file
View File

@ -0,0 +1,8 @@
[flake8]
max-line-length = 88
exclude = .git,__pycache__,venv,alembic/versions/*
ignore = E203, W503, E501
per-file-ignores =
__init__.py: F401
src/models/models.py: E712
alembic/*: E711,E712,F401

77
.gitignore vendored
View File

@ -1,13 +1,10 @@
# Byte-compiled / optimized / DLL files # Python
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
# C extensions
*.so *.so
# Distribution / packaging
.Python .Python
env/
build/ build/
develop-eggs/ develop-eggs/
dist/ dist/
@ -19,19 +16,10 @@ lib64/
parts/ parts/
sdist/ sdist/
var/ var/
wheels/
*.egg-info/ *.egg-info/
.installed.cfg .installed.cfg
*.egg *.egg
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports # Unit test / coverage reports
htmlcov/ htmlcov/
.tox/ .tox/
@ -42,14 +30,57 @@ nosetests.xml
coverage.xml coverage.xml
*.cover *.cover
.hypothesis/ .hypothesis/
.pytest_cache/
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# IDE
.idea/
.vscode/
*.sublime-project
*.sublime-workspace
.DS_Store
# Logs
logs/
*.log
# Database
*.db
*.sqlite
*.sqlite3
backup/
# Local
local_settings.py
local.py
# Docker
.docker/
# Alembic versions
# alembic/versions/
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Translations # Translations
*.mo *.mo
*.pot *.pot
# Django stuff: # Django stuff:
*.log
local_settings.py
db.sqlite3 db.sqlite3
# Flask stuff: # Flask stuff:
@ -77,17 +108,6 @@ celerybeat-schedule
# SageMath parsed files # SageMath parsed files
*.sage.py *.sage.py
# Environments
.env
.venv
env/
venv/
.venv/
.env/
ENV/
env.bak/
venv.bak/
# Spyder project settings # Spyder project settings
.spyderproject .spyderproject
.spyproject .spyproject
@ -102,11 +122,8 @@ venv.bak/
.mypy_cache/ .mypy_cache/
# IDE # IDE
.idea/
.vscode/
*.swp *.swp
*.swo *.swo
# OS # OS
.DS_Store
Thumbs.db Thumbs.db

View File

@ -1,4 +1,4 @@
FROM python:3.11-slim FROM python:3.10-slim
# Define o diretório de trabalho # Define o diretório de trabalho
WORKDIR /app WORKDIR /app
@ -15,19 +15,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copia os arquivos de requisitos # Copy project files
COPY requirements.txt .
# Instala as dependências
RUN pip install --no-cache-dir -r requirements.txt
# Copia o código-fonte
COPY . . COPY . .
# Install dependencies
RUN pip install --no-cache-dir -e .
# Configuração para produção # Configuração para produção
ENV PORT=8000 \ ENV PORT=8000 \
HOST=0.0.0.0 \ HOST=0.0.0.0 \
DEBUG=false DEBUG=false
# Expose port
EXPOSE 8000
# Define o comando de inicialização # Define o comando de inicialização
CMD alembic upgrade head && uvicorn src.main:app --host $HOST --port $PORT CMD alembic upgrade head && uvicorn src.main:app --host $HOST --port $PORT

View File

@ -1,42 +1,42 @@
.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs .PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs lint format install install-dev venv
# Comandos do Alembic # Alembic commands
init: init:
alembic init alembics alembic init alembics
# make alembic-revision message="descrição da migração" # make alembic-revision message="migration description"
alembic-revision: alembic-revision:
alembic revision --autogenerate -m "$(message)" alembic revision --autogenerate -m "$(message)"
# Comando para atualizar o banco de dados # Command to update database to latest version
alembic-upgrade: alembic-upgrade:
alembic upgrade head alembic upgrade head
# Comando para voltar uma versão # Command to downgrade one version
alembic-downgrade: alembic-downgrade:
alembic downgrade -1 alembic downgrade -1
# Comando para rodar o servidor # Command to run the server
run: run:
uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 --reload-dir src uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 --reload-dir src
# Comando para limpar o cache em todas as pastas do projeto pastas pycache # Command to run the server in production mode
run-prod:
uvicorn src.main:app --host 0.0.0.0 --port 8000 --workers 4
# Command to clean cache in all project folders
clear-cache: clear-cache:
rm -rf ~/.cache/uv/environments-v2/* && find . -type d -name "__pycache__" -exec rm -r {} + rm -rf ~/.cache/uv/environments-v2/* && find . -type d -name "__pycache__" -exec rm -r {} +
# Comando para criar uma nova migração # Command to create a new migration and apply it
alembic-migrate: alembic-migrate:
alembic revision --autogenerate -m "$(message)" && alembic upgrade head alembic revision --autogenerate -m "$(message)" && alembic upgrade head
# Comando para resetar o banco de dados # Command to reset the database
alembic-reset: alembic-reset:
alembic downgrade base && alembic upgrade head alembic downgrade base && alembic upgrade head
# Comando para forçar upgrade com CASCADE # Commands to run seeders
alembic-upgrade-cascade:
psql -U postgres -d a2a_saas -c "DROP TABLE IF EXISTS events CASCADE; DROP TABLE IF EXISTS sessions CASCADE; DROP TABLE IF EXISTS user_states CASCADE; DROP TABLE IF EXISTS app_states CASCADE;" && alembic upgrade head
# Comandos para executar seeders
seed-admin: seed-admin:
python -m scripts.seeders.admin_seeder python -m scripts.seeders.admin_seeder
@ -58,7 +58,7 @@ seed-contacts:
seed-all: seed-all:
python -m scripts.run_seeders python -m scripts.run_seeders
# Comandos Docker # Docker commands
docker-build: docker-build:
docker-compose build docker-compose build
@ -73,3 +73,20 @@ docker-logs:
docker-seed: docker-seed:
docker-compose exec api python -m scripts.run_seeders docker-compose exec api python -m scripts.run_seeders
# Testing, linting and formatting commands
lint:
flake8 src/ tests/
format:
black src/ tests/
# Virtual environment and installation commands
venv:
python -m venv venv
install:
pip install -e .
install-dev:
pip install -e ".[dev]"

53
conftest.py Normal file
View File

@ -0,0 +1,53 @@
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from src.config.database import Base, get_db
from src.main import app
# Use in-memory SQLite for tests
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="function")
def db_session():
"""Creates a fresh database session for each test."""
Base.metadata.create_all(bind=engine) # Create tables
connection = engine.connect()
transaction = connection.begin()
session = TestingSessionLocal(bind=connection)
# Use our test database instead of the standard one
def override_get_db():
try:
yield session
session.commit()
finally:
session.close()
app.dependency_overrides[get_db] = override_get_db
yield session # The test will run here
# Teardown
transaction.rollback()
connection.close()
Base.metadata.drop_all(bind=engine)
app.dependency_overrides.clear()
@pytest.fixture(scope="function")
def client(db_session):
"""Creates a FastAPI TestClient with database session fixture."""
with TestClient(app) as test_client:
yield test_client

89
pyproject.toml Normal file
View File

@ -0,0 +1,89 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "evo-ai"
version = "1.0.0"
description = "API for executing AI agents"
readme = "README.md"
authors = [
{name = "EvoAI Team", email = "admin@evoai.com"}
]
requires-python = ">=3.10"
license = {text = "Proprietary"}
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"License :: Other/Proprietary License",
"Operating System :: OS Independent",
]
# Main dependencies
dependencies = [
"fastapi==0.115.12",
"uvicorn==0.34.2",
"pydantic==2.11.3",
"sqlalchemy==2.0.40",
"psycopg2-binary==2.9.10",
"google-cloud-aiplatform==1.90.0",
"python-dotenv==1.1.0",
"google-adk==0.3.0",
"litellm==1.67.4.post1",
"python-multipart==0.0.20",
"alembic==1.15.2",
"asyncpg==0.30.0",
"python-jose==3.4.0",
"passlib==1.7.4",
"sendgrid==6.11.0",
"pydantic-settings==2.9.1",
"fastapi_utils==0.8.0",
"bcrypt==4.3.0",
"jinja2==3.1.6",
]
[project.optional-dependencies]
dev = [
"black==25.1.0",
"flake8==7.2.0",
"pytest==8.3.5",
"pytest-cov==6.1.1",
"httpx==0.28.1",
"pytest-asyncio==0.26.0",
]
[tool.setuptools]
packages = ["src"]
[tool.black]
line-length = 88
target-version = ["py310"]
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
| venv
| alembic/versions
)/
'''
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*.py"
python_functions = "test_*"
python_classes = "Test*"
filterwarnings = [
"ignore::DeprecationWarning",
]
[tool.coverage.run]
source = ["src"]
omit = ["tests/*", "alembic/*"]

View File

@ -1,22 +0,0 @@
fastapi
uvicorn
pydantic
sqlalchemy
psycopg2
psycopg2-binary
google-cloud-aiplatform
python-dotenv
google-adk
litellm
python-multipart
alembic
asyncpg
# Novas dependências para autenticação
python-jose[cryptography]
passlib[bcrypt]
sendgrid
pydantic[email]
pydantic-settings
fastapi_utils
bcrypt
jinja2

6
setup.py Normal file
View File

@ -0,0 +1,6 @@
"""Setup script for the package."""
from setuptools import setup
if __name__ == "__main__":
setup()

View File

@ -1,3 +1,3 @@
""" """
Pacote principal da aplicação Main package of the application
""" """

View File

@ -1,14 +1,17 @@
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List, Optional
from datetime import datetime
import uuid import uuid
from src.config.database import get_db from src.config.database import get_db
from src.core.jwt_middleware import get_jwt_token, verify_admin from src.core.jwt_middleware import get_jwt_token, verify_admin
from src.schemas.audit import AuditLogResponse, AuditLogFilter from src.schemas.audit import AuditLogResponse, AuditLogFilter
from src.services.audit_service import get_audit_logs, create_audit_log from src.services.audit_service import get_audit_logs, create_audit_log
from src.services.user_service import get_admin_users, create_admin_user, deactivate_user from src.services.user_service import (
get_admin_users,
create_admin_user,
deactivate_user,
)
from src.schemas.user import UserResponse, AdminUserCreate from src.schemas.user import UserResponse, AdminUserCreate
router = APIRouter( router = APIRouter(
@ -18,6 +21,7 @@ router = APIRouter(
responses={403: {"description": "Permission denied"}}, responses={403: {"description": "Permission denied"}},
) )
# Audit routes # Audit routes
@router.get("/audit-logs", response_model=List[AuditLogResponse]) @router.get("/audit-logs", response_model=List[AuditLogResponse])
async def read_audit_logs( async def read_audit_logs(
@ -45,9 +49,10 @@ async def read_audit_logs(
resource_type=filters.resource_type, resource_type=filters.resource_type,
resource_id=filters.resource_id, resource_id=filters.resource_id,
start_date=filters.start_date, start_date=filters.start_date,
end_date=filters.end_date end_date=filters.end_date,
) )
# Admin routes # Admin routes
@router.get("/users", response_model=List[UserResponse]) @router.get("/users", response_model=List[UserResponse])
async def read_admin_users( async def read_admin_users(
@ -70,6 +75,7 @@ async def read_admin_users(
""" """
return get_admin_users(db, skip, limit) return get_admin_users(db, skip, limit)
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_new_admin_user( async def create_new_admin_user(
user_data: AdminUserCreate, user_data: AdminUserCreate,
@ -97,16 +103,13 @@ async def create_new_admin_user(
if not user_id: if not user_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unable to identify the logged in user" detail="Unable to identify the logged in user",
) )
# Create admin user # Create admin user
user, message = create_admin_user(db, user_data) user, message = create_admin_user(db, user_data)
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
status_code=status.HTTP_400_BAD_REQUEST,
detail=message
)
# Register action in audit log # Register action in audit log
create_audit_log( create_audit_log(
@ -116,11 +119,12 @@ async def create_new_admin_user(
resource_type="admin_user", resource_type="admin_user",
resource_id=str(user.id), resource_id=str(user.id),
details={"email": user.email}, details={"email": user.email},
request=request request=request,
) )
return user return user
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def deactivate_admin_user( async def deactivate_admin_user(
user_id: uuid.UUID, user_id: uuid.UUID,
@ -145,23 +149,20 @@ async def deactivate_admin_user(
if not current_user_id: if not current_user_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unable to identify the logged in user" detail="Unable to identify the logged in user",
) )
# Do not allow deactivating yourself # Do not allow deactivating yourself
if str(user_id) == current_user_id: if str(user_id) == current_user_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Unable to deactivate your own user" detail="Unable to deactivate your own user",
) )
# Deactivate user # Deactivate user
success, message = deactivate_user(db, user_id) success, message = deactivate_user(db, user_id)
if not success: if not success:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
status_code=status.HTTP_400_BAD_REQUEST,
detail=message
)
# Register action in audit log # Register action in audit log
create_audit_log( create_audit_log(
@ -171,5 +172,5 @@ async def deactivate_admin_user(
resource_type="admin_user", resource_type="admin_user",
resource_id=str(user_id), resource_id=str(user_id),
details=None, details=None,
request=request request=request,
) )

View File

@ -7,10 +7,6 @@ from src.core.jwt_middleware import (
get_jwt_token, get_jwt_token,
verify_user_client, verify_user_client,
) )
from src.core.jwt_middleware import (
get_jwt_token,
verify_user_client,
)
from src.schemas.schemas import ( from src.schemas.schemas import (
Agent, Agent,
AgentCreate, AgentCreate,

View File

@ -162,9 +162,7 @@ async def login_for_access_token(form_data: UserLogin, db: Session = Depends(get
""" """
user = authenticate_user(db, form_data.email, form_data.password) user = authenticate_user(db, form_data.email, form_data.password)
if not user: if not user:
logger.warning( logger.warning(f"Login attempt with invalid credentials: {form_data.email}")
f"Login attempt with invalid credentials: {form_data.email}"
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password", detail="Invalid email or password",

View File

@ -11,7 +11,11 @@ from src.services import (
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
from src.services.agent_runner import run_agent from src.services.agent_runner import run_agent
from src.core.exceptions import AgentNotFoundError from src.core.exceptions import AgentNotFoundError
from src.main import session_service, artifacts_service, memory_service from src.services.service_providers import (
session_service,
artifacts_service,
memory_service,
)
from datetime import datetime from datetime import datetime
import logging import logging

View File

@ -19,7 +19,7 @@ from src.services.session_service import (
get_sessions_by_agent, get_sessions_by_agent,
get_sessions_by_client, get_sessions_by_client,
) )
from src.main import session_service from src.services.service_providers import session_service
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,6 +30,7 @@ router = APIRouter(
responses={404: {"description": "Not found"}}, responses={404: {"description": "Not found"}},
) )
# Session Routes # Session Routes
@router.get("/client/{client_id}", response_model=List[Adk_Session]) @router.get("/client/{client_id}", response_model=List[Adk_Session])
async def get_client_sessions( async def get_client_sessions(

View File

@ -10,6 +10,7 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
def get_db(): def get_db():
db = SessionLocal() db = SessionLocal()
try: try:

View File

@ -1,55 +1,66 @@
from fastapi import HTTPException from fastapi import HTTPException
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
class BaseAPIException(HTTPException): class BaseAPIException(HTTPException):
"""Base class for API exceptions""" """Base class for API exceptions"""
def __init__( def __init__(
self, self,
status_code: int, status_code: int,
message: str, message: str,
error_code: str, error_code: str,
details: Optional[Dict[str, Any]] = None details: Optional[Dict[str, Any]] = None,
): ):
super().__init__(status_code=status_code, detail={ super().__init__(
status_code=status_code,
detail={
"error": message, "error": message,
"error_code": error_code, "error_code": error_code,
"details": details or {} "details": details or {},
}) },
)
class AgentNotFoundError(BaseAPIException): class AgentNotFoundError(BaseAPIException):
"""Exception when the agent is not found""" """Exception when the agent is not found"""
def __init__(self, agent_id: str): def __init__(self, agent_id: str):
super().__init__( super().__init__(
status_code=404, status_code=404,
message=f"Agent with ID {agent_id} not found", message=f"Agent with ID {agent_id} not found",
error_code="AGENT_NOT_FOUND" error_code="AGENT_NOT_FOUND",
) )
class InvalidParameterError(BaseAPIException): class InvalidParameterError(BaseAPIException):
"""Exception for invalid parameters""" """Exception for invalid parameters"""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
super().__init__( super().__init__(
status_code=400, status_code=400,
message=message, message=message,
error_code="INVALID_PARAMETER", error_code="INVALID_PARAMETER",
details=details details=details,
) )
class InvalidRequestError(BaseAPIException): class InvalidRequestError(BaseAPIException):
"""Exception for invalid requests""" """Exception for invalid requests"""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
super().__init__( super().__init__(
status_code=400, status_code=400,
message=message, message=message,
error_code="INVALID_REQUEST", error_code="INVALID_REQUEST",
details=details details=details,
) )
class InternalServerError(BaseAPIException): class InternalServerError(BaseAPIException):
"""Exception for server errors""" """Exception for server errors"""
def __init__(self, message: str = "Server error"): def __init__(self, message: str = "Server error"):
super().__init__( super().__init__(
status_code=500, status_code=500, message=message, error_code="INTERNAL_SERVER_ERROR"
message=message,
error_code="INTERNAL_SERVER_ERROR"
) )

View File

@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict: async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
""" """
Extracts and validates the JWT token Extracts and validates the JWT token
@ -34,9 +35,7 @@ async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
try: try:
payload = jwt.decode( payload = jwt.decode(
token, token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
) )
email: str = payload.get("sub") email: str = payload.get("sub")
@ -55,10 +54,11 @@ async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
logger.error(f"Error decoding JWT token: {str(e)}") logger.error(f"Error decoding JWT token: {str(e)}")
raise credentials_exception raise credentials_exception
async def verify_user_client( async def verify_user_client(
payload: dict = Depends(get_jwt_token), payload: dict = Depends(get_jwt_token),
db: Session = Depends(get_db), db: Session = Depends(get_db),
required_client_id: UUID = None required_client_id: UUID = None,
) -> bool: ) -> bool:
""" """
Verifies if the user is associated with the specified client Verifies if the user is associated with the specified client
@ -81,10 +81,12 @@ async def verify_user_client(
# Para não-admins, verificar se o client_id corresponde # Para não-admins, verificar se o client_id corresponde
user_client_id = payload.get("client_id") user_client_id = payload.get("client_id")
if not user_client_id: if not user_client_id:
logger.warning(f"Non-admin user without client_id in token: {payload.get('sub')}") logger.warning(
f"Non-admin user without client_id in token: {payload.get('sub')}"
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="User not associated with a client" detail="User not associated with a client",
) )
# If no client_id is specified to verify, any client is valid # If no client_id is specified to verify, any client is valid
@ -93,14 +95,17 @@ async def verify_user_client(
# Verify if the user's client_id corresponds to the required_client_id # Verify if the user's client_id corresponds to the required_client_id
if str(user_client_id) != str(required_client_id): if str(user_client_id) != str(required_client_id):
logger.warning(f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}") logger.warning(
f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}"
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to access resources of this client" detail="Access denied to access resources of this client",
) )
return True return True
async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool: async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
""" """
Verifies if the user is an administrator Verifies if the user is an administrator
@ -118,12 +123,15 @@ async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
logger.warning(f"Access denied to admin: User {payload.get('sub')}") logger.warning(f"Access denied to admin: User {payload.get('sub')}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Restricted to administrators." detail="Access denied. Restricted to administrators.",
) )
return True return True
def get_current_user_client_id(payload: dict = Depends(get_jwt_token)) -> Optional[UUID]:
def get_current_user_client_id(
payload: dict = Depends(get_jwt_token),
) -> Optional[UUID]:
""" """
Gets the ID of the client associated with the current user Gets the ID of the client associated with the current user

View File

@ -1,35 +1,30 @@
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
# Add the root directory to PYTHONPATH
root_dir = Path(__file__).parent.parent
sys.path.append(str(root_dir))
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.config.database import engine, Base from src.config.database import engine, Base
from src.config.settings import settings from src.config.settings import settings
from src.utils.logger import setup_logger from src.utils.logger import setup_logger
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
# Initialize service instances # Necessary for other modules
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING) from src.services.service_providers import session_service # noqa: F401
artifacts_service = InMemoryArtifactService() from src.services.service_providers import artifacts_service # noqa: F401
memory_service = InMemoryMemoryService() from src.services.service_providers import memory_service # noqa: F401
# Import routers after service initialization to avoid circular imports import src.api.auth_routes
from src.api.auth_routes import router as auth_router import src.api.admin_routes
from src.api.admin_routes import router as admin_router import src.api.chat_routes
from src.api.chat_routes import router as chat_router import src.api.session_routes
from src.api.session_routes import router as session_router import src.api.agent_routes
from src.api.agent_routes import router as agent_router import src.api.contact_routes
from src.api.contact_routes import router as contact_router import src.api.mcp_server_routes
from src.api.mcp_server_routes import router as mcp_server_router import src.api.tool_routes
from src.api.tool_routes import router as tool_router import src.api.client_routes
from src.api.client_routes import router as client_router
# Add the root directory to PYTHONPATH
root_dir = Path(__file__).parent.parent
sys.path.append(str(root_dir))
# Configure logger # Configure logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
@ -52,8 +47,7 @@ app.add_middleware(
# PostgreSQL configuration # PostgreSQL configuration
POSTGRES_CONNECTION_STRING = os.getenv( POSTGRES_CONNECTION_STRING = os.getenv(
"POSTGRES_CONNECTION_STRING", "POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
"postgresql://postgres:root@localhost:5432/evo_ai"
) )
# Create database tables # Create database tables
@ -61,6 +55,17 @@ Base.metadata.create_all(bind=engine)
API_PREFIX = "/api/v1" API_PREFIX = "/api/v1"
# Define router references
auth_router = src.api.auth_routes.router
admin_router = src.api.admin_routes.router
chat_router = src.api.chat_routes.router
session_router = src.api.session_routes.router
agent_router = src.api.agent_routes.router
contact_router = src.api.contact_routes.router
mcp_server_router = src.api.mcp_server_routes.router
tool_router = src.api.tool_routes.router
client_router = src.api.client_routes.router
# Include routes # Include routes
app.include_router(auth_router, prefix=API_PREFIX) app.include_router(auth_router, prefix=API_PREFIX)
app.include_router(admin_router, prefix=API_PREFIX) app.include_router(admin_router, prefix=API_PREFIX)
@ -79,5 +84,5 @@ def read_root():
"message": "Welcome to Evo AI API", "message": "Welcome to Evo AI API",
"documentation": "/docs", "documentation": "/docs",
"version": settings.API_VERSION, "version": settings.API_VERSION,
"auth": "To access the API, use JWT authentication via '/api/v1/auth/login'" "auth": "To access the API, use JWT authentication via '/api/v1/auth/login'",
} }

View File

@ -1,9 +1,20 @@
from sqlalchemy import Column, String, UUID, DateTime, ForeignKey, JSON, Text, BigInteger, CheckConstraint, Boolean from sqlalchemy import (
Column,
String,
UUID,
DateTime,
ForeignKey,
JSON,
Text,
CheckConstraint,
Boolean,
)
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, backref from sqlalchemy.orm import relationship, backref
from src.config.database import Base from src.config.database import Base
import uuid import uuid
class Client(Base): class Client(Base):
__tablename__ = "clients" __tablename__ = "clients"
@ -13,13 +24,16 @@ class Client(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class User(Base): class User(Base):
__tablename__ = "users" __tablename__ = "users"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
email = Column(String, unique=True, index=True, nullable=False) email = Column(String, unique=True, index=True, nullable=False)
password_hash = Column(String, nullable=False) password_hash = Column(String, nullable=False)
client_id = Column(UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True) client_id = Column(
UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True
)
is_active = Column(Boolean, default=False) is_active = Column(Boolean, default=False)
is_admin = Column(Boolean, default=False) is_admin = Column(Boolean, default=False)
email_verified = Column(Boolean, default=False) email_verified = Column(Boolean, default=False)
@ -31,7 +45,10 @@ class User(Base):
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
# Relationship with Client (One-to-One, optional for administrators) # Relationship with Client (One-to-One, optional for administrators)
client = relationship("Client", backref=backref("user", uselist=False, cascade="all, delete-orphan")) client = relationship(
"Client", backref=backref("user", uselist=False, cascade="all, delete-orphan")
)
class Contact(Base): class Contact(Base):
__tablename__ = "contacts" __tablename__ = "contacts"
@ -44,6 +61,7 @@ class Contact(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class Agent(Base): class Agent(Base):
__tablename__ = "agents" __tablename__ = "agents"
@ -60,21 +78,30 @@ class Agent(Base):
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
__table_args__ = ( __table_args__ = (
CheckConstraint("type IN ('llm', 'sequential', 'parallel', 'loop')", name='check_agent_type'), CheckConstraint(
"type IN ('llm', 'sequential', 'parallel', 'loop')", name="check_agent_type"
),
) )
def to_dict(self): def to_dict(self):
"""Converts the object to a dictionary, converting UUIDs to strings""" """Converts the object to a dictionary, converting UUIDs to strings"""
result = {} result = {}
for key, value in self.__dict__.items(): for key, value in self.__dict__.items():
if key.startswith('_'): if key.startswith("_"):
continue continue
if isinstance(value, uuid.UUID): if isinstance(value, uuid.UUID):
result[key] = str(value) result[key] = str(value)
elif isinstance(value, dict): elif isinstance(value, dict):
result[key] = self._convert_dict(value) result[key] = self._convert_dict(value)
elif isinstance(value, list): elif isinstance(value, list):
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value] result[key] = [
(
self._convert_dict(item)
if isinstance(item, dict)
else str(item) if isinstance(item, uuid.UUID) else item
)
for item in value
]
else: else:
result[key] = value result[key] = value
return result return result
@ -88,11 +115,19 @@ class Agent(Base):
elif isinstance(value, dict): elif isinstance(value, dict):
result[key] = self._convert_dict(value) result[key] = self._convert_dict(value)
elif isinstance(value, list): elif isinstance(value, list):
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value] result[key] = [
(
self._convert_dict(item)
if isinstance(item, dict)
else str(item) if isinstance(item, uuid.UUID) else item
)
for item in value
]
else: else:
result[key] = value result[key] = value
return result return result
class MCPServer(Base): class MCPServer(Base):
__tablename__ = "mcp_servers" __tablename__ = "mcp_servers"
@ -107,9 +142,12 @@ class MCPServer(Base):
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
__table_args__ = ( __table_args__ = (
CheckConstraint("type IN ('official', 'community')", name='check_mcp_server_type'), CheckConstraint(
"type IN ('official', 'community')", name="check_mcp_server_type"
),
) )
class Tool(Base): class Tool(Base):
__tablename__ = "tools" __tablename__ = "tools"
@ -121,10 +159,11 @@ class Tool(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class Session(Base): class Session(Base):
__tablename__ = "sessions" __tablename__ = "sessions"
# The directive below makes Alembic ignore this table in migrations # The directive below makes Alembic ignore this table in migrations
__table_args__ = {'extend_existing': True, 'info': {'skip_autogenerate': True}} __table_args__ = {"extend_existing": True, "info": {"skip_autogenerate": True}}
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
app_name = Column(String) app_name = Column(String)
@ -133,11 +172,14 @@ class Session(Base):
create_time = Column(DateTime(timezone=True)) create_time = Column(DateTime(timezone=True))
update_time = Column(DateTime(timezone=True)) update_time = Column(DateTime(timezone=True))
class AuditLog(Base): class AuditLog(Base):
__tablename__ = "audit_logs" __tablename__ = "audit_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True) user_id = Column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
action = Column(String, nullable=False) action = Column(String, nullable=False)
resource_type = Column(String, nullable=False) resource_type = Column(String, nullable=False)
resource_id = Column(String, nullable=True) resource_id = Column(String, nullable=True)

View File

@ -1,26 +1,38 @@
from typing import List, Optional, Dict, Any, Union from typing import List, Optional, Dict, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from uuid import UUID from uuid import UUID
class ToolConfig(BaseModel): class ToolConfig(BaseModel):
"""Configuration of a tool""" """Configuration of a tool"""
id: UUID id: UUID
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the tool") envs: Dict[str, str] = Field(
default_factory=dict, description="Environment variables of the tool"
)
class Config: class Config:
from_attributes = True from_attributes = True
class MCPServerConfig(BaseModel): class MCPServerConfig(BaseModel):
"""Configuration of an MCP server""" """Configuration of an MCP server"""
id: UUID id: UUID
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the server") envs: Dict[str, str] = Field(
tools: List[str] = Field(default_factory=list, description="List of tools of the server") default_factory=dict, description="Environment variables of the server"
)
tools: List[str] = Field(
default_factory=list, description="List of tools of the server"
)
class Config: class Config:
from_attributes = True from_attributes = True
class HTTPToolParameter(BaseModel): class HTTPToolParameter(BaseModel):
"""Parameter of an HTTP tool""" """Parameter of an HTTP tool"""
type: str type: str
required: bool required: bool
description: str description: str
@ -28,8 +40,10 @@ class HTTPToolParameter(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class HTTPToolParameters(BaseModel): class HTTPToolParameters(BaseModel):
"""Parameters of an HTTP tool""" """Parameters of an HTTP tool"""
path_params: Optional[Dict[str, str]] = None path_params: Optional[Dict[str, str]] = None
query_params: Optional[Dict[str, Union[str, List[str]]]] = None query_params: Optional[Dict[str, Union[str, List[str]]]] = None
body_params: Optional[Dict[str, HTTPToolParameter]] = None body_params: Optional[Dict[str, HTTPToolParameter]] = None
@ -37,8 +51,10 @@ class HTTPToolParameters(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class HTTPToolErrorHandling(BaseModel): class HTTPToolErrorHandling(BaseModel):
"""Configuration of error handling""" """Configuration of error handling"""
timeout: int timeout: int
retry_count: int retry_count: int
fallback_response: Dict[str, str] fallback_response: Dict[str, str]
@ -46,8 +62,10 @@ class HTTPToolErrorHandling(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class HTTPTool(BaseModel): class HTTPTool(BaseModel):
"""Configuration of an HTTP tool""" """Configuration of an HTTP tool"""
name: str name: str
method: str method: str
values: Dict[str, str] values: Dict[str, str]
@ -60,42 +78,72 @@ class HTTPTool(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class CustomTools(BaseModel): class CustomTools(BaseModel):
"""Configuration of custom tools""" """Configuration of custom tools"""
http_tools: List[HTTPTool] = Field(default_factory=list, description="List of HTTP tools")
http_tools: List[HTTPTool] = Field(
default_factory=list, description="List of HTTP tools"
)
class Config: class Config:
from_attributes = True from_attributes = True
class LLMConfig(BaseModel): class LLMConfig(BaseModel):
"""Configuration for LLM agents""" """Configuration for LLM agents"""
tools: Optional[List[ToolConfig]] = Field(default=None, description="List of available tools")
custom_tools: Optional[CustomTools] = Field(default=None, description="Custom tools") tools: Optional[List[ToolConfig]] = Field(
mcp_servers: Optional[List[MCPServerConfig]] = Field(default=None, description="List of MCP servers") default=None, description="List of available tools"
sub_agents: Optional[List[UUID]] = Field(default=None, description="List of IDs of sub-agents") )
custom_tools: Optional[CustomTools] = Field(
default=None, description="Custom tools"
)
mcp_servers: Optional[List[MCPServerConfig]] = Field(
default=None, description="List of MCP servers"
)
sub_agents: Optional[List[UUID]] = Field(
default=None, description="List of IDs of sub-agents"
)
class Config: class Config:
from_attributes = True from_attributes = True
class SequentialConfig(BaseModel): class SequentialConfig(BaseModel):
"""Configuration for sequential agents""" """Configuration for sequential agents"""
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents in execution order")
sub_agents: List[UUID] = Field(
..., description="List of IDs of sub-agents in execution order"
)
class Config: class Config:
from_attributes = True from_attributes = True
class ParallelConfig(BaseModel): class ParallelConfig(BaseModel):
"""Configuration for parallel agents""" """Configuration for parallel agents"""
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for parallel execution")
sub_agents: List[UUID] = Field(
..., description="List of IDs of sub-agents for parallel execution"
)
class Config: class Config:
from_attributes = True from_attributes = True
class LoopConfig(BaseModel): class LoopConfig(BaseModel):
"""Configuration for loop agents""" """Configuration for loop agents"""
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for loop execution")
max_iterations: Optional[int] = Field(default=None, description="Maximum number of iterations") sub_agents: List[UUID] = Field(
condition: Optional[str] = Field(default=None, description="Condition to stop the loop") ..., description="List of IDs of sub-agents for loop execution"
)
max_iterations: Optional[int] = Field(
default=None, description="Maximum number of iterations"
)
condition: Optional[str] = Field(
default=None, description="Condition to stop the loop"
)
class Config: class Config:
from_attributes = True from_attributes = True

View File

@ -3,19 +3,25 @@ from typing import Optional, Dict, Any
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
class AuditLogBase(BaseModel): class AuditLogBase(BaseModel):
"""Base schema for audit log""" """Base schema for audit log"""
action: str action: str
resource_type: str resource_type: str
resource_id: Optional[str] = None resource_id: Optional[str] = None
details: Optional[Dict[str, Any]] = None details: Optional[Dict[str, Any]] = None
class AuditLogCreate(AuditLogBase): class AuditLogCreate(AuditLogBase):
"""Schema for creating audit log""" """Schema for creating audit log"""
pass pass
class AuditLogResponse(AuditLogBase): class AuditLogResponse(AuditLogBase):
"""Schema for audit log response""" """Schema for audit log response"""
id: UUID id: UUID
user_id: Optional[UUID] = None user_id: Optional[UUID] = None
ip_address: Optional[str] = None ip_address: Optional[str] = None
@ -25,8 +31,10 @@ class AuditLogResponse(AuditLogBase):
class Config: class Config:
from_attributes = True from_attributes = True
class AuditLogFilter(BaseModel): class AuditLogFilter(BaseModel):
"""Schema for audit log search filters""" """Schema for audit log search filters"""
user_id: Optional[UUID] = None user_id: Optional[UUID] = None
action: Optional[str] = None action: Optional[str] = None
resource_type: Optional[str] = None resource_type: Optional[str] = None

View File

@ -1,21 +1,33 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
"""Schema for chat requests""" """Schema for chat requests"""
agent_id: str = Field(..., description="ID of the agent that will process the message")
contact_id: str = Field(..., description="ID of the contact that will process the message") agent_id: str = Field(
..., description="ID of the agent that will process the message"
)
contact_id: str = Field(
..., description="ID of the contact that will process the message"
)
message: str = Field(..., description="User message") message: str = Field(..., description="User message")
class ChatResponse(BaseModel): class ChatResponse(BaseModel):
"""Schema for chat responses""" """Schema for chat responses"""
response: str = Field(..., description="Agent response") response: str = Field(..., description="Agent response")
status: str = Field(..., description="Operation status") status: str = Field(..., description="Operation status")
error: Optional[str] = Field(None, description="Error message, if there is one") error: Optional[str] = Field(None, description="Error message, if there is one")
timestamp: str = Field(..., description="Timestamp of the response") timestamp: str = Field(..., description="Timestamp of the response")
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
"""Schema for error responses""" """Schema for error responses"""
error: str = Field(..., description="Error message") error: str = Field(..., description="Error message")
status_code: int = Field(..., description="HTTP status code of the error") status_code: int = Field(..., description="HTTP status code of the error")
details: Optional[Dict[str, Any]] = Field(None, description="Additional error details") details: Optional[Dict[str, Any]] = Field(
None, description="Additional error details"
)

View File

@ -4,15 +4,18 @@ from datetime import datetime
from uuid import UUID from uuid import UUID
import uuid import uuid
import re import re
from .agent_config import LLMConfig, SequentialConfig, ParallelConfig, LoopConfig from src.schemas.agent_config import LLMConfig
class ClientBase(BaseModel): class ClientBase(BaseModel):
name: str name: str
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
class ClientCreate(ClientBase): class ClientCreate(ClientBase):
pass pass
class Client(ClientBase): class Client(ClientBase):
id: UUID id: UUID
created_at: datetime created_at: datetime
@ -20,14 +23,17 @@ class Client(ClientBase):
class Config: class Config:
from_attributes = True from_attributes = True
class ContactBase(BaseModel): class ContactBase(BaseModel):
ext_id: Optional[str] = None ext_id: Optional[str] = None
name: Optional[str] = None name: Optional[str] = None
meta: Optional[Dict[str, Any]] = Field(default_factory=dict) meta: Optional[Dict[str, Any]] = Field(default_factory=dict)
class ContactCreate(ContactBase): class ContactCreate(ContactBase):
client_id: UUID client_id: UUID
class Contact(ContactBase): class Contact(ContactBase):
id: UUID id: UUID
client_id: UUID client_id: UUID
@ -35,67 +41,80 @@ class Contact(ContactBase):
class Config: class Config:
from_attributes = True from_attributes = True
class AgentBase(BaseModel): class AgentBase(BaseModel):
name: str = Field(..., description="Agent name (no spaces or special characters)") name: str = Field(..., description="Agent name (no spaces or special characters)")
description: Optional[str] = Field(None, description="Agent description") description: Optional[str] = Field(None, description="Agent description")
type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)") type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)")
model: Optional[str] = Field(None, description="Agent model (required only for llm type)") model: Optional[str] = Field(
api_key: Optional[str] = Field(None, description="Agent API Key (required only for llm type)") None, description="Agent model (required only for llm type)"
)
api_key: Optional[str] = Field(
None, description="Agent API Key (required only for llm type)"
)
instruction: Optional[str] = None instruction: Optional[str] = None
config: Union[LLMConfig, Dict[str, Any]] = Field(..., description="Agent configuration based on type") config: Union[LLMConfig, Dict[str, Any]] = Field(
..., description="Agent configuration based on type"
)
@validator('name') @validator("name")
def validate_name(cls, v): def validate_name(cls, v):
if not re.match(r'^[a-zA-Z0-9_-]+$', v): if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError('Agent name cannot contain spaces or special characters') raise ValueError("Agent name cannot contain spaces or special characters")
return v return v
@validator('type') @validator("type")
def validate_type(cls, v): def validate_type(cls, v):
if v not in ['llm', 'sequential', 'parallel', 'loop']: if v not in ["llm", "sequential", "parallel", "loop"]:
raise ValueError('Invalid agent type. Must be: llm, sequential, parallel or loop') raise ValueError(
"Invalid agent type. Must be: llm, sequential, parallel or loop"
)
return v return v
@validator('model') @validator("model")
def validate_model(cls, v, values): def validate_model(cls, v, values):
if 'type' in values and values['type'] == 'llm' and not v: if "type" in values and values["type"] == "llm" and not v:
raise ValueError('Model is required for llm type agents') raise ValueError("Model is required for llm type agents")
return v return v
@validator('api_key') @validator("api_key")
def validate_api_key(cls, v, values): def validate_api_key(cls, v, values):
if 'type' in values and values['type'] == 'llm' and not v: if "type" in values and values["type"] == "llm" and not v:
raise ValueError('API Key is required for llm type agents') raise ValueError("API Key is required for llm type agents")
return v return v
@validator('config') @validator("config")
def validate_config(cls, v, values): def validate_config(cls, v, values):
if 'type' not in values: if "type" not in values:
return v return v
if values['type'] == 'llm': if values["type"] == "llm":
if isinstance(v, dict): if isinstance(v, dict):
try: try:
# Convert the dictionary to LLMConfig # Convert the dictionary to LLMConfig
v = LLMConfig(**v) v = LLMConfig(**v)
except Exception as e: except Exception as e:
raise ValueError(f'Invalid LLM configuration for agent: {str(e)}') raise ValueError(f"Invalid LLM configuration for agent: {str(e)}")
elif not isinstance(v, LLMConfig): elif not isinstance(v, LLMConfig):
raise ValueError('Invalid LLM configuration for agent') raise ValueError("Invalid LLM configuration for agent")
elif values['type'] in ['sequential', 'parallel', 'loop']: elif values["type"] in ["sequential", "parallel", "loop"]:
if not isinstance(v, dict): if not isinstance(v, dict):
raise ValueError(f'Invalid configuration for agent {values["type"]}') raise ValueError(f'Invalid configuration for agent {values["type"]}')
if 'sub_agents' not in v: if "sub_agents" not in v:
raise ValueError(f'Agent {values["type"]} must have sub_agents') raise ValueError(f'Agent {values["type"]} must have sub_agents')
if not isinstance(v['sub_agents'], list): if not isinstance(v["sub_agents"], list):
raise ValueError('sub_agents must be a list') raise ValueError("sub_agents must be a list")
if not v['sub_agents']: if not v["sub_agents"]:
raise ValueError(f'Agent {values["type"]} must have at least one sub-agent') raise ValueError(
f'Agent {values["type"]} must have at least one sub-agent'
)
return v return v
class AgentCreate(AgentBase): class AgentCreate(AgentBase):
client_id: UUID client_id: UUID
class Agent(AgentBase): class Agent(AgentBase):
id: UUID id: UUID
client_id: UUID client_id: UUID
@ -105,6 +124,7 @@ class Agent(AgentBase):
class Config: class Config:
from_attributes = True from_attributes = True
class MCPServerBase(BaseModel): class MCPServerBase(BaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
@ -113,9 +133,11 @@ class MCPServerBase(BaseModel):
tools: List[str] = Field(default_factory=list) tools: List[str] = Field(default_factory=list)
type: str = Field(default="official") type: str = Field(default="official")
class MCPServerCreate(MCPServerBase): class MCPServerCreate(MCPServerBase):
pass pass
class MCPServer(MCPServerBase): class MCPServer(MCPServerBase):
id: uuid.UUID id: uuid.UUID
created_at: datetime created_at: datetime
@ -124,15 +146,18 @@ class MCPServer(MCPServerBase):
class Config: class Config:
from_attributes = True from_attributes = True
class ToolBase(BaseModel): class ToolBase(BaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
config_json: Dict[str, Any] = Field(default_factory=dict) config_json: Dict[str, Any] = Field(default_factory=dict)
environments: Dict[str, Any] = Field(default_factory=dict) environments: Dict[str, Any] = Field(default_factory=dict)
class ToolCreate(ToolBase): class ToolCreate(ToolBase):
pass pass
class Tool(ToolBase): class Tool(ToolBase):
id: uuid.UUID id: uuid.UUID
created_at: datetime created_at: datetime

View File

@ -1,23 +1,28 @@
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
class UserBase(BaseModel): class UserBase(BaseModel):
email: EmailStr email: EmailStr
class UserCreate(UserBase): class UserCreate(UserBase):
password: str password: str
name: str # For client creation name: str # For client creation
class AdminUserCreate(UserBase): class AdminUserCreate(UserBase):
password: str password: str
name: str name: str
class UserLogin(BaseModel): class UserLogin(BaseModel):
email: EmailStr email: EmailStr
password: str password: str
class UserResponse(UserBase): class UserResponse(UserBase):
id: UUID id: UUID
client_id: Optional[UUID] = None client_id: Optional[UUID] = None
@ -29,22 +34,27 @@ class UserResponse(UserBase):
class Config: class Config:
from_attributes = True from_attributes = True
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
class TokenData(BaseModel): class TokenData(BaseModel):
sub: str # user email sub: str # user email
exp: datetime exp: datetime
is_admin: bool is_admin: bool
client_id: Optional[UUID] = None client_id: Optional[UUID] = None
class PasswordReset(BaseModel): class PasswordReset(BaseModel):
token: str token: str
new_password: str new_password: str
class ForgotPassword(BaseModel): class ForgotPassword(BaseModel):
email: EmailStr email: EmailStr
class MessageResponse(BaseModel): class MessageResponse(BaseModel):
message: str message: str

View File

@ -13,11 +13,10 @@ from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmResponse, LlmRequest from google.adk.models import LlmResponse, LlmRequest
from google.adk.tools import load_memory from google.adk.tools import load_memory
from typing import Optional
import logging
import os
import requests import requests
import os
from datetime import datetime from datetime import datetime
logger = setup_logger(__name__) logger = setup_logger(__name__)
@ -83,7 +82,7 @@ def before_model_callback(
llm_request.config.system_instruction = modified_text llm_request.config.system_instruction = modified_text
logger.debug( logger.debug(
f"📝 System instruction updated with search results and history" "📝 System instruction updated with search results and history"
) )
else: else:
logger.warning("⚠️ No results found in the search") logger.warning("⚠️ No results found in the search")
@ -180,7 +179,9 @@ class AgentBuilder:
mcp_tools = [] mcp_tools = []
mcp_exit_stack = None mcp_exit_stack = None
if agent.config.get("mcp_servers"): if agent.config.get("mcp_servers"):
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(agent.config, self.db) mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
agent.config, self.db
)
# Combine all tools # Combine all tools
all_tools = custom_tools + mcp_tools all_tools = custom_tools + mcp_tools
@ -201,10 +202,13 @@ class AgentBuilder:
# Check if load_memory is enabled # Check if load_memory is enabled
# before_model_callback_func = None # before_model_callback_func = None
if agent.config.get("load_memory") == True: if agent.config.get("load_memory"):
all_tools.append(load_memory) all_tools.append(load_memory)
# before_model_callback_func = before_model_callback # before_model_callback_func = before_model_callback
formatted_prompt = formatted_prompt + "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n" formatted_prompt = (
formatted_prompt
+ "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
)
return ( return (
LlmAgent( LlmAgent(

View File

@ -22,9 +22,7 @@ async def run_agent(
db: Session, db: Session,
): ):
try: try:
logger.info( logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}")
f"Starting execution of agent {agent_id} for contact {contact_id}"
)
logger.info(f"Received message: {message}") logger.info(f"Received message: {message}")
get_root_agent = get_agent(db, agent_id) get_root_agent = get_agent(db, agent_id)

View File

@ -216,9 +216,7 @@ async def update_agent(
return agent return agent
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException( raise HTTPException(status_code=500, detail=f"Error updating agent: {str(e)}")
status_code=500, detail=f"Error updating agent: {str(e)}"
)
def delete_agent(db: Session, agent_id: uuid.UUID) -> bool: def delete_agent(db: Session, agent_id: uuid.UUID) -> bool:

View File

@ -1,15 +1,15 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from src.models.models import AuditLog, User from src.models.models import AuditLog
from datetime import datetime from datetime import datetime
from fastapi import Request from fastapi import Request
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
import uuid import uuid
import logging import logging
import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_audit_log( def create_audit_log(
db: Session, db: Session,
user_id: Optional[uuid.UUID], user_id: Optional[uuid.UUID],
@ -17,7 +17,7 @@ def create_audit_log(
resource_type: str, resource_type: str,
resource_id: Optional[str] = None, resource_id: Optional[str] = None,
details: Optional[Dict[str, Any]] = None, details: Optional[Dict[str, Any]] = None,
request: Optional[Request] = None request: Optional[Request] = None,
) -> Optional[AuditLog]: ) -> Optional[AuditLog]:
""" """
Create a new audit log Create a new audit log
@ -39,7 +39,7 @@ def create_audit_log(
user_agent = None user_agent = None
if request: if request:
ip_address = request.client.host if hasattr(request, 'client') else None ip_address = request.client.host if hasattr(request, "client") else None
user_agent = request.headers.get("user-agent") user_agent = request.headers.get("user-agent")
# Convert details to serializable format # Convert details to serializable format
@ -56,7 +56,7 @@ def create_audit_log(
resource_id=str(resource_id) if resource_id else None, resource_id=str(resource_id) if resource_id else None,
details=details, details=details,
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent user_agent=user_agent,
) )
db.add(audit_log) db.add(audit_log)
@ -64,8 +64,8 @@ def create_audit_log(
db.refresh(audit_log) db.refresh(audit_log)
logger.info( logger.info(
f"Audit log created: {action} in {resource_type}" + f"Audit log created: {action} in {resource_type}"
(f" (ID: {resource_id})" if resource_id else "") + (f" (ID: {resource_id})" if resource_id else "")
) )
return audit_log return audit_log
@ -78,6 +78,7 @@ def create_audit_log(
logger.error(f"Unexpected error creating audit log: {str(e)}") logger.error(f"Unexpected error creating audit log: {str(e)}")
return None return None
def get_audit_logs( def get_audit_logs(
db: Session, db: Session,
skip: int = 0, skip: int = 0,
@ -87,7 +88,7 @@ def get_audit_logs(
resource_type: Optional[str] = None, resource_type: Optional[str] = None,
resource_id: Optional[str] = None, resource_id: Optional[str] = None,
start_date: Optional[datetime] = None, start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None end_date: Optional[datetime] = None,
) -> List[AuditLog]: ) -> List[AuditLog]:
""" """
Get audit logs with optional filters Get audit logs with optional filters

View File

@ -16,7 +16,10 @@ logger = logging.getLogger(__name__)
# Define OAuth2 authentication scheme with password flow # Define OAuth2 authentication scheme with password flow
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User:
async def get_current_user(
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
) -> User:
""" """
Get the current user from the JWT token Get the current user from the JWT token
@ -39,9 +42,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
try: try:
# Decode the token # Decode the token
payload = jwt.decode( payload = jwt.decode(
token, token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
) )
# Extract token data # Extract token data
@ -61,7 +62,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
sub=email, sub=email,
exp=datetime.fromtimestamp(exp), exp=datetime.fromtimestamp(exp),
is_admin=payload.get("is_admin", False), is_admin=payload.get("is_admin", False),
client_id=payload.get("client_id") client_id=payload.get("client_id"),
) )
except JWTError as e: except JWTError as e:
@ -77,13 +78,15 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
if not user.is_active: if not user.is_active:
logger.warning(f"Attempt to access inactive user: {user.email}") logger.warning(f"Attempt to access inactive user: {user.email}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
detail="Inactive user"
) )
return user return user
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
""" """
Check if the current user is active Check if the current user is active
@ -99,12 +102,14 @@ async def get_current_active_user(current_user: User = Depends(get_current_user)
if not current_user.is_active: if not current_user.is_active:
logger.warning(f"Attempt to access inactive user: {current_user.email}") logger.warning(f"Attempt to access inactive user: {current_user.email}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
detail="Inactive user"
) )
return current_user return current_user
async def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User:
async def get_current_admin_user(
current_user: User = Depends(get_current_user),
) -> User:
""" """
Check if the current user is an administrator Check if the current user is an administrator
@ -118,13 +123,16 @@ async def get_current_admin_user(current_user: User = Depends(get_current_user))
HTTPException: If the user is not an administrator HTTPException: If the user is not an administrator
""" """
if not current_user.is_admin: if not current_user.is_admin:
logger.warning(f"Attempt to access admin by non-admin user: {current_user.email}") logger.warning(
f"Attempt to access admin by non-admin user: {current_user.email}"
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Restricted to administrators." detail="Access denied. Restricted to administrators.",
) )
return current_user return current_user
def create_access_token(user: User) -> str: def create_access_token(user: User) -> str:
""" """
Create a JWT access token for the user Create a JWT access token for the user

View File

@ -11,6 +11,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]: def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
"""Search for a client by ID""" """Search for a client by ID"""
try: try:
@ -23,9 +24,10 @@ def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
logger.error(f"Error searching for client {client_id}: {str(e)}") logger.error(f"Error searching for client {client_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for client" detail="Error searching for client",
) )
def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]: def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
"""Search for all clients with pagination""" """Search for all clients with pagination"""
try: try:
@ -34,9 +36,10 @@ def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
logger.error(f"Error searching for clients: {str(e)}") logger.error(f"Error searching for clients: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for clients" detail="Error searching for clients",
) )
def create_client(db: Session, client: ClientCreate) -> Client: def create_client(db: Session, client: ClientCreate) -> Client:
"""Create a new client""" """Create a new client"""
try: try:
@ -51,10 +54,13 @@ def create_client(db: Session, client: ClientCreate) -> Client:
logger.error(f"Error creating client: {str(e)}") logger.error(f"Error creating client: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating client" detail="Error creating client",
) )
def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Optional[Client]:
def update_client(
db: Session, client_id: uuid.UUID, client: ClientCreate
) -> Optional[Client]:
"""Updates an existing client""" """Updates an existing client"""
try: try:
db_client = get_client(db, client_id) db_client = get_client(db, client_id)
@ -73,9 +79,10 @@ def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Op
logger.error(f"Error updating client {client_id}: {str(e)}") logger.error(f"Error updating client {client_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating client" detail="Error updating client",
) )
def delete_client(db: Session, client_id: uuid.UUID) -> bool: def delete_client(db: Session, client_id: uuid.UUID) -> bool:
"""Removes a client""" """Removes a client"""
try: try:
@ -92,10 +99,13 @@ def delete_client(db: Session, client_id: uuid.UUID) -> bool:
logger.error(f"Error removing client {client_id}: {str(e)}") logger.error(f"Error removing client {client_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing client" detail="Error removing client",
) )
def create_client_with_user(db: Session, client_data: ClientCreate, user_data: UserCreate) -> Tuple[Optional[Client], str]:
def create_client_with_user(
db: Session, client_data: ClientCreate, user_data: UserCreate
) -> Tuple[Optional[Client], str]:
""" """
Creates a new client with an associated user Creates a new client with an associated user

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]: def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
"""Search for a contact by ID""" """Search for a contact by ID"""
try: try:
@ -21,20 +22,30 @@ def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
logger.error(f"Error searching for contact {contact_id}: {str(e)}") logger.error(f"Error searching for contact {contact_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for contact" detail="Error searching for contact",
) )
def get_contacts_by_client(db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100) -> List[Contact]:
def get_contacts_by_client(
db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100
) -> List[Contact]:
"""Search for contacts of a client with pagination""" """Search for contacts of a client with pagination"""
try: try:
return db.query(Contact).filter(Contact.client_id == client_id).offset(skip).limit(limit).all() return (
db.query(Contact)
.filter(Contact.client_id == client_id)
.offset(skip)
.limit(limit)
.all()
)
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Error searching for contacts of client {client_id}: {str(e)}") logger.error(f"Error searching for contacts of client {client_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for contacts" detail="Error searching for contacts",
) )
def create_contact(db: Session, contact: ContactCreate) -> Contact: def create_contact(db: Session, contact: ContactCreate) -> Contact:
"""Create a new contact""" """Create a new contact"""
try: try:
@ -49,10 +60,13 @@ def create_contact(db: Session, contact: ContactCreate) -> Contact:
logger.error(f"Error creating contact: {str(e)}") logger.error(f"Error creating contact: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating contact" detail="Error creating contact",
) )
def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -> Optional[Contact]:
def update_contact(
db: Session, contact_id: uuid.UUID, contact: ContactCreate
) -> Optional[Contact]:
"""Update an existing contact""" """Update an existing contact"""
try: try:
db_contact = get_contact(db, contact_id) db_contact = get_contact(db, contact_id)
@ -71,9 +85,10 @@ def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -
logger.error(f"Error updating contact {contact_id}: {str(e)}") logger.error(f"Error updating contact {contact_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating contact" detail="Error updating contact",
) )
def delete_contact(db: Session, contact_id: uuid.UUID) -> bool: def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
"""Remove a contact""" """Remove a contact"""
try: try:
@ -90,5 +105,5 @@ def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
logger.error(f"Error removing contact {contact_id}: {str(e)}") logger.error(f"Error removing contact {contact_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing contact" detail="Error removing contact",
) )

View File

@ -6,6 +6,7 @@ from src.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
class CustomToolBuilder: class CustomToolBuilder:
def __init__(self): def __init__(self):
self.tools = [] self.tools = []
@ -53,7 +54,9 @@ class CustomToolBuilder:
# Adds default values to query params if they are not present # Adds default values to query params if they are not present
for param, value in values.items(): for param, value in values.items():
if param not in query_params and param not in parameters.get("path_params", {}): if param not in query_params and param not in parameters.get(
"path_params", {}
):
query_params[param] = value query_params[param] = value
# Processa body parameters # Processa body parameters
@ -64,7 +67,11 @@ class CustomToolBuilder:
# Adds default values to body if they are not present # Adds default values to body if they are not present
for param, value in values.items(): for param, value in values.items():
if param not in body_data and param not in query_params and param not in parameters.get("path_params", {}): if (
param not in body_data
and param not in query_params
and param not in parameters.get("path_params", {})
):
body_data[param] = value body_data[param] = value
# Makes the HTTP request # Makes the HTTP request
@ -74,7 +81,7 @@ class CustomToolBuilder:
headers=processed_headers, headers=processed_headers,
params=query_params, params=query_params,
json=body_data if body_data else None, json=body_data if body_data else None,
timeout=error_handling.get("timeout", 30) timeout=error_handling.get("timeout", 30),
) )
if response.status_code >= 400: if response.status_code >= 400:
@ -87,10 +94,12 @@ class CustomToolBuilder:
except Exception as e: except Exception as e:
logger.error(f"Error executing tool {name}: {str(e)}") logger.error(f"Error executing tool {name}: {str(e)}")
return json.dumps(error_handling.get("fallback_response", { return json.dumps(
"error": "tool_execution_error", error_handling.get(
"message": str(e) "fallback_response",
})) {"error": "tool_execution_error", "message": str(e)},
)
)
# Adds dynamic docstring based on the configuration # Adds dynamic docstring based on the configuration
param_docs = [] param_docs = []
@ -109,7 +118,9 @@ class CustomToolBuilder:
# Adds body parameters # Adds body parameters
for param, param_config in parameters.get("body_params", {}).items(): for param, param_config in parameters.get("body_params", {}).items():
required = "Required" if param_config.get("required", False) else "Optional" required = "Required" if param_config.get("required", False) else "Optional"
param_docs.append(f"{param} ({param_config['type']}, {required}): {param_config['description']}") param_docs.append(
f"{param} ({param_config['type']}, {required}): {param_config['description']}"
)
# Adds default values # Adds default values
if values: if values:

View File

@ -16,9 +16,10 @@ os.makedirs(templates_dir, exist_ok=True)
# Configure Jinja2 with the templates directory # Configure Jinja2 with the templates directory
env = Environment( env = Environment(
loader=FileSystemLoader(templates_dir), loader=FileSystemLoader(templates_dir),
autoescape=select_autoescape(['html', 'xml']) autoescape=select_autoescape(["html", "xml"]),
) )
def _render_template(template_name: str, context: dict) -> str: def _render_template(template_name: str, context: dict) -> str:
""" """
Render a template with the provided data Render a template with the provided data
@ -37,6 +38,7 @@ def _render_template(template_name: str, context: dict) -> str:
logger.error(f"Error rendering template '{template_name}': {str(e)}") logger.error(f"Error rendering template '{template_name}': {str(e)}")
return f"<p>Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}</p>" return f"<p>Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}</p>"
def send_verification_email(email: str, token: str) -> bool: def send_verification_email(email: str, token: str) -> bool:
""" """
Send a verification email to the user Send a verification email to the user
@ -56,11 +58,16 @@ def send_verification_email(email: str, token: str) -> bool:
verification_link = f"{settings.APP_URL}/auth/verify-email/{token}" verification_link = f"{settings.APP_URL}/auth/verify-email/{token}"
html_content = _render_template('verification_email', { html_content = _render_template(
'verification_link': verification_link, "verification_email",
'user_name': email.split('@')[0], # Use part of the email as temporary name {
'current_year': datetime.now().year "verification_link": verification_link,
}) "user_name": email.split("@")[
0
], # Use part of the email as temporary name
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content) content = Content("text/html", html_content)
@ -71,13 +78,16 @@ def send_verification_email(email: str, token: str) -> bool:
logger.info(f"Verification email sent to {email}") logger.info(f"Verification email sent to {email}")
return True return True
else: else:
logger.error(f"Failed to send verification email to {email}. Status: {response.status_code}") logger.error(
f"Failed to send verification email to {email}. Status: {response.status_code}"
)
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error sending verification email to {email}: {str(e)}") logger.error(f"Error sending verification email to {email}: {str(e)}")
return False return False
def send_password_reset_email(email: str, token: str) -> bool: def send_password_reset_email(email: str, token: str) -> bool:
""" """
Send a password reset email to the user Send a password reset email to the user
@ -97,11 +107,16 @@ def send_password_reset_email(email: str, token: str) -> bool:
reset_link = f"{settings.APP_URL}/reset-password?token={token}" reset_link = f"{settings.APP_URL}/reset-password?token={token}"
html_content = _render_template('password_reset', { html_content = _render_template(
'reset_link': reset_link, "password_reset",
'user_name': email.split('@')[0], # Use part of the email as temporary name {
'current_year': datetime.now().year "reset_link": reset_link,
}) "user_name": email.split("@")[
0
], # Use part of the email as temporary name
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content) content = Content("text/html", html_content)
@ -112,13 +127,16 @@ def send_password_reset_email(email: str, token: str) -> bool:
logger.info(f"Password reset email sent to {email}") logger.info(f"Password reset email sent to {email}")
return True return True
else: else:
logger.error(f"Failed to send password reset email to {email}. Status: {response.status_code}") logger.error(
f"Failed to send password reset email to {email}. Status: {response.status_code}"
)
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error sending password reset email to {email}: {str(e)}") logger.error(f"Error sending password reset email to {email}: {str(e)}")
return False return False
def send_welcome_email(email: str, user_name: str = None) -> bool: def send_welcome_email(email: str, user_name: str = None) -> bool:
""" """
Send a welcome email to the user after verification Send a welcome email to the user after verification
@ -138,11 +156,14 @@ def send_welcome_email(email: str, user_name: str = None) -> bool:
dashboard_link = f"{settings.APP_URL}/dashboard" dashboard_link = f"{settings.APP_URL}/dashboard"
html_content = _render_template('welcome_email', { html_content = _render_template(
'dashboard_link': dashboard_link, "welcome_email",
'user_name': user_name or email.split('@')[0], {
'current_year': datetime.now().year "dashboard_link": dashboard_link,
}) "user_name": user_name or email.split("@")[0],
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content) content = Content("text/html", html_content)
@ -153,14 +174,19 @@ def send_welcome_email(email: str, user_name: str = None) -> bool:
logger.info(f"Welcome email sent to {email}") logger.info(f"Welcome email sent to {email}")
return True return True
else: else:
logger.error(f"Failed to send welcome email to {email}. Status: {response.status_code}") logger.error(
f"Failed to send welcome email to {email}. Status: {response.status_code}"
)
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error sending welcome email to {email}: {str(e)}") logger.error(f"Error sending welcome email to {email}: {str(e)}")
return False return False
def send_account_locked_email(email: str, reset_token: str, failed_attempts: int, time_period: str) -> bool:
def send_account_locked_email(
email: str, reset_token: str, failed_attempts: int, time_period: str
) -> bool:
""" """
Send an email informing that the account has been locked after login attempts Send an email informing that the account has been locked after login attempts
@ -181,13 +207,16 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int
reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}" reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}"
html_content = _render_template('account_locked', { html_content = _render_template(
'reset_link': reset_link, "account_locked",
'user_name': email.split('@')[0], {
'failed_attempts': failed_attempts, "reset_link": reset_link,
'time_period': time_period, "user_name": email.split("@")[0],
'current_year': datetime.now().year "failed_attempts": failed_attempts,
}) "time_period": time_period,
"current_year": datetime.now().year,
},
)
content = Content("text/html", html_content) content = Content("text/html", html_content)
@ -198,7 +227,9 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int
logger.info(f"Account locked email sent to {email}") logger.info(f"Account locked email sent to {email}")
return True return True
else: else:
logger.error(f"Failed to send account locked email to {email}. Status: {response.status_code}") logger.error(
f"Failed to send account locked email to {email}. Status: {response.status_code}"
)
return False return False
except Exception as e: except Exception as e:

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]: def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
"""Search for an MCP server by ID""" """Search for an MCP server by ID"""
try: try:
@ -21,9 +22,10 @@ def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
logger.error(f"Error searching for MCP server {server_id}: {str(e)}") logger.error(f"Error searching for MCP server {server_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for MCP server" detail="Error searching for MCP server",
) )
def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]: def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]:
"""Search for all MCP servers with pagination""" """Search for all MCP servers with pagination"""
try: try:
@ -32,9 +34,10 @@ def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPSer
logger.error(f"Error searching for MCP servers: {str(e)}") logger.error(f"Error searching for MCP servers: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for MCP servers" detail="Error searching for MCP servers",
) )
def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer: def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
"""Create a new MCP server""" """Create a new MCP server"""
try: try:
@ -49,10 +52,13 @@ def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
logger.error(f"Error creating MCP server: {str(e)}") logger.error(f"Error creating MCP server: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating MCP server" detail="Error creating MCP server",
) )
def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate) -> Optional[MCPServer]:
def update_mcp_server(
db: Session, server_id: uuid.UUID, server: MCPServerCreate
) -> Optional[MCPServer]:
"""Update an existing MCP server""" """Update an existing MCP server"""
try: try:
db_server = get_mcp_server(db, server_id) db_server = get_mcp_server(db, server_id)
@ -71,9 +77,10 @@ def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate
logger.error(f"Error updating MCP server {server_id}: {str(e)}") logger.error(f"Error updating MCP server {server_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating MCP server" detail="Error updating MCP server",
) )
def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool: def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
"""Remove an MCP server""" """Remove an MCP server"""
try: try:
@ -90,5 +97,5 @@ def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
logger.error(f"Error removing MCP server {server_id}: {str(e)}") logger.error(f"Error removing MCP server {server_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing MCP server" detail="Error removing MCP server",
) )

View File

@ -12,20 +12,22 @@ from sqlalchemy.orm import Session
logger = setup_logger(__name__) logger = setup_logger(__name__)
class MCPService: class MCPService:
def __init__(self): def __init__(self):
self.tools = [] self.tools = []
self.exit_stack = AsyncExitStack() self.exit_stack = AsyncExitStack()
async def _connect_to_mcp_server(self, server_config: Dict[str, Any]) -> Tuple[List[Any], Optional[AsyncExitStack]]: async def _connect_to_mcp_server(
self, server_config: Dict[str, Any]
) -> Tuple[List[Any], Optional[AsyncExitStack]]:
"""Connect to a specific MCP server and return its tools.""" """Connect to a specific MCP server and return its tools."""
try: try:
# Determines the type of server (local or remote) # Determines the type of server (local or remote)
if "url" in server_config: if "url" in server_config:
# Remote server (SSE) # Remote server (SSE)
connection_params = SseServerParams( connection_params = SseServerParams(
url=server_config["url"], url=server_config["url"], headers=server_config.get("headers", {})
headers=server_config.get("headers", {})
) )
else: else:
# Local server (Stdio) # Local server (Stdio)
@ -39,9 +41,7 @@ class MCPService:
os.environ[key] = value os.environ[key] = value
connection_params = StdioServerParameters( connection_params = StdioServerParameters(
command=command, command=command, args=args, env=env
args=args,
env=env
) )
tools, exit_stack = await MCPToolset.from_server( tools, exit_stack = await MCPToolset.from_server(
@ -74,7 +74,9 @@ class MCPService:
return filtered_tools return filtered_tools
def _filter_tools_by_agent(self, tools: List[Any], agent_tools: List[str]) -> List[Any]: def _filter_tools_by_agent(
self, tools: List[Any], agent_tools: List[str]
) -> List[Any]:
"""Filters tools compatible with the agent.""" """Filters tools compatible with the agent."""
filtered_tools = [] filtered_tools = []
for tool in tools: for tool in tools:
@ -83,7 +85,9 @@ class MCPService:
filtered_tools.append(tool) filtered_tools.append(tool)
return filtered_tools return filtered_tools
async def build_tools(self, mcp_config: Dict[str, Any], db: Session) -> Tuple[List[Any], AsyncExitStack]: async def build_tools(
self, mcp_config: Dict[str, Any], db: Session
) -> Tuple[List[Any], AsyncExitStack]:
"""Builds a list of tools from multiple MCP servers.""" """Builds a list of tools from multiple MCP servers."""
self.tools = [] self.tools = []
self.exit_stack = AsyncExitStack() self.exit_stack = AsyncExitStack()
@ -92,7 +96,7 @@ class MCPService:
for server in mcp_config.get("mcp_servers", []): for server in mcp_config.get("mcp_servers", []):
try: try:
# Search for the MCP server in the database # Search for the MCP server in the database
mcp_server = get_mcp_server(db, server['id']) mcp_server = get_mcp_server(db, server["id"])
if not mcp_server: if not mcp_server:
logger.warning(f"Servidor MCP não encontrado: {server['id']}") logger.warning(f"Servidor MCP não encontrado: {server['id']}")
continue continue
@ -101,14 +105,16 @@ class MCPService:
server_config = mcp_server.config_json.copy() server_config = mcp_server.config_json.copy()
# Replaces the environment variables in the config_json # Replaces the environment variables in the config_json
if 'env' in server_config: if "env" in server_config:
for key, value in server_config['env'].items(): for key, value in server_config["env"].items():
if value.startswith('env@@'): if value.startswith("env@@"):
env_key = value.replace('env@@', '') env_key = value.replace("env@@", "")
if env_key in server.get('envs', {}): if env_key in server.get("envs", {}):
server_config['env'][key] = server['envs'][env_key] server_config["env"][key] = server["envs"][env_key]
else: else:
logger.warning(f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}") logger.warning(
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
)
continue continue
logger.info(f"Connecting to MCP server: {mcp_server.name}") logger.info(f"Connecting to MCP server: {mcp_server.name}")
@ -119,20 +125,28 @@ class MCPService:
filtered_tools = self._filter_incompatible_tools(tools) filtered_tools = self._filter_incompatible_tools(tools)
# Filters tools compatible with the agent # Filters tools compatible with the agent
agent_tools = server.get('tools', []) agent_tools = server.get("tools", [])
filtered_tools = self._filter_tools_by_agent(filtered_tools, agent_tools) filtered_tools = self._filter_tools_by_agent(
filtered_tools, agent_tools
)
self.tools.extend(filtered_tools) self.tools.extend(filtered_tools)
# Registers the exit_stack with the AsyncExitStack # Registers the exit_stack with the AsyncExitStack
await self.exit_stack.enter_async_context(exit_stack) await self.exit_stack.enter_async_context(exit_stack)
logger.info(f"Connected successfully. Added {len(filtered_tools)} tools.") logger.info(
f"Connected successfully. Added {len(filtered_tools)} tools."
)
else: else:
logger.warning(f"Failed to connect or no tools available for {mcp_server.name}") logger.warning(
f"Failed to connect or no tools available for {mcp_server.name}"
)
except Exception as e: except Exception as e:
logger.error(f"Error connecting to MCP server {server['id']}: {e}") logger.error(f"Error connecting to MCP server {server['id']}: {e}")
continue continue
logger.info(f"MCP Toolset created successfully. Total of {len(self.tools)} tools.") logger.info(
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
)
return self.tools, self.exit_stack return self.tools, self.exit_stack

View File

@ -0,0 +1,9 @@
from src.config.settings import settings
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
# Initialize service instances
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
artifacts_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()

View File

@ -132,7 +132,7 @@ def get_session_events(
session = get_session_by_id(session_service, session_id) session = get_session_by_id(session_service, session_id)
# If we get here, the session exists (get_session_by_id already validates) # If we get here, the session exists (get_session_by_id already validates)
if not hasattr(session, 'events') or session.events is None: if not hasattr(session, "events") or session.events is None:
return [] return []
return session.events return session.events

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]: def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
"""Search for a tool by ID""" """Search for a tool by ID"""
try: try:
@ -21,9 +22,10 @@ def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
logger.error(f"Error searching for tool {tool_id}: {str(e)}") logger.error(f"Error searching for tool {tool_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for tool" detail="Error searching for tool",
) )
def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]: def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
"""Search for all tools with pagination""" """Search for all tools with pagination"""
try: try:
@ -32,9 +34,10 @@ def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
logger.error(f"Error searching for tools: {str(e)}") logger.error(f"Error searching for tools: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error searching for tools" detail="Error searching for tools",
) )
def create_tool(db: Session, tool: ToolCreate) -> Tool: def create_tool(db: Session, tool: ToolCreate) -> Tool:
"""Creates a new tool""" """Creates a new tool"""
try: try:
@ -49,9 +52,10 @@ def create_tool(db: Session, tool: ToolCreate) -> Tool:
logger.error(f"Error creating tool: {str(e)}") logger.error(f"Error creating tool: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating tool" detail="Error creating tool",
) )
def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]: def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]:
"""Updates an existing tool""" """Updates an existing tool"""
try: try:
@ -71,9 +75,10 @@ def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[T
logger.error(f"Error updating tool {tool_id}: {str(e)}") logger.error(f"Error updating tool {tool_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating tool" detail="Error updating tool",
) )
def delete_tool(db: Session, tool_id: uuid.UUID) -> bool: def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
"""Remove a tool""" """Remove a tool"""
try: try:
@ -90,5 +95,5 @@ def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
logger.error(f"Error removing tool {tool_id}: {str(e)}") logger.error(f"Error removing tool {tool_id}: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error removing tool" detail="Error removing tool",
) )

View File

@ -3,7 +3,10 @@ from sqlalchemy.exc import SQLAlchemyError
from src.models.models import User, Client from src.models.models import User, Client
from src.schemas.user import UserCreate from src.schemas.user import UserCreate
from src.utils.security import get_password_hash, verify_password, generate_token from src.utils.security import get_password_hash, verify_password, generate_token
from src.services.email_service import send_verification_email, send_password_reset_email from src.services.email_service import (
send_verification_email,
send_password_reset_email,
)
from datetime import datetime, timedelta from datetime import datetime, timedelta
import uuid import uuid
import logging import logging
@ -11,7 +14,13 @@ from typing import Optional, Tuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, client_id: Optional[uuid.UUID] = None) -> Tuple[Optional[User], str]:
def create_user(
db: Session,
user_data: UserCreate,
is_admin: bool = False,
client_id: Optional[uuid.UUID] = None,
) -> Tuple[Optional[User], str]:
""" """
Creates a new user in the system Creates a new user in the system
@ -28,7 +37,9 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
# Check if email already exists # Check if email already exists
db_user = db.query(User).filter(User.email == user_data.email).first() db_user = db.query(User).filter(User.email == user_data.email).first()
if db_user: if db_user:
logger.warning(f"Attempt to register with existing email: {user_data.email}") logger.warning(
f"Attempt to register with existing email: {user_data.email}"
)
return None, "Email already registered" return None, "Email already registered"
# Create verification token # Create verification token
@ -56,7 +67,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
is_active=False, # Inactive until email is verified is_active=False, # Inactive until email is verified
email_verified=False, email_verified=False,
verification_token=verification_token, verification_token=verification_token,
verification_token_expiry=token_expiry verification_token_expiry=token_expiry,
) )
db.add(user) db.add(user)
db.commit() db.commit()
@ -68,7 +79,10 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
# We don't do rollback here, we just log the error # We don't do rollback here, we just log the error
logger.info(f"User created successfully: {user.email}") logger.info(f"User created successfully: {user.email}")
return user, "User created successfully. Check your email to activate your account." return (
user,
"User created successfully. Check your email to activate your account.",
)
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
@ -79,6 +93,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
logger.error(f"Unexpected error creating user: {str(e)}") logger.error(f"Unexpected error creating user: {str(e)}")
return None, f"Unexpected error: {str(e)}" return None, f"Unexpected error: {str(e)}"
def verify_email(db: Session, token: str) -> Tuple[bool, str]: def verify_email(db: Session, token: str) -> Tuple[bool, str]:
""" """
Verify the user's email using the provided token Verify the user's email using the provided token
@ -111,7 +126,9 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]:
expiry = expiry.replace(tzinfo=now.tzinfo) expiry = expiry.replace(tzinfo=now.tzinfo)
if expiry < now: if expiry < now:
logger.warning(f"Attempt to verify with expired token for user: {user.email}") logger.warning(
f"Attempt to verify with expired token for user: {user.email}"
)
return False, "Verification token expired" return False, "Verification token expired"
# Update user # Update user
@ -133,6 +150,7 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]:
logger.error(f"Unexpected error verifying email: {str(e)}") logger.error(f"Unexpected error verifying email: {str(e)}")
return False, f"Unexpected error: {str(e)}" return False, f"Unexpected error: {str(e)}"
def resend_verification(db: Session, email: str) -> Tuple[bool, str]: def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
""" """
Resend the verification email Resend the verification email
@ -149,11 +167,15 @@ def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
user = db.query(User).filter(User.email == email).first() user = db.query(User).filter(User.email == email).first()
if not user: if not user:
logger.warning(f"Attempt to resend verification email for non-existent email: {email}") logger.warning(
f"Attempt to resend verification email for non-existent email: {email}"
)
return False, "Email not found" return False, "Email not found"
if user.email_verified: if user.email_verified:
logger.info(f"Attempt to resend verification email for already verified email: {email}") logger.info(
f"Attempt to resend verification email for already verified email: {email}"
)
return False, "Email already verified" return False, "Email already verified"
# Generate new token # Generate new token
@ -184,6 +206,7 @@ def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
logger.error(f"Unexpected error resending verification: {str(e)}") logger.error(f"Unexpected error resending verification: {str(e)}")
return False, f"Unexpected error: {str(e)}" return False, f"Unexpected error: {str(e)}"
def forgot_password(db: Session, email: str) -> Tuple[bool, str]: def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
""" """
Initiates the password recovery process Initiates the password recovery process
@ -202,7 +225,10 @@ def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
if not user: if not user:
# For security, we don't inform if the email exists or not # For security, we don't inform if the email exists or not
logger.info(f"Attempt to recover password for non-existent email: {email}") logger.info(f"Attempt to recover password for non-existent email: {email}")
return True, "If the email is registered, you will receive instructions to reset your password." return (
True,
"If the email is registered, you will receive instructions to reset your password.",
)
# Generate reset token # Generate reset token
reset_token = generate_token() reset_token = generate_token()
@ -221,7 +247,10 @@ def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
return False, "Failed to send password reset email" return False, "Failed to send password reset email"
logger.info(f"Password reset email sent successfully to: {user.email}") logger.info(f"Password reset email sent successfully to: {user.email}")
return True, "If the email is registered, you will receive instructions to reset your password." return (
True,
"If the email is registered, you will receive instructions to reset your password.",
)
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
@ -232,6 +261,7 @@ def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
logger.error(f"Unexpected error processing password recovery: {str(e)}") logger.error(f"Unexpected error processing password recovery: {str(e)}")
return False, f"Unexpected error: {str(e)}" return False, f"Unexpected error: {str(e)}"
def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]: def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]:
""" """
Resets the user's password using the provided token Resets the user's password using the provided token
@ -254,7 +284,9 @@ def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, st
# Check if the token has expired # Check if the token has expired
if user.password_reset_expiry < datetime.utcnow(): if user.password_reset_expiry < datetime.utcnow():
logger.warning(f"Attempt to reset password with expired token for user: {user.email}") logger.warning(
f"Attempt to reset password with expired token for user: {user.email}"
)
return False, "Password reset token expired" return False, "Password reset token expired"
# Update password # Update password
@ -264,7 +296,10 @@ def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, st
db.commit() db.commit()
logger.info(f"Password reset successfully for user: {user.email}") logger.info(f"Password reset successfully for user: {user.email}")
return True, "Password reset successfully. You can now login with your new password." return (
True,
"Password reset successfully. You can now login with your new password.",
)
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
@ -275,6 +310,7 @@ def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, st
logger.error(f"Unexpected error resetting password: {str(e)}") logger.error(f"Unexpected error resetting password: {str(e)}")
return False, f"Unexpected error: {str(e)}" return False, f"Unexpected error: {str(e)}"
def get_user_by_email(db: Session, email: str) -> Optional[User]: def get_user_by_email(db: Session, email: str) -> Optional[User]:
""" """
Searches for a user by email Searches for a user by email
@ -292,6 +328,7 @@ def get_user_by_email(db: Session, email: str) -> Optional[User]:
logger.error(f"Error searching for user by email: {str(e)}") logger.error(f"Error searching for user by email: {str(e)}")
return None return None
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
""" """
Authenticates a user with email and password Authenticates a user with email and password
@ -313,6 +350,7 @@ def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
return None return None
return user return user
def get_admin_users(db: Session, skip: int = 0, limit: int = 100): def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
""" """
Lists the admin users Lists the admin users
@ -326,7 +364,7 @@ def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
List[User]: List of admin users List[User]: List of admin users
""" """
try: try:
users = db.query(User).filter(User.is_admin == True).offset(skip).limit(limit).all() users = db.query(User).filter(User.is_admin).offset(skip).limit(limit).all()
logger.info(f"List of admins: {len(users)} found") logger.info(f"List of admins: {len(users)} found")
return users return users
@ -338,6 +376,7 @@ def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
logger.error(f"Unexpected error listing admins: {str(e)}") logger.error(f"Unexpected error listing admins: {str(e)}")
return [] return []
def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]: def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]:
""" """
Creates a new admin user Creates a new admin user
@ -351,6 +390,7 @@ def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User
""" """
return create_user(db, user_data, is_admin=True) return create_user(db, user_data, is_admin=True)
def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]: def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]:
""" """
Deactivates a user (does not delete, only marks as inactive) Deactivates a user (does not delete, only marks as inactive)

View File

@ -3,6 +3,7 @@ import os
import sys import sys
from src.config.settings import settings from src.config.settings import settings
class CustomFormatter(logging.Formatter): class CustomFormatter(logging.Formatter):
"""Custom formatter for logs""" """Custom formatter for logs"""
@ -12,14 +13,16 @@ class CustomFormatter(logging.Formatter):
bold_red = "\x1b[31;1m" bold_red = "\x1b[31;1m"
reset = "\x1b[0m" reset = "\x1b[0m"
format_template = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)" format_template = (
"%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
)
FORMATS = { FORMATS = {
logging.DEBUG: grey + format_template + reset, logging.DEBUG: grey + format_template + reset,
logging.INFO: grey + format_template + reset, logging.INFO: grey + format_template + reset,
logging.WARNING: yellow + format_template + reset, logging.WARNING: yellow + format_template + reset,
logging.ERROR: red + format_template + reset, logging.ERROR: red + format_template + reset,
logging.CRITICAL: bold_red + format_template + reset logging.CRITICAL: bold_red + format_template + reset,
} }
def format(self, record): def format(self, record):
@ -27,6 +30,7 @@ class CustomFormatter(logging.Formatter):
formatter = logging.Formatter(log_fmt) formatter = logging.Formatter(log_fmt)
return formatter.format(record) return formatter.format(record)
def setup_logger(name: str) -> logging.Logger: def setup_logger(name: str) -> logging.Logger:
""" """
Configures a custom logger Configures a custom logger

View File

@ -11,7 +11,8 @@ from dataclasses import dataclass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Fix bcrypt error with passlib # Fix bcrypt error with passlib
if not hasattr(bcrypt, '__about__'): if not hasattr(bcrypt, "__about__"):
@dataclass @dataclass
class BcryptAbout: class BcryptAbout:
__version__: str = getattr(bcrypt, "__version__") __version__: str = getattr(bcrypt, "__version__")
@ -21,31 +22,33 @@ if not hasattr(bcrypt, '__about__'):
# Context for password hashing using bcrypt # Context for password hashing using bcrypt
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
"""Creates a password hash""" """Creates a password hash"""
return pwd_context.hash(password) return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies if the provided password matches the stored hash""" """Verifies if the provided password matches the stored hash"""
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
def create_jwt_token(data: dict, expires_delta: timedelta = None) -> str: def create_jwt_token(data: dict, expires_delta: timedelta = None) -> str:
"""Creates a JWT token""" """Creates a JWT token"""
to_encode = data.copy() to_encode = data.copy()
if expires_delta: if expires_delta:
expire = datetime.utcnow() + expires_delta expire = datetime.utcnow() + expires_delta
else: else:
expire = datetime.utcnow() + timedelta( expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRATION_TIME)
minutes=settings.JWT_EXPIRATION_TIME
)
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
encoded_jwt = jwt.encode( encoded_jwt = jwt.encode(
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
) )
return encoded_jwt return encoded_jwt
def generate_token(length: int = 32) -> str: def generate_token(length: int = 32) -> str:
"""Generates a secure token for email verification or password reset""" """Generates a secure token for email verification or password reset"""
alphabet = string.ascii_letters + string.digits alphabet = string.ascii_letters + string.digits
token = ''.join(secrets.choice(alphabet) for _ in range(length)) token = "".join(secrets.choice(alphabet) for _ in range(length))
return token return token

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@
# Package initialization for tests

1
tests/api/__init__.py Normal file
View File

@ -0,0 +1 @@
# API tests package

11
tests/api/test_root.py Normal file
View File

@ -0,0 +1,11 @@
def test_read_root(client):
"""
Test that the root endpoint returns the correct response.
"""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert "message" in data
assert "documentation" in data
assert "version" in data
assert "auth" in data

View File

@ -0,0 +1 @@
# Services tests package

View File

@ -0,0 +1,27 @@
from src.services.auth_service import create_access_token
from src.models.models import User
import uuid
def test_create_access_token():
"""
Test that an access token is created with the correct data.
"""
# Create a mock user
user = User(
id=uuid.uuid4(),
email="test@example.com",
hashed_password="hashed_password",
is_active=True,
is_admin=False,
name="Test User",
client_id=uuid.uuid4(),
)
# Create token
token = create_access_token(user)
# Simple validation
assert token is not None
assert isinstance(token, str)
assert len(token) > 0