"""Database connection helper using IAM authentication and RDS Proxy."""
import os
import time
import logging
import boto3
import psycopg2
import psycopg2.extras
logger = logging.getLogger(__name__)
_rds_client = boto3.client("rds")
# Module-level connection cache (reused across warm Lambda invocations)
_connection = None
def _generate_auth_token() -> str:
"""Generate an IAM authentication token for RDS Proxy."""
return _rds_client.generate_db_auth_token(
DBHostname=os.environ["DB_PROXY_ENDPOINT"],
Port=5432,
DBUsername=os.environ["DB_USER"],
)
def get_connection():
"""Get a database connection via RDS Proxy with IAM authentication.
Reuses an existing connection if available and healthy.
Retries up to 3 times with exponential backoff on failure.
"""
global _connection
if _connection is not None:
try:
_connection.cursor().execute("SELECT 1")
return _connection
except Exception:
logger.warning("Existing connection is stale, reconnecting")
try:
_connection.close()
except Exception:
pass
_connection = None
max_retries = 3
backoff_ms = [100, 200, 400]
for attempt in range(max_retries):
try:
token = _generate_auth_token()
conn = psycopg2.connect(
host=os.environ["DB_PROXY_ENDPOINT"],
port=5432,
dbname=os.environ["DB_NAME"],
user=os.environ["DB_USER"],
password=token,
sslmode="require",
connect_timeout=5,
)
conn.autocommit = False
_connection = conn
return _connection
except Exception as e:
logger.error(
"DB connection attempt %d/%d failed: %s",
attempt + 1,
max_retries,
str(e),
)
if attempt < max_retries - 1:
time.sleep(backoff_ms[attempt] / 1000.0)
else:
raise
def close_connection():
"""Close the cached database connection."""
global _connection
if _connection is not None:
try:
_connection.close()
except Exception:
pass
_connection = None