74 lines
2.2 KiB
Python
74 lines
2.2 KiB
Python
|
|
"""
|
||
|
|
Middleware for FastAPI application
|
||
|
|
"""
|
||
|
|
import time
|
||
|
|
import logging
|
||
|
|
from typing import Callable
|
||
|
|
from fastapi import Request, Response
|
||
|
|
from fastapi.responses import JSONResponse
|
||
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||
|
|
from slowapi.util import get_remote_address
|
||
|
|
from slowapi.errors import RateLimitExceeded
|
||
|
|
|
||
|
|
# Setup logging
|
||
|
|
logging.basicConfig(
|
||
|
|
level=logging.INFO,
|
||
|
|
format='[%(asctime)s] %(levelname)s: %(message)s',
|
||
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||
|
|
)
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# Rate limiter
|
||
|
|
limiter = Limiter(key_func=get_remote_address)
|
||
|
|
|
||
|
|
|
||
|
|
async def log_requests(request: Request, call_next: Callable):
|
||
|
|
"""Log all incoming requests"""
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
# Log request
|
||
|
|
logger.info(f"🔍 [DEBUG] Incoming request {{method: '{request.method}', path: '{request.url.path}', ip: '{request.client.host}'}}")
|
||
|
|
|
||
|
|
# Process request
|
||
|
|
response = await call_next(request)
|
||
|
|
|
||
|
|
# Calculate duration
|
||
|
|
duration = int((time.time() - start_time) * 1000)
|
||
|
|
|
||
|
|
# Log response
|
||
|
|
status_emoji = "📝" if 200 <= response.status_code < 300 else "⚠️" if 400 <= response.status_code < 500 else "❌"
|
||
|
|
level = "INFO" if 200 <= response.status_code < 300 else "WARN" if 400 <= response.status_code < 500 else "ERROR"
|
||
|
|
|
||
|
|
logger.log(
|
||
|
|
getattr(logging, level),
|
||
|
|
f"{status_emoji} [{level}] Request completed {{method: '{request.method}', path: '{request.url.path}', status: {response.status_code}, duration: '{duration}ms'}}"
|
||
|
|
)
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
|
||
|
|
async def security_headers(request: Request, call_next: Callable):
|
||
|
|
"""Add security headers"""
|
||
|
|
response = await call_next(request)
|
||
|
|
|
||
|
|
response.headers['X-Content-Type-Options'] = 'nosniff'
|
||
|
|
response.headers['X-Frame-Options'] = 'DENY'
|
||
|
|
response.headers['X-XSS-Protection'] = '1; mode=block'
|
||
|
|
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
|
||
|
|
def setup_middleware(app):
|
||
|
|
"""Setup all middleware"""
|
||
|
|
# Request logging
|
||
|
|
app.middleware("http")(log_requests)
|
||
|
|
|
||
|
|
# Security headers
|
||
|
|
app.middleware("http")(security_headers)
|
||
|
|
|
||
|
|
# Rate limiting
|
||
|
|
app.state.limiter = limiter
|
||
|
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||
|
|
|