Skip to content
Snippets Groups Projects
Commit ae12d9a8 authored by Kenechukwu Orjiene's avatar Kenechukwu Orjiene
Browse files

Merge branch 'feat/dev' into 'main'

The keys creation and key verification

See merge request !11
parents 3a940348 cc96ea9d
No related branches found
No related tags found
1 merge request!11The keys creation and key verification
...@@ -6,6 +6,8 @@ import asyncio ...@@ -6,6 +6,8 @@ import asyncio
import logging import logging
from pathlib import Path from pathlib import Path
import sys import sys
from config import CONFIG
# Add src directory to Python path # Add src directory to Python path
src_path = Path(__file__).parent.parent / "src" src_path = Path(__file__).parent.parent / "src"
sys.path.insert(0, str(src_path)) sys.path.insert(0, str(src_path))
...@@ -14,15 +16,20 @@ from pyfed.federation.delivery import ActivityDelivery ...@@ -14,15 +16,20 @@ from pyfed.federation.delivery import ActivityDelivery
from pyfed.federation.discovery import InstanceDiscovery from pyfed.federation.discovery import InstanceDiscovery
from pyfed.security.key_management import KeyManager from pyfed.security.key_management import KeyManager
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for more detailed logs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def send_activity_to_mastodon(): async def send_activity_to_mastodon():
# Initialize components # Initialize components with config
key_manager = KeyManager( key_manager = KeyManager(
domain="localhost:8000", domain=CONFIG["domain"],
keys_path="example_keys" keys_path=CONFIG["keys_path"],
rotation_config=False
) )
await key_manager.initialize()
active_key = await key_manager.get_active_key()
logger.debug(f"Using active key ID: {active_key.key_id}")
discovery = InstanceDiscovery() discovery = InstanceDiscovery()
delivery = ActivityDelivery(key_manager=key_manager) delivery = ActivityDelivery(key_manager=key_manager)
...@@ -31,7 +38,7 @@ async def send_activity_to_mastodon(): ...@@ -31,7 +38,7 @@ async def send_activity_to_mastodon():
await delivery.initialize() await delivery.initialize()
try: try:
# 1. First, perform WebFinger lookup to get the actor's inbox # 1. First, perform WebFinger lookup to get the actor's URL
logger.info("Performing WebFinger lookup...") logger.info("Performing WebFinger lookup...")
webfinger_result = await discovery.webfinger( webfinger_result = await discovery.webfinger(
resource="acct:kene29@mastodon.social" resource="acct:kene29@mastodon.social"
...@@ -51,15 +58,29 @@ async def send_activity_to_mastodon(): ...@@ -51,15 +58,29 @@ async def send_activity_to_mastodon():
if not actor_url: if not actor_url:
raise Exception("Could not find ActivityPub actor URL") raise Exception("Could not find ActivityPub actor URL")
# 2. Create the Activity # 2. Fetch the actor's profile to get their inbox URL
logger.info(f"Fetching actor profile from {actor_url}")
async with discovery.session.get(actor_url) as response:
if response.status != 200:
raise Exception(f"Failed to fetch actor profile: {response.status}")
actor_data = await response.json()
# Get the inbox URL from the actor's profile
inbox_url = actor_data.get('inbox')
if not inbox_url:
raise Exception("Could not find actor's inbox URL")
logger.info(f"Found actor's inbox: {inbox_url}")
# 3. Create the Activity with ngrok domain
note_activity = { note_activity = {
"@context": "https://www.w3.org/ns/activitystreams", "@context": "https://www.w3.org/ns/activitystreams",
"type": "Create", "type": "Create",
"actor": f"https://localhost:8000/users/testuser", "actor": f"https://{CONFIG['domain']}/users/{CONFIG['user']}",
"object": { "object": {
"type": "Note", "type": "Note",
"content": "Hello @kene29@mastodon.social! This is a test message from PyFed.", "content": "Hello @kene29@mastodon.social! This is a test message from PyFed.",
"attributedTo": f"https://localhost:8000/users/testuser", "attributedTo": f"https://{CONFIG['domain']}/users/{CONFIG['user']}",
"to": [actor_url], "to": [actor_url],
"cc": ["https://www.w3.org/ns/activitystreams#Public"] "cc": ["https://www.w3.org/ns/activitystreams#Public"]
}, },
...@@ -67,12 +88,13 @@ async def send_activity_to_mastodon(): ...@@ -67,12 +88,13 @@ async def send_activity_to_mastodon():
"cc": ["https://www.w3.org/ns/activitystreams#Public"] "cc": ["https://www.w3.org/ns/activitystreams#Public"]
} }
# 3. Deliver the activity # 4. Deliver the activity to the inbox
logger.info(f"Delivering activity: {note_activity}") logger.info(f"Delivering activity: {note_activity}")
result = await delivery.deliver_activity( result = await delivery.deliver_activity(
activity=note_activity, activity=note_activity,
recipients=[actor_url] recipients=[inbox_url] # Use the inbox URL instead of actor URL
) )
logger.info(f"Delivery result: {result}")
if result.success: if result.success:
logger.info("Activity delivered successfully!") logger.info("Activity delivered successfully!")
......
...@@ -9,10 +9,14 @@ import aiohttp ...@@ -9,10 +9,14 @@ import aiohttp
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
import asyncio import asyncio
import certifi
import ssl
import json
from ..utils.exceptions import DeliveryError from ..utils.exceptions import DeliveryError
from ..utils.logging import get_logger from ..utils.logging import get_logger
from ..security.key_management import KeyManager from ..security.key_management import KeyManager
from ..security.http_signatures import HTTPSignatureVerifier
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -36,7 +40,7 @@ class ActivityDelivery: ...@@ -36,7 +40,7 @@ class ActivityDelivery:
key_manager: KeyManager, key_manager: KeyManager,
timeout: int = 30, timeout: int = 30,
max_retries: int = 3, max_retries: int = 3,
retry_delay: int = 300): retry_delay: int = 20):
"""Initialize delivery service.""" """Initialize delivery service."""
self.key_manager = key_manager self.key_manager = key_manager
self.timeout = timeout self.timeout = timeout
...@@ -44,16 +48,28 @@ class ActivityDelivery: ...@@ -44,16 +48,28 @@ class ActivityDelivery:
self.retry_delay = retry_delay self.retry_delay = retry_delay
self.delivery_status = {} self.delivery_status = {}
self.session = None self.session = None
self.signature_verifier = key_manager
async def initialize(self) -> None: async def initialize(self) -> None:
"""Initialize delivery service.""" """Initialize delivery service."""
active_key = await self.key_manager.get_active_key()
key_id = active_key.key_id
logger.debug(f"\n\n\n\n Using existing key id: {key_id}")
self.signature_verifier = HTTPSignatureVerifier(
key_manager=self.key_manager,
key_id=key_id
)
# Create SSL context with certifi certificates
ssl_context = ssl.create_default_context(cafile=certifi.where())
self.session = aiohttp.ClientSession( self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.timeout), timeout=aiohttp.ClientTimeout(total=self.timeout),
headers={ headers={
"User-Agent": f"PyFed/1.0 (+https://{self.key_manager.domain})", "User-Agent": f"PyFed/1.0 (+https://{self.key_manager.domain})",
"Accept": "application/activity+json" "Accept": "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\""
} },
connector=aiohttp.TCPConnector(ssl=ssl_context)
) )
async def close(self) -> None: async def close(self) -> None:
...@@ -68,51 +84,33 @@ class ActivityDelivery: ...@@ -68,51 +84,33 @@ class ActivityDelivery:
result = DeliveryResult() result = DeliveryResult()
try: try:
# Group recipients by domain logger.info(f"Starting delivery to {len(recipients)} recipients")
domain_groups = self._group_by_domain(recipients)
for domain, domain_recipients in domain_groups.items(): for recipient in recipients:
# Try shared inbox first logger.info(f"Attempting delivery to {recipient}")
shared_result = await self._try_shared_inbox( delivery_start = datetime.utcnow()
activity, domain, domain_recipients
inbox_result = await self._deliver_to_inbox(
activity,
recipient
) )
# Update result with shared inbox attempt delivery_time = (datetime.utcnow() - delivery_start).total_seconds()
if shared_result.success: logger.info(f"Delivery to {recipient} took {delivery_time:.2f} seconds")
result.success.extend(shared_result.success)
result.status_code = shared_result.status_code if inbox_result.success:
result.error_message = shared_result.error_message result.success.extend(inbox_result.success)
else: else:
# Individual delivery for failed recipients result.failed.extend(inbox_result.failed)
for recipient in domain_recipients: result.status_code = inbox_result.status_code
inbox_result = await self._deliver_to_inbox( result.error_message = inbox_result.error_message
activity,
recipient
)
if inbox_result.success:
result.success.extend(inbox_result.success)
else:
result.failed.extend(inbox_result.failed)
result.status_code = inbox_result.status_code
result.error_message = inbox_result.error_message
# Ensure no duplicates
result.success = list(set(result.success))
result.failed = list(set(result.failed))
# Store delivery status
if activity.get('id'):
self.delivery_status[activity['id']] = {
'success': result.success,
'failed': result.failed,
'timestamp': datetime.utcnow().isoformat()
}
logger.info(f"Delivery completed. Success: {len(result.success)}, Failed: {len(result.failed)}")
return result return result
except Exception as e: except Exception as e:
logger.error(f"Activity delivery failed: {e}") logger.error(f"Activity delivery failed: {e}")
result.failed = list(set(recipients)) result.failed = recipients
result.error_message = str(e) result.error_message = str(e)
return result return result
...@@ -124,13 +122,23 @@ class ActivityDelivery: ...@@ -124,13 +122,23 @@ class ActivityDelivery:
result = DeliveryResult() result = DeliveryResult()
try: try:
# Prepare headers # Parse URL and get host
parsed_url = urlparse(inbox_url) parsed_url = urlparse(inbox_url)
host = parsed_url.netloc
# Convert activity to JSON string once to ensure consistency
activity_json = json.dumps(activity,
sort_keys=True,
ensure_ascii=True,
separators=(',', ':'))
# Prepare headers
headers = { headers = {
"Accept": "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"", "Accept": "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"",
"Content-Type": "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"", "Content-Type": "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"",
"User-Agent": "PyFed/1.0", "User-Agent": "PyFed/1.0",
"Host": parsed_url.netloc "Host": host,
"Date": datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT')
} }
# Sign request # Sign request
...@@ -138,51 +146,51 @@ class ActivityDelivery: ...@@ -138,51 +146,51 @@ class ActivityDelivery:
method="POST", method="POST",
path=parsed_url.path, path=parsed_url.path,
headers=headers, headers=headers,
body=activity body=activity # Pass the original activity for consistent hashing
) )
async with aiohttp.ClientSession() as session: # Send the request with the exact same JSON string we hashed
async with await session.post( async with self.session.post(
inbox_url, inbox_url,
json=activity, data=activity_json, # Use the exact JSON string
headers=signed_headers, headers={
timeout=self.timeout **signed_headers,
) as response: "Content-Type": "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\""
result.status_code = response.status }
error_text = await response.text() ) as response:
result.status_code = response.status
if response.status in [200, 201, 202]: error_text = await response.text()
result.success = [inbox_url]
return result if response.status in [200, 201, 202]:
result.success = [inbox_url]
if response.status == 429:
retry_after = int(response.headers.get('Retry-After', self.retry_delay))
result.retry_after = retry_after
await asyncio.sleep(retry_after)
if attempt < self.max_retries:
return await self._deliver_to_inbox(
activity, inbox_url, attempt + 1
)
result.error_message = f"Delivery failed: {response.status} - {error_text}"
result.failed = [inbox_url]
return result return result
if response.status == 429:
retry_after = int(response.headers.get('Retry-After', self.retry_delay))
result.retry_after = retry_after
await asyncio.sleep(retry_after)
if attempt < self.max_retries:
return await self._deliver_to_inbox(
activity, inbox_url, attempt + 1
)
result.error_message = f"Delivery failed: {response.status} - {error_text}"
result.failed = [inbox_url]
return result
except asyncio.TimeoutError: except asyncio.TimeoutError:
result.error_message = f"Delivery timeout to {inbox_url}" logger.warning(f"Request timed out for {inbox_url}")
result.error_message = f"Request timed out after {self.timeout} seconds"
result.failed = [inbox_url] result.failed = [inbox_url]
except Exception as e: # Retry with backoff if attempts remain
result.error_message = str(e) if attempt < self.max_retries:
result.failed = [inbox_url] retry_delay = self.retry_delay * attempt # Exponential backoff
logger.info(f"Retrying in {retry_delay} seconds (attempt {attempt + 1}/{self.max_retries})")
await asyncio.sleep(retry_delay)
return await self._deliver_to_inbox(activity, inbox_url, attempt + 1)
if attempt < self.max_retries: return result
await asyncio.sleep(self.retry_delay)
return await self._deliver_to_inbox(
activity, inbox_url, attempt + 1
)
return result
async def _try_shared_inbox(self, async def _try_shared_inbox(self,
activity: Dict[str, Any], activity: Dict[str, Any],
......
...@@ -23,6 +23,7 @@ import hashlib ...@@ -23,6 +23,7 @@ import hashlib
from ..utils.exceptions import SignatureError from ..utils.exceptions import SignatureError
from ..utils.logging import get_logger from ..utils.logging import get_logger
from ..cache.memory_cache import MemoryCache from ..cache.memory_cache import MemoryCache
from .key_management import KeyManager
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -44,19 +45,22 @@ class HTTPSignatureVerifier: ...@@ -44,19 +45,22 @@ class HTTPSignatureVerifier:
"""Enhanced HTTP signature verification.""" """Enhanced HTTP signature verification."""
def __init__(self, def __init__(self,
private_key_path: str, key_manager: Optional[KeyManager] = None,
public_key_path: str, private_key_path: str = "",
public_key_path: str = "",
key_id: Optional[str] = None): key_id: Optional[str] = None):
"""Initialize signature verifier.""" """Initialize signature verifier."""
self.key_manager = key_manager
self.private_key_path = private_key_path self.private_key_path = private_key_path
self.public_key_path = public_key_path self.public_key_path = public_key_path
self.key_id = key_id or f"https://{urlparse(private_key_path).netloc}#main-key" self.key_id = key_id
self.signature_cache = SignatureCache() self.signature_cache = SignatureCache()
self._test_now = None self._test_now = None
# Load keys # Load keys only if paths are provided and no key_manager is present
self.private_key = self._load_private_key() if not self.key_manager and (private_key_path or public_key_path):
self.public_key = self._load_public_key() self.private_key = self._load_private_key()
self.public_key = self._load_public_key()
def set_test_time(self, test_time: datetime) -> None: def set_test_time(self, test_time: datetime) -> None:
"""Set a fixed time for testing.""" """Set a fixed time for testing."""
...@@ -154,35 +158,23 @@ class HTTPSignatureVerifier: ...@@ -154,35 +158,23 @@ class HTTPSignatureVerifier:
path: str, path: str,
headers: Dict[str, str], headers: Dict[str, str],
body: Optional[Dict[str, Any]] = None) -> Dict[str, str]: body: Optional[Dict[str, Any]] = None) -> Dict[str, str]:
""" """Sign HTTP request."""
Sign HTTP request.
Args:
method: HTTP method
path: Request path
headers: Request headers
body: Request body (optional)
Returns:
Dict with signature headers
"""
try: try:
request_headers = headers.copy() request_headers = headers.copy()
# Add date if not present, using test time if set # Add date if not present
if 'date' not in request_headers: if 'date' not in request_headers:
now = self._test_now if self._test_now is not None else datetime.utcnow() now = self._test_now if self._test_now is not None else datetime.utcnow()
request_headers['date'] = now.strftime('%a, %d %b %Y %H:%M:%S GMT') request_headers['date'] = now.strftime('%a, %d %b %Y %H:%M:%S GMT')
# Add digest if body present # Calculate digest first
if body is not None: if body is not None:
body_digest = self._generate_digest(body) digest = self._generate_digest(body)
request_headers['digest'] = f"SHA-256={body_digest}" request_headers['digest'] = f"SHA-256={digest}"
logger.debug(f"Added digest header: SHA-256={digest}")
# Headers to sign # Headers to sign (in specific order)
headers_to_sign = ['(request-target)', 'host', 'date'] headers_to_sign = ['(request-target)', 'host', 'date', 'digest']
if body is not None:
headers_to_sign.append('digest')
# Build signing string # Build signing string
signing_string = self._build_signing_string( signing_string = self._build_signing_string(
...@@ -191,10 +183,19 @@ class HTTPSignatureVerifier: ...@@ -191,10 +183,19 @@ class HTTPSignatureVerifier:
request_headers, request_headers,
headers_to_sign headers_to_sign
) )
logger.debug(f"Headers being signed: {headers_to_sign}")
logger.debug(f"Signing string: {signing_string}")
# Get private key
if self.key_manager:
private_key = await self.key_manager.get_active_private_key()
else:
private_key = self.private_key
# Sign # Sign
signature = self.private_key.sign( signature = private_key.sign(
signing_string.encode(), signing_string.encode('utf-8'),
padding.PKCS1v15(), padding.PKCS1v15(),
hashes.SHA256() hashes.SHA256()
) )
...@@ -207,10 +208,14 @@ class HTTPSignatureVerifier: ...@@ -207,10 +208,14 @@ class HTTPSignatureVerifier:
f'signature="{base64.b64encode(signature).decode()}"' f'signature="{base64.b64encode(signature).decode()}"'
) )
return { # Return headers with both Digest and Signature
signed_headers = {
**request_headers, **request_headers,
'Signature': signature_header 'Signature': signature_header
} }
logger.debug(f"Final headers: {signed_headers}")
return signed_headers
except Exception as e: except Exception as e:
logger.error(f"Request signing failed: {e}") logger.error(f"Request signing failed: {e}")
...@@ -274,20 +279,65 @@ class HTTPSignatureVerifier: ...@@ -274,20 +279,65 @@ class HTTPSignatureVerifier:
headers: Dict[str, str], headers: Dict[str, str],
signed_headers: List[str]) -> str: signed_headers: List[str]) -> str:
"""Build string to sign.""" """Build string to sign."""
lines = [] try:
lines = []
for header in signed_headers: # Convert headers to case-insensitive dictionary
if header == '(request-target)': headers_lower = {k.lower(): v for k, v in headers.items()}
lines.append(f"(request-target): {method.lower()} {path}")
else: for header in signed_headers:
if header.lower() not in [k.lower() for k in headers]: if header == '(request-target)':
raise SignatureError(f"Missing required header: {header}") lines.append(f"(request-target): {method.lower()} {path}")
lines.append(f"{header.lower()}: {headers[header]}") else:
header_lower = header.lower()
if header_lower not in headers_lower:
logger.error(f"Missing required header: {header}")
logger.error(f"Available headers: {list(headers_lower.keys())}")
raise SignatureError(f"Missing required header: {header}")
lines.append(f"{header_lower}: {headers_lower[header_lower]}")
return '\n'.join(lines) signing_string = '\n'.join(lines)
logger.debug(f"Signing string: {signing_string}")
return signing_string
except Exception as e:
logger.error(f"Failed to build signing string: {e}")
logger.error(f"Headers: {headers}")
logger.error(f"Signed headers: {signed_headers}")
raise SignatureError(f"Failed to build signing string: {e}")
def _generate_digest(self, body: Dict[str, Any]) -> str: def _generate_digest(self, body: Dict[str, Any]) -> str:
"""Generate SHA-256 digest of body.""" """Generate SHA-256 digest of body."""
body_bytes = json.dumps(body, sort_keys=True).encode() try:
digest = hashlib.sha256(body_bytes).digest() # Convert body to JSON string with canonical form
return base64.b64encode(digest).decode() # Use compact separators and sort keys
\ No newline at end of file body_json = json.dumps(body,
sort_keys=True,
ensure_ascii=True,
separators=(',', ':'))
# Debug the exact string being hashed
logger.debug(f"JSON string being hashed (length={len(body_json)}): {body_json}")
# Hash the exact bytes that will be sent
body_bytes = body_json.encode('utf-8')
logger.debug(f"Bytes being hashed (length={len(body_bytes)}): {body_bytes}")
# Calculate SHA-256
hasher = hashlib.sha256()
hasher.update(body_bytes)
digest = hasher.digest()
# Base64 encode
digest_b64 = base64.b64encode(digest).decode('ascii')
# Compare with Mastodon's expected digest
logger.debug("Digest comparison:")
logger.debug(f"Our digest: {digest_b64}")
logger.debug(f"Mastodon wants: bzRP4UzlS2nkynOrOT0Nk+m2twsOLoeOcJSXNsk3NW0=")
return digest_b64
except Exception as e:
logger.error(f"Failed to generate digest: {e}")
logger.error(f"Body: {body}")
raise SignatureError(f"Failed to generate digest: {e}")
\ No newline at end of file
...@@ -58,18 +58,25 @@ class KeyManager: ...@@ -58,18 +58,25 @@ class KeyManager:
async def initialize(self) -> None: async def initialize(self) -> None:
"""Initialize key manager.""" """Initialize key manager."""
try: try:
logger.info(f"Initializing key manager with path: {self.keys_path}")
# Create keys directory # Create keys directory
self.keys_path.mkdir(parents=True, exist_ok=True) self.keys_path.mkdir(parents=True, exist_ok=True)
logger.info("Created keys directory")
# Load existing keys # Load existing keys
await self._load_existing_keys() await self._load_existing_keys()
logger.info(f"Loaded {len(self.active_keys)} existing keys")
# Generate initial keys if none exist # Generate initial keys if none exist
if not self.active_keys: if not self.active_keys:
logger.info("No active keys found, generating new key pair")
await self.generate_key_pair() await self.generate_key_pair()
logger.info(f"Generated new key pair, total active keys: {len(self.active_keys)}")
# Start rotation task # Start rotation task
self._rotation_task = asyncio.create_task(self._key_rotation_loop()) self._rotation_task = asyncio.create_task(self._key_rotation_loop())
logger.info("Started key rotation task")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize key manager: {e}") logger.error(f"Failed to initialize key manager: {e}")
...@@ -78,6 +85,7 @@ class KeyManager: ...@@ -78,6 +85,7 @@ class KeyManager:
async def generate_key_pair(self) -> KeyPair: async def generate_key_pair(self) -> KeyPair:
"""Generate new key pair.""" """Generate new key pair."""
try: try:
logger.info("Generating new key pair")
# Generate keys # Generate keys
private_key = rsa.generate_private_key( private_key = rsa.generate_private_key(
public_exponent=65537, public_exponent=65537,
...@@ -89,8 +97,18 @@ class KeyManager: ...@@ -89,8 +97,18 @@ class KeyManager:
created_at = datetime.utcnow() created_at = datetime.utcnow()
expires_at = created_at + timedelta(days=self.rotation_config.rotation_interval) expires_at = created_at + timedelta(days=self.rotation_config.rotation_interval)
# Generate key ID # Generate key ID (for HTTP use)
key_id = f"https://{self.domain}/keys/{created_at.timestamp()}" key_id = f"https://{self.domain}/keys/{int(created_at.timestamp())}"
logger.debug(f"\n\n\n\nkey id - key management {key_id}")
# Generate safe file path (for storage)
safe_timestamp = str(int(created_at.timestamp()))
safe_domain = self.domain.replace(':', '_').replace('/', '_').replace('.', '_')
safe_path = f"{safe_domain}_{safe_timestamp}"
logger.info(f"Generated key ID: {key_id}")
logger.info(f"Safe path: {safe_path}")
# Create key pair # Create key pair
key_pair = KeyPair( key_pair = KeyPair(
...@@ -101,11 +119,13 @@ class KeyManager: ...@@ -101,11 +119,13 @@ class KeyManager:
key_id=key_id key_id=key_id
) )
# Save keys # Save keys with safe path
await self._save_key_pair(key_pair) await self._save_key_pair(key_pair, safe_path)
logger.info("Saved key pair to disk")
# Add to active keys # Add to active keys
self.active_keys[key_id] = key_pair self.active_keys[key_id] = key_pair
logger.info(f"Added key pair to active keys. Total active keys: {len(self.active_keys)}")
return key_pair return key_pair
...@@ -152,6 +172,7 @@ class KeyManager: ...@@ -152,6 +172,7 @@ class KeyManager:
key=lambda k: k.created_at key=lambda k: k.created_at
) )
async def verify_key(self, key_id: str, domain: str) -> bool: async def verify_key(self, key_id: str, domain: str) -> bool:
"""Verify a key's validity.""" """Verify a key's validity."""
try: try:
...@@ -171,12 +192,20 @@ class KeyManager: ...@@ -171,12 +192,20 @@ class KeyManager:
async def _load_existing_keys(self) -> None: async def _load_existing_keys(self) -> None:
"""Load existing keys from disk.""" """Load existing keys from disk."""
try: try:
for key_file in self.keys_path.glob("*.json"): # Recursively search for all json files
for key_file in self.keys_path.rglob("*.json"):
logger.info(f"Found key metadata file: {key_file}")
async with aiofiles.open(key_file, 'r') as f: async with aiofiles.open(key_file, 'r') as f:
metadata = json.loads(await f.read()) metadata = json.loads(await f.read())
# Load private key # Get the private key path from the same directory as the metadata
private_key_path = self.keys_path / f"{metadata['key_id']}_private.pem" private_key_path = key_file.parent / f"{key_file.stem}_private.pem"
logger.info(f"Looking for private key at: {private_key_path}")
if not private_key_path.exists():
logger.warning(f"Private key not found at {private_key_path}")
continue
async with aiofiles.open(private_key_path, 'rb') as f: async with aiofiles.open(private_key_path, 'rb') as f:
private_key = serialization.load_pem_private_key( private_key = serialization.load_pem_private_key(
await f.read(), await f.read(),
...@@ -195,16 +224,19 @@ class KeyManager: ...@@ -195,16 +224,19 @@ class KeyManager:
# Add to active keys if not expired # Add to active keys if not expired
if datetime.utcnow() <= key_pair.expires_at: if datetime.utcnow() <= key_pair.expires_at:
self.active_keys[key_pair.key_id] = key_pair self.active_keys[key_pair.key_id] = key_pair
logger.info(f"Loaded active key: {key_pair.key_id}")
else:
logger.info(f"Skipping expired key: {key_pair.key_id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to load existing keys: {e}") logger.error(f"Failed to load existing keys: {e}")
raise KeyManagementError(f"Failed to load existing keys: {e}") raise KeyManagementError(f"Failed to load existing keys: {e}")
async def _save_key_pair(self, key_pair: KeyPair) -> None: async def _save_key_pair(self, key_pair: KeyPair, safe_path: str) -> None:
"""Save key pair to disk.""" """Save key pair to disk."""
try: try:
# Save private key # Save private key
private_key_path = self.keys_path / f"{key_pair.key_id}_private.pem" private_key_path = self.keys_path / f"{safe_path}_private.pem"
private_pem = key_pair.private_key.private_bytes( private_pem = key_pair.private_key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8, format=serialization.PrivateFormat.PKCS8,
...@@ -214,7 +246,7 @@ class KeyManager: ...@@ -214,7 +246,7 @@ class KeyManager:
await f.write(private_pem) await f.write(private_pem)
# Save public key # Save public key
public_key_path = self.keys_path / f"{key_pair.key_id}_public.pem" public_key_path = self.keys_path / f"{safe_path}_public.pem"
public_pem = key_pair.public_key.public_bytes( public_pem = key_pair.public_key.public_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo format=serialization.PublicFormat.SubjectPublicKeyInfo
...@@ -226,9 +258,10 @@ class KeyManager: ...@@ -226,9 +258,10 @@ class KeyManager:
metadata = { metadata = {
'key_id': key_pair.key_id, 'key_id': key_pair.key_id,
'created_at': key_pair.created_at.isoformat(), 'created_at': key_pair.created_at.isoformat(),
'expires_at': key_pair.expires_at.isoformat() 'expires_at': key_pair.expires_at.isoformat(),
'safe_path': safe_path
} }
metadata_path = self.keys_path / f"{key_pair.key_id}.json" metadata_path = self.keys_path / f"{safe_path}.json"
async with aiofiles.open(metadata_path, 'w') as f: async with aiofiles.open(metadata_path, 'w') as f:
await f.write(json.dumps(metadata)) await f.write(json.dumps(metadata))
...@@ -284,3 +317,15 @@ class KeyManager: ...@@ -284,3 +317,15 @@ class KeyManager:
await self._rotation_task await self._rotation_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
async def get_active_private_key(self) -> RSAPrivateKey:
"""Get the most recent active private key."""
if not self.active_keys:
raise KeyManagementError("No active keys available")
# Return the private key of the most recently created key
most_recent_key = max(
self.active_keys.values(),
key=lambda k: k.created_at
)
return most_recent_key.private_key
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment