chore: update project structure and add testing framework
This commit is contained in:
parent
7af234ef48
commit
e7e030dfd5
77
.cursorrules
77
.cursorrules
@ -17,9 +17,16 @@
|
|||||||
```
|
```
|
||||||
src/
|
src/
|
||||||
├── api/
|
├── api/
|
||||||
│ ├── routes.py # API routes definition
|
│ ├── __init__.py # Package initialization
|
||||||
│ ├── auth_routes.py # Authentication routes (login, registration, etc.)
|
│ ├── admin_routes.py # Admin routes for management interface
|
||||||
│ └── admin_routes.py # Protected admin routes
|
│ ├── agent_routes.py # Routes to manage agents
|
||||||
|
│ ├── auth_routes.py # Authentication routes (login, registration)
|
||||||
|
│ ├── chat_routes.py # Routes for chat interactions with agents
|
||||||
|
│ ├── client_routes.py # Routes to manage clients
|
||||||
|
│ ├── contact_routes.py # Routes to manage contacts
|
||||||
|
│ ├── mcp_server_routes.py # Routes to manage MCP servers
|
||||||
|
│ ├── session_routes.py # Routes to manage chat sessions
|
||||||
|
│ └── tool_routes.py # Routes to manage tools for agents
|
||||||
├── config/
|
├── config/
|
||||||
│ ├── database.py # Database configuration
|
│ ├── database.py # Database configuration
|
||||||
│ └── settings.py # General settings
|
│ └── settings.py # General settings
|
||||||
@ -30,18 +37,21 @@ src/
|
|||||||
│ └── models.py # SQLAlchemy models
|
│ └── models.py # SQLAlchemy models
|
||||||
├── schemas/
|
├── schemas/
|
||||||
│ ├── schemas.py # Main Pydantic schemas
|
│ ├── schemas.py # Main Pydantic schemas
|
||||||
|
│ ├── chat.py # Chat schemas
|
||||||
│ ├── user.py # User and authentication schemas
|
│ ├── user.py # User and authentication schemas
|
||||||
│ └── audit.py # Audit logs schemas
|
│ └── audit.py # Audit logs schemas
|
||||||
├── services/
|
├── services/
|
||||||
│ ├── agent_service.py # Business logic for agents
|
│ ├── agent_service.py # Business logic for agents
|
||||||
|
│ ├── agent_runner.py # Agent execution logic
|
||||||
|
│ ├── auth_service.py # JWT authentication logic
|
||||||
|
│ ├── audit_service.py # Audit logs logic
|
||||||
│ ├── client_service.py # Business logic for clients
|
│ ├── client_service.py # Business logic for clients
|
||||||
│ ├── contact_service.py # Business logic for contacts
|
│ ├── contact_service.py # Business logic for contacts
|
||||||
│ ├── mcp_server_service.py # Business logic for MCP servers
|
|
||||||
│ ├── tool_service.py # Business logic for tools
|
|
||||||
│ ├── user_service.py # User and authentication logic
|
|
||||||
│ ├── auth_service.py # JWT authentication logic
|
|
||||||
│ ├── email_service.py # Email sending service
|
│ ├── email_service.py # Email sending service
|
||||||
│ └── audit_service.py # Audit logs logic
|
│ ├── mcp_server_service.py # Business logic for MCP servers
|
||||||
|
│ ├── session_service.py # Business logic for chat sessions
|
||||||
|
│ ├── tool_service.py # Business logic for tools
|
||||||
|
│ └── user_service.py # User management logic
|
||||||
├── templates/
|
├── templates/
|
||||||
│ ├── emails/
|
│ ├── emails/
|
||||||
│ │ ├── base_email.html # Base template with common structure and styles
|
│ │ ├── base_email.html # Base template with common structure and styles
|
||||||
@ -49,7 +59,21 @@ src/
|
|||||||
│ │ ├── password_reset.html # Password reset template
|
│ │ ├── password_reset.html # Password reset template
|
||||||
│ │ ├── welcome_email.html # Welcome email after verification
|
│ │ ├── welcome_email.html # Welcome email after verification
|
||||||
│ │ └── account_locked.html # Security alert for locked accounts
|
│ │ └── account_locked.html # Security alert for locked accounts
|
||||||
|
├── tests/
|
||||||
|
│ ├── __init__.py # Package initialization
|
||||||
|
│ ├── api/
|
||||||
|
│ │ ├── __init__.py # Package initialization
|
||||||
|
│ │ ├── test_auth_routes.py # Test for authentication routes
|
||||||
|
│ │ └── test_root.py # Test for root endpoint
|
||||||
|
│ ├── models/
|
||||||
|
│ │ ├── __init__.py # Package initialization
|
||||||
|
│ │ ├── test_models.py # Test for models
|
||||||
|
│ ├── services/
|
||||||
|
│ │ ├── __init__.py # Package initialization
|
||||||
|
│ │ ├── test_auth_service.py # Test for authentication service
|
||||||
|
│ │ └── test_user_service.py # Test for user service
|
||||||
└── utils/
|
└── utils/
|
||||||
|
├── logger.py # Logger configuration
|
||||||
└── security.py # Security utilities (JWT, hash)
|
└── security.py # Security utilities (JWT, hash)
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -63,6 +87,15 @@ src/
|
|||||||
- Code examples in documentation must be in English
|
- Code examples in documentation must be in English
|
||||||
- Commit messages must be in English
|
- Commit messages must be in English
|
||||||
|
|
||||||
|
### Project Configuration
|
||||||
|
- Dependencies managed in `pyproject.toml` using modern Python packaging standards
|
||||||
|
- Development dependencies specified as optional dependencies in `pyproject.toml`
|
||||||
|
- Single source of truth for project metadata in `pyproject.toml`
|
||||||
|
- Build system configured to use setuptools
|
||||||
|
- Pytest configuration in `pyproject.toml` under `[tool.pytest.ini_options]`
|
||||||
|
- Code formatting with Black configured in `pyproject.toml`
|
||||||
|
- Linting with Flake8 configured in `.flake8`
|
||||||
|
|
||||||
### Schemas (Pydantic)
|
### Schemas (Pydantic)
|
||||||
- Use `BaseModel` as base for all schemas
|
- Use `BaseModel` as base for all schemas
|
||||||
- Define fields with explicit types
|
- Define fields with explicit types
|
||||||
@ -136,6 +169,28 @@ src/
|
|||||||
- Indentation with 4 spaces
|
- Indentation with 4 spaces
|
||||||
- Maximum of 79 characters per line
|
- Maximum of 79 characters per line
|
||||||
|
|
||||||
|
## Commit Rules
|
||||||
|
- Use Conventional Commits format for all commit messages
|
||||||
|
- Format: `<type>(<scope>): <description>`
|
||||||
|
- Types:
|
||||||
|
- `feat`: A new feature
|
||||||
|
- `fix`: A bug fix
|
||||||
|
- `docs`: Documentation changes
|
||||||
|
- `style`: Changes that do not affect code meaning (formatting, etc.)
|
||||||
|
- `refactor`: Code changes that neither fix a bug nor add a feature
|
||||||
|
- `perf`: Performance improvements
|
||||||
|
- `test`: Adding or modifying tests
|
||||||
|
- `chore`: Changes to build process or auxiliary tools
|
||||||
|
- Scope is optional and should be the module or component affected
|
||||||
|
- Description should be concise, in the imperative mood, and not capitalized
|
||||||
|
- Use body for more detailed explanations if needed
|
||||||
|
- Reference issues in the footer with `Fixes #123` or `Relates to #123`
|
||||||
|
- Examples:
|
||||||
|
- `feat(auth): add password reset functionality`
|
||||||
|
- `fix(api): correct validation error in client registration`
|
||||||
|
- `docs: update API documentation for new endpoints`
|
||||||
|
- `refactor(services): improve error handling in authentication`
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
- Always validate input data
|
- Always validate input data
|
||||||
- Implement appropriate logging
|
- Implement appropriate logging
|
||||||
@ -163,6 +218,7 @@ src/
|
|||||||
|
|
||||||
## Useful Commands
|
## Useful Commands
|
||||||
- `make run`: Start the server
|
- `make run`: Start the server
|
||||||
|
- `make run-prod`: Start the server in production mode
|
||||||
- `make alembic-revision message="description"`: Create new migration
|
- `make alembic-revision message="description"`: Create new migration
|
||||||
- `make alembic-upgrade`: Apply pending migrations
|
- `make alembic-upgrade`: Apply pending migrations
|
||||||
- `make alembic-downgrade`: Revert last migration
|
- `make alembic-downgrade`: Revert last migration
|
||||||
@ -170,3 +226,8 @@ src/
|
|||||||
- `make alembic-reset`: Reset database to initial state
|
- `make alembic-reset`: Reset database to initial state
|
||||||
- `make alembic-upgrade-cascade`: Force upgrade removing dependencies
|
- `make alembic-upgrade-cascade`: Force upgrade removing dependencies
|
||||||
- `make clear-cache`: Clean project cache
|
- `make clear-cache`: Clean project cache
|
||||||
|
- `make seed-all`: Run all database seeders
|
||||||
|
- `make lint`: Run linting checks with flake8
|
||||||
|
- `make format`: Format code with black
|
||||||
|
- `make install`: Install project for development
|
||||||
|
- `make install-dev`: Install project with development dependencies
|
||||||
|
@ -1,3 +1,47 @@
|
|||||||
|
# Environment and IDE
|
||||||
|
.venv
|
||||||
|
venv
|
||||||
|
.env
|
||||||
|
.idea
|
||||||
|
.vscode
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.Python
|
||||||
|
.pytest_cache
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
|
||||||
|
# Version control
|
||||||
|
.git
|
||||||
|
.github
|
||||||
|
.gitignore
|
||||||
|
|
||||||
|
# Logs and temp files
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
tmp
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
.dockerignore
|
||||||
|
Dockerfile*
|
||||||
|
docker-compose*
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
README.md
|
||||||
|
LICENSE
|
||||||
|
docs/
|
||||||
|
|
||||||
|
# Development tools
|
||||||
|
tests/
|
||||||
|
.flake8
|
||||||
|
pyproject.toml
|
||||||
|
requirements-dev.txt
|
||||||
|
Makefile
|
||||||
|
|
||||||
# Ambiente virtual
|
# Ambiente virtual
|
||||||
venv/
|
venv/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
8
.flake8
Normal file
8
.flake8
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
[flake8]
|
||||||
|
max-line-length = 88
|
||||||
|
exclude = .git,__pycache__,venv,alembic/versions/*
|
||||||
|
ignore = E203, W503, E501
|
||||||
|
per-file-ignores =
|
||||||
|
__init__.py: F401
|
||||||
|
src/models/models.py: E712
|
||||||
|
alembic/*: E711,E712,F401
|
77
.gitignore
vendored
77
.gitignore
vendored
@ -1,13 +1,10 @@
|
|||||||
# Byte-compiled / optimized / DLL files
|
# Python
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
# C extensions
|
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
# Distribution / packaging
|
|
||||||
.Python
|
.Python
|
||||||
|
env/
|
||||||
build/
|
build/
|
||||||
develop-eggs/
|
develop-eggs/
|
||||||
dist/
|
dist/
|
||||||
@ -19,19 +16,10 @@ lib64/
|
|||||||
parts/
|
parts/
|
||||||
sdist/
|
sdist/
|
||||||
var/
|
var/
|
||||||
wheels/
|
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
|
|
||||||
# PyInstaller
|
|
||||||
*.manifest
|
|
||||||
*.spec
|
|
||||||
|
|
||||||
# Installer logs
|
|
||||||
pip-log.txt
|
|
||||||
pip-delete-this-directory.txt
|
|
||||||
|
|
||||||
# Unit test / coverage reports
|
# Unit test / coverage reports
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.tox/
|
.tox/
|
||||||
@ -42,14 +30,57 @@ nosetests.xml
|
|||||||
coverage.xml
|
coverage.xml
|
||||||
*.cover
|
*.cover
|
||||||
.hypothesis/
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.sublime-project
|
||||||
|
*.sublime-workspace
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs/
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Database
|
||||||
|
*.db
|
||||||
|
*.sqlite
|
||||||
|
*.sqlite3
|
||||||
|
backup/
|
||||||
|
|
||||||
|
# Local
|
||||||
|
local_settings.py
|
||||||
|
local.py
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
.docker/
|
||||||
|
|
||||||
|
# Alembic versions
|
||||||
|
# alembic/versions/
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
# Translations
|
# Translations
|
||||||
*.mo
|
*.mo
|
||||||
*.pot
|
*.pot
|
||||||
|
|
||||||
# Django stuff:
|
# Django stuff:
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
db.sqlite3
|
||||||
|
|
||||||
# Flask stuff:
|
# Flask stuff:
|
||||||
@ -77,17 +108,6 @@ celerybeat-schedule
|
|||||||
# SageMath parsed files
|
# SageMath parsed files
|
||||||
*.sage.py
|
*.sage.py
|
||||||
|
|
||||||
# Environments
|
|
||||||
.env
|
|
||||||
.venv
|
|
||||||
env/
|
|
||||||
venv/
|
|
||||||
.venv/
|
|
||||||
.env/
|
|
||||||
ENV/
|
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# Spyder project settings
|
# Spyder project settings
|
||||||
.spyderproject
|
.spyderproject
|
||||||
.spyproject
|
.spyproject
|
||||||
@ -102,11 +122,8 @@ venv.bak/
|
|||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.idea/
|
|
||||||
.vscode/
|
|
||||||
*.swp
|
*.swp
|
||||||
*.swo
|
*.swo
|
||||||
|
|
||||||
# OS
|
# OS
|
||||||
.DS_Store
|
|
||||||
Thumbs.db
|
Thumbs.db
|
16
Dockerfile
16
Dockerfile
@ -1,4 +1,4 @@
|
|||||||
FROM python:3.11-slim
|
FROM python:3.10-slim
|
||||||
|
|
||||||
# Define o diretório de trabalho
|
# Define o diretório de trabalho
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
@ -15,19 +15,19 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copia os arquivos de requisitos
|
# Copy project files
|
||||||
COPY requirements.txt .
|
|
||||||
|
|
||||||
# Instala as dependências
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Copia o código-fonte
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir -e .
|
||||||
|
|
||||||
# Configuração para produção
|
# Configuração para produção
|
||||||
ENV PORT=8000 \
|
ENV PORT=8000 \
|
||||||
HOST=0.0.0.0 \
|
HOST=0.0.0.0 \
|
||||||
DEBUG=false
|
DEBUG=false
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
# Define o comando de inicialização
|
# Define o comando de inicialização
|
||||||
CMD alembic upgrade head && uvicorn src.main:app --host $HOST --port $PORT
|
CMD alembic upgrade head && uvicorn src.main:app --host $HOST --port $PORT
|
51
Makefile
51
Makefile
@ -1,42 +1,42 @@
|
|||||||
.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs
|
.PHONY: migrate init revision upgrade downgrade run seed-admin seed-client seed-agents seed-mcp-servers seed-tools seed-contacts seed-all docker-build docker-up docker-down docker-logs lint format install install-dev venv
|
||||||
|
|
||||||
# Comandos do Alembic
|
# Alembic commands
|
||||||
init:
|
init:
|
||||||
alembic init alembics
|
alembic init alembics
|
||||||
|
|
||||||
# make alembic-revision message="descrição da migração"
|
# make alembic-revision message="migration description"
|
||||||
alembic-revision:
|
alembic-revision:
|
||||||
alembic revision --autogenerate -m "$(message)"
|
alembic revision --autogenerate -m "$(message)"
|
||||||
|
|
||||||
# Comando para atualizar o banco de dados
|
# Command to update database to latest version
|
||||||
alembic-upgrade:
|
alembic-upgrade:
|
||||||
alembic upgrade head
|
alembic upgrade head
|
||||||
|
|
||||||
# Comando para voltar uma versão
|
# Command to downgrade one version
|
||||||
alembic-downgrade:
|
alembic-downgrade:
|
||||||
alembic downgrade -1
|
alembic downgrade -1
|
||||||
|
|
||||||
# Comando para rodar o servidor
|
# Command to run the server
|
||||||
run:
|
run:
|
||||||
uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 --reload-dir src
|
uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 --reload-dir src
|
||||||
|
|
||||||
# Comando para limpar o cache em todas as pastas do projeto pastas pycache
|
# Command to run the server in production mode
|
||||||
|
run-prod:
|
||||||
|
uvicorn src.main:app --host 0.0.0.0 --port 8000 --workers 4
|
||||||
|
|
||||||
|
# Command to clean cache in all project folders
|
||||||
clear-cache:
|
clear-cache:
|
||||||
rm -rf ~/.cache/uv/environments-v2/* && find . -type d -name "__pycache__" -exec rm -r {} +
|
rm -rf ~/.cache/uv/environments-v2/* && find . -type d -name "__pycache__" -exec rm -r {} +
|
||||||
|
|
||||||
# Comando para criar uma nova migração
|
# Command to create a new migration and apply it
|
||||||
alembic-migrate:
|
alembic-migrate:
|
||||||
alembic revision --autogenerate -m "$(message)" && alembic upgrade head
|
alembic revision --autogenerate -m "$(message)" && alembic upgrade head
|
||||||
|
|
||||||
# Comando para resetar o banco de dados
|
# Command to reset the database
|
||||||
alembic-reset:
|
alembic-reset:
|
||||||
alembic downgrade base && alembic upgrade head
|
alembic downgrade base && alembic upgrade head
|
||||||
|
|
||||||
# Comando para forçar upgrade com CASCADE
|
# Commands to run seeders
|
||||||
alembic-upgrade-cascade:
|
|
||||||
psql -U postgres -d a2a_saas -c "DROP TABLE IF EXISTS events CASCADE; DROP TABLE IF EXISTS sessions CASCADE; DROP TABLE IF EXISTS user_states CASCADE; DROP TABLE IF EXISTS app_states CASCADE;" && alembic upgrade head
|
|
||||||
|
|
||||||
# Comandos para executar seeders
|
|
||||||
seed-admin:
|
seed-admin:
|
||||||
python -m scripts.seeders.admin_seeder
|
python -m scripts.seeders.admin_seeder
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ seed-contacts:
|
|||||||
seed-all:
|
seed-all:
|
||||||
python -m scripts.run_seeders
|
python -m scripts.run_seeders
|
||||||
|
|
||||||
# Comandos Docker
|
# Docker commands
|
||||||
docker-build:
|
docker-build:
|
||||||
docker-compose build
|
docker-compose build
|
||||||
|
|
||||||
@ -72,4 +72,21 @@ docker-logs:
|
|||||||
docker-compose logs -f
|
docker-compose logs -f
|
||||||
|
|
||||||
docker-seed:
|
docker-seed:
|
||||||
docker-compose exec api python -m scripts.run_seeders
|
docker-compose exec api python -m scripts.run_seeders
|
||||||
|
|
||||||
|
# Testing, linting and formatting commands
|
||||||
|
lint:
|
||||||
|
flake8 src/ tests/
|
||||||
|
|
||||||
|
format:
|
||||||
|
black src/ tests/
|
||||||
|
|
||||||
|
# Virtual environment and installation commands
|
||||||
|
venv:
|
||||||
|
python -m venv venv
|
||||||
|
|
||||||
|
install:
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
|
install-dev:
|
||||||
|
pip install -e ".[dev]"
|
53
conftest.py
Normal file
53
conftest.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from src.config.database import Base, get_db
|
||||||
|
from src.main import app
|
||||||
|
|
||||||
|
# Use in-memory SQLite for tests
|
||||||
|
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
SQLALCHEMY_DATABASE_URL,
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def db_session():
|
||||||
|
"""Creates a fresh database session for each test."""
|
||||||
|
Base.metadata.create_all(bind=engine) # Create tables
|
||||||
|
|
||||||
|
connection = engine.connect()
|
||||||
|
transaction = connection.begin()
|
||||||
|
session = TestingSessionLocal(bind=connection)
|
||||||
|
|
||||||
|
# Use our test database instead of the standard one
|
||||||
|
def override_get_db():
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
yield session # The test will run here
|
||||||
|
|
||||||
|
# Teardown
|
||||||
|
transaction.rollback()
|
||||||
|
connection.close()
|
||||||
|
Base.metadata.drop_all(bind=engine)
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def client(db_session):
|
||||||
|
"""Creates a FastAPI TestClient with database session fixture."""
|
||||||
|
with TestClient(app) as test_client:
|
||||||
|
yield test_client
|
89
pyproject.toml
Normal file
89
pyproject.toml
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=42", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "evo-ai"
|
||||||
|
version = "1.0.0"
|
||||||
|
description = "API for executing AI agents"
|
||||||
|
readme = "README.md"
|
||||||
|
authors = [
|
||||||
|
{name = "EvoAI Team", email = "admin@evoai.com"}
|
||||||
|
]
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
license = {text = "Proprietary"}
|
||||||
|
classifiers = [
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"License :: Other/Proprietary License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Main dependencies
|
||||||
|
dependencies = [
|
||||||
|
"fastapi==0.115.12",
|
||||||
|
"uvicorn==0.34.2",
|
||||||
|
"pydantic==2.11.3",
|
||||||
|
"sqlalchemy==2.0.40",
|
||||||
|
"psycopg2-binary==2.9.10",
|
||||||
|
"google-cloud-aiplatform==1.90.0",
|
||||||
|
"python-dotenv==1.1.0",
|
||||||
|
"google-adk==0.3.0",
|
||||||
|
"litellm==1.67.4.post1",
|
||||||
|
"python-multipart==0.0.20",
|
||||||
|
"alembic==1.15.2",
|
||||||
|
"asyncpg==0.30.0",
|
||||||
|
"python-jose==3.4.0",
|
||||||
|
"passlib==1.7.4",
|
||||||
|
"sendgrid==6.11.0",
|
||||||
|
"pydantic-settings==2.9.1",
|
||||||
|
"fastapi_utils==0.8.0",
|
||||||
|
"bcrypt==4.3.0",
|
||||||
|
"jinja2==3.1.6",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"black==25.1.0",
|
||||||
|
"flake8==7.2.0",
|
||||||
|
"pytest==8.3.5",
|
||||||
|
"pytest-cov==6.1.1",
|
||||||
|
"httpx==0.28.1",
|
||||||
|
"pytest-asyncio==0.26.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = ["src"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 88
|
||||||
|
target-version = ["py310"]
|
||||||
|
include = '\.pyi?$'
|
||||||
|
exclude = '''
|
||||||
|
/(
|
||||||
|
\.git
|
||||||
|
| \.hg
|
||||||
|
| \.mypy_cache
|
||||||
|
| \.tox
|
||||||
|
| \.venv
|
||||||
|
| _build
|
||||||
|
| buck-out
|
||||||
|
| build
|
||||||
|
| dist
|
||||||
|
| venv
|
||||||
|
| alembic/versions
|
||||||
|
)/
|
||||||
|
'''
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = "test_*.py"
|
||||||
|
python_functions = "test_*"
|
||||||
|
python_classes = "Test*"
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::DeprecationWarning",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
source = ["src"]
|
||||||
|
omit = ["tests/*", "alembic/*"]
|
@ -1,22 +0,0 @@
|
|||||||
fastapi
|
|
||||||
uvicorn
|
|
||||||
pydantic
|
|
||||||
sqlalchemy
|
|
||||||
psycopg2
|
|
||||||
psycopg2-binary
|
|
||||||
google-cloud-aiplatform
|
|
||||||
python-dotenv
|
|
||||||
google-adk
|
|
||||||
litellm
|
|
||||||
python-multipart
|
|
||||||
alembic
|
|
||||||
asyncpg
|
|
||||||
# Novas dependências para autenticação
|
|
||||||
python-jose[cryptography]
|
|
||||||
passlib[bcrypt]
|
|
||||||
sendgrid
|
|
||||||
pydantic[email]
|
|
||||||
pydantic-settings
|
|
||||||
fastapi_utils
|
|
||||||
bcrypt
|
|
||||||
jinja2
|
|
6
setup.py
Normal file
6
setup.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
"""Setup script for the package."""
|
||||||
|
|
||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
setup()
|
@ -1,3 +1,3 @@
|
|||||||
"""
|
"""
|
||||||
Pacote principal da aplicação
|
Main package of the application
|
||||||
"""
|
"""
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
|
from typing import List
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from src.config.database import get_db
|
from src.config.database import get_db
|
||||||
from src.core.jwt_middleware import get_jwt_token, verify_admin
|
from src.core.jwt_middleware import get_jwt_token, verify_admin
|
||||||
from src.schemas.audit import AuditLogResponse, AuditLogFilter
|
from src.schemas.audit import AuditLogResponse, AuditLogFilter
|
||||||
from src.services.audit_service import get_audit_logs, create_audit_log
|
from src.services.audit_service import get_audit_logs, create_audit_log
|
||||||
from src.services.user_service import get_admin_users, create_admin_user, deactivate_user
|
from src.services.user_service import (
|
||||||
|
get_admin_users,
|
||||||
|
create_admin_user,
|
||||||
|
deactivate_user,
|
||||||
|
)
|
||||||
from src.schemas.user import UserResponse, AdminUserCreate
|
from src.schemas.user import UserResponse, AdminUserCreate
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
@ -18,6 +21,7 @@ router = APIRouter(
|
|||||||
responses={403: {"description": "Permission denied"}},
|
responses={403: {"description": "Permission denied"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Audit routes
|
# Audit routes
|
||||||
@router.get("/audit-logs", response_model=List[AuditLogResponse])
|
@router.get("/audit-logs", response_model=List[AuditLogResponse])
|
||||||
async def read_audit_logs(
|
async def read_audit_logs(
|
||||||
@ -27,12 +31,12 @@ async def read_audit_logs(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get audit logs with optional filters
|
Get audit logs with optional filters
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filters: Filters for log search
|
filters: Filters for log search
|
||||||
db: Database session
|
db: Database session
|
||||||
payload: JWT token payload
|
payload: JWT token payload
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[AuditLogResponse]: List of audit logs
|
List[AuditLogResponse]: List of audit logs
|
||||||
"""
|
"""
|
||||||
@ -45,9 +49,10 @@ async def read_audit_logs(
|
|||||||
resource_type=filters.resource_type,
|
resource_type=filters.resource_type,
|
||||||
resource_id=filters.resource_id,
|
resource_id=filters.resource_id,
|
||||||
start_date=filters.start_date,
|
start_date=filters.start_date,
|
||||||
end_date=filters.end_date
|
end_date=filters.end_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Admin routes
|
# Admin routes
|
||||||
@router.get("/users", response_model=List[UserResponse])
|
@router.get("/users", response_model=List[UserResponse])
|
||||||
async def read_admin_users(
|
async def read_admin_users(
|
||||||
@ -58,18 +63,19 @@ async def read_admin_users(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
List admin users
|
List admin users
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
skip: Number of records to skip
|
skip: Number of records to skip
|
||||||
limit: Maximum number of records to return
|
limit: Maximum number of records to return
|
||||||
db: Database session
|
db: Database session
|
||||||
payload: JWT token payload
|
payload: JWT token payload
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[UserResponse]: List of admin users
|
List[UserResponse]: List of admin users
|
||||||
"""
|
"""
|
||||||
return get_admin_users(db, skip, limit)
|
return get_admin_users(db, skip, limit)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def create_new_admin_user(
|
async def create_new_admin_user(
|
||||||
user_data: AdminUserCreate,
|
user_data: AdminUserCreate,
|
||||||
@ -79,16 +85,16 @@ async def create_new_admin_user(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new admin user
|
Create a new admin user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_data: User data to be created
|
user_data: User data to be created
|
||||||
request: FastAPI Request object
|
request: FastAPI Request object
|
||||||
db: Database session
|
db: Database session
|
||||||
payload: JWT token payload
|
payload: JWT token payload
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UserResponse: Created user data
|
UserResponse: Created user data
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If there is an error in creation
|
HTTPException: If there is an error in creation
|
||||||
"""
|
"""
|
||||||
@ -97,17 +103,14 @@ async def create_new_admin_user(
|
|||||||
if not user_id:
|
if not user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Unable to identify the logged in user"
|
detail="Unable to identify the logged in user",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create admin user
|
# Create admin user
|
||||||
user, message = create_admin_user(db, user_data)
|
user, message = create_admin_user(db, user_data)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=message
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register action in audit log
|
# Register action in audit log
|
||||||
create_audit_log(
|
create_audit_log(
|
||||||
db,
|
db,
|
||||||
@ -116,11 +119,12 @@ async def create_new_admin_user(
|
|||||||
resource_type="admin_user",
|
resource_type="admin_user",
|
||||||
resource_id=str(user.id),
|
resource_id=str(user.id),
|
||||||
details={"email": user.email},
|
details={"email": user.email},
|
||||||
request=request
|
request=request,
|
||||||
)
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def deactivate_admin_user(
|
async def deactivate_admin_user(
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
@ -130,13 +134,13 @@ async def deactivate_admin_user(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Deactivate an admin user (does not delete, only deactivates)
|
Deactivate an admin user (does not delete, only deactivates)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: ID of the user to be deactivated
|
user_id: ID of the user to be deactivated
|
||||||
request: FastAPI Request object
|
request: FastAPI Request object
|
||||||
db: Database session
|
db: Database session
|
||||||
payload: JWT token payload
|
payload: JWT token payload
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If there is an error in deactivation
|
HTTPException: If there is an error in deactivation
|
||||||
"""
|
"""
|
||||||
@ -145,24 +149,21 @@ async def deactivate_admin_user(
|
|||||||
if not current_user_id:
|
if not current_user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Unable to identify the logged in user"
|
detail="Unable to identify the logged in user",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Do not allow deactivating yourself
|
# Do not allow deactivating yourself
|
||||||
if str(user_id) == current_user_id:
|
if str(user_id) == current_user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Unable to deactivate your own user"
|
detail="Unable to deactivate your own user",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Deactivate user
|
# Deactivate user
|
||||||
success, message = deactivate_user(db, user_id)
|
success, message = deactivate_user(db, user_id)
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=message
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register action in audit log
|
# Register action in audit log
|
||||||
create_audit_log(
|
create_audit_log(
|
||||||
db,
|
db,
|
||||||
@ -171,5 +172,5 @@ async def deactivate_admin_user(
|
|||||||
resource_type="admin_user",
|
resource_type="admin_user",
|
||||||
resource_id=str(user_id),
|
resource_id=str(user_id),
|
||||||
details=None,
|
details=None,
|
||||||
request=request
|
request=request,
|
||||||
)
|
)
|
||||||
|
@ -7,10 +7,6 @@ from src.core.jwt_middleware import (
|
|||||||
get_jwt_token,
|
get_jwt_token,
|
||||||
verify_user_client,
|
verify_user_client,
|
||||||
)
|
)
|
||||||
from src.core.jwt_middleware import (
|
|
||||||
get_jwt_token,
|
|
||||||
verify_user_client,
|
|
||||||
)
|
|
||||||
from src.schemas.schemas import (
|
from src.schemas.schemas import (
|
||||||
Agent,
|
Agent,
|
||||||
AgentCreate,
|
AgentCreate,
|
||||||
|
@ -162,9 +162,7 @@ async def login_for_access_token(form_data: UserLogin, db: Session = Depends(get
|
|||||||
"""
|
"""
|
||||||
user = authenticate_user(db, form_data.email, form_data.password)
|
user = authenticate_user(db, form_data.email, form_data.password)
|
||||||
if not user:
|
if not user:
|
||||||
logger.warning(
|
logger.warning(f"Login attempt with invalid credentials: {form_data.email}")
|
||||||
f"Login attempt with invalid credentials: {form_data.email}"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid email or password",
|
detail="Invalid email or password",
|
||||||
|
@ -11,7 +11,11 @@ from src.services import (
|
|||||||
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
|
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
|
||||||
from src.services.agent_runner import run_agent
|
from src.services.agent_runner import run_agent
|
||||||
from src.core.exceptions import AgentNotFoundError
|
from src.core.exceptions import AgentNotFoundError
|
||||||
from src.main import session_service, artifacts_service, memory_service
|
from src.services.service_providers import (
|
||||||
|
session_service,
|
||||||
|
artifacts_service,
|
||||||
|
memory_service,
|
||||||
|
)
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
@ -71,4 +75,4 @@ async def chat(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||||
)
|
)
|
||||||
|
@ -19,7 +19,7 @@ from src.services.session_service import (
|
|||||||
get_sessions_by_agent,
|
get_sessions_by_agent,
|
||||||
get_sessions_by_client,
|
get_sessions_by_client,
|
||||||
)
|
)
|
||||||
from src.main import session_service
|
from src.services.service_providers import session_service
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -30,6 +30,7 @@ router = APIRouter(
|
|||||||
responses={404: {"description": "Not found"}},
|
responses={404: {"description": "Not found"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Session Routes
|
# Session Routes
|
||||||
@router.get("/client/{client_id}", response_model=List[Adk_Session])
|
@router.get("/client/{client_id}", response_model=List[Adk_Session])
|
||||||
async def get_client_sessions(
|
async def get_client_sessions(
|
||||||
|
@ -10,9 +10,10 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
def get_db():
|
def get_db():
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
@ -1,55 +1,66 @@
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class BaseAPIException(HTTPException):
|
class BaseAPIException(HTTPException):
|
||||||
"""Base class for API exceptions"""
|
"""Base class for API exceptions"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
status_code: int,
|
status_code: int,
|
||||||
message: str,
|
message: str,
|
||||||
error_code: str,
|
error_code: str,
|
||||||
details: Optional[Dict[str, Any]] = None
|
details: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(status_code=status_code, detail={
|
super().__init__(
|
||||||
"error": message,
|
status_code=status_code,
|
||||||
"error_code": error_code,
|
detail={
|
||||||
"details": details or {}
|
"error": message,
|
||||||
})
|
"error_code": error_code,
|
||||||
|
"details": details or {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentNotFoundError(BaseAPIException):
|
class AgentNotFoundError(BaseAPIException):
|
||||||
"""Exception when the agent is not found"""
|
"""Exception when the agent is not found"""
|
||||||
|
|
||||||
def __init__(self, agent_id: str):
|
def __init__(self, agent_id: str):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
message=f"Agent with ID {agent_id} not found",
|
message=f"Agent with ID {agent_id} not found",
|
||||||
error_code="AGENT_NOT_FOUND"
|
error_code="AGENT_NOT_FOUND",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvalidParameterError(BaseAPIException):
|
class InvalidParameterError(BaseAPIException):
|
||||||
"""Exception for invalid parameters"""
|
"""Exception for invalid parameters"""
|
||||||
|
|
||||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
message=message,
|
message=message,
|
||||||
error_code="INVALID_PARAMETER",
|
error_code="INVALID_PARAMETER",
|
||||||
details=details
|
details=details,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvalidRequestError(BaseAPIException):
|
class InvalidRequestError(BaseAPIException):
|
||||||
"""Exception for invalid requests"""
|
"""Exception for invalid requests"""
|
||||||
|
|
||||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
message=message,
|
message=message,
|
||||||
error_code="INVALID_REQUEST",
|
error_code="INVALID_REQUEST",
|
||||||
details=details
|
details=details,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InternalServerError(BaseAPIException):
|
class InternalServerError(BaseAPIException):
|
||||||
"""Exception for server errors"""
|
"""Exception for server errors"""
|
||||||
|
|
||||||
def __init__(self, message: str = "Server error"):
|
def __init__(self, message: str = "Server error"):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=500,
|
status_code=500, message=message, error_code="INTERNAL_SERVER_ERROR"
|
||||||
message=message,
|
)
|
||||||
error_code="INTERNAL_SERVER_ERROR"
|
|
||||||
)
|
|
||||||
|
@ -13,16 +13,17 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||||
|
|
||||||
|
|
||||||
async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
|
async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
|
||||||
"""
|
"""
|
||||||
Extracts and validates the JWT token
|
Extracts and validates the JWT token
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: Token JWT
|
token: Token JWT
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Token payload data
|
dict: Token payload data
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If the token is invalid
|
HTTPException: If the token is invalid
|
||||||
"""
|
"""
|
||||||
@ -31,86 +32,90 @@ async def get_jwt_token(token: str = Depends(oauth2_scheme)) -> dict:
|
|||||||
detail="Invalid credentials",
|
detail="Invalid credentials",
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token,
|
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||||
settings.JWT_SECRET_KEY,
|
|
||||||
algorithms=[settings.JWT_ALGORITHM]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
email: str = payload.get("sub")
|
email: str = payload.get("sub")
|
||||||
if email is None:
|
if email is None:
|
||||||
logger.warning("Token without email (sub)")
|
logger.warning("Token without email (sub)")
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
exp = payload.get("exp")
|
exp = payload.get("exp")
|
||||||
if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow():
|
if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow():
|
||||||
logger.warning(f"Token expired for {email}")
|
logger.warning(f"Token expired for {email}")
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
except JWTError as e:
|
except JWTError as e:
|
||||||
logger.error(f"Error decoding JWT token: {str(e)}")
|
logger.error(f"Error decoding JWT token: {str(e)}")
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
|
|
||||||
async def verify_user_client(
|
async def verify_user_client(
|
||||||
payload: dict = Depends(get_jwt_token),
|
payload: dict = Depends(get_jwt_token),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
required_client_id: UUID = None
|
required_client_id: UUID = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Verifies if the user is associated with the specified client
|
Verifies if the user is associated with the specified client
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
payload: Token JWT payload
|
payload: Token JWT payload
|
||||||
db: Database session
|
db: Database session
|
||||||
required_client_id: Client ID to be verified
|
required_client_id: Client ID to be verified
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True se verificado com sucesso
|
bool: True se verificado com sucesso
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If the user does not have permission
|
HTTPException: If the user does not have permission
|
||||||
"""
|
"""
|
||||||
# Administrators have access to all clients
|
# Administrators have access to all clients
|
||||||
if payload.get("is_admin", False):
|
if payload.get("is_admin", False):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Para não-admins, verificar se o client_id corresponde
|
# Para não-admins, verificar se o client_id corresponde
|
||||||
user_client_id = payload.get("client_id")
|
user_client_id = payload.get("client_id")
|
||||||
if not user_client_id:
|
if not user_client_id:
|
||||||
logger.warning(f"Non-admin user without client_id in token: {payload.get('sub')}")
|
logger.warning(
|
||||||
|
f"Non-admin user without client_id in token: {payload.get('sub')}"
|
||||||
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="User not associated with a client"
|
detail="User not associated with a client",
|
||||||
)
|
)
|
||||||
|
|
||||||
# If no client_id is specified to verify, any client is valid
|
# If no client_id is specified to verify, any client is valid
|
||||||
if not required_client_id:
|
if not required_client_id:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Verify if the user's client_id corresponds to the required_client_id
|
# Verify if the user's client_id corresponds to the required_client_id
|
||||||
if str(user_client_id) != str(required_client_id):
|
if str(user_client_id) != str(required_client_id):
|
||||||
logger.warning(f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}")
|
logger.warning(
|
||||||
|
f"Access denied: User {payload.get('sub')} tried to access resources of client {required_client_id}"
|
||||||
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access denied to access resources of this client"
|
detail="Access denied to access resources of this client",
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
|
async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
|
||||||
"""
|
"""
|
||||||
Verifies if the user is an administrator
|
Verifies if the user is an administrator
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
payload: Token JWT payload
|
payload: Token JWT payload
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the user is an administrator
|
bool: True if the user is an administrator
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If the user is not an administrator
|
HTTPException: If the user is not an administrator
|
||||||
"""
|
"""
|
||||||
@ -118,26 +123,29 @@ async def verify_admin(payload: dict = Depends(get_jwt_token)) -> bool:
|
|||||||
logger.warning(f"Access denied to admin: User {payload.get('sub')}")
|
logger.warning(f"Access denied to admin: User {payload.get('sub')}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access denied. Restricted to administrators."
|
detail="Access denied. Restricted to administrators.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_current_user_client_id(payload: dict = Depends(get_jwt_token)) -> Optional[UUID]:
|
|
||||||
|
def get_current_user_client_id(
|
||||||
|
payload: dict = Depends(get_jwt_token),
|
||||||
|
) -> Optional[UUID]:
|
||||||
"""
|
"""
|
||||||
Gets the ID of the client associated with the current user
|
Gets the ID of the client associated with the current user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
payload: Token JWT payload
|
payload: Token JWT payload
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[UUID]: Client ID or None if the user is an administrator
|
Optional[UUID]: Client ID or None if the user is an administrator
|
||||||
"""
|
"""
|
||||||
if payload.get("is_admin", False):
|
if payload.get("is_admin", False):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
client_id = payload.get("client_id")
|
client_id = payload.get("client_id")
|
||||||
if client_id:
|
if client_id:
|
||||||
return UUID(client_id)
|
return UUID(client_id)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
55
src/main.py
55
src/main.py
@ -1,35 +1,30 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Add the root directory to PYTHONPATH
|
|
||||||
root_dir = Path(__file__).parent.parent
|
|
||||||
sys.path.append(str(root_dir))
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from src.config.database import engine, Base
|
from src.config.database import engine, Base
|
||||||
from src.config.settings import settings
|
from src.config.settings import settings
|
||||||
from src.utils.logger import setup_logger
|
from src.utils.logger import setup_logger
|
||||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
|
||||||
from google.adk.sessions import DatabaseSessionService
|
|
||||||
from google.adk.memory import InMemoryMemoryService
|
|
||||||
|
|
||||||
# Initialize service instances
|
# Necessary for other modules
|
||||||
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
|
from src.services.service_providers import session_service # noqa: F401
|
||||||
artifacts_service = InMemoryArtifactService()
|
from src.services.service_providers import artifacts_service # noqa: F401
|
||||||
memory_service = InMemoryMemoryService()
|
from src.services.service_providers import memory_service # noqa: F401
|
||||||
|
|
||||||
# Import routers after service initialization to avoid circular imports
|
import src.api.auth_routes
|
||||||
from src.api.auth_routes import router as auth_router
|
import src.api.admin_routes
|
||||||
from src.api.admin_routes import router as admin_router
|
import src.api.chat_routes
|
||||||
from src.api.chat_routes import router as chat_router
|
import src.api.session_routes
|
||||||
from src.api.session_routes import router as session_router
|
import src.api.agent_routes
|
||||||
from src.api.agent_routes import router as agent_router
|
import src.api.contact_routes
|
||||||
from src.api.contact_routes import router as contact_router
|
import src.api.mcp_server_routes
|
||||||
from src.api.mcp_server_routes import router as mcp_server_router
|
import src.api.tool_routes
|
||||||
from src.api.tool_routes import router as tool_router
|
import src.api.client_routes
|
||||||
from src.api.client_routes import router as client_router
|
|
||||||
|
# Add the root directory to PYTHONPATH
|
||||||
|
root_dir = Path(__file__).parent.parent
|
||||||
|
sys.path.append(str(root_dir))
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
@ -52,8 +47,7 @@ app.add_middleware(
|
|||||||
|
|
||||||
# PostgreSQL configuration
|
# PostgreSQL configuration
|
||||||
POSTGRES_CONNECTION_STRING = os.getenv(
|
POSTGRES_CONNECTION_STRING = os.getenv(
|
||||||
"POSTGRES_CONNECTION_STRING",
|
"POSTGRES_CONNECTION_STRING", "postgresql://postgres:root@localhost:5432/evo_ai"
|
||||||
"postgresql://postgres:root@localhost:5432/evo_ai"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create database tables
|
# Create database tables
|
||||||
@ -61,6 +55,17 @@ Base.metadata.create_all(bind=engine)
|
|||||||
|
|
||||||
API_PREFIX = "/api/v1"
|
API_PREFIX = "/api/v1"
|
||||||
|
|
||||||
|
# Define router references
|
||||||
|
auth_router = src.api.auth_routes.router
|
||||||
|
admin_router = src.api.admin_routes.router
|
||||||
|
chat_router = src.api.chat_routes.router
|
||||||
|
session_router = src.api.session_routes.router
|
||||||
|
agent_router = src.api.agent_routes.router
|
||||||
|
contact_router = src.api.contact_routes.router
|
||||||
|
mcp_server_router = src.api.mcp_server_routes.router
|
||||||
|
tool_router = src.api.tool_routes.router
|
||||||
|
client_router = src.api.client_routes.router
|
||||||
|
|
||||||
# Include routes
|
# Include routes
|
||||||
app.include_router(auth_router, prefix=API_PREFIX)
|
app.include_router(auth_router, prefix=API_PREFIX)
|
||||||
app.include_router(admin_router, prefix=API_PREFIX)
|
app.include_router(admin_router, prefix=API_PREFIX)
|
||||||
@ -79,5 +84,5 @@ def read_root():
|
|||||||
"message": "Welcome to Evo AI API",
|
"message": "Welcome to Evo AI API",
|
||||||
"documentation": "/docs",
|
"documentation": "/docs",
|
||||||
"version": settings.API_VERSION,
|
"version": settings.API_VERSION,
|
||||||
"auth": "To access the API, use JWT authentication via '/api/v1/auth/login'"
|
"auth": "To access the API, use JWT authentication via '/api/v1/auth/login'",
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,20 @@
|
|||||||
from sqlalchemy import Column, String, UUID, DateTime, ForeignKey, JSON, Text, BigInteger, CheckConstraint, Boolean
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
String,
|
||||||
|
UUID,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
JSON,
|
||||||
|
Text,
|
||||||
|
CheckConstraint,
|
||||||
|
Boolean,
|
||||||
|
)
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from sqlalchemy.orm import relationship, backref
|
from sqlalchemy.orm import relationship, backref
|
||||||
from src.config.database import Base
|
from src.config.database import Base
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
class Client(Base):
|
class Client(Base):
|
||||||
__tablename__ = "clients"
|
__tablename__ = "clients"
|
||||||
|
|
||||||
@ -13,13 +24,16 @@ class Client(Base):
|
|||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
email = Column(String, unique=True, index=True, nullable=False)
|
email = Column(String, unique=True, index=True, nullable=False)
|
||||||
password_hash = Column(String, nullable=False)
|
password_hash = Column(String, nullable=False)
|
||||||
client_id = Column(UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True)
|
client_id = Column(
|
||||||
|
UUID(as_uuid=True), ForeignKey("clients.id", ondelete="CASCADE"), nullable=True
|
||||||
|
)
|
||||||
is_active = Column(Boolean, default=False)
|
is_active = Column(Boolean, default=False)
|
||||||
is_admin = Column(Boolean, default=False)
|
is_admin = Column(Boolean, default=False)
|
||||||
email_verified = Column(Boolean, default=False)
|
email_verified = Column(Boolean, default=False)
|
||||||
@ -29,9 +43,12 @@ class User(Base):
|
|||||||
password_reset_expiry = Column(DateTime(timezone=True), nullable=True)
|
password_reset_expiry = Column(DateTime(timezone=True), nullable=True)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
|
||||||
# Relationship with Client (One-to-One, optional for administrators)
|
# Relationship with Client (One-to-One, optional for administrators)
|
||||||
client = relationship("Client", backref=backref("user", uselist=False, cascade="all, delete-orphan"))
|
client = relationship(
|
||||||
|
"Client", backref=backref("user", uselist=False, cascade="all, delete-orphan")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Contact(Base):
|
class Contact(Base):
|
||||||
__tablename__ = "contacts"
|
__tablename__ = "contacts"
|
||||||
@ -44,6 +61,7 @@ class Contact(Base):
|
|||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
|
||||||
|
|
||||||
class Agent(Base):
|
class Agent(Base):
|
||||||
__tablename__ = "agents"
|
__tablename__ = "agents"
|
||||||
|
|
||||||
@ -60,21 +78,30 @@ class Agent(Base):
|
|||||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
CheckConstraint("type IN ('llm', 'sequential', 'parallel', 'loop')", name='check_agent_type'),
|
CheckConstraint(
|
||||||
|
"type IN ('llm', 'sequential', 'parallel', 'loop')", name="check_agent_type"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""Converts the object to a dictionary, converting UUIDs to strings"""
|
"""Converts the object to a dictionary, converting UUIDs to strings"""
|
||||||
result = {}
|
result = {}
|
||||||
for key, value in self.__dict__.items():
|
for key, value in self.__dict__.items():
|
||||||
if key.startswith('_'):
|
if key.startswith("_"):
|
||||||
continue
|
continue
|
||||||
if isinstance(value, uuid.UUID):
|
if isinstance(value, uuid.UUID):
|
||||||
result[key] = str(value)
|
result[key] = str(value)
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
result[key] = self._convert_dict(value)
|
result[key] = self._convert_dict(value)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value]
|
result[key] = [
|
||||||
|
(
|
||||||
|
self._convert_dict(item)
|
||||||
|
if isinstance(item, dict)
|
||||||
|
else str(item) if isinstance(item, uuid.UUID) else item
|
||||||
|
)
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
result[key] = value
|
result[key] = value
|
||||||
return result
|
return result
|
||||||
@ -88,11 +115,19 @@ class Agent(Base):
|
|||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
result[key] = self._convert_dict(value)
|
result[key] = self._convert_dict(value)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
result[key] = [self._convert_dict(item) if isinstance(item, dict) else str(item) if isinstance(item, uuid.UUID) else item for item in value]
|
result[key] = [
|
||||||
|
(
|
||||||
|
self._convert_dict(item)
|
||||||
|
if isinstance(item, dict)
|
||||||
|
else str(item) if isinstance(item, uuid.UUID) else item
|
||||||
|
)
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
result[key] = value
|
result[key] = value
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class MCPServer(Base):
|
class MCPServer(Base):
|
||||||
__tablename__ = "mcp_servers"
|
__tablename__ = "mcp_servers"
|
||||||
|
|
||||||
@ -105,11 +140,14 @@ class MCPServer(Base):
|
|||||||
type = Column(String, nullable=False, default="official")
|
type = Column(String, nullable=False, default="official")
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
CheckConstraint("type IN ('official', 'community')", name='check_mcp_server_type'),
|
CheckConstraint(
|
||||||
|
"type IN ('official', 'community')", name="check_mcp_server_type"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Tool(Base):
|
class Tool(Base):
|
||||||
__tablename__ = "tools"
|
__tablename__ = "tools"
|
||||||
|
|
||||||
@ -121,11 +159,12 @@ class Tool(Base):
|
|||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
|
||||||
|
|
||||||
class Session(Base):
|
class Session(Base):
|
||||||
__tablename__ = "sessions"
|
__tablename__ = "sessions"
|
||||||
# The directive below makes Alembic ignore this table in migrations
|
# The directive below makes Alembic ignore this table in migrations
|
||||||
__table_args__ = {'extend_existing': True, 'info': {'skip_autogenerate': True}}
|
__table_args__ = {"extend_existing": True, "info": {"skip_autogenerate": True}}
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True)
|
||||||
app_name = Column(String)
|
app_name = Column(String)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
@ -133,11 +172,14 @@ class Session(Base):
|
|||||||
create_time = Column(DateTime(timezone=True))
|
create_time = Column(DateTime(timezone=True))
|
||||||
update_time = Column(DateTime(timezone=True))
|
update_time = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
|
|
||||||
class AuditLog(Base):
|
class AuditLog(Base):
|
||||||
__tablename__ = "audit_logs"
|
__tablename__ = "audit_logs"
|
||||||
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
action = Column(String, nullable=False)
|
action = Column(String, nullable=False)
|
||||||
resource_type = Column(String, nullable=False)
|
resource_type = Column(String, nullable=False)
|
||||||
resource_id = Column(String, nullable=True)
|
resource_id = Column(String, nullable=True)
|
||||||
@ -145,6 +187,6 @@ class AuditLog(Base):
|
|||||||
ip_address = Column(String, nullable=True)
|
ip_address = Column(String, nullable=True)
|
||||||
user_agent = Column(String, nullable=True)
|
user_agent = Column(String, nullable=True)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
# Relationship with User
|
# Relationship with User
|
||||||
user = relationship("User", backref="audit_logs")
|
user = relationship("User", backref="audit_logs")
|
||||||
|
@ -1,26 +1,38 @@
|
|||||||
from typing import List, Optional, Dict, Any, Union
|
from typing import List, Optional, Dict, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
class ToolConfig(BaseModel):
|
class ToolConfig(BaseModel):
|
||||||
"""Configuration of a tool"""
|
"""Configuration of a tool"""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the tool")
|
envs: Dict[str, str] = Field(
|
||||||
|
default_factory=dict, description="Environment variables of the tool"
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class MCPServerConfig(BaseModel):
|
class MCPServerConfig(BaseModel):
|
||||||
"""Configuration of an MCP server"""
|
"""Configuration of an MCP server"""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
envs: Dict[str, str] = Field(default_factory=dict, description="Environment variables of the server")
|
envs: Dict[str, str] = Field(
|
||||||
tools: List[str] = Field(default_factory=list, description="List of tools of the server")
|
default_factory=dict, description="Environment variables of the server"
|
||||||
|
)
|
||||||
|
tools: List[str] = Field(
|
||||||
|
default_factory=list, description="List of tools of the server"
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class HTTPToolParameter(BaseModel):
|
class HTTPToolParameter(BaseModel):
|
||||||
"""Parameter of an HTTP tool"""
|
"""Parameter of an HTTP tool"""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
required: bool
|
required: bool
|
||||||
description: str
|
description: str
|
||||||
@ -28,8 +40,10 @@ class HTTPToolParameter(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class HTTPToolParameters(BaseModel):
|
class HTTPToolParameters(BaseModel):
|
||||||
"""Parameters of an HTTP tool"""
|
"""Parameters of an HTTP tool"""
|
||||||
|
|
||||||
path_params: Optional[Dict[str, str]] = None
|
path_params: Optional[Dict[str, str]] = None
|
||||||
query_params: Optional[Dict[str, Union[str, List[str]]]] = None
|
query_params: Optional[Dict[str, Union[str, List[str]]]] = None
|
||||||
body_params: Optional[Dict[str, HTTPToolParameter]] = None
|
body_params: Optional[Dict[str, HTTPToolParameter]] = None
|
||||||
@ -37,8 +51,10 @@ class HTTPToolParameters(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class HTTPToolErrorHandling(BaseModel):
|
class HTTPToolErrorHandling(BaseModel):
|
||||||
"""Configuration of error handling"""
|
"""Configuration of error handling"""
|
||||||
|
|
||||||
timeout: int
|
timeout: int
|
||||||
retry_count: int
|
retry_count: int
|
||||||
fallback_response: Dict[str, str]
|
fallback_response: Dict[str, str]
|
||||||
@ -46,8 +62,10 @@ class HTTPToolErrorHandling(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class HTTPTool(BaseModel):
|
class HTTPTool(BaseModel):
|
||||||
"""Configuration of an HTTP tool"""
|
"""Configuration of an HTTP tool"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
method: str
|
method: str
|
||||||
values: Dict[str, str]
|
values: Dict[str, str]
|
||||||
@ -60,42 +78,72 @@ class HTTPTool(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class CustomTools(BaseModel):
|
class CustomTools(BaseModel):
|
||||||
"""Configuration of custom tools"""
|
"""Configuration of custom tools"""
|
||||||
http_tools: List[HTTPTool] = Field(default_factory=list, description="List of HTTP tools")
|
|
||||||
|
http_tools: List[HTTPTool] = Field(
|
||||||
|
default_factory=list, description="List of HTTP tools"
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel):
|
class LLMConfig(BaseModel):
|
||||||
"""Configuration for LLM agents"""
|
"""Configuration for LLM agents"""
|
||||||
tools: Optional[List[ToolConfig]] = Field(default=None, description="List of available tools")
|
|
||||||
custom_tools: Optional[CustomTools] = Field(default=None, description="Custom tools")
|
tools: Optional[List[ToolConfig]] = Field(
|
||||||
mcp_servers: Optional[List[MCPServerConfig]] = Field(default=None, description="List of MCP servers")
|
default=None, description="List of available tools"
|
||||||
sub_agents: Optional[List[UUID]] = Field(default=None, description="List of IDs of sub-agents")
|
)
|
||||||
|
custom_tools: Optional[CustomTools] = Field(
|
||||||
|
default=None, description="Custom tools"
|
||||||
|
)
|
||||||
|
mcp_servers: Optional[List[MCPServerConfig]] = Field(
|
||||||
|
default=None, description="List of MCP servers"
|
||||||
|
)
|
||||||
|
sub_agents: Optional[List[UUID]] = Field(
|
||||||
|
default=None, description="List of IDs of sub-agents"
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class SequentialConfig(BaseModel):
|
class SequentialConfig(BaseModel):
|
||||||
"""Configuration for sequential agents"""
|
"""Configuration for sequential agents"""
|
||||||
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents in execution order")
|
|
||||||
|
sub_agents: List[UUID] = Field(
|
||||||
|
..., description="List of IDs of sub-agents in execution order"
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class ParallelConfig(BaseModel):
|
class ParallelConfig(BaseModel):
|
||||||
"""Configuration for parallel agents"""
|
"""Configuration for parallel agents"""
|
||||||
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for parallel execution")
|
|
||||||
|
sub_agents: List[UUID] = Field(
|
||||||
|
..., description="List of IDs of sub-agents for parallel execution"
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class LoopConfig(BaseModel):
|
class LoopConfig(BaseModel):
|
||||||
"""Configuration for loop agents"""
|
"""Configuration for loop agents"""
|
||||||
sub_agents: List[UUID] = Field(..., description="List of IDs of sub-agents for loop execution")
|
|
||||||
max_iterations: Optional[int] = Field(default=None, description="Maximum number of iterations")
|
sub_agents: List[UUID] = Field(
|
||||||
condition: Optional[str] = Field(default=None, description="Condition to stop the loop")
|
..., description="List of IDs of sub-agents for loop execution"
|
||||||
|
)
|
||||||
|
max_iterations: Optional[int] = Field(
|
||||||
|
default=None, description="Maximum number of iterations"
|
||||||
|
)
|
||||||
|
condition: Optional[str] = Field(
|
||||||
|
default=None, description="Condition to stop the loop"
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
@ -3,30 +3,38 @@ from typing import Optional, Dict, Any
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
class AuditLogBase(BaseModel):
|
class AuditLogBase(BaseModel):
|
||||||
"""Base schema for audit log"""
|
"""Base schema for audit log"""
|
||||||
|
|
||||||
action: str
|
action: str
|
||||||
resource_type: str
|
resource_type: str
|
||||||
resource_id: Optional[str] = None
|
resource_id: Optional[str] = None
|
||||||
details: Optional[Dict[str, Any]] = None
|
details: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class AuditLogCreate(AuditLogBase):
|
class AuditLogCreate(AuditLogBase):
|
||||||
"""Schema for creating audit log"""
|
"""Schema for creating audit log"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AuditLogResponse(AuditLogBase):
|
class AuditLogResponse(AuditLogBase):
|
||||||
"""Schema for audit log response"""
|
"""Schema for audit log response"""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
user_id: Optional[UUID] = None
|
user_id: Optional[UUID] = None
|
||||||
ip_address: Optional[str] = None
|
ip_address: Optional[str] = None
|
||||||
user_agent: Optional[str] = None
|
user_agent: Optional[str] = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class AuditLogFilter(BaseModel):
|
class AuditLogFilter(BaseModel):
|
||||||
"""Schema for audit log search filters"""
|
"""Schema for audit log search filters"""
|
||||||
|
|
||||||
user_id: Optional[UUID] = None
|
user_id: Optional[UUID] = None
|
||||||
action: Optional[str] = None
|
action: Optional[str] = None
|
||||||
resource_type: Optional[str] = None
|
resource_type: Optional[str] = None
|
||||||
@ -34,4 +42,4 @@ class AuditLogFilter(BaseModel):
|
|||||||
start_date: Optional[datetime] = None
|
start_date: Optional[datetime] = None
|
||||||
end_date: Optional[datetime] = None
|
end_date: Optional[datetime] = None
|
||||||
skip: Optional[int] = Field(0, ge=0)
|
skip: Optional[int] = Field(0, ge=0)
|
||||||
limit: Optional[int] = Field(100, ge=1, le=1000)
|
limit: Optional[int] = Field(100, ge=1, le=1000)
|
||||||
|
@ -1,21 +1,33 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
"""Schema for chat requests"""
|
"""Schema for chat requests"""
|
||||||
agent_id: str = Field(..., description="ID of the agent that will process the message")
|
|
||||||
contact_id: str = Field(..., description="ID of the contact that will process the message")
|
agent_id: str = Field(
|
||||||
|
..., description="ID of the agent that will process the message"
|
||||||
|
)
|
||||||
|
contact_id: str = Field(
|
||||||
|
..., description="ID of the contact that will process the message"
|
||||||
|
)
|
||||||
message: str = Field(..., description="User message")
|
message: str = Field(..., description="User message")
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
"""Schema for chat responses"""
|
"""Schema for chat responses"""
|
||||||
|
|
||||||
response: str = Field(..., description="Agent response")
|
response: str = Field(..., description="Agent response")
|
||||||
status: str = Field(..., description="Operation status")
|
status: str = Field(..., description="Operation status")
|
||||||
error: Optional[str] = Field(None, description="Error message, if there is one")
|
error: Optional[str] = Field(None, description="Error message, if there is one")
|
||||||
timestamp: str = Field(..., description="Timestamp of the response")
|
timestamp: str = Field(..., description="Timestamp of the response")
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
class ErrorResponse(BaseModel):
|
||||||
"""Schema for error responses"""
|
"""Schema for error responses"""
|
||||||
|
|
||||||
error: str = Field(..., description="Error message")
|
error: str = Field(..., description="Error message")
|
||||||
status_code: int = Field(..., description="HTTP status code of the error")
|
status_code: int = Field(..., description="HTTP status code of the error")
|
||||||
details: Optional[Dict[str, Any]] = Field(None, description="Additional error details")
|
details: Optional[Dict[str, Any]] = Field(
|
||||||
|
None, description="Additional error details"
|
||||||
|
)
|
||||||
|
@ -4,15 +4,18 @@ from datetime import datetime
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import uuid
|
import uuid
|
||||||
import re
|
import re
|
||||||
from .agent_config import LLMConfig, SequentialConfig, ParallelConfig, LoopConfig
|
from src.schemas.agent_config import LLMConfig
|
||||||
|
|
||||||
|
|
||||||
class ClientBase(BaseModel):
|
class ClientBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
email: Optional[EmailStr] = None
|
email: Optional[EmailStr] = None
|
||||||
|
|
||||||
|
|
||||||
class ClientCreate(ClientBase):
|
class ClientCreate(ClientBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Client(ClientBase):
|
class Client(ClientBase):
|
||||||
id: UUID
|
id: UUID
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
@ -20,14 +23,17 @@ class Client(ClientBase):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class ContactBase(BaseModel):
|
class ContactBase(BaseModel):
|
||||||
ext_id: Optional[str] = None
|
ext_id: Optional[str] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
meta: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
meta: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ContactCreate(ContactBase):
|
class ContactCreate(ContactBase):
|
||||||
client_id: UUID
|
client_id: UUID
|
||||||
|
|
||||||
|
|
||||||
class Contact(ContactBase):
|
class Contact(ContactBase):
|
||||||
id: UUID
|
id: UUID
|
||||||
client_id: UUID
|
client_id: UUID
|
||||||
@ -35,67 +41,80 @@ class Contact(ContactBase):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class AgentBase(BaseModel):
|
class AgentBase(BaseModel):
|
||||||
name: str = Field(..., description="Agent name (no spaces or special characters)")
|
name: str = Field(..., description="Agent name (no spaces or special characters)")
|
||||||
description: Optional[str] = Field(None, description="Agent description")
|
description: Optional[str] = Field(None, description="Agent description")
|
||||||
type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)")
|
type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)")
|
||||||
model: Optional[str] = Field(None, description="Agent model (required only for llm type)")
|
model: Optional[str] = Field(
|
||||||
api_key: Optional[str] = Field(None, description="Agent API Key (required only for llm type)")
|
None, description="Agent model (required only for llm type)"
|
||||||
|
)
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
None, description="Agent API Key (required only for llm type)"
|
||||||
|
)
|
||||||
instruction: Optional[str] = None
|
instruction: Optional[str] = None
|
||||||
config: Union[LLMConfig, Dict[str, Any]] = Field(..., description="Agent configuration based on type")
|
config: Union[LLMConfig, Dict[str, Any]] = Field(
|
||||||
|
..., description="Agent configuration based on type"
|
||||||
|
)
|
||||||
|
|
||||||
@validator('name')
|
@validator("name")
|
||||||
def validate_name(cls, v):
|
def validate_name(cls, v):
|
||||||
if not re.match(r'^[a-zA-Z0-9_-]+$', v):
|
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||||
raise ValueError('Agent name cannot contain spaces or special characters')
|
raise ValueError("Agent name cannot contain spaces or special characters")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator('type')
|
@validator("type")
|
||||||
def validate_type(cls, v):
|
def validate_type(cls, v):
|
||||||
if v not in ['llm', 'sequential', 'parallel', 'loop']:
|
if v not in ["llm", "sequential", "parallel", "loop"]:
|
||||||
raise ValueError('Invalid agent type. Must be: llm, sequential, parallel or loop')
|
raise ValueError(
|
||||||
|
"Invalid agent type. Must be: llm, sequential, parallel or loop"
|
||||||
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator('model')
|
@validator("model")
|
||||||
def validate_model(cls, v, values):
|
def validate_model(cls, v, values):
|
||||||
if 'type' in values and values['type'] == 'llm' and not v:
|
if "type" in values and values["type"] == "llm" and not v:
|
||||||
raise ValueError('Model is required for llm type agents')
|
raise ValueError("Model is required for llm type agents")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator('api_key')
|
@validator("api_key")
|
||||||
def validate_api_key(cls, v, values):
|
def validate_api_key(cls, v, values):
|
||||||
if 'type' in values and values['type'] == 'llm' and not v:
|
if "type" in values and values["type"] == "llm" and not v:
|
||||||
raise ValueError('API Key is required for llm type agents')
|
raise ValueError("API Key is required for llm type agents")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator('config')
|
@validator("config")
|
||||||
def validate_config(cls, v, values):
|
def validate_config(cls, v, values):
|
||||||
if 'type' not in values:
|
if "type" not in values:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
if values['type'] == 'llm':
|
if values["type"] == "llm":
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
try:
|
try:
|
||||||
# Convert the dictionary to LLMConfig
|
# Convert the dictionary to LLMConfig
|
||||||
v = LLMConfig(**v)
|
v = LLMConfig(**v)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f'Invalid LLM configuration for agent: {str(e)}')
|
raise ValueError(f"Invalid LLM configuration for agent: {str(e)}")
|
||||||
elif not isinstance(v, LLMConfig):
|
elif not isinstance(v, LLMConfig):
|
||||||
raise ValueError('Invalid LLM configuration for agent')
|
raise ValueError("Invalid LLM configuration for agent")
|
||||||
elif values['type'] in ['sequential', 'parallel', 'loop']:
|
elif values["type"] in ["sequential", "parallel", "loop"]:
|
||||||
if not isinstance(v, dict):
|
if not isinstance(v, dict):
|
||||||
raise ValueError(f'Invalid configuration for agent {values["type"]}')
|
raise ValueError(f'Invalid configuration for agent {values["type"]}')
|
||||||
if 'sub_agents' not in v:
|
if "sub_agents" not in v:
|
||||||
raise ValueError(f'Agent {values["type"]} must have sub_agents')
|
raise ValueError(f'Agent {values["type"]} must have sub_agents')
|
||||||
if not isinstance(v['sub_agents'], list):
|
if not isinstance(v["sub_agents"], list):
|
||||||
raise ValueError('sub_agents must be a list')
|
raise ValueError("sub_agents must be a list")
|
||||||
if not v['sub_agents']:
|
if not v["sub_agents"]:
|
||||||
raise ValueError(f'Agent {values["type"]} must have at least one sub-agent')
|
raise ValueError(
|
||||||
|
f'Agent {values["type"]} must have at least one sub-agent'
|
||||||
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
class AgentCreate(AgentBase):
|
class AgentCreate(AgentBase):
|
||||||
client_id: UUID
|
client_id: UUID
|
||||||
|
|
||||||
|
|
||||||
class Agent(AgentBase):
|
class Agent(AgentBase):
|
||||||
id: UUID
|
id: UUID
|
||||||
client_id: UUID
|
client_id: UUID
|
||||||
@ -105,6 +124,7 @@ class Agent(AgentBase):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class MCPServerBase(BaseModel):
|
class MCPServerBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
@ -113,9 +133,11 @@ class MCPServerBase(BaseModel):
|
|||||||
tools: List[str] = Field(default_factory=list)
|
tools: List[str] = Field(default_factory=list)
|
||||||
type: str = Field(default="official")
|
type: str = Field(default="official")
|
||||||
|
|
||||||
|
|
||||||
class MCPServerCreate(MCPServerBase):
|
class MCPServerCreate(MCPServerBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MCPServer(MCPServerBase):
|
class MCPServer(MCPServerBase):
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
@ -124,19 +146,22 @@ class MCPServer(MCPServerBase):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class ToolBase(BaseModel):
|
class ToolBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
config_json: Dict[str, Any] = Field(default_factory=dict)
|
config_json: Dict[str, Any] = Field(default_factory=dict)
|
||||||
environments: Dict[str, Any] = Field(default_factory=dict)
|
environments: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ToolCreate(ToolBase):
|
class ToolCreate(ToolBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Tool(ToolBase):
|
class Tool(ToolBase):
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
@ -1,23 +1,28 @@
|
|||||||
from pydantic import BaseModel, EmailStr, Field
|
from pydantic import BaseModel, EmailStr
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
class UserBase(BaseModel):
|
class UserBase(BaseModel):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(UserBase):
|
class UserCreate(UserBase):
|
||||||
password: str
|
password: str
|
||||||
name: str # For client creation
|
name: str # For client creation
|
||||||
|
|
||||||
|
|
||||||
class AdminUserCreate(UserBase):
|
class AdminUserCreate(UserBase):
|
||||||
password: str
|
password: str
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class UserLogin(BaseModel):
|
class UserLogin(BaseModel):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(UserBase):
|
class UserResponse(UserBase):
|
||||||
id: UUID
|
id: UUID
|
||||||
client_id: Optional[UUID] = None
|
client_id: Optional[UUID] = None
|
||||||
@ -25,26 +30,31 @@ class UserResponse(UserBase):
|
|||||||
email_verified: bool
|
email_verified: bool
|
||||||
is_admin: bool
|
is_admin: bool
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class TokenResponse(BaseModel):
|
class TokenResponse(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
class TokenData(BaseModel):
|
class TokenData(BaseModel):
|
||||||
sub: str # user email
|
sub: str # user email
|
||||||
exp: datetime
|
exp: datetime
|
||||||
is_admin: bool
|
is_admin: bool
|
||||||
client_id: Optional[UUID] = None
|
client_id: Optional[UUID] = None
|
||||||
|
|
||||||
|
|
||||||
class PasswordReset(BaseModel):
|
class PasswordReset(BaseModel):
|
||||||
token: str
|
token: str
|
||||||
new_password: str
|
new_password: str
|
||||||
|
|
||||||
|
|
||||||
class ForgotPassword(BaseModel):
|
class ForgotPassword(BaseModel):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
class MessageResponse(BaseModel):
|
class MessageResponse(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
|
@ -1 +1 @@
|
|||||||
from .agent_runner import run_agent
|
from .agent_runner import run_agent
|
||||||
|
@ -13,11 +13,10 @@ from google.adk.agents.callback_context import CallbackContext
|
|||||||
from google.adk.models import LlmResponse, LlmRequest
|
from google.adk.models import LlmResponse, LlmRequest
|
||||||
from google.adk.tools import load_memory
|
from google.adk.tools import load_memory
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import requests
|
import requests
|
||||||
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -83,7 +82,7 @@ def before_model_callback(
|
|||||||
llm_request.config.system_instruction = modified_text
|
llm_request.config.system_instruction = modified_text
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"📝 System instruction updated with search results and history"
|
"📝 System instruction updated with search results and history"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("⚠️ No results found in the search")
|
logger.warning("⚠️ No results found in the search")
|
||||||
@ -180,11 +179,13 @@ class AgentBuilder:
|
|||||||
mcp_tools = []
|
mcp_tools = []
|
||||||
mcp_exit_stack = None
|
mcp_exit_stack = None
|
||||||
if agent.config.get("mcp_servers"):
|
if agent.config.get("mcp_servers"):
|
||||||
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(agent.config, self.db)
|
mcp_tools, mcp_exit_stack = await self.mcp_service.build_tools(
|
||||||
|
agent.config, self.db
|
||||||
|
)
|
||||||
|
|
||||||
# Combine all tools
|
# Combine all tools
|
||||||
all_tools = custom_tools + mcp_tools
|
all_tools = custom_tools + mcp_tools
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
current_datetime = now.strftime("%d/%m/%Y %H:%M")
|
current_datetime = now.strftime("%d/%m/%Y %H:%M")
|
||||||
current_day_of_week = now.strftime("%A")
|
current_day_of_week = now.strftime("%A")
|
||||||
@ -201,10 +202,13 @@ class AgentBuilder:
|
|||||||
|
|
||||||
# Check if load_memory is enabled
|
# Check if load_memory is enabled
|
||||||
# before_model_callback_func = None
|
# before_model_callback_func = None
|
||||||
if agent.config.get("load_memory") == True:
|
if agent.config.get("load_memory"):
|
||||||
all_tools.append(load_memory)
|
all_tools.append(load_memory)
|
||||||
# before_model_callback_func = before_model_callback
|
# before_model_callback_func = before_model_callback
|
||||||
formatted_prompt = formatted_prompt + "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
|
formatted_prompt = (
|
||||||
|
formatted_prompt
|
||||||
|
+ "\n\n<memory_instructions>ALWAYS use the load_memory tool to retrieve knowledge for your context</memory_instructions>\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
LlmAgent(
|
LlmAgent(
|
||||||
|
@ -22,9 +22,7 @@ async def run_agent(
|
|||||||
db: Session,
|
db: Session,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(f"Starting execution of agent {agent_id} for contact {contact_id}")
|
||||||
f"Starting execution of agent {agent_id} for contact {contact_id}"
|
|
||||||
)
|
|
||||||
logger.info(f"Received message: {message}")
|
logger.info(f"Received message: {message}")
|
||||||
|
|
||||||
get_root_agent = get_agent(db, agent_id)
|
get_root_agent = get_agent(db, agent_id)
|
||||||
@ -77,15 +75,15 @@ async def run_agent(
|
|||||||
if event.is_final_response() and event.content and event.content.parts:
|
if event.is_final_response() and event.content and event.content.parts:
|
||||||
final_response_text = event.content.parts[0].text
|
final_response_text = event.content.parts[0].text
|
||||||
logger.info(f"Final response received: {final_response_text}")
|
logger.info(f"Final response received: {final_response_text}")
|
||||||
|
|
||||||
completed_session = session_service.get_session(
|
completed_session = session_service.get_session(
|
||||||
app_name=agent_id,
|
app_name=agent_id,
|
||||||
user_id=contact_id,
|
user_id=contact_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_service.add_session_to_memory(completed_session)
|
memory_service.add_session_to_memory(completed_session)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Ensure the exit_stack is closed correctly
|
# Ensure the exit_stack is closed correctly
|
||||||
if exit_stack:
|
if exit_stack:
|
||||||
|
@ -216,9 +216,7 @@ async def update_agent(
|
|||||||
return agent
|
return agent
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail=f"Error updating agent: {str(e)}")
|
||||||
status_code=500, detail=f"Error updating agent: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_agent(db: Session, agent_id: uuid.UUID) -> bool:
|
def delete_agent(db: Session, agent_id: uuid.UUID) -> bool:
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from src.models.models import AuditLog, User
|
from src.models.models import AuditLog
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
import json
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_audit_log(
|
def create_audit_log(
|
||||||
db: Session,
|
db: Session,
|
||||||
user_id: Optional[uuid.UUID],
|
user_id: Optional[uuid.UUID],
|
||||||
@ -17,11 +17,11 @@ def create_audit_log(
|
|||||||
resource_type: str,
|
resource_type: str,
|
||||||
resource_id: Optional[str] = None,
|
resource_id: Optional[str] = None,
|
||||||
details: Optional[Dict[str, Any]] = None,
|
details: Optional[Dict[str, Any]] = None,
|
||||||
request: Optional[Request] = None
|
request: Optional[Request] = None,
|
||||||
) -> Optional[AuditLog]:
|
) -> Optional[AuditLog]:
|
||||||
"""
|
"""
|
||||||
Create a new audit log
|
Create a new audit log
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
user_id: User ID that performed the action (or None if anonymous)
|
user_id: User ID that performed the action (or None if anonymous)
|
||||||
@ -30,25 +30,25 @@ def create_audit_log(
|
|||||||
resource_id: Resource ID (optional)
|
resource_id: Resource ID (optional)
|
||||||
details: Additional details of the action (optional)
|
details: Additional details of the action (optional)
|
||||||
request: FastAPI Request object (optional, to get IP and User-Agent)
|
request: FastAPI Request object (optional, to get IP and User-Agent)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[AuditLog]: Created audit log or None in case of error
|
Optional[AuditLog]: Created audit log or None in case of error
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
ip_address = None
|
ip_address = None
|
||||||
user_agent = None
|
user_agent = None
|
||||||
|
|
||||||
if request:
|
if request:
|
||||||
ip_address = request.client.host if hasattr(request, 'client') else None
|
ip_address = request.client.host if hasattr(request, "client") else None
|
||||||
user_agent = request.headers.get("user-agent")
|
user_agent = request.headers.get("user-agent")
|
||||||
|
|
||||||
# Convert details to serializable format
|
# Convert details to serializable format
|
||||||
if details:
|
if details:
|
||||||
# Convert UUIDs to strings
|
# Convert UUIDs to strings
|
||||||
for key, value in details.items():
|
for key, value in details.items():
|
||||||
if isinstance(value, uuid.UUID):
|
if isinstance(value, uuid.UUID):
|
||||||
details[key] = str(value)
|
details[key] = str(value)
|
||||||
|
|
||||||
audit_log = AuditLog(
|
audit_log = AuditLog(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
action=action,
|
action=action,
|
||||||
@ -56,20 +56,20 @@ def create_audit_log(
|
|||||||
resource_id=str(resource_id) if resource_id else None,
|
resource_id=str(resource_id) if resource_id else None,
|
||||||
details=details,
|
details=details,
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
user_agent=user_agent
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(audit_log)
|
db.add(audit_log)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(audit_log)
|
db.refresh(audit_log)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Audit log created: {action} in {resource_type}" +
|
f"Audit log created: {action} in {resource_type}"
|
||||||
(f" (ID: {resource_id})" if resource_id else "")
|
+ (f" (ID: {resource_id})" if resource_id else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
return audit_log
|
return audit_log
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"Error creating audit log: {str(e)}")
|
logger.error(f"Error creating audit log: {str(e)}")
|
||||||
@ -78,6 +78,7 @@ def create_audit_log(
|
|||||||
logger.error(f"Unexpected error creating audit log: {str(e)}")
|
logger.error(f"Unexpected error creating audit log: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_audit_logs(
|
def get_audit_logs(
|
||||||
db: Session,
|
db: Session,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
@ -87,11 +88,11 @@ def get_audit_logs(
|
|||||||
resource_type: Optional[str] = None,
|
resource_type: Optional[str] = None,
|
||||||
resource_id: Optional[str] = None,
|
resource_id: Optional[str] = None,
|
||||||
start_date: Optional[datetime] = None,
|
start_date: Optional[datetime] = None,
|
||||||
end_date: Optional[datetime] = None
|
end_date: Optional[datetime] = None,
|
||||||
) -> List[AuditLog]:
|
) -> List[AuditLog]:
|
||||||
"""
|
"""
|
||||||
Get audit logs with optional filters
|
Get audit logs with optional filters
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
skip: Number of records to skip
|
skip: Number of records to skip
|
||||||
@ -102,35 +103,35 @@ def get_audit_logs(
|
|||||||
resource_id: Filter by resource ID
|
resource_id: Filter by resource ID
|
||||||
start_date: Start date
|
start_date: Start date
|
||||||
end_date: End date
|
end_date: End date
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[AuditLog]: List of audit logs
|
List[AuditLog]: List of audit logs
|
||||||
"""
|
"""
|
||||||
query = db.query(AuditLog)
|
query = db.query(AuditLog)
|
||||||
|
|
||||||
# Apply filters, if provided
|
# Apply filters, if provided
|
||||||
if user_id:
|
if user_id:
|
||||||
query = query.filter(AuditLog.user_id == user_id)
|
query = query.filter(AuditLog.user_id == user_id)
|
||||||
|
|
||||||
if action:
|
if action:
|
||||||
query = query.filter(AuditLog.action == action)
|
query = query.filter(AuditLog.action == action)
|
||||||
|
|
||||||
if resource_type:
|
if resource_type:
|
||||||
query = query.filter(AuditLog.resource_type == resource_type)
|
query = query.filter(AuditLog.resource_type == resource_type)
|
||||||
|
|
||||||
if resource_id:
|
if resource_id:
|
||||||
query = query.filter(AuditLog.resource_id == resource_id)
|
query = query.filter(AuditLog.resource_id == resource_id)
|
||||||
|
|
||||||
if start_date:
|
if start_date:
|
||||||
query = query.filter(AuditLog.created_at >= start_date)
|
query = query.filter(AuditLog.created_at >= start_date)
|
||||||
|
|
||||||
if end_date:
|
if end_date:
|
||||||
query = query.filter(AuditLog.created_at <= end_date)
|
query = query.filter(AuditLog.created_at <= end_date)
|
||||||
|
|
||||||
# Order by creation date (most recent first)
|
# Order by creation date (most recent first)
|
||||||
query = query.order_by(AuditLog.created_at.desc())
|
query = query.order_by(AuditLog.created_at.desc())
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
query = query.offset(skip).limit(limit)
|
query = query.offset(skip).limit(limit)
|
||||||
|
|
||||||
return query.all()
|
return query.all()
|
||||||
|
@ -16,17 +16,20 @@ logger = logging.getLogger(__name__)
|
|||||||
# Define OAuth2 authentication scheme with password flow
|
# Define OAuth2 authentication scheme with password flow
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||||
|
|
||||||
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User:
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
|
||||||
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Get the current user from the JWT token
|
Get the current user from the JWT token
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: JWT token
|
token: JWT token
|
||||||
db: Database session
|
db: Database session
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
User: Current user
|
User: Current user
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If the token is invalid or the user is not found
|
HTTPException: If the token is invalid or the user is not found
|
||||||
"""
|
"""
|
||||||
@ -35,103 +38,108 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
|
|||||||
detail="Invalid credentials",
|
detail="Invalid credentials",
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Decode the token
|
# Decode the token
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token,
|
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||||
settings.JWT_SECRET_KEY,
|
|
||||||
algorithms=[settings.JWT_ALGORITHM]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract token data
|
# Extract token data
|
||||||
email: str = payload.get("sub")
|
email: str = payload.get("sub")
|
||||||
if email is None:
|
if email is None:
|
||||||
logger.warning("Token without email (sub)")
|
logger.warning("Token without email (sub)")
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
# Check if the token has expired
|
# Check if the token has expired
|
||||||
exp = payload.get("exp")
|
exp = payload.get("exp")
|
||||||
if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow():
|
if exp is None or datetime.fromtimestamp(exp) < datetime.utcnow():
|
||||||
logger.warning(f"Token expired for {email}")
|
logger.warning(f"Token expired for {email}")
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
# Create TokenData object
|
# Create TokenData object
|
||||||
token_data = TokenData(
|
token_data = TokenData(
|
||||||
sub=email,
|
sub=email,
|
||||||
exp=datetime.fromtimestamp(exp),
|
exp=datetime.fromtimestamp(exp),
|
||||||
is_admin=payload.get("is_admin", False),
|
is_admin=payload.get("is_admin", False),
|
||||||
client_id=payload.get("client_id")
|
client_id=payload.get("client_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
except JWTError as e:
|
except JWTError as e:
|
||||||
logger.error(f"Error decoding JWT token: {str(e)}")
|
logger.error(f"Error decoding JWT token: {str(e)}")
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
# Search for user in the database
|
# Search for user in the database
|
||||||
user = get_user_by_email(db, email=token_data.sub)
|
user = get_user_by_email(db, email=token_data.sub)
|
||||||
if user is None:
|
if user is None:
|
||||||
logger.warning(f"User not found for email: {token_data.sub}")
|
logger.warning(f"User not found for email: {token_data.sub}")
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
logger.warning(f"Attempt to access inactive user: {user.email}")
|
logger.warning(f"Attempt to access inactive user: {user.email}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||||
detail="Inactive user"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
|
|
||||||
|
async def get_current_active_user(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Check if the current user is active
|
Check if the current user is active
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
current_user: Current user
|
current_user: Current user
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
User: Current user if active
|
User: Current user if active
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If the user is not active
|
HTTPException: If the user is not active
|
||||||
"""
|
"""
|
||||||
if not current_user.is_active:
|
if not current_user.is_active:
|
||||||
logger.warning(f"Attempt to access inactive user: {current_user.email}")
|
logger.warning(f"Attempt to access inactive user: {current_user.email}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||||
detail="Inactive user"
|
|
||||||
)
|
)
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
async def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User:
|
|
||||||
|
async def get_current_admin_user(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Check if the current user is an administrator
|
Check if the current user is an administrator
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
current_user: Current user
|
current_user: Current user
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
User: Current user if administrator
|
User: Current user if administrator
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If the user is not an administrator
|
HTTPException: If the user is not an administrator
|
||||||
"""
|
"""
|
||||||
if not current_user.is_admin:
|
if not current_user.is_admin:
|
||||||
logger.warning(f"Attempt to access admin by non-admin user: {current_user.email}")
|
logger.warning(
|
||||||
|
f"Attempt to access admin by non-admin user: {current_user.email}"
|
||||||
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access denied. Restricted to administrators."
|
detail="Access denied. Restricted to administrators.",
|
||||||
)
|
)
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(user: User) -> str:
|
def create_access_token(user: User) -> str:
|
||||||
"""
|
"""
|
||||||
Create a JWT access token for the user
|
Create a JWT access token for the user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user: User for which to create the token
|
user: User for which to create the token
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: JWT token
|
str: JWT token
|
||||||
"""
|
"""
|
||||||
@ -140,10 +148,10 @@ def create_access_token(user: User) -> str:
|
|||||||
"sub": user.email,
|
"sub": user.email,
|
||||||
"is_admin": user.is_admin,
|
"is_admin": user.is_admin,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Include client_id only if not administrator and client_id is set
|
# Include client_id only if not administrator and client_id is set
|
||||||
if not user.is_admin and user.client_id:
|
if not user.is_admin and user.client_id:
|
||||||
token_data["client_id"] = str(user.client_id)
|
token_data["client_id"] = str(user.client_id)
|
||||||
|
|
||||||
# Create token
|
# Create token
|
||||||
return create_jwt_token(token_data)
|
return create_jwt_token(token_data)
|
||||||
|
@ -11,6 +11,7 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
|
def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
|
||||||
"""Search for a client by ID"""
|
"""Search for a client by ID"""
|
||||||
try:
|
try:
|
||||||
@ -23,9 +24,10 @@ def get_client(db: Session, client_id: uuid.UUID) -> Optional[Client]:
|
|||||||
logger.error(f"Error searching for client {client_id}: {str(e)}")
|
logger.error(f"Error searching for client {client_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error searching for client"
|
detail="Error searching for client",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
|
def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
|
||||||
"""Search for all clients with pagination"""
|
"""Search for all clients with pagination"""
|
||||||
try:
|
try:
|
||||||
@ -34,9 +36,10 @@ def get_clients(db: Session, skip: int = 0, limit: int = 100) -> List[Client]:
|
|||||||
logger.error(f"Error searching for clients: {str(e)}")
|
logger.error(f"Error searching for clients: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error searching for clients"
|
detail="Error searching for clients",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_client(db: Session, client: ClientCreate) -> Client:
|
def create_client(db: Session, client: ClientCreate) -> Client:
|
||||||
"""Create a new client"""
|
"""Create a new client"""
|
||||||
try:
|
try:
|
||||||
@ -51,19 +54,22 @@ def create_client(db: Session, client: ClientCreate) -> Client:
|
|||||||
logger.error(f"Error creating client: {str(e)}")
|
logger.error(f"Error creating client: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error creating client"
|
detail="Error creating client",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Optional[Client]:
|
|
||||||
|
def update_client(
|
||||||
|
db: Session, client_id: uuid.UUID, client: ClientCreate
|
||||||
|
) -> Optional[Client]:
|
||||||
"""Updates an existing client"""
|
"""Updates an existing client"""
|
||||||
try:
|
try:
|
||||||
db_client = get_client(db, client_id)
|
db_client = get_client(db, client_id)
|
||||||
if not db_client:
|
if not db_client:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
for key, value in client.model_dump().items():
|
for key, value in client.model_dump().items():
|
||||||
setattr(db_client, key, value)
|
setattr(db_client, key, value)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_client)
|
db.refresh(db_client)
|
||||||
logger.info(f"Client updated successfully: {client_id}")
|
logger.info(f"Client updated successfully: {client_id}")
|
||||||
@ -73,16 +79,17 @@ def update_client(db: Session, client_id: uuid.UUID, client: ClientCreate) -> Op
|
|||||||
logger.error(f"Error updating client {client_id}: {str(e)}")
|
logger.error(f"Error updating client {client_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error updating client"
|
detail="Error updating client",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def delete_client(db: Session, client_id: uuid.UUID) -> bool:
|
def delete_client(db: Session, client_id: uuid.UUID) -> bool:
|
||||||
"""Removes a client"""
|
"""Removes a client"""
|
||||||
try:
|
try:
|
||||||
db_client = get_client(db, client_id)
|
db_client = get_client(db, client_id)
|
||||||
if not db_client:
|
if not db_client:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
db.delete(db_client)
|
db.delete(db_client)
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"Client removed successfully: {client_id}")
|
logger.info(f"Client removed successfully: {client_id}")
|
||||||
@ -92,18 +99,21 @@ def delete_client(db: Session, client_id: uuid.UUID) -> bool:
|
|||||||
logger.error(f"Error removing client {client_id}: {str(e)}")
|
logger.error(f"Error removing client {client_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error removing client"
|
detail="Error removing client",
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_client_with_user(db: Session, client_data: ClientCreate, user_data: UserCreate) -> Tuple[Optional[Client], str]:
|
|
||||||
|
def create_client_with_user(
|
||||||
|
db: Session, client_data: ClientCreate, user_data: UserCreate
|
||||||
|
) -> Tuple[Optional[Client], str]:
|
||||||
"""
|
"""
|
||||||
Creates a new client with an associated user
|
Creates a new client with an associated user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
client_data: Client data to be created
|
client_data: Client data to be created
|
||||||
user_data: User data to be created
|
user_data: User data to be created
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Optional[Client], str]: Tuple with the created client (or None in case of error) and status message
|
Tuple[Optional[Client], str]: Tuple with the created client (or None in case of error) and status message
|
||||||
"""
|
"""
|
||||||
@ -112,27 +122,27 @@ def create_client_with_user(db: Session, client_data: ClientCreate, user_data: U
|
|||||||
client = Client(**client_data.model_dump())
|
client = Client(**client_data.model_dump())
|
||||||
db.add(client)
|
db.add(client)
|
||||||
db.flush() # Get client ID without committing the transaction
|
db.flush() # Get client ID without committing the transaction
|
||||||
|
|
||||||
# Use client ID to create the associated user
|
# Use client ID to create the associated user
|
||||||
user, message = create_user(db, user_data, is_admin=False, client_id=client.id)
|
user, message = create_user(db, user_data, is_admin=False, client_id=client.id)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
# If there was an error creating the user, rollback
|
# If there was an error creating the user, rollback
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"Error creating user for client: {message}")
|
logger.error(f"Error creating user for client: {message}")
|
||||||
return None, f"Error creating user: {message}"
|
return None, f"Error creating user: {message}"
|
||||||
|
|
||||||
# If everything went well, commit the transaction
|
# If everything went well, commit the transaction
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"Client and user created successfully: {client.id}")
|
logger.info(f"Client and user created successfully: {client.id}")
|
||||||
return client, "Client and user created successfully"
|
return client, "Client and user created successfully"
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"Error creating client with user: {str(e)}")
|
logger.error(f"Error creating client with user: {str(e)}")
|
||||||
return None, f"Error creating client with user: {str(e)}"
|
return None, f"Error creating client with user: {str(e)}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"Unexpected error creating client with user: {str(e)}")
|
logger.error(f"Unexpected error creating client with user: {str(e)}")
|
||||||
return None, f"Unexpected error: {str(e)}"
|
return None, f"Unexpected error: {str(e)}"
|
||||||
|
@ -9,6 +9,7 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
|
def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
|
||||||
"""Search for a contact by ID"""
|
"""Search for a contact by ID"""
|
||||||
try:
|
try:
|
||||||
@ -21,20 +22,30 @@ def get_contact(db: Session, contact_id: uuid.UUID) -> Optional[Contact]:
|
|||||||
logger.error(f"Error searching for contact {contact_id}: {str(e)}")
|
logger.error(f"Error searching for contact {contact_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error searching for contact"
|
detail="Error searching for contact",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_contacts_by_client(db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100) -> List[Contact]:
|
|
||||||
|
def get_contacts_by_client(
|
||||||
|
db: Session, client_id: uuid.UUID, skip: int = 0, limit: int = 100
|
||||||
|
) -> List[Contact]:
|
||||||
"""Search for contacts of a client with pagination"""
|
"""Search for contacts of a client with pagination"""
|
||||||
try:
|
try:
|
||||||
return db.query(Contact).filter(Contact.client_id == client_id).offset(skip).limit(limit).all()
|
return (
|
||||||
|
db.query(Contact)
|
||||||
|
.filter(Contact.client_id == client_id)
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"Error searching for contacts of client {client_id}: {str(e)}")
|
logger.error(f"Error searching for contacts of client {client_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error searching for contacts"
|
detail="Error searching for contacts",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_contact(db: Session, contact: ContactCreate) -> Contact:
|
def create_contact(db: Session, contact: ContactCreate) -> Contact:
|
||||||
"""Create a new contact"""
|
"""Create a new contact"""
|
||||||
try:
|
try:
|
||||||
@ -49,19 +60,22 @@ def create_contact(db: Session, contact: ContactCreate) -> Contact:
|
|||||||
logger.error(f"Error creating contact: {str(e)}")
|
logger.error(f"Error creating contact: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error creating contact"
|
detail="Error creating contact",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -> Optional[Contact]:
|
|
||||||
|
def update_contact(
|
||||||
|
db: Session, contact_id: uuid.UUID, contact: ContactCreate
|
||||||
|
) -> Optional[Contact]:
|
||||||
"""Update an existing contact"""
|
"""Update an existing contact"""
|
||||||
try:
|
try:
|
||||||
db_contact = get_contact(db, contact_id)
|
db_contact = get_contact(db, contact_id)
|
||||||
if not db_contact:
|
if not db_contact:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
for key, value in contact.model_dump().items():
|
for key, value in contact.model_dump().items():
|
||||||
setattr(db_contact, key, value)
|
setattr(db_contact, key, value)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_contact)
|
db.refresh(db_contact)
|
||||||
logger.info(f"Contact updated successfully: {contact_id}")
|
logger.info(f"Contact updated successfully: {contact_id}")
|
||||||
@ -71,16 +85,17 @@ def update_contact(db: Session, contact_id: uuid.UUID, contact: ContactCreate) -
|
|||||||
logger.error(f"Error updating contact {contact_id}: {str(e)}")
|
logger.error(f"Error updating contact {contact_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error updating contact"
|
detail="Error updating contact",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
|
def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
|
||||||
"""Remove a contact"""
|
"""Remove a contact"""
|
||||||
try:
|
try:
|
||||||
db_contact = get_contact(db, contact_id)
|
db_contact = get_contact(db, contact_id)
|
||||||
if not db_contact:
|
if not db_contact:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
db.delete(db_contact)
|
db.delete(db_contact)
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"Contact removed successfully: {contact_id}")
|
logger.info(f"Contact removed successfully: {contact_id}")
|
||||||
@ -90,5 +105,5 @@ def delete_contact(db: Session, contact_id: uuid.UUID) -> bool:
|
|||||||
logger.error(f"Error removing contact {contact_id}: {str(e)}")
|
logger.error(f"Error removing contact {contact_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error removing contact"
|
detail="Error removing contact",
|
||||||
)
|
)
|
||||||
|
@ -6,6 +6,7 @@ from src.utils.logger import setup_logger
|
|||||||
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CustomToolBuilder:
|
class CustomToolBuilder:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tools = []
|
self.tools = []
|
||||||
@ -53,7 +54,9 @@ class CustomToolBuilder:
|
|||||||
|
|
||||||
# Adds default values to query params if they are not present
|
# Adds default values to query params if they are not present
|
||||||
for param, value in values.items():
|
for param, value in values.items():
|
||||||
if param not in query_params and param not in parameters.get("path_params", {}):
|
if param not in query_params and param not in parameters.get(
|
||||||
|
"path_params", {}
|
||||||
|
):
|
||||||
query_params[param] = value
|
query_params[param] = value
|
||||||
|
|
||||||
# Processa body parameters
|
# Processa body parameters
|
||||||
@ -64,7 +67,11 @@ class CustomToolBuilder:
|
|||||||
|
|
||||||
# Adds default values to body if they are not present
|
# Adds default values to body if they are not present
|
||||||
for param, value in values.items():
|
for param, value in values.items():
|
||||||
if param not in body_data and param not in query_params and param not in parameters.get("path_params", {}):
|
if (
|
||||||
|
param not in body_data
|
||||||
|
and param not in query_params
|
||||||
|
and param not in parameters.get("path_params", {})
|
||||||
|
):
|
||||||
body_data[param] = value
|
body_data[param] = value
|
||||||
|
|
||||||
# Makes the HTTP request
|
# Makes the HTTP request
|
||||||
@ -74,7 +81,7 @@ class CustomToolBuilder:
|
|||||||
headers=processed_headers,
|
headers=processed_headers,
|
||||||
params=query_params,
|
params=query_params,
|
||||||
json=body_data if body_data else None,
|
json=body_data if body_data else None,
|
||||||
timeout=error_handling.get("timeout", 30)
|
timeout=error_handling.get("timeout", 30),
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code >= 400:
|
if response.status_code >= 400:
|
||||||
@ -87,30 +94,34 @@ class CustomToolBuilder:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error executing tool {name}: {str(e)}")
|
logger.error(f"Error executing tool {name}: {str(e)}")
|
||||||
return json.dumps(error_handling.get("fallback_response", {
|
return json.dumps(
|
||||||
"error": "tool_execution_error",
|
error_handling.get(
|
||||||
"message": str(e)
|
"fallback_response",
|
||||||
}))
|
{"error": "tool_execution_error", "message": str(e)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Adds dynamic docstring based on the configuration
|
# Adds dynamic docstring based on the configuration
|
||||||
param_docs = []
|
param_docs = []
|
||||||
|
|
||||||
# Adds path parameters
|
# Adds path parameters
|
||||||
for param, value in parameters.get("path_params", {}).items():
|
for param, value in parameters.get("path_params", {}).items():
|
||||||
param_docs.append(f"{param}: {value}")
|
param_docs.append(f"{param}: {value}")
|
||||||
|
|
||||||
# Adds query parameters
|
# Adds query parameters
|
||||||
for param, value in parameters.get("query_params", {}).items():
|
for param, value in parameters.get("query_params", {}).items():
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
param_docs.append(f"{param}: List[{', '.join(value)}]")
|
param_docs.append(f"{param}: List[{', '.join(value)}]")
|
||||||
else:
|
else:
|
||||||
param_docs.append(f"{param}: {value}")
|
param_docs.append(f"{param}: {value}")
|
||||||
|
|
||||||
# Adds body parameters
|
# Adds body parameters
|
||||||
for param, param_config in parameters.get("body_params", {}).items():
|
for param, param_config in parameters.get("body_params", {}).items():
|
||||||
required = "Required" if param_config.get("required", False) else "Optional"
|
required = "Required" if param_config.get("required", False) else "Optional"
|
||||||
param_docs.append(f"{param} ({param_config['type']}, {required}): {param_config['description']}")
|
param_docs.append(
|
||||||
|
f"{param} ({param_config['type']}, {required}): {param_config['description']}"
|
||||||
|
)
|
||||||
|
|
||||||
# Adds default values
|
# Adds default values
|
||||||
if values:
|
if values:
|
||||||
param_docs.append("\nDefault values:")
|
param_docs.append("\nDefault values:")
|
||||||
@ -119,10 +130,10 @@ class CustomToolBuilder:
|
|||||||
|
|
||||||
http_tool.__doc__ = f"""
|
http_tool.__doc__ = f"""
|
||||||
{description}
|
{description}
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
{chr(10).join(param_docs)}
|
{chr(10).join(param_docs)}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
String containing the response in JSON format
|
String containing the response in JSON format
|
||||||
"""
|
"""
|
||||||
@ -140,4 +151,4 @@ class CustomToolBuilder:
|
|||||||
for http_tool_config in tools_config.get("http_tools", []):
|
for http_tool_config in tools_config.get("http_tools", []):
|
||||||
self.tools.append(self._create_http_tool(http_tool_config))
|
self.tools.append(self._create_http_tool(http_tool_config))
|
||||||
|
|
||||||
return self.tools
|
return self.tools
|
||||||
|
@ -16,17 +16,18 @@ os.makedirs(templates_dir, exist_ok=True)
|
|||||||
# Configure Jinja2 with the templates directory
|
# Configure Jinja2 with the templates directory
|
||||||
env = Environment(
|
env = Environment(
|
||||||
loader=FileSystemLoader(templates_dir),
|
loader=FileSystemLoader(templates_dir),
|
||||||
autoescape=select_autoescape(['html', 'xml'])
|
autoescape=select_autoescape(["html", "xml"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _render_template(template_name: str, context: dict) -> str:
|
def _render_template(template_name: str, context: dict) -> str:
|
||||||
"""
|
"""
|
||||||
Render a template with the provided data
|
Render a template with the provided data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
template_name: Template file name
|
template_name: Template file name
|
||||||
context: Data to render in the template
|
context: Data to render in the template
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Rendered HTML
|
str: Rendered HTML
|
||||||
"""
|
"""
|
||||||
@ -37,14 +38,15 @@ def _render_template(template_name: str, context: dict) -> str:
|
|||||||
logger.error(f"Error rendering template '{template_name}': {str(e)}")
|
logger.error(f"Error rendering template '{template_name}': {str(e)}")
|
||||||
return f"<p>Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}</p>"
|
return f"<p>Could not display email content. Please access {context.get('verification_link', '') or context.get('reset_link', '')}</p>"
|
||||||
|
|
||||||
|
|
||||||
def send_verification_email(email: str, token: str) -> bool:
|
def send_verification_email(email: str, token: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Send a verification email to the user
|
Send a verification email to the user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
email: Recipient's email
|
email: Recipient's email
|
||||||
token: Email verification token
|
token: Email verification token
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the email was sent successfully, False otherwise
|
bool: True if the email was sent successfully, False otherwise
|
||||||
"""
|
"""
|
||||||
@ -53,39 +55,47 @@ def send_verification_email(email: str, token: str) -> bool:
|
|||||||
from_email = Email(settings.EMAIL_FROM)
|
from_email = Email(settings.EMAIL_FROM)
|
||||||
to_email = To(email)
|
to_email = To(email)
|
||||||
subject = "Email Verification - Evo AI"
|
subject = "Email Verification - Evo AI"
|
||||||
|
|
||||||
verification_link = f"{settings.APP_URL}/auth/verify-email/{token}"
|
verification_link = f"{settings.APP_URL}/auth/verify-email/{token}"
|
||||||
|
|
||||||
html_content = _render_template('verification_email', {
|
html_content = _render_template(
|
||||||
'verification_link': verification_link,
|
"verification_email",
|
||||||
'user_name': email.split('@')[0], # Use part of the email as temporary name
|
{
|
||||||
'current_year': datetime.now().year
|
"verification_link": verification_link,
|
||||||
})
|
"user_name": email.split("@")[
|
||||||
|
0
|
||||||
|
], # Use part of the email as temporary name
|
||||||
|
"current_year": datetime.now().year,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
content = Content("text/html", html_content)
|
content = Content("text/html", html_content)
|
||||||
|
|
||||||
mail = Mail(from_email, to_email, subject, content)
|
mail = Mail(from_email, to_email, subject, content)
|
||||||
response = sg.client.mail.send.post(request_body=mail.get())
|
response = sg.client.mail.send.post(request_body=mail.get())
|
||||||
|
|
||||||
if response.status_code >= 200 and response.status_code < 300:
|
if response.status_code >= 200 and response.status_code < 300:
|
||||||
logger.info(f"Verification email sent to {email}")
|
logger.info(f"Verification email sent to {email}")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to send verification email to {email}. Status: {response.status_code}")
|
logger.error(
|
||||||
|
f"Failed to send verification email to {email}. Status: {response.status_code}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending verification email to {email}: {str(e)}")
|
logger.error(f"Error sending verification email to {email}: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def send_password_reset_email(email: str, token: str) -> bool:
|
def send_password_reset_email(email: str, token: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Send a password reset email to the user
|
Send a password reset email to the user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
email: Recipient's email
|
email: Recipient's email
|
||||||
token: Password reset token
|
token: Password reset token
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the email was sent successfully, False otherwise
|
bool: True if the email was sent successfully, False otherwise
|
||||||
"""
|
"""
|
||||||
@ -94,39 +104,47 @@ def send_password_reset_email(email: str, token: str) -> bool:
|
|||||||
from_email = Email(settings.EMAIL_FROM)
|
from_email = Email(settings.EMAIL_FROM)
|
||||||
to_email = To(email)
|
to_email = To(email)
|
||||||
subject = "Password Reset - Evo AI"
|
subject = "Password Reset - Evo AI"
|
||||||
|
|
||||||
reset_link = f"{settings.APP_URL}/reset-password?token={token}"
|
reset_link = f"{settings.APP_URL}/reset-password?token={token}"
|
||||||
|
|
||||||
html_content = _render_template('password_reset', {
|
html_content = _render_template(
|
||||||
'reset_link': reset_link,
|
"password_reset",
|
||||||
'user_name': email.split('@')[0], # Use part of the email as temporary name
|
{
|
||||||
'current_year': datetime.now().year
|
"reset_link": reset_link,
|
||||||
})
|
"user_name": email.split("@")[
|
||||||
|
0
|
||||||
|
], # Use part of the email as temporary name
|
||||||
|
"current_year": datetime.now().year,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
content = Content("text/html", html_content)
|
content = Content("text/html", html_content)
|
||||||
|
|
||||||
mail = Mail(from_email, to_email, subject, content)
|
mail = Mail(from_email, to_email, subject, content)
|
||||||
response = sg.client.mail.send.post(request_body=mail.get())
|
response = sg.client.mail.send.post(request_body=mail.get())
|
||||||
|
|
||||||
if response.status_code >= 200 and response.status_code < 300:
|
if response.status_code >= 200 and response.status_code < 300:
|
||||||
logger.info(f"Password reset email sent to {email}")
|
logger.info(f"Password reset email sent to {email}")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to send password reset email to {email}. Status: {response.status_code}")
|
logger.error(
|
||||||
|
f"Failed to send password reset email to {email}. Status: {response.status_code}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending password reset email to {email}: {str(e)}")
|
logger.error(f"Error sending password reset email to {email}: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def send_welcome_email(email: str, user_name: str = None) -> bool:
|
def send_welcome_email(email: str, user_name: str = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Send a welcome email to the user after verification
|
Send a welcome email to the user after verification
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
email: Recipient's email
|
email: Recipient's email
|
||||||
user_name: User's name (optional)
|
user_name: User's name (optional)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the email was sent successfully, False otherwise
|
bool: True if the email was sent successfully, False otherwise
|
||||||
"""
|
"""
|
||||||
@ -135,41 +153,49 @@ def send_welcome_email(email: str, user_name: str = None) -> bool:
|
|||||||
from_email = Email(settings.EMAIL_FROM)
|
from_email = Email(settings.EMAIL_FROM)
|
||||||
to_email = To(email)
|
to_email = To(email)
|
||||||
subject = "Welcome to Evo AI"
|
subject = "Welcome to Evo AI"
|
||||||
|
|
||||||
dashboard_link = f"{settings.APP_URL}/dashboard"
|
dashboard_link = f"{settings.APP_URL}/dashboard"
|
||||||
|
|
||||||
html_content = _render_template('welcome_email', {
|
html_content = _render_template(
|
||||||
'dashboard_link': dashboard_link,
|
"welcome_email",
|
||||||
'user_name': user_name or email.split('@')[0],
|
{
|
||||||
'current_year': datetime.now().year
|
"dashboard_link": dashboard_link,
|
||||||
})
|
"user_name": user_name or email.split("@")[0],
|
||||||
|
"current_year": datetime.now().year,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
content = Content("text/html", html_content)
|
content = Content("text/html", html_content)
|
||||||
|
|
||||||
mail = Mail(from_email, to_email, subject, content)
|
mail = Mail(from_email, to_email, subject, content)
|
||||||
response = sg.client.mail.send.post(request_body=mail.get())
|
response = sg.client.mail.send.post(request_body=mail.get())
|
||||||
|
|
||||||
if response.status_code >= 200 and response.status_code < 300:
|
if response.status_code >= 200 and response.status_code < 300:
|
||||||
logger.info(f"Welcome email sent to {email}")
|
logger.info(f"Welcome email sent to {email}")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to send welcome email to {email}. Status: {response.status_code}")
|
logger.error(
|
||||||
|
f"Failed to send welcome email to {email}. Status: {response.status_code}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending welcome email to {email}: {str(e)}")
|
logger.error(f"Error sending welcome email to {email}: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def send_account_locked_email(email: str, reset_token: str, failed_attempts: int, time_period: str) -> bool:
|
|
||||||
|
def send_account_locked_email(
|
||||||
|
email: str, reset_token: str, failed_attempts: int, time_period: str
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Send an email informing that the account has been locked after login attempts
|
Send an email informing that the account has been locked after login attempts
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
email: Recipient's email
|
email: Recipient's email
|
||||||
reset_token: Token to reset the password
|
reset_token: Token to reset the password
|
||||||
failed_attempts: Number of failed attempts
|
failed_attempts: Number of failed attempts
|
||||||
time_period: Time period of the attempts
|
time_period: Time period of the attempts
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the email was sent successfully, False otherwise
|
bool: True if the email was sent successfully, False otherwise
|
||||||
"""
|
"""
|
||||||
@ -178,29 +204,34 @@ def send_account_locked_email(email: str, reset_token: str, failed_attempts: int
|
|||||||
from_email = Email(settings.EMAIL_FROM)
|
from_email = Email(settings.EMAIL_FROM)
|
||||||
to_email = To(email)
|
to_email = To(email)
|
||||||
subject = "Security Alert - Account Locked"
|
subject = "Security Alert - Account Locked"
|
||||||
|
|
||||||
reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}"
|
reset_link = f"{settings.APP_URL}/reset-password?token={reset_token}"
|
||||||
|
|
||||||
html_content = _render_template('account_locked', {
|
html_content = _render_template(
|
||||||
'reset_link': reset_link,
|
"account_locked",
|
||||||
'user_name': email.split('@')[0],
|
{
|
||||||
'failed_attempts': failed_attempts,
|
"reset_link": reset_link,
|
||||||
'time_period': time_period,
|
"user_name": email.split("@")[0],
|
||||||
'current_year': datetime.now().year
|
"failed_attempts": failed_attempts,
|
||||||
})
|
"time_period": time_period,
|
||||||
|
"current_year": datetime.now().year,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
content = Content("text/html", html_content)
|
content = Content("text/html", html_content)
|
||||||
|
|
||||||
mail = Mail(from_email, to_email, subject, content)
|
mail = Mail(from_email, to_email, subject, content)
|
||||||
response = sg.client.mail.send.post(request_body=mail.get())
|
response = sg.client.mail.send.post(request_body=mail.get())
|
||||||
|
|
||||||
if response.status_code >= 200 and response.status_code < 300:
|
if response.status_code >= 200 and response.status_code < 300:
|
||||||
logger.info(f"Account locked email sent to {email}")
|
logger.info(f"Account locked email sent to {email}")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to send account locked email to {email}. Status: {response.status_code}")
|
logger.error(
|
||||||
|
f"Failed to send account locked email to {email}. Status: {response.status_code}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending account locked email to {email}: {str(e)}")
|
logger.error(f"Error sending account locked email to {email}: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
@ -9,6 +9,7 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
|
def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
|
||||||
"""Search for an MCP server by ID"""
|
"""Search for an MCP server by ID"""
|
||||||
try:
|
try:
|
||||||
@ -21,9 +22,10 @@ def get_mcp_server(db: Session, server_id: uuid.UUID) -> Optional[MCPServer]:
|
|||||||
logger.error(f"Error searching for MCP server {server_id}: {str(e)}")
|
logger.error(f"Error searching for MCP server {server_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error searching for MCP server"
|
detail="Error searching for MCP server",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]:
|
def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPServer]:
|
||||||
"""Search for all MCP servers with pagination"""
|
"""Search for all MCP servers with pagination"""
|
||||||
try:
|
try:
|
||||||
@ -32,9 +34,10 @@ def get_mcp_servers(db: Session, skip: int = 0, limit: int = 100) -> List[MCPSer
|
|||||||
logger.error(f"Error searching for MCP servers: {str(e)}")
|
logger.error(f"Error searching for MCP servers: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error searching for MCP servers"
|
detail="Error searching for MCP servers",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
|
def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
|
||||||
"""Create a new MCP server"""
|
"""Create a new MCP server"""
|
||||||
try:
|
try:
|
||||||
@ -49,19 +52,22 @@ def create_mcp_server(db: Session, server: MCPServerCreate) -> MCPServer:
|
|||||||
logger.error(f"Error creating MCP server: {str(e)}")
|
logger.error(f"Error creating MCP server: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error creating MCP server"
|
detail="Error creating MCP server",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate) -> Optional[MCPServer]:
|
|
||||||
|
def update_mcp_server(
|
||||||
|
db: Session, server_id: uuid.UUID, server: MCPServerCreate
|
||||||
|
) -> Optional[MCPServer]:
|
||||||
"""Update an existing MCP server"""
|
"""Update an existing MCP server"""
|
||||||
try:
|
try:
|
||||||
db_server = get_mcp_server(db, server_id)
|
db_server = get_mcp_server(db, server_id)
|
||||||
if not db_server:
|
if not db_server:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
for key, value in server.model_dump().items():
|
for key, value in server.model_dump().items():
|
||||||
setattr(db_server, key, value)
|
setattr(db_server, key, value)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_server)
|
db.refresh(db_server)
|
||||||
logger.info(f"MCP server updated successfully: {server_id}")
|
logger.info(f"MCP server updated successfully: {server_id}")
|
||||||
@ -71,16 +77,17 @@ def update_mcp_server(db: Session, server_id: uuid.UUID, server: MCPServerCreate
|
|||||||
logger.error(f"Error updating MCP server {server_id}: {str(e)}")
|
logger.error(f"Error updating MCP server {server_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error updating MCP server"
|
detail="Error updating MCP server",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
|
def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
|
||||||
"""Remove an MCP server"""
|
"""Remove an MCP server"""
|
||||||
try:
|
try:
|
||||||
db_server = get_mcp_server(db, server_id)
|
db_server = get_mcp_server(db, server_id)
|
||||||
if not db_server:
|
if not db_server:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
db.delete(db_server)
|
db.delete(db_server)
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"MCP server removed successfully: {server_id}")
|
logger.info(f"MCP server removed successfully: {server_id}")
|
||||||
@ -90,5 +97,5 @@ def delete_mcp_server(db: Session, server_id: uuid.UUID) -> bool:
|
|||||||
logger.error(f"Error removing MCP server {server_id}: {str(e)}")
|
logger.error(f"Error removing MCP server {server_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error removing MCP server"
|
detail="Error removing MCP server",
|
||||||
)
|
)
|
||||||
|
@ -12,26 +12,28 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MCPService:
|
class MCPService:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tools = []
|
self.tools = []
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
async def _connect_to_mcp_server(self, server_config: Dict[str, Any]) -> Tuple[List[Any], Optional[AsyncExitStack]]:
|
async def _connect_to_mcp_server(
|
||||||
|
self, server_config: Dict[str, Any]
|
||||||
|
) -> Tuple[List[Any], Optional[AsyncExitStack]]:
|
||||||
"""Connect to a specific MCP server and return its tools."""
|
"""Connect to a specific MCP server and return its tools."""
|
||||||
try:
|
try:
|
||||||
# Determines the type of server (local or remote)
|
# Determines the type of server (local or remote)
|
||||||
if "url" in server_config:
|
if "url" in server_config:
|
||||||
# Remote server (SSE)
|
# Remote server (SSE)
|
||||||
connection_params = SseServerParams(
|
connection_params = SseServerParams(
|
||||||
url=server_config["url"],
|
url=server_config["url"], headers=server_config.get("headers", {})
|
||||||
headers=server_config.get("headers", {})
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Local server (Stdio)
|
# Local server (Stdio)
|
||||||
command = server_config.get("command", "npx")
|
command = server_config.get("command", "npx")
|
||||||
args = server_config.get("args", [])
|
args = server_config.get("args", [])
|
||||||
|
|
||||||
# Adds environment variables if specified
|
# Adds environment variables if specified
|
||||||
env = server_config.get("env", {})
|
env = server_config.get("env", {})
|
||||||
if env:
|
if env:
|
||||||
@ -39,9 +41,7 @@ class MCPService:
|
|||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
connection_params = StdioServerParameters(
|
connection_params = StdioServerParameters(
|
||||||
command=command,
|
command=command, args=args, env=env
|
||||||
args=args,
|
|
||||||
env=env
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tools, exit_stack = await MCPToolset.from_server(
|
tools, exit_stack = await MCPToolset.from_server(
|
||||||
@ -73,8 +73,10 @@ class MCPService:
|
|||||||
logger.warning(f"Removed {removed_count} incompatible tools.")
|
logger.warning(f"Removed {removed_count} incompatible tools.")
|
||||||
|
|
||||||
return filtered_tools
|
return filtered_tools
|
||||||
|
|
||||||
def _filter_tools_by_agent(self, tools: List[Any], agent_tools: List[str]) -> List[Any]:
|
def _filter_tools_by_agent(
|
||||||
|
self, tools: List[Any], agent_tools: List[str]
|
||||||
|
) -> List[Any]:
|
||||||
"""Filters tools compatible with the agent."""
|
"""Filters tools compatible with the agent."""
|
||||||
filtered_tools = []
|
filtered_tools = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
@ -83,7 +85,9 @@ class MCPService:
|
|||||||
filtered_tools.append(tool)
|
filtered_tools.append(tool)
|
||||||
return filtered_tools
|
return filtered_tools
|
||||||
|
|
||||||
async def build_tools(self, mcp_config: Dict[str, Any], db: Session) -> Tuple[List[Any], AsyncExitStack]:
|
async def build_tools(
|
||||||
|
self, mcp_config: Dict[str, Any], db: Session
|
||||||
|
) -> Tuple[List[Any], AsyncExitStack]:
|
||||||
"""Builds a list of tools from multiple MCP servers."""
|
"""Builds a list of tools from multiple MCP servers."""
|
||||||
self.tools = []
|
self.tools = []
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
@ -92,23 +96,25 @@ class MCPService:
|
|||||||
for server in mcp_config.get("mcp_servers", []):
|
for server in mcp_config.get("mcp_servers", []):
|
||||||
try:
|
try:
|
||||||
# Search for the MCP server in the database
|
# Search for the MCP server in the database
|
||||||
mcp_server = get_mcp_server(db, server['id'])
|
mcp_server = get_mcp_server(db, server["id"])
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
logger.warning(f"Servidor MCP não encontrado: {server['id']}")
|
logger.warning(f"Servidor MCP não encontrado: {server['id']}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Prepares the server configuration
|
# Prepares the server configuration
|
||||||
server_config = mcp_server.config_json.copy()
|
server_config = mcp_server.config_json.copy()
|
||||||
|
|
||||||
# Replaces the environment variables in the config_json
|
# Replaces the environment variables in the config_json
|
||||||
if 'env' in server_config:
|
if "env" in server_config:
|
||||||
for key, value in server_config['env'].items():
|
for key, value in server_config["env"].items():
|
||||||
if value.startswith('env@@'):
|
if value.startswith("env@@"):
|
||||||
env_key = value.replace('env@@', '')
|
env_key = value.replace("env@@", "")
|
||||||
if env_key in server.get('envs', {}):
|
if env_key in server.get("envs", {}):
|
||||||
server_config['env'][key] = server['envs'][env_key]
|
server_config["env"][key] = server["envs"][env_key]
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}")
|
logger.warning(
|
||||||
|
f"Environment variable '{env_key}' not provided for the MCP server {mcp_server.name}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"Connecting to MCP server: {mcp_server.name}")
|
logger.info(f"Connecting to MCP server: {mcp_server.name}")
|
||||||
@ -117,22 +123,30 @@ class MCPService:
|
|||||||
if tools and exit_stack:
|
if tools and exit_stack:
|
||||||
# Filters incompatible tools
|
# Filters incompatible tools
|
||||||
filtered_tools = self._filter_incompatible_tools(tools)
|
filtered_tools = self._filter_incompatible_tools(tools)
|
||||||
|
|
||||||
# Filters tools compatible with the agent
|
# Filters tools compatible with the agent
|
||||||
agent_tools = server.get('tools', [])
|
agent_tools = server.get("tools", [])
|
||||||
filtered_tools = self._filter_tools_by_agent(filtered_tools, agent_tools)
|
filtered_tools = self._filter_tools_by_agent(
|
||||||
|
filtered_tools, agent_tools
|
||||||
|
)
|
||||||
self.tools.extend(filtered_tools)
|
self.tools.extend(filtered_tools)
|
||||||
|
|
||||||
# Registers the exit_stack with the AsyncExitStack
|
# Registers the exit_stack with the AsyncExitStack
|
||||||
await self.exit_stack.enter_async_context(exit_stack)
|
await self.exit_stack.enter_async_context(exit_stack)
|
||||||
logger.info(f"Connected successfully. Added {len(filtered_tools)} tools.")
|
logger.info(
|
||||||
|
f"Connected successfully. Added {len(filtered_tools)} tools."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Failed to connect or no tools available for {mcp_server.name}")
|
logger.warning(
|
||||||
|
f"Failed to connect or no tools available for {mcp_server.name}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error connecting to MCP server {server['id']}: {e}")
|
logger.error(f"Error connecting to MCP server {server['id']}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"MCP Toolset created successfully. Total of {len(self.tools)} tools.")
|
logger.info(
|
||||||
|
f"MCP Toolset created successfully. Total of {len(self.tools)} tools."
|
||||||
|
)
|
||||||
|
|
||||||
return self.tools, self.exit_stack
|
return self.tools, self.exit_stack
|
||||||
|
9
src/services/service_providers.py
Normal file
9
src/services/service_providers.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from src.config.settings import settings
|
||||||
|
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||||
|
from google.adk.sessions import DatabaseSessionService
|
||||||
|
from google.adk.memory import InMemoryMemoryService
|
||||||
|
|
||||||
|
# Initialize service instances
|
||||||
|
session_service = DatabaseSessionService(db_url=settings.POSTGRES_CONNECTION_STRING)
|
||||||
|
artifacts_service = InMemoryArtifactService()
|
||||||
|
memory_service = InMemoryMemoryService()
|
@ -66,7 +66,7 @@ def get_session_by_id(
|
|||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Invalid session ID. Expected format: app_name_user_id",
|
detail="Invalid session ID. Expected format: app_name_user_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
parts = session_id.split("_", 1)
|
parts = session_id.split("_", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
logger.error(f"Invalid session ID format: {session_id}")
|
logger.error(f"Invalid session ID format: {session_id}")
|
||||||
@ -74,22 +74,22 @@ def get_session_by_id(
|
|||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Invalid session ID format. Expected format: app_name_user_id",
|
detail="Invalid session ID format. Expected format: app_name_user_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id, app_name = parts
|
user_id, app_name = parts
|
||||||
|
|
||||||
session = session_service.get_session(
|
session = session_service.get_session(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
logger.error(f"Session not found: {session_id}")
|
logger.error(f"Session not found: {session_id}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=f"Session not found: {session_id}",
|
detail=f"Session not found: {session_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error searching for session {session_id}: {str(e)}")
|
logger.error(f"Error searching for session {session_id}: {str(e)}")
|
||||||
@ -106,7 +106,7 @@ def delete_session(session_service: DatabaseSessionService, session_id: str) ->
|
|||||||
try:
|
try:
|
||||||
session = get_session_by_id(session_service, session_id)
|
session = get_session_by_id(session_service, session_id)
|
||||||
# If we get here, the session exists (get_session_by_id already validates)
|
# If we get here, the session exists (get_session_by_id already validates)
|
||||||
|
|
||||||
session_service.delete_session(
|
session_service.delete_session(
|
||||||
app_name=session.app_name,
|
app_name=session.app_name,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
@ -131,10 +131,10 @@ def get_session_events(
|
|||||||
try:
|
try:
|
||||||
session = get_session_by_id(session_service, session_id)
|
session = get_session_by_id(session_service, session_id)
|
||||||
# If we get here, the session exists (get_session_by_id already validates)
|
# If we get here, the session exists (get_session_by_id already validates)
|
||||||
|
|
||||||
if not hasattr(session, 'events') or session.events is None:
|
if not hasattr(session, "events") or session.events is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return session.events
|
return session.events
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
# Passes HTTP exceptions from get_session_by_id
|
# Passes HTTP exceptions from get_session_by_id
|
||||||
|
@ -9,6 +9,7 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
|
def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
|
||||||
"""Search for a tool by ID"""
|
"""Search for a tool by ID"""
|
||||||
try:
|
try:
|
||||||
@ -21,9 +22,10 @@ def get_tool(db: Session, tool_id: uuid.UUID) -> Optional[Tool]:
|
|||||||
logger.error(f"Error searching for tool {tool_id}: {str(e)}")
|
logger.error(f"Error searching for tool {tool_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error searching for tool"
|
detail="Error searching for tool",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
|
def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
|
||||||
"""Search for all tools with pagination"""
|
"""Search for all tools with pagination"""
|
||||||
try:
|
try:
|
||||||
@ -32,9 +34,10 @@ def get_tools(db: Session, skip: int = 0, limit: int = 100) -> List[Tool]:
|
|||||||
logger.error(f"Error searching for tools: {str(e)}")
|
logger.error(f"Error searching for tools: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error searching for tools"
|
detail="Error searching for tools",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_tool(db: Session, tool: ToolCreate) -> Tool:
|
def create_tool(db: Session, tool: ToolCreate) -> Tool:
|
||||||
"""Creates a new tool"""
|
"""Creates a new tool"""
|
||||||
try:
|
try:
|
||||||
@ -49,19 +52,20 @@ def create_tool(db: Session, tool: ToolCreate) -> Tool:
|
|||||||
logger.error(f"Error creating tool: {str(e)}")
|
logger.error(f"Error creating tool: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error creating tool"
|
detail="Error creating tool",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]:
|
def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[Tool]:
|
||||||
"""Updates an existing tool"""
|
"""Updates an existing tool"""
|
||||||
try:
|
try:
|
||||||
db_tool = get_tool(db, tool_id)
|
db_tool = get_tool(db, tool_id)
|
||||||
if not db_tool:
|
if not db_tool:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
for key, value in tool.model_dump().items():
|
for key, value in tool.model_dump().items():
|
||||||
setattr(db_tool, key, value)
|
setattr(db_tool, key, value)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_tool)
|
db.refresh(db_tool)
|
||||||
logger.info(f"Tool updated successfully: {tool_id}")
|
logger.info(f"Tool updated successfully: {tool_id}")
|
||||||
@ -71,16 +75,17 @@ def update_tool(db: Session, tool_id: uuid.UUID, tool: ToolCreate) -> Optional[T
|
|||||||
logger.error(f"Error updating tool {tool_id}: {str(e)}")
|
logger.error(f"Error updating tool {tool_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error updating tool"
|
detail="Error updating tool",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
|
def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
|
||||||
"""Remove a tool"""
|
"""Remove a tool"""
|
||||||
try:
|
try:
|
||||||
db_tool = get_tool(db, tool_id)
|
db_tool = get_tool(db, tool_id)
|
||||||
if not db_tool:
|
if not db_tool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
db.delete(db_tool)
|
db.delete(db_tool)
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"Tool removed successfully: {tool_id}")
|
logger.info(f"Tool removed successfully: {tool_id}")
|
||||||
@ -90,5 +95,5 @@ def delete_tool(db: Session, tool_id: uuid.UUID) -> bool:
|
|||||||
logger.error(f"Error removing tool {tool_id}: {str(e)}")
|
logger.error(f"Error removing tool {tool_id}: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error removing tool"
|
detail="Error removing tool",
|
||||||
)
|
)
|
||||||
|
@ -3,7 +3,10 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||||||
from src.models.models import User, Client
|
from src.models.models import User, Client
|
||||||
from src.schemas.user import UserCreate
|
from src.schemas.user import UserCreate
|
||||||
from src.utils.security import get_password_hash, verify_password, generate_token
|
from src.utils.security import get_password_hash, verify_password, generate_token
|
||||||
from src.services.email_service import send_verification_email, send_password_reset_email
|
from src.services.email_service import (
|
||||||
|
send_verification_email,
|
||||||
|
send_password_reset_email,
|
||||||
|
)
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
@ -11,16 +14,22 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, client_id: Optional[uuid.UUID] = None) -> Tuple[Optional[User], str]:
|
|
||||||
|
def create_user(
|
||||||
|
db: Session,
|
||||||
|
user_data: UserCreate,
|
||||||
|
is_admin: bool = False,
|
||||||
|
client_id: Optional[uuid.UUID] = None,
|
||||||
|
) -> Tuple[Optional[User], str]:
|
||||||
"""
|
"""
|
||||||
Creates a new user in the system
|
Creates a new user in the system
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
user_data: User data to be created
|
user_data: User data to be created
|
||||||
is_admin: If the user is an administrator
|
is_admin: If the user is an administrator
|
||||||
client_id: Associated client ID (optional, a new one will be created if not provided)
|
client_id: Associated client ID (optional, a new one will be created if not provided)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message
|
Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message
|
||||||
"""
|
"""
|
||||||
@ -28,17 +37,19 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
|
|||||||
# Check if email already exists
|
# Check if email already exists
|
||||||
db_user = db.query(User).filter(User.email == user_data.email).first()
|
db_user = db.query(User).filter(User.email == user_data.email).first()
|
||||||
if db_user:
|
if db_user:
|
||||||
logger.warning(f"Attempt to register with existing email: {user_data.email}")
|
logger.warning(
|
||||||
|
f"Attempt to register with existing email: {user_data.email}"
|
||||||
|
)
|
||||||
return None, "Email already registered"
|
return None, "Email already registered"
|
||||||
|
|
||||||
# Create verification token
|
# Create verification token
|
||||||
verification_token = generate_token()
|
verification_token = generate_token()
|
||||||
token_expiry = datetime.utcnow() + timedelta(hours=24)
|
token_expiry = datetime.utcnow() + timedelta(hours=24)
|
||||||
|
|
||||||
# Start transaction
|
# Start transaction
|
||||||
user = None
|
user = None
|
||||||
local_client_id = client_id
|
local_client_id = client_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# If not admin and no client_id, create an associated client
|
# If not admin and no client_id, create an associated client
|
||||||
if not is_admin and local_client_id is None:
|
if not is_admin and local_client_id is None:
|
||||||
@ -46,7 +57,7 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
|
|||||||
db.add(client)
|
db.add(client)
|
||||||
db.flush() # Get the client ID
|
db.flush() # Get the client ID
|
||||||
local_client_id = client.id
|
local_client_id = client.id
|
||||||
|
|
||||||
# Create user
|
# Create user
|
||||||
user = User(
|
user = User(
|
||||||
email=user_data.email,
|
email=user_data.email,
|
||||||
@ -56,52 +67,56 @@ def create_user(db: Session, user_data: UserCreate, is_admin: bool = False, clie
|
|||||||
is_active=False, # Inactive until email is verified
|
is_active=False, # Inactive until email is verified
|
||||||
email_verified=False,
|
email_verified=False,
|
||||||
verification_token=verification_token,
|
verification_token=verification_token,
|
||||||
verification_token_expiry=token_expiry
|
verification_token_expiry=token_expiry,
|
||||||
)
|
)
|
||||||
db.add(user)
|
db.add(user)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# Send verification email
|
# Send verification email
|
||||||
email_sent = send_verification_email(user.email, verification_token)
|
email_sent = send_verification_email(user.email, verification_token)
|
||||||
if not email_sent:
|
if not email_sent:
|
||||||
logger.error(f"Failed to send verification email to {user.email}")
|
logger.error(f"Failed to send verification email to {user.email}")
|
||||||
# We don't do rollback here, we just log the error
|
# We don't do rollback here, we just log the error
|
||||||
|
|
||||||
logger.info(f"User created successfully: {user.email}")
|
logger.info(f"User created successfully: {user.email}")
|
||||||
return user, "User created successfully. Check your email to activate your account."
|
return (
|
||||||
|
user,
|
||||||
|
"User created successfully. Check your email to activate your account.",
|
||||||
|
)
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"Error creating user: {str(e)}")
|
logger.error(f"Error creating user: {str(e)}")
|
||||||
return None, f"Error creating user: {str(e)}"
|
return None, f"Error creating user: {str(e)}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error creating user: {str(e)}")
|
logger.error(f"Unexpected error creating user: {str(e)}")
|
||||||
return None, f"Unexpected error: {str(e)}"
|
return None, f"Unexpected error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def verify_email(db: Session, token: str) -> Tuple[bool, str]:
|
def verify_email(db: Session, token: str) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Verify the user's email using the provided token
|
Verify the user's email using the provided token
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
token: Verification token
|
token: Verification token
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: Tuple with verification status and message
|
Tuple[bool, str]: Tuple with verification status and message
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Search for user by token
|
# Search for user by token
|
||||||
user = db.query(User).filter(User.verification_token == token).first()
|
user = db.query(User).filter(User.verification_token == token).first()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
logger.warning(f"Attempt to verify with invalid token: {token}")
|
logger.warning(f"Attempt to verify with invalid token: {token}")
|
||||||
return False, "Invalid verification token"
|
return False, "Invalid verification token"
|
||||||
|
|
||||||
# Check if the token has expired
|
# Check if the token has expired
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
expiry = user.verification_token_expiry
|
expiry = user.verification_token_expiry
|
||||||
|
|
||||||
# Ensure both dates are of the same type (aware or naive)
|
# Ensure both dates are of the same type (aware or naive)
|
||||||
if expiry.tzinfo is not None and now.tzinfo is None:
|
if expiry.tzinfo is not None and now.tzinfo is None:
|
||||||
# If expiry has timezone and now doesn't, convert now to have timezone
|
# If expiry has timezone and now doesn't, convert now to have timezone
|
||||||
@ -109,180 +124,201 @@ def verify_email(db: Session, token: str) -> Tuple[bool, str]:
|
|||||||
elif now.tzinfo is not None and expiry.tzinfo is None:
|
elif now.tzinfo is not None and expiry.tzinfo is None:
|
||||||
# If now has timezone and expiry doesn't, convert expiry to have timezone
|
# If now has timezone and expiry doesn't, convert expiry to have timezone
|
||||||
expiry = expiry.replace(tzinfo=now.tzinfo)
|
expiry = expiry.replace(tzinfo=now.tzinfo)
|
||||||
|
|
||||||
if expiry < now:
|
if expiry < now:
|
||||||
logger.warning(f"Attempt to verify with expired token for user: {user.email}")
|
logger.warning(
|
||||||
|
f"Attempt to verify with expired token for user: {user.email}"
|
||||||
|
)
|
||||||
return False, "Verification token expired"
|
return False, "Verification token expired"
|
||||||
|
|
||||||
# Update user
|
# Update user
|
||||||
user.email_verified = True
|
user.email_verified = True
|
||||||
user.is_active = True
|
user.is_active = True
|
||||||
user.verification_token = None
|
user.verification_token = None
|
||||||
user.verification_token_expiry = None
|
user.verification_token_expiry = None
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"Email verified successfully for user: {user.email}")
|
logger.info(f"Email verified successfully for user: {user.email}")
|
||||||
return True, "Email verified successfully. Your account is active."
|
return True, "Email verified successfully. Your account is active."
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"Error verifying email: {str(e)}")
|
logger.error(f"Error verifying email: {str(e)}")
|
||||||
return False, f"Error verifying email: {str(e)}"
|
return False, f"Error verifying email: {str(e)}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error verifying email: {str(e)}")
|
logger.error(f"Unexpected error verifying email: {str(e)}")
|
||||||
return False, f"Unexpected error: {str(e)}"
|
return False, f"Unexpected error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
|
def resend_verification(db: Session, email: str) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Resend the verification email
|
Resend the verification email
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
email: User email
|
email: User email
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: Tuple with operation status and message
|
Tuple[bool, str]: Tuple with operation status and message
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Search for user by email
|
# Search for user by email
|
||||||
user = db.query(User).filter(User.email == email).first()
|
user = db.query(User).filter(User.email == email).first()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
logger.warning(f"Attempt to resend verification email for non-existent email: {email}")
|
logger.warning(
|
||||||
|
f"Attempt to resend verification email for non-existent email: {email}"
|
||||||
|
)
|
||||||
return False, "Email not found"
|
return False, "Email not found"
|
||||||
|
|
||||||
if user.email_verified:
|
if user.email_verified:
|
||||||
logger.info(f"Attempt to resend verification email for already verified email: {email}")
|
logger.info(
|
||||||
|
f"Attempt to resend verification email for already verified email: {email}"
|
||||||
|
)
|
||||||
return False, "Email already verified"
|
return False, "Email already verified"
|
||||||
|
|
||||||
# Generate new token
|
# Generate new token
|
||||||
verification_token = generate_token()
|
verification_token = generate_token()
|
||||||
token_expiry = datetime.utcnow() + timedelta(hours=24)
|
token_expiry = datetime.utcnow() + timedelta(hours=24)
|
||||||
|
|
||||||
# Update user
|
# Update user
|
||||||
user.verification_token = verification_token
|
user.verification_token = verification_token
|
||||||
user.verification_token_expiry = token_expiry
|
user.verification_token_expiry = token_expiry
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# Send email
|
# Send email
|
||||||
email_sent = send_verification_email(user.email, verification_token)
|
email_sent = send_verification_email(user.email, verification_token)
|
||||||
if not email_sent:
|
if not email_sent:
|
||||||
logger.error(f"Failed to resend verification email to {user.email}")
|
logger.error(f"Failed to resend verification email to {user.email}")
|
||||||
return False, "Failed to send verification email"
|
return False, "Failed to send verification email"
|
||||||
|
|
||||||
logger.info(f"Verification email resent successfully to: {user.email}")
|
logger.info(f"Verification email resent successfully to: {user.email}")
|
||||||
return True, "Verification email resent. Check your inbox."
|
return True, "Verification email resent. Check your inbox."
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"Error resending verification: {str(e)}")
|
logger.error(f"Error resending verification: {str(e)}")
|
||||||
return False, f"Error resending verification: {str(e)}"
|
return False, f"Error resending verification: {str(e)}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error resending verification: {str(e)}")
|
logger.error(f"Unexpected error resending verification: {str(e)}")
|
||||||
return False, f"Unexpected error: {str(e)}"
|
return False, f"Unexpected error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
|
def forgot_password(db: Session, email: str) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Initiates the password recovery process
|
Initiates the password recovery process
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
email: User email
|
email: User email
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: Tuple with operation status and message
|
Tuple[bool, str]: Tuple with operation status and message
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Search for user by email
|
# Search for user by email
|
||||||
user = db.query(User).filter(User.email == email).first()
|
user = db.query(User).filter(User.email == email).first()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
# For security, we don't inform if the email exists or not
|
# For security, we don't inform if the email exists or not
|
||||||
logger.info(f"Attempt to recover password for non-existent email: {email}")
|
logger.info(f"Attempt to recover password for non-existent email: {email}")
|
||||||
return True, "If the email is registered, you will receive instructions to reset your password."
|
return (
|
||||||
|
True,
|
||||||
|
"If the email is registered, you will receive instructions to reset your password.",
|
||||||
|
)
|
||||||
|
|
||||||
# Generate reset token
|
# Generate reset token
|
||||||
reset_token = generate_token()
|
reset_token = generate_token()
|
||||||
token_expiry = datetime.utcnow() + timedelta(hours=1) # Token valid for 1 hour
|
token_expiry = datetime.utcnow() + timedelta(hours=1) # Token valid for 1 hour
|
||||||
|
|
||||||
# Update user
|
# Update user
|
||||||
user.password_reset_token = reset_token
|
user.password_reset_token = reset_token
|
||||||
user.password_reset_expiry = token_expiry
|
user.password_reset_expiry = token_expiry
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# Send email
|
# Send email
|
||||||
email_sent = send_password_reset_email(user.email, reset_token)
|
email_sent = send_password_reset_email(user.email, reset_token)
|
||||||
if not email_sent:
|
if not email_sent:
|
||||||
logger.error(f"Failed to send password reset email to {user.email}")
|
logger.error(f"Failed to send password reset email to {user.email}")
|
||||||
return False, "Failed to send password reset email"
|
return False, "Failed to send password reset email"
|
||||||
|
|
||||||
logger.info(f"Password reset email sent successfully to: {user.email}")
|
logger.info(f"Password reset email sent successfully to: {user.email}")
|
||||||
return True, "If the email is registered, you will receive instructions to reset your password."
|
return (
|
||||||
|
True,
|
||||||
|
"If the email is registered, you will receive instructions to reset your password.",
|
||||||
|
)
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"Error processing password recovery: {str(e)}")
|
logger.error(f"Error processing password recovery: {str(e)}")
|
||||||
return False, f"Error processing password recovery: {str(e)}"
|
return False, f"Error processing password recovery: {str(e)}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error processing password recovery: {str(e)}")
|
logger.error(f"Unexpected error processing password recovery: {str(e)}")
|
||||||
return False, f"Unexpected error: {str(e)}"
|
return False, f"Unexpected error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]:
|
def reset_password(db: Session, token: str, new_password: str) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Resets the user's password using the provided token
|
Resets the user's password using the provided token
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
token: Password reset token
|
token: Password reset token
|
||||||
new_password: New password
|
new_password: New password
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: Tuple with operation status and message
|
Tuple[bool, str]: Tuple with operation status and message
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Search for user by token
|
# Search for user by token
|
||||||
user = db.query(User).filter(User.password_reset_token == token).first()
|
user = db.query(User).filter(User.password_reset_token == token).first()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
logger.warning(f"Attempt to reset password with invalid token: {token}")
|
logger.warning(f"Attempt to reset password with invalid token: {token}")
|
||||||
return False, "Invalid password reset token"
|
return False, "Invalid password reset token"
|
||||||
|
|
||||||
# Check if the token has expired
|
# Check if the token has expired
|
||||||
if user.password_reset_expiry < datetime.utcnow():
|
if user.password_reset_expiry < datetime.utcnow():
|
||||||
logger.warning(f"Attempt to reset password with expired token for user: {user.email}")
|
logger.warning(
|
||||||
|
f"Attempt to reset password with expired token for user: {user.email}"
|
||||||
|
)
|
||||||
return False, "Password reset token expired"
|
return False, "Password reset token expired"
|
||||||
|
|
||||||
# Update password
|
# Update password
|
||||||
user.password_hash = get_password_hash(new_password)
|
user.password_hash = get_password_hash(new_password)
|
||||||
user.password_reset_token = None
|
user.password_reset_token = None
|
||||||
user.password_reset_expiry = None
|
user.password_reset_expiry = None
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"Password reset successfully for user: {user.email}")
|
logger.info(f"Password reset successfully for user: {user.email}")
|
||||||
return True, "Password reset successfully. You can now login with your new password."
|
return (
|
||||||
|
True,
|
||||||
|
"Password reset successfully. You can now login with your new password.",
|
||||||
|
)
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"Error resetting password: {str(e)}")
|
logger.error(f"Error resetting password: {str(e)}")
|
||||||
return False, f"Error resetting password: {str(e)}"
|
return False, f"Error resetting password: {str(e)}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error resetting password: {str(e)}")
|
logger.error(f"Unexpected error resetting password: {str(e)}")
|
||||||
return False, f"Unexpected error: {str(e)}"
|
return False, f"Unexpected error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||||
"""
|
"""
|
||||||
Searches for a user by email
|
Searches for a user by email
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
email: User email
|
email: User email
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[User]: User found or None
|
Optional[User]: User found or None
|
||||||
"""
|
"""
|
||||||
@ -292,15 +328,16 @@ def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
|||||||
logger.error(f"Error searching for user by email: {str(e)}")
|
logger.error(f"Error searching for user by email: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||||
"""
|
"""
|
||||||
Authenticates a user with email and password
|
Authenticates a user with email and password
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
email: User email
|
email: User email
|
||||||
password: User password
|
password: User password
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[User]: Authenticated user or None
|
Optional[User]: Authenticated user or None
|
||||||
"""
|
"""
|
||||||
@ -313,75 +350,78 @@ def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
|||||||
return None
|
return None
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
|
def get_admin_users(db: Session, skip: int = 0, limit: int = 100):
|
||||||
"""
|
"""
|
||||||
Lists the admin users
|
Lists the admin users
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
skip: Number of records to skip
|
skip: Number of records to skip
|
||||||
limit: Maximum number of records to return
|
limit: Maximum number of records to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[User]: List of admin users
|
List[User]: List of admin users
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
users = db.query(User).filter(User.is_admin == True).offset(skip).limit(limit).all()
|
users = db.query(User).filter(User.is_admin).offset(skip).limit(limit).all()
|
||||||
logger.info(f"List of admins: {len(users)} found")
|
logger.info(f"List of admins: {len(users)} found")
|
||||||
return users
|
return users
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error(f"Error listing admins: {str(e)}")
|
logger.error(f"Error listing admins: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error listing admins: {str(e)}")
|
logger.error(f"Unexpected error listing admins: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]:
|
def create_admin_user(db: Session, user_data: UserCreate) -> Tuple[Optional[User], str]:
|
||||||
"""
|
"""
|
||||||
Creates a new admin user
|
Creates a new admin user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
user_data: User data to be created
|
user_data: User data to be created
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message
|
Tuple[Optional[User], str]: Tuple with the created user (or None in case of error) and status message
|
||||||
"""
|
"""
|
||||||
return create_user(db, user_data, is_admin=True)
|
return create_user(db, user_data, is_admin=True)
|
||||||
|
|
||||||
|
|
||||||
def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]:
|
def deactivate_user(db: Session, user_id: uuid.UUID) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Deactivates a user (does not delete, only marks as inactive)
|
Deactivates a user (does not delete, only marks as inactive)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
user_id: ID of the user to be deactivated
|
user_id: ID of the user to be deactivated
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: Tuple with operation status and message
|
Tuple[bool, str]: Tuple with operation status and message
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Search for user by ID
|
# Search for user by ID
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
logger.warning(f"Attempt to deactivate non-existent user: {user_id}")
|
logger.warning(f"Attempt to deactivate non-existent user: {user_id}")
|
||||||
return False, "User not found"
|
return False, "User not found"
|
||||||
|
|
||||||
# Deactivate user
|
# Deactivate user
|
||||||
user.is_active = False
|
user.is_active = False
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"User deactivated successfully: {user.email}")
|
logger.info(f"User deactivated successfully: {user.email}")
|
||||||
return True, "User deactivated successfully"
|
return True, "User deactivated successfully"
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"Error deactivating user: {str(e)}")
|
logger.error(f"Error deactivating user: {str(e)}")
|
||||||
return False, f"Error deactivating user: {str(e)}"
|
return False, f"Error deactivating user: {str(e)}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error deactivating user: {str(e)}")
|
logger.error(f"Unexpected error deactivating user: {str(e)}")
|
||||||
return False, f"Unexpected error: {str(e)}"
|
return False, f"Unexpected error: {str(e)}"
|
||||||
|
@ -3,23 +3,26 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from src.config.settings import settings
|
from src.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
class CustomFormatter(logging.Formatter):
|
class CustomFormatter(logging.Formatter):
|
||||||
"""Custom formatter for logs"""
|
"""Custom formatter for logs"""
|
||||||
|
|
||||||
grey = "\x1b[38;20m"
|
grey = "\x1b[38;20m"
|
||||||
yellow = "\x1b[33;20m"
|
yellow = "\x1b[33;20m"
|
||||||
red = "\x1b[31;20m"
|
red = "\x1b[31;20m"
|
||||||
bold_red = "\x1b[31;1m"
|
bold_red = "\x1b[31;1m"
|
||||||
reset = "\x1b[0m"
|
reset = "\x1b[0m"
|
||||||
|
|
||||||
format_template = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
|
format_template = (
|
||||||
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
|
||||||
|
)
|
||||||
|
|
||||||
FORMATS = {
|
FORMATS = {
|
||||||
logging.DEBUG: grey + format_template + reset,
|
logging.DEBUG: grey + format_template + reset,
|
||||||
logging.INFO: grey + format_template + reset,
|
logging.INFO: grey + format_template + reset,
|
||||||
logging.WARNING: yellow + format_template + reset,
|
logging.WARNING: yellow + format_template + reset,
|
||||||
logging.ERROR: red + format_template + reset,
|
logging.ERROR: red + format_template + reset,
|
||||||
logging.CRITICAL: bold_red + format_template + reset
|
logging.CRITICAL: bold_red + format_template + reset,
|
||||||
}
|
}
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
@ -27,33 +30,34 @@ class CustomFormatter(logging.Formatter):
|
|||||||
formatter = logging.Formatter(log_fmt)
|
formatter = logging.Formatter(log_fmt)
|
||||||
return formatter.format(record)
|
return formatter.format(record)
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(name: str) -> logging.Logger:
|
def setup_logger(name: str) -> logging.Logger:
|
||||||
"""
|
"""
|
||||||
Configures a custom logger
|
Configures a custom logger
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Logger name
|
name: Logger name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logging.Logger: Logger configurado
|
logging.Logger: Logger configurado
|
||||||
"""
|
"""
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
# Remove existing handlers to avoid duplication
|
# Remove existing handlers to avoid duplication
|
||||||
if logger.handlers:
|
if logger.handlers:
|
||||||
logger.handlers.clear()
|
logger.handlers.clear()
|
||||||
|
|
||||||
# Configure the logger level based on the environment variable or configuration
|
# Configure the logger level based on the environment variable or configuration
|
||||||
log_level = getattr(logging, os.getenv("LOG_LEVEL", settings.LOG_LEVEL).upper())
|
log_level = getattr(logging, os.getenv("LOG_LEVEL", settings.LOG_LEVEL).upper())
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# Console handler
|
# Console handler
|
||||||
console_handler = logging.StreamHandler(sys.stdout)
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
console_handler.setFormatter(CustomFormatter())
|
console_handler.setFormatter(CustomFormatter())
|
||||||
console_handler.setLevel(log_level)
|
console_handler.setLevel(log_level)
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
# Prevent logs from being propagated to the root logger
|
# Prevent logs from being propagated to the root logger
|
||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
@ -11,41 +11,44 @@ from dataclasses import dataclass
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Fix bcrypt error with passlib
|
# Fix bcrypt error with passlib
|
||||||
if not hasattr(bcrypt, '__about__'):
|
if not hasattr(bcrypt, "__about__"):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BcryptAbout:
|
class BcryptAbout:
|
||||||
__version__: str = getattr(bcrypt, "__version__")
|
__version__: str = getattr(bcrypt, "__version__")
|
||||||
|
|
||||||
setattr(bcrypt, "__about__", BcryptAbout())
|
setattr(bcrypt, "__about__", BcryptAbout())
|
||||||
|
|
||||||
# Context for password hashing using bcrypt
|
# Context for password hashing using bcrypt
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
def get_password_hash(password: str) -> str:
|
def get_password_hash(password: str) -> str:
|
||||||
"""Creates a password hash"""
|
"""Creates a password hash"""
|
||||||
return pwd_context.hash(password)
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
"""Verifies if the provided password matches the stored hash"""
|
"""Verifies if the provided password matches the stored hash"""
|
||||||
return pwd_context.verify(plain_password, hashed_password)
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
|
||||||
def create_jwt_token(data: dict, expires_delta: timedelta = None) -> str:
|
def create_jwt_token(data: dict, expires_delta: timedelta = None) -> str:
|
||||||
"""Creates a JWT token"""
|
"""Creates a JWT token"""
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
if expires_delta:
|
if expires_delta:
|
||||||
expire = datetime.utcnow() + expires_delta
|
expire = datetime.utcnow() + expires_delta
|
||||||
else:
|
else:
|
||||||
expire = datetime.utcnow() + timedelta(
|
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRATION_TIME)
|
||||||
minutes=settings.JWT_EXPIRATION_TIME
|
|
||||||
)
|
|
||||||
to_encode.update({"exp": expire})
|
to_encode.update({"exp": expire})
|
||||||
encoded_jwt = jwt.encode(
|
encoded_jwt = jwt.encode(
|
||||||
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||||
)
|
)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
def generate_token(length: int = 32) -> str:
|
def generate_token(length: int = 32) -> str:
|
||||||
"""Generates a secure token for email verification or password reset"""
|
"""Generates a secure token for email verification or password reset"""
|
||||||
alphabet = string.ascii_letters + string.digits
|
alphabet = string.ascii_letters + string.digits
|
||||||
token = ''.join(secrets.choice(alphabet) for _ in range(length))
|
token = "".join(secrets.choice(alphabet) for _ in range(length))
|
||||||
return token
|
return token
|
||||||
|
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Package initialization for tests
|
1
tests/api/__init__.py
Normal file
1
tests/api/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# API tests package
|
11
tests/api/test_root.py
Normal file
11
tests/api/test_root.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
def test_read_root(client):
|
||||||
|
"""
|
||||||
|
Test that the root endpoint returns the correct response.
|
||||||
|
"""
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "message" in data
|
||||||
|
assert "documentation" in data
|
||||||
|
assert "version" in data
|
||||||
|
assert "auth" in data
|
1
tests/services/__init__.py
Normal file
1
tests/services/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Services tests package
|
27
tests/services/test_auth_service.py
Normal file
27
tests/services/test_auth_service.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from src.services.auth_service import create_access_token
|
||||||
|
from src.models.models import User
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_access_token():
|
||||||
|
"""
|
||||||
|
Test that an access token is created with the correct data.
|
||||||
|
"""
|
||||||
|
# Create a mock user
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="test@example.com",
|
||||||
|
hashed_password="hashed_password",
|
||||||
|
is_active=True,
|
||||||
|
is_admin=False,
|
||||||
|
name="Test User",
|
||||||
|
client_id=uuid.uuid4(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create token
|
||||||
|
token = create_access_token(user)
|
||||||
|
|
||||||
|
# Simple validation
|
||||||
|
assert token is not None
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert len(token) > 0
|
Loading…
Reference in New Issue
Block a user