Skip to content
Snippets Groups Projects
Commit cc96ea9d authored by supersonicwisd1's avatar supersonicwisd1
Browse files

updates on the key_id and findign the key_id

parent 6da4dad2
No related branches found
No related tags found
1 merge request!11The keys creation and key verification
......@@ -6,6 +6,8 @@ import asyncio
import logging
from pathlib import Path
import sys
from config import CONFIG
# Add src directory to Python path
src_path = Path(__file__).parent.parent / "src"
sys.path.insert(0, str(src_path))
......@@ -14,15 +16,20 @@ from pyfed.federation.delivery import ActivityDelivery
from pyfed.federation.discovery import InstanceDiscovery
from pyfed.security.key_management import KeyManager
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for more detailed logs
logger = logging.getLogger(__name__)
async def send_activity_to_mastodon():
# Initialize components
# Initialize components with config
key_manager = KeyManager(
domain="localhost:8000",
keys_path="example_keys"
domain=CONFIG["domain"],
keys_path=CONFIG["keys_path"],
rotation_config=False
)
await key_manager.initialize()
active_key = await key_manager.get_active_key()
logger.debug(f"Using active key ID: {active_key.key_id}")
discovery = InstanceDiscovery()
delivery = ActivityDelivery(key_manager=key_manager)
......@@ -31,7 +38,7 @@ async def send_activity_to_mastodon():
await delivery.initialize()
try:
# 1. First, perform WebFinger lookup to get the actor's inbox
# 1. First, perform WebFinger lookup to get the actor's URL
logger.info("Performing WebFinger lookup...")
webfinger_result = await discovery.webfinger(
resource="acct:kene29@mastodon.social"
......@@ -51,15 +58,29 @@ async def send_activity_to_mastodon():
if not actor_url:
raise Exception("Could not find ActivityPub actor URL")
# 2. Create the Activity
# 2. Fetch the actor's profile to get their inbox URL
logger.info(f"Fetching actor profile from {actor_url}")
async with discovery.session.get(actor_url) as response:
if response.status != 200:
raise Exception(f"Failed to fetch actor profile: {response.status}")
actor_data = await response.json()
# Get the inbox URL from the actor's profile
inbox_url = actor_data.get('inbox')
if not inbox_url:
raise Exception("Could not find actor's inbox URL")
logger.info(f"Found actor's inbox: {inbox_url}")
# 3. Create the Activity with ngrok domain
note_activity = {
"@context": "https://www.w3.org/ns/activitystreams",
"type": "Create",
"actor": f"https://localhost:8000/users/testuser",
"actor": f"https://{CONFIG['domain']}/users/{CONFIG['user']}",
"object": {
"type": "Note",
"content": "Hello @kene29@mastodon.social! This is a test message from PyFed.",
"attributedTo": f"https://localhost:8000/users/testuser",
"attributedTo": f"https://{CONFIG['domain']}/users/{CONFIG['user']}",
"to": [actor_url],
"cc": ["https://www.w3.org/ns/activitystreams#Public"]
},
......@@ -67,12 +88,13 @@ async def send_activity_to_mastodon():
"cc": ["https://www.w3.org/ns/activitystreams#Public"]
}
# 3. Deliver the activity
# 4. Deliver the activity to the inbox
logger.info(f"Delivering activity: {note_activity}")
result = await delivery.deliver_activity(
activity=note_activity,
recipients=[actor_url]
recipients=[inbox_url] # Use the inbox URL instead of actor URL
)
logger.info(f"Delivery result: {result}")
if result.success:
logger.info("Activity delivered successfully!")
......
......@@ -9,10 +9,14 @@ import aiohttp
from dataclasses import dataclass
from datetime import datetime
import asyncio
import certifi
import ssl
import json
from ..utils.exceptions import DeliveryError
from ..utils.logging import get_logger
from ..security.key_management import KeyManager
from ..security.http_signatures import HTTPSignatureVerifier
logger = get_logger(__name__)
......@@ -36,7 +40,7 @@ class ActivityDelivery:
key_manager: KeyManager,
timeout: int = 30,
max_retries: int = 3,
retry_delay: int = 300):
retry_delay: int = 20):
"""Initialize delivery service."""
self.key_manager = key_manager
self.timeout = timeout
......@@ -44,16 +48,28 @@ class ActivityDelivery:
self.retry_delay = retry_delay
self.delivery_status = {}
self.session = None
self.signature_verifier = key_manager
async def initialize(self) -> None:
"""Initialize delivery service."""
active_key = await self.key_manager.get_active_key()
key_id = active_key.key_id
logger.debug(f"\n\n\n\n Using existing key id: {key_id}")
self.signature_verifier = HTTPSignatureVerifier(
key_manager=self.key_manager,
key_id=key_id
)
# Create SSL context with certifi certificates
ssl_context = ssl.create_default_context(cafile=certifi.where())
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.timeout),
headers={
"User-Agent": f"PyFed/1.0 (+https://{self.key_manager.domain})",
"Accept": "application/activity+json"
}
"Accept": "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\""
},
connector=aiohttp.TCPConnector(ssl=ssl_context)
)
async def close(self) -> None:
......@@ -68,51 +84,33 @@ class ActivityDelivery:
result = DeliveryResult()
try:
# Group recipients by domain
domain_groups = self._group_by_domain(recipients)
logger.info(f"Starting delivery to {len(recipients)} recipients")
for domain, domain_recipients in domain_groups.items():
# Try shared inbox first
shared_result = await self._try_shared_inbox(
activity, domain, domain_recipients
for recipient in recipients:
logger.info(f"Attempting delivery to {recipient}")
delivery_start = datetime.utcnow()
inbox_result = await self._deliver_to_inbox(
activity,
recipient
)
# Update result with shared inbox attempt
if shared_result.success:
result.success.extend(shared_result.success)
result.status_code = shared_result.status_code
result.error_message = shared_result.error_message
delivery_time = (datetime.utcnow() - delivery_start).total_seconds()
logger.info(f"Delivery to {recipient} took {delivery_time:.2f} seconds")
if inbox_result.success:
result.success.extend(inbox_result.success)
else:
# Individual delivery for failed recipients
for recipient in domain_recipients:
inbox_result = await self._deliver_to_inbox(
activity,
recipient
)
if inbox_result.success:
result.success.extend(inbox_result.success)
else:
result.failed.extend(inbox_result.failed)
result.status_code = inbox_result.status_code
result.error_message = inbox_result.error_message
# Ensure no duplicates
result.success = list(set(result.success))
result.failed = list(set(result.failed))
# Store delivery status
if activity.get('id'):
self.delivery_status[activity['id']] = {
'success': result.success,
'failed': result.failed,
'timestamp': datetime.utcnow().isoformat()
}
result.failed.extend(inbox_result.failed)
result.status_code = inbox_result.status_code
result.error_message = inbox_result.error_message
logger.info(f"Delivery completed. Success: {len(result.success)}, Failed: {len(result.failed)}")
return result
except Exception as e:
logger.error(f"Activity delivery failed: {e}")
result.failed = list(set(recipients))
result.failed = recipients
result.error_message = str(e)
return result
......@@ -124,13 +122,23 @@ class ActivityDelivery:
result = DeliveryResult()
try:
# Prepare headers
# Parse URL and get host
parsed_url = urlparse(inbox_url)
host = parsed_url.netloc
# Convert activity to JSON string once to ensure consistency
activity_json = json.dumps(activity,
sort_keys=True,
ensure_ascii=True,
separators=(',', ':'))
# Prepare headers
headers = {
"Accept": "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"",
"Content-Type": "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"",
"User-Agent": "PyFed/1.0",
"Host": parsed_url.netloc
"Host": host,
"Date": datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT')
}
# Sign request
......@@ -138,51 +146,51 @@ class ActivityDelivery:
method="POST",
path=parsed_url.path,
headers=headers,
body=activity
body=activity # Pass the original activity for consistent hashing
)
async with aiohttp.ClientSession() as session:
async with await session.post(
inbox_url,
json=activity,
headers=signed_headers,
timeout=self.timeout
) as response:
result.status_code = response.status
error_text = await response.text()
if response.status in [200, 201, 202]:
result.success = [inbox_url]
return result
if response.status == 429:
retry_after = int(response.headers.get('Retry-After', self.retry_delay))
result.retry_after = retry_after
await asyncio.sleep(retry_after)
if attempt < self.max_retries:
return await self._deliver_to_inbox(
activity, inbox_url, attempt + 1
)
result.error_message = f"Delivery failed: {response.status} - {error_text}"
result.failed = [inbox_url]
# Send the request with the exact same JSON string we hashed
async with self.session.post(
inbox_url,
data=activity_json, # Use the exact JSON string
headers={
**signed_headers,
"Content-Type": "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\""
}
) as response:
result.status_code = response.status
error_text = await response.text()
if response.status in [200, 201, 202]:
result.success = [inbox_url]
return result
if response.status == 429:
retry_after = int(response.headers.get('Retry-After', self.retry_delay))
result.retry_after = retry_after
await asyncio.sleep(retry_after)
if attempt < self.max_retries:
return await self._deliver_to_inbox(
activity, inbox_url, attempt + 1
)
result.error_message = f"Delivery failed: {response.status} - {error_text}"
result.failed = [inbox_url]
return result
except asyncio.TimeoutError:
result.error_message = f"Delivery timeout to {inbox_url}"
logger.warning(f"Request timed out for {inbox_url}")
result.error_message = f"Request timed out after {self.timeout} seconds"
result.failed = [inbox_url]
except Exception as e:
result.error_message = str(e)
result.failed = [inbox_url]
# Retry with backoff if attempts remain
if attempt < self.max_retries:
retry_delay = self.retry_delay * attempt # Exponential backoff
logger.info(f"Retrying in {retry_delay} seconds (attempt {attempt + 1}/{self.max_retries})")
await asyncio.sleep(retry_delay)
return await self._deliver_to_inbox(activity, inbox_url, attempt + 1)
if attempt < self.max_retries:
await asyncio.sleep(self.retry_delay)
return await self._deliver_to_inbox(
activity, inbox_url, attempt + 1
)
return result
return result
async def _try_shared_inbox(self,
activity: Dict[str, Any],
......
......@@ -23,6 +23,7 @@ import hashlib
from ..utils.exceptions import SignatureError
from ..utils.logging import get_logger
from ..cache.memory_cache import MemoryCache
from .key_management import KeyManager
logger = get_logger(__name__)
......@@ -44,19 +45,22 @@ class HTTPSignatureVerifier:
"""Enhanced HTTP signature verification."""
def __init__(self,
private_key_path: str,
public_key_path: str,
key_manager: Optional[KeyManager] = None,
private_key_path: str = "",
public_key_path: str = "",
key_id: Optional[str] = None):
"""Initialize signature verifier."""
self.key_manager = key_manager
self.private_key_path = private_key_path
self.public_key_path = public_key_path
self.key_id = key_id or f"https://{urlparse(private_key_path).netloc}#main-key"
self.key_id = key_id
self.signature_cache = SignatureCache()
self._test_now = None
# Load keys
self.private_key = self._load_private_key()
self.public_key = self._load_public_key()
# Load keys only if paths are provided and no key_manager is present
if not self.key_manager and (private_key_path or public_key_path):
self.private_key = self._load_private_key()
self.public_key = self._load_public_key()
def set_test_time(self, test_time: datetime) -> None:
"""Set a fixed time for testing."""
......@@ -154,35 +158,23 @@ class HTTPSignatureVerifier:
path: str,
headers: Dict[str, str],
body: Optional[Dict[str, Any]] = None) -> Dict[str, str]:
"""
Sign HTTP request.
Args:
method: HTTP method
path: Request path
headers: Request headers
body: Request body (optional)
Returns:
Dict with signature headers
"""
"""Sign HTTP request."""
try:
request_headers = headers.copy()
# Add date if not present, using test time if set
# Add date if not present
if 'date' not in request_headers:
now = self._test_now if self._test_now is not None else datetime.utcnow()
request_headers['date'] = now.strftime('%a, %d %b %Y %H:%M:%S GMT')
# Add digest if body present
# Calculate digest first
if body is not None:
body_digest = self._generate_digest(body)
request_headers['digest'] = f"SHA-256={body_digest}"
digest = self._generate_digest(body)
request_headers['digest'] = f"SHA-256={digest}"
logger.debug(f"Added digest header: SHA-256={digest}")
# Headers to sign
headers_to_sign = ['(request-target)', 'host', 'date']
if body is not None:
headers_to_sign.append('digest')
# Headers to sign (in specific order)
headers_to_sign = ['(request-target)', 'host', 'date', 'digest']
# Build signing string
signing_string = self._build_signing_string(
......@@ -191,10 +183,19 @@ class HTTPSignatureVerifier:
request_headers,
headers_to_sign
)
logger.debug(f"Headers being signed: {headers_to_sign}")
logger.debug(f"Signing string: {signing_string}")
# Get private key
if self.key_manager:
private_key = await self.key_manager.get_active_private_key()
else:
private_key = self.private_key
# Sign
signature = self.private_key.sign(
signing_string.encode(),
signature = private_key.sign(
signing_string.encode('utf-8'),
padding.PKCS1v15(),
hashes.SHA256()
)
......@@ -207,10 +208,14 @@ class HTTPSignatureVerifier:
f'signature="{base64.b64encode(signature).decode()}"'
)
return {
# Return headers with both Digest and Signature
signed_headers = {
**request_headers,
'Signature': signature_header
}
logger.debug(f"Final headers: {signed_headers}")
return signed_headers
except Exception as e:
logger.error(f"Request signing failed: {e}")
......@@ -274,20 +279,65 @@ class HTTPSignatureVerifier:
headers: Dict[str, str],
signed_headers: List[str]) -> str:
"""Build string to sign."""
lines = []
for header in signed_headers:
if header == '(request-target)':
lines.append(f"(request-target): {method.lower()} {path}")
else:
if header.lower() not in [k.lower() for k in headers]:
raise SignatureError(f"Missing required header: {header}")
lines.append(f"{header.lower()}: {headers[header]}")
try:
lines = []
# Convert headers to case-insensitive dictionary
headers_lower = {k.lower(): v for k, v in headers.items()}
for header in signed_headers:
if header == '(request-target)':
lines.append(f"(request-target): {method.lower()} {path}")
else:
header_lower = header.lower()
if header_lower not in headers_lower:
logger.error(f"Missing required header: {header}")
logger.error(f"Available headers: {list(headers_lower.keys())}")
raise SignatureError(f"Missing required header: {header}")
lines.append(f"{header_lower}: {headers_lower[header_lower]}")
return '\n'.join(lines)
signing_string = '\n'.join(lines)
logger.debug(f"Signing string: {signing_string}")
return signing_string
except Exception as e:
logger.error(f"Failed to build signing string: {e}")
logger.error(f"Headers: {headers}")
logger.error(f"Signed headers: {signed_headers}")
raise SignatureError(f"Failed to build signing string: {e}")
def _generate_digest(self, body: Dict[str, Any]) -> str:
"""Generate SHA-256 digest of body."""
body_bytes = json.dumps(body, sort_keys=True).encode()
digest = hashlib.sha256(body_bytes).digest()
return base64.b64encode(digest).decode()
\ No newline at end of file
try:
# Convert body to JSON string with canonical form
# Use compact separators and sort keys
body_json = json.dumps(body,
sort_keys=True,
ensure_ascii=True,
separators=(',', ':'))
# Debug the exact string being hashed
logger.debug(f"JSON string being hashed (length={len(body_json)}): {body_json}")
# Hash the exact bytes that will be sent
body_bytes = body_json.encode('utf-8')
logger.debug(f"Bytes being hashed (length={len(body_bytes)}): {body_bytes}")
# Calculate SHA-256
hasher = hashlib.sha256()
hasher.update(body_bytes)
digest = hasher.digest()
# Base64 encode
digest_b64 = base64.b64encode(digest).decode('ascii')
# Compare with Mastodon's expected digest
logger.debug("Digest comparison:")
logger.debug(f"Our digest: {digest_b64}")
logger.debug(f"Mastodon wants: bzRP4UzlS2nkynOrOT0Nk+m2twsOLoeOcJSXNsk3NW0=")
return digest_b64
except Exception as e:
logger.error(f"Failed to generate digest: {e}")
logger.error(f"Body: {body}")
raise SignatureError(f"Failed to generate digest: {e}")
\ No newline at end of file
......@@ -58,18 +58,25 @@ class KeyManager:
async def initialize(self) -> None:
"""Initialize key manager."""
try:
logger.info(f"Initializing key manager with path: {self.keys_path}")
# Create keys directory
self.keys_path.mkdir(parents=True, exist_ok=True)
logger.info("Created keys directory")
# Load existing keys
await self._load_existing_keys()
logger.info(f"Loaded {len(self.active_keys)} existing keys")
# Generate initial keys if none exist
if not self.active_keys:
logger.info("No active keys found, generating new key pair")
await self.generate_key_pair()
logger.info(f"Generated new key pair, total active keys: {len(self.active_keys)}")
# Start rotation task
self._rotation_task = asyncio.create_task(self._key_rotation_loop())
logger.info("Started key rotation task")
except Exception as e:
logger.error(f"Failed to initialize key manager: {e}")
......@@ -78,6 +85,7 @@ class KeyManager:
async def generate_key_pair(self) -> KeyPair:
"""Generate new key pair."""
try:
logger.info("Generating new key pair")
# Generate keys
private_key = rsa.generate_private_key(
public_exponent=65537,
......@@ -89,8 +97,18 @@ class KeyManager:
created_at = datetime.utcnow()
expires_at = created_at + timedelta(days=self.rotation_config.rotation_interval)
# Generate key ID
key_id = f"https://{self.domain}/keys/{created_at.timestamp()}"
# Generate key ID (for HTTP use)
key_id = f"https://{self.domain}/keys/{int(created_at.timestamp())}"
logger.debug(f"\n\n\n\nkey id - key management {key_id}")
# Generate safe file path (for storage)
safe_timestamp = str(int(created_at.timestamp()))
safe_domain = self.domain.replace(':', '_').replace('/', '_').replace('.', '_')
safe_path = f"{safe_domain}_{safe_timestamp}"
logger.info(f"Generated key ID: {key_id}")
logger.info(f"Safe path: {safe_path}")
# Create key pair
key_pair = KeyPair(
......@@ -101,11 +119,13 @@ class KeyManager:
key_id=key_id
)
# Save keys
await self._save_key_pair(key_pair)
# Save keys with safe path
await self._save_key_pair(key_pair, safe_path)
logger.info("Saved key pair to disk")
# Add to active keys
self.active_keys[key_id] = key_pair
logger.info(f"Added key pair to active keys. Total active keys: {len(self.active_keys)}")
return key_pair
......@@ -152,6 +172,7 @@ class KeyManager:
key=lambda k: k.created_at
)
async def verify_key(self, key_id: str, domain: str) -> bool:
"""Verify a key's validity."""
try:
......@@ -171,12 +192,20 @@ class KeyManager:
async def _load_existing_keys(self) -> None:
"""Load existing keys from disk."""
try:
for key_file in self.keys_path.glob("*.json"):
# Recursively search for all json files
for key_file in self.keys_path.rglob("*.json"):
logger.info(f"Found key metadata file: {key_file}")
async with aiofiles.open(key_file, 'r') as f:
metadata = json.loads(await f.read())
# Load private key
private_key_path = self.keys_path / f"{metadata['key_id']}_private.pem"
# Get the private key path from the same directory as the metadata
private_key_path = key_file.parent / f"{key_file.stem}_private.pem"
logger.info(f"Looking for private key at: {private_key_path}")
if not private_key_path.exists():
logger.warning(f"Private key not found at {private_key_path}")
continue
async with aiofiles.open(private_key_path, 'rb') as f:
private_key = serialization.load_pem_private_key(
await f.read(),
......@@ -195,16 +224,19 @@ class KeyManager:
# Add to active keys if not expired
if datetime.utcnow() <= key_pair.expires_at:
self.active_keys[key_pair.key_id] = key_pair
logger.info(f"Loaded active key: {key_pair.key_id}")
else:
logger.info(f"Skipping expired key: {key_pair.key_id}")
except Exception as e:
logger.error(f"Failed to load existing keys: {e}")
raise KeyManagementError(f"Failed to load existing keys: {e}")
async def _save_key_pair(self, key_pair: KeyPair) -> None:
async def _save_key_pair(self, key_pair: KeyPair, safe_path: str) -> None:
"""Save key pair to disk."""
try:
# Save private key
private_key_path = self.keys_path / f"{key_pair.key_id}_private.pem"
private_key_path = self.keys_path / f"{safe_path}_private.pem"
private_pem = key_pair.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
......@@ -214,7 +246,7 @@ class KeyManager:
await f.write(private_pem)
# Save public key
public_key_path = self.keys_path / f"{key_pair.key_id}_public.pem"
public_key_path = self.keys_path / f"{safe_path}_public.pem"
public_pem = key_pair.public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
......@@ -226,9 +258,10 @@ class KeyManager:
metadata = {
'key_id': key_pair.key_id,
'created_at': key_pair.created_at.isoformat(),
'expires_at': key_pair.expires_at.isoformat()
'expires_at': key_pair.expires_at.isoformat(),
'safe_path': safe_path
}
metadata_path = self.keys_path / f"{key_pair.key_id}.json"
metadata_path = self.keys_path / f"{safe_path}.json"
async with aiofiles.open(metadata_path, 'w') as f:
await f.write(json.dumps(metadata))
......@@ -284,3 +317,15 @@ class KeyManager:
await self._rotation_task
except asyncio.CancelledError:
pass
async def get_active_private_key(self) -> RSAPrivateKey:
"""Get the most recent active private key."""
if not self.active_keys:
raise KeyManagementError("No active keys available")
# Return the private key of the most recently created key
most_recent_key = max(
self.active_keys.values(),
key=lambda k: k.created_at
)
return most_recent_key.private_key
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment