chore: update project structure and add testing framework
This commit is contained in:
parent
7af234ef48
commit
e7e030dfd5
77
.cursorrules
77
.cursorrules
@ -17,9 +17,16 @@
|
||||
```
|
||||
src/
|
||||
├── api/
|
||||
│ ├── routes.py # API routes definition
|
||||
│ ├── auth_routes.py # Authentication routes (login, registration, etc.)
|
||||
│ └── admin_routes.py # Protected admin routes
|
||||
│ ├── __init__.py # Package initialization
|
||||
│ ├── admin_routes.py # Admin routes for management interface
|
||||
│ ├── agent_routes.py # Routes to manage agents
|
||||
│ ├── auth_routes.py # Authentication routes (login, registration)
|
||||
│ ├── chat_routes.py # Routes for chat interactions with agents
|
||||
│ ├── client_routes.py # Routes to manage clients
|
||||
│ ├── contact_routes.py # Routes to manage contacts
|
||||
│ ├── mcp_server_routes.py # Routes to manage MCP servers
|
||||
│ ├── session_routes.py # Routes to manage chat sessions
|
||||
│ └── tool_routes.py # Routes to manage tools for agents
|
||||
├── config/
|
||||
│ ├── database.py # Database configuration
|
||||
│ └── settings.py # General settings
|
||||
@ -30,18 +37,21 @@ src/
|
||||
│ └── models.py # SQLAlchemy models
|
||||
├── schemas/
|
||||
│ ├── schemas.py # Main Pydantic schemas
|
||||
│ ├── chat.py # Chat schemas
|
||||
│ ├── user.py # User and authentication schemas
|
||||
│ └── audit.py # Audit logs schemas
|
||||
├── services/
|
||||
│ ├── agent_service.py # Business logic for agents
|
||||
│ ├── agent_runner.py # Agent execution logic
|
||||
│ ├── auth_service.py # JWT authentication logic
|
||||
│ ├── audit_service.py # Audit logs logic
|
||||
│ ├── client_service.py # Business logic for clients
|
||||
│ ├── contact_service.py # Business logic for contacts
|
||||
│ ├── mcp_server_service.py # Business logic for MCP servers
|
||||
│ ├── tool_service.py # Business logic for tools
|
||||
│ ├── user_service.py # User and authentication logic
|
||||
│ ├── auth_service.py # JWT authentication logic
|
||||
│ ├── email_service.py # Email sending service
|
||||
│ └── audit_service.py # Audit logs logic
|
||||
│ ├── mcp_server_service.py # Business logic for MCP servers
|
||||
│ ├── session_service.py # Business logic for chat sessions
|
||||
│ ├── tool_service.py # Business logic for tools
|
||||
│ └── user_service.py # User management logic
|
||||
├── templates/
|
||||
│ ├── emails/
|
||||
│ │ ├── base_email.html # Base template with common structure and styles
|
||||
@ -49,7 +59,21 @@ src/
|
||||
│ │ ├── password_reset.html # Password reset template
|
||||
│ │ ├── welcome_email.html # Welcome email after verification
|
||||
│ │ └── account_locked.html # Security alert for locked accounts
|
||||
├── tests/
|
||||
│ ├── __init__.py # Package initialization
|
||||
│ ├── api/
|
||||
│ │ ├── __init__.py # Package initialization
|
||||
│ │ ├── test_auth_routes.py # Test for authentication routes
|
||||
│ │ └── test_root.py # Test for root endpoint
|
||||
│ ├── models/
|
||||
│ │ ├── __init__.py # Package initialization
|
||||
│ │ ├── test_models.py # Test for models
|
||||
│ ├── services/
|
||||
│ │ ├── __init__.py # Package initialization
|
||||
│ │ ├── test_auth_service.py # Test for authentication service
|
||||
│ │ └── test_user_service.py # Test for user service
|
||||
└── utils/
|
||||
├── logger.py # Logger configuration
|
||||
└── security.py # Security utilities (JWT, hash)
|
||||
```
|
||||
|
||||
@ -63,6 +87,15 @@ src/
|
||||
- Code examples in documentation must be in English
|
||||
- Commit messages must be in English
|
||||
|
||||
### Project Configuration
|
||||
- Dependencies managed in `pyproject.toml` using modern Python packaging standards
|
||||
- Development dependencies specified as optional dependencies in `pyproject.toml`
|
||||
- Single source of truth for project metadata in `pyproject.toml`
|
||||
- Build system configured to use setuptools
|
||||
- Pytest configuration in `pyproject.toml` under `[tool.pytest.ini_options]`
|
||||
- Code formatting with Black configured in `pyproject.toml`
|
||||
- Linting with Flake8 configured in `.flake8`
|
||||
|
||||
### Schemas (Pydantic)
|
||||
- Use `BaseModel` as base for all schemas
|
||||
- Define fields with explicit types
|
||||
@ -136,6 +169,28 @@ src/
|
||||
- Indentation with 4 spaces
|
||||
- Maximum of 79 characters per line
|
||||
|
||||
## Commit Rules
|
||||
- Use Conventional Commits format for all commit messages
|
||||
- Format: `<type>(<scope>): <description>`
|
||||
- Types:
|
||||
- `feat`: A new feature
|
||||
- `fix`: A bug fix
|
||||
- `docs`: Documentation changes
|
||||
- `style`: Changes that do not affect code meaning (formatting, etc.)
|
||||
- `refactor`: Code changes that neither fix a bug nor add a feature
|
||||
- `perf`: Performance improvements
|
||||
- `test`: Adding or modifying tests
|
||||
- `chore`: Changes to build process or auxiliary tools
|
||||
- Scope is optional and should be the module or component affected
|
||||
- Description should be concise, in the imperative mood, and not capitalized
|
||||
- Use body for more detailed explanations if needed
|
||||
- Reference issues in the footer with `Fixes #123` or `Relates to #123`
|
||||
- Examples:
|
||||
- `feat(auth): add password reset functionality`
|
||||
- `fix(api): correct validation error in client registration`
|
||||
- `docs: update API documentation for new endpoints`
|
||||
- `refactor(services): improve error handling in authentication`
|
||||
|
||||
## Best Practices
|
||||
- Always validate input data
|
||||
- Implement appropriate logging
|
||||
@ -163,6 +218,7 @@ src/
|
||||
|
||||
## Useful Commands
|
||||
- `make run`: Start the server
|
||||
- `make run-prod`: Start the server in production mode
|
||||
- `make alembic-revision message="description"`: Create new migration
|
||||
- `make alembic-upgrade`: Apply pending migrations
|
||||
- `make alembic-downgrade`: Revert last migration
|
||||
@ -170,3 +226,8 @@ src/
|
||||
- `make alembic-reset`: Reset database to initial state
|
||||
- `make alembic-upgrade-cascade`: Force upgrade removing dependencies
|
||||
- `make clear-cache`: Clean project cache
|
||||
- `make seed-all`: Run all database seeders
|
||||
- `make lint`: Run linting checks with flake8
|
||||
- `make format`: Format code with black
|
||||
- `make install`: Install project for development
|
||||
- `make install-dev`: Install project with development dependencies
|
||||
|
@ -1,3 +1,47 @@
|
||||
# Environment and IDE
|
||||
.venv
|
||||
venv
|
||||
.env
|
||||
.idea
|
||||
.vscode
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
.pytest_cache
|
||||
.coverage
|
||||
htmlcov/
|
||||
.tox/
|
||||
|
||||
# Version control
|
||||
.git
|
||||
.github
|
||||
.gitignore
|
||||
|
||||
# Logs and temp files
|
||||
logs
|
||||
*.log
|
||||
tmp
|
||||
.DS_Store
|
||||
|
||||
# Docker
|
||||
.dockerignore
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
|
||||
# Documentation
|
||||
README.md
|
||||
LICENSE
|
||||
docs/
|
||||
|
||||
# Development tools
|
||||
tests/
|
||||
.flake8
|
||||
pyproject.toml
|
||||
requirements-dev.txt
|
||||
Makefile
|
||||
|
||||
# Ambiente virtual
|
||||
venv/
|
||||
__pycache__/
|
||||
|
8
.flake8
Normal file
8
.flake8
Normal file
@ -0,0 +1,8 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
exclude = .git,__pycache__,venv,alembic/versions/*
|
||||
ignore = E203, W503, E501
|
||||
per-file-ignores =
|
||||
__init__.py: F401
|
||||
src/models/models.py: E712
|
||||
alembic/*: E711,E712,F401
|
77
.gitignore
vendored
77
.gitignore
vendored
@ -1,13 +1,10 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
@ -19,19 +16,10 @@ lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
@ -42,14 +30,57 @@ nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.sublime-project
|
||||
*.sublime-workspace
|
||||
.DS_Store
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
backup/
|
||||
|
||||
# Local
|
||||
local_settings.py
|
||||
local.py
|
||||
|
||||
# Docker
|
||||
.docker/
|
||||
|
||||
# Alembic versions
|
||||
# alembic/versions/
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
@ -77,17 +108,6 @@ celerybeat-schedule
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
.venv/
|
||||
.env/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
@ -102,11 +122,8 @@ venv.bak/
|
||||
.mypy_cache/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
16
Dockerfile
16
Dockerfile
@ -1,4 +1,4 @@
|
||||
FROM python:3.11-slim
|
||||
FROM python:3.10-slim
|
||||
|
||||
# Define o diretório de trabalho
|
||||
WORKDIR /app
|
||||
@ -15,19 +15,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copia os arquivos de requisitos
|
||||
COPY requirements.txt .
|
||||
|
||||
# Instala as dependências
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copia o código-fonte
|
||||
# Copy project files
|
||||
COPY . .
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir -e .
|
||||
|
||||
# Configuração para produção
|
||||
ENV PORT=8000 \
|
||||
HOST=0.0.0.0 \
|
||||
DEBUG=false
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Define o comando de inicialização
|
||||
CMD alembic upgrade head && uvicorn src.main:app --host $HOST --port $PORT
|
47
Makefile
47
Makefile
@ -1,42 +1,42 @@
|
||||
.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs
|
||||
.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs lint format install install-dev venv
|
||||
|
||||
# Comandos do Alembic
|
||||
# Alembic commands
|
||||
init:
|
||||
alembic init alembics
|
||||
|
||||
# make alembic-revision message="descrição da migração"
|
||||
# make alembic-revision message="migration description"
|
||||
alembic-revision:
|
||||
alembic revision --autogenerate -m "$(message)"
|
||||
|
||||
# Comando para atualizar o banco de dados
|
||||
# Command to update database to latest version
|
||||
alembic-upgrade:
|
||||
alembic upgrade head
|
||||
|
||||
# Comando para voltar uma versão
|
||||
# Command to downgrade one version
|
||||
alembic-downgrade:
|
||||
alembic downgrade -1
|
||||
|
||||
# Comando para rodar o servidor
|
||||
# Command to run the server
|
||||
run:
|
||||
uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 --reload-dir src
|
||||
|
||||
# Comando para limpar o cache em todas as pastas do projeto pastas pycache
|
||||
# Command to run the server in production mode
|
||||
run-prod:
|
||||
uvicorn src.main:app --host 0.0.0.0 --port 8000 --workers 4
|
||||
|
||||
# Command to clean cache in all project folders
|
||||
clear-cache:
|
||||
rm -rf ~/.cache/uv/environments-v2/* && find . -type d -name "__pycache__" -exec rm -r {} +
|
||||
|
||||
# Comando para criar uma nova migração
|
||||
# Command to create a new migration and apply it
|
||||
alembic-migrate:
|
||||
alembic revision --autogenerate -m "$(message)" && alembic upgrade head
|
||||
|
||||
# Comando para resetar o banco de dados
|
||||
# Command to reset the database
|
||||
alembic-reset:
|
||||
alembic downgrade base && alembic upgrade head
|
||||
|
||||
# Comando para forçar upgrade com CASCADE
|
||||
alembic-upgrade-cascade:
|
||||
psql -U postgres -d a2a_saas -c "DROP TABLE IF EXISTS events CASCADE; DROP TABLE IF EXISTS sessions CASCADE; DROP TABLE IF EXISTS user_states CASCADE; DROP TABLE IF EXISTS app_states CASCADE;" && alembic upgrade head
|
||||
|
||||
# Comandos para executar seeders
|
||||
# Commands to run seeders
|
||||
seed-admin:
|
||||
python -m scripts.seeders.admin_seeder
|
||||
|
||||
@ -58,7 +58,7 @@ seed-contacts:
|
||||
seed-all:
|
||||
python -m scripts.run_seeders
|
||||
|
||||
# Comandos Docker
|
||||
# Docker commands
|
||||
docker-build:
|
||||
docker-compose build
|
||||
|
||||
@ -73,3 +73,20 @@ docker-logs:
|
||||
|
||||
docker-seed:
|
||||
docker-compose exec api python -m scripts.run_seeders
|
||||
|
||||
# Testing, linting and formatting commands
|
||||
lint:
|
||||
flake8 src/ tests/
|
||||
|
||||
format:
|
||||
black src/ tests/
|
||||
|
||||
# Virtual environment and installation commands
|
||||
venv:
|
||||
python -m venv venv
|
||||
|
||||
install:
|
||||
pip install -e .
|
||||
|
||||
install-dev:
|
||||
pip install -e ".[dev]"
|
53
conftest.py
Normal file
53
conftest.py
Normal file
@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from src.config.database import Base, get_db
|
||||
from src.main import app
|
||||
|
||||
# Use in-memory SQLite for tests
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""Creates a fresh database session for each test."""
|
||||
Base.metadata.create_all(bind=engine) # Create tables
|
||||
|
||||
connection = engine.connect()
|
||||
transaction = connection.begin()
|
||||
session = TestingSessionLocal(bind=connection)
|
||||
|
||||
# Use our test database instead of the standard one
|
||||
def override_get_db():
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
yield session # The test will run here
|
||||
|
||||
# Teardown
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(db_session):
|
||||
"""Creates a FastAPI TestClient with database session fixture."""
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
89
pyproject.toml
Normal file
89
pyproject.toml
Normal file
@ -0,0 +1,89 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "evo-ai"
|
||||
version = "1.0.0"
|
||||
description = "API for executing AI agents"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{name = "EvoAI Team", email = "admin@evoai.com"}
|
||||
]
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "Proprietary"}
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"License :: Other/Proprietary License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
|
||||
# Main dependencies
|
||||
dependencies = [
|
||||
"fastapi==0.115.12",
|
||||
"uvicorn==0.34.2",
|
||||
"pydantic==2.11.3",
|
||||
"sqlalchemy==2.0.40",
|
||||
"psycopg2-binary==2.9.10",
|
||||
"google-cloud-aiplatform==1.90.0",
|
||||
"python-dotenv==1.1.0",
|
||||
"google-adk==0.3.0",
|
||||
"litellm==1.67.4.post1",
|
||||
"python-multipart==0.0.20",
|
||||
"alembic==1.15.2",
|
||||
"asyncpg==0.30.0",
|
||||
"python-jose==3.4.0",
|
||||
"passlib==1.7.4",
|
||||
"sendgrid==6.11.0",
|
||||
"pydantic-settings==2.9.1",
|
||||
"fastapi_utils==0.8.0",
|
||||
"bcrypt==4.3.0",
|
||||
"jinja2==3.1.6",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"black==25.1.0",
|
||||
"flake8==7.2.0",
|
||||
"pytest==8.3.5",
|
||||
"pytest-cov==6.1.1",
|
||||
"httpx==0.28.1",
|
||||
"pytest-asyncio==0.26.0",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["src"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py310"]
|
||||
include = '\.pyi?$'
|
||||
exclude = '''
|
||||
/(
|
||||
\.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| _build
|
||||
| buck-out
|
||||
| build
|
||||
| dist
|
||||
| venv
|
||||
| alembic/versions
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = "test_*.py"
|
||||
python_functions = "test_*"
|
||||
python_classes = "Test*"
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["src"]
|
||||
omit = ["tests/*", "alembic/*"]
|
@ -1,22 +0,0 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
sqlalchemy
|
||||
psycopg2
|
||||
psycopg2-binary
|
||||
google-cloud-aiplatform
|
||||
python-dotenv
|
||||
google-adk
|
||||
litellm
|
||||
python-multipart
|
||||
alembic
|
||||
asyncpg
|
||||
# Novas dependências para autenticação
|
||||
python-jose[cryptography]
|
||||
passlib[bcrypt]
|
||||
sendgrid
|
||||
pydantic[email]
|
||||
pydantic-settings
|
||||
fastapi_utils
|
||||
bcrypt
|
||||
jinja2
|
6
setup.py
Normal file
6
setup.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Setup script for the package."""
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup()
|
@ -1,3 +1,3 @@
|
||||
"""
|
||||
Pacote principal da aplicação
|
||||
Main package of the application
|
||||
"""
|
@ -1,14 +1,17 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from src.config.database import get_db
|
||||
from src.core.jwt_middleware import get_jwt_token, verify_admin
|
||||
from src.schemas.audit import AuditLogResponse, AuditLogFilter
|
||||
from src.services.audit_service import get_audit_logs, create_audit_log
|
||||
from src.services.user_service import get_admin_users, create_admin_user, deactivate_user
|
||||
from src.services.user_service import (
|
||||
get_admin_users,
|
||||
create_admin_user,
|
||||
deactivate_user,
|
||||
)
|
||||
from src.schemas.user import UserResponse, AdminUserCreate
|
||||
|
||||
router = APIRouter(
|
||||
@ -18,6 +21,7 @@ router = APIRouter(
|
||||
responses={403: {"description": "Permission denied"}},
|
||||
)
|
||||
|
||||
|
||||
# Audit routes
|
||||
@router.get("/audit-logs", response_model=List[AuditLogResponse])
|
||||
async def read_audit_logs(
|
||||
@ -45,9 +49,10 @@ async def read_audit_logs(
|
||||
resource_type=filters.resource_type,
|
||||
resource_id=filters.resource_id,
|
||||
start_date=filters.start_date,
|
||||
end_date=filters.end_date
|
||||
end_date=filters.end_date,
|
||||
)
|
||||
|
||||
|
||||
# Admin routes
|
||||
@router.get("/users", response_model=List[UserResponse])
|
||||
async def read_admin_users(
|
||||
@ -70,6 +75,7 @@ async def read_admin_users(
|
||||
"""
|
||||
return get_admin_users(db, skip, limit)
|
||||
|
||||
|
||||
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_new_admin_user(
|
||||
user_data: AdminUserCreate,
|
||||
@ -97,16 +103,13 @@ async def create_new_admin_user(
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unable to identify the logged in user"
|
||||
detail="Unable to identify the logged in user",
|
||||
)
|
||||
|
||||
# Create admin user
|
||||
user, message = create_admin_user(db, user_data)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=message
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
|
||||
|
||||
# Register action in audit log
|
||||
create_audit_log(
|
||||
@ -116,11 +119,12 @@ async def create_new_admin_user(
|
||||
resource_type="admin_user",
|
||||
resource_id=str(user.id),
|
||||
details={"email": user.email},
|
||||
request=request
|
||||
request=request,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def deactivate_admin_user(
|
||||
user_id: uuid.UUID,
|
||||
@ -145,23 +149,20 @@ async def deactivate_admin_user(
|
||||
if not current_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unable to identify the logged in user"
|
||||
detail="Unable to identify the logged in user",
|
||||
)
|
||||
|
||||
# Do not allow deactivating yourself
|
||||
if str(user_id) == current_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Unable to deactivate your own user"
|
||||
detail="Unable to deactivate your own user",
|
||||
)
|
||||
|
||||
# Deactivate user
|
||||
success, message = deactivate_user(db, user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=message
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
|
||||
|
||||
# Register action in audit log
|
||||
create_audit_log(
|
||||
@ -171,5 +172,5 @@ async def deactivate_admin_user(
|
||||
resource_type="admin_user",
|
||||
resource_id=str(user_id),
|
||||
details=None,
|
||||
request=request
|
||||
request=request,
|
||||
)
|
@ -7,10 +7,6 @@ from src.core.jwt_middleware import (
|
||||
get_jwt_token,
|
||||
verify_user_client,
|
||||
)
|
||||
from src.core.jwt_middleware import (
|
||||
get_jwt_token,
|
||||
verify_user_client,
|
||||
)
|
||||
from src.schemas.schemas import (
|
||||
Agent,
|
||||
AgentCreate,
|
||||
|
@ -162,9 +162,7 @@ async def login_for_access_token(form_data: UserLogin, db: Session = Depends(get
|
||||
"""
|
||||
user = authenticate_user(db, form_data.email, form_data.password)
|
||||
if not user:
|
||||
logger.warning(
|
||||
f"Login attempt with invalid credentials: {form_data.email}"
|
||||
)
|
||||
logger.warning(f"Login attempt with invalid credentials: {form_data.email}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
|
@ -11,7 +11,11 @@ from src.services import (
|
||||
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
|
||||
from src.services.agent_runner import run_agent
|
||||
from src.core.exceptions import AgentNotFoundError
|
||||
from src.main import session_service, artifacts_service, memory_service
|
||||
from src.services.service_providers import (
|
||||
session_service,
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
)
|
||||
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
@ -19,7 +19,7 @@ from src.services.session_service import (
|
||||
get_sessions_by_agent,
|
||||
get_sessions_by_client,
|
||||
)
|
||||
from src.main import session_service
|
||||
from src.services.service_providers import session_service
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -30,6 +30,7 @@ router = APIRouter(
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
# Session Routes
|
||||
@router.get("/client/{client_id}", response_model=List[Adk_Session])
|
||||
async def get_client_sessions(
|
||||
|
@ -10,6 +10,7 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
@ -1,55 +1,66 @@
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class BaseAPIException(HTTPException):
|
||||
"""Base class for API exceptions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
error_code: str,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, detail={
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail={
|
||||
"error": message,
|
||||
"error_code": error_code,
|
||||
"details": details or {}
|
||||
})
|
||||
"details": details or {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class AgentNotFoundError(BaseAPIException):
|
||||
"""Exception when the agent is not found"""
|
||||
|
||||
def __init__(self, agent_id: str):
|
||||
super().__init__(
|
||||
status_code=404,
|
||||
message=f"Agent with ID {agent_id} not found",
|
||||
error_code="AGENT_NOT_FOUND"
|
||||
error_code="AGENT_NOT_FOUND",
|
||||
)
|
||||
|
||||
|
||||
class InvalidParameterError(BaseAPIException):
|
||||
"""Exception for invalid parameters"""
|
||||
|
||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(
|
||||
status_code=400,
|
||||
message=message,
|
||||
error_code="INVALID_PARAMETER",
|
||||
details=details
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class InvalidRequestError(BaseAPIException):
|
||||
"""Exception for invalid requests"""
|
||||
|
||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(
|
||||
status_code=400,
|
||||
message=message,
|
||||
error_code="INVALID_REQUEST",
|
||||
details=details
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class InternalServerError(BaseAPIException):
|
||||
"""Exception for server errors"""
|
||||
|
||||
def __init__(self, message: str = "Server error"):
|
||||
super().__init__(
|
||||
status_code=500,
|
||||
message=message,
|
||||
error_code="INTERNAL_SERVER_ERROR"
|
||||
status_code=500, message=message, error_code="INTERNAL_SERVER_ERROR"
|
||||
)
|
@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
|
||||
|
||||
async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
|
||||
"""
|
||||
Extracts and validates the JWT token
|
||||
@ -34,9 +35,7 @@ async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM]
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
|
||||
email: str = payload.get("sub")
|
||||
@ -55,10 +54,11 @@ async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
|
||||
logger.error(f"Error decoding JWT token: {str(e)}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
async def verify_user_client(
|
||||
payload: dict = Depends(get_jwt_token),
|
||||
db: Session = Depends(get_db),
|
||||
required_client_id: UUID = None
|
||||
required_client_id: UUID = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Verifies if the user is associated with the specified client
|
||||
@ -81,10 +81,12 @@ async def verify_user_client(
|
||||
# Para não-admins, verificar se o client_id corresponde
|
||||
user_client_id = payload.get("client_id")
|
||||
if not user_client_id:
|
||||
logger.warning(f"Non-admin user without client_id in token: {payload.get('sub')}")
|
||||
logger.warning(
|
||||
f"Non-admin user without client_id in token: {payload.get('sub')}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User not associated with a client"
|
||||
detail="User not associated with a client",
|
||||
)
|
||||
|
||||
# If no client_id is specified to verify, any client is valid
|
||||
@ -93,14 +95,17 @@ async def verify_user_client(
|
||||
|
||||
# Verify if the user's client_id corresponds to the required_client_id
|
||||
if str(user_client_id) != str(required_client_id):
|
||||
logger.warning(f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}")
|
||||
logger.warning(
|
||||
f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to access resources of this client"
|
||||
detail="Access denied to access resources of this client",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
|
||||
"""
|
||||
Verifies if the user is an administrator
|
||||
@ -118,12 +123,15 @@ async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
|
||||
logger.warning(f"Access denied to admin: User {payload.get('sub')}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. Restricted to administrators."
|
||||
detail="Access denied. Restricted to administrators.",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def get_current_user_client_id(payload: dict = Depends(get_jwt_token)) -> Optional[UUID]:
|
||||
|
||||
def get_current_user_client_id(
|
||||
payload: dict = Depends(get_jwt_token),
|
||||
) -> Optional[UUID]:
|
||||
"""
|
||||
Gets the ID of the client associated with the current user
|
||||
|
||||
|
55
src/main.py
55
src/main.py
@ -1,35 +1,30 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the root directory to PYTHONPATH
|
||||
root_dir = Path(__file__).parent.parent
|
||||
sys.path.append(str(root_dir))
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from src.config.database import engine, Base
|
||||
from src.config.settings import settings
|
||||
from src.utils.logger import setup_logger
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.sessions import DatabaseSessionService
|
||||
from google.adk.memory import InMemoryMemoryService
|
||||
|
||||
# Initialize service instances
|
||||
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
|
||||
artifacts_service = InMemoryArtifactService()
|
||||
memory_service = InMemoryMemoryService()
|
||||
# Necessary for other modules
|
||||
from src.services.service_providers import session_service # noqa: F401
|
||||
from src.services.service_providers import artifacts_service # noqa: F401
|
||||
from src.services.service_providers import memory_service # noqa: F401
|
||||
|
||||
# Import routers after service initialization to avoid circular imports
|
||||
from src.api.auth_routes import router as auth_router
|
||||
from src.api.admin_routes import router as admin_router
|
||||
from src.api.chat_routes import router as chat_router
|
||||
from src.api.session_routes import router as session_router
|
||||
from src.api.agent_routes import router as agent_router
|
||||
from src.api.contact_routes import router as contact_router
|
||||
from src.api.mcp_server_routes import router as mcp_server_router
|
||||
from src.api.tool_routes import router as tool_router
|
||||
from src.api.client_routes import router as client_router
|
||||
import src.api.auth_routes
|
||||
import src.api.admin_routes
|
||||
import src.api.chat_routes
|
||||
import src.api.session_routes
|
||||
import src.api.agent_routes
|
||||
import src.api.contact_routes
|
||||
import src.api.mcp_server_routes
|
||||
import src.api.tool_routes
|
||||
import src.api.client_routes
|
||||
|
||||
# Add the root directory to PYTHONPATH
|
||||
root_dir = Path(__file__).parent.parent
|
||||
sys.path.append(str(root_dir))
|
||||
|
||||
# Configure logger
|
||||
logger = setup_logger(__name__)
|
||||
@ -52,8 +47,7 @@ app.add_middleware(
|
||||
|
||||
# PostgreSQL configuration
|
||||
POSTGRES_CONNECTION_STRING = os.getenv(
|
||||
"POSTGRES_CONNECTION_STRING",
|
||||
"postgresql://postgres:root@localhost:5432/evo_ai"
|
||||
"POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
|
||||
)
|
||||
|
||||
# Create database tables
|
||||
@ -61,6 +55,17 @@ Base.metadata.create_all(bind=engine)
|
||||
|
||||
API_PREFIX = "/api/v1"
|
||||
|
||||
# Define router references
|
||||
auth_router = src.api.auth_routes.router
|
||||
admin_router = src.api.admin_routes.router
|
||||
chat_router = src.api.chat_routes.router
|
||||
session_router = src.api.session_routes.router
|
||||
agent_router = src.api.agent_routes.router
|
||||
contact_router = src.api.contact_routes.router
|
||||
mcp_server_router = src.api.mcp_server_routes.router
|
||||
tool_router = src.api.tool_routes.router
|
||||
client_router = src.api.client_routes.router
|
||||
|
||||
# Include routes
|
||||
app.include_router(auth_router, prefix=API_PREFIX)
|
||||
app.include_router(admin_router, prefix=API_PREFIX)
|
||||
@ -79,5 +84,5 @@ def read_root():
|
||||
"message": "Welcome to Evo AI API",
|
||||
"documentation": "/docs",
|
||||
"version": settings.API_VERSION,
|
||||
"auth": "To access the API, use JWT authentication via '/api/v1/auth/login'"
|
||||
"auth": "To access the API, use JWT authentication via '/api/v1/auth/login'",
|
||||
}
|
||||
|
@ -1,9 +1,20 @@
|
||||
from sqlalchemy import Column, String, UUID, DateTime, ForeignKey, JSON, Text, BigInteger, CheckConstraint, Boolean
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
String,
|
||||
UUID,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
JSON,
|
||||
Text,
|
||||
CheckConstraint,
|
||||
Boolean,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship, backref
|
||||
from src.config.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Client(Base):
|
||||
__tablename__ = "clients"
|
||||
|
||||
@ -13,13 +24,16 @@ class Client(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
email = Column(String, unique=True, index=True, nullable=False)
|
||||
password_hash = Column(String, nullable=False)
|
||||
client_id = Column(UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True)
|
||||
client_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
is_active = Column(Boolean, default=False)
|
||||
is_admin = Column(Boolean, default=False)
|
||||
email_verified = Column(Boolean, default=False)
|
||||
@ -31,7 +45,10 @@ class User(Base):
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
# Relationship with Client (One-to-One, optional for administrators)
|
||||
client = relationship("Client", backref=backref("user", uselist=False, cascade="all, delete-orphan"))
|
||||
client = relationship(
|
||||
"Client", backref=backref("user", uselist=False, cascade="all, delete-orphan")
|
||||
)
|
||||
|
||||
|
||||
class Contact(Base):
|
||||
__tablename__ = "contacts"
|
||||
@ -44,6 +61,7 @@ class Contact(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class Agent(Base):
|
||||
__tablename__ = "agents"
|
||||
|
||||
@ -60,21 +78,30 @@ class Agent(Base):
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("type IN ('llm', 'sequential', 'parallel', 'loop')", name='check_agent_type'),
|
||||
CheckConstraint(
|
||||
"type IN ('llm', 'sequential', 'parallel', 'loop')", name="check_agent_type"
|
||||
),
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Converts the object to a dictionary, converting UUIDs to strings"""
|
||||
result = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if key.startswith('_'):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if isinstance(value, uuid.UUID):
|
||||
result[key] = str(value)
|
||||
elif isinstance(value, dict):
|
||||
result[key] = self._convert_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value]
|
||||
result[key] = [
|
||||
(
|
||||
self._convert_dict(item)
|
||||
if isinstance(item, dict)
|
||||
else str(item) if isinstance(item, uuid.UUID) else item
|
||||
)
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
@ -88,11 +115,19 @@ class Agent(Base):
|
||||
elif isinstance(value, dict):
|
||||
result[key] = self._convert_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value]
|
||||
result[key] = [
|
||||
(
|
||||
self._convert_dict(item)
|
||||
if isinstance(item, dict)
|
||||
else str(item) if isinstance(item, uuid.UUID) else item
|
||||
)
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
class MCPServer(Base):
|
||||
__tablename__ = "mcp_servers"
|
||||
|
||||
@ -107,9 +142,12 @@ class MCPServer(Base):
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("type IN ('official', 'community')", name='check_mcp_server_type'),
|
||||
CheckConstraint(
|
||||
"type IN ('official', 'community')", name="check_mcp_server_type"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Tool(Base):
|
||||
__tablename__ = "tools"
|
||||
|
||||
@ -121,10 +159,11 @@ class Tool(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class Session(Base):
|
||||
__tablename__ = "sessions"
|
||||
# The directive below makes Alembic ignore this table in migrations
|
||||
__table_args__ = {'extend_existing': True, 'info': {'skip_autogenerate': True}}
|
||||
__table_args__ = {"extend_existing": True, "info": {"skip_autogenerate": True}}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
app_name = Column(String)
|
||||
@ -133,11 +172,14 @@ class Session(Base):
|
||||
create_time = Column(DateTime(timezone=True))
|
||||
update_time = Column(DateTime(timezone=True))
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
action = Column(String, nullable=False)
|
||||
resource_type = Column(String, nullable=False)
|
||||
resource_id = Column(String, nullable=True)
|
||||
|
@ -1,26 +1,38 @@
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from typing import List, Optional, Dict, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""Configuration of a tool"""
|
||||
|
||||
id: UUID
|
||||
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the tool")
|
||||
envs: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Environment variables of the tool"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""Configuration of an MCP server"""
|
||||
|
||||
id: UUID
|
||||
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the server")
|
||||
tools: List[str] = Field(default_factory=list, description="List of tools of the server")
|
||||
envs: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Environment variables of the server"
|
||||
)
|
||||
tools: List[str] = Field(
|
||||
default_factory=list, description="List of tools of the server"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HTTPToolParameter(BaseModel):
|
||||
"""Parameter of an HTTP tool"""
|
||||
|
||||
type: str
|
||||
required: bool
|
||||
description: str
|
||||
@ -28,8 +40,10 @@ class HTTPToolParameter(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HTTPToolParameters(BaseModel):
|
||||
"""Parameters of an HTTP tool"""
|
||||
|
||||
path_params: Optional[Dict[str, str]] = None
|
||||
query_params: Optional[Dict[str, Union[str, List[str]]]] = None
|
||||
body_params: Optional[Dict[str, HTTPToolParameter]] = None
|
||||
@ -37,8 +51,10 @@ class HTTPToolParameters(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HTTPToolErrorHandling(BaseModel):
|
||||
"""Configuration of error handling"""
|
||||
|
||||
timeout: int
|
||||
retry_count: int
|
||||
fallback_response: Dict[str, str]
|
||||
@ -46,8 +62,10 @@ class HTTPToolErrorHandling(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HTTPTool(BaseModel):
|
||||
"""Configuration of an HTTP tool"""
|
||||
|
||||
name: str
|
||||
method: str
|
||||
values: Dict[str, str]
|
||||
@ -60,42 +78,72 @@ class HTTPTool(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CustomTools(BaseModel):
|
||||
"""Configuration of custom tools"""
|
||||
http_tools: List[HTTPTool] = Field(default_factory=list, description="List of HTTP tools")
|
||||
|
||||
http_tools: List[HTTPTool] = Field(
|
||||
default_factory=list, description="List of HTTP tools"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Configuration for LLM agents"""
|
||||
tools: Optional[List[ToolConfig]] = Field(default=None, description="List of available tools")
|
||||
custom_tools: Optional[CustomTools] = Field(default=None, description="Custom tools")
|
||||
mcp_servers: Optional[List[MCPServerConfig]] = Field(default=None, description="List of MCP servers")
|
||||
sub_agents: Optional[List[UUID]] = Field(default=None, description="List of IDs of sub-agents")
|
||||
|
||||
tools: Optional[List[ToolConfig]] = Field(
|
||||
default=None, description="List of available tools"
|
||||
)
|
||||
custom_tools: Optional[CustomTools] = Field(
|
||||
default=None, description="Custom tools"
|
||||
)
|
||||
mcp_servers: Optional[List[MCPServerConfig]] = Field(
|
||||
default=None, description="List of MCP servers"
|
||||
)
|
||||
sub_agents: Optional[List[UUID]] = Field(
|
||||
default=None, description="List of IDs of sub-agents"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SequentialConfig(BaseModel):
|
||||
"""Configuration for sequential agents"""
|
||||
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents in execution order")
|
||||
|
||||
sub_agents: List[UUID] = Field(
|
||||
..., description="List of IDs of sub-agents in execution order"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ParallelConfig(BaseModel):
|
||||
"""Configuration for parallel agents"""
|
||||
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for parallel execution")
|
||||
|
||||
sub_agents: List[UUID] = Field(
|
||||
..., description="List of IDs of sub-agents for parallel execution"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class LoopConfig(BaseModel):
|
||||
"""Configuration for loop agents"""
|
||||
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for loop execution")
|
||||
max_iterations: Optional[int] = Field(default=None, description="Maximum number of iterations")
|
||||
condition: Optional[str] = Field(default=None, description="Condition to stop the loop")
|
||||
|
||||
sub_agents: List[UUID] = Field(
|
||||
..., description="List of IDs of sub-agents for loop execution"
|
||||
)
|
||||
max_iterations: Optional[int] = Field(
|
||||
default=None, description="Maximum number of iterations"
|
||||
)
|
||||
condition: Optional[str] = Field(
|
||||
default=None, description="Condition to stop the loop"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
@ -3,19 +3,25 @@ from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class AuditLogBase(BaseModel):
|
||||
"""Base schema for audit log"""
|
||||
|
||||
action: str
|
||||
resource_type: str
|
||||
resource_id: Optional[str] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AuditLogCreate(AuditLogBase):
|
||||
"""Schema for creating audit log"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AuditLogResponse(AuditLogBase):
|
||||
"""Schema for audit log response"""
|
||||
|
||||
id: UUID
|
||||
user_id: Optional[UUID] = None
|
||||
ip_address: Optional[str] = None
|
||||
@ -25,8 +31,10 @@ class AuditLogResponse(AuditLogBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AuditLogFilter(BaseModel):
|
||||
"""Schema for audit log search filters"""
|
||||
|
||||
user_id: Optional[UUID] = None
|
||||
action: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
|
@ -1,21 +1,33 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Schema for chat requests"""
|
||||
agent_id: str = Field(..., description="ID of the agent that will process the message")
|
||||
contact_id: str = Field(..., description="ID of the contact that will process the message")
|
||||
|
||||
agent_id: str = Field(
|
||||
..., description="ID of the agent that will process the message"
|
||||
)
|
||||
contact_id: str = Field(
|
||||
..., description="ID of the contact that will process the message"
|
||||
)
|
||||
message: str = Field(..., description="User message")
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Schema for chat responses"""
|
||||
|
||||
response: str = Field(..., description="Agent response")
|
||||
status: str = Field(..., description="Operation status")
|
||||
error: Optional[str] = Field(None, description="Error message, if there is one")
|
||||
timestamp: str = Field(..., description="Timestamp of the response")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Schema for error responses"""
|
||||
|
||||
error: str = Field(..., description="Error message")
|
||||
status_code: int = Field(..., description="HTTP status code of the error")
|
||||
details: Optional[Dict[str, Any]] = Field(None, description="Additional error details")
|
||||
details: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional error details"
|
||||
)
|
||||
|
@ -4,15 +4,18 @@ from datetime import datetime
|
||||
from uuid import UUID
|
||||
import uuid
|
||||
import re
|
||||
from .agent_config import LLMConfig, SequentialConfig, ParallelConfig, LoopConfig
|
||||
from src.schemas.agent_config import LLMConfig
|
||||
|
||||
|
||||
class ClientBase(BaseModel):
|
||||
name: str
|
||||
email: Optional[EmailStr] = None
|
||||
|
||||
|
||||
class ClientCreate(ClientBase):
|
||||
pass
|
||||
|
||||
|
||||
class Client(ClientBase):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
@ -20,14 +23,17 @@ class Client(ClientBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ContactBase(BaseModel):
|
||||
ext_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
meta: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ContactCreate(ContactBase):
|
||||
client_id: UUID
|
||||
|
||||
|
||||
class Contact(ContactBase):
|
||||
id: UUID
|
||||
client_id: UUID
|
||||
@ -35,67 +41,80 @@ class Contact(ContactBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AgentBase(BaseModel):
|
||||
name: str = Field(..., description="Agent name (no spaces or special characters)")
|
||||
description: Optional[str] = Field(None, description="Agent description")
|
||||
type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)")
|
||||
model: Optional[str] = Field(None, description="Agent model (required only for llm type)")
|
||||
api_key: Optional[str] = Field(None, description="Agent API Key (required only for llm type)")
|
||||
model: Optional[str] = Field(
|
||||
None, description="Agent model (required only for llm type)"
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
None, description="Agent API Key (required only for llm type)"
|
||||
)
|
||||
instruction: Optional[str] = None
|
||||
config: Union[LLMConfig, Dict[str, Any]] = Field(..., description="Agent configuration based on type")
|
||||
config: Union[LLMConfig, Dict[str, Any]] = Field(
|
||||
..., description="Agent configuration based on type"
|
||||
)
|
||||
|
||||
@validator('name')
|
||||
@validator("name")
|
||||
def validate_name(cls, v):
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', v):
|
||||
raise ValueError('Agent name cannot contain spaces or special characters')
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("Agent name cannot contain spaces or special characters")
|
||||
return v
|
||||
|
||||
@validator('type')
|
||||
@validator("type")
|
||||
def validate_type(cls, v):
|
||||
if v not in ['llm', 'sequential', 'parallel', 'loop']:
|
||||
raise ValueError('Invalid agent type. Must be: llm, sequential, parallel or loop')
|
||||
if v not in ["llm", "sequential", "parallel", "loop"]:
|
||||
raise ValueError(
|
||||
"Invalid agent type. Must be: llm, sequential, parallel or loop"
|
||||
)
|
||||
return v
|
||||
|
||||
@validator('model')
|
||||
@validator("model")
|
||||
def validate_model(cls, v, values):
|
||||
if 'type' in values and values['type'] == 'llm' and not v:
|
||||
raise ValueError('Model is required for llm type agents')
|
||||
if "type" in values and values["type"] == "llm" and not v:
|
||||
raise ValueError("Model is required for llm type agents")
|
||||
return v
|
||||
|
||||
@validator('api_key')
|
||||
@validator("api_key")
|
||||
def validate_api_key(cls, v, values):
|
||||
if 'type' in values and values['type'] == 'llm' and not v:
|
||||
raise ValueError('API Key is required for llm type agents')
|
||||
if "type" in values and values["type"] == "llm" and not v:
|
||||
raise ValueError("API Key is required for llm type agents")
|
||||
return v
|
||||
|
||||
@validator('config')
|
||||
@validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
if 'type' not in values:
|
||||
if "type" not in values:
|
||||
return v
|
||||
|
||||
if values['type'] == 'llm':
|
||||
if values["type"] == "llm":
|
||||
if isinstance(v, dict):
|
||||
try:
|
||||
# Convert the dictionary to LLMConfig
|
||||
v = LLMConfig(**v)
|
||||
except Exception as e:
|
||||
raise ValueError(f'Invalid LLM configuration for agent: {str(e)}')
|
||||
raise ValueError(f"Invalid LLM configuration for agent: {str(e)}")
|
||||
elif not isinstance(v, LLMConfig):
|
||||
raise ValueError('Invalid LLM configuration for agent')
|
||||
elif values['type'] in ['sequential', 'parallel', 'loop']:
|
||||
raise ValueError("Invalid LLM configuration for agent")
|
||||
elif values["type"] in ["sequential", "parallel", "loop"]:
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError(f'Invalid configuration for agent {values["type"]}')
|
||||
if 'sub_agents' not in v:
|
||||
if "sub_agents" not in v:
|
||||
raise ValueError(f'Agent {values["type"]} must have sub_agents')
|
||||
if not isinstance(v['sub_agents'], list):
|
||||
raise ValueError('sub_agents must be a list')
|
||||
if not v['sub_agents']:
|
||||
raise ValueError(f'Agent {values["type"]} must have at least one sub-agent')
|
||||
if not isinstance(v["sub_agents"], list):
|
||||
raise ValueError("sub_agents must be a list")
|
||||
if not v["sub_agents"]:
|
||||
raise ValueError(
|
||||
f'Agent {values["type"]} must have at least one sub-agent'
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class AgentCreate(AgentBase):
|
||||
client_id: UUID
|
||||
|
||||
|
||||
class Agent(AgentBase):
|
||||
id: UUID
|
||||
client_id: UUID
|
||||
@ -105,6 +124,7 @@ class Agent(AgentBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MCPServerBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
@ -113,9 +133,11 @@ class MCPServerBase(BaseModel):
|
||||
tools: List[str] = Field(default_factory=list)
|
||||
type: str = Field(default="official")
|
||||
|
||||
|
||||
class MCPServerCreate(MCPServerBase):
|
||||
pass
|
||||
|
||||
|
||||
class MCPServer(MCPServerBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime
|
||||
@ -124,15 +146,18 @@ class MCPServer(MCPServerBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ToolBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
config_json: Dict[str, Any] = Field(default_factory=dict)
|
||||
environments: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ToolCreate(ToolBase):
|
||||
pass
|
||||
|
||||
|
||||
class Tool(ToolBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime
|
||||
|
@ -1,23 +1,28 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
name: str # For client creation
|
||||
|
||||
|
||||
class AdminUserCreate(UserBase):
|
||||
password: str
|
||||
name: str
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class UserResponse(UserBase):
|
||||
id: UUID
|
||||
client_id: Optional[UUID] = None
|
||||
@ -29,22 +34,27 @@ class UserResponse(UserBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
sub: str # user email
|
||||
exp: datetime
|
||||
is_admin: bool
|
||||
client_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class ForgotPassword(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
@ -13,11 +13,10 @@ from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.models import LlmResponse, LlmRequest
|
||||
from google.adk.tools import load_memory
|
||||
|
||||
from typing import Optional
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@ -83,7 +82,7 @@ def before_model_callback(
|
||||
llm_request.config.system_instruction = modified_text
|
||||
|
||||
logger.debug(
|
||||
f"📝 System instruction updated with search results and history"
|
||||
"📝 System instruction updated with search results and history"
|
||||
)
|
||||
else:
|
||||
logger.warning("⚠️ No results found in the search")
|
||||
@ -180,7 +179,9 @@ class AgentBuilder:
|
||||
mcp_tools = []
|
||||
mcp_exit_stack = None
|
||||
if agent.config.get("mcp_servers"):
|
||||
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(agent.config, self.db)
|
||||
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
|
||||
agent.config, self.db
|
||||
)
|
||||
|
||||
# Combine all tools
|
||||
all_tools = custom_tools + mcp_tools
|
||||
@ -201,10 +202,13 @@ class AgentBuilder:
|
||||
|
||||
# Check if load_memory is enabled
|
||||
# before_model_callback_func = None
|
||||
if agent.config.get("load_memory") == True:
|
||||
if agent.config.get("load_memory"):
|
||||
all_tools.append(load_memory)
|
||||
# before_model_callback_func = before_model_callback
|
||||
formatted_prompt = formatted_prompt + "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
|
||||
formatted_prompt = (
|
||||
formatted_prompt
|
||||
+ "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
|
||||
)
|
||||
|
||||
return (
|
||||
LlmAgent(
|
||||
|
@ -22,9 +22,7 @@ async def run_agent(
|
||||
db: Session,
|
||||
):
|
||||
try:
|
||||
logger.info(
|
||||
f"Starting execution of agent {agent_id} for contact {contact_id}"
|
||||
)
|
||||
logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}")
|
||||
logger.info(f"Received message: {message}")
|
||||
|
||||
get_root_agent = get_agent(db, agent_id)
|
||||
|
@ -216,9 +216,7 @@ async def update_agent(
|
||||
return agent
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error updating agent: {str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=f"Error updating agent: {str(e)}")
|
||||
|
||||
|
||||
def delete_agent(db: Session, agent_id: uuid.UUID) -> bool:
|
||||
|
@ -1,15 +1,15 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.models.models import AuditLog, User
|
||||
from src.models.models import AuditLog
|
||||
from datetime import datetime
|
||||
from fastapi import Request
|
||||
from typing import Optional, Dict, Any, List
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_audit_log(
|
||||
db: Session,
|
||||
user_id: Optional[uuid.UUID],
|
||||
@ -17,7 +17,7 @@ def create_audit_log(
|
||||
resource_type: str,
|
||||
resource_id: Optional[str] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
request: Optional[Request] = None
|
||||
request: Optional[Request] = None,
|
||||
) -> Optional[AuditLog]:
|
||||
"""
|
||||
Create a new audit log
|
||||
@ -39,7 +39,7 @@ def create_audit_log(
|
||||
user_agent = None
|
||||
|
||||
if request:
|
||||
ip_address = request.client.host if hasattr(request, 'client') else None
|
||||
ip_address = request.client.host if hasattr(request, "client") else None
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
# Convert details to serializable format
|
||||
@ -56,7 +56,7 @@ def create_audit_log(
|
||||
resource_id=str(resource_id) if resource_id else None,
|
||||
details=details,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
db.add(audit_log)
|
||||
@ -64,8 +64,8 @@ def create_audit_log(
|
||||
db.refresh(audit_log)
|
||||
|
||||
logger.info(
|
||||
f"Audit log created: {action} in {resource_type}" +
|
||||
(f" (ID: {resource_id})" if resource_id else "")
|
||||
f"Audit log created: {action} in {resource_type}"
|
||||
+ (f" (ID: {resource_id})" if resource_id else "")
|
||||
)
|
||||
|
||||
return audit_log
|
||||
@ -78,6 +78,7 @@ def create_audit_log(
|
||||
logger.error(f"Unexpected error creating audit log: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_audit_logs(
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
@ -87,7 +88,7 @@ def get_audit_logs(
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None,
|
||||
) -> List[AuditLog]:
|
||||
"""
|
||||
Get audit logs with optional filters
|
||||
|
@ -16,7 +16,10 @@ logger = logging.getLogger(__name__)
|
||||
# Define OAuth2 authentication scheme with password flow
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User:
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current user from the JWT token
|
||||
|
||||
@ -39,9 +42,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
|
||||
try:
|
||||
# Decode the token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM]
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
|
||||
# Extract token data
|
||||
@ -61,7 +62,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
|
||||
sub=email,
|
||||
exp=datetime.fromtimestamp(exp),
|
||||
is_admin=payload.get("is_admin", False),
|
||||
client_id=payload.get("client_id")
|
||||
client_id=payload.get("client_id"),
|
||||
)
|
||||
|
||||
except JWTError as e:
|
||||
@ -77,13 +78,15 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
|
||||
if not user.is_active:
|
||||
logger.warning(f"Attempt to access inactive user: {user.email}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""
|
||||
Check if the current user is active
|
||||
|
||||
@ -99,12 +102,14 @@ async def get_current_active_user(current_user: User = Depends(get_current_user)
|
||||
if not current_user.is_active:
|
||||
logger.warning(f"Attempt to access inactive user: {current_user.email}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
async def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
|
||||
async def get_current_admin_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""
|
||||
Check if the current user is an administrator
|
||||
|
||||
@ -118,13 +123,16 @@ async def get_current_admin_user(current_user: User = Depends(get_current_user))
|
||||
HTTPException: If the user is not an administrator
|
||||
"""
|
||||
if not current_user.is_admin:
|
||||
logger.warning(f"Attempt to access admin by non-admin user: {current_user.email}")
|
||||
logger.warning(
|
||||
f"Attempt to access admin by non-admin user: {current_user.email}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. Restricted to administrators."
|
||||
detail="Access denied. Restricted to administrators.",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def create_access_token(user: User) -> str:
|
||||
"""
|
||||
Create a JWT access token for the user
|
||||
|
@ -11,6 +11,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
|
||||
"""Search for a client by ID"""
|
||||
try:
|
||||
@ -23,9 +24,10 @@ def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
|
||||
logger.error(f"Error searching for client {client_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for client"
|
||||
detail="Error searching for client",
|
||||
)
|
||||
|
||||
|
||||
def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
|
||||
"""Search for all clients with pagination"""
|
||||
try:
|
||||
@ -34,9 +36,10 @@ def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
|
||||
logger.error(f"Error searching for clients: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for clients"
|
||||
detail="Error searching for clients",
|
||||
)
|
||||
|
||||
|
||||
def create_client(db: Session, client: ClientCreate) -> Client:
|
||||
"""Create a new client"""
|
||||
try:
|
||||
@ -51,10 +54,13 @@ def create_client(db: Session, client: ClientCreate) -> Client:
|
||||
logger.error(f"Error creating client: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error creating client"
|
||||
detail="Error creating client",
|
||||
)
|
||||
|
||||
def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Optional[Client]:
|
||||
|
||||
def update_client(
|
||||
db: Session, client_id: uuid.UUID, client: ClientCreate
|
||||
) -> Optional[Client]:
|
||||
"""Updates an existing client"""
|
||||
try:
|
||||
db_client = get_client(db, client_id)
|
||||
@ -73,9 +79,10 @@ def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Op
|
||||
logger.error(f"Error updating client {client_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error updating client"
|
||||
detail="Error updating client",
|
||||
)
|
||||
|
||||
|
||||
def delete_client(db: Session, client_id: uuid.UUID) -> bool:
|
||||
"""Removes a client"""
|
||||
try:
|
||||
@ -92,10 +99,13 @@ def delete_client(db: Session, client_id: uuid.UUID) -> bool:
|
||||
logger.error(f"Error removing client {client_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error removing client"
|
||||
detail="Error removing client",
|
||||
)
|
||||
|
||||
def create_client_with_user(db: Session, client_data: ClientCreate, user_data: UserCreate) -> Tuple[Optional[Client], str]:
|
||||
|
||||
def create_client_with_user(
|
||||
db: Session, client_data: ClientCreate, user_data: UserCreate
|
||||
) -> Tuple[Optional[Client], str]:
|
||||
"""
|
||||
Creates a new client with an associated user
|
||||
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
|
||||
"""Search for a contact by ID"""
|
||||
try:
|
||||
@ -21,20 +22,30 @@ def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
|
||||
logger.error(f"Error searching for contact {contact_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for contact"
|
||||
detail="Error searching for contact",
|
||||
)
|
||||
|
||||
def get_contacts_by_client(db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100) -> List[Contact]:
|
||||
|
||||
def get_contacts_by_client(
|
||||
db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100
|
||||
) -> List[Contact]:
|
||||
"""Search for contacts of a client with pagination"""
|
||||
try:
|
||||
return db.query(Contact).filter(Contact.client_id == client_id).offset(skip).limit(limit).all()
|
||||
return (
|
||||
db.query(Contact)
|
||||
.filter(Contact.client_id == client_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error searching for contacts of client {client_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for contacts"
|
||||
detail="Error searching for contacts",
|
||||
)
|
||||
|
||||
|
||||
def create_contact(db: Session, contact: ContactCreate) -> Contact:
|
||||
"""Create a new contact"""
|
||||
try:
|
||||
@ -49,10 +60,13 @@ def create_contact(db: Session, contact: ContactCreate) -> Contact:
|
||||
logger.error(f"Error creating contact: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error creating contact"
|
||||
detail="Error creating contact",
|
||||
)
|
||||
|
||||
def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -> Optional[Contact]:
|
||||
|
||||
def update_contact(
|
||||
db: Session, contact_id: uuid.UUID, contact: ContactCreate
|
||||
) -> Optional[Contact]:
|
||||
"""Update an existing contact"""
|
||||
try:
|
||||
db_contact = get_contact(db, contact_id)
|
||||
@ -71,9 +85,10 @@ def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -
|
||||
logger.error(f"Error updating contact {contact_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error updating contact"
|
||||
detail="Error updating contact",
|
||||
)
|
||||
|
||||
|
||||
def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
|
||||
"""Remove a contact"""
|
||||
try:
|
||||
@ -90,5 +105,5 @@ def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
|
||||
logger.error(f"Error removing contact {contact_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error removing contact"
|
||||
detail="Error removing contact",
|
||||
)
|
@ -6,6 +6,7 @@ from src.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class CustomToolBuilder:
|
||||
def __init__(self):
|
||||
self.tools = []
|
||||
@ -53,7 +54,9 @@ class CustomToolBuilder:
|
||||
|
||||
# Adds default values to query params if they are not present
|
||||
for param, value in values.items():
|
||||
if param not in query_params and param not in parameters.get("path_params", {}):
|
||||
if param not in query_params and param not in parameters.get(
|
||||
"path_params", {}
|
||||
):
|
||||
query_params[param] = value
|
||||
|
||||
# Processa body parameters
|
||||
@ -64,7 +67,11 @@ class CustomToolBuilder:
|
||||
|
||||
# Adds default values to body if they are not present
|
||||
for param, value in values.items():
|
||||
if param not in body_data and param not in query_params and param not in parameters.get("path_params", {}):
|
||||
if (
|
||||
param not in body_data
|
||||
and param not in query_params
|
||||
and param not in parameters.get("path_params", {})
|
||||
):
|
||||
body_data[param] = value
|
||||
|
||||
# Makes the HTTP request
|
||||
@ -74,7 +81,7 @@ class CustomToolBuilder:
|
||||
headers=processed_headers,
|
||||
params=query_params,
|
||||
json=body_data if body_data else None,
|
||||
timeout=error_handling.get("timeout", 30)
|
||||
timeout=error_handling.get("timeout", 30),
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
@ -87,10 +94,12 @@ class CustomToolBuilder:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {name}: {str(e)}")
|
||||
return json.dumps(error_handling.get("fallback_response", {
|
||||
"error": "tool_execution_error",
|
||||
"message": str(e)
|
||||
}))
|
||||
return json.dumps(
|
||||
error_handling.get(
|
||||
"fallback_response",
|
||||
{"error": "tool_execution_error", "message": str(e)},
|
||||
)
|
||||
)
|
||||
|
||||
# Adds dynamic docstring based on the configuration
|
||||
param_docs = []
|
||||
@ -109,7 +118,9 @@ class CustomToolBuilder:
|
||||
# Adds body parameters
|
||||
for param, param_config in parameters.get("body_params", {}).items():
|
||||
required = "Required" if param_config.get("required", False) else "Optional"
|
||||
param_docs.append(f"{param} ({param_config['type']}, {required}): {param_config['description']}")
|
||||
param_docs.append(
|
||||
f"{param} ({param_config['type']}, {required}): {param_config['description']}"
|
||||
)
|
||||
|
||||
# Adds default values
|
||||
if values:
|
||||
|
@ -16,9 +16,10 @@ os.makedirs(templates_dir, exist_ok=True)
|
||||
# Configure Jinja2 with the templates directory
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(templates_dir),
|
||||
autoescape=select_autoescape(['html', 'xml'])
|
||||
autoescape=select_autoescape(["html", "xml"]),
|
||||
)
|
||||
|
||||
|
||||
def _render_template(template_name: str, context: dict) -> str:
|
||||
"""
|
||||
Render a template with the provided data
|
||||
@ -37,6 +38,7 @@ def _render_template(template_name: str, context: dict) -> str:
|
||||
logger.error(f"Error rendering template '{template_name}': {str(e)}")
|
||||
return f"<p>Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}</p>"
|
||||
|
||||
|
||||
def send_verification_email(email: str, token: str) -> bool:
|
||||
"""
|
||||
Send a verification email to the user
|
||||
@ -56,11 +58,16 @@ def send_verification_email(email: str, token: str) -> bool:
|
||||
|
||||
verification_link = f"{settings.APP_URL}/auth/verify-email/{token}"
|
||||
|
||||
html_content = _render_template('verification_email', {
|
||||
'verification_link': verification_link,
|
||||
'user_name': email.split('@')[0], # Use part of the email as temporary name
|
||||
'current_year': datetime.now().year
|
||||
})
|
||||
html_content = _render_template(
|
||||
"verification_email",
|
||||
{
|
||||
"verification_link": verification_link,
|
||||
"user_name": email.split("@")[
|
||||
0
|
||||
], # Use part of the email as temporary name
|
||||
"current_year": datetime.now().year,
|
||||
},
|
||||
)
|
||||
|
||||
content = Content("text/html", html_content)
|
||||
|
||||
@ -71,13 +78,16 @@ def send_verification_email(email: str, token: str) -> bool:
|
||||
logger.info(f"Verification email sent to {email}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to send verification email to {email}. Status: {response.status_code}")
|
||||
logger.error(
|
||||
f"Failed to send verification email to {email}. Status: {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending verification email to {email}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def send_password_reset_email(email: str, token: str) -> bool:
|
||||
"""
|
||||
Send a password reset email to the user
|
||||
@ -97,11 +107,16 @@ def send_password_reset_email(email: str, token: str) -> bool:
|
||||
|
||||
reset_link = f"{settings.APP_URL}/reset-password?token={token}"
|
||||
|
||||
html_content = _render_template('password_reset', {
|
||||
'reset_link': reset_link,
|
||||
'user_name': email.split('@')[0], # Use part of the email as temporary name
|
||||
'current_year': datetime.now().year
|
||||
})
|
||||
html_content = _render_template(
|
||||
"password_reset",
|
||||
{
|
||||
"reset_link": reset_link,
|
||||
"user_name": email.split("@")[
|
||||
0
|
||||
], # Use part of the email as temporary name
|
||||
"current_year": datetime.now().year,
|
||||
},
|
||||
)
|
||||
|
||||
content = Content("text/html", html_content)
|
||||
|
||||
@ -112,13 +127,16 @@ def send_password_reset_email(email: str, token: str) -> bool:
|
||||
logger.info(f"Password reset email sent to {email}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to send password reset email to {email}. Status: {response.status_code}")
|
||||
logger.error(
|
||||
f"Failed to send password reset email to {email}. Status: {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending password reset email to {email}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def send_welcome_email(email: str, user_name: str = None) -> bool:
|
||||
"""
|
||||
Send a welcome email to the user after verification
|
||||
@ -138,11 +156,14 @@ def send_welcome_email(email: str, user_name: str = None) -> bool:
|
||||
|
||||
dashboard_link = f"{settings.APP_URL}/dashboard"
|
||||
|
||||
html_content = _render_template('welcome_email', {
|
||||
'dashboard_link': dashboard_link,
|
||||
'user_name': user_name or email.split('@')[0],
|
||||
'current_year': datetime.now().year
|
||||
})
|
||||
html_content = _render_template(
|
||||
"welcome_email",
|
||||
{
|
||||
"dashboard_link": dashboard_link,
|
||||
"user_name": user_name or email.split("@")[0],
|
||||
"current_year": datetime.now().year,
|
||||
},
|
||||
)
|
||||
|
||||
content = Content("text/html", html_content)
|
||||
|
||||
@ -153,14 +174,19 @@ def send_welcome_email(email: str, user_name: str = None) -> bool:
|
||||
logger.info(f"Welcome email sent to {email}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to send welcome email to {email}. Status: {response.status_code}")
|
||||
logger.error(
|
||||
f"Failed to send welcome email to {email}. Status: {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending welcome email to {email}: {str(e)}")
|
||||
return False
|
||||
|
||||
def send_account_locked_email(email: str, reset_token: str, failed_attempts: int, time_period: str) -> bool:
|
||||
|
||||
def send_account_locked_email(
|
||||
email: str, reset_token: str, failed_attempts: int, time_period: str
|
||||
) -> bool:
|
||||
"""
|
||||
Send an email informing that the account has been locked after login attempts
|
||||
|
||||
@ -181,13 +207,16 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int
|
||||
|
||||
reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}"
|
||||
|
||||
html_content = _render_template('account_locked', {
|
||||
'reset_link': reset_link,
|
||||
'user_name': email.split('@')[0],
|
||||
'failed_attempts': failed_attempts,
|
||||
'time_period': time_period,
|
||||
'current_year': datetime.now().year
|
||||
})
|
||||
html_content = _render_template(
|
||||
"account_locked",
|
||||
{
|
||||
"reset_link": reset_link,
|
||||
"user_name": email.split("@")[0],
|
||||
"failed_attempts": failed_attempts,
|
||||
"time_period": time_period,
|
||||
"current_year": datetime.now().year,
|
||||
},
|
||||
)
|
||||
|
||||
content = Content("text/html", html_content)
|
||||
|
||||
@ -198,7 +227,9 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int
|
||||
logger.info(f"Account locked email sent to {email}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to send account locked email to {email}. Status: {response.status_code}")
|
||||
logger.error(
|
||||
f"Failed to send account locked email to {email}. Status: {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
|
||||
"""Search for an MCP server by ID"""
|
||||
try:
|
||||
@ -21,9 +22,10 @@ def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
|
||||
logger.error(f"Error searching for MCP server {server_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for MCP server"
|
||||
detail="Error searching for MCP server",
|
||||
)
|
||||
|
||||
|
||||
def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]:
|
||||
"""Search for all MCP servers with pagination"""
|
||||
try:
|
||||
@ -32,9 +34,10 @@ def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPSer
|
||||
logger.error(f"Error searching for MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for MCP servers"
|
||||
detail="Error searching for MCP servers",
|
||||
)
|
||||
|
||||
|
||||
def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
|
||||
"""Create a new MCP server"""
|
||||
try:
|
||||
@ -49,10 +52,13 @@ def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
|
||||
logger.error(f"Error creating MCP server: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error creating MCP server"
|
||||
detail="Error creating MCP server",
|
||||
)
|
||||
|
||||
def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate) -> Optional[MCPServer]:
|
||||
|
||||
def update_mcp_server(
|
||||
db: Session, server_id: uuid.UUID, server: MCPServerCreate
|
||||
) -> Optional[MCPServer]:
|
||||
"""Update an existing MCP server"""
|
||||
try:
|
||||
db_server = get_mcp_server(db, server_id)
|
||||
@ -71,9 +77,10 @@ def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate
|
||||
logger.error(f"Error updating MCP server {server_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error updating MCP server"
|
||||
detail="Error updating MCP server",
|
||||
)
|
||||
|
||||
|
||||
def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
|
||||
"""Remove an MCP server"""
|
||||
try:
|
||||
@ -90,5 +97,5 @@ def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
|
||||
logger.error(f"Error removing MCP server {server_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error removing MCP server"
|
||||
detail="Error removing MCP server",
|
||||
)
|
@ -12,20 +12,22 @@ from sqlalchemy.orm import Session
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class MCPService:
|
||||
def __init__(self):
|
||||
self.tools = []
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def _connect_to_mcp_server(self, server_config: Dict[str, Any]) -> Tuple[List[Any], Optional[AsyncExitStack]]:
|
||||
async def _connect_to_mcp_server(
|
||||
self, server_config: Dict[str, Any]
|
||||
) -> Tuple[List[Any], Optional[AsyncExitStack]]:
|
||||
"""Connect to a specific MCP server and return its tools."""
|
||||
try:
|
||||
# Determines the type of server (local or remote)
|
||||
if "url" in server_config:
|
||||
# Remote server (SSE)
|
||||
connection_params = SseServerParams(
|
||||
url=server_config["url"],
|
||||
headers=server_config.get("headers", {})
|
||||
url=server_config["url"], headers=server_config.get("headers", {})
|
||||
)
|
||||
else:
|
||||
# Local server (Stdio)
|
||||
@ -39,9 +41,7 @@ class MCPService:
|
||||
os.environ[key] = value
|
||||
|
||||
connection_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=args,
|
||||
env=env
|
||||
command=command, args=args, env=env
|
||||
)
|
||||
|
||||
tools, exit_stack = await MCPToolset.from_server(
|
||||
@ -74,7 +74,9 @@ class MCPService:
|
||||
|
||||
return filtered_tools
|
||||
|
||||
def _filter_tools_by_agent(self, tools: List[Any], agent_tools: List[str]) -> List[Any]:
|
||||
def _filter_tools_by_agent(
|
||||
self, tools: List[Any], agent_tools: List[str]
|
||||
) -> List[Any]:
|
||||
"""Filters tools compatible with the agent."""
|
||||
filtered_tools = []
|
||||
for tool in tools:
|
||||
@ -83,7 +85,9 @@ class MCPService:
|
||||
filtered_tools.append(tool)
|
||||
return filtered_tools
|
||||
|
||||
async def build_tools(self, mcp_config: Dict[str, Any], db: Session) -> Tuple[List[Any], AsyncExitStack]:
|
||||
async def build_tools(
|
||||
self, mcp_config: Dict[str, Any], db: Session
|
||||
) -> Tuple[List[Any], AsyncExitStack]:
|
||||
"""Builds a list of tools from multiple MCP servers."""
|
||||
self.tools = []
|
||||
self.exit_stack = AsyncExitStack()
|
||||
@ -92,7 +96,7 @@ class MCPService:
|
||||
for server in mcp_config.get("mcp_servers", []):
|
||||
try:
|
||||
# Search for the MCP server in the database
|
||||
mcp_server = get_mcp_server(db, server['id'])
|
||||
mcp_server = get_mcp_server(db, server["id"])
|
||||
if not mcp_server:
|
||||
logger.warning(f"Servidor MCP não encontrado: {server['id']}")
|
||||
continue
|
||||
@ -101,14 +105,16 @@ class MCPService:
|
||||
server_config = mcp_server.config_json.copy()
|
||||
|
||||
# Replaces the environment variables in the config_json
|
||||
if 'env' in server_config:
|
||||
for key, value in server_config['env'].items():
|
||||
if value.startswith('env@@'):
|
||||
env_key = value.replace('env@@', '')
|
||||
if env_key in server.get('envs', {}):
|
||||
server_config['env'][key] = server['envs'][env_key]
|
||||
if "env" in server_config:
|
||||
for key, value in server_config["env"].items():
|
||||
if value.startswith("env@@"):
|
||||
env_key = value.replace("env@@", "")
|
||||
if env_key in server.get("envs", {}):
|
||||
server_config["env"][key] = server["envs"][env_key]
|
||||
else:
|
||||
logger.warning(f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}")
|
||||
logger.warning(
|
||||
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Connecting to MCP server: {mcp_server.name}")
|
||||
@ -119,20 +125,28 @@ class MCPService:
|
||||
filtered_tools = self._filter_incompatible_tools(tools)
|
||||
|
||||
# Filters tools compatible with the agent
|
||||
agent_tools = server.get('tools', [])
|
||||
filtered_tools = self._filter_tools_by_agent(filtered_tools, agent_tools)
|
||||
agent_tools = server.get("tools", [])
|
||||
filtered_tools = self._filter_tools_by_agent(
|
||||
filtered_tools, agent_tools
|
||||
)
|
||||
self.tools.extend(filtered_tools)
|
||||
|
||||
# Registers the exit_stack with the AsyncExitStack
|
||||
await self.exit_stack.enter_async_context(exit_stack)
|
||||
logger.info(f"Connected successfully. Added {len(filtered_tools)} tools.")
|
||||
logger.info(
|
||||
f"Connected successfully. Added {len(filtered_tools)} tools."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Failed to connect or no tools available for {mcp_server.name}")
|
||||
logger.warning(
|
||||
f"Failed to connect or no tools available for {mcp_server.name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to MCP server {server['id']}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"MCP Toolset created successfully. Total of {len(self.tools)} tools.")
|
||||
logger.info(
|
||||
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
|
||||
)
|
||||
|
||||
return self.tools, self.exit_stack
|
9
src/services/service_providers.py
Normal file
9
src/services/service_providers.py
Normal file
@ -0,0 +1,9 @@
|
||||
from src.config.settings import settings
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.sessions import DatabaseSessionService
|
||||
from google.adk.memory import InMemoryMemoryService
|
||||
|
||||
# Initialize service instances
|
||||
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
|
||||
artifacts_service = InMemoryArtifactService()
|
||||
memory_service = InMemoryMemoryService()
|
@ -132,7 +132,7 @@ def get_session_events(
|
||||
session = get_session_by_id(session_service, session_id)
|
||||
# If we get here, the session exists (get_session_by_id already validates)
|
||||
|
||||
if not hasattr(session, 'events') or session.events is None:
|
||||
if not hasattr(session, "events") or session.events is None:
|
||||
return []
|
||||
|
||||
return session.events
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
|
||||
"""Search for a tool by ID"""
|
||||
try:
|
||||
@ -21,9 +22,10 @@ def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
|
||||
logger.error(f"Error searching for tool {tool_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for tool"
|
||||
detail="Error searching for tool",
|
||||
)
|
||||
|
||||
|
||||
def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
|
||||
"""Search for all tools with pagination"""
|
||||
try:
|
||||
@ -32,9 +34,10 @@ def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
|
||||
logger.error(f"Error searching for tools: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error searching for tools"
|
||||
detail="Error searching for tools",
|
||||
)
|
||||
|
||||
|
||||
def create_tool(db: Session, tool: ToolCreate) -> Tool:
|
||||
"""Creates a new tool"""
|
||||
try:
|
||||
@ -49,9 +52,10 @@ def create_tool(db: Session, tool: ToolCreate) -> Tool:
|
||||
logger.error(f"Error creating tool: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error creating tool"
|
||||
detail="Error creating tool",
|
||||
)
|
||||
|
||||
|
||||
def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]:
|
||||
"""Updates an existing tool"""
|
||||
try:
|
||||
@ -71,9 +75,10 @@ def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[T
|
||||
logger.error(f"Error updating tool {tool_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error updating tool"
|
||||
detail="Error updating tool",
|
||||
)
|
||||
|
||||
|
||||
def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
|
||||
"""Remove a tool"""
|
||||
try:
|
||||
@ -90,5 +95,5 @@ def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
|
||||
logger.error(f"Error removing tool {tool_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error removing tool"
|
||||
detail="Error removing tool",
|
||||
)
|
@ -3,7 +3,10 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.models.models import User, Client
|
||||
from src.schemas.user import UserCreate
|
||||
from src.utils.security import get_password_hash, verify_password, generate_token
|
||||
from src.services.email_service import send_verification_email, send_password_reset_email
|
||||
from src.services.email_service import (
|
||||
send_verification_email,
|
||||
send_password_reset_email,
|
||||
)
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
import logging
|
||||
@ -11,7 +14,13 @@ from typing import Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, client_id: Optional[uuid.UUID] = None) -> Tuple[Optional[User], str]:
|
||||
|
||||
def create_user(
|
||||
db: Session,
|
||||
user_data: UserCreate,
|
||||
is_admin: bool = False,
|
||||
client_id: Optional[uuid.UUID] = None,
|
||||
) -> Tuple[Optional[User], str]:
|
||||
"""
|
||||
Creates a new user in the system
|
||||
|
||||
@ -28,7 +37,9 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
|
||||
# Check if email already exists
|
||||
db_user = db.query(User).filter(User.email == user_data.email).first()
|
||||
if db_user:
|
||||
logger.warning(f"Attempt to register with existing email: {user_data.email}")
|
||||
logger.warning(
|
||||
f"Attempt to register with existing email: {user_data.email}"
|
||||
)
|
||||
return None, "Email already registered"
|
||||
|
||||
# Create verification token
|
||||
@ -56,7 +67,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
|
||||
is_active=False, # Inactive until email is verified
|
||||
email_verified=False,
|
||||
verification_token=verification_token,
|
||||
verification_token_expiry=token_expiry
|
||||
verification_token_expiry=token_expiry,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
@ -68,7 +79,10 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
|
||||
# We don't do rollback here, we just log the error
|
||||
|
||||
logger.info(f"User created successfully: {user.email}")
|
||||
return user, "User created successfully. Check your email to activate your account."
|
||||
return (
|
||||
user,
|
||||
"User created successfully. Check your email to activate your account.",
|
||||
)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
@ -79,6 +93,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
|
||||
logger.error(f"Unexpected error creating user: {str(e)}")
|
||||
return None, f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
def verify_email(db: Session, token: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Verify the user's email using the provided token
|
||||
@ -111,7 +126,9 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]:
|
||||
expiry = expiry.replace(tzinfo=now.tzinfo)
|
||||
|
||||
if expiry < now:
|
||||
logger.warning(f"Attempt to verify with expired token for user: {user.email}")
|
||||
logger.warning(
|
||||
f"Attempt to verify with expired token for user: {user.email}"
|
||||
)
|
||||
return False, "Verification token expired"
|
||||
|
||||
# Update user
|
||||
@ -133,6 +150,7 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]:
|
||||
logger.error(f"Unexpected error verifying email: {str(e)}")
|
||||
return False, f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Resend the verification email
|
||||
@ -149,11 +167,15 @@ def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"Attempt to resend verification email for non-existent email: {email}")
|
||||
logger.warning(
|
||||
f"Attempt to resend verification email for non-existent email: {email}"
|
||||
)
|
||||
return False, "Email not found"
|
||||
|
||||
if user.email_verified:
|
||||
logger.info(f"Attempt to resend verification email for already verified email: {email}")
|
||||
logger.info(
|
||||
f"Attempt to resend verification email for already verified email: {email}"
|
||||
)
|
||||
return False, "Email already verified"
|
||||
|
||||
# Generate new token
|
||||
@ -184,6 +206,7 @@ def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
|
||||
logger.error(f"Unexpected error resending verification: {str(e)}")
|
||||
return False, f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Initiates the password recovery process
|
||||
@ -202,7 +225,10 @@ def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
|
||||
if not user:
|
||||
# For security, we don't inform if the email exists or not
|
||||
logger.info(f"Attempt to recover password for non-existent email: {email}")
|
||||
return True, "If the email is registered, you will receive instructions to reset your password."
|
||||
return (
|
||||
True,
|
||||
"If the email is registered, you will receive instructions to reset your password.",
|
||||
)
|
||||
|
||||
# Generate reset token
|
||||
reset_token = generate_token()
|
||||
@ -221,7 +247,10 @@ def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
|
||||
return False, "Failed to send password reset email"
|
||||
|
||||
logger.info(f"Password reset email sent successfully to: {user.email}")
|
||||
return True, "If the email is registered, you will receive instructions to reset your password."
|
||||
return (
|
||||
True,
|
||||
"If the email is registered, you will receive instructions to reset your password.",
|
||||
)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
@ -232,6 +261,7 @@ def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
|
||||
logger.error(f"Unexpected error processing password recovery: {str(e)}")
|
||||
return False, f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Resets the user's password using the provided token
|
||||
@ -254,7 +284,9 @@ def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, st
|
||||
|
||||
# Check if the token has expired
|
||||
if user.password_reset_expiry < datetime.utcnow():
|
||||
logger.warning(f"Attempt to reset password with expired token for user: {user.email}")
|
||||
logger.warning(
|
||||
f"Attempt to reset password with expired token for user: {user.email}"
|
||||
)
|
||||
return False, "Password reset token expired"
|
||||
|
||||
# Update password
|
||||
@ -264,7 +296,10 @@ def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, st
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Password reset successfully for user: {user.email}")
|
||||
return True, "Password reset successfully. You can now login with your new password."
|
||||
return (
|
||||
True,
|
||||
"Password reset successfully. You can now login with your new password.",
|
||||
)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
@ -275,6 +310,7 @@ def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, st
|
||||
logger.error(f"Unexpected error resetting password: {str(e)}")
|
||||
return False, f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
"""
|
||||
Searches for a user by email
|
||||
@ -292,6 +328,7 @@ def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
logger.error(f"Error searching for user by email: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticates a user with email and password
|
||||
@ -313,6 +350,7 @@ def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
|
||||
"""
|
||||
Lists the admin users
|
||||
@ -326,7 +364,7 @@ def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
|
||||
List[User]: List of admin users
|
||||
"""
|
||||
try:
|
||||
users = db.query(User).filter(User.is_admin == True).offset(skip).limit(limit).all()
|
||||
users = db.query(User).filter(User.is_admin).offset(skip).limit(limit).all()
|
||||
logger.info(f"List of admins: {len(users)} found")
|
||||
return users
|
||||
|
||||
@ -338,6 +376,7 @@ def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
|
||||
logger.error(f"Unexpected error listing admins: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]:
|
||||
"""
|
||||
Creates a new admin user
|
||||
@ -351,6 +390,7 @@ def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User
|
||||
"""
|
||||
return create_user(db, user_data, is_admin=True)
|
||||
|
||||
|
||||
def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]:
|
||||
"""
|
||||
Deactivates a user (does not delete, only marks as inactive)
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import sys
|
||||
from src.config.settings import settings
|
||||
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
"""Custom formatter for logs"""
|
||||
|
||||
@ -12,14 +13,16 @@ class CustomFormatter(logging.Formatter):
|
||||
bold_red = "\x1b[31;1m"
|
||||
reset = "\x1b[0m"
|
||||
|
||||
format_template = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
|
||||
format_template = (
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
|
||||
)
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: grey + format_template + reset,
|
||||
logging.INFO: grey + format_template + reset,
|
||||
logging.WARNING: yellow + format_template + reset,
|
||||
logging.ERROR: red + format_template + reset,
|
||||
logging.CRITICAL: bold_red + format_template + reset
|
||||
logging.CRITICAL: bold_red + format_template + reset,
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
@ -27,6 +30,7 @@ class CustomFormatter(logging.Formatter):
|
||||
formatter = logging.Formatter(log_fmt)
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
def setup_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Configures a custom logger
|
||||
|
@ -11,7 +11,8 @@ from dataclasses import dataclass
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Fix bcrypt error with passlib
|
||||
if not hasattr(bcrypt, '__about__'):
|
||||
if not hasattr(bcrypt, "__about__"):
|
||||
|
||||
@dataclass
|
||||
class BcryptAbout:
|
||||
__version__: str = getattr(bcrypt, "__version__")
|
||||
@ -21,31 +22,33 @@ if not hasattr(bcrypt, '__about__'):
|
||||
# Context for password hashing using bcrypt
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Creates a password hash"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verifies if the provided password matches the stored hash"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_jwt_token(data: dict, expires_delta: timedelta = None) -> str:
|
||||
"""Creates a JWT token"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.JWT_EXPIRATION_TIME
|
||||
)
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRATION_TIME)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def generate_token(length: int = 32) -> str:
|
||||
"""Generates a secure token for email verification or password reset"""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
token = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
token = "".join(secrets.choice(alphabet) for _ in range(length))
|
||||
return token
|
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Package initialization for tests
|
1
tests/api/__init__.py
Normal file
1
tests/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# API tests package
|
11
tests/api/test_root.py
Normal file
11
tests/api/test_root.py
Normal file
@ -0,0 +1,11 @@
|
||||
def test_read_root(client):
|
||||
"""
|
||||
Test that the root endpoint returns the correct response.
|
||||
"""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
assert "documentation" in data
|
||||
assert "version" in data
|
||||
assert "auth" in data
|
1
tests/services/__init__.py
Normal file
1
tests/services/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Services tests package
|
27
tests/services/test_auth_service.py
Normal file
27
tests/services/test_auth_service.py
Normal file
@ -0,0 +1,27 @@
|
||||
from src.services.auth_service import create_access_token
|
||||
from src.models.models import User
|
||||
import uuid
|
||||
|
||||
|
||||
def test_create_access_token():
|
||||
"""
|
||||
Test that an access token is created with the correct data.
|
||||
"""
|
||||
# Create a mock user
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
hashed_password="hashed_password",
|
||||
is_active=True,
|
||||
is_admin=False,
|
||||
name="Test User",
|
||||
client_id=uuid.uuid4(),
|
||||
)
|
||||
|
||||
# Create token
|
||||
token = create_access_token(user)
|
||||
|
||||
# Simple validation
|
||||
assert token is not None
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
Loading…
Reference in New Issue
Block a user