diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..39249ed1a2b4edc7e10e4b738b0a899b868a0d3f --- /dev/null +++ b/examples/README.md @@ -0,0 +1,2 @@ +To run examples, use the following command: +`python <example.py>` \ No newline at end of file diff --git a/examples/federation_example.py b/examples/federation_example.py new file mode 100644 index 0000000000000000000000000000000000000000..8c2d0661af79511f7197bced024ff90d72511b69 --- /dev/null +++ b/examples/federation_example.py @@ -0,0 +1,200 @@ +""" +Federation examples showing common ActivityPub interactions. +""" + +import asyncio +import logging +from datetime import datetime, timezone +from pathlib import Path +import sys + +# Add src directory to Python path +src_path = Path(__file__).parent.parent / "src" +sys.path.insert(0, str(src_path)) + +from pyfed.security.key_management import KeyManager +from pyfed.federation.delivery import ActivityDelivery +from pyfed.federation.discovery import InstanceDiscovery +from pyfed.protocols.webfinger import WebFingerClient +from pyfed.models import APCreate, APNote, APPerson, APFollow, APLike, APAnnounce + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class FederationExample: + """Federation interaction examples.""" + + def __init__(self, domain: str = "example.com"): + self.domain = domain + self.key_manager = None + self.delivery = None + self.webfinger = None + self.discovery = None + + async def initialize(self): + """Initialize components.""" + # Initialize key manager + self.key_manager = KeyManager( + domain=self.domain, + keys_path=str(Path("example_keys").resolve()) + ) + logger.info("Initializing key manager...") + await self.key_manager.initialize() + + # Initialize delivery + self.delivery = ActivityDelivery( + key_manager=self.key_manager, + timeout=30 + ) + logger.info("Initializing delivery...") + await self.delivery.initialize() + + # Initialize WebFinger client with SSL verification disabled for testing + self.webfinger = WebFingerClient(verify_ssl=False) + logger.info("Initializing WebFinger client...") + await self.webfinger.initialize() + + async def send_public_post(self, content: str): + """Send a public post to the Fediverse.""" + logger.info(f"Sending public post: {content}") + + # Create local actor + actor = APPerson( + id=f"https://{self.domain}/users/alice", + name="Alice", + preferred_username="alice", + inbox=f"https://{self.domain}/users/alice/inbox", + outbox=f"https://{self.domain}/users/alice/outbox" + ) + + # Create note + note = APNote( + id=f"https://{self.domain}/notes/{datetime.utcnow().timestamp()}", + content=content, + attributed_to=str(actor.id), + to=["https://www.w3.org/ns/activitystreams#Public"], + published=datetime.utcnow().isoformat() + ) + + # Create activity + create_activity = APCreate( + id=f"https://{self.domain}/activities/{datetime.utcnow().timestamp()}", + actor=str(actor.id), + object=note, + to=note.to, + published=datetime.utcnow().isoformat() + ) + + # Deliver to followers (example) + activity_dict = create_activity.serialize() + result = await self.delivery.deliver_activity( + activity=activity_dict, + recipients=[f"https://{self.domain}/followers"] + ) + logger.info(f"Delivery result: {result}") + + async def send_direct_message(self, recipient: str, content: str): + """Send a direct message to a specific user.""" + logger.info(f"Sending direct message to {recipient}") + + # Resolve recipient's inbox + inbox_url = await self.webfinger.get_inbox_url(recipient) + if not inbox_url: + logger.error(f"Could not find inbox for {recipient}") + return + + # Create note + note = APNote( + id=f"https://{self.domain}/notes/{datetime.utcnow().timestamp()}", + content=content, + attributed_to=f"https://{self.domain}/users/alice", + to=[inbox_url], + published=datetime.utcnow().isoformat() + ) + + # Create activity + create_activity = APCreate( + id=f"https://{self.domain}/activities/{datetime.utcnow().timestamp()}", + actor=f"https://{self.domain}/users/alice", + object=note, + to=note.to, + published=datetime.utcnow().isoformat() + ) + + # Deliver direct message + activity_dict = create_activity.serialize() + result = await self.delivery.deliver_activity( + activity=activity_dict, + recipients=[inbox_url] + ) + logger.info(f"Delivery result: {result}") + + async def follow_account(self, account: str): + """Follow a remote account.""" + logger.info(f"Following account: {account}") + + # Resolve account + actor_url = await self.webfinger.get_actor_url(account) + if not actor_url: + logger.error(f"Could not resolve account {account}") + return + + # Create Follow activity + follow = APFollow( + id=f"https://{self.domain}/activities/follow_{datetime.utcnow().timestamp()}", + actor=f"https://{self.domain}/users/alice", + object=actor_url, + published=datetime.utcnow().isoformat() + ) + + # Get target inbox + inbox_url = await self.webfinger.get_inbox_url(account) + if not inbox_url: + logger.error(f"Could not find inbox for {account}") + return + + # Deliver Follow activity + activity_dict = follow.serialize() + result = await self.delivery.deliver_activity( + activity=activity_dict, + recipients=[inbox_url] + ) + logger.info(f"Follow result: {result}") + + async def close(self): + """Clean up resources.""" + if self.delivery and hasattr(self.delivery, 'close'): + await self.delivery.close() + if self.webfinger and hasattr(self.webfinger, 'close'): + await self.webfinger.close() + +async def main(): + """Run federation examples.""" + federation = FederationExample() + logger.info("Initializing federation...") + await federation.initialize() + + try: + # Example 1: Send public post + logger.info("Sending public post...") + await federation.send_public_post( + "Hello #Fediverse! This is a test post from PyFed!" + ) + + # Example 2: Send direct message + logger.info("Sending direct message...") + await federation.send_direct_message( + "kene29@mastodon.social", + "Hello! This is a direct message test from PyFed." + ) + + # Example 3: Follow account + logger.info("Following account...") + await federation.follow_account("kene29@mastodon.social") + + finally: + await federation.close() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/integration_examples.py b/examples/integration_examples.py new file mode 100644 index 0000000000000000000000000000000000000000..3c483af555494f80313a6625b6957d47eb6116a3 --- /dev/null +++ b/examples/integration_examples.py @@ -0,0 +1,73 @@ +""" +Framework integration examples. +""" + +# FastAPI Integration +from fastapi import FastAPI +from pyfed.integration.frameworks.fastapi import FastAPIIntegration +from pyfed.integration.config import IntegrationConfig + +async def fastapi_example(): + """FastAPI integration example.""" + # Create config + config = IntegrationConfig( + domain="example.com", + database_url="postgresql://user:pass@localhost/pyfed", + redis_url="redis://localhost", + media_path="uploads/", + key_path="keys/" + ) + + # Initialize integration + integration = FastAPIIntegration(config) + await integration.initialize() + + app = integration.app + + # Add custom routes + @app.get("/custom") + async def custom_route(): + return {"message": "Custom route"} + + return app + +# Django Integration +from django.urls import path +from pyfed.integration.frameworks.django import DjangoIntegration + +# settings.py +PYFED_CONFIG = { + 'domain': 'example.com', + 'database_url': 'postgresql://user:pass@localhost/pyfed', + 'redis_url': 'redis://localhost', + 'media_path': 'uploads/', + 'key_path': 'keys/', +} + +# urls.py +from django.urls import path, include + +urlpatterns = [ + path('', include('pyfed.integration.frameworks.django.urls')), +] + +# Flask Integration +from flask import Flask +from pyfed.integration.frameworks.flask import FlaskIntegration + +def flask_example(): + """Flask integration example.""" + app = Flask(__name__) + + config = IntegrationConfig( + domain="example.com", + database_url="postgresql://user:pass@localhost/pyfed", + redis_url="redis://localhost", + media_path="uploads/", + key_path="keys/" + ) + + integration = FlaskIntegration(config) + integration.initialize() + + return app \ No newline at end of file diff --git a/examples/note_example.py b/examples/note_example.py new file mode 100644 index 0000000000000000000000000000000000000000..c966d54594fc4141f92563910cf5459e5a5b5d6e --- /dev/null +++ b/examples/note_example.py @@ -0,0 +1,147 @@ +""" +Note creation and interaction examples. +""" + +import asyncio +import logging +from datetime import datetime, timezone +import os +import sys +from pathlib import Path +import ssl + +# Add src directory to Python path +src_path = Path(__file__).parent.parent / "src" +sys.path.insert(0, str(src_path)) + +from pyfed.models import APCreate, APNote, APPerson +from pyfed.security import KeyManager +from pyfed.storage import StorageBackend +from pyfed.federation import ActivityDelivery +from pyfed.protocols.webfinger import WebFingerClient +from pyfed.serializers.json_serializer import to_json, from_json + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def ensure_directories(): + """Ensure required directories exist.""" + dirs = [ + "example_keys", + "example_data", + "example_media" + ] + + for dir_name in dirs: + path = Path(dir_name) + path.mkdir(parents=True, exist_ok=True) + logger.info(f"Ensured directory exists: {path}") + +async def create_note_example(): + """Example of creating and delivering a note.""" + try: + ensure_directories() + + # Initialize storage + storage = StorageBackend.create( + provider="sqlite", + database_url="example_data/pyfed_example.db" + ) + await storage.initialize() + + # Initialize key manager + key_manager = KeyManager( + domain="example.com", + active_keys=Path("example_keys").resolve(), + keys_path=str(Path("example_keys").resolve()) + ) + await key_manager.initialize() + + # Initialize delivery + delivery = ActivityDelivery( + key_manager=key_manager, + timeout=30 + ) + await delivery.initialize() + + # Initialize WebFinger client with SSL verification disabled for testing + webfinger = WebFingerClient(verify_ssl=False) + await webfinger.initialize() + + # Get inbox URL for recipient + recipient = "kene29@mastodon.social" + logger.info(f"Looking up inbox for {recipient}...") + inbox_url = await webfinger.get_inbox_url(recipient) + if not inbox_url: + logger.error(f"Could not find inbox for {recipient}") + return + + logger.info(f"Found inbox URL: {inbox_url}") + + # Create actor + actor = APPerson( + id="https://example.com/users/alice", + name="Alice", + preferred_username="alice", + inbox="https://example.com/users/alice/inbox", + outbox="https://example.com/users/alice/outbox", + followers="https://example.com/users/alice/followers" + ) + + # Create note with string attributed_to + note = APNote( + id=f"https://example.com/notes/{datetime.now(timezone.utc).timestamp()}", + content=f"Hello @{recipient}! This is a test note!", + attributed_to=str(actor.id), # Convert URL to string + to=[inbox_url], + cc=["https://www.w3.org/ns/activitystreams#Public"], + published=datetime.now(timezone.utc).isoformat() + ) + + # Create activity + create_activity = APCreate( + id=f"https://example.com/activities/{datetime.now(timezone.utc).timestamp()}", + actor=str(actor.id), # Convert URL to string + object=note, + to=note.to, + cc=note.cc, + published=datetime.now(timezone.utc).isoformat(), + ) + + # Serialize and deliver + logger.info("Serializing activity...") + activity_dict = create_activity.serialize() + logger.info(f"Serialized activity: {activity_dict}") + logger.info(f"Activity: {to_json(create_activity, indent=2)}") + + logger.info("Delivering activity...") + result = await delivery.deliver_activity( + activity=activity_dict, + recipients=[inbox_url] + ) + logger.info(f"Delivery result: {result}") + + except Exception as e: + logger.error(f"Error in note example: {e}") + raise + finally: + if 'storage' in locals(): + await storage.close() + if 'delivery' in locals(): + await delivery.close() + if 'webfinger' in locals(): + await webfinger.close() + +def main(): + """Main entry point.""" + try: + asyncio.run(create_note_example()) + except KeyboardInterrupt: + logger.info("Example stopped by user") + except Exception as e: + logger.error(f"Example failed: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/run_example.py b/examples/run_example.py new file mode 100644 index 0000000000000000000000000000000000000000..420ec82ea95b1e9e4196ebaaa6e8248b3c6b4aa7 --- /dev/null +++ b/examples/run_example.py @@ -0,0 +1,145 @@ +""" +PyFed runnable example. +""" + +import asyncio +import logging +from datetime import datetime +import os +import sys +from pathlib import Path + +# Add src directory to Python path +src_path = Path(__file__).parent.parent / "src" +sys.path.insert(0, str(src_path)) + +from pyfed.models import APCreate, APNote, APPerson, APLike +from pyfed.security import KeyManager +from pyfed.storage import StorageBackend +from pyfed.federation import ActivityDelivery +from pyfed.serializers.json_serializer import to_json, from_json + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def ensure_directories(): + """Ensure required directories exist.""" + dirs = [ + "example_keys", + "example_data", + "example_media" + ] + + for dir_name in dirs: + path = Path(dir_name) + path.mkdir(parents=True, exist_ok=True) + logger.info(f"Ensured directory exists: {path}") + +async def run_examples(): + """Run PyFed examples.""" + try: + logger.info("Setting up PyFed components...") + ensure_directories() + + # Initialize storage + storage = StorageBackend.create( + provider="sqlite", + database_url="example_data/pyfed_example.db" + ) + await storage.initialize() + + # Initialize key manager + key_manager = KeyManager( + domain="example.com", + active_keys=Path("example_keys").resolve(), + keys_path=str(Path("example_keys").resolve()) + ) + await key_manager.initialize() + + # Initialize delivery + delivery = ActivityDelivery( + key_manager=key_manager, + timeout=30 + ) + + # Create actor + actor = APPerson( + id="https://example.com/users/alice", + name="Alice", + preferred_username="alice", + inbox="https://example.com/users/alice/inbox", + outbox="https://example.com/users/alice/outbox", + followers="https://example.com/users/alice/followers" + ) + + # Create note + note = APNote( + id="https://example.com/notes/123", + content="Hello, Federation! #test @bob@remote.com", + attributed_to=str(actor.id), + to=["https://www.w3.org/ns/activitystreams#Public"], + published=datetime.utcnow().isoformat() + ) + + # Create activity + create_activity = APCreate( + id=f"https://example.com/activities/{datetime.utcnow().timestamp()}", + actor=str(actor.id), + object=note, + to=note.to, + published=datetime.utcnow().isoformat() + ) + + # Store activity + logger.info("Storing activity...") + activity_dict = create_activity.serialize() + logger.info(f"Serialized activity (with context):\n{to_json(create_activity, indent=2)}") + activity_id = await storage.create_activity(activity_dict) + logger.info(f"Activity stored with ID: {activity_id}") + + # Create like activity + like_activity = APLike( + id=f"https://example.com/activities/like_{datetime.utcnow().timestamp()}", + actor=str(actor.id), + object=note.id, + to=["https://www.w3.org/ns/activitystreams#Public"], + published=datetime.utcnow().isoformat() + ) + + # Store like + logger.info("Storing like activity...") + like_dict = like_activity.serialize() + logger.info(f"Serialized like activity (with context):\n{to_json(like_activity, indent=2)}") + like_id = await storage.create_activity(like_dict) + logger.info(f"Like activity stored with ID: {like_id}") + + # Retrieve activities + logger.info("\nRetrieving activities...") + stored_activity = await storage.get_activity(activity_id) + stored_like = await storage.get_activity(like_id) + + logger.info("Retrieved Create activity:") + logger.info(stored_activity) + logger.info("\nRetrieved Like activity:") + logger.info(stored_like) + + except Exception as e: + logger.error(f"Error running examples: {e}") + raise + finally: + if storage: + await storage.close() + +def main(): + """Main entry point.""" + try: + asyncio.run(run_examples()) + except KeyboardInterrupt: + logger.info("Example stopped by user") + except Exception as e: + logger.error(f"Example failed: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 52ccba8ad30b97da82f05abb530f02f2cea6c6ed..c8fb7e92494ece87076a27f43ca73e70b0a92c52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,30 @@ +aiofiles==24.1.0 aiohappyeyeballs==2.4.3 aiohttp==3.10.10 aioredis==2.0.1 aiosignal==1.3.1 +aiosqlite==0.20.0 annotated-types==0.7.0 +anyio==4.6.2.post1 async-timeout==5.0.0 +asyncpg==0.30.0 attrs==24.2.0 +backoff==2.2.1 +beautifulsoup4==4.12.3 +bs4==0.0.2 +cachetools==5.5.0 cffi==1.17.1 cryptography==43.0.3 +dnspython==2.7.0 factory_boy==3.3.1 Faker==30.8.0 +fastapi==0.115.4 frozenlist==1.4.1 greenlet==3.1.1 idna==3.10 iniconfig==2.0.0 +Markdown==3.7 +motor==3.6.0 multidict==6.1.0 packaging==24.1 pluggy==1.5.0 @@ -22,11 +34,17 @@ pycparser==2.22 pydantic==2.9.2 pydantic_core==2.23.4 PyJWT==2.9.0 +pymongo==4.9.2 pytest==8.3.3 +pytest-aiohttp==1.0.5 pytest-asyncio==0.24.0 python-dateutil==2.9.0.post0 +PyYAML==6.0.2 redis==5.2.0 six==1.16.0 +sniffio==1.3.1 +soupsieve==2.6 SQLAlchemy==2.0.36 +starlette==0.41.2 typing_extensions==4.12.2 -yarl==1.15.2 +yarl==1.15.2 \ No newline at end of file diff --git a/setup.py b/setup.py index ca99640739d50ee7f5a4aed9f3282945f107bf88..331dde77841330d9a91d82aca17be3d441bb0b48 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,29 @@ from setuptools import setup, find_packages setup( name="pyfed", - packages=find_packages(where="src"), + version="0.1.0", package_dir={"": "src"}, + packages=find_packages(where="src"), + python_requires=">=3.8", + install_requires=[ + "cryptography>=3.4.7", + "aiohttp>=3.8.0", + "pydantic>=1.8.2", + "sqlalchemy>=1.4.0", + "aiosqlite>=0.17.0", + "asyncpg>=0.25.0", + "redis>=4.0.0", + "beautifulsoup4>=4.9.3", + "markdown>=3.3.4", + ], + extras_require={ + "fastapi": ["fastapi>=0.68.0", "uvicorn>=0.15.0"], + "django": ["django>=3.2.0"], + "flask": ["flask>=2.0.0"], + "dev": [ + "pytest>=6.2.5", + "pytest-asyncio>=0.15.1", + "pytest-cov>=2.12.1", + ], + }, ) \ No newline at end of file diff --git a/src/pyfed/__init__.py b/src/pyfed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c99bded97e6d92532f2f90a2cd0d682770776e33 --- /dev/null +++ b/src/pyfed/__init__.py @@ -0,0 +1,27 @@ +""" +pyfed package initialization. +""" +# from .plugins import plugin_manager # Update the import path + +# Export all models +from .models import ( + APObject, APEvent, APPlace, APProfile, APRelationship, APTombstone, + APArticle, APAudio, APDocument, APImage, APNote, APPage, APVideo, + APActor, APPerson, APGroup, APOrganization, APApplication, APService, + APLink, APMention, + APCollection, APOrderedCollection, APCollectionPage, APOrderedCollectionPage, + APCreate, APUpdate, APDelete, APFollow, APUndo, APLike, APAnnounce +) + +# Export serializers +from .serializers.json_serializer import ActivityPubSerializer + +__all__ = [ + 'APObject', 'APEvent', 'APPlace', 'APProfile', 'APRelationship', 'APTombstone', + 'APArticle', 'APAudio', 'APDocument', 'APImage', 'APNote', 'APPage', 'APVideo', + 'APActor', 'APPerson', 'APGroup', 'APOrganization', 'APApplication', 'APService', + 'APLink', 'APMention', + 'APCollection', 'APOrderedCollection', 'APCollectionPage', 'APOrderedCollectionPage', + 'APCreate', 'APUpdate', 'APDelete', 'APFollow', 'APUndo', 'APLike', 'APAnnounce', + 'ActivityPubSerializer' +] diff --git a/src/pyfed/cache/__init__.py b/src/pyfed/cache/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e37035a5594facca9b5f86f959d97b02f77e75a9 --- /dev/null +++ b/src/pyfed/cache/__init__.py @@ -0,0 +1,14 @@ +""" +Caching package for ActivityPub data. +""" + +from .actor_cache import ActorCache +from .webfinger_cache import WebFingerCache +from .cache import Cache, object_cache + +__all__ = [ + 'ActorCache', + 'WebFingerCache', + 'Cache', + 'object_cache' +] diff --git a/src/pyfed/cache/actor_cache.py b/src/pyfed/cache/actor_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..b34a361d49e3080863b7049178da8c9af658b445 --- /dev/null +++ b/src/pyfed/cache/actor_cache.py @@ -0,0 +1,30 @@ +""" +Actor cache implementation. +""" + +from typing import Dict, Any, Optional +from datetime import datetime + +class ActorCache: + """Cache for actor data.""" + + def __init__(self, cache, ttl: int = 3600): + """Initialize actor cache.""" + self.cache = cache + self.ttl = ttl + + async def get(self, actor_id: str) -> Optional[Dict[str, Any]]: + """Get actor data from cache.""" + return await self.cache.get(f"actor:{actor_id}") + + async def set(self, actor_id: str, actor_data: Dict[str, Any]) -> None: + """Set actor data in cache.""" + await self.cache.set(f"actor:{actor_id}", actor_data, self.ttl) + + async def delete(self, actor_id: str) -> None: + """Delete actor data from cache.""" + await self.cache.delete(f"actor:{actor_id}") + + async def clear(self) -> None: + """Clear all cached actors.""" + await self.cache.clear() \ No newline at end of file diff --git a/src/pyfed/cache/cache.py b/src/pyfed/cache/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d96ed96026891b5041976c749c5e23066db66e --- /dev/null +++ b/src/pyfed/cache/cache.py @@ -0,0 +1,32 @@ +from typing import Any, Optional +from functools import lru_cache +from datetime import datetime, timedelta + +class Cache: + def __init__(self, max_size: int = 100, ttl: int = 300): + self.max_size = max_size + self.ttl = ttl + self.cache = {} + + def get(self, key: str) -> Optional[Any]: + if key in self.cache: + value, timestamp = self.cache[key] + if datetime.now() - timestamp < timedelta(seconds=self.ttl): + return value + else: + del self.cache[key] + return None + + def set(self, key: str, value: Any): + if len(self.cache) >= self.max_size: + oldest_key = min(self.cache, key=lambda k: self.cache[k][1]) + del self.cache[oldest_key] + self.cache[key] = (value, datetime.now()) + +# Create a global cache instance +object_cache = Cache() + +@lru_cache(maxsize=100) +def expensive_computation(arg1, arg2): + # This is a placeholder for any expensive computation + return arg1 + arg2 diff --git a/src/pyfed/cache/memory_cache.py b/src/pyfed/cache/memory_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe359ad749de5121c2849849952b1514126049f --- /dev/null +++ b/src/pyfed/cache/memory_cache.py @@ -0,0 +1,45 @@ +""" +In-memory cache implementation. +""" + +from typing import Dict, Any, Optional +from datetime import datetime, timedelta + +class MemoryCache: + """Simple in-memory cache.""" + + def __init__(self, ttl: int = 3600): + """Initialize cache.""" + self.data: Dict[str, Any] = {} + self.expires: Dict[str, datetime] = {} + self.ttl = ttl + + async def get(self, key: str) -> Optional[Any]: + """Get value from cache.""" + if key not in self.data: + return None + + # Check expiration + if datetime.utcnow() > self.expires[key]: + del self.data[key] + del self.expires[key] + return None + + return self.data[key] + + async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: + """Set value in cache.""" + self.data[key] = value + self.expires[key] = datetime.utcnow() + timedelta( + seconds=ttl if ttl is not None else self.ttl + ) + + async def delete(self, key: str) -> None: + """Delete value from cache.""" + self.data.pop(key, None) + self.expires.pop(key, None) + + async def clear(self) -> None: + """Clear all cache entries.""" + self.data.clear() + self.expires.clear() \ No newline at end of file diff --git a/src/pyfed/cache/nodeinfo_cache.py b/src/pyfed/cache/nodeinfo_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..76f268ac4a25e3b7ee7aa0c344562736ca60f3d3 --- /dev/null +++ b/src/pyfed/cache/nodeinfo_cache.py @@ -0,0 +1,66 @@ +""" +NodeInfo caching implementation. +""" + +from typing import Optional, Dict, Any +from datetime import datetime, timedelta +import aioredis +import json + +from ..utils.exceptions import CacheError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +class NodeInfoCache: + """NodeInfo cache implementation.""" + + def __init__(self, + redis_url: str = "redis://localhost", + ttl: int = 3600): # 1 hour default + self.redis_url = redis_url + self.ttl = ttl + self.redis = None + + async def initialize(self) -> None: + """Initialize cache.""" + try: + self.redis = await aioredis.from_url(self.redis_url) + except Exception as e: + logger.error(f"Failed to initialize NodeInfo cache: {e}") + raise CacheError(f"Cache initialization failed: {e}") + + async def get(self, domain: str) -> Optional[Dict[str, Any]]: + """Get cached NodeInfo.""" + try: + if not self.redis: + return None + + key = f"nodeinfo:{domain}" + data = await self.redis.get(key) + return json.loads(data) if data else None + + except Exception as e: + logger.error(f"Failed to get from cache: {e}") + return None + + async def set(self, domain: str, data: Dict[str, Any]) -> None: + """Cache NodeInfo data.""" + try: + if not self.redis: + return + + key = f"nodeinfo:{domain}" + await self.redis.set( + key, + json.dumps(data), + ex=self.ttl + ) + + except Exception as e: + logger.error(f"Failed to cache NodeInfo: {e}") + + async def close(self) -> None: + """Clean up resources.""" + if self.redis: + await self.redis.close() \ No newline at end of file diff --git a/src/pyfed/cache/webfinger_cache.py b/src/pyfed/cache/webfinger_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..e3dd30e6976da0607f5b6d95adb205faf647c2ba --- /dev/null +++ b/src/pyfed/cache/webfinger_cache.py @@ -0,0 +1,78 @@ +""" +Caching implementation for WebFinger results. +""" + +from typing import Dict, Any, Optional +from datetime import datetime, timedelta +import json + +from pyfed.storage.backends.redis import RedisStorageBackend +from pyfed.utils.logging import get_logger + +logger = get_logger(__name__) + +class WebFingerCache: + """ + Cache for storing WebFinger lookup results. + + This reduces the need for repeated WebFinger queries. + """ + + def __init__(self, + redis: RedisStorageBackend, + ttl: int = 86400): # 24 hours default TTL + """ + Initialize WebFinger cache. + + Args: + redis: Redis storage backend + ttl: Time-to-live for cached entries in seconds + """ + self.redis = redis + self.ttl = ttl + + async def get(self, resource: str) -> Optional[Dict[str, Any]]: + """ + Get WebFinger result from cache. + + Args: + resource: WebFinger resource (e.g., acct:user@domain) + + Returns: + WebFinger data if cached, None otherwise + """ + try: + data = await self.redis.get(f"webfinger:{resource}") + return json.loads(data) if data else None + except Exception as e: + logger.error(f"Error getting WebFinger from cache: {e}") + return None + + async def set(self, resource: str, data: Dict[str, Any]) -> None: + """ + Cache WebFinger result. + + Args: + resource: WebFinger resource + data: WebFinger result data + """ + try: + await self.redis.set( + f"webfinger:{resource}", + json.dumps(data), + expire=self.ttl + ) + except Exception as e: + logger.error(f"Error caching WebFinger data: {e}") + + async def invalidate(self, resource: str) -> None: + """ + Remove WebFinger result from cache. + + Args: + resource: WebFinger resource + """ + try: + await self.redis.delete(f"webfinger:{resource}") + except Exception as e: + logger.error(f"Error invalidating WebFinger cache: {e}") \ No newline at end of file diff --git a/src/pyfed/config.py b/src/pyfed/config.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2697b830c33d121c88e3a03db50ea70a3435af --- /dev/null +++ b/src/pyfed/config.py @@ -0,0 +1,186 @@ +""" +PyFed configuration management. +""" + +from typing import Dict, Any, Optional, List +from dataclasses import asdict, dataclass, field +from pathlib import Path +import yaml +import json +import os + +from .utils.exceptions import ConfigError +from .utils.logging import get_logger + +logger = get_logger(__name__) + +@dataclass +class DatabaseConfig: + """Database configuration.""" + url: str + min_connections: int = 5 + max_connections: int = 20 + timeout: int = 30 + +@dataclass +class SecurityConfig: + """Security configuration.""" + domain: str + key_path: str + private_key_path: Optional[str] = None + public_key_path: Optional[str] = None + signature_ttl: int = 300 # 5 minutes + max_payload_size: int = 5_000_000 # 5MB + allowed_algorithms: List[str] = field(default_factory=lambda: ["rsa-sha256"]) + +@dataclass +class FederationConfig: + """Federation configuration.""" + domain: str + shared_inbox: bool = True + delivery_timeout: int = 30 + max_recipients: int = 100 + retry_delay: int = 300 + verify_ssl: bool = True + +@dataclass +class MediaConfig: + """Media configuration.""" + upload_path: str + max_size: int = 10_000_000 # 10MB + allowed_types: List[str] = field(default_factory=lambda: [ + 'image/jpeg', + 'image/png', + 'image/gif', + 'video/mp4', + 'audio/mpeg' + ]) + +@dataclass +class StorageConfig: + """Storage configuration.""" + provider: str = "sqlite" + database: DatabaseConfig = field(default_factory=lambda: DatabaseConfig( + url="sqlite:///pyfed.db" + )) + +@dataclass +class PyFedConfig: + """Main PyFed configuration.""" + domain: str + storage: StorageConfig + security: SecurityConfig + federation: FederationConfig + media: MediaConfig + debug: bool = False + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'PyFedConfig': + """Create config from dictionary.""" + try: + # Create nested configs + storage_config = StorageConfig( + provider=data.get('storage', {}).get('provider', 'sqlite'), + database=DatabaseConfig(**data.get('storage', {}).get('database', {})) + ) + + security_config = SecurityConfig( + domain=data['domain'], + **data.get('security', {}) + ) + + federation_config = FederationConfig( + domain=data['domain'], + **data.get('federation', {}) + ) + + media_config = MediaConfig( + upload_path=data.get('media', {}).get('upload_path', 'uploads'), + **data.get('media', {}) + ) + + return cls( + domain=data['domain'], + storage=storage_config, + security=security_config, + federation=federation_config, + media=media_config, + debug=data.get('debug', False) + ) + + except Exception as e: + raise ConfigError(f"Failed to create config: {e}") + + @classmethod + def from_file(cls, path: str) -> 'PyFedConfig': + """Load configuration from file.""" + try: + with open(path) as f: + if path.endswith('.yaml') or path.endswith('.yml'): + data = yaml.safe_load(f) + else: + data = json.load(f) + return cls.from_dict(data) + except Exception as e: + raise ConfigError(f"Failed to load config file: {e}") + + @classmethod + def from_env(cls) -> 'PyFedConfig': + """Load configuration from environment variables.""" + try: + return cls( + domain=os.getenv('PYFED_DOMAIN', 'localhost'), + storage=StorageConfig( + provider=os.getenv('PYFED_STORAGE_PROVIDER', 'sqlite'), + database=DatabaseConfig( + url=os.getenv('PYFED_DATABASE_URL', 'sqlite:///pyfed.db'), + min_connections=int(os.getenv('PYFED_DB_MIN_CONNECTIONS', '5')), + max_connections=int(os.getenv('PYFED_DB_MAX_CONNECTIONS', '20')), + timeout=int(os.getenv('PYFED_DB_TIMEOUT', '30')) + ) + ), + security=SecurityConfig( + domain=os.getenv('PYFED_DOMAIN', 'localhost'), + key_path=os.getenv('PYFED_KEY_PATH', 'keys'), + signature_ttl=int(os.getenv('PYFED_SIGNATURE_TTL', '300')) + ), + federation=FederationConfig( + domain=os.getenv('PYFED_DOMAIN', 'localhost'), + shared_inbox=os.getenv('PYFED_SHARED_INBOX', 'true').lower() == 'true', + delivery_timeout=int(os.getenv('PYFED_DELIVERY_TIMEOUT', '30')), + verify_ssl=os.getenv('PYFED_VERIFY_SSL', 'true').lower() == 'true' + ), + media=MediaConfig( + upload_path=os.getenv('PYFED_UPLOAD_PATH', 'uploads'), + max_size=int(os.getenv('PYFED_MAX_UPLOAD_SIZE', '10000000')) + ), + debug=os.getenv('PYFED_DEBUG', 'false').lower() == 'true' + ) + except Exception as e: + raise ConfigError(f"Failed to load config from env: {e}") + + def to_dict(self) -> Dict[str, Any]: + """Convert config to dictionary.""" + return { + 'domain': self.domain, + 'storage': { + 'provider': self.storage.provider, + 'database': asdict(self.storage.database) + }, + 'security': asdict(self.security), + 'federation': asdict(self.federation), + 'media': asdict(self.media), + 'debug': self.debug + } + + def save(self, path: str) -> None: + """Save configuration to file.""" + try: + data = self.to_dict() + with open(path, 'w') as f: + if path.endswith('.yaml') or path.endswith('.yml'): + yaml.dump(data, f, default_flow_style=False) + else: + json.dump(data, f, indent=2) + except Exception as e: + raise ConfigError(f"Failed to save config: {e}") diff --git a/src/pyfed/content/__init__.py b/src/pyfed/content/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..545ddb0b62a998ec7b7208c0e0183f4c107f5ec0 --- /dev/null +++ b/src/pyfed/content/__init__.py @@ -0,0 +1,6 @@ +from .handler import ContentHandler +from .collections import CollectionHandler + +__all__ = [ + 'ContentHandler', 'CollectionHandler' +] diff --git a/src/pyfed/content/collections.py b/src/pyfed/content/collections.py new file mode 100644 index 0000000000000000000000000000000000000000..bebd0309a2bb1b2d27b305fa8ec310e0ba1a88f9 --- /dev/null +++ b/src/pyfed/content/collections.py @@ -0,0 +1,190 @@ +""" +Collection handling implementation. +""" + +from typing import Dict, Any, List, Optional, Union +from datetime import datetime +import asyncio + +from ..utils.exceptions import CollectionError +from ..utils.logging import get_logger +from ..storage.base import StorageBackend + +logger = get_logger(__name__) + +class CollectionHandler: + """Handle ActivityPub collections.""" + + def __init__(self, storage: StorageBackend): + self.storage = storage + + async def create_collection(self, + collection_type: str, + owner: str, + items: Optional[List[str]] = None) -> str: + """ + Create a new collection. + + Args: + collection_type: Collection type (OrderedCollection or Collection) + owner: Collection owner + items: Initial collection items + + Returns: + Collection ID + """ + try: + collection = { + "type": collection_type, + "attributedTo": owner, + "totalItems": len(items) if items else 0, + "items": items or [], + "published": datetime.utcnow().isoformat() + } + + collection_id = await self.storage.create_object(collection) + return collection_id + + except Exception as e: + logger.error(f"Failed to create collection: {e}") + raise CollectionError(f"Failed to create collection: {e}") + + async def add_to_collection(self, + collection_id: str, + items: Union[str, List[str]]) -> None: + """ + Add items to collection. + + Args: + collection_id: Collection ID + items: Item(s) to add + """ + try: + collection = await self.storage.get_object(collection_id) + if not collection: + raise CollectionError(f"Collection not found: {collection_id}") + + if isinstance(items, str): + items = [items] + + # Add items + current_items = collection.get('items', []) + new_items = list(set(current_items + items)) + + # Update collection + collection['items'] = new_items + collection['totalItems'] = len(new_items) + + await self.storage.update_object(collection_id, collection) + + except Exception as e: + logger.error(f"Failed to add to collection: {e}") + raise CollectionError(f"Failed to add to collection: {e}") + + async def remove_from_collection(self, + collection_id: str, + items: Union[str, List[str]]) -> None: + """ + Remove items from collection. + + Args: + collection_id: Collection ID + items: Item(s) to remove + """ + try: + collection = await self.storage.get_object(collection_id) + if not collection: + raise CollectionError(f"Collection not found: {collection_id}") + + if isinstance(items, str): + items = [items] + + # Remove items + current_items = collection.get('items', []) + new_items = [i for i in current_items if i not in items] + + # Update collection + collection['items'] = new_items + collection['totalItems'] = len(new_items) + + await self.storage.update_object(collection_id, collection) + + except Exception as e: + logger.error(f"Failed to remove from collection: {e}") + raise CollectionError(f"Failed to remove from collection: {e}") + + async def get_collection_page(self, + collection_id: str, + page: int = 1, + per_page: int = 20) -> Dict[str, Any]: + """ + Get collection page. + + Args: + collection_id: Collection ID + page: Page number + per_page: Items per page + + Returns: + Collection page + """ + try: + collection = await self.storage.get_object(collection_id) + if not collection: + raise CollectionError(f"Collection not found: {collection_id}") + + items = collection.get('items', []) + total = len(items) + + # Calculate pagination + start = (page - 1) * per_page + end = start + per_page + page_items = items[start:end] + + return { + "type": "OrderedCollectionPage", + "partOf": collection_id, + "orderedItems": page_items, + "totalItems": total, + "current": f"{collection_id}?page={page}", + "first": f"{collection_id}?page=1", + "last": f"{collection_id}?page={-(-total // per_page)}", + "next": f"{collection_id}?page={page + 1}" if end < total else None, + "prev": f"{collection_id}?page={page - 1}" if page > 1 else None + } + + except Exception as e: + logger.error(f"Failed to get collection page: {e}") + raise CollectionError(f"Failed to get collection page: {e}") + + async def merge_collections(self, + target_id: str, + source_id: str) -> None: + """ + Merge two collections. + + Args: + target_id: Target collection ID + source_id: Source collection ID + """ + try: + target = await self.storage.get_object(target_id) + source = await self.storage.get_object(source_id) + + if not target or not source: + raise CollectionError("Collection not found") + + # Merge items + target_items = target.get('items', []) + source_items = source.get('items', []) + merged_items = list(set(target_items + source_items)) + + # Update target + target['items'] = merged_items + target['totalItems'] = len(merged_items) + + await self.storage.update_object(target_id, target) + + except Exception as e: + logger.error(f"Failed to merge collections: {e}") + raise CollectionError(f"Failed to merge collections: {e}") \ No newline at end of file diff --git a/src/pyfed/content/handler.py b/src/pyfed/content/handler.py new file mode 100644 index 0000000000000000000000000000000000000000..5b653180d197ad19ac9794362406abbec2def242 --- /dev/null +++ b/src/pyfed/content/handler.py @@ -0,0 +1,215 @@ +""" +Unified content handling implementation. +""" + +from typing import Dict, Any, List, Optional, Tuple +import re +from urllib.parse import urlparse +import markdown +from bs4 import BeautifulSoup +import html + +from ..utils.exceptions import ContentError +from ..utils.logging import get_logger +from ..federation.discovery import InstanceDiscovery + +logger = get_logger(__name__) + +class ContentHandler: + """Handle content processing.""" + + def __init__(self, + instance_discovery: InstanceDiscovery, + allowed_tags: Optional[List[str]] = None, + allowed_attributes: Optional[Dict[str, List[str]]] = None): + self.instance_discovery = instance_discovery + self.allowed_tags = allowed_tags or [ + 'p', 'br', 'span', 'a', 'em', 'strong', + 'ul', 'ol', 'li', 'blockquote', 'code', + 'pre', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6' + ] + self.allowed_attributes = allowed_attributes or { + 'a': ['href', 'rel', 'class'], + 'span': ['class'], + 'code': ['class'], + 'pre': ['class'] + } + self.markdown = markdown.Markdown( + extensions=['extra', 'smarty', 'codehilite'] + ) + self.mention_pattern = re.compile(r'@([^@\s]+)@([^\s]+)') + + async def process_content(self, + content: str, + content_type: str = "text/markdown", + local_domain: Optional[str] = None) -> Tuple[str, List[Dict[str, Any]]]: + """ + Process content including formatting and mentions. + + Args: + content: Raw content + content_type: Content type (text/markdown or text/html) + local_domain: Local domain for resolving mentions + + Returns: + Tuple of (processed content, mention objects) + """ + try: + # Process mentions first + processed_content, mentions = await self._process_mentions( + content, local_domain + ) + + # Format content + formatted_content = await self._format_content( + processed_content, + content_type + ) + + return formatted_content, mentions + + except Exception as e: + logger.error(f"Content processing failed: {e}") + raise ContentError(f"Content processing failed: {e}") + + async def _process_mentions(self, + content: str, + local_domain: Optional[str]) -> Tuple[str, List[Dict[str, Any]]]: + """Process mentions in content.""" + mentions = [] + processed = content + + # Find all mentions + for match in self.mention_pattern.finditer(content): + username, domain = match.groups() + mention = await self._resolve_mention(username, domain, local_domain) + if mention: + mentions.append(mention) + # Replace mention with link + processed = processed.replace( + f"@{username}@{domain}", + f"<span class='h-card'><a href='{mention['href']}' class='u-url mention'>@{username}</a></span>" + ) + + return processed, mentions + + async def _resolve_mention(self, + username: str, + domain: str, + local_domain: Optional[str]) -> Optional[Dict[str, Any]]: + """Resolve mention to actor.""" + try: + # Local mention + if domain == local_domain: + return { + "type": "Mention", + "href": f"https://{domain}/users/{username}", + "name": f"@{username}@{domain}" + } + + # Remote mention + webfinger = await self.instance_discovery.webfinger( + f"acct:{username}@{domain}" + ) + + if not webfinger: + return None + + # Find actor URL + actor_url = None + for link in webfinger.get('links', []): + if link.get('rel') == 'self' and link.get('type') == 'application/activity+json': + actor_url = link.get('href') + break + + if not actor_url: + return None + + return { + "type": "Mention", + "href": actor_url, + "name": f"@{username}@{domain}" + } + + except Exception as e: + logger.error(f"Failed to resolve mention: {e}") + return None + + async def _format_content(self, + content: str, + content_type: str) -> str: + """Format content.""" + try: + # Convert to HTML + if content_type == "text/markdown": + html_content = self.markdown.convert(content) + else: + html_content = content + + # Sanitize HTML + clean_html = self._sanitize_html(html_content) + + # Add microformats + formatted = self._add_microformats(clean_html) + + return formatted + + except Exception as e: + logger.error(f"Failed to format content: {e}") + raise ContentError(f"Failed to format content: {e}") + + def _sanitize_html(self, content: str) -> str: + """Sanitize HTML content.""" + try: + soup = BeautifulSoup(content, 'html.parser') + + for tag in soup.find_all(True): + if tag.name not in self.allowed_tags: + tag.unwrap() + else: + # Remove disallowed attributes + allowed = self.allowed_attributes.get(tag.name, []) + for attr in list(tag.attrs): + if attr not in allowed: + del tag[attr] + + # Clean URLs in links + if tag.name == 'a' and tag.get('href'): + tag['href'] = self._clean_url(tag['href']) + tag['rel'] = 'nofollow noopener noreferrer' + + return str(soup) + + except Exception as e: + logger.error(f"Failed to sanitize HTML: {e}") + raise ContentError(f"Failed to sanitize HTML: {e}") + + def _clean_url(self, url: str) -> str: + """Clean and validate URL.""" + url = url.strip() + + # Only allow http(s) URLs + if not url.startswith(('http://', 'https://')): + return '#' + + return url + + def _add_microformats(self, content: str) -> str: + """Add microformat classes.""" + try: + soup = BeautifulSoup(content, 'html.parser') + + # Add e-content class to content wrapper + if soup.find(['p', 'div']): + wrapper = soup.find(['p', 'div']) + wrapper['class'] = wrapper.get('class', []) + ['e-content'] + + # Add u-url class to links + for link in soup.find_all('a'): + link['class'] = link.get('class', []) + ['u-url'] + + return str(soup) + + except Exception as e: + logger.error(f"Failed to add microformats: {e}") + return content \ No newline at end of file diff --git a/src/pyfed/content/media.py b/src/pyfed/content/media.py new file mode 100644 index 0000000000000000000000000000000000000000..78e8a27bdd47f5f00912ce38646c29f78997d52f --- /dev/null +++ b/src/pyfed/content/media.py @@ -0,0 +1,155 @@ +""" +Media attachment handling implementation. +""" + +from typing import Dict, Any, Optional, List +import aiohttp +import mimetypes +import hashlib +from pathlib import Path +import magic +from PIL import Image +import asyncio + +from ..utils.exceptions import MediaError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +class MediaHandler: + """Handle media attachments.""" + + def __init__(self, + upload_path: str = "uploads", + max_size: int = 10_000_000, # 10MB + allowed_types: Optional[List[str]] = None): + self.upload_path = Path(upload_path) + self.max_size = max_size + self.allowed_types = allowed_types or [ + 'image/jpeg', + 'image/png', + 'image/gif', + 'video/mp4', + 'audio/mpeg', + 'audio/ogg' + ] + self.upload_path.mkdir(parents=True, exist_ok=True) + + async def process_attachment(self, + url: str, + description: Optional[str] = None) -> Dict[str, Any]: + """ + Process media attachment. + + Args: + url: Media URL + description: Media description + + Returns: + Processed attachment object + """ + try: + # Download media + content = await self._download_media(url) + + # Validate media + mime_type = magic.from_buffer(content, mime=True) + if mime_type not in self.allowed_types: + raise MediaError(f"Unsupported media type: {mime_type}") + + if len(content) > self.max_size: + raise MediaError("Media too large") + + # Generate filename + file_hash = hashlib.sha256(content).hexdigest() + ext = mimetypes.guess_extension(mime_type) or '' + filename = f"{file_hash}{ext}" + + # Save file + file_path = self.upload_path / filename + with open(file_path, 'wb') as f: + f.write(content) + + # Generate thumbnails for images + thumbnails = {} + if mime_type.startswith('image/'): + thumbnails = await self._generate_thumbnails(file_path) + + return { + "type": "Document", + "mediaType": mime_type, + "url": f"/media/{filename}", + "name": description or filename, + "blurhash": await self._generate_blurhash(file_path) if mime_type.startswith('image/') else None, + "width": await self._get_image_width(file_path) if mime_type.startswith('image/') else None, + "height": await self._get_image_height(file_path) if mime_type.startswith('image/') else None, + "thumbnails": thumbnails + } + + except Exception as e: + logger.error(f"Failed to process attachment: {e}") + raise MediaError(f"Failed to process attachment: {e}") + + async def _download_media(self, url: str) -> bytes: + """Download media from URL.""" + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status != 200: + raise MediaError(f"Failed to download media: {response.status}") + return await response.read() + except Exception as e: + raise MediaError(f"Failed to download media: {e}") + + async def _generate_thumbnails(self, file_path: Path) -> Dict[str, Dict[str, Any]]: + """Generate image thumbnails.""" + thumbnails = {} + sizes = [(320, 320), (640, 640)] + + try: + image = Image.open(file_path) + for width, height in sizes: + thumb = image.copy() + thumb.thumbnail((width, height)) + + thumb_hash = hashlib.sha256(str(file_path).encode()).hexdigest() + thumb_filename = f"thumb_{width}x{height}_{thumb_hash}.jpg" + thumb_path = self.upload_path / thumb_filename + + thumb.save(thumb_path, "JPEG", quality=85) + + thumbnails[f"{width}x{height}"] = { + "url": f"/media/{thumb_filename}", + "width": thumb.width, + "height": thumb.height + } + + return thumbnails + + except Exception as e: + logger.error(f"Failed to generate thumbnails: {e}") + return {} + + async def _generate_blurhash(self, file_path: Path) -> Optional[str]: + """Generate blurhash for image.""" + try: + # Implementation for blurhash generation + return None + except Exception: + return None + + async def _get_image_width(self, file_path: Path) -> Optional[int]: + """Get image width.""" + try: + with Image.open(file_path) as img: + return img.width + except Exception: + return None + + async def _get_image_height(self, file_path: Path) -> Optional[int]: + """Get image height.""" + try: + with Image.open(file_path) as img: + return img.height + except Exception: + return None \ No newline at end of file diff --git a/src/pyfed/federation/__init__.py b/src/pyfed/federation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7770cfdadd2a832700722298a1c75e52ed719950 --- /dev/null +++ b/src/pyfed/federation/__init__.py @@ -0,0 +1,15 @@ +""" +Federation package for ActivityPub server-to-server interactions. +""" + +from .delivery import ActivityDelivery +from .fetch import ResourceFetcher +from .resolver import ActivityPubResolver +from .discovery import InstanceDiscovery + +__all__ = [ + 'ActivityDelivery', + 'ResourceFetcher', + 'ActivityPubResolver', + 'InstanceDiscovery' +] \ No newline at end of file diff --git a/src/pyfed/federation/delivery.py b/src/pyfed/federation/delivery.py new file mode 100644 index 0000000000000000000000000000000000000000..cc1e94e33093774088d25ea67989fc063262c52b --- /dev/null +++ b/src/pyfed/federation/delivery.py @@ -0,0 +1,217 @@ +""" +federation/delivery.py +Activity delivery implementation. +""" + +from typing import Dict, Any, Optional, List +from urllib.parse import urlparse +import aiohttp +from dataclasses import dataclass +from datetime import datetime +import asyncio + +from ..utils.exceptions import DeliveryError +from ..utils.logging import get_logger +from ..security.key_management import KeyManager + +logger = get_logger(__name__) + +@dataclass +class DeliveryResult: + """Delivery result.""" + success: List[str] = None + failed: List[str] = None + status_code: Optional[int] = None + error_message: Optional[str] = None + retry_after: Optional[int] = None + + def __post_init__(self): + self.success = self.success or [] + self.failed = self.failed or [] + +class ActivityDelivery: + """Activity delivery implementation.""" + + def __init__(self, + key_manager: KeyManager, + timeout: int = 30, + max_retries: int = 3, + retry_delay: int = 300): + """Initialize delivery service.""" + self.key_manager = key_manager + self.timeout = timeout + self.max_retries = max_retries + self.retry_delay = retry_delay + self.delivery_status = {} + self.session = None + self.signature_verifier = key_manager + + async def initialize(self) -> None: + """Initialize delivery service.""" + self.session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers={ + "User-Agent": f"PyFed/1.0 (+https://{self.key_manager.domain})", + "Accept": "application/activity+json" + } + ) + + async def close(self) -> None: + """Clean up resources.""" + if self.session: + await self.session.close() + + async def deliver_activity(self, + activity: Dict[str, Any], + recipients: List[str]) -> DeliveryResult: + """Deliver activity to recipients.""" + result = DeliveryResult() + + try: + # Group recipients by domain + domain_groups = self._group_by_domain(recipients) + + for domain, domain_recipients in domain_groups.items(): + # Try shared inbox first + shared_result = await self._try_shared_inbox( + activity, domain, domain_recipients + ) + + # Update result with shared inbox attempt + if shared_result.success: + result.success.extend(shared_result.success) + result.status_code = shared_result.status_code + result.error_message = shared_result.error_message + else: + # Individual delivery for failed recipients + for recipient in domain_recipients: + inbox_result = await self._deliver_to_inbox( + activity, + recipient + ) + if inbox_result.success: + result.success.extend(inbox_result.success) + else: + result.failed.extend(inbox_result.failed) + result.status_code = inbox_result.status_code + result.error_message = inbox_result.error_message + + # Ensure no duplicates + result.success = list(set(result.success)) + result.failed = list(set(result.failed)) + + # Store delivery status + if activity.get('id'): + self.delivery_status[activity['id']] = { + 'success': result.success, + 'failed': result.failed, + 'timestamp': datetime.utcnow().isoformat() + } + + return result + + except Exception as e: + logger.error(f"Activity delivery failed: {e}") + result.failed = list(set(recipients)) + result.error_message = str(e) + return result + + async def _deliver_to_inbox(self, + activity: Dict[str, Any], + inbox_url: str, + attempt: int = 1) -> DeliveryResult: + """Deliver activity to a single inbox.""" + result = DeliveryResult() + + try: + # Prepare headers + parsed_url = urlparse(inbox_url) + headers = { + "Content-Type": "application/activity+json", + "Accept": "application/activity+json", + "User-Agent": "PyFed/1.0", + "Host": parsed_url.netloc + } + + # Sign request + signed_headers = await self.signature_verifier.sign_request( + method="POST", + path=parsed_url.path, + headers=headers, + body=activity + ) + + async with aiohttp.ClientSession() as session: + async with await session.post( + inbox_url, + json=activity, + headers=signed_headers, + timeout=self.timeout + ) as response: + result.status_code = response.status + error_text = await response.text() + + if response.status in [200, 201, 202]: + result.success = [inbox_url] + return result + + if response.status == 429: + retry_after = int(response.headers.get('Retry-After', self.retry_delay)) + result.retry_after = retry_after + await asyncio.sleep(retry_after) + if attempt < self.max_retries: + return await self._deliver_to_inbox( + activity, inbox_url, attempt + 1 + ) + + result.error_message = f"Delivery failed: {response.status} - {error_text}" + result.failed = [inbox_url] + return result + + except asyncio.TimeoutError: + result.error_message = f"Delivery timeout to {inbox_url}" + result.failed = [inbox_url] + + except Exception as e: + result.error_message = str(e) + result.failed = [inbox_url] + + if attempt < self.max_retries: + await asyncio.sleep(self.retry_delay) + return await self._deliver_to_inbox( + activity, inbox_url, attempt + 1 + ) + + return result + + async def _try_shared_inbox(self, + activity: Dict[str, Any], + domain: str, + recipients: List[str]) -> DeliveryResult: + """Try delivering to domain's shared inbox.""" + shared_inbox_url = f"https://{domain}/inbox" + try: + result = await self._deliver_to_inbox(activity, shared_inbox_url) + if result.status_code in [200, 201, 202]: + result.success = list(set(recipients)) + result.failed = [] + else: + result.failed = list(set(recipients)) + return result + except Exception as e: + logger.debug(f"Shared inbox delivery failed: {e}") + return DeliveryResult(failed=list(set(recipients))) + + def _group_by_domain(self, recipients: List[str]) -> Dict[str, List[str]]: + """Group recipients by domain.""" + groups: Dict[str, List[str]] = {} + for recipient in recipients: + domain = urlparse(recipient).netloc + if domain not in groups: + groups[domain] = [] + groups[domain].append(recipient) + return groups + + def get_delivery_status(self, activity_id: str) -> Optional[Dict[str, Any]]: + """Get delivery status for an activity.""" + return self.delivery_status.get(activity_id) \ No newline at end of file diff --git a/src/pyfed/federation/discovery.py b/src/pyfed/federation/discovery.py new file mode 100644 index 0000000000000000000000000000000000000000..85ed45b8dbc78e0ed7ba5507e1b6e92078447fe1 --- /dev/null +++ b/src/pyfed/federation/discovery.py @@ -0,0 +1,334 @@ +""" +federation/discovery.py +Federation instance discovery implementation. + +Handles: +- Instance metadata discovery +- WebFinger resolution +- NodeInfo discovery +- Actor discovery +""" + +from typing import Dict, Any, Optional, List +import aiohttp +import json +from urllib.parse import urlparse, urljoin +from datetime import datetime +import asyncio +from dataclasses import dataclass + +from ..utils.exceptions import DiscoveryError +from ..utils.logging import get_logger +from ..cache.memory_cache import MemoryCache + +logger = get_logger(__name__) + +@dataclass +class NodeInfo: + """NodeInfo data.""" + version: str + software: Dict[str, str] + protocols: List[str] + services: Dict[str, List[str]] + usage: Dict[str, Any] + open_registrations: bool + metadata: Dict[str, Any] + +@dataclass +class InstanceInfo: + """Instance information.""" + domain: str + nodeinfo: Optional[NodeInfo] + software_version: Optional[str] + instance_actor: Optional[Dict[str, Any]] + shared_inbox: Optional[str] + endpoints: Dict[str, str] + features: Dict[str, bool] + last_updated: datetime + +class InstanceDiscovery: + """Federation instance discovery.""" + + def __init__(self, + cache_ttl: int = 3600, # 1 hour + request_timeout: int = 10): + """Initialize instance discovery.""" + self.cache = MemoryCache(ttl=cache_ttl) + self.timeout = request_timeout + self.session = None + + async def initialize(self) -> None: + """Initialize HTTP session.""" + self.session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers={ + "User-Agent": "PyFed/1.0", + "Accept": "application/activity+json" + } + ) + + async def discover_instance(self, domain: str) -> InstanceInfo: + """ + Discover instance information. + + Args: + domain: Instance domain + + Returns: + InstanceInfo with complete instance data + """ + try: + # Check cache first + cache_key = f"instance:{domain}" + if cached := await self.cache.get(cache_key): + return cached + + # Discover instance components + nodeinfo = await self.discover_nodeinfo(domain) + instance_actor = await self.discover_instance_actor(domain) + endpoints = await self.discover_endpoints(domain) + features = await self.discover_features(domain) + + # Build instance info + info = InstanceInfo( + domain=domain, + nodeinfo=nodeinfo, + software_version=nodeinfo.software.get('version') if nodeinfo else None, + instance_actor=instance_actor, + shared_inbox=instance_actor.get('endpoints', {}).get('sharedInbox') + if instance_actor else None, + endpoints=endpoints, + features=features, + last_updated=datetime.utcnow() + ) + + # Cache result + await self.cache.set(cache_key, info) + return info + + except Exception as e: + logger.error(f"Instance discovery failed for {domain}: {e}") + raise DiscoveryError(f"Instance discovery failed: {e}") + + async def discover_nodeinfo(self, domain: str) -> Optional[NodeInfo]: + """ + Discover NodeInfo data. + + Implements NodeInfo 2.0 and 2.1 discovery. + """ + try: + # Try well-known location first + well_known_url = f"https://{domain}/.well-known/nodeinfo" + async with self.session.get(well_known_url) as response: + if response.status != 200: + return None + + links = await response.json() + nodeinfo_url = None + + # Find highest supported version + for link in links.get('links', []): + if link.get('rel') == 'http://nodeinfo.diaspora.software/ns/schema/2.1': + nodeinfo_url = link.get('href') + break + elif link.get('rel') == 'http://nodeinfo.diaspora.software/ns/schema/2.0': + nodeinfo_url = link.get('href') + + if not nodeinfo_url: + return None + + # Fetch NodeInfo + async with self.session.get(nodeinfo_url) as nodeinfo_response: + if nodeinfo_response.status != 200: + return None + + data = await nodeinfo_response.json() + return NodeInfo( + version=data.get('version', '2.0'), + software=data.get('software', {}), + protocols=data.get('protocols', []), + services=data.get('services', {}), + usage=data.get('usage', {}), + open_registrations=data.get('openRegistrations', False), + metadata=data.get('metadata', {}) + ) + + except Exception as e: + logger.error(f"NodeInfo discovery failed for {domain}: {e}") + return None + + async def discover_instance_actor(self, domain: str) -> Optional[Dict[str, Any]]: + """ + Discover instance actor. + + Tries multiple common locations. + """ + try: + locations = [ + f"https://{domain}/actor", + f"https://{domain}/instance", + f"https://{domain}/instance/actor", + f"https://{domain}/" + ] + + headers = { + "Accept": "application/activity+json" + } + + for url in locations: + try: + async with self.session.get(url, headers=headers) as response: + if response.status == 200: + data = await response.json() + if data.get('type') in ['Application', 'Service']: + return data + except: + continue + + return None + + except Exception as e: + logger.error(f"Instance actor discovery failed for {domain}: {e}") + return None + + async def discover_endpoints(self, domain: str) -> Dict[str, str]: + """ + Discover instance endpoints. + + Finds common ActivityPub endpoints. + """ + endpoints = {} + base_url = f"https://{domain}" + + # Common endpoint paths + paths = { + 'inbox': '/inbox', + 'outbox': '/outbox', + 'following': '/following', + 'followers': '/followers', + 'featured': '/featured', + 'shared_inbox': '/inbox', + 'nodeinfo': '/.well-known/nodeinfo', + 'webfinger': '/.well-known/webfinger' + } + + for name, path in paths.items(): + url = urljoin(base_url, path) + try: + async with self.session.head(url) as response: + if response.status != 404: + endpoints[name] = url + except: + continue + + return endpoints + + async def discover_features(self, domain: str) -> Dict[str, bool]: + """ + Discover supported features. + + Checks for various federation features. + """ + features = { + 'activitypub': False, + 'webfinger': False, + 'nodeinfo': False, + 'shared_inbox': False, + 'collections': False, + 'media_proxy': False + } + + # Check WebFinger + try: + webfinger_url = f"https://{domain}/.well-known/webfinger?resource=acct:test@{domain}" + async with self.session.head(webfinger_url) as response: + features['webfinger'] = response.status != 404 + except: + pass + + # Check NodeInfo + try: + nodeinfo_url = f"https://{domain}/.well-known/nodeinfo" + async with self.session.head(nodeinfo_url) as response: + features['nodeinfo'] = response.status == 200 + except: + pass + + # Check shared inbox + try: + inbox_url = f"https://{domain}/inbox" + async with self.session.head(inbox_url) as response: + features['shared_inbox'] = response.status != 404 + except: + pass + + # Check collections + try: + collections = ['/following', '/followers', '/featured'] + features['collections'] = any( + await self._check_endpoint(domain, path) + for path in collections + ) + except: + pass + + # Check media proxy + try: + proxy_url = f"https://{domain}/proxy" + async with self.session.head(proxy_url) as response: + features['media_proxy'] = response.status != 404 + except: + pass + + # ActivityPub is supported if basic endpoints exist + features['activitypub'] = features['shared_inbox'] or features['collections'] + + return features + + async def _check_endpoint(self, domain: str, path: str) -> bool: + """Check if endpoint exists.""" + try: + url = f"https://{domain}{path}" + async with self.session.head(url) as response: + return response.status != 404 + except: + return False + + async def webfinger(self, + resource: str, + domain: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + Perform WebFinger lookup. + + Args: + resource: Resource to look up (acct: or https:) + domain: Optional domain override + + Returns: + WebFinger response data + """ + try: + if not domain: + if resource.startswith('acct:'): + domain = resource.split('@')[1] + else: + domain = urlparse(resource).netloc + + url = ( + f"https://{domain}/.well-known/webfinger" + f"?resource={resource}" + ) + + async with self.session.get(url) as response: + if response.status != 200: + return None + + return await response.json() + + except Exception as e: + logger.error(f"WebFinger lookup failed for {resource}: {e}") + return None + async def close(self) -> None: + """Clean up resources.""" + if self.session: + await self.session.close() diff --git a/src/pyfed/federation/fetch.py b/src/pyfed/federation/fetch.py new file mode 100644 index 0000000000000000000000000000000000000000..cab9e954be32cc540eeaa746e4b7c3722b5166a3 --- /dev/null +++ b/src/pyfed/federation/fetch.py @@ -0,0 +1,88 @@ +""" +federation/fetch.py +Remote resource fetching implementation. +""" + +from typing import Dict, Any, Optional, Union +import aiohttp +from urllib.parse import urlparse + +from pyfed.security.http_signatures import HTTPSignatureVerifier +from pyfed.cache.actor_cache import ActorCache +from pyfed.utils.exceptions import FetchError +from pyfed.utils.logging import get_logger + +logger = get_logger(__name__) + +class ResourceFetcher: + """ + Handles fetching remote ActivityPub resources. + + This class: + - Fetches remote actors and objects + - Handles HTTP signatures + - Uses caching when possible + - Validates responses + """ + + def __init__(self, + signature_verifier: HTTPSignatureVerifier, + actor_cache: Optional[ActorCache] = None, + timeout: int = 30): + """ + Initialize resource fetcher. + + Args: + signature_verifier: HTTP signature verifier + actor_cache: Optional actor cache + timeout: Request timeout in seconds + """ + self.signature_verifier = signature_verifier + self.actor_cache = actor_cache + self.timeout = timeout + + async def fetch_resource(self, url: str) -> Dict[str, Any]: + """ + Fetch a remote resource. + + Args: + url: Resource URL + + Returns: + Resource data + + Raises: + FetchError: If fetch fails + """ + try: + # Check actor cache first + if self.actor_cache: + cached = await self.actor_cache.get(url) + if cached: + return cached + + # Sign request + headers = await self.signature_verifier.sign_request( + method='GET', + path=urlparse(url).path, + host=urlparse(url).netloc + ) + headers['Accept'] = 'application/activity+json' + + # Fetch resource + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: + async with session.get(url, headers=headers) as response: + if response.status != 200: + raise FetchError(f"Failed to fetch {url}: {response.status}") + + data = await response.json() + + # Cache actor data + if self.actor_cache and data.get('type') in ('Person', 'Group', 'Organization', 'Service'): + await self.actor_cache.set(url, data) + + return data + + except Exception as e: + logger.error(f"Error fetching {url}: {e}") + raise FetchError(f"Failed to fetch {url}: {e}") \ No newline at end of file diff --git a/src/pyfed/federation/queue.py b/src/pyfed/federation/queue.py new file mode 100644 index 0000000000000000000000000000000000000000..47eaf7a8543362cd3887377c17c19b51bcd78150 --- /dev/null +++ b/src/pyfed/federation/queue.py @@ -0,0 +1,219 @@ +""" +federation/queue.py +Activity delivery queue implementation. +""" + +from typing import Dict, List, Any, Optional +from datetime import datetime, timedelta +import asyncio +import json +from dataclasses import dataclass +from enum import Enum +import aioredis + +from ..utils.exceptions import QueueError +from ..utils.logging import get_logger +from .delivery import ActivityDelivery, DeliveryResult + +logger = get_logger(__name__) + +class DeliveryStatus(Enum): + """Delivery status states.""" + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + RETRYING = "retrying" + +@dataclass +class QueuedDelivery: + """Queued delivery item.""" + id: str + activity: Dict[str, Any] + recipients: List[str] + status: DeliveryStatus + attempts: int + next_attempt: Optional[datetime] + created_at: datetime + updated_at: datetime + error: Optional[str] = None + +class DeliveryQueue: + """Activity delivery queue.""" + + def __init__(self, + delivery_service: ActivityDelivery, + redis_url: str = "redis://localhost", + max_attempts: int = 5, + batch_size: int = 20): + self.delivery_service = delivery_service + self.redis_url = redis_url + self.max_attempts = max_attempts + self.batch_size = batch_size + self.redis: Optional[aioredis.Redis] = None + self._processing_task = None + + async def initialize(self) -> None: + """Initialize queue.""" + try: + self.redis = await aioredis.from_url(self.redis_url) + self._processing_task = asyncio.create_task(self._process_queue_loop()) + except Exception as e: + logger.error(f"Failed to initialize queue: {e}") + raise QueueError(f"Queue initialization failed: {e}") + + async def enqueue(self, + activity: Dict[str, Any], + recipients: List[str], + priority: int = 0) -> str: + """ + Queue activity for delivery. + + Args: + activity: Activity to deliver + recipients: List of recipient inboxes + priority: Delivery priority (0-9, higher is more urgent) + + Returns: + Delivery ID + """ + try: + # Create delivery record + delivery_id = f"delivery_{datetime.utcnow().timestamp()}" + delivery = QueuedDelivery( + id=delivery_id, + activity=activity, + recipients=recipients, + status=DeliveryStatus.PENDING, + attempts=0, + next_attempt=datetime.utcnow(), + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + + # Store in Redis + await self.redis.set( + f"delivery:{delivery_id}", + json.dumps(delivery.__dict__) + ) + + # Add to priority queue + score = datetime.utcnow().timestamp() + (9 - priority) * 10 + await self.redis.zadd( + "delivery_queue", + {delivery_id: score} + ) + + logger.info(f"Queued delivery {delivery_id} with priority {priority}") + return delivery_id + + except Exception as e: + logger.error(f"Failed to enqueue delivery: {e}") + raise QueueError(f"Failed to enqueue delivery: {e}") + + async def get_status(self, delivery_id: str) -> Optional[QueuedDelivery]: + """Get delivery status.""" + try: + data = await self.redis.get(f"delivery:{delivery_id}") + if data: + return QueuedDelivery(**json.loads(data)) + return None + except Exception as e: + logger.error(f"Failed to get delivery status: {e}") + raise QueueError(f"Failed to get delivery status: {e}") + + async def _process_queue_loop(self) -> None: + """Background task for processing queue.""" + while True: + try: + # Get batch of deliveries + now = datetime.utcnow().timestamp() + delivery_ids = await self.redis.zrangebyscore( + "delivery_queue", + "-inf", + now, + start=0, + num=self.batch_size + ) + + if not delivery_ids: + await asyncio.sleep(1) + continue + + # Process deliveries + for delivery_id in delivery_ids: + await self._process_delivery(delivery_id) + + except Exception as e: + logger.error(f"Queue processing error: {e}") + await asyncio.sleep(5) + + async def _process_delivery(self, delivery_id: str) -> None: + """Process a single delivery.""" + try: + # Get delivery data + data = await self.redis.get(f"delivery:{delivery_id}") + if not data: + return + + delivery = QueuedDelivery(**json.loads(data)) + + # Update status + delivery.status = DeliveryStatus.IN_PROGRESS + delivery.attempts += 1 + delivery.updated_at = datetime.utcnow() + + # Attempt delivery + try: + result = await self.delivery_service.deliver_activity( + activity=delivery.activity, + recipients=delivery.recipients + ) + + if result.success: + delivery.status = DeliveryStatus.COMPLETED + else: + delivery.status = DeliveryStatus.FAILED + delivery.error = result.error_message + + # Schedule retry if attempts remain + if delivery.attempts < self.max_attempts: + delivery.status = DeliveryStatus.RETRYING + delivery.next_attempt = datetime.utcnow() + timedelta( + minutes=2 ** delivery.attempts + ) + + # Re-queue with backoff + await self.redis.zadd( + "delivery_queue", + {delivery_id: delivery.next_attempt.timestamp()} + ) + + except Exception as e: + delivery.status = DeliveryStatus.FAILED + delivery.error = str(e) + + # Update delivery record + await self.redis.set( + f"delivery:{delivery_id}", + json.dumps(delivery.__dict__) + ) + + # Remove from queue if complete + if delivery.status in [DeliveryStatus.COMPLETED, DeliveryStatus.FAILED]: + await self.redis.zrem("delivery_queue", delivery_id) + + except Exception as e: + logger.error(f"Failed to process delivery {delivery_id}: {e}") + + async def close(self) -> None: + """Clean up resources.""" + if self._processing_task: + self._processing_task.cancel() + try: + await self._processing_task + except asyncio.CancelledError: + pass + + if self.redis: + await self.redis.close() \ No newline at end of file diff --git a/src/pyfed/federation/rate_limit.py b/src/pyfed/federation/rate_limit.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab669e0a0bf1468ed342ecbe0c8d590c3b9e4b7 --- /dev/null +++ b/src/pyfed/federation/rate_limit.py @@ -0,0 +1,343 @@ +""" +federation/rate_limit.py +Federation rate limiting implementation. + +Features: +- Per-domain rate limiting +- Multiple rate limit strategies +- Redis-backed storage +- Configurable limits +- Burst handling +""" + +from typing import Dict, Any, Optional, List, Union +from datetime import datetime, timedelta +import asyncio +import aioredis +from dataclasses import dataclass +from enum import Enum +import json + +from ..utils.exceptions import RateLimitError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +class RateLimitStrategy(Enum): + """Rate limit strategies.""" + FIXED_WINDOW = "fixed_window" + SLIDING_WINDOW = "sliding_window" + TOKEN_BUCKET = "token_bucket" + LEAKY_BUCKET = "leaky_bucket" + +@dataclass +class RateLimit: + """Rate limit configuration.""" + requests: int + period: int # seconds + burst: Optional[int] = None + +@dataclass +class RateLimitState: + """Current rate limit state.""" + remaining: int + reset: datetime + limit: int + +class RateLimiter: + """Federation rate limiting.""" + + def __init__(self, + redis_url: str = "redis://localhost", + strategy: RateLimitStrategy = RateLimitStrategy.SLIDING_WINDOW, + default_limits: Optional[Dict[str, RateLimit]] = None): + """ + Initialize rate limiter. + + Args: + redis_url: Redis connection URL + strategy: Rate limiting strategy + default_limits: Default rate limits by action type + """ + self.redis_url = redis_url + self.strategy = strategy + self.redis: Optional[aioredis.Redis] = None + self.default_limits = default_limits or { + "inbox": RateLimit(requests=1000, period=3600, burst=100), # 1000/hour with burst + "outbox": RateLimit(requests=100, period=3600), # 100/hour + "follow": RateLimit(requests=50, period=3600), # 50/hour + "media": RateLimit(requests=200, period=3600) # 200/hour + } + self._cleanup_task = None + + async def initialize(self) -> None: + """Initialize rate limiter.""" + try: + self.redis = await aioredis.from_url(self.redis_url) + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info(f"Rate limiter initialized with strategy: {self.strategy.value}") + except Exception as e: + logger.error(f"Failed to initialize rate limiter: {e}") + raise RateLimitError(f"Rate limiter initialization failed: {e}") + + async def check_limit(self, + domain: str, + action: str, + cost: int = 1) -> RateLimitState: + """ + Check if domain is rate limited. + + Args: + domain: Domain to check + action: Action type + cost: Action cost (default: 1) + + Returns: + RateLimitState with current limits + + Raises: + RateLimitError if limit exceeded + """ + try: + limit = self.default_limits.get(action) + if not limit: + raise RateLimitError(f"Unknown action type: {action}") + + key = f"ratelimit:{domain}:{action}" + + if self.strategy == RateLimitStrategy.SLIDING_WINDOW: + state = await self._check_sliding_window(key, limit, cost) + elif self.strategy == RateLimitStrategy.TOKEN_BUCKET: + state = await self._check_token_bucket(key, limit, cost) + elif self.strategy == RateLimitStrategy.LEAKY_BUCKET: + state = await self._check_leaky_bucket(key, limit, cost) + else: + state = await self._check_fixed_window(key, limit, cost) + + if state.remaining < 0: + raise RateLimitError( + f"Rate limit exceeded for {domain} ({action}). " + f"Reset at {state.reset.isoformat()}" + ) + + return state + + except RateLimitError: + raise + except Exception as e: + logger.error(f"Rate limit check failed: {e}") + raise RateLimitError(f"Rate limit check failed: {e}") + + async def _check_fixed_window(self, + key: str, + limit: RateLimit, + cost: int) -> RateLimitState: + """Fixed window rate limiting.""" + now = datetime.utcnow() + window_key = f"{key}:{int(now.timestamp() / limit.period)}" + + async with self.redis.pipeline() as pipe: + # Get current count + count = await self.redis.get(window_key) + current = int(count) if count else 0 + + if current + cost > limit.requests: + # Limit exceeded + expiry = ( + int(now.timestamp() / limit.period) + 1 + ) * limit.period + reset = datetime.fromtimestamp(expiry) + return RateLimitState( + remaining=-1, + reset=reset, + limit=limit.requests + ) + + # Update count + pipe.incrby(window_key, cost) + pipe.expire(window_key, limit.period) + await pipe.execute() + + return RateLimitState( + remaining=limit.requests - (current + cost), + reset=datetime.fromtimestamp( + (int(now.timestamp() / limit.period) + 1) * limit.period + ), + limit=limit.requests + ) + + async def _check_sliding_window(self, + key: str, + limit: RateLimit, + cost: int) -> RateLimitState: + """Sliding window rate limiting.""" + now = datetime.utcnow() + window_start = int((now - timedelta(seconds=limit.period)).timestamp() * 1000) + + # Remove old entries + await self.redis.zremrangebyscore(key, "-inf", window_start) + + # Get current count + count = await self.redis.zcount(key, window_start, "+inf") + + if count + cost > limit.requests: + # Get reset time from oldest remaining entry + oldest = await self.redis.zrange(key, 0, 0, withscores=True) + if oldest: + reset = datetime.fromtimestamp(oldest[0][1] / 1000 + limit.period) + else: + reset = now + timedelta(seconds=limit.period) + return RateLimitState( + remaining=-1, + reset=reset, + limit=limit.requests + ) + + # Add new entry + score = int(now.timestamp() * 1000) + await self.redis.zadd(key, {str(score): score}) + + return RateLimitState( + remaining=limit.requests - (count + cost), + reset=now + timedelta(seconds=limit.period), + limit=limit.requests + ) + + async def _check_token_bucket(self, + key: str, + limit: RateLimit, + cost: int) -> RateLimitState: + """Token bucket rate limiting.""" + now = datetime.utcnow() + bucket_key = f"{key}:bucket" + + # Get current bucket state + state = await self.redis.get(bucket_key) + if state: + tokens, last_update = json.loads(state) + else: + tokens = limit.burst or limit.requests + last_update = now.timestamp() + + # Add new tokens based on time passed + elapsed = now.timestamp() - last_update + new_tokens = int(elapsed * (limit.requests / limit.period)) + tokens = min(tokens + new_tokens, limit.burst or limit.requests) + + if tokens < cost: + # Not enough tokens + refill_time = (cost - tokens) * (limit.period / limit.requests) + reset = now + timedelta(seconds=refill_time) + return RateLimitState( + remaining=-1, + reset=reset, + limit=limit.requests + ) + + # Use tokens + tokens -= cost + await self.redis.set( + bucket_key, + json.dumps([tokens, now.timestamp()]), + ex=limit.period * 2 + ) + + return RateLimitState( + remaining=tokens, + reset=now + timedelta(seconds=limit.period), + limit=limit.requests + ) + + async def _check_leaky_bucket(self, + key: str, + limit: RateLimit, + cost: int) -> RateLimitState: + """Leaky bucket rate limiting.""" + now = datetime.utcnow() + bucket_key = f"{key}:leaky" + + # Get current bucket state + state = await self.redis.get(bucket_key) + if state: + level, last_update = json.loads(state) + else: + level = 0 + last_update = now.timestamp() + + # Leak based on time passed + elapsed = now.timestamp() - last_update + leak = elapsed * (limit.requests / limit.period) + level = max(0, level - leak) + + if level + cost > (limit.burst or limit.requests): + # Bucket would overflow + drain_time = ( + level + cost - (limit.burst or limit.requests) + ) * (limit.period / limit.requests) + reset = now + timedelta(seconds=drain_time) + return RateLimitState( + remaining=-1, + reset=reset, + limit=limit.requests + ) + + # Add to bucket + level += cost + await self.redis.set( + bucket_key, + json.dumps([level, now.timestamp()]), + ex=limit.period * 2 + ) + + return RateLimitState( + remaining=int((limit.burst or limit.requests) - level), + reset=now + timedelta(seconds=limit.period), + limit=limit.requests + ) + + async def record_request(self, + domain: str, + action: str, + cost: int = 1) -> None: + """ + Record a request for rate limiting. + + This is a convenience method that combines check_limit and recording. + """ + await self.check_limit(domain, action, cost) + + async def get_limits(self, domain: str) -> Dict[str, RateLimitState]: + """Get current rate limits for domain.""" + limits = {} + for action, limit in self.default_limits.items(): + try: + state = await self.check_limit(domain, action, 0) + limits[action] = state + except RateLimitError: + continue + return limits + + async def _cleanup_loop(self) -> None: + """Background task for cleaning up expired rate limit data.""" + while True: + try: + # Scan for expired keys + async for key in self.redis.scan_iter("ratelimit:*"): + if await self.redis.ttl(key) <= 0: + await self.redis.delete(key) + await asyncio.sleep(3600) # Run hourly + except Exception as e: + logger.error(f"Rate limit cleanup failed: {e}") + await asyncio.sleep(300) # Retry in 5 minutes + + async def close(self) -> None: + """Clean up resources.""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + if self.redis: + await self.redis.close() \ No newline at end of file diff --git a/src/pyfed/federation/resolver.py b/src/pyfed/federation/resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..a66a31febdabeff278bcab8e8b3b390a70c0d092 --- /dev/null +++ b/src/pyfed/federation/resolver.py @@ -0,0 +1,122 @@ +""" +federation/resolver.py +ActivityPub resolver implementation. +""" + +from typing import Dict, Any, Optional +import aiohttp +from urllib.parse import urlparse + +from ..utils.exceptions import ResolverError +from ..utils.logging import get_logger +from ..cache.actor_cache import ActorCache +from ..security.webfinger import WebFingerService + +logger = get_logger(__name__) + +class ActivityPubResolver: + """Resolves ActivityPub resources.""" + + def __init__(self, + actor_cache: Optional[ActorCache] = None, + discovery_service: Optional[WebFingerService] = None): + """ + Initialize resolver. + + Args: + actor_cache: Optional actor cache + discovery_service: Optional WebFinger service + """ + self.actor_cache = actor_cache + self.discovery_service = discovery_service + + async def resolve_actor(self, actor_id: str) -> Optional[Dict[str, Any]]: + """ + Resolve an actor by ID or account. + + Args: + actor_id: Actor ID or account (user@domain) + + Returns: + Actor data or None if not found + """ + try: + # Check cache first + if self.actor_cache: + cached = await self.actor_cache.get(actor_id) + if cached: + return cached + + # Try WebFinger if it's an account + if '@' in actor_id and self.discovery_service: + actor_url = await self.discovery_service.get_actor_url(actor_id) + if actor_url: + actor_id = actor_url + + # Fetch actor data + headers = { + "Accept": "application/activity+json", + "User-Agent": "PyFed/1.0" + } + + async with aiohttp.ClientSession() as session: + async with session.get(actor_id, headers=headers) as response: + if response.status != 200: + logger.error( + f"Failed to fetch actor {actor_id}: {response.status}" + ) + return None + + actor_data = await response.json() + + # Cache the result + if self.actor_cache: + await self.actor_cache.set(actor_id, actor_data) + + return actor_data + + except Exception as e: + logger.error(f"Failed to resolve actor {actor_id}: {e}") + return None + + async def resolve_object(self, object_id: str) -> Optional[Dict[str, Any]]: + """ + Resolve an object by ID. + + Args: + object_id: Object ID + + Returns: + Object data or None if not found + """ + try: + headers = { + "Accept": "application/activity+json", + "User-Agent": "PyFed/1.0" + } + + async with aiohttp.ClientSession() as session: + async with session.get(object_id, headers=headers) as response: + if response.status != 200: + logger.error( + f"Failed to fetch object {object_id}: {response.status}" + ) + return None + + return await response.json() + + except Exception as e: + logger.error(f"Failed to resolve object {object_id}: {e}") + return None + + async def resolve_activity(self, activity_id: str) -> Optional[Dict[str, Any]]: + """ + Resolve an activity by ID. + + Args: + activity_id: Activity ID + + Returns: + Activity data or None if not found + """ + return await self.resolve_object(activity_id) \ No newline at end of file diff --git a/src/pyfed/flow.drawio.png b/src/pyfed/flow.drawio.png new file mode 100644 index 0000000000000000000000000000000000000000..3eea5386b6a1991be6a4b4a562e867579719b429 Binary files /dev/null and b/src/pyfed/flow.drawio.png differ diff --git a/src/pyfed/handlers/__init__.py b/src/pyfed/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5be1266bcff62ce921c202d45716f5a7fa5e030c --- /dev/null +++ b/src/pyfed/handlers/__init__.py @@ -0,0 +1,27 @@ +""" +Activity handlers package for processing ActivityPub activities. +""" + +from .base import ActivityHandler +from .create import CreateHandler +from .follow import FollowHandler +from .like import LikeHandler +from .delete import DeleteHandler +from .announce import AnnounceHandler +from .update import UpdateHandler +from .undo import UndoHandler +from .accept import AcceptHandler +from .reject import RejectHandler + +__all__ = [ + 'ActivityHandler', + 'CreateHandler', + 'FollowHandler', + 'LikeHandler', + 'DeleteHandler', + 'AnnounceHandler', + 'UpdateHandler', + 'UndoHandler', + 'AcceptHandler', + 'RejectHandler' +] \ No newline at end of file diff --git a/src/pyfed/handlers/accept.py b/src/pyfed/handlers/accept.py new file mode 100644 index 0000000000000000000000000000000000000000..2ccbc964d18cd13306cdcc98df0b94ed91508d6a --- /dev/null +++ b/src/pyfed/handlers/accept.py @@ -0,0 +1,83 @@ +"""Accept activity handler.""" + +from typing import Dict, Any +from .base import ActivityHandler +from pyfed.utils.exceptions import HandlerError, ValidationError +from pyfed.utils.logging import get_logger + +logger = get_logger(__name__) + +class AcceptHandler(ActivityHandler): + """Handles Accept activities.""" + + ACCEPTABLE_TYPES = ['Follow'] + + async def validate(self, activity: Dict[str, Any]) -> None: + """Validate Accept activity.""" + if activity.get('type') != 'Accept': + raise ValidationError("Invalid activity type") + + object_data = activity.get('object') + if not object_data: + raise ValidationError("Accept must have an object") + + # Check if object is a dict with type + if isinstance(object_data, dict): + object_type = object_data.get('type') + if object_type not in self.ACCEPTABLE_TYPES: + raise ValidationError(f"Cannot accept activity type: {object_type}") + + async def handle(self, activity: Dict[str, Any]) -> None: + """Handle Accept activity.""" + try: + await self.validate(activity) + # await self.pre_handle(activity) + + actor = activity.get('actor') + object_data = activity.get('object') + + # Resolve object if it's a string ID + if isinstance(object_data, str): + try: + resolved_object = await self.resolve_object_data(object_data) + object_data = resolved_object + except HandlerError: + raise HandlerError("Could not resolve Follow activity") + + # Verify it's a Follow activity + if not isinstance(object_data, dict) or object_data.get('type') != 'Follow': + raise HandlerError("Invalid Follow activity format") + + # Verify authorization + if object_data.get('object') != actor: + raise HandlerError("Unauthorized: can only accept activities targeting self") + + # Handle Follow acceptance + await self._handle_accept_follow( + follower=object_data['actor'], + following=object_data['object'] + ) + + # Store accept activity + activity_id = await self.storage.create_activity(activity) + + # Notify original actor + if object_data.get('actor'): + await self.delivery.deliver_activity(activity, [object_data['actor']]) + + # await self.post_handle(activity) + + except ValidationError: + raise + except HandlerError: + raise + except Exception as e: + logger.error(f"Failed to handle Accept activity: {e}") + raise HandlerError(f"Failed to handle Accept activity: {e}") + + async def _handle_accept_follow(self, follower: str, following: str) -> None: + """Handle accepting a Follow.""" + await self.storage.confirm_follow( + follower=follower, + following=following + ) \ No newline at end of file diff --git a/src/pyfed/handlers/announce.py b/src/pyfed/handlers/announce.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5bf4c342cf184aea67b6d64adeae28e882aded --- /dev/null +++ b/src/pyfed/handlers/announce.py @@ -0,0 +1,102 @@ +""" +Announce activity handler implementation. +""" + +from typing import Dict, Any, Optional +from datetime import datetime + +from .base import ActivityHandler +from ..utils.exceptions import ValidationError, HandlerError +from ..models.activities import APAnnounce + +class AnnounceHandler(ActivityHandler): + """Handle Announce (Boost/Reblog) activities.""" + + async def validate(self, activity: Dict[str, Any]) -> None: + """Validate Announce activity.""" + try: + # Validate basic structure + if activity['type'] != 'Announce': + raise ValidationError("Invalid activity type") + + if 'object' not in activity: + raise ValidationError("Missing object") + + if 'actor' not in activity: + raise ValidationError("Missing actor") + + # Validate object exists + object_id = activity['object'] + obj = await self.resolver.resolve_object(object_id) + if not obj: + raise ValidationError(f"Object not found: {object_id}") + + # Validate actor can announce + actor = await self.resolver.resolve_actor(activity['actor']) + if not actor: + raise ValidationError(f"Actor not found: {activity['actor']}") + + except ValidationError: + raise + except Exception as e: + raise ValidationError(f"Validation failed: {e}") + + async def process(self, activity: Dict[str, Any]) -> Optional[str]: + """Process Announce activity.""" + try: + # Store announce activity + activity_id = await self.storage.create_activity(activity) + + # Update object shares collection + object_id = activity['object'] + await self._update_shares_collection(object_id, activity['actor']) + + # Update actor's shares collection + await self._update_actor_shares(activity['actor'], object_id) + + # Send notifications + await self._notify_object_owner(object_id, activity) + + # Deliver to followers + await self._deliver_to_followers(activity) + + return activity_id + + except Exception as e: + raise HandlerError(f"Failed to process Announce: {e}") + + async def _update_shares_collection(self, object_id: str, actor: str) -> None: + """Update object's shares collection.""" + obj = await self.storage.get_object(object_id) + if obj: + shares = obj.get('shares', []) + if actor not in shares: + shares.append(actor) + obj['shares'] = shares + await self.storage.update_object(object_id, obj) + + async def _update_actor_shares(self, actor_id: str, object_id: str) -> None: + """Update actor's shares collection.""" + actor = await self.storage.get_object(actor_id) + if actor: + shares = actor.get('shares', []) + if object_id not in shares: + shares.append(object_id) + actor['shares'] = shares + await self.storage.update_object(actor_id, actor) + + async def _notify_object_owner(self, object_id: str, activity: Dict[str, Any]) -> None: + """Notify object owner about announce.""" + obj = await self.storage.get_object(object_id) + if obj and obj.get('attributedTo'): + # Implementation for notification + pass + + async def _deliver_to_followers(self, activity: Dict[str, Any]) -> None: + """Deliver announce to followers.""" + actor = await self.storage.get_object(activity['actor']) + if actor and actor.get('followers'): + await self.delivery.deliver_activity( + activity=activity, + recipients=[actor['followers']] + ) \ No newline at end of file diff --git a/src/pyfed/handlers/base.py b/src/pyfed/handlers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..caa1400a3f0902017c5d88a53c2824d8a2e98b0f --- /dev/null +++ b/src/pyfed/handlers/base.py @@ -0,0 +1,101 @@ +""" +Base activity handler implementation. +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional +from datetime import datetime + +from ..utils.exceptions import HandlerError, ValidationError +from ..utils.logging import get_logger +from ..storage.base import StorageBackend +from ..federation.resolver import ActivityPubResolver +from ..federation.delivery import ActivityDelivery +# from ..monitoring.decorators import monitor, trace_async +from ..content.handler import ContentHandler # Instead of individual imports + +logger = get_logger(__name__) + +class ActivityHandler(ABC): + """Base activity handler.""" + + def __init__(self, + storage: StorageBackend, + resolver: ActivityPubResolver, + delivery: ActivityDelivery): + self.storage = storage + self.resolver = resolver + self.delivery = delivery + + @abstractmethod + async def validate(self, activity: Dict[str, Any]) -> None: + """ + Validate activity. + + Args: + activity: Activity to validate + + Raises: + ValidationError if validation fails + """ + pass + + @abstractmethod + async def process(self, activity: Dict[str, Any]) -> Optional[str]: + """ + Process activity. + + Args: + activity: Activity to process + + Returns: + Activity ID if successful + + Raises: + HandlerError if processing fails + """ + pass + + # @monitor("activity.handle") + # @trace_async("activity.handle") + async def handle(self, activity: Dict[str, Any]) -> Optional[str]: + """ + Handle activity. + + Args: + activity: Activity to handle + + Returns: + Activity ID if successful + + Raises: + HandlerError if handling fails + """ + try: + # Pre-handle operations + await self.pre_handle(activity) + + # Validate activity + await self.validate(activity) + + # Process activity + result = await self.process(activity) + + # Post-handle operations + await self.post_handle(activity) + + return result + + except (ValidationError, HandlerError): + raise + except Exception as e: + logger.error(f"Handler error: {e}") + raise HandlerError(f"Failed to handle activity: {e}") + + async def pre_handle(self, activity: Dict[str, Any]) -> None: + """Pre-handle operations.""" + pass + + async def post_handle(self, activity: Dict[str, Any]) -> None: + """Post-handle operations.""" + pass \ No newline at end of file diff --git a/src/pyfed/handlers/create.py b/src/pyfed/handlers/create.py new file mode 100644 index 0000000000000000000000000000000000000000..44178d0d06473e519cab699d9d7dae2e833e3480 --- /dev/null +++ b/src/pyfed/handlers/create.py @@ -0,0 +1,134 @@ +""" +Enhanced Create activity handler. +""" + +from typing import Dict, Any, Optional, List +from datetime import datetime + +from .base import ActivityHandler +from ..utils.exceptions import ValidationError, HandlerError +from ..models.activities import APCreate +from ..models.objects import APNote, APArticle, APImage + +from pyfed.utils.logging import get_logger + +logger = get_logger(__name__) + +class CreateHandler(ActivityHandler): + """Enhanced Create activity handler.""" + + SUPPORTED_TYPES = { + 'Note': APNote, + 'Article': APArticle, + 'Image': APImage + } + + async def validate(self, activity: Dict[str, Any]) -> None: + """ + Enhanced Create validation. + + Validates: + - Activity structure + - Object type support + - Content rules + - Media attachments + - Rate limits + """ + try: + # Validate basic structure + create = APCreate.model_validate(activity) + + # Validate object + obj = activity.get('object', {}) + obj_type = obj.get('type') + + if not obj_type or obj_type not in self.SUPPORTED_TYPES: + raise ValidationError(f"Unsupported object type: {obj_type}") + + # Validate object model + model_class = self.SUPPORTED_TYPES[obj_type] + model_class.model_validate(obj) + + # Check actor permissions + actor = await self.resolver.resolve_actor(activity['actor']) + if not actor: + raise ValidationError("Actor not found") + + # Check rate limits + await self._check_rate_limits(actor['id']) + + # Validate content + await self._validate_content(obj) + + except Exception as e: + raise ValidationError(f"Create validation failed: {e}") + + async def process(self, activity: Dict[str, Any]) -> Optional[str]: + """ + Enhanced Create processing. + + Handles: + - Object storage + - Side effects + - Notifications + - Federation + """ + try: + # Store object + obj = activity['object'] + object_id = await self.storage.create_object(obj) + + # Store activity + activity_id = await self.storage.create_activity(activity) + + # Process mentions + if mentions := await self._extract_mentions(obj): + await self._handle_mentions(mentions, activity) + + # Process attachments + if attachments := obj.get('attachment', []): + await self._process_attachments(attachments, object_id) + + # Handle notifications + await self._send_notifications(activity) + + # Update collections + await self._update_collections(activity) + + return activity_id + + except Exception as e: + logger.error(f"Create processing failed: {e}") + raise HandlerError(f"Failed to process Create: {e}") + + async def _check_rate_limits(self, actor_id: str) -> None: + """Check rate limits for actor.""" + # Implementation here + + async def _validate_content(self, obj: Dict[str, Any]) -> None: + """Validate object content.""" + # Implementation here + + async def _extract_mentions(self, obj: Dict[str, Any]) -> List[str]: + """Extract mentions from object.""" + # Implementation here + + async def _handle_mentions(self, + mentions: List[str], + activity: Dict[str, Any]) -> None: + """Handle mentions in activity.""" + # Implementation here + + async def _process_attachments(self, + attachments: List[Dict[str, Any]], + object_id: str) -> None: + """Process media attachments.""" + # Implementation here + + async def _send_notifications(self, activity: Dict[str, Any]) -> None: + """Send notifications for activity.""" + # Implementation here + + async def _update_collections(self, activity: Dict[str, Any]) -> None: + """Update relevant collections.""" + # Implementation here \ No newline at end of file diff --git a/src/pyfed/handlers/delete.py b/src/pyfed/handlers/delete.py new file mode 100644 index 0000000000000000000000000000000000000000..142a7e6ee60fdd97c84d28c7ff6fbd3d71f5495f --- /dev/null +++ b/src/pyfed/handlers/delete.py @@ -0,0 +1,81 @@ +""" +Delete activity handler implementation. +""" + +from typing import Dict, Any, Optional +from datetime import datetime + +from .base import ActivityHandler +# from ..monitoring.decorators import monitor, trace_async +from ..utils.exceptions import ValidationError, HandlerError +from ..models.activities import APDelete + +class DeleteHandler(ActivityHandler): + """Handle Delete activities.""" + + # @monitor("handler.delete.validate") + async def validate(self, activity: Dict[str, Any]) -> None: + """Validate Delete activity.""" + try: + # Validate basic structure + delete = APDelete.model_validate(activity) + + # Verify actor permissions + actor = await self.resolver.resolve_actor(activity['actor']) + if not actor: + raise ValidationError("Actor not found") + + # Verify object ownership + object_id = activity.get('object') + if isinstance(object_id, dict): + object_id = object_id.get('id') + + stored_object = await self.storage.get_object(object_id) + if not stored_object: + raise ValidationError("Object not found") + + if stored_object.get('attributedTo') != activity['actor']: + raise ValidationError("Not authorized to delete this object") + + except Exception as e: + raise ValidationError(f"Delete validation failed: {e}") + + # @monitor("handler.delete.process") + async def process(self, activity: Dict[str, Any]) -> Optional[str]: + """Process Delete activity.""" + try: + # Store delete activity + activity_id = await self.storage.create_activity(activity) + + # Delete object + object_id = activity.get('object') + if isinstance(object_id, dict): + object_id = object_id.get('id') + + await self.storage.delete_object(object_id) + + # Handle tombstone + await self._create_tombstone(object_id) + + # Update collections + await self._update_collections(object_id) + + return activity_id + + except Exception as e: + raise HandlerError(f"Failed to process Delete: {e}") + + async def _create_tombstone(self, object_id: str) -> None: + """Create tombstone for deleted object.""" + tombstone = { + 'type': 'Tombstone', + 'id': object_id, + 'deleted': datetime.utcnow().isoformat() + } + await self.storage.create_object(tombstone) + + async def _update_collections(self, object_id: str) -> None: + """Update collections after deletion.""" + # Remove from collections + # Implementation here + \ No newline at end of file diff --git a/src/pyfed/handlers/follow.py b/src/pyfed/handlers/follow.py new file mode 100644 index 0000000000000000000000000000000000000000..45e042ec059f4e27dad83039586d907f6dab7e83 --- /dev/null +++ b/src/pyfed/handlers/follow.py @@ -0,0 +1,72 @@ +""" +Handler for Follow activities. +""" + +from typing import Dict, Any +from .base import ActivityHandler +from pyfed.utils.exceptions import HandlerError, ValidationError +from pyfed.utils.logging import get_logger + +logger = get_logger(__name__) + +class FollowHandler(ActivityHandler): + """Handles Follow activities.""" + + async def validate(self, activity: Dict[str, Any]) -> None: + """Validate Follow activity.""" + if activity.get('type') != 'Follow': + raise ValidationError("Invalid activity type") + + if not activity.get('actor'): + raise ValidationError("Follow must have an actor") + + if not activity.get('object'): + raise ValidationError("Follow must have an object") + + async def handle(self, activity: Dict[str, Any]) -> None: + """Handle Follow activity.""" + try: + await self.validate(activity) + # await self.pre_handle(activity) + + actor = activity.get('actor') + target = activity.get('object') + + # Resolve target actor + target_actor = await self.resolver.resolve_actor(target) + if not target_actor: + raise HandlerError(f"Could not resolve target actor: {target}") + + # Check if already following + if await self.storage.is_following(actor, target): + raise HandlerError("Already following this actor") + + # Store follow request + await self.storage.create_follow_request( + follower=actor, + following=target + ) + + # Store follow activity + activity_id = await self.storage.create_activity(activity) + + # Deliver to target's inbox + target_inbox = target_actor.get('inbox') + if not target_inbox: + raise HandlerError("Target actor has no inbox") + + await self.delivery.deliver_activity( + activity=activity, + recipients=[target_inbox] + ) + + # await self.post_handle(activity) + logger.info(f"Handled Follow activity: {actor} -> {target}") + + except ValidationError: + raise + except HandlerError: + raise + except Exception as e: + logger.error(f"Failed to handle Follow activity: {e}") + raise HandlerError(f"Failed to handle Follow activity: {e}") \ No newline at end of file diff --git a/src/pyfed/handlers/like.py b/src/pyfed/handlers/like.py new file mode 100644 index 0000000000000000000000000000000000000000..4e193a13d8bfe63323462598cfc0a8724f112593 --- /dev/null +++ b/src/pyfed/handlers/like.py @@ -0,0 +1,76 @@ +"""Like activity handler.""" + +from typing import Dict, Any +from .base import ActivityHandler +from pyfed.utils.exceptions import HandlerError, ValidationError +from pyfed.utils.logging import get_logger + +logger = get_logger(__name__) + +class LikeHandler(ActivityHandler): + """Handles Like activities.""" + + async def validate(self, activity: Dict[str, Any]) -> None: + """Validate Like activity.""" + if activity.get('type') != 'Like': + raise ValidationError("Invalid activity type") + + if not activity.get('actor'): + raise ValidationError("Like must have an actor") + + if not activity.get('object'): + raise ValidationError("Like must have an object") + + async def handle(self, activity: Dict[str, Any]) -> None: + """Handle Like activity.""" + try: + await self.validate(activity) + await self.pre_handle(activity) + + actor = activity.get('actor') + object_id = activity.get('object') + + # Resolve target object + object_data = await self.resolve_object_data(object_id) + if not object_data: + raise HandlerError(f"Could not resolve object: {object_id}") + + # Check for duplicate like + if await self.storage.has_liked(actor, object_id): + raise HandlerError("Already liked this object") + + # Store like + await self.storage.create_like( + actor=actor, + object_id=object_id + ) + + # Store like activity + activity_id = await self.storage.create_activity(activity) + + # Notify object creator + await self._notify_object_owner(object_data, activity) + + await self.post_handle(activity) + logger.info(f"Handled Like activity: {actor} -> {object_id}") + + except ValidationError: + raise + except HandlerError: + raise + except Exception as e: + logger.error(f"Failed to handle Like activity: {e}") + raise HandlerError(f"Failed to handle Like activity: {e}") + + async def _notify_object_owner(self, object_data: Dict[str, Any], activity: Dict[str, Any]) -> None: + """Notify object owner about the like.""" + if object_data.get('attributedTo'): + target_actor = await self.resolver.resolve_actor( + object_data['attributedTo'] + ) + if target_actor and target_actor.get('inbox'): + await self.delivery.deliver_activity( + activity=activity, + recipients=[target_actor['inbox']] + ) + diff --git a/src/pyfed/handlers/reject.py b/src/pyfed/handlers/reject.py new file mode 100644 index 0000000000000000000000000000000000000000..a5cea9526c75c4359b840acd75bdb287a7bbdc39 --- /dev/null +++ b/src/pyfed/handlers/reject.py @@ -0,0 +1,94 @@ +""" +Handler for Reject activities. +""" + +from typing import Dict, Any +from .base import ActivityHandler +from pyfed.utils.exceptions import HandlerError, ValidationError +from pyfed.utils.logging import get_logger + +logger = get_logger(__name__) + +class RejectHandler(ActivityHandler): + """Handles Reject activities.""" + + REJECTABLE_TYPES = ['Follow'] + + async def handle(self, activity: Dict[str, Any]) -> None: + """Handle Reject activity.""" + try: + await self.validate(activity) + await self.pre_handle(activity) + + actor = activity.get('actor') + object_data = activity.get('object') + + # Resolve object if it's a string ID + if isinstance(object_data, str): + resolved_object = await self.resolve_object_data(object_data) + if not resolved_object: + raise HandlerError("Could not resolve activity") + object_data = resolved_object + + # Verify authorization + if object_data.get('object') != actor: + raise HandlerError("Unauthorized: can only reject activities targeting self") + + # Handle based on rejected activity type + if object_data.get('type') == 'Follow': + # Check if follow request exists and hasn't been handled + status = await self.storage.get_follow_request_status( + follower=object_data['actor'], + following=object_data['object'] + ) + + if status == 'rejected': + raise HandlerError("Follow request already rejected") + elif status == 'accepted': + raise HandlerError("Follow request already accepted") + + await self._handle_reject_follow( + follower=object_data['actor'], + following=object_data['object'], + reason=activity.get('content') + ) + + # Store reject activity + activity_id = await self.storage.create_activity(activity) + + # Notify original actor + if object_data.get('actor'): + await self.delivery.deliver_activity(activity, [object_data['actor']]) + + await self.post_handle(activity) + logger.info(f"Handled Reject activity: {actor} -> {object_data.get('id')}") + + except ValidationError: + raise + except HandlerError: + raise + except Exception as e: + logger.error(f"Failed to handle Reject activity: {e}") + raise HandlerError(f"Failed to handle Reject activity: {e}") + + async def _handle_reject_follow(self, follower: str, following: str, reason: str = None) -> None: + """Handle rejecting a Follow.""" + await self.storage.remove_follow_request( + follower=follower, + following=following, + reason=reason + ) + + async def validate(self, activity: Dict[str, Any]) -> None: + """Validate Reject activity.""" + if activity.get('type') != 'Reject': + raise ValidationError("Invalid activity type") + + object_data = activity.get('object') + if not object_data: + raise ValidationError("Reject must have an object") + + if isinstance(object_data, dict): + object_type = object_data.get('type') + if object_type not in self.REJECTABLE_TYPES: + raise ValidationError(f"Cannot reject activity type: {object_type}") \ No newline at end of file diff --git a/src/pyfed/handlers/undo.py b/src/pyfed/handlers/undo.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dd3ce263e9c76321375f03be24a94be443fe36 --- /dev/null +++ b/src/pyfed/handlers/undo.py @@ -0,0 +1,150 @@ +""" +Undo activity handler implementation. +""" + +from typing import Dict, Any, Optional +from datetime import datetime + +from .base import ActivityHandler +from ..utils.exceptions import ValidationError, HandlerError +from ..models.activities import APUndo + +class UndoHandler(ActivityHandler): + """Handle Undo activities.""" + + async def validate(self, activity: Dict[str, Any]) -> None: + """Validate Undo activity.""" + try: + # Validate basic structure + if activity['type'] != 'Undo': + raise ValidationError("Invalid activity type") + + if 'object' not in activity: + raise ValidationError("Missing object") + + if 'actor' not in activity: + raise ValidationError("Missing actor") + + # Validate object exists + obj = activity['object'] + if isinstance(obj, str): + obj = await self.storage.get_activity(obj) + if not obj: + raise ValidationError(f"Activity not found: {activity['object']}") + + # Validate actor has permission + if obj.get('actor') != activity['actor']: + raise ValidationError("Not authorized to undo activity") + + # Validate activity can be undone + if obj['type'] not in ['Like', 'Announce', 'Follow', 'Block']: + raise ValidationError(f"Cannot undo activity type: {obj['type']}") + + except ValidationError: + raise + except Exception as e: + raise ValidationError(f"Validation failed: {e}") + + async def process(self, activity: Dict[str, Any]) -> Optional[str]: + """Process Undo activity.""" + try: + # Store undo activity + activity_id = await self.storage.create_activity(activity) + + # Get original activity + obj = activity['object'] + if isinstance(obj, str): + obj = await self.storage.get_activity(obj) + + # Process based on activity type + if obj['type'] == 'Like': + await self._undo_like(obj) + elif obj['type'] == 'Announce': + await self._undo_announce(obj) + elif obj['type'] == 'Follow': + await self._undo_follow(obj) + elif obj['type'] == 'Block': + await self._undo_block(obj) + + # Deliver undo to recipients + await self._deliver_undo(activity) + + return activity_id + + except Exception as e: + raise HandlerError(f"Failed to process Undo: {e}") + + async def _undo_like(self, activity: Dict[str, Any]) -> None: + """Undo a Like activity.""" + object_id = activity['object'] + actor = activity['actor'] + + obj = await self.storage.get_object(object_id) + if obj: + likes = obj.get('likes', []) + if actor in likes: + likes.remove(actor) + obj['likes'] = likes + await self.storage.update_object(object_id, obj) + + async def _undo_announce(self, activity: Dict[str, Any]) -> None: + """Undo an Announce activity.""" + object_id = activity['object'] + actor = activity['actor'] + + obj = await self.storage.get_object(object_id) + if obj: + shares = obj.get('shares', []) + if actor in shares: + shares.remove(actor) + obj['shares'] = shares + await self.storage.update_object(object_id, obj) + + async def _undo_follow(self, activity: Dict[str, Any]) -> None: + """Undo a Follow activity.""" + object_id = activity['object'] + actor = activity['actor'] + + # Remove from following collection + actor_obj = await self.storage.get_object(actor) + if actor_obj: + following = actor_obj.get('following', []) + if object_id in following: + following.remove(object_id) + actor_obj['following'] = following + await self.storage.update_object(actor, actor_obj) + + # Remove from followers collection + target = await self.storage.get_object(object_id) + if target: + followers = target.get('followers', []) + if actor in followers: + followers.remove(actor) + target['followers'] = followers + await self.storage.update_object(object_id, target) + + async def _undo_block(self, activity: Dict[str, Any]) -> None: + """Undo a Block activity.""" + object_id = activity['object'] + actor = activity['actor'] + + actor_obj = await self.storage.get_object(actor) + if actor_obj: + blocks = actor_obj.get('blocks', []) + if object_id in blocks: + blocks.remove(object_id) + actor_obj['blocks'] = blocks + await self.storage.update_object(actor, actor_obj) + + async def _deliver_undo(self, activity: Dict[str, Any]) -> None: + """Deliver undo to recipients.""" + obj = activity['object'] + if isinstance(obj, str): + obj = await self.storage.get_activity(obj) + + # Deliver to original recipients + if obj and obj.get('to'): + await self.delivery.deliver_activity( + activity=activity, + recipients=obj['to'] + ) \ No newline at end of file diff --git a/src/pyfed/handlers/update.py b/src/pyfed/handlers/update.py new file mode 100644 index 0000000000000000000000000000000000000000..63e7a70f69f9f7e27c061bb127d9bcf0484d7370 --- /dev/null +++ b/src/pyfed/handlers/update.py @@ -0,0 +1,96 @@ +""" +Update activity handler implementation. +""" + +from typing import Dict, Any, Optional +from datetime import datetime + +from .base import ActivityHandler +from ..utils.exceptions import ValidationError, HandlerError +from ..models.activities import APUpdate + +class UpdateHandler(ActivityHandler): + """Handle Update activities.""" + + async def validate(self, activity: Dict[str, Any]) -> None: + """Validate Update activity.""" + try: + # Validate basic structure + if activity['type'] != 'Update': + raise ValidationError("Invalid activity type") + + if 'object' not in activity: + raise ValidationError("Missing object") + + if 'actor' not in activity: + raise ValidationError("Missing actor") + + # Validate object exists + obj = activity['object'] + if not isinstance(obj, dict) or 'id' not in obj: + raise ValidationError("Invalid object format") + + existing = await self.storage.get_object(obj['id']) + if not existing: + raise ValidationError(f"Object not found: {obj['id']}") + + # Validate actor has permission + if existing.get('attributedTo') != activity['actor']: + raise ValidationError("Not authorized to update object") + + except ValidationError: + raise + except Exception as e: + raise ValidationError(f"Validation failed: {e}") + + async def process(self, activity: Dict[str, Any]) -> Optional[str]: + """Process Update activity.""" + try: + # Store update activity + activity_id = await self.storage.create_activity(activity) + + # Update object + obj = activity['object'] + obj['updated'] = datetime.utcnow().isoformat() + + # Preserve immutable fields + existing = await self.storage.get_object(obj['id']) + for field in ['id', 'type', 'attributedTo', 'published']: + obj[field] = existing[field] + + # Store updated object + await self.storage.update_object(obj['id'], obj) + + # Deliver update to recipients + await self._deliver_update(activity) + + return activity_id + + except Exception as e: + raise HandlerError(f"Failed to process Update: {e}") + + async def _deliver_update(self, activity: Dict[str, Any]) -> None: + """Deliver update to recipients.""" + obj = activity['object'] + + # Collect recipients + recipients = [] + + # Add mentioned users + if 'tag' in obj: + for tag in obj['tag']: + if tag.get('type') == 'Mention': + recipients.append(tag['href']) + + # Add followers if public + actor = await self.storage.get_object(activity['actor']) + if actor and actor.get('followers'): + if 'public' in obj.get('to', []): + recipients.append(actor['followers']) + + # Deliver activity + if recipients: + await self.delivery.deliver_activity( + activity=activity, + recipients=recipients + ) \ No newline at end of file diff --git a/src/pyfed/integration/base.py b/src/pyfed/integration/base.py new file mode 100644 index 0000000000000000000000000000000000000000..90065272ed14f8f228af70c3325035312c00b5b1 --- /dev/null +++ b/src/pyfed/integration/base.py @@ -0,0 +1,76 @@ +""" +Base integration interfaces. +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List +from dataclasses import dataclass +import yaml +import json +from pathlib import Path + +from ..utils.exceptions import IntegrationError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +@dataclass +class IntegrationConfig: + """Integration configuration.""" + domain: str + database_url: str + redis_url: str + media_path: str + key_path: str + max_payload_size: int = 5_000_000 # 5MB + request_timeout: int = 30 + debug: bool = False + +class BaseIntegration(ABC): + """Base integration interface.""" + + def __init__(self, config: IntegrationConfig): + self.config = config + self.app = None + self.storage = None + self.delivery = None + self.key_manager = None + self.instance = None + + @abstractmethod + async def initialize(self) -> None: + """Initialize integration.""" + pass + + @abstractmethod + async def shutdown(self) -> None: + """Shutdown integration.""" + pass + + @abstractmethod + async def handle_activity(self, activity: Dict[str, Any]) -> Optional[str]: + """Handle incoming activity.""" + pass + + @abstractmethod + async def deliver_activity(self, + activity: Dict[str, Any], + recipients: List[str]) -> None: + """Deliver activity to recipients.""" + pass + + @classmethod + def load_config(cls, config_path: str) -> IntegrationConfig: + """Load configuration from file.""" + try: + with open(config_path) as f: + if config_path.endswith('.yaml') or config_path.endswith('.yml'): + data = yaml.safe_load(f) + else: + data = json.load(f) + + return IntegrationConfig(**data) + + except Exception as e: + logger.error(f"Failed to load config: {e}") + raise IntegrationError(f"Failed to load config: {e}") \ No newline at end of file diff --git a/src/pyfed/integration/config.py b/src/pyfed/integration/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a57cb813e39d188b99e82752905ff725af093060 --- /dev/null +++ b/src/pyfed/integration/config.py @@ -0,0 +1,177 @@ +""" +Configuration management implementation. +""" + +from typing import Dict, Any, Optional +import yaml +import json +from pathlib import Path +import os +from dataclasses import dataclass, asdict + +from ..utils.exceptions import ConfigError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +@dataclass +class DatabaseConfig: + """Database configuration.""" + url: str + min_connections: int = 5 + max_connections: int = 20 + timeout: int = 30 + +@dataclass +class RedisConfig: + """Redis configuration.""" + url: str + pool_size: int = 10 + timeout: int = 30 + +@dataclass +class SecurityConfig: + """Security configuration.""" + key_path: str + signature_ttl: int = 300 + max_payload_size: int = 5_000_000 + allowed_algorithms: list = None + + def __post_init__(self): + if self.allowed_algorithms is None: + self.allowed_algorithms = ["rsa-sha256"] + +@dataclass +class FederationConfig: + """Federation configuration.""" + domain: str + shared_inbox: bool = True + delivery_timeout: int = 30 + max_recipients: int = 100 + retry_delay: int = 300 + +@dataclass +class MediaConfig: + """Media configuration.""" + upload_path: str + max_size: int = 10_000_000 + allowed_types: list = None + + def __post_init__(self): + if self.allowed_types is None: + self.allowed_types = [ + 'image/jpeg', + 'image/png', + 'image/gif', + 'video/mp4', + 'audio/mpeg' + ] + +@dataclass +class ApplicationConfig: + """Application configuration.""" + database: DatabaseConfig + redis: RedisConfig + security: SecurityConfig + federation: FederationConfig + media: MediaConfig + debug: bool = False + +class ConfigurationManager: + """Manage application configuration.""" + + def __init__(self, config_path: Optional[str] = None): + self.config_path = config_path + self.config = None + + def load_config(self) -> ApplicationConfig: + """Load configuration.""" + try: + # Load from file if specified + if self.config_path: + return self._load_from_file(self.config_path) + + # Load from environment + return self._load_from_env() + + except Exception as e: + logger.error(f"Failed to load config: {e}") + raise ConfigError(f"Failed to load config: {e}") + + def _load_from_file(self, path: str) -> ApplicationConfig: + """Load configuration from file.""" + try: + with open(path) as f: + if path.endswith('.yaml') or path.endswith('.yml'): + data = yaml.safe_load(f) + else: + data = json.load(f) + + return self._create_config(data) + + except Exception as e: + raise ConfigError(f"Failed to load config file: {e}") + + def _load_from_env(self) -> ApplicationConfig: + """Load configuration from environment variables.""" + try: + return ApplicationConfig( + database=DatabaseConfig( + url=os.getenv('DATABASE_URL', 'sqlite:///pyfed.db'), + min_connections=int(os.getenv('DB_MIN_CONNECTIONS', '5')), + max_connections=int(os.getenv('DB_MAX_CONNECTIONS', '20')), + timeout=int(os.getenv('DB_TIMEOUT', '30')) + ), + redis=RedisConfig( + url=os.getenv('REDIS_URL', 'redis://localhost'), + pool_size=int(os.getenv('REDIS_POOL_SIZE', '10')), + timeout=int(os.getenv('REDIS_TIMEOUT', '30')) + ), + security=SecurityConfig( + key_path=os.getenv('KEY_PATH', 'keys'), + signature_ttl=int(os.getenv('SIGNATURE_TTL', '300')), + max_payload_size=int(os.getenv('MAX_PAYLOAD_SIZE', '5000000')) + ), + federation=FederationConfig( + domain=os.getenv('DOMAIN', 'localhost'), + shared_inbox=os.getenv('SHARED_INBOX', 'true').lower() == 'true', + delivery_timeout=int(os.getenv('DELIVERY_TIMEOUT', '30')), + max_recipients=int(os.getenv('MAX_RECIPIENTS', '100')) + ), + media=MediaConfig( + upload_path=os.getenv('UPLOAD_PATH', 'uploads'), + max_size=int(os.getenv('MAX_UPLOAD_SIZE', '10000000')) + ), + debug=os.getenv('DEBUG', 'false').lower() == 'true' + ) + + except Exception as e: + raise ConfigError(f"Failed to load config from env: {e}") + + def _create_config(self, data: Dict[str, Any]) -> ApplicationConfig: + """Create config from dictionary.""" + try: + return ApplicationConfig( + database=DatabaseConfig(**data.get('database', {})), + redis=RedisConfig(**data.get('redis', {})), + security=SecurityConfig(**data.get('security', {})), + federation=FederationConfig(**data.get('federation', {})), + media=MediaConfig(**data.get('media', {})), + debug=data.get('debug', False) + ) + except Exception as e: + raise ConfigError(f"Invalid config data: {e}") + + def save_config(self, config: ApplicationConfig, path: str) -> None: + """Save configuration to file.""" + try: + data = asdict(config) + + with open(path, 'w') as f: + if path.endswith('.yaml') or path.endswith('.yml'): + yaml.dump(data, f, default_flow_style=False) + else: + json.dump(data, f, indent=2) + + except Exception as e: + raise ConfigError(f"Failed to save config: {e}") \ No newline at end of file diff --git a/src/pyfed/integration/frameworks/django.py b/src/pyfed/integration/frameworks/django.py new file mode 100644 index 0000000000000000000000000000000000000000..257ede31b3f479f8aa519fe0640c46e211136b3b --- /dev/null +++ b/src/pyfed/integration/frameworks/django.py @@ -0,0 +1,154 @@ +""" +Django integration implementation. +""" + +from typing import Dict, Any, Optional, List +from django.http import HttpRequest, JsonResponse, HttpResponseBadRequest +from django.views import View +from django.conf import settings +import json +import asyncio + +from ..base import BaseIntegration, IntegrationConfig +from ...utils.exceptions import IntegrationError +from ...utils.logging import get_logger +from ...storage import StorageBackend +from ...federation.delivery import ActivityDelivery +from ...security.key_management import KeyManager +from ..middleware import ActivityPubMiddleware + +logger = get_logger(__name__) + +class DjangoIntegration(BaseIntegration): + """Django integration.""" + + def __init__(self, config: IntegrationConfig): + super().__init__(config) + self.middleware = None + self._setup_views() + + def _setup_views(self) -> None: + """Setup Django views.""" + + class InboxView(View): + """Shared inbox view.""" + + async def post(self, request: HttpRequest): + try: + # Verify request + if not await self.middleware.process_request( + method=request.method, + path=request.path, + headers=dict(request.headers), + body=json.loads(request.body) + ): + return JsonResponse( + {"error": "Unauthorized"}, + status=401 + ) + + # Handle activity + activity = json.loads(request.body) + result = await self.handle_activity(activity) + + return JsonResponse( + {"id": result}, + status=202, + content_type="application/activity+json" + ) + + except Exception as e: + logger.error(f"Inbox error: {e}") + return JsonResponse( + {"error": str(e)}, + status=500 + ) + + class ActorView(View): + """Instance actor view.""" + + async def get(self, request: HttpRequest): + try: + return JsonResponse( + self.instance.actor, + content_type="application/activity+json" + ) + except Exception as e: + logger.error(f"Actor error: {e}") + return JsonResponse( + {"error": str(e)}, + status=500 + ) + + class NodeInfoView(View): + """NodeInfo view.""" + + async def get(self, request: HttpRequest): + try: + nodeinfo = await self.instance.get_nodeinfo() + return JsonResponse(nodeinfo) + except Exception as e: + logger.error(f"NodeInfo error: {e}") + return JsonResponse( + {"error": str(e)}, + status=500 + ) + + # Store view classes + self.views = { + 'inbox': InboxView, + 'actor': ActorView, + 'nodeinfo': NodeInfoView + } + + async def initialize(self) -> None: + """Initialize integration.""" + try: + # Initialize components + self.storage = StorageBackend.create( + provider="postgresql", + database_url=self.config.database_url + ) + await self.storage.initialize() + + self.key_manager = KeyManager(self.config.key_path) + await self.key_manager.initialize() + + self.delivery = ActivityDelivery( + storage=self.storage, + key_manager=self.key_manager + ) + await self.delivery.initialize() + + # Initialize middleware + self.middleware = ActivityPubMiddleware( + signature_verifier=self.key_manager.signature_verifier, + rate_limiter=self.delivery.rate_limiter + ) + + logger.info("Django integration initialized") + + except Exception as e: + logger.error(f"Failed to initialize Django integration: {e}") + raise IntegrationError(f"Integration initialization failed: {e}") + + async def shutdown(self) -> None: + """Shutdown integration.""" + try: + await self.storage.close() + await self.delivery.close() + await self.key_manager.close() + + except Exception as e: + logger.error(f"Failed to shutdown Django integration: {e}") + raise IntegrationError(f"Integration shutdown failed: {e}") + + async def handle_activity(self, activity: Dict[str, Any]) -> Optional[str]: + """Handle incoming activity.""" + return await self.delivery.process_activity(activity) + + async def deliver_activity(self, + activity: Dict[str, Any], + recipients: List[str]) -> None: + """Deliver activity to recipients.""" + await self.delivery.deliver_activity(activity, recipients) \ No newline at end of file diff --git a/src/pyfed/integration/frameworks/fastapi.py b/src/pyfed/integration/frameworks/fastapi.py new file mode 100644 index 0000000000000000000000000000000000000000..2a9ae5086885c3b3be97bd6f14d93c9c99db9fd5 --- /dev/null +++ b/src/pyfed/integration/frameworks/fastapi.py @@ -0,0 +1,145 @@ +""" +FastAPI integration implementation. +""" + +from typing import Dict, Any, Optional, List +from fastapi import FastAPI, Request, Response, HTTPException, Depends +from fastapi.middleware.cors import CORSMiddleware +import json + +from ..base import BaseIntegration, IntegrationConfig +from ...utils.exceptions import IntegrationError +from ...utils.logging import get_logger +from ...storage import StorageBackend +from ...federation.delivery import ActivityDelivery +from ...security.key_management import KeyManager +from ..middleware import ActivityPubMiddleware + +logger = get_logger(__name__) + +class FastAPIIntegration(BaseIntegration): + """FastAPI integration.""" + + def __init__(self, config: IntegrationConfig): + super().__init__(config) + self.app = FastAPI(title="PyFed ActivityPub") + self.middleware = None + self._setup_middleware() + self._setup_routes() + + def _setup_middleware(self) -> None: + """Setup FastAPI middleware.""" + # Add CORS middleware + self.app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + def _setup_routes(self) -> None: + """Setup API routes.""" + + @self.app.post("/inbox") + async def shared_inbox(request: Request): + """Handle shared inbox.""" + try: + # Verify request + if not await self.middleware.process_request( + method=request.method, + path=request.url.path, + headers=dict(request.headers), + body=await request.json() + ): + raise HTTPException(status_code=401, detail="Unauthorized") + + # Handle activity + activity = await request.json() + result = await self.handle_activity(activity) + + return Response( + content=json.dumps({"id": result}), + media_type="application/activity+json", + status_code=202 + ) + + except Exception as e: + logger.error(f"Inbox error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @self.app.get("/actor") + async def get_instance_actor(): + """Get instance actor.""" + try: + return Response( + content=json.dumps(self.instance.actor), + media_type="application/activity+json" + ) + except Exception as e: + logger.error(f"Actor error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @self.app.get("/.well-known/nodeinfo") + async def get_nodeinfo(): + """Get NodeInfo.""" + try: + return Response( + content=json.dumps(await self.instance.get_nodeinfo()), + media_type="application/json" + ) + except Exception as e: + logger.error(f"NodeInfo error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + async def initialize(self) -> None: + """Initialize integration.""" + try: + # Initialize components + self.storage = StorageBackend.create( + provider="postgresql", + database_url=self.config.database_url + ) + await self.storage.initialize() + + self.key_manager = KeyManager(self.config.key_path) + await self.key_manager.initialize() + + self.delivery = ActivityDelivery( + storage=self.storage, + key_manager=self.key_manager + ) + await self.delivery.initialize() + + # Initialize middleware + self.middleware = ActivityPubMiddleware( + signature_verifier=self.key_manager.signature_verifier, + rate_limiter=self.delivery.rate_limiter + ) + + logger.info("FastAPI integration initialized") + + except Exception as e: + logger.error(f"Failed to initialize FastAPI integration: {e}") + raise IntegrationError(f"Integration initialization failed: {e}") + + async def shutdown(self) -> None: + """Shutdown integration.""" + try: + await self.storage.close() + await self.delivery.close() + await self.key_manager.close() + + except Exception as e: + logger.error(f"Failed to shutdown FastAPI integration: {e}") + raise IntegrationError(f"Integration shutdown failed: {e}") + + async def handle_activity(self, activity: Dict[str, Any]) -> Optional[str]: + """Handle incoming activity.""" + return await self.delivery.process_activity(activity) + + async def deliver_activity(self, + activity: Dict[str, Any], + recipients: List[str]) -> None: + """Deliver activity to recipients.""" + await self.delivery.deliver_activity(activity, recipients) \ No newline at end of file diff --git a/src/pyfed/integration/frameworks/flask.py b/src/pyfed/integration/frameworks/flask.py new file mode 100644 index 0000000000000000000000000000000000000000..df11a03ed9bd49a0dea9a820264ad726fc95a069 --- /dev/null +++ b/src/pyfed/integration/frameworks/flask.py @@ -0,0 +1,129 @@ +""" +Flask integration implementation. +""" + +from typing import Dict, Any, Optional, List +from flask import Flask, request, jsonify, Response +import json + +from ..base import BaseIntegration, IntegrationConfig +from ...utils.exceptions import IntegrationError +from ...utils.logging import get_logger +from ...storage import StorageBackend +from ...federation.delivery import ActivityDelivery +from ...security.key_management import KeyManager +from ..middleware import ActivityPubMiddleware + +logger = get_logger(__name__) + +class FlaskIntegration(BaseIntegration): + """Flask integration.""" + + def __init__(self, config: IntegrationConfig): + super().__init__(config) + self.app = Flask(__name__) + self.middleware = None + self._setup_routes() + + def _setup_routes(self) -> None: + """Setup Flask routes.""" + + @self.app.post("/inbox") + async def shared_inbox(): + """Handle shared inbox.""" + try: + # Verify request + if not await self.middleware.process_request( + method=request.method, + path=request.path, + headers=dict(request.headers), + body=request.get_json() + ): + return jsonify({"error": "Unauthorized"}), 401 + + # Handle activity + activity = request.get_json() + result = await self.handle_activity(activity) + + return Response( + response=json.dumps({"id": result}), + status=202, + mimetype="application/activity+json" + ) + + except Exception as e: + logger.error(f"Inbox error: {e}") + return jsonify({"error": str(e)}), 500 + + @self.app.get("/actor") + async def get_instance_actor(): + """Get instance actor.""" + try: + return Response( + response=json.dumps(self.instance.actor), + mimetype="application/activity+json" + ) + except Exception as e: + logger.error(f"Actor error: {e}") + return jsonify({"error": str(e)}), 500 + + @self.app.get("/.well-known/nodeinfo") + async def get_nodeinfo(): + """Get NodeInfo.""" + try: + return jsonify(await self.instance.get_nodeinfo()) + except Exception as e: + logger.error(f"NodeInfo error: {e}") + return jsonify({"error": str(e)}), 500 + + async def initialize(self) -> None: + """Initialize integration.""" + try: + # Initialize components + self.storage = StorageBackend.create( + provider="postgresql", + database_url=self.config.database_url + ) + await self.storage.initialize() + + self.key_manager = KeyManager(self.config.key_path) + await self.key_manager.initialize() + + self.delivery = ActivityDelivery( + storage=self.storage, + key_manager=self.key_manager + ) + await self.delivery.initialize() + + # Initialize middleware + self.middleware = ActivityPubMiddleware( + signature_verifier=self.key_manager.signature_verifier, + rate_limiter=self.delivery.rate_limiter + ) + + logger.info("Flask integration initialized") + + except Exception as e: + logger.error(f"Failed to initialize Flask integration: {e}") + raise IntegrationError(f"Integration initialization failed: {e}") + + async def shutdown(self) -> None: + """Shutdown integration.""" + try: + await self.storage.close() + await self.delivery.close() + await self.key_manager.close() + + except Exception as e: + logger.error(f"Failed to shutdown Flask integration: {e}") + raise IntegrationError(f"Integration shutdown failed: {e}") + + async def handle_activity(self, activity: Dict[str, Any]) -> Optional[str]: + """Handle incoming activity.""" + return await self.delivery.process_activity(activity) + + async def deliver_activity(self, + activity: Dict[str, Any], + recipients: List[str]) -> None: + """Deliver activity to recipients.""" + await self.delivery.deliver_activity(activity, recipients) \ No newline at end of file diff --git a/src/pyfed/integration/middleware.py b/src/pyfed/integration/middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..10f72761ca8647459578434eae8442097e7f4740 --- /dev/null +++ b/src/pyfed/integration/middleware.py @@ -0,0 +1,92 @@ +""" +Integration middleware implementation. +""" + +from typing import Dict, Any, Optional, Callable, Awaitable +import asyncio +from datetime import datetime +import json + +from ..utils.exceptions import MiddlewareError +from ..utils.logging import get_logger +from ..security.http_signatures import HTTPSignatureVerifier +from ..federation.rate_limit import RateLimiter + +logger = get_logger(__name__) + +class ActivityPubMiddleware: + """ActivityPub middleware.""" + + def __init__(self, + signature_verifier: HTTPSignatureVerifier, + rate_limiter: RateLimiter): + self.signature_verifier = signature_verifier + self.rate_limiter = rate_limiter + + async def process_request(self, + method: str, + path: str, + headers: Dict[str, str], + body: Optional[Dict[str, Any]] = None) -> bool: + """ + Process incoming request. + + Args: + method: HTTP method + path: Request path + headers: Request headers + body: Request body + + Returns: + bool: True if request is valid + """ + try: + # Verify HTTP signature + if not await self.signature_verifier.verify_request( + headers=headers, + method=method, + path=path + ): + return False + + # Check rate limits + domain = headers.get('Host', '').split(':')[0] + if not await self.rate_limiter.check_limit(domain, 'request'): + return False + + return True + + except Exception as e: + logger.error(f"Middleware error: {e}") + return False + + async def process_response(self, + status: int, + headers: Dict[str, str], + body: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + """ + Process outgoing response. + + Args: + status: Response status + headers: Response headers + body: Response body + + Returns: + Dict with processed headers + """ + try: + response_headers = headers.copy() + + # Add standard headers + response_headers.update({ + "Content-Type": "application/activity+json", + "Vary": "Accept, Accept-Encoding", + "Cache-Control": "max-age=0, private, must-revalidate" + }) + + return response_headers + + except Exception as e: + logger.error(f"Middleware error: {e}") + return headers \ No newline at end of file diff --git a/src/pyfed/models/actors.py b/src/pyfed/models/actors.py index 0a5fc0b1cbc777d4546e64271e059a6331350ea7..0056ff87f7e0d365aede0d9b3d3df86e90f759b7 100644 --- a/src/pyfed/models/actors.py +++ b/src/pyfed/models/actors.py @@ -11,10 +11,9 @@ from __future__ import annotations from pydantic import Field, HttpUrl, field_validator from typing import Optional, List, Dict, Any, TypedDict, Literal from datetime import datetime -from pyfed.utils.exceptions import InvalidURLError -# from pyfed.plugins import plugin_manager -from pyfed.utils.logging import get_logger -from pyfed.cache import object_cache +from ..utils.exceptions import InvalidURLError +from ..utils.logging import get_logger +from ..cache import object_cache from .objects import APObject @@ -64,52 +63,52 @@ class APActor(APObject): except ValueError: raise InvalidURLError(f"Invalid URL: {v}") - async def send_to_inbox(self, activity: Dict[str, Any]) -> bool: - """ - Send an activity to this actor's inbox. + # async def send_to_inbox(self, activity: Dict[str, Any]) -> bool: + # """ + # Send an activity to this actor's inbox. - Args: - activity (Dict[str, Any]): The activity to send. + # Args: + # activity (Dict[str, Any]): The activity to send. - Returns: - bool: True if the activity was successfully delivered, False otherwise. + # Returns: + # bool: True if the activity was successfully delivered, False otherwise. - Note: This method should implement the logic described in: - https://www.w3.org/TR/activitypub/#delivery - """ - logger.info(f"Sending activity to inbox: {self.inbox}") - # Execute pre-send hook - # plugin_manager.execute_hook('pre_send_to_inbox', self, activity) + # Note: This method should implement the logic described in: + # https://www.w3.org/TR/activitypub/#delivery + # """ + # logger.info(f"Sending activity to inbox: {self.inbox}") + # # Execute pre-send hook + # # plugin_manager.execute_hook('pre_send_to_inbox', self, activity) - # Placeholder for actual implementation - return True + # # Placeholder for actual implementation + # return True - async def fetch_followers(self) -> List[APActor]: - """ - Fetch the followers of this actor. + # async def fetch_followers(self) -> List[APActor]: + # """ + # Fetch the followers of this actor. - Returns: - List[APActor]: A list of actors following this actor. - """ - logger.info(f"Fetching followers for actor: {self.id}") + # Returns: + # List[APActor]: A list of actors following this actor. + # """ + # logger.info(f"Fetching followers for actor: {self.id}") - # Check cache first - cached_followers = object_cache.get(f"followers:{self.id}") - if cached_followers is not None: - return cached_followers + # # Check cache first + # cached_followers = object_cache.get(f"followers:{self.id}") + # if cached_followers is not None: + # return cached_followers - # Execute pre-fetch hook - # plugin_manager.execute_hook('pre_fetch_followers', self) + # # Execute pre-fetch hook + # # plugin_manager.execute_hook('pre_fetch_followers', self) - # Fetch followers (placeholder implementation) - followers = [] # Actual implementation would go here + # # Fetch followers (placeholder implementation) + # followers = [] # Actual implementation would go here - # Cache the result - object_cache.set(f"followers:{self.id}", followers) + # # Cache the result + # object_cache.set(f"followers:{self.id}", followers) - return followers + # return followers - async def create_activity(self, activity_type: str, object: Dict[str, Any]) -> ActivityDict: + # async def create_activity(self, activity_type: str, object: Dict[str, Any]) -> ActivityDict: """ Create an activity with this actor as the 'actor'. diff --git a/src/pyfed/models/objects.py b/src/pyfed/models/objects.py index 032a573d01d4c48b14957b40e5a63d28c0202a1a..cd49aab7ffe88048ce0155293535b28ffa76eedf 100644 --- a/src/pyfed/models/objects.py +++ b/src/pyfed/models/objects.py @@ -52,7 +52,7 @@ class APObject(ActivityPubBase): Usage: https://www.w3.org/TR/activitypub/#object """ id: HttpUrl = Field(..., description="Unique identifier for the object") - type: str = Field(..., description="The type of the object") + type: Literal["Object"] = Field(..., description="The type of the object") attachment: Optional[List[Union[str, 'APObject']]] = Field(default=None, description="Files attached to the object") attributed_to: Optional[Union[str, 'APObject']] = Field(default=None, description="Entity attributed to this object") audience: Optional[List[Union[str, 'APObject']]] = Field(default=None, description="Intended audience") diff --git a/src/pyfed/protocols/__init__.py b/src/pyfed/protocols/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f4736c422d2d586f76fc4e47c24a9a93a4d05c --- /dev/null +++ b/src/pyfed/protocols/__init__.py @@ -0,0 +1,6 @@ +# from .c2s import C2SConfig, ClientToServerProtocol +# from .s2s import ServerToServerProtocol + +# __all__ = [ +# 'C2SConfig', 'ClientToServerProtocol', 'ServerToServerProtocol' +# ] diff --git a/src/pyfed/protocols/webfinger.py b/src/pyfed/protocols/webfinger.py new file mode 100644 index 0000000000000000000000000000000000000000..32697e5c4de8d4d437926a5ab11299560ae0b90c --- /dev/null +++ b/src/pyfed/protocols/webfinger.py @@ -0,0 +1,137 @@ +""" +WebFinger protocol implementation. +""" + +from typing import Dict, Any, Optional +import aiohttp +from urllib.parse import quote +import json + +from ..utils.exceptions import WebFingerError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +class WebFingerClient: + """WebFinger client implementation.""" + + def __init__(self, timeout: int = 30, verify_ssl: bool = True): + self.timeout = timeout + self.verify_ssl = verify_ssl + self.session = None + + async def initialize(self) -> None: + """Initialize client.""" + # Create SSL context + ssl_context = None + if not self.verify_ssl: + import ssl + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + self.session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers={"Accept": "application/jrd+json, application/json"}, + connector=aiohttp.TCPConnector(ssl=ssl_context) + ) + + async def finger(self, account: str) -> Optional[Dict[str, Any]]: + """ + Perform WebFinger lookup. + + Args: + account: Account to look up (e.g. user@domain.com) + + Returns: + WebFinger response if found + """ + try: + # Parse account + if account.startswith('acct:'): + account = account[5:] + if '@' not in account: + raise WebFingerError(f"Invalid account format: {account}") + + username, domain = account.split('@') + resource = f"acct:{username}@{domain}" + + # Construct WebFinger URL + url = f"https://{domain}/.well-known/webfinger?resource={quote(resource)}" + + # Perform lookup + async with self.session.get(url) as response: + if response.status != 200: + logger.error( + f"WebFinger lookup failed for {account}: {response.status}" + ) + return None + + data = await response.json() + return data + + except Exception as e: + logger.error(f"WebFinger lookup failed for {account}: {e}") + return None + + async def get_actor_url(self, account: str) -> Optional[str]: + """ + Get actor URL from WebFinger response. + + Args: + account: Account to look up + + Returns: + Actor URL if found + """ + try: + data = await self.finger(account) + if not data: + return None + + # Look for actor URL in links + for link in data.get('links', []): + if ( + link.get('rel') == 'self' and + link.get('type') == 'application/activity+json' + ): + return link.get('href') + + return None + + except Exception as e: + logger.error(f"Failed to get actor URL for {account}: {e}") + return None + + async def get_inbox_url(self, account: str) -> Optional[str]: + """ + Get inbox URL for account. + + Args: + account: Account to look up + + Returns: + Inbox URL if found + """ + try: + # First get actor URL + actor_url = await self.get_actor_url(account) + if not actor_url: + return None + + # Fetch actor object + async with self.session.get(actor_url) as response: + if response.status != 200: + return None + + actor = await response.json() + return actor.get('inbox') + + except Exception as e: + logger.error(f"Failed to get inbox URL for {account}: {e}") + return None + + async def close(self) -> None: + """Clean up resources.""" + if self.session: + await self.session.close() \ No newline at end of file diff --git a/src/pyfed/requirements.txt b/src/pyfed/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0519ecba6ea913e21689ec692e81e9e4973fbf73 --- /dev/null +++ b/src/pyfed/requirements.txt @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/pyfed/security/__init__.py b/src/pyfed/security/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6bdfbdf557a34b8d9362cf5161dc20f769d1ac6 --- /dev/null +++ b/src/pyfed/security/__init__.py @@ -0,0 +1,5 @@ +from .key_management import KeyManager + +__all__ = [ + 'KeyManager' +] diff --git a/src/pyfed/security/hardening.py b/src/pyfed/security/hardening.py new file mode 100644 index 0000000000000000000000000000000000000000..b5dc7413f9d13372b4379a2a19f3f8d58275d68d --- /dev/null +++ b/src/pyfed/security/hardening.py @@ -0,0 +1,184 @@ +""" +Security hardening implementation. +""" + +import base64 +from typing import Dict, Any, Optional, List +import hashlib +import secrets +from datetime import datetime, timedelta +from dataclasses import dataclass +from enum import Enum +import re + +from ..utils.exceptions import SecurityError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +class SecurityLevel(Enum): + """Security level settings.""" + BASIC = "basic" + ENHANCED = "enhanced" + STRICT = "strict" + +@dataclass +class SecurityPolicy: + """Security policy configuration.""" + min_key_size: int + key_rotation_days: int + signature_max_age: int # seconds + require_digest: bool + allowed_algorithms: List[str] + blocked_ips: List[str] + blocked_domains: List[str] + request_timeout: int + max_payload_size: int + required_headers: List[str] + +class SecurityHardening: + """Security hardening implementation.""" + + def __init__(self, level: SecurityLevel = SecurityLevel.ENHANCED): + self.level = level + self.policy = self._get_policy(level) + self._nonce_cache = {} + + def _get_policy(self, level: SecurityLevel) -> SecurityPolicy: + """Get security policy for level.""" + if level == SecurityLevel.BASIC: + return SecurityPolicy( + min_key_size=2048, + key_rotation_days=90, + signature_max_age=300, # 5 minutes + require_digest=False, + allowed_algorithms=["rsa-sha256"], + blocked_ips=[], + blocked_domains=[], + request_timeout=30, + max_payload_size=5_000_000, # 5MB + required_headers=["date", "host"] + ) + elif level == SecurityLevel.ENHANCED: + return SecurityPolicy( + min_key_size=4096, + key_rotation_days=30, + signature_max_age=120, # 2 minutes + require_digest=True, + allowed_algorithms=["rsa-sha256", "rsa-sha512"], + blocked_ips=[], + blocked_domains=[], + request_timeout=20, + max_payload_size=1_000_000, # 1MB + required_headers=["date", "host", "digest"] + ) + else: # STRICT + return SecurityPolicy( + min_key_size=8192, + key_rotation_days=7, + signature_max_age=60, # 1 minute + require_digest=True, + allowed_algorithms=["rsa-sha512"], + blocked_ips=[], + blocked_domains=[], + request_timeout=10, + max_payload_size=500_000, # 500KB + required_headers=["date", "host", "digest", "content-type"] + ) + + def validate_request(self, + headers: Dict[str, str], + body: Optional[str] = None, + remote_ip: Optional[str] = None) -> None: + """ + Validate request security. + + Args: + headers: Request headers + body: Request body + remote_ip: Remote IP address + + Raises: + SecurityError if validation fails + """ + try: + # Check required headers + for header in self.policy.required_headers: + if header.lower() not in [k.lower() for k in headers]: + raise SecurityError(f"Missing required header: {header}") + + # Check signature algorithm + sig_header = headers.get('signature', '') + if 'algorithm=' in sig_header: + algo = re.search(r'algorithm="([^"]+)"', sig_header) + if algo and algo.group(1) not in self.policy.allowed_algorithms: + raise SecurityError(f"Unsupported signature algorithm: {algo.group(1)}") + + # Verify digest if required + if self.policy.require_digest and body: + if 'digest' not in headers: + raise SecurityError("Missing required digest header") + if not self._verify_digest(headers['digest'], body): + raise SecurityError("Invalid digest") + + # Check payload size + if body and len(body) > self.policy.max_payload_size: + raise SecurityError("Payload too large") + + # Check IP/domain blocks + if remote_ip and remote_ip in self.policy.blocked_ips: + raise SecurityError("IP address blocked") + + # Verify nonce + nonce = headers.get('nonce') + if nonce and not self._verify_nonce(nonce): + raise SecurityError("Invalid or reused nonce") + + except SecurityError: + raise + except Exception as e: + logger.error(f"Security validation failed: {e}") + raise SecurityError(f"Security validation failed: {e}") + + def _verify_digest(self, digest_header: str, body: str) -> bool: + """Verify request digest.""" + try: + algo, value = digest_header.split('=', 1) + if algo.upper() == 'SHA-256': + computed = hashlib.sha256(body.encode()).digest() + return base64.b64encode(computed).decode() == value + return False + except: + return False + + def _verify_nonce(self, nonce: str) -> bool: + """Verify and track nonce.""" + now = datetime.utcnow() + + # Clean old nonces + self._nonce_cache = { + n: t for n, t in self._nonce_cache.items() + if t > now - timedelta(minutes=5) + } + + # Check if nonce used + if nonce in self._nonce_cache: + return False + + # Store nonce + self._nonce_cache[nonce] = now + return True + + def generate_nonce(self) -> str: + """Generate secure nonce.""" + return secrets.token_urlsafe(32) + + def block_ip(self, ip: str) -> None: + """Add IP to block list.""" + if ip not in self.policy.blocked_ips: + self.policy.blocked_ips.append(ip) + + def block_domain(self, domain: str) -> None: + """Add domain to block list.""" + if domain not in self.policy.blocked_domains: + self.policy.blocked_domains.append(domain) \ No newline at end of file diff --git a/src/pyfed/security/http_signatures.py b/src/pyfed/security/http_signatures.py new file mode 100644 index 0000000000000000000000000000000000000000..a4db3edf6bdbf84fdb8d239276e5be4a5d497646 --- /dev/null +++ b/src/pyfed/security/http_signatures.py @@ -0,0 +1,293 @@ +""" +security/http_signatures.py +Enhanced HTTP Signatures implementation. + +Implements HTTP Signatures (draft-cavage-http-signatures) with: +- Key management +- Signature caching +- Request verification +- Performance optimization +""" + +from typing import Dict, Any, Optional, List, Union +import base64 +from unittest.mock import Mock +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.exceptions import InvalidSignature +import json +from datetime import datetime, timedelta +from urllib.parse import urlparse +import hashlib + +from ..utils.exceptions import SignatureError +from ..utils.logging import get_logger +from ..cache.memory_cache import MemoryCache + +logger = get_logger(__name__) + +class SignatureCache: + """Cache for HTTP signatures.""" + + def __init__(self, ttl: int = 300): # 5 minutes default TTL + self.cache = MemoryCache(ttl) + + async def get(self, key: str) -> Optional[Dict[str, Any]]: + """Get cached signature.""" + return await self.cache.get(key) + + async def set(self, key: str, value: Dict[str, Any]) -> None: + """Cache signature.""" + await self.cache.set(key, value) + +class HTTPSignatureVerifier: + """Enhanced HTTP signature verification.""" + + def __init__(self, + private_key_path: str, + public_key_path: str, + key_id: Optional[str] = None): + """Initialize signature verifier.""" + self.private_key_path = private_key_path + self.public_key_path = public_key_path + self.key_id = key_id or f"https://{urlparse(private_key_path).netloc}#main-key" + self.signature_cache = SignatureCache() + self._test_now = None + + # Load keys + self.private_key = self._load_private_key() + self.public_key = self._load_public_key() + + def set_test_time(self, test_time: datetime) -> None: + """Set a fixed time for testing.""" + self._test_now = test_time + + def _load_private_key(self) -> rsa.RSAPrivateKey: + """Load private key from file.""" + try: + with open(self.private_key_path, 'rb') as f: + key_data = f.read() + return serialization.load_pem_private_key( + key_data, + password=None + ) + except Exception as e: + logger.error(f"Failed to load private key: {e}") + raise SignatureError(f"Failed to load private key: {e}") + + def _load_public_key(self) -> rsa.RSAPublicKey: + """Load public key from file.""" + try: + with open(self.public_key_path, 'rb') as f: + key_data = f.read() + return serialization.load_pem_public_key(key_data) + except Exception as e: + logger.error(f"Failed to load public key: {e}") + raise SignatureError(f"Failed to load public key: {e}") + + async def verify_request(self, + headers: Dict[str, str], + method: str = "POST", + path: str = "") -> bool: + """Verify HTTP signature on request.""" + try: + if 'Signature' not in headers: + raise SignatureError("Missing Signature header") + + # Parse signature header + sig_params = self._parse_signature_header(headers['Signature']) + + # Check cache + cache_key = f"{sig_params['keyId']}:{headers.get('Date', '')}" + if cached := await self.signature_cache.get(cache_key): + return cached['valid'] + + # Build signing string + signed_headers = sig_params['headers'].split() + + # Verify all required headers are present + for header in signed_headers: + if header != '(request-target)' and header.lower() not in [k.lower() for k in headers]: + raise SignatureError(f"Missing required header: {header}") + + # Verify date before signature verification + if not self._verify_date(headers.get('Date')): + raise SignatureError("Invalid or missing date") + + # Build signing string + signing_string = self._build_signing_string( + method, + path, + headers, + signed_headers + ) + + # Verify signature + try: + signature = base64.b64decode(sig_params['signature']) + self.public_key.verify( + signature, + signing_string.encode(), + padding.PKCS1v15(), + hashes.SHA256() + ) + await self.signature_cache.set( + cache_key, + {'valid': True, 'timestamp': datetime.utcnow().isoformat()} + ) + return True + except InvalidSignature: + await self.signature_cache.set( + cache_key, + {'valid': False, 'timestamp': datetime.utcnow().isoformat()} + ) + return False + + except SignatureError: + raise + except Exception as e: + logger.error(f"Signature verification failed: {e}") + raise SignatureError(f"Signature verification failed: {e}") + + async def sign_request(self, + method: str, + path: str, + headers: Dict[str, str], + body: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + """ + Sign HTTP request. + + Args: + method: HTTP method + path: Request path + headers: Request headers + body: Request body (optional) + + Returns: + Dict with signature headers + """ + try: + request_headers = headers.copy() + + # Add date if not present, using test time if set + if 'date' not in request_headers: + now = self._test_now if self._test_now is not None else datetime.utcnow() + request_headers['date'] = now.strftime('%a, %d %b %Y %H:%M:%S GMT') + + # Add digest if body present + if body is not None: + body_digest = self._generate_digest(body) + request_headers['digest'] = f"SHA-256={body_digest}" + + # Headers to sign + headers_to_sign = ['(request-target)', 'host', 'date'] + if body is not None: + headers_to_sign.append('digest') + + # Build signing string + signing_string = self._build_signing_string( + method, + path, + request_headers, + headers_to_sign + ) + + # Sign + signature = self.private_key.sign( + signing_string.encode(), + padding.PKCS1v15(), + hashes.SHA256() + ) + + # Build signature header + signature_header = ( + f'keyId="{self.key_id}",' + f'algorithm="rsa-sha256",' + f'headers="{" ".join(headers_to_sign)}",' + f'signature="{base64.b64encode(signature).decode()}"' + ) + + return { + **request_headers, + 'Signature': signature_header + } + + except Exception as e: + logger.error(f"Request signing failed: {e}") + raise SignatureError(f"Request signing failed: {e}") + + def _parse_signature_header(self, header: str) -> Dict[str, str]: + """Parse HTTP signature header.""" + try: + parts = {} + for part in header.split(','): + if '=' not in part: + continue + key, value = part.split('=', 1) + parts[key.strip()] = value.strip('"') + + required = ['keyId', 'algorithm', 'headers', 'signature'] + if not all(k in parts for k in required): + raise SignatureError("Missing required signature parameters") + + return parts + except Exception as e: + raise SignatureError(f"Invalid signature header: {e}") + + def _verify_date(self, date_header: Optional[str]) -> bool: + """ + Verify request date with clock skew handling. + + Allows 5 minutes of clock skew in either direction. + """ + if not date_header: + return False + + try: + request_time = datetime.strptime( + date_header, + '%a, %d %b %Y %H:%M:%S GMT' + ).replace(tzinfo=None) + + # Use test time if set, otherwise use current time + now = self._test_now if self._test_now is not None else datetime.utcnow() + now = now.replace(tzinfo=None) + + skew = timedelta(minutes=5) + earliest = now - skew + latest = now + skew + + logger.debug( + f"Date verification: request_time={request_time}, " + f"now={now}, earliest={earliest}, latest={latest}" + ) + + return earliest <= request_time <= latest + + except Exception as e: + logger.debug(f"Date verification failed: {e}") + return False + + def _build_signing_string(self, + method: str, + path: str, + headers: Dict[str, str], + signed_headers: List[str]) -> str: + """Build string to sign.""" + lines = [] + + for header in signed_headers: + if header == '(request-target)': + lines.append(f"(request-target): {method.lower()} {path}") + else: + if header.lower() not in [k.lower() for k in headers]: + raise SignatureError(f"Missing required header: {header}") + lines.append(f"{header.lower()}: {headers[header]}") + + return '\n'.join(lines) + + def _generate_digest(self, body: Dict[str, Any]) -> str: + """Generate SHA-256 digest of body.""" + body_bytes = json.dumps(body, sort_keys=True).encode() + digest = hashlib.sha256(body_bytes).digest() + return base64.b64encode(digest).decode() \ No newline at end of file diff --git a/src/pyfed/security/interfaces.py b/src/pyfed/security/interfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..bc34d471cc8989246c34d934fc7cf0b94fbe92fb --- /dev/null +++ b/src/pyfed/security/interfaces.py @@ -0,0 +1,45 @@ +""" +Security component interfaces. +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional + +class SignatureVerifier(ABC): + """Interface for HTTP signature verification.""" + + @abstractmethod + async def verify_request(self, headers: Dict[str, str]) -> bool: + """ + Verify request signature. + + Args: + headers: Request headers including signature + + Returns: + True if signature is valid + """ + pass + + @abstractmethod + async def sign_request(self, + method: str, + path: str, + host: str, + body: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + """ + Sign a request. + + Args: + method: HTTP method + path: Request path + host: Target host + body: Optional request body + + Returns: + Headers with signature + + Raises: + SignatureError: If signing fails + """ + pass \ No newline at end of file diff --git a/src/pyfed/security/key_management.py b/src/pyfed/security/key_management.py new file mode 100644 index 0000000000000000000000000000000000000000..1b8f392677c8bfdd00a3f76e4aba539ffe885f92 --- /dev/null +++ b/src/pyfed/security/key_management.py @@ -0,0 +1,286 @@ +""" +Enhanced key management with rotation support. +""" + +from typing import Dict, Any, Optional, Tuple +from datetime import datetime, timedelta +import json +from pathlib import Path +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import rsa, padding +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey +import aiofiles +import asyncio + +from ..utils.exceptions import KeyManagementError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +class KeyRotation: + """Key rotation configuration.""" + def __init__(self, + rotation_interval: int = 30, # days + key_overlap: int = 2, # days + key_size: int = 2048): + self.rotation_interval = rotation_interval + self.key_overlap = key_overlap + self.key_size = key_size + +class KeyPair: + """Key pair with metadata.""" + def __init__(self, + private_key: RSAPrivateKey, + public_key: RSAPublicKey, + created_at: datetime, + expires_at: datetime, + key_id: str): + self.private_key = private_key + self.public_key = public_key + self.created_at = created_at + self.expires_at = expires_at + self.key_id = key_id + +class KeyManager: + """Enhanced key management with rotation.""" + + def __init__(self, + domain: str, + keys_path: str = "keys", + active_keys: Optional[Path] = None, + rotation_config: Optional[KeyRotation] = None): + self.domain = domain + self.keys_path = Path(keys_path) + self.rotation_config = rotation_config or KeyRotation() + self.active_keys: Dict[str, KeyPair] = {} + self._rotation_task = None + + async def initialize(self) -> None: + """Initialize key manager.""" + try: + # Create keys directory + self.keys_path.mkdir(parents=True, exist_ok=True) + + # Load existing keys + await self._load_existing_keys() + + # Generate initial keys if none exist + if not self.active_keys: + await self.generate_key_pair() + + # Start rotation task + self._rotation_task = asyncio.create_task(self._key_rotation_loop()) + + except Exception as e: + logger.error(f"Failed to initialize key manager: {e}") + raise KeyManagementError(f"Key manager initialization failed: {e}") + + async def generate_key_pair(self) -> KeyPair: + """Generate new key pair.""" + try: + # Generate keys + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=self.rotation_config.key_size + ) + public_key = private_key.public_key() + + # Set validity period + created_at = datetime.utcnow() + expires_at = created_at + timedelta(days=self.rotation_config.rotation_interval) + + # Generate key ID + key_id = f"https://{self.domain}/keys/{created_at.timestamp()}" + + # Create key pair + key_pair = KeyPair( + private_key=private_key, + public_key=public_key, + created_at=created_at, + expires_at=expires_at, + key_id=key_id + ) + + # Save keys + await self._save_key_pair(key_pair) + + # Add to active keys + self.active_keys[key_id] = key_pair + + return key_pair + + except Exception as e: + logger.error(f"Failed to generate key pair: {e}") + raise KeyManagementError(f"Key generation failed: {e}") + + async def rotate_keys(self) -> None: + """Perform key rotation.""" + try: + logger.info("Starting key rotation") + + # Generate new key pair + new_pair = await self.generate_key_pair() + logger.info(f"Generated new key pair: {new_pair.key_id}") + + # Remove expired keys + now = datetime.utcnow() + expired = [ + key_id for key_id, pair in self.active_keys.items() + if pair.expires_at < now - timedelta(days=self.rotation_config.key_overlap) + ] + + for key_id in expired: + await self._archive_key_pair(self.active_keys[key_id]) + del self.active_keys[key_id] + logger.info(f"Archived expired key: {key_id}") + + # Announce new key to federation + await self._announce_key_rotation(new_pair) + + except Exception as e: + logger.error(f"Key rotation failed: {e}") + raise KeyManagementError(f"Key rotation failed: {e}") + + async def get_active_key(self) -> KeyPair: + """Get the most recent active key.""" + if not self.active_keys: + raise KeyManagementError("No active keys available") + + # Return most recently created key + return max( + self.active_keys.values(), + key=lambda k: k.created_at + ) + + async def verify_key(self, key_id: str, domain: str) -> bool: + """Verify a key's validity.""" + try: + # Check if key is one of our active keys + if key_id in self.active_keys: + key_pair = self.active_keys[key_id] + return datetime.utcnow() <= key_pair.expires_at + + # For external keys, verify with their server + # Implementation for external key verification + return False + + except Exception as e: + logger.error(f"Key verification failed: {e}") + return False + + async def _load_existing_keys(self) -> None: + """Load existing keys from disk.""" + try: + for key_file in self.keys_path.glob("*.json"): + async with aiofiles.open(key_file, 'r') as f: + metadata = json.loads(await f.read()) + + # Load private key + private_key_path = self.keys_path / f"{metadata['key_id']}_private.pem" + async with aiofiles.open(private_key_path, 'rb') as f: + private_key = serialization.load_pem_private_key( + await f.read(), + password=None + ) + + # Create key pair + key_pair = KeyPair( + private_key=private_key, + public_key=private_key.public_key(), + created_at=datetime.fromisoformat(metadata['created_at']), + expires_at=datetime.fromisoformat(metadata['expires_at']), + key_id=metadata['key_id'] + ) + + # Add to active keys if not expired + if datetime.utcnow() <= key_pair.expires_at: + self.active_keys[key_pair.key_id] = key_pair + + except Exception as e: + logger.error(f"Failed to load existing keys: {e}") + raise KeyManagementError(f"Failed to load existing keys: {e}") + + async def _save_key_pair(self, key_pair: KeyPair) -> None: + """Save key pair to disk.""" + try: + # Save private key + private_key_path = self.keys_path / f"{key_pair.key_id}_private.pem" + private_pem = key_pair.private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ) + async with aiofiles.open(private_key_path, 'wb') as f: + await f.write(private_pem) + + # Save public key + public_key_path = self.keys_path / f"{key_pair.key_id}_public.pem" + public_pem = key_pair.public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + async with aiofiles.open(public_key_path, 'wb') as f: + await f.write(public_pem) + + # Save metadata + metadata = { + 'key_id': key_pair.key_id, + 'created_at': key_pair.created_at.isoformat(), + 'expires_at': key_pair.expires_at.isoformat() + } + metadata_path = self.keys_path / f"{key_pair.key_id}.json" + async with aiofiles.open(metadata_path, 'w') as f: + await f.write(json.dumps(metadata)) + + except Exception as e: + logger.error(f"Failed to save key pair: {e}") + raise KeyManagementError(f"Failed to save key pair: {e}") + + async def _archive_key_pair(self, key_pair: KeyPair) -> None: + """Archive an expired key pair.""" + try: + archive_dir = self.keys_path / "archive" + archive_dir.mkdir(exist_ok=True) + + # Move key files to archive + for ext in ['_private.pem', '_public.pem', '.json']: + src = self.keys_path / f"{key_pair.key_id}{ext}" + dst = archive_dir / f"{key_pair.key_id}{ext}" + if src.exists(): + src.rename(dst) + + except Exception as e: + logger.error(f"Failed to archive key pair: {e}") + raise KeyManagementError(f"Failed to archive key pair: {e}") + + async def _announce_key_rotation(self, key_pair: KeyPair) -> None: + """Announce new key to federation.""" + # Implementation for announcing key rotation to federation + pass + + async def _key_rotation_loop(self) -> None: + """Background task for key rotation.""" + while True: + try: + # Check for keys needing rotation + now = datetime.utcnow() + for key_pair in self.active_keys.values(): + if key_pair.expires_at <= now + timedelta(days=1): + await self.rotate_keys() + break + + # Sleep for a day + await asyncio.sleep(86400) + + except Exception as e: + logger.error(f"Key rotation loop error: {e}") + await asyncio.sleep(3600) # Retry in an hour + + async def close(self) -> None: + """Clean up resources.""" + if self._rotation_task: + self._rotation_task.cancel() + try: + await self._rotation_task + except asyncio.CancelledError: + pass diff --git a/src/pyfed/security/oauth.py b/src/pyfed/security/oauth.py new file mode 100644 index 0000000000000000000000000000000000000000..62f66c48b97ec50e69f3d814518c37822988d4f0 --- /dev/null +++ b/src/pyfed/security/oauth.py @@ -0,0 +1,298 @@ +""" +Enhanced OAuth2 implementation for ActivityPub C2S authentication. +""" + +from typing import Dict, Any, Optional, List +from dataclasses import dataclass +import aiohttp +import jwt +from datetime import datetime, timedelta +import asyncio +from abc import ABC, abstractmethod +from ..utils.exceptions import AuthenticationError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +@dataclass +class OAuth2Config: + """OAuth2 configuration.""" + token_lifetime: int = 3600 # 1 hour + refresh_token_lifetime: int = 2592000 # 30 days + clock_skew: int = 30 # seconds + max_retries: int = 3 + retry_delay: float = 1.0 + request_timeout: float = 10.0 + allowed_grant_types: List[str] = ("password", "refresh_token") + allowed_scopes: List[str] = ("read", "write") + required_token_fields: List[str] = ( + "access_token", + "token_type", + "expires_in", + "refresh_token" + ) + +class TokenCache(ABC): + """Abstract token cache interface.""" + + @abstractmethod + async def get_token(self, key: str) -> Optional[Dict[str, Any]]: + """Get token from cache.""" + pass + + @abstractmethod + async def store_token(self, key: str, token_data: Dict[str, Any]) -> None: + """Store token in cache.""" + pass + + @abstractmethod + async def invalidate_token(self, key: str) -> None: + """Invalidate cached token.""" + pass + +class OAuth2Handler: + """Enhanced OAuth2 handler with improved security.""" + + def __init__(self, + client_id: str, + client_secret: str, + token_endpoint: str, + config: Optional[OAuth2Config] = None, + token_cache: Optional[TokenCache] = None): + """Initialize OAuth2 handler.""" + self.client_id = client_id + self.client_secret = client_secret + self.token_endpoint = token_endpoint + self.config = config or OAuth2Config() + self.token_cache = token_cache + self._lock = asyncio.Lock() + + # Metrics + self.metrics = { + 'tokens_created': 0, + 'tokens_refreshed': 0, + 'tokens_verified': 0, + 'token_failures': 0, + 'cache_hits': 0, + 'cache_misses': 0 + } + + async def create_token(self, + username: str, + password: str, + scope: Optional[str] = None) -> Dict[str, Any]: + """ + Create OAuth2 token using password grant. + + Args: + username: User's username + password: User's password + scope: Optional scope request + + Returns: + Token response data + + Raises: + AuthenticationError: If token creation fails + """ + async with self._lock: + try: + # Validate scope + if scope and not self._validate_scope(scope): + raise AuthenticationError(f"Invalid scope: {scope}") + + data = { + 'grant_type': 'password', + 'username': username, + 'password': password, + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'scope': scope or ' '.join(self.config.allowed_scopes) + } + + token_data = await self._make_token_request(data) + + # Cache token if cache available + if self.token_cache: + await self.token_cache.store_token(username, token_data) + + self.metrics['tokens_created'] += 1 + return token_data + + except AuthenticationError: + self.metrics['token_failures'] += 1 + raise + except Exception as e: + self.metrics['token_failures'] += 1 + logger.error(f"Token creation failed: {e}") + raise AuthenticationError(f"Token creation failed: {e}") + + async def refresh_token(self, + refresh_token: str, + user_id: Optional[str] = None) -> Dict[str, Any]: + """ + Refresh OAuth2 token. + + Args: + refresh_token: Refresh token + user_id: Optional user ID for cache + + Returns: + New token data + + Raises: + AuthenticationError: If refresh fails + """ + async with self._lock: + try: + data = { + 'grant_type': 'refresh_token', + 'refresh_token': refresh_token, + 'client_id': self.client_id, + 'client_secret': self.client_secret + } + + token_data = await self._make_token_request(data) + + # Update cache + if user_id and self.token_cache: + await self.token_cache.store_token(user_id, token_data) + + self.metrics['tokens_refreshed'] += 1 + return token_data + + except Exception as e: + self.metrics['token_failures'] += 1 + logger.error(f"Token refresh failed: {e}") + raise AuthenticationError(f"Token refresh failed: {e}") + + async def verify_token(self, + token: str, + required_scope: Optional[str] = None) -> Dict[str, Any]: + """ + Verify OAuth2 token. + + Args: + token: Token to verify + required_scope: Optional required scope + + Returns: + Token payload if valid + + Raises: + AuthenticationError: If token is invalid + """ + try: + # First check cache if available + if self.token_cache: + cached = await self.token_cache.get_token(token) + if cached: + self.metrics['cache_hits'] += 1 + return cached + self.metrics['cache_misses'] += 1 + + # Verify JWT + payload = jwt.decode( + token, + self.client_secret, + algorithms=['HS256'], + leeway=self.config.clock_skew + ) + + # Check expiry + exp = datetime.fromtimestamp(payload['exp']) + if exp < datetime.utcnow(): + raise AuthenticationError("Token has expired") + + # Verify scope if required + if required_scope: + token_scopes = payload.get('scope', '').split() + if required_scope not in token_scopes: + raise AuthenticationError(f"Missing required scope: {required_scope}") + + self.metrics['tokens_verified'] += 1 + return payload + + except jwt.ExpiredSignatureError: + raise AuthenticationError("Token has expired") + except jwt.InvalidTokenError as e: + raise AuthenticationError(f"Invalid token: {e}") + except Exception as e: + logger.error(f"Token verification failed: {e}") + raise AuthenticationError(f"Token verification failed: {e}") + + async def _make_token_request(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Make OAuth2 token request with retries. + + Args: + data: Request data + + Returns: + Token response data + + Raises: + AuthenticationError: If request fails + """ + timeout = aiohttp.ClientTimeout(total=self.config.request_timeout) + + for attempt in range(self.config.max_retries): + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + self.token_endpoint, + data=data, + headers={'Accept': 'application/json'} + ) as response: + if response.status != 200: + error_data = await response.text() + raise AuthenticationError( + f"Token request failed: {response.status} - {error_data}" + ) + + token_data = await response.json() + + # Validate response + self._validate_token_response(token_data) + + return token_data + + except aiohttp.ClientError as e: + if attempt == self.config.max_retries - 1: + raise AuthenticationError(f"Network error: {e}") + await asyncio.sleep(self.config.retry_delay) + except AuthenticationError: + raise + except Exception as e: + if attempt == self.config.max_retries - 1: + raise AuthenticationError(f"Token request failed: {e}") + await asyncio.sleep(self.config.retry_delay) + + def _validate_token_response(self, data: Dict[str, Any]) -> None: + """Validate token response data.""" + if not isinstance(data, dict): + raise AuthenticationError("Invalid token response format") + + missing = set(self.config.required_token_fields) - set(data.keys()) + if missing: + raise AuthenticationError(f"Missing required fields: {missing}") + + if data.get('token_type', '').lower() != 'bearer': + raise AuthenticationError("Unsupported token type") + + def _validate_scope(self, scope: str) -> bool: + """Validate requested scope.""" + requested = set(scope.split()) + allowed = set(self.config.allowed_scopes) + return requested.issubset(allowed) + + async def revoke_token(self, token: str, user_id: Optional[str] = None) -> None: + """ + Revoke OAuth2 token. + + Args: + token: Token to revoke + user_id: Optional user ID for cache + """ + if self.token_cache and user_id: + await self.token_cache.invalidate_token(user_id) \ No newline at end of file diff --git a/src/pyfed/security/rate_limiter.py b/src/pyfed/security/rate_limiter.py new file mode 100644 index 0000000000000000000000000000000000000000..de848fca5a480672a0dbce328d760fdc98902cb1 --- /dev/null +++ b/src/pyfed/security/rate_limiter.py @@ -0,0 +1,158 @@ +""" +Enhanced rate limiting with better configurability and monitoring. +""" + +from typing import Dict, Optional, NamedTuple, List +import time +import asyncio +from datetime import datetime +from dataclasses import dataclass +import redis.asyncio as redis +from ..utils.exceptions import RateLimitExceeded + +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +@dataclass +class RateLimitConfig: + """Rate limit configuration.""" + requests_per_minute: int + window_seconds: int = 60 + burst_size: int = 0 # Allow burst over limit + redis_url: Optional[str] = None + +class RateLimitInfo(NamedTuple): + """Rate limit information.""" + remaining: int + reset_time: int + limit: int + +class RateLimiter: + """Enhanced rate limiting implementation.""" + + def __init__(self, config: RateLimitConfig): + self.config = config + self.redis_client = None + self._local_counters: Dict[str, Dict] = {} + self._lock = asyncio.Lock() + self.metrics = { + 'total_requests': 0, + 'exceeded_limits': 0, + 'redis_errors': 0 + } + + async def init(self) -> None: + """Initialize rate limiter.""" + if self.config.redis_url: + await self._init_redis() + + async def _init_redis(self) -> None: + """Initialize Redis connection with retries.""" + retries = 3 + for attempt in range(retries): + try: + self.redis_client = redis.from_url(self.config.redis_url) + await self.redis_client.ping() + break + except Exception as e: + if attempt == retries - 1: + logger.error(f"Failed to connect to Redis: {e}") + self.redis_client = None + else: + await asyncio.sleep(1) + + async def check_rate_limit(self, client_id: str) -> RateLimitInfo: + """Enhanced rate limit check with detailed information.""" + self.metrics['total_requests'] += 1 + try: + current_time = int(time.time()) + + if self.redis_client: + count = await self._check_redis_rate_limit(client_id, current_time) + else: + async with self._lock: + count = await self._check_local_rate_limit(client_id, current_time) + + remaining = max(0, self.config.requests_per_minute - count) + reset_time = (current_time // self.config.window_seconds + 1) * self.config.window_seconds + + if count > self.config.requests_per_minute + self.config.burst_size: + self.metrics['exceeded_limits'] += 1 + raise RateLimitExceeded( + message=f"Rate limit exceeded. Try again after {reset_time}", + reset_time=reset_time, + limit=self.config.requests_per_minute, + remaining=remaining + ) + + return RateLimitInfo( + remaining=remaining, + reset_time=reset_time, + limit=self.config.requests_per_minute + ) + + except RateLimitExceeded: + raise + except redis.RedisError as e: + self.metrics['redis_errors'] += 1 + logger.error(f"Redis error in rate limit check: {e}") + # Fall back to local rate limiting + return await self._check_local_rate_limit(client_id, current_time) + + async def _check_redis_rate_limit(self, client_id: str, current_time: int) -> int: + """Enhanced Redis rate limiting with sliding window.""" + key = f"rate_limit:{client_id}" + window_key = f"{key}:{current_time // self.config.window_seconds}" + + async with self.redis_client.pipeline() as pipe: + try: + # Use pipeline for atomic operations + pipe.watch(window_key) + current_count = await pipe.get(window_key) or 0 + + pipe.multi() + pipe.incr(window_key) + pipe.expire(window_key, self.config.window_seconds) + + # Handle sliding window + prev_window_key = f"{key}:{(current_time // self.config.window_seconds) - 1}" + prev_count = await pipe.get(prev_window_key) or 0 + + results = await pipe.execute() + current_count = int(results[0]) + + # Calculate weighted count for sliding window + weight = ((current_time % self.config.window_seconds) / self.config.window_seconds) + total_count = int(current_count + (int(prev_count) * (1 - weight))) + + return total_count + + except redis.WatchError: + # Key modified, retry + return await self._check_redis_rate_limit(client_id, current_time) + + async def get_metrics(self) -> Dict[str, int]: + """Get rate limiter metrics.""" + return { + **self.metrics, + 'active_clients': len(self._local_counters) + } + + async def reset_limits(self, client_id: Optional[str] = None) -> None: + """Reset rate limits for client or all clients.""" + if client_id: + if self.redis_client: + pattern = f"rate_limit:{client_id}:*" + keys = await self.redis_client.keys(pattern) + if keys: + await self.redis_client.delete(*keys) + if client_id in self._local_counters: + del self._local_counters[client_id] + else: + if self.redis_client: + pattern = "rate_limit:*" + keys = await self.redis_client.keys(pattern) + if keys: + await self.redis_client.delete(*keys) + self._local_counters.clear() \ No newline at end of file diff --git a/src/pyfed/security/revocation.py b/src/pyfed/security/revocation.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb7a28ce3ede54db5d83d98cb89d322a5bdf110 --- /dev/null +++ b/src/pyfed/security/revocation.py @@ -0,0 +1,153 @@ +""" +Key revocation system implementation. +""" + +from typing import Dict, Any, Optional, List +from datetime import datetime, timedelta +import json +import asyncio +import aioredis +from dataclasses import dataclass +from enum import Enum + +from ..utils.exceptions import RevocationError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +class RevocationReason(Enum): + """Key revocation reasons.""" + COMPROMISED = "compromised" + SUPERSEDED = "superseded" + CESSATION_OF_OPERATION = "cessation_of_operation" + PRIVILEGE_WITHDRAWN = "privilege_withdrawn" + +@dataclass +class RevocationInfo: + """Key revocation information.""" + key_id: str + reason: RevocationReason + timestamp: datetime + replacement_key_id: Optional[str] = None + details: Optional[str] = None + +class RevocationManager: + """Key revocation management.""" + + def __init__(self, + redis_url: str = "redis://localhost", + propagation_delay: int = 300): # 5 minutes + self.redis_url = redis_url + self.propagation_delay = propagation_delay + self.redis: Optional[aioredis.Redis] = None + self._propagation_task = None + + async def initialize(self) -> None: + """Initialize revocation manager.""" + try: + self.redis = await aioredis.from_url(self.redis_url) + self._propagation_task = asyncio.create_task( + self._propagate_revocations() + ) + logger.info("Revocation manager initialized") + except Exception as e: + logger.error(f"Failed to initialize revocation manager: {e}") + raise RevocationError(f"Revocation initialization failed: {e}") + + async def revoke_key(self, + key_id: str, + reason: RevocationReason, + replacement_key_id: Optional[str] = None, + details: Optional[str] = None) -> None: + """ + Revoke a key. + + Args: + key_id: ID of key to revoke + reason: Reason for revocation + replacement_key_id: ID of replacement key + details: Additional details + """ + try: + revocation = RevocationInfo( + key_id=key_id, + reason=reason, + timestamp=datetime.utcnow(), + replacement_key_id=replacement_key_id, + details=details + ) + + # Store revocation + await self.redis.hset( + "revocations", + key_id, + json.dumps(revocation.__dict__) + ) + + # Add to propagation queue + await self.redis.zadd( + "revocation_queue", + {key_id: datetime.utcnow().timestamp()} + ) + + logger.info(f"Key {key_id} revoked: {reason.value}") + + except Exception as e: + logger.error(f"Failed to revoke key {key_id}: {e}") + raise RevocationError(f"Key revocation failed: {e}") + + async def check_revocation(self, key_id: str) -> Optional[RevocationInfo]: + """Check if a key is revoked.""" + try: + data = await self.redis.hget("revocations", key_id) + if data: + info = json.loads(data) + return RevocationInfo(**info) + return None + except Exception as e: + logger.error(f"Failed to check revocation for {key_id}: {e}") + raise RevocationError(f"Revocation check failed: {e}") + + async def _propagate_revocations(self) -> None: + """Propagate revocations to federation.""" + while True: + try: + now = datetime.utcnow().timestamp() + cutoff = now - self.propagation_delay + + # Get revocations ready for propagation + revocations = await self.redis.zrangebyscore( + "revocation_queue", + "-inf", + cutoff + ) + + for key_id in revocations: + # Propagate revocation + await self._announce_revocation(key_id) + + # Remove from queue + await self.redis.zrem("revocation_queue", key_id) + + await asyncio.sleep(60) # Check every minute + + except Exception as e: + logger.error(f"Revocation propagation failed: {e}") + await asyncio.sleep(300) # Retry in 5 minutes + + async def _announce_revocation(self, key_id: str) -> None: + """Announce key revocation to federation.""" + # Implementation for federation announcement + pass + + async def close(self) -> None: + """Clean up resources.""" + if self._propagation_task: + self._propagation_task.cancel() + try: + await self._propagation_task + except asyncio.CancelledError: + pass + + if self.redis: + await self.redis.close() \ No newline at end of file diff --git a/src/pyfed/security/validators.py b/src/pyfed/security/validators.py new file mode 100644 index 0000000000000000000000000000000000000000..0519ecba6ea913e21689ec692e81e9e4973fbf73 --- /dev/null +++ b/src/pyfed/security/validators.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/pyfed/security/webfinger.py b/src/pyfed/security/webfinger.py new file mode 100644 index 0000000000000000000000000000000000000000..e0f2ca91b107df4c5355888ae1ef6b94c720c667 --- /dev/null +++ b/src/pyfed/security/webfinger.py @@ -0,0 +1,263 @@ +""" +Enhanced WebFinger implementation for ActivityPub. +""" + +import aiohttp +import json +from typing import Dict, Any, Optional, List, Union +from dataclasses import dataclass +from urllib.parse import urlparse, urljoin +import cachetools +import asyncio +from datetime import timedelta + +@dataclass +class WebFingerConfig: + """Configuration for WebFinger service.""" + cache_ttl: int = 3600 # Cache TTL in seconds + timeout: int = 10 # Request timeout in seconds + max_redirects: int = 3 + cache_size: int = 1000 + user_agent: str = "PyFed/1.0 (WebFinger)" + allowed_protocols: List[str] = ("https",) + +class WebFingerError(Exception): + """Base exception for WebFinger errors.""" + pass + +class WebFingerService: + """Enhanced WebFinger service.""" + + def __init__(self, + local_domain: str, + config: Optional[WebFingerConfig] = None): + """ + Initialize WebFinger service. + + Args: + local_domain: Local server domain + config: Optional configuration + """ + self.local_domain = local_domain + self.config = config or WebFingerConfig() + + # Initialize cache + self.cache = cachetools.TTLCache( + maxsize=self.config.cache_size, + ttl=self.config.cache_ttl + ) + + # Lock for thread safety + self._lock = asyncio.Lock() + + # Metrics + self.metrics = { + 'webfinger_requests': 0, + 'cache_hits': 0, + 'cache_misses': 0, + 'failed_requests': 0 + } + + async def get_resource(self, resource: str) -> Dict[str, Any]: + """ + Get WebFinger resource data. + + Args: + resource: Resource identifier (acct:user@domain or https://domain/users/user) + + Returns: + Dict containing resource data + + Raises: + WebFingerError: If resource lookup fails + """ + try: + self.metrics['webfinger_requests'] += 1 + + # Check cache + async with self._lock: + if resource in self.cache: + self.metrics['cache_hits'] += 1 + return self.cache[resource] + self.metrics['cache_misses'] += 1 + + # Parse resource + domain = self._get_domain(resource) + if not domain: + raise WebFingerError(f"Invalid resource: {resource}") + + # Make WebFinger request + data = await self._fetch_webfinger(domain, resource) + + # Cache result + async with self._lock: + self.cache[resource] = data + + return data + + except WebFingerError: + raise + except Exception as e: + self.metrics['failed_requests'] += 1 + raise WebFingerError(f"WebFinger lookup failed: {e}") + + async def get_actor_url(self, account: str) -> str: + """ + Get ActivityPub actor URL for account. + + Args: + account: Account in format user@domain + + Returns: + Actor URL + + Raises: + WebFingerError: If actor lookup fails + """ + try: + # Validate account format + if '@' not in account: + raise WebFingerError("Invalid account format") + + username, domain = account.split('@') + + # First try WebFinger + resource = f"acct:{account}" + try: + data = await self.get_resource(resource) + + # Look for ActivityPub profile URL + for link in data.get('links', []): + if (link.get('rel') == 'self' and + link.get('type') == 'application/activity+json'): + return link['href'] + + except WebFingerError: + # Fall back to direct URL only for allowed protocols + if domain: + url = f"https://{domain}/users/{username}" + if urlparse(url).scheme in self.config.allowed_protocols: + return url + + raise WebFingerError(f"Could not determine actor URL for {account}") + + except WebFingerError: + raise + except Exception as e: + raise WebFingerError(f"Actor URL lookup failed: {e}") + + async def get_actor_data(self, url: str) -> Dict[str, Any]: + """ + Fetch actor data from URL. + + Args: + url: Actor URL + + Returns: + Actor data + + Raises: + WebFingerError: If actor fetch fails + """ + try: + headers = { + "Accept": "application/activity+json", + "User-Agent": self.config.user_agent + } + + timeout = aiohttp.ClientTimeout(total=self.config.timeout) + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url, headers=headers) as response: + if response.status != 200: + raise WebFingerError( + f"Failed to fetch actor: {response.status}" + ) + + data = await response.json() + + # Validate data + if not isinstance(data, dict): + raise WebFingerError("Invalid actor data format") + + required_fields = {'id', 'type'} + if not all(field in data for field in required_fields): + raise WebFingerError("Missing required actor fields") + + return data + + except aiohttp.ClientError as e: + raise WebFingerError(f"Network error: {e}") + except json.JSONDecodeError as e: + raise WebFingerError(f"Invalid JSON response: {e}") + except Exception as e: + raise WebFingerError(f"Failed to fetch actor data: {e}") + + async def _fetch_webfinger(self, domain: str, resource: str) -> Dict[str, Any]: + """Make WebFinger request with security checks.""" + try: + # Validate domain and build URL + if not self._is_valid_domain(domain): + raise WebFingerError(f"Invalid domain: {domain}") + + webfinger_url = f"https://{domain}/.well-known/webfinger" + params = {"resource": resource} + + headers = { + "Accept": "application/jrd+json", + "User-Agent": self.config.user_agent + } + + timeout = aiohttp.ClientTimeout(total=self.config.timeout) + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get( + webfinger_url, + params=params, + headers=headers, + max_redirects=self.config.max_redirects + ) as response: + if response.status != 200: + raise WebFingerError( + f"WebFinger request failed: {response.status}" + ) + + data = await response.json() + + # Validate response + if not isinstance(data, dict): + raise WebFingerError("Invalid WebFinger response format") + + if 'subject' not in data: + raise WebFingerError("Missing required field: subject") + + return data + + except Exception as e: + raise WebFingerError(f"WebFinger request failed: {e}") + + def _get_domain(self, resource: str) -> Optional[str]: + """Extract domain from resource with validation.""" + try: + if resource.startswith('acct:'): + _, address = resource.split(':', 1) + if '@' in address: + return address.split('@')[1] + else: + parsed = urlparse(resource) + if parsed.netloc: + return parsed.netloc + return None + except Exception: + return None + + def _is_valid_domain(self, domain: str) -> bool: + """Validate domain name.""" + if not domain: + return False + # Add your domain validation logic here + return True + + async def cleanup(self): + """Cleanup resources.""" + self.cache.clear() \ No newline at end of file diff --git a/src/pyfed/serializers/json_serializer.py b/src/pyfed/serializers/json_serializer.py index 3cbb28b73720c8153a5c4be7107a4dc89f9c2017..54e34d70453cbc1e788196b13f4f41d59d2664ee 100644 --- a/src/pyfed/serializers/json_serializer.py +++ b/src/pyfed/serializers/json_serializer.py @@ -1,113 +1,238 @@ """ -json_serializer.py -This module provides JSON serialization for ActivityPub objects. +JSON serializer for ActivityPub objects. """ +from typing import Any, Dict, Union, List, Optional, Type, get_origin, get_args +from datetime import datetime, timezone import json -from datetime import datetime -from typing import Any, Dict -from pydantic import BaseModel +import re +from pydantic import BaseModel, AnyUrl, HttpUrl from pydantic_core import Url -from ..utils.logging import get_logger -# from ..plugins import plugin_manager - -logger = get_logger(__name__) def to_camel_case(snake_str: str) -> str: - """Converts snake_case to camelCase.""" + """Convert snake_case to camelCase.""" components = snake_str.split('_') return components[0] + ''.join(x.title() for x in components[1:]) -def convert_dict_keys_to_camel_case(data: Dict[str, Any]) -> Dict[str, Any]: - """Recursively converts all dictionary keys from snake_case to camelCase.""" - if not isinstance(data, dict): - return data - - return { - to_camel_case(key): ( - convert_dict_keys_to_camel_case(value) if isinstance(value, (dict, list)) - else value - ) - for key, value in data.items() - } - -class ActivityPubJSONEncoder(json.JSONEncoder): - """Custom JSON encoder for ActivityPub objects.""" - - def default(self, obj: Any) -> Any: - if isinstance(obj, BaseModel): - return convert_dict_keys_to_camel_case(obj.model_dump()) - if isinstance(obj, Url): - return str(obj) - if isinstance(obj, datetime): - return obj.isoformat() - if isinstance(obj, list): - return [self.default(item) for item in obj] - return super().default(obj) +def is_url_field(field_name: str) -> bool: + """Check if field name suggests it's a URL.""" + url_indicators = [ + 'url', 'href', 'id', 'inbox', 'outbox', 'following', + 'followers', 'liked', 'icon', 'image', 'avatar', + 'endpoints', 'featured', 'streams' + ] + return any(indicator in field_name.lower() for indicator in url_indicators) -class ActivityPubBase(BaseModel): - """Base class for all ActivityPub models.""" +class ActivityPubSerializer: + """ActivityPub serializer implementation.""" - class Config: - """Pydantic model configuration.""" - populate_by_name = True - use_enum_values = True - alias_generator = to_camel_case + @staticmethod + def _process_value(value: Any, field_name: str = "", depth: int = 0) -> Any: + """ + Process a single value for serialization. + + Args: + value: Value to process + field_name: Name of the field being processed + depth: Current recursion depth + + Returns: + Processed value + """ + # Prevent infinite recursion + if depth > 10: # Maximum nesting depth + return str(value) -class ActivityPubSerializer: - """Serializer for ActivityPub objects.""" + if value is None: + return None + + # Handle BaseModel instances (nested objects) + if isinstance(value, BaseModel): + # Recursively serialize nested objects + serialized = value.model_dump(exclude_none=True) + return { + to_camel_case(k): ActivityPubSerializer._process_value(v, k, depth + 1) + for k, v in serialized.items() + } + + # Handle URL types - using pydantic_core.Url instead of AnyUrl + if isinstance(value, Url): + return str(value) + + # Handle datetime + if isinstance(value, datetime): + return value.astimezone(timezone.utc).isoformat() + + # Handle lists with potential nested objects + if isinstance(value, list): + return [ + ActivityPubSerializer._process_value(item, field_name, depth + 1) + for item in value + ] + + # Handle dictionaries with potential nested objects + if isinstance(value, dict): + return { + to_camel_case(k): ActivityPubSerializer._process_value(v, k, depth + 1) + for k, v in value.items() + } + + # Convert string to URL if field name suggests it's a URL + if isinstance(value, str) and is_url_field(field_name): + if not value.startswith(('http://', 'https://')): + value = f"https://{value}" + return value + + return value @staticmethod - def serialize(obj: ActivityPubBase, include_context: bool = True, **kwargs) -> str: + def serialize(obj: Any, include_context: bool = True) -> Dict[str, Any]: """ - Serialize an ActivityPub object to JSON string. - + Serialize object to dictionary. + Args: - obj (ActivityPubBase): The object to serialize. - include_context (bool): Whether to include @context field. - **kwargs: Additional arguments passed to json.dumps. - + obj: Object to serialize + include_context: Whether to include @context + Returns: - str: JSON string representation of the object. + Serialized dictionary """ - logger.debug("Serializing object") - - # Execute pre-serialize hook - # plugin_manager.execute_hook('pre_serialize', obj) - - # Convert to dictionary and convert keys to camelCase - data = convert_dict_keys_to_camel_case(obj.model_dump()) + if not isinstance(obj, BaseModel): + return ActivityPubSerializer._process_value(obj) + + # Process each field + processed_data = ActivityPubSerializer._process_value(obj) # Add context if needed if include_context: - data["@context"] = "https://www.w3.org/ns/activitystreams" - - # Serialize to JSON - serialized = json.dumps(data, cls=ActivityPubJSONEncoder, **kwargs) - - return serialized + processed_data["@context"] = "https://www.w3.org/ns/activitystreams" + + return processed_data @staticmethod - def deserialize(json_str: str, model_class: type[ActivityPubBase]) -> ActivityPubBase: + def _process_field_value(value: Any, field_type: Any) -> Any: """ - Deserialize a JSON string to an ActivityPub object. - + Process field value during deserialization. + Args: - json_str (str): The JSON string to deserialize. - model_class (type[ActivityPubBase]): The class to deserialize into. - + value: Value to process + field_type: Type annotation for the field + Returns: - ActivityPubBase: The deserialized object. + Processed value """ - logger.debug(f"Deserializing to {model_class.__name__}") + # Handle None values + if value is None: + return None + + # Handle nested BaseModel + if hasattr(field_type, 'model_fields'): + return ActivityPubSerializer.deserialize(value, field_type) + + # Handle lists + origin = get_origin(field_type) + if origin is list: + args = get_args(field_type) + if args and hasattr(args[0], 'model_fields'): + return [ + ActivityPubSerializer.deserialize(item, args[0]) + if isinstance(item, dict) + else item + for item in value + ] + + # Handle dictionaries + if origin is dict: + key_type, val_type = get_args(field_type) + if hasattr(val_type, 'model_fields'): + return { + k: ActivityPubSerializer.deserialize(v, val_type) + if isinstance(v, dict) + else v + for k, v in value.items() + } + + return value + + @staticmethod + def deserialize(data: Union[str, Dict[str, Any]], model_class: Type[BaseModel]) -> BaseModel: + """ + Deserialize data to object. - # Parse JSON - data = json.loads(json_str) + Args: + data: JSON string or dictionary to deserialize + model_class: Class to deserialize into + + Returns: + Deserialized object + """ + # Handle JSON string input + if isinstance(data, str): + try: + data_dict = json.loads(data) + except json.JSONDecodeError: + raise ValueError("Invalid JSON string") + else: + data_dict = data + + if not isinstance(data_dict, dict): + raise ValueError("Data must be a dictionary or JSON string") + + # Make a copy of the data + data_dict = dict(data_dict) # Remove context if present - data.pop("@context", None) + data_dict.pop('@context', None) - # Create object - obj = model_class.model_validate(data) - - return obj + # Convert keys from camelCase to snake_case and process values + processed_data = {} + for key, value in data_dict.items(): + if key == '@context': + continue + + snake_key = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', key).lower() + + # Get field info from model + field_info = model_class.model_fields.get(snake_key) + if field_info is None: + continue + + # Process the field value + processed_value = ActivityPubSerializer._process_field_value( + value, field_info.annotation + ) + + processed_data[snake_key] = processed_value + + # Use model_validate instead of direct construction + return model_class.model_validate(processed_data) + +class ActivityPubBase(BaseModel): + """Base class for all ActivityPub objects.""" + + def serialize(self, include_context: bool = True) -> Dict[str, Any]: + """Serialize object to dictionary.""" + return ActivityPubSerializer.serialize(self, include_context) + + @classmethod + def deserialize(cls, data: Union[str, Dict[str, Any]]) -> 'ActivityPubBase': + """Deserialize dictionary to object.""" + return ActivityPubSerializer.deserialize(data, cls) + + class Config: + """Pydantic config.""" + alias_generator = to_camel_case + populate_by_alias = True + extra = "allow" + arbitrary_types_allowed = True + populate_by_name = True + +def to_json(obj: ActivityPubBase, **kwargs) -> str: + """Convert object to JSON string.""" + return json.dumps(ActivityPubSerializer.serialize(obj), **kwargs) + +def from_json(json_str: str, model_class: Type[ActivityPubBase]) -> ActivityPubBase: + """Convert JSON string to object.""" + return ActivityPubSerializer.deserialize(json_str, model_class) + + diff --git a/src/pyfed/storage/__init__.py b/src/pyfed/storage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..30b7e6136062684e08439d4839072a1a439ff8d9 --- /dev/null +++ b/src/pyfed/storage/__init__.py @@ -0,0 +1,44 @@ +""" +Storage package for ActivityPub data persistence. + +This package provides: +- Abstract storage interfaces +- Multiple storage backend implementations +- Storage provider protocol +""" + +from .base import StorageBackend, StorageProvider +from .backends.postgresql import PostgreSQLStorage +from .backends.mongodb import MongoDBStorageBackend +from .backends.redis import RedisStorageBackend +from .backends.sqlite import SQLiteStorage +# Default storage backend +DEFAULT_BACKEND = PostgreSQLStorage + +__all__ = [ + 'StorageBackend', + 'StorageProvider', + 'PostgreSQLStorage', + 'MongoDBStorageBackend', + 'RedisStorageBackend', + 'SQLiteStorage', + 'DEFAULT_BACKEND' +] + +# Storage backend registry +STORAGE_BACKENDS = { + 'postgresql': PostgreSQLStorage, + 'mongodb': MongoDBStorageBackend, + 'redis': RedisStorageBackend, + 'sqlite': SQLiteStorage +} + +def get_storage_backend(backend_type: str) -> type[StorageBackend]: + """Get storage backend class by type.""" + if backend_type not in STORAGE_BACKENDS: + raise ValueError(f"Unknown storage backend: {backend_type}") + return STORAGE_BACKENDS[backend_type] + +def register_backend(name: str, backend_class: type[StorageBackend]) -> None: + """Register a new storage backend.""" + STORAGE_BACKENDS[name] = backend_class \ No newline at end of file diff --git a/src/pyfed/storage/backends/__init__.py b/src/pyfed/storage/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0519ecba6ea913e21689ec692e81e9e4973fbf73 --- /dev/null +++ b/src/pyfed/storage/backends/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/pyfed/storage/backends/memory.py b/src/pyfed/storage/backends/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..b872bf4bdc1d48a32d67c2da8a5da2ae5e1b73fa --- /dev/null +++ b/src/pyfed/storage/backends/memory.py @@ -0,0 +1,196 @@ +""" +In-memory storage backend for testing and examples. +""" + +from typing import Dict, Any, List, Optional +from datetime import datetime + +from ..base import StorageBackend +from ...utils.exceptions import StorageError + +class MemoryStorageBackend(StorageBackend): + """Simple in-memory storage backend.""" + + def __init__(self): + """Initialize in-memory storage.""" + self.actors: Dict[str, Dict] = {} + self.activities: Dict[str, Dict] = {} + self.objects: Dict[str, Dict] = {} + self.follows: Dict[str, Dict] = {} + self.likes: Dict[str, Dict] = {} + self.announces: Dict[str, Dict] = {} + self.usernames: Dict[str, str] = {} # username -> actor_id mapping + self.inboxes: Dict[str, List] = {} # actor_id -> activities + self.outboxes: Dict[str, List] = {} # actor_id -> activities + + async def init(self) -> None: + """Initialize storage.""" + pass + + async def close(self) -> None: + """Close storage.""" + self.actors.clear() + self.activities.clear() + self.objects.clear() + self.follows.clear() + self.likes.clear() + self.announces.clear() + + async def create_actor(self, actor_data: Dict[str, Any]) -> str: + """Create actor.""" + actor_id = actor_data.get('id') + if not actor_id: + raise StorageError("Actor must have an ID") + + username = actor_data.get('preferredUsername') + if username: + self.usernames[username] = actor_id + + self.actors[actor_id] = actor_data + self.inboxes[actor_id] = [] + self.outboxes[actor_id] = [] + return actor_id + + async def get_actor(self, actor_id: str) -> Optional[Dict[str, Any]]: + """Get actor by ID.""" + return self.actors.get(actor_id) + + async def get_actor_by_username(self, username: str) -> Optional[Dict[str, Any]]: + """Get actor by username.""" + actor_id = self.usernames.get(username) + if actor_id: + return self.actors.get(actor_id) + return None + + async def update_actor(self, actor_id: str, actor_data: Dict[str, Any]) -> None: + """Update actor.""" + if actor_id not in self.actors: + raise StorageError("Actor not found") + self.actors[actor_id].update(actor_data) + + async def delete_actor(self, actor_id: str) -> None: + """Delete actor.""" + if actor_id in self.actors: + actor = self.actors[actor_id] + username = actor.get('preferredUsername') + if username: + self.usernames.pop(username, None) + self.actors.pop(actor_id) + self.inboxes.pop(actor_id, None) + self.outboxes.pop(actor_id, None) + + async def create_activity(self, activity_data: Dict[str, Any]) -> str: + """Create activity.""" + activity_id = activity_data.get('id') + if not activity_id: + activity_id = f"activity_{len(self.activities)}" + activity_data['id'] = activity_id + + self.activities[activity_id] = activity_data + + # Add to actor's outbox + actor = activity_data.get('actor') + if actor and actor in self.outboxes: + self.outboxes[actor].append(activity_id) + + return activity_id + + async def get_activities(self, limit: int = 20, offset: int = 0) -> List[Dict[str, Any]]: + """Get activities with pagination.""" + activities = list(self.activities.values()) + return activities[offset:offset + limit] + + async def create_object(self, object_data: Dict[str, Any]) -> str: + """Create object.""" + object_id = object_data.get('id') + if not object_id: + object_id = f"object_{len(self.objects)}" + object_data['id'] = object_id + + self.objects[object_id] = object_data + return object_id + + async def update_object(self, object_id: str, object_data: Dict[str, Any]) -> None: + """Update object.""" + if object_id not in self.objects: + raise StorageError("Object not found") + self.objects[object_id].update(object_data) + + async def delete_object(self, object_id: str) -> None: + """Delete object.""" + self.objects.pop(object_id, None) + + async def get_inbox(self, actor_id: str, limit: int = 20, offset: int = 0) -> List[Dict[str, Any]]: + """Get actor's inbox.""" + activity_ids = self.inboxes.get(actor_id, []) + activities = [ + self.activities[aid] + for aid in activity_ids[offset:offset + limit] + if aid in self.activities + ] + return activities + + async def get_outbox(self, actor_id: str, limit: int = 20, offset: int = 0) -> List[Dict[str, Any]]: + """Get actor's outbox.""" + activity_ids = self.outboxes.get(actor_id, []) + activities = [ + self.activities[aid] + for aid in activity_ids[offset:offset + limit] + if aid in self.activities + ] + return activities + + async def create_follow(self, follower: str, following: str) -> None: + """Create follow relationship.""" + key = f"{follower}:{following}" + self.follows[key] = { + "follower": follower, + "following": following, + "created_at": datetime.utcnow().isoformat() + } + + async def get_followers(self, actor_id: str, limit: int = 20, offset: int = 0) -> List[str]: + """Get actor's followers.""" + followers = [ + follow["follower"] + for follow in self.follows.values() + if follow["following"] == actor_id + ] + return followers[offset:offset + limit] + + async def get_following(self, actor_id: str, limit: int = 20, offset: int = 0) -> List[str]: + """Get actors being followed.""" + following = [ + follow["following"] + for follow in self.follows.values() + if follow["follower"] == actor_id + ] + return following[offset:offset + limit] + + async def create_like(self, actor: str, object_id: str) -> None: + """Create like.""" + key = f"{actor}:{object_id}" + self.likes[key] = { + "actor": actor, + "object": object_id, + "created_at": datetime.utcnow().isoformat() + } + + async def create_announce(self, actor: str, object_id: str) -> None: + """Create announce.""" + key = f"{actor}:{object_id}" + self.announces[key] = { + "actor": actor, + "object": object_id, + "created_at": datetime.utcnow().isoformat() + } + + async def remove_announce(self, actor: str, object_id: str) -> None: + """Remove announce.""" + key = f"{actor}:{object_id}" + self.announces.pop(key, None) + + async def has_announced(self, actor: str, object_id: str) -> bool: + """Check if actor has announced object.""" + key = f"{actor}:{object_id}" + return key in self.announces \ No newline at end of file diff --git a/src/pyfed/storage/backends/mongodb.py b/src/pyfed/storage/backends/mongodb.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6116885bfc3cb1adc72b38e125999dd6c6dea0 --- /dev/null +++ b/src/pyfed/storage/backends/mongodb.py @@ -0,0 +1,30 @@ +""" +MongoDB storage backend. +""" + +from typing import Dict, Any, List, Optional +from datetime import datetime +import motor.motor_asyncio + +from ..base import StorageBackend + +class MongoDBStorageBackend(StorageBackend): + """MongoDB storage backend implementation.""" + + def __init__(self, uri: str, database: str): + """Initialize MongoDB storage.""" + self.client = motor.motor_asyncio.AsyncIOMotorClient(uri) + self.db = self.client[database] + + async def init(self) -> None: + """Initialize database.""" + # Create indexes + await self.db.actors.create_index('username', unique=True) + await self.db.activities.create_index('actor_id') + await self.db.follows.create_index([('follower_id', 1), ('following_id', 1)]) + + async def close(self) -> None: + """Close database connection.""" + self.client.close() + + # Implement all abstract methods... \ No newline at end of file diff --git a/src/pyfed/storage/backends/postgresql.py b/src/pyfed/storage/backends/postgresql.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a8a9c9491ff6d1d5e798301278c0ea769e66f3 --- /dev/null +++ b/src/pyfed/storage/backends/postgresql.py @@ -0,0 +1,279 @@ +""" +PostgreSQL storage backend implementation. +""" + +from typing import Dict, Any, Optional, List +import json +from datetime import datetime +import asyncpg +from asyncpg.pool import Pool + +from ..base import StorageBackend +from ...utils.exceptions import StorageError +from ...utils.logging import get_logger + +logger = get_logger(__name__) + +SCHEMA = """ +CREATE TABLE IF NOT EXISTS activities ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + actor TEXT NOT NULL, + object_id TEXT, + data JSONB NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT activities_type_idx CHECK (type = ANY(ARRAY['Create', 'Follow', 'Like', 'Announce', 'Delete', 'Update', 'Undo', 'Accept', 'Reject'])) +); + +CREATE INDEX IF NOT EXISTS idx_activities_type ON activities(type); +CREATE INDEX IF NOT EXISTS idx_activities_actor ON activities(actor); +CREATE INDEX IF NOT EXISTS idx_activities_object_id ON activities(object_id); +CREATE INDEX IF NOT EXISTS idx_activities_created_at ON activities(created_at); + +CREATE TABLE IF NOT EXISTS objects ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + attributed_to TEXT, + data JSONB NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT objects_type_idx CHECK (type = ANY(ARRAY['Note', 'Article', 'Image', 'Video', 'Person', 'Group', 'Organization'])) +); + +CREATE INDEX IF NOT EXISTS idx_objects_type ON objects(type); +CREATE INDEX IF NOT EXISTS idx_objects_attributed_to ON objects(attributed_to); +CREATE INDEX IF NOT EXISTS idx_objects_created_at ON objects(created_at); +CREATE INDEX IF NOT EXISTS idx_objects_updated_at ON objects(updated_at); + +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ language 'plpgsql'; + +CREATE TRIGGER update_objects_updated_at + BEFORE UPDATE ON objects + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); +""" + +class PostgreSQLStorage(StorageBackend): + """PostgreSQL storage implementation.""" + + # def __init__(self, + # host: str = "localhost", + # port: int = 5432, + # user: str = "postgres", + # password: str = "", + # database: str = "pyfed", + # min_size: int = 5, + # max_size: int = 20): + # """Initialize PostgreSQL storage.""" + # self.dsn = f"postgresql://{user}:{password}@{host}:{port}/{database}" + # self.pool: Optional[Pool] = None + # self.min_size = min_size + # self.max_size = max_size + + def __init__(self, database_url: str, min_size: int = 5, max_size: int = 20, **kwargs): + self.database_url = database_url + self.pool = None + self.min_size = min_size + self.max_size = max_size + + async def initialize(self) -> None: + """Initialize database connection and schema.""" + try: + # Create connection pool + self.pool = await asyncpg.create_pool( + dsn=self.database_url, + min_size=self.min_size, + max_size=self.max_size + ) + + # Initialize schema + async with self.pool.acquire() as conn: + await conn.execute(SCHEMA) + + logger.info("PostgreSQL storage initialized") + + except Exception as e: + logger.error(f"Failed to initialize PostgreSQL storage: {e}") + raise StorageError(f"Storage initialization failed: {e}") + + async def create_activity(self, activity: Dict[str, Any]) -> str: + """Store an activity.""" + try: + activity_id = activity.get('id') + if not activity_id: + raise StorageError("Activity must have an ID") + + async with self.pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO activities (id, type, actor, object_id, data) + VALUES ($1, $2, $3, $4, $5) + """, + activity_id, + activity.get('type'), + activity.get('actor'), + activity.get('object', {}).get('id'), + json.dumps(activity) + ) + return activity_id + + except Exception as e: + logger.error(f"Failed to create activity: {e}") + raise StorageError(f"Failed to create activity: {e}") + + async def get_activity(self, activity_id: str) -> Optional[Dict[str, Any]]: + """Get an activity by ID.""" + try: + async with self.pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT data FROM activities WHERE id = $1", + activity_id + ) + return json.loads(row['data']) if row else None + + except Exception as e: + logger.error(f"Failed to get activity: {e}") + raise StorageError(f"Failed to get activity: {e}") + + async def create_object(self, obj: Dict[str, Any]) -> str: + """Store an object.""" + try: + object_id = obj.get('id') + if not object_id: + raise StorageError("Object must have an ID") + + async with self.pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO objects (id, type, attributed_to, data) + VALUES ($1, $2, $3, $4) + """, + object_id, + obj.get('type'), + obj.get('attributedTo'), + json.dumps(obj) + ) + return object_id + + except Exception as e: + logger.error(f"Failed to create object: {e}") + raise StorageError(f"Failed to create object: {e}") + + async def get_object(self, object_id: str) -> Optional[Dict[str, Any]]: + """Get an object by ID.""" + try: + async with self.pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT data FROM objects WHERE id = $1", + object_id + ) + return json.loads(row['data']) if row else None + + except Exception as e: + logger.error(f"Failed to get object: {e}") + raise StorageError(f"Failed to get object: {e}") + + async def delete_object(self, object_id: str) -> bool: + """Delete an object.""" + try: + async with self.pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM objects WHERE id = $1", + object_id + ) + return result == "DELETE 1" + + except Exception as e: + logger.error(f"Failed to delete object: {e}") + raise StorageError(f"Failed to delete object: {e}") + + async def update_object(self, object_id: str, obj: Dict[str, Any]) -> bool: + """Update an object.""" + try: + async with self.pool.acquire() as conn: + result = await conn.execute( + """ + UPDATE objects + SET data = $1, type = $2, attributed_to = $3 + WHERE id = $4 + """, + json.dumps(obj), + obj.get('type'), + obj.get('attributedTo'), + object_id + ) + return result == "UPDATE 1" + + except Exception as e: + logger.error(f"Failed to update object: {e}") + raise StorageError(f"Failed to update object: {e}") + + async def list_activities(self, + actor_id: Optional[str] = None, + activity_type: Optional[str] = None, + limit: int = 20, + offset: int = 0) -> List[Dict[str, Any]]: + """List activities with optional filtering.""" + try: + query = ["SELECT data FROM activities WHERE TRUE"] + params = [] + + if actor_id: + query.append("AND actor = $" + str(len(params) + 1)) + params.append(actor_id) + if activity_type: + query.append("AND type = $" + str(len(params) + 1)) + params.append(activity_type) + + query.append("ORDER BY created_at DESC") + query.append(f"LIMIT {limit} OFFSET {offset}") + + async with self.pool.acquire() as conn: + rows = await conn.fetch(" ".join(query), *params) + return [json.loads(row['data']) for row in rows] + + except Exception as e: + logger.error(f"Failed to list activities: {e}") + raise StorageError(f"Failed to list activities: {e}") + + async def list_objects(self, + object_type: Optional[str] = None, + attributed_to: Optional[str] = None, + limit: int = 20, + offset: int = 0) -> List[Dict[str, Any]]: + """List objects with optional filtering.""" + try: + query = ["SELECT data FROM objects WHERE TRUE"] + params = [] + + if object_type: + query.append("AND type = $" + str(len(params) + 1)) + params.append(object_type) + if attributed_to: + query.append("AND attributed_to = $" + str(len(params) + 1)) + params.append(attributed_to) + + query.append("ORDER BY created_at DESC") + query.append(f"LIMIT {limit} OFFSET {offset}") + + async with self.pool.acquire() as conn: + rows = await conn.fetch(" ".join(query), *params) + return [json.loads(row['data']) for row in rows] + + except Exception as e: + logger.error(f"Failed to list objects: {e}") + raise StorageError(f"Failed to list objects: {e}") + + async def close(self) -> None: + """Close database connection.""" + if self.pool: + await self.pool.close() + +# Register the provider +StorageBackend.register_provider("postgresql", PostgreSQLStorage) \ No newline at end of file diff --git a/src/pyfed/storage/backends/redis.py b/src/pyfed/storage/backends/redis.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7ad51fb8e9d1a18ab6108a743cf850f4351cda --- /dev/null +++ b/src/pyfed/storage/backends/redis.py @@ -0,0 +1,79 @@ +""" +Redis storage backend for caching. +""" + +from typing import Dict, Any, List, Optional +import json +from datetime import datetime +import redis.asyncio as redis + +from ..base import StorageBackend + +class RedisStorageBackend(StorageBackend): + """Redis storage backend implementation.""" + + def __init__(self, url: str, ttl: int = 3600): + """Initialize Redis storage.""" + self.redis = redis.from_url(url) + self.ttl = ttl + + async def init(self) -> None: + """Initialize Redis connection.""" + await self.redis.ping() + + async def close(self) -> None: + """Close Redis connection.""" + await self.redis.close() + + async def store_actor(self, actor_data: Dict[str, Any]) -> str: + """Store actor data.""" + await self.redis.set(actor_data["id"], json.dumps(actor_data), ex=self.ttl) + + async def get_actor(self, actor_id: str) -> Optional[Dict[str, Any]]: + """Get actor data.""" + data = await self.redis.get(actor_id) + return json.loads(data) if data else None + + async def get_actors(self, actor_ids: List[str]) -> List[Dict[str, Any]]: + """Get actor data.""" + data = await self.redis.mget(actor_ids) + return [json.loads(d) for d in data if d] + + async def store_activity(self, activity_data: Dict[str, Any]) -> str: + """Store activity data.""" + await self.redis.set(activity_data["id"], json.dumps(activity_data), ex=self.ttl) + + async def get_activity(self, activity_id: str) -> Optional[Dict[str, Any]]: + """Get activity data.""" + data = await self.redis.get(activity_id) + return json.loads(data) if data else None + + async def get_activities(self, activity_ids: List[str]) -> List[Dict[str, Any]]: + """Get activity data.""" + data = await self.redis.mget(activity_ids) + return [json.loads(d) for d in data if d] + + async def store_temp_token(self, token: str, data: Dict[str, Any], ttl: int) -> None: + """Store temporary token data.""" + await self.redis.set(token, json.dumps(data), ex=ttl) + + async def get_temp_token(self, token: str) -> Optional[Dict[str, Any]]: + """Get temporary token data.""" + data = await self.redis.get(token) + return json.loads(data) if data else None + + async def delete_temp_token(self, token: str) -> None: + """Delete temporary token data.""" + await self.redis.delete(token) + + async def delete_temp_tokens(self, tokens: List[str]) -> None: + """Delete temporary token data.""" + await self.redis.delete(*tokens) + + async def delete_all_temp_tokens(self) -> None: + """Delete all temporary token data.""" + await self.redis.flushall() + + async def delete_all_data(self) -> None: + """Delete all data.""" + await self.redis.flushall() \ No newline at end of file diff --git a/src/pyfed/storage/backends/sqlite.py b/src/pyfed/storage/backends/sqlite.py new file mode 100644 index 0000000000000000000000000000000000000000..e075062846d9f7b5c0619dc9ee481e85b178d1d7 --- /dev/null +++ b/src/pyfed/storage/backends/sqlite.py @@ -0,0 +1,247 @@ +""" +SQLite storage backend implementation. +""" + +from typing import Dict, Any, Optional, List +import aiosqlite +import json +from datetime import datetime + +from ..base import StorageBackend +from ...utils.exceptions import StorageError + +class SQLiteStorage(StorageBackend): + """SQLite storage implementation.""" + + def __init__(self, database_url: str, **kwargs): + self.database_url = database_url + self.db = None + + async def initialize(self) -> None: + """Initialize storage.""" + try: + self.db = await aiosqlite.connect(self.database_url) + + # Create tables + await self.db.execute(""" + CREATE TABLE IF NOT EXISTS activities ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + actor TEXT NOT NULL, + object_id TEXT, + data JSON NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + await self.db.execute(""" + CREATE TABLE IF NOT EXISTS objects ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + attributed_to TEXT, + data JSON NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indexes + await self.db.execute( + "CREATE INDEX IF NOT EXISTS idx_activities_type ON activities(type)" + ) + await self.db.execute( + "CREATE INDEX IF NOT EXISTS idx_activities_actor ON activities(actor)" + ) + await self.db.execute( + "CREATE INDEX IF NOT EXISTS idx_objects_type ON objects(type)" + ) + await self.db.execute( + "CREATE INDEX IF NOT EXISTS idx_objects_attributed_to ON objects(attributed_to)" + ) + + await self.db.commit() + + except Exception as e: + raise StorageError(f"Failed to initialize SQLite storage: {e}") + + async def create_activity(self, activity: Dict[str, Any]) -> str: + """Store an activity.""" + try: + activity_id = activity.get('id') + if not activity_id: + raise StorageError("Activity must have an ID") + + # Extract object_id based on whether object is a string or dict + object_data = activity.get('object') + object_id = object_data.get('id') if isinstance(object_data, dict) else object_data + + await self.db.execute( + """ + INSERT INTO activities (id, type, actor, object_id, data) + VALUES (?, ?, ?, ?, ?) + """, + ( + activity_id, + activity.get('type'), + activity.get('actor'), + object_id, # Use extracted object_id + json.dumps(activity) + ) + ) + await self.db.commit() + return activity_id + + except Exception as e: + raise StorageError(f"Failed to create activity: {e}") + + async def get_activity(self, activity_id: str) -> Optional[Dict[str, Any]]: + """Get an activity by ID.""" + try: + async with self.db.execute( + "SELECT data FROM activities WHERE id = ?", + (activity_id,) + ) as cursor: + row = await cursor.fetchone() + if row: + return json.loads(row[0]) + return None + + except Exception as e: + raise StorageError(f"Failed to get activity: {e}") + + async def list_activities(self, + actor_id: Optional[str] = None, + activity_type: Optional[str] = None, + limit: int = 20, + offset: int = 0) -> List[Dict[str, Any]]: + """List activities with optional filtering.""" + try: + query = ["SELECT data FROM activities WHERE 1=1"] + params = [] + + if actor_id: + query.append("AND actor = ?") + params.append(actor_id) + if activity_type: + query.append("AND type = ?") + params.append(activity_type) + + query.append("ORDER BY created_at DESC LIMIT ? OFFSET ?") + params.extend([limit, offset]) + + async with self.db.execute(" ".join(query), params) as cursor: + rows = await cursor.fetchall() + return [json.loads(row[0]) for row in rows] + + except Exception as e: + raise StorageError(f"Failed to list activities: {e}") + + async def create_object(self, obj: Dict[str, Any]) -> str: + """Store an object.""" + try: + object_id = obj.get('id') + if not object_id: + raise StorageError("Object must have an ID") + + await self.db.execute( + """ + INSERT INTO objects (id, type, attributed_to, data) + VALUES (?, ?, ?, ?) + """, + ( + object_id, + obj.get('type'), + obj.get('attributedTo'), + json.dumps(obj) + ) + ) + await self.db.commit() + return object_id + + except Exception as e: + raise StorageError(f"Failed to create object: {e}") + + async def get_object(self, object_id: str) -> Optional[Dict[str, Any]]: + """Get an object by ID.""" + try: + async with self.db.execute( + "SELECT data FROM objects WHERE id = ?", + (object_id,) + ) as cursor: + row = await cursor.fetchone() + if row: + return json.loads(row[0]) + return None + + except Exception as e: + raise StorageError(f"Failed to get object: {e}") + + async def update_object(self, object_id: str, obj: Dict[str, Any]) -> bool: + """Update an object.""" + try: + await self.db.execute( + """ + UPDATE objects + SET data = ?, type = ?, attributed_to = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = ? + """, + ( + json.dumps(obj), + obj.get('type'), + obj.get('attributedTo'), + object_id + ) + ) + await self.db.commit() + return True + + except Exception as e: + raise StorageError(f"Failed to update object: {e}") + + async def delete_object(self, object_id: str) -> bool: + """Delete an object.""" + try: + await self.db.execute( + "DELETE FROM objects WHERE id = ?", + (object_id,) + ) + await self.db.commit() + return True + + except Exception as e: + raise StorageError(f"Failed to delete object: {e}") + + async def list_objects(self, + object_type: Optional[str] = None, + attributed_to: Optional[str] = None, + limit: int = 20, + offset: int = 0) -> List[Dict[str, Any]]: + """List objects with optional filtering.""" + try: + query = ["SELECT data FROM objects WHERE 1=1"] + params = [] + + if object_type: + query.append("AND type = ?") + params.append(object_type) + if attributed_to: + query.append("AND attributed_to = ?") + params.append(attributed_to) + + query.append("ORDER BY created_at DESC LIMIT ? OFFSET ?") + params.extend([limit, offset]) + + async with self.db.execute(" ".join(query), params) as cursor: + rows = await cursor.fetchall() + return [json.loads(row[0]) for row in rows] + + except Exception as e: + raise StorageError(f"Failed to list objects: {e}") + + async def close(self) -> None: + """Close database connection.""" + if self.db: + await self.db.close() + +# Register the provider +StorageBackend.register_provider("sqlite", SQLiteStorage) \ No newline at end of file diff --git a/src/pyfed/storage/base.py b/src/pyfed/storage/base.py new file mode 100644 index 0000000000000000000000000000000000000000..52ef2b49ef2585ebe9d1bfe8adcd241c902e1747 --- /dev/null +++ b/src/pyfed/storage/base.py @@ -0,0 +1,94 @@ +""" +Base storage interface. +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List +from enum import Enum + +from ..utils.exceptions import StorageError + +class StorageProvider(Enum): + """Storage provider types.""" + MEMORY = "memory" + SQLITE = "sqlite" + POSTGRESQL = "postgresql" + REDIS = "redis" + +class StorageBackend(ABC): + """Abstract storage backend.""" + + # Register storage backends + _providers = {} + + @classmethod + def register_provider(cls, provider_name: str, provider_class: type): + """Register a storage provider.""" + cls._providers[provider_name] = provider_class + + @classmethod + def create(cls, provider: str, **kwargs) -> 'StorageBackend': + """Create storage backend instance.""" + if provider not in cls._providers: + raise StorageError(f"Unsupported storage provider: {provider}") + + provider_class = cls._providers[provider] + return provider_class(**kwargs) + + @abstractmethod + async def initialize(self) -> None: + """Initialize storage.""" + pass + + @abstractmethod + async def create_activity(self, activity: Dict[str, Any]) -> str: + """Store an activity.""" + pass + + @abstractmethod + async def get_activity(self, activity_id: str) -> Optional[Dict[str, Any]]: + """Get an activity by ID.""" + pass + + @abstractmethod + async def create_object(self, obj: Dict[str, Any]) -> str: + """Store an object.""" + pass + + @abstractmethod + async def get_object(self, object_id: str) -> Optional[Dict[str, Any]]: + """Get an object by ID.""" + pass + + @abstractmethod + async def delete_object(self, object_id: str) -> bool: + """Delete an object.""" + pass + + @abstractmethod + async def update_object(self, object_id: str, obj: Dict[str, Any]) -> bool: + """Update an object.""" + pass + + @abstractmethod + async def list_activities(self, + actor_id: Optional[str] = None, + activity_type: Optional[str] = None, + limit: int = 20, + offset: int = 0) -> List[Dict[str, Any]]: + """List activities with optional filtering.""" + pass + + @abstractmethod + async def list_objects(self, + object_type: Optional[str] = None, + attributed_to: Optional[str] = None, + limit: int = 20, + offset: int = 0) -> List[Dict[str, Any]]: + """List objects with optional filtering.""" + pass + + @abstractmethod + async def close(self) -> None: + """Close storage connection.""" + pass \ No newline at end of file diff --git a/src/pyfed/storage/connection.py b/src/pyfed/storage/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a7112287788555e5fde7b3ce6ea43677f6b23e --- /dev/null +++ b/src/pyfed/storage/connection.py @@ -0,0 +1,244 @@ +""" +Enhanced connection pooling implementation. +""" + +from typing import Dict, Any, Optional, List +import asyncio +import asyncpg +from datetime import datetime +import time +from dataclasses import dataclass +from enum import Enum + +from ..utils.exceptions import StorageError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +class PoolStrategy(Enum): + """Connection pool strategies.""" + FIXED = "fixed" + DYNAMIC = "dynamic" + ADAPTIVE = "adaptive" + +@dataclass +class PoolConfig: + """Pool configuration.""" + min_size: int + max_size: int + strategy: PoolStrategy + idle_timeout: int # seconds + max_queries: int # per connection + connection_timeout: int # seconds + +class ConnectionMetrics: + """Connection pool metrics.""" + def __init__(self): + self.total_connections = 0 + self.active_connections = 0 + self.idle_connections = 0 + self.waiting_queries = 0 + self.total_queries = 0 + self.failed_queries = 0 + self.connection_timeouts = 0 + self.query_timeouts = 0 + +class EnhancedPool: + """Enhanced connection pool.""" + + def __init__(self, + dsn: str, + config: Optional[PoolConfig] = None): + self.dsn = dsn + self.config = config or PoolConfig( + min_size=5, + max_size=20, + strategy=PoolStrategy.ADAPTIVE, + idle_timeout=300, + max_queries=1000, + connection_timeout=30 + ) + self.pool: Optional[asyncpg.Pool] = None + self.metrics = ConnectionMetrics() + self._monitor_task = None + self._cleanup_task = None + self._connection_stats: Dict[asyncpg.Connection, Dict[str, Any]] = {} + + async def initialize(self) -> None: + """Initialize connection pool.""" + try: + # Create pool + self.pool = await asyncpg.create_pool( + self.dsn, + min_size=self.config.min_size, + max_size=self.config.max_size, + command_timeout=self.config.connection_timeout + ) + + # Start monitoring + self._monitor_task = asyncio.create_task(self._monitor_pool()) + self._cleanup_task = asyncio.create_task(self._cleanup_connections()) + + logger.info( + f"Connection pool initialized with strategy: {self.config.strategy.value}" + ) + + except Exception as e: + logger.error(f"Failed to initialize connection pool: {e}") + raise StorageError(f"Pool initialization failed: {e}") + + async def acquire(self) -> asyncpg.Connection: + """ + Acquire database connection. + + Implements connection management based on strategy. + """ + try: + # Update metrics + self.metrics.waiting_queries += 1 + + # Get connection + if self.config.strategy == PoolStrategy.ADAPTIVE: + conn = await self._get_adaptive_connection() + else: + conn = await self.pool.acquire() + + # Update stats + self._connection_stats[conn] = { + 'acquired_at': datetime.utcnow(), + 'queries': 0 + } + + self.metrics.active_connections += 1 + self.metrics.waiting_queries -= 1 + + return conn + + except Exception as e: + self.metrics.waiting_queries -= 1 + self.metrics.connection_timeouts += 1 + logger.error(f"Failed to acquire connection: {e}") + raise StorageError(f"Failed to acquire connection: {e}") + + async def release(self, conn: asyncpg.Connection) -> None: + """Release database connection.""" + try: + # Update stats + if conn in self._connection_stats: + del self._connection_stats[conn] + + self.metrics.active_connections -= 1 + self.metrics.idle_connections += 1 + + await self.pool.release(conn) + + except Exception as e: + logger.error(f"Failed to release connection: {e}") + raise StorageError(f"Failed to release connection: {e}") + + async def _get_adaptive_connection(self) -> asyncpg.Connection: + """Get connection using adaptive strategy.""" + current_size = self.metrics.total_connections + active = self.metrics.active_connections + waiting = self.metrics.waiting_queries + + # Check if we need to grow pool + if ( + current_size < self.config.max_size and + (active / current_size > 0.75 or waiting > 0) + ): + # Grow pool + await self.pool.set_min_size(min( + current_size + 5, + self.config.max_size + )) + + return await self.pool.acquire() + + async def _monitor_pool(self) -> None: + """Monitor pool health and performance.""" + while True: + try: + # Update metrics + self.metrics.total_connections = len(self._connection_stats) + + # Log metrics + logger.debug( + f"Pool metrics - Total: {self.metrics.total_connections}, " + f"Active: {self.metrics.active_connections}, " + f"Idle: {self.metrics.idle_connections}, " + f"Waiting: {self.metrics.waiting_queries}" + ) + + # Adjust pool size if needed + if self.config.strategy == PoolStrategy.ADAPTIVE: + await self._adjust_pool_size() + + await asyncio.sleep(60) + + except Exception as e: + logger.error(f"Pool monitoring failed: {e}") + await asyncio.sleep(300) + + async def _adjust_pool_size(self) -> None: + """Adjust pool size based on usage.""" + current_size = self.metrics.total_connections + active = self.metrics.active_connections + + # Shrink if underutilized + if active / current_size < 0.25 and current_size > self.config.min_size: + new_size = max( + current_size - 5, + self.config.min_size + ) + await self.pool.set_min_size(new_size) + + # Grow if heavily utilized + elif active / current_size > 0.75 and current_size < self.config.max_size: + new_size = min( + current_size + 5, + self.config.max_size + ) + await self.pool.set_min_size(new_size) + + async def _cleanup_connections(self) -> None: + """Clean up idle and overused connections.""" + while True: + try: + now = datetime.utcnow() + + # Check each connection + for conn, stats in self._connection_stats.items(): + # Check idle timeout + idle_time = (now - stats['acquired_at']).total_seconds() + if idle_time > self.config.idle_timeout: + await self.release(conn) + + # Check query limit + if stats['queries'] >= self.config.max_queries: + await self.release(conn) + + await asyncio.sleep(60) + + except Exception as e: + logger.error(f"Connection cleanup failed: {e}") + await asyncio.sleep(300) + + async def close(self) -> None: + """Close connection pool.""" + if self._monitor_task: + self._monitor_task.cancel() + try: + await self._monitor_task + except asyncio.CancelledError: + pass + + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + if self.pool: + await self.pool.close() \ No newline at end of file diff --git a/src/pyfed/storage/migrations.py b/src/pyfed/storage/migrations.py new file mode 100644 index 0000000000000000000000000000000000000000..307ce5166cb5d7646c1b00630ba5a02aa13dd390 --- /dev/null +++ b/src/pyfed/storage/migrations.py @@ -0,0 +1,256 @@ +from typing import Dict, Any, List, Optional +import importlib +import pkgutil +import asyncio +from datetime import datetime +from pathlib import Path +import aiosqlite +import asyncpg +from enum import Enum + +from ..utils.exceptions import MigrationError +from ..utils.logging import get_logger + +logger = get_logger(__name__) + +class DatabaseType(Enum): + """Supported database types.""" + SQLITE = "sqlite" + POSTGRESQL = "postgresql" + +class MigrationInfo: + """Migration metadata.""" + def __init__(self, + version: str, + name: str, + description: str, + applied_at: Optional[datetime] = None): + self.version = version + self.name = name + self.description = description + self.applied_at = applied_at + +class MigrationManager: + """Database migration manager.""" + + def __init__(self, + db_type: DatabaseType, + connection_string: str, + migrations_dir: str = "migrations"): + self.db_type = db_type + self.connection_string = connection_string + self.migrations_dir = Path(migrations_dir) + self.conn = None + + async def initialize(self) -> None: + """Initialize migration system.""" + try: + # Connect to database + if self.db_type == DatabaseType.SQLITE: + self.conn = await aiosqlite.connect(self.connection_string) + else: + self.conn = await asyncpg.connect(self.connection_string) + + # Create migrations table + await self._create_migrations_table() + + except Exception as e: + logger.error(f"Failed to initialize migrations: {e}") + raise MigrationError(f"Migration initialization failed: {e}") + + async def _create_migrations_table(self) -> None: + """Create migrations tracking table.""" + if self.db_type == DatabaseType.SQLITE: + await self.conn.execute(""" + CREATE TABLE IF NOT EXISTS migrations ( + version TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + await self.conn.commit() + else: + await self.conn.execute(""" + CREATE TABLE IF NOT EXISTS migrations ( + version TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + applied_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP + ) + """) + + async def get_applied_migrations(self) -> List[MigrationInfo]: + """Get list of applied migrations.""" + try: + if self.db_type == DatabaseType.SQLITE: + async with self.conn.execute( + "SELECT version, name, description, applied_at FROM migrations ORDER BY version" + ) as cursor: + rows = await cursor.fetchall() + return [ + MigrationInfo( + version=row[0], + name=row[1], + description=row[2], + applied_at=datetime.fromisoformat(row[3]) + ) + for row in rows + ] + else: + rows = await self.conn.fetch( + "SELECT version, name, description, applied_at FROM migrations ORDER BY version" + ) + return [ + MigrationInfo( + version=row['version'], + name=row['name'], + description=row['description'], + applied_at=row['applied_at'] + ) + for row in rows + ] + + except Exception as e: + logger.error(f"Failed to get applied migrations: {e}") + raise MigrationError(f"Failed to get applied migrations: {e}") + + async def get_pending_migrations(self) -> List[MigrationInfo]: + """Get list of pending migrations.""" + try: + applied = await self.get_applied_migrations() + applied_versions = {m.version for m in applied} + + pending = [] + for migration in self._load_migrations(): + if migration.version not in applied_versions: + pending.append(migration) + + return sorted(pending, key=lambda m: m.version) + + except Exception as e: + logger.error(f"Failed to get pending migrations: {e}") + raise MigrationError(f"Failed to get pending migrations: {e}") + + def _load_migrations(self) -> List[MigrationInfo]: + """Load migration files.""" + migrations = [] + + for item in sorted(self.migrations_dir.glob("*.sql")): + version = item.stem + with open(item) as f: + description = f.readline().strip("-- ").strip() + migrations.append(MigrationInfo( + version=version, + name=item.name, + description=description + )) + + return migrations + + async def migrate(self, target_version: Optional[str] = None) -> None: + """ + Run migrations up to target version. + + Args: + target_version: Version to migrate to, or None for latest + """ + try: + pending = await self.get_pending_migrations() + if not pending: + logger.info("No pending migrations") + return + + for migration in pending: + if target_version and migration.version > target_version: + break + + logger.info(f"Applying migration {migration.version}: {migration.name}") + + # Read migration file + with open(self.migrations_dir / migration.name) as f: + sql = f.read() + + # Apply migration + if self.db_type == DatabaseType.SQLITE: + await self.conn.executescript(sql) + await self.conn.execute( + """ + INSERT INTO migrations (version, name, description) + VALUES (?, ?, ?) + """, + (migration.version, migration.name, migration.description) + ) + await self.conn.commit() + else: + async with self.conn.transaction(): + await self.conn.execute(sql) + await self.conn.execute( + """ + INSERT INTO migrations (version, name, description) + VALUES ($1, $2, $3) + """, + migration.version, + migration.name, + migration.description + ) + + logger.info(f"Applied migration {migration.version}") + + except Exception as e: + logger.error(f"Migration failed: {e}") + raise MigrationError(f"Migration failed: {e}") + + async def rollback(self, target_version: str) -> None: + """ + Rollback migrations to target version. + + Args: + target_version: Version to rollback to + """ + try: + applied = await self.get_applied_migrations() + to_rollback = [ + m for m in reversed(applied) + if m.version > target_version + ] + + for migration in to_rollback: + logger.info(f"Rolling back migration {migration.version}") + + # Read rollback file + rollback_file = self.migrations_dir / f"{migration.version}_rollback.sql" + if not rollback_file.exists(): + raise MigrationError( + f"No rollback file for migration {migration.version}" + ) + + with open(rollback_file) as f: + sql = f.read() + + # Apply rollback + if self.db_type == DatabaseType.SQLITE: + await self.conn.executescript(sql) + await self.conn.execute( + "DELETE FROM migrations WHERE version = ?", + (migration.version,) + ) + await self.conn.commit() + else: + async with self.conn.transaction(): + await self.conn.execute(sql) + await self.conn.execute( + "DELETE FROM migrations WHERE version = $1", + migration.version + ) + + logger.info(f"Rolled back migration {migration.version}") + + except Exception as e: + logger.error(f"Rollback failed: {e}") + raise MigrationError(f"Rollback failed: {e}") + + async def close(self) -> None: + """Close database connection.""" + if self.conn: + await self.conn.close() \ No newline at end of file diff --git a/src/pyfed/storage/optimization.py b/src/pyfed/storage/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6635f60b5005926b8aa73b5d0b48d6e8cb3ebd --- /dev/null +++ b/src/pyfed/storage/optimization.py @@ -0,0 +1,125 @@ +""" +Storage optimization implementation. +""" + +from typing import Dict, Any, Optional, List, Type +from datetime import datetime +import asyncio +import asyncpg +from dataclasses import dataclass +import json +from enum import Enum + +from ..utils.exceptions import StorageError +from ..utils.logging import get_logger +from .base import StorageBackend + +logger = get_logger(__name__) + +class QueryOptimizer: + """Query optimization and caching.""" + + def __init__(self, + cache_ttl: int = 300, # 5 minutes + max_cache_size: int = 1000): + self.cache_ttl = cache_ttl + self.max_cache_size = max_cache_size + self.query_cache: Dict[str, Dict[str, Any]] = {} + self.query_stats: Dict[str, Dict[str, int]] = {} + self._cleanup_task = None + + async def initialize(self) -> None: + """Initialize optimizer.""" + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + def optimize_query(self, query: str) -> str: + """ + Optimize SQL query. + + Applies optimizations like: + - Index hints + - Join optimizations + - Subquery optimization + """ + # Add index hints + if "WHERE" in query and "ORDER BY" in query: + query = self._add_index_hints(query) + + # Optimize joins + if "JOIN" in query: + query = self._optimize_joins(query) + + # Optimize subqueries + if "SELECT" in query and "IN (SELECT" in query: + query = self._optimize_subqueries(query) + + return query + + def _add_index_hints(self, query: str) -> str: + """Add index hints to query.""" + # Implementation for adding index hints + return query + + def _optimize_joins(self, query: str) -> str: + """Optimize query joins.""" + # Implementation for optimizing joins + return query + + def _optimize_subqueries(self, query: str) -> str: + """Optimize subqueries.""" + # Implementation for optimizing subqueries + return query + + async def get_cached_result(self, query: str, params: tuple) -> Optional[Any]: + """Get cached query result.""" + cache_key = f"{query}:{params}" + if cache_key in self.query_cache: + entry = self.query_cache[cache_key] + if datetime.utcnow().timestamp() < entry['expires']: + return entry['result'] + return None + + async def cache_result(self, query: str, params: tuple, result: Any) -> None: + """Cache query result.""" + cache_key = f"{query}:{params}" + expires = datetime.utcnow().timestamp() + self.cache_ttl + + # Manage cache size + if len(self.query_cache) >= self.max_cache_size: + # Remove least used entries + sorted_stats = sorted( + self.query_stats.items(), + key=lambda x: x[1]['hits'] + ) + to_remove = len(self.query_cache) - self.max_cache_size + 1 + for key, _ in sorted_stats[:to_remove]: + del self.query_cache[key] + del self.query_stats[key] + + self.query_cache[cache_key] = { + 'result': result, + 'expires': expires + } + + # Update stats + if cache_key not in self.query_stats: + self.query_stats[cache_key] = {'hits': 0} + self.query_stats[cache_key]['hits'] += 1 + + async def _cleanup_loop(self) -> None: + """Clean up expired cache entries.""" + while True: + try: + now = datetime.utcnow().timestamp() + expired = [ + key for key, entry in self.query_cache.items() + if entry['expires'] < now + ] + for key in expired: + del self.query_cache[key] + if key in self.query_stats: + del self.query_stats[key] + await asyncio.sleep(60) + except Exception as e: + logger.error(f"Cache cleanup failed: {e}") + await asyncio.sleep(300) \ No newline at end of file diff --git a/src/pyfed/utils/exceptions.py b/src/pyfed/utils/exceptions.py index f9cbc9caf1978486ad446950910175384a3bec0f..50c854e79e1293277dacb24f92c53d657740dc06 100644 --- a/src/pyfed/utils/exceptions.py +++ b/src/pyfed/utils/exceptions.py @@ -21,22 +21,118 @@ class SignatureVerificationError(ActivityPubException): pass class SignatureError(ActivityPubException): - """Raised when signature verification fails.""" + """Raised when signature creation fails.""" pass class AuthenticationError(ActivityPubException): - """Raised when signature verification fails.""" + """Raised when authentication fails.""" pass class RateLimitExceeded(ActivityPubException): - """Raised when signature verification fails.""" + """Raised when rate limit is exceeded.""" pass class WebFingerError(ActivityPubException): - """Raised when signature verification fails.""" + """Raised when WebFinger lookup fails.""" pass class SecurityError(ActivityPubException): - """Raised when signature verification fails.""" + """Raised when security-related errors occur.""" + pass + +class DeliveryError(ActivityPubException): + """Raised when delivery fails.""" + pass + +class FetchError(ActivityPubException): + """Raised when fetching fails.""" + pass + +class DiscoveryError(ActivityPubException): + """Raised when discovery fails.""" + pass + +class ResolutionError(ActivityPubException): + """Raised when rosolving fails.""" + pass + +class HandlerError(ActivityPubException): + """Raised when handling fails""" + pass + +class RateLimitError(ActivityPubException): + """Raised when rate limit is exceeded.""" + pass + +class TokenError(ActivityPubException): + """Raised when token-related errors occur.""" + pass + +class OAuthError(ActivityPubException): + """Raised when OAuth-related errors occur.""" + pass + +class RateLimiterError(ActivityPubException): + """Raised when rate limiter-related errors occur.""" + pass + +class FetchError(ActivityPubException): + """Raised when fetching fails.""" + pass + +class StorageError(ActivityPubException): + """Raised when storage-related errors occur.""" + pass + +class DiscoveryError(ActivityPubException): + """Raised when discovery fails.""" + pass + +class ResolverError(ActivityPubException): + """Raised when resolver-related errors occur.""" + pass + +class SignatureError(ActivityPubException): + """Raised when signature-related errors occur.""" + pass + +class SecurityValidatorError(ActivityPubException): + """Raised when security validator-related errors occur.""" + pass + +class DeliveryError(ActivityPubException): + """Raised when delivery-related errors occur.""" + pass + +class ResourceFetcherError(ActivityPubException): + """Raised when resource fetcher-related errors occur.""" + pass + +class SecurityValidatorError(ActivityPubException): + """Raised when security validator-related errors occur.""" + pass + +class WebFingerError(ActivityPubException): + """Raised when WebFinger-related errors occur.""" + pass + +class KeyManagementError(ActivityPubException): + """Raised when key manager-related errors occur.""" + pass + +class CollectionError(ActivityPubException): + """Raised when collection-related errors occur.""" + pass + +class ContentError(ActivityPubException): + """Raised when content-related errors occur.""" + pass + +class ContentHandlerError(ActivityPubException): + """Raised when content handler-related errors occur.""" + pass + +class CollectionHandlerError(ActivityPubException): + """Raised when collection handler-related errors occur.""" pass diff --git a/tests/conftest.py b/tests/conftest.py index 50159df17aa7631645848fa6df4f565f24ee3a1d..544021b2a24c45ba4f75dff7239831d0ef89f6ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,34 +1,75 @@ -import pytest -from pathlib import Path -import subprocess - """ -PyTest configuration and fixtures. +Global test configuration and fixtures. """ -pytest_plugins = ('pytest_asyncio',) - -@pytest.fixture(scope="session", autouse=True) -def test_keys(): - """Ensure test keys exist.""" - keys_dir = Path("tests/fixtures/keys") - private_key = keys_dir / "private.pem" - public_key = keys_dir / "public.pem" - - if not private_key.exists() or not public_key.exists(): - # Generate keys if they don't exist - subprocess.run(["python", "tests/fixtures/generate_keys.py"], check=True) - - return { - "private_key": str(private_key), - "public_key": str(public_key) - } +import pytest +from aiohttp import web +import asyncio +from typing import AsyncGenerator, Callable +from unittest.mock import AsyncMock, Mock + +pytest_plugins = ['pytest_aiohttp'] + +@pytest.fixture +def app() -> web.Application: + """Create base application for testing.""" + return web.Application() + +@pytest.fixture +def auth_middleware() -> Callable: + """Create auth middleware for testing.""" + async def middleware(app: web.Application, handler: Callable): + async def auth_handler(request: web.Request): + request['user'] = {"sub": "test_user"} + return await handler(request) + return auth_handler + return middleware + +@pytest.fixture +def rate_limit_middleware() -> Callable: + """Create rate limit middleware for testing.""" + async def middleware(app: web.Application, handler: Callable): + async def rate_limit_handler(request: web.Request): + # Simple in-memory rate limiting for tests + key = f"rate_limit:{request.remote}" + count = app.get(key, 0) + 1 + app[key] = count + + if count > 5: # Rate limit after 5 requests + return web.Response(status=429) + + return await handler(request) + return rate_limit_handler + return middleware + +@pytest.fixture +async def client(aiohttp_client: Callable, app: web.Application) -> AsyncGenerator: + """Create test client with configured application.""" + return await aiohttp_client(app) + +@pytest.fixture +def mock_storage() -> AsyncMock: + """Create mock storage.""" + storage = AsyncMock() + storage.create_activity = AsyncMock(return_value="test_activity_id") + storage.create_object = AsyncMock(return_value="test_object_id") + storage.is_following = AsyncMock(return_value=False) + return storage + +@pytest.fixture +def mock_delivery() -> AsyncMock: + """Create mock delivery service.""" + delivery = AsyncMock() + delivery.deliver_activity = AsyncMock( + return_value={"success": ["inbox1"], "failed": []} + ) + return delivery @pytest.fixture -def signature_verifier(test_keys): - """Create a test signature verifier with generated keys.""" - from pyfed.security.http_signatures import HTTPSignatureVerifier - return HTTPSignatureVerifier( - private_key_path=test_keys["private_key"], - public_key_path=test_keys["public_key"] - ) \ No newline at end of file +def mock_resolver() -> AsyncMock: + """Create mock resolver.""" + resolver = AsyncMock() + resolver.resolve_actor = AsyncMock( + return_value={"id": "https://example.com/users/test"} + ) + return resolver \ No newline at end of file diff --git a/tests/tests_models/__init__.py b/tests/unit_tests/models/__init__.py similarity index 100% rename from tests/tests_models/__init__.py rename to tests/unit_tests/models/__init__.py diff --git a/tests/tests_models/test_activities.py b/tests/unit_tests/models/test_activities.py similarity index 94% rename from tests/tests_models/test_activities.py rename to tests/unit_tests/models/test_activities.py index 544ec39e5db8e1e3a61b1e6dd0507f6f36232031..5f516325245806df0c0e4248ae07efb22dd8289a 100644 --- a/tests/tests_models/test_activities.py +++ b/tests/unit_tests/models/test_activities.py @@ -14,7 +14,6 @@ def test_create_activity(): """Test creating a Create activity.""" activity = APCreate( id="https://example.com/activity/123", - type="Create", actor="https://example.com/user/1", object={ "id": "https://example.com/note/123", @@ -30,7 +29,6 @@ def test_update_activity(): """Test creating an Update activity.""" activity = APUpdate( id="https://example.com/activity/123", - type="Update", actor="https://example.com/user/1", object={ "id": "https://example.com/note/123", @@ -45,7 +43,6 @@ def test_delete_activity(): """Test creating a Delete activity.""" activity = APDelete( id="https://example.com/activity/123", - type="Delete", actor="https://example.com/user/1", object="https://example.com/note/123" ) @@ -56,7 +53,6 @@ def test_follow_activity(): """Test creating a Follow activity.""" activity = APFollow( id="https://example.com/activity/123", - type="Follow", actor="https://example.com/user/1", object="https://example.com/user/2" ) @@ -67,7 +63,6 @@ def test_undo_activity(): """Test creating an Undo activity.""" activity = APUndo( id="https://example.com/activity/123", - type="Undo", actor="https://example.com/user/1", object={ "id": "https://example.com/activity/456", @@ -83,7 +78,6 @@ def test_like_activity(): """Test creating a Like activity.""" activity = APLike( id="https://example.com/activity/123", - type="Like", actor="https://example.com/user/1", object="https://example.com/note/123" ) @@ -94,7 +88,6 @@ def test_announce_activity(): """Test creating an Announce activity.""" activity = APAnnounce( id="https://example.com/activity/123", - type="Announce", actor="https://example.com/user/1", object="https://example.com/note/123" ) @@ -105,7 +98,6 @@ def test_activity_with_target(): """Test creating an activity with a target.""" activity = APCreate( id="https://example.com/activity/123", - type="Create", actor="https://example.com/user/1", object="https://example.com/note/123", target="https://example.com/collection/1" @@ -116,7 +108,6 @@ def test_activity_with_result(): """Test creating an activity with a result.""" activity = APCreate( id="https://example.com/activity/123", - type="Create", actor="https://example.com/user/1", object="https://example.com/note/123", result={ @@ -132,7 +123,6 @@ def test_invalid_activity_missing_actor(): with pytest.raises(ValidationError): APCreate( id="https://example.com/activity/123", - type="Create", object="https://example.com/note/123" ) @@ -141,6 +131,5 @@ def test_invalid_activity_missing_object(): with pytest.raises(ValidationError): APCreate( id="https://example.com/activity/123", - type="Create", actor="https://example.com/user/1" ) diff --git a/tests/tests_models/test_actors.py b/tests/unit_tests/models/test_actors.py similarity index 94% rename from tests/tests_models/test_actors.py rename to tests/unit_tests/models/test_actors.py index 29659a4fe627d1f3d6e9a44ed8fec3dfa69d7624..1a1b72b8a4786e9d55858994f066b7fc319a61f0 100644 --- a/tests/tests_models/test_actors.py +++ b/tests/unit_tests/models/test_actors.py @@ -12,7 +12,6 @@ from pyfed.models import ( def test_valid_person(): person = APPerson( id="https://example.com/users/alice", - type="Person", name="Alice", inbox="https://example.com/users/alice/inbox", outbox="https://example.com/users/alice/outbox", @@ -21,12 +20,12 @@ def test_valid_person(): assert person.type == "Person" assert str(person.inbox) == "https://example.com/users/alice/inbox" assert str(person.outbox) == "https://example.com/users/alice/outbox" + print(person.preferred_username) assert person.preferred_username == "alice" def test_person_with_optional_fields(): person = APPerson( id="https://example.com/users/alice", - type="Person", name="Alice", inbox="https://example.com/users/alice/inbox", outbox="https://example.com/users/alice/outbox", @@ -44,7 +43,6 @@ def test_invalid_person_missing_required(): with pytest.raises(ValidationError): APPerson( id="https://example.com/users/alice", - type="Person", name="Alice" # Missing required inbox and outbox ) @@ -53,7 +51,6 @@ def test_invalid_person_invalid_url(): with pytest.raises(ValidationError): APPerson( id="https://example.com/users/alice", - type="Person", name="Alice", inbox="not-a-url", # Invalid URL outbox="https://example.com/users/alice/outbox" @@ -62,7 +59,6 @@ def test_invalid_person_invalid_url(): def test_valid_group(): group = APGroup( id="https://example.com/groups/admins", - type="Group", name="Administrators", inbox="https://example.com/groups/admins/inbox", outbox="https://example.com/groups/admins/outbox" @@ -73,7 +69,6 @@ def test_valid_group(): def test_valid_organization(): org = APOrganization( id="https://example.com/org/acme", - type="Organization", name="ACME Corporation", inbox="https://example.com/org/acme/inbox", outbox="https://example.com/org/acme/outbox" @@ -84,7 +79,6 @@ def test_valid_organization(): def test_valid_application(): app = APApplication( id="https://example.com/apps/bot", - type="Application", name="Bot Application", inbox="https://example.com/apps/bot/inbox", outbox="https://example.com/apps/bot/outbox" @@ -95,7 +89,6 @@ def test_valid_application(): def test_valid_service(): service = APService( id="https://example.com/services/api", - type="Service", name="API Service", inbox="https://example.com/services/api/inbox", outbox="https://example.com/services/api/outbox" @@ -106,7 +99,6 @@ def test_valid_service(): def test_actor_with_public_key(): person = APPerson( id="https://example.com/users/alice", - type="Person", name="Alice", inbox="https://example.com/users/alice/inbox", outbox="https://example.com/users/alice/outbox", @@ -122,7 +114,6 @@ def test_actor_with_public_key(): def test_actor_with_endpoints(): person = APPerson( id="https://example.com/users/alice", - type="Person", name="Alice", inbox="https://example.com/users/alice/inbox", outbox="https://example.com/users/alice/outbox", diff --git a/tests/tests_models/test_collections.py b/tests/unit_tests/models/test_collections.py similarity index 93% rename from tests/tests_models/test_collections.py rename to tests/unit_tests/models/test_collections.py index 1d715263452da32f6a3e294069565da76a013224..b54833f6d6acf0519b15de4cca81293f0779fb6a 100644 --- a/tests/tests_models/test_collections.py +++ b/tests/unit_tests/models/test_collections.py @@ -13,7 +13,6 @@ def test_valid_collection(): """Test creating a valid Collection.""" collection = APCollection( id="https://example.com/collection/123", - type="Collection", total_items=10, items=["https://example.com/item/1", "https://example.com/item/2"] ) @@ -25,7 +24,6 @@ def test_collection_with_optional_fields(): """Test creating a Collection with all optional fields.""" collection = APCollection( id="https://example.com/collection/123", - type="Collection", total_items=2, current="https://example.com/collection/123/current", first="https://example.com/collection/123/first", @@ -49,7 +47,6 @@ def test_valid_ordered_collection(): """Test creating a valid OrderedCollection.""" collection = APOrderedCollection( id="https://example.com/collection/123", - type="OrderedCollection", total_items=2, ordered_items=["https://example.com/item/1", "https://example.com/item/2"] ) @@ -60,7 +57,6 @@ def test_valid_collection_page(): """Test creating a valid CollectionPage.""" page = APCollectionPage( id="https://example.com/collection/123/page/1", - type="CollectionPage", part_of="https://example.com/collection/123", items=["https://example.com/item/1"] ) @@ -71,7 +67,6 @@ def test_collection_page_with_navigation(): """Test creating a CollectionPage with navigation links.""" page = APCollectionPage( id="https://example.com/collection/123/page/2", - type="CollectionPage", part_of="https://example.com/collection/123", items=["https://example.com/item/2"], next="https://example.com/collection/123/page/3", @@ -84,7 +79,6 @@ def test_valid_ordered_collection_page(): """Test creating a valid OrderedCollectionPage.""" page = APOrderedCollectionPage( id="https://example.com/collection/123/page/1", - type="OrderedCollectionPage", part_of="https://example.com/collection/123", ordered_items=["https://example.com/item/1"], start_index=0 @@ -97,7 +91,6 @@ def test_invalid_ordered_page_negative_index(): with pytest.raises(ValidationError): APOrderedCollectionPage( id="https://example.com/collection/123/page/1", - type="OrderedCollectionPage", start_index=-1 ) @@ -105,7 +98,6 @@ def test_collection_with_object_items(): """Test creating a Collection with APObject items.""" collection = APCollection( id="https://example.com/collection/123", - type="Collection", items=[{ "id": "https://example.com/item/1", "type": "Note", @@ -119,7 +111,6 @@ def test_ordered_collection_empty(): """Test creating an empty OrderedCollection.""" collection = APOrderedCollection( id="https://example.com/collection/123", - type="OrderedCollection", total_items=0 ) assert collection.total_items == 0 diff --git a/tests/tests_models/test_factories.py b/tests/unit_tests/models/test_factories.py similarity index 100% rename from tests/tests_models/test_factories.py rename to tests/unit_tests/models/test_factories.py diff --git a/tests/tests_models/test_imports.py b/tests/unit_tests/models/test_imports.py similarity index 100% rename from tests/tests_models/test_imports.py rename to tests/unit_tests/models/test_imports.py diff --git a/tests/tests_models/test_interactions.py b/tests/unit_tests/models/test_interactions.py similarity index 97% rename from tests/tests_models/test_interactions.py rename to tests/unit_tests/models/test_interactions.py index 3464332d56a7c0627fb8212c63300e6054c63344..7389fdd9b7ef5df93bfd7a573a6258a81f4a3e00 100644 --- a/tests/tests_models/test_interactions.py +++ b/tests/unit_tests/models/test_interactions.py @@ -72,13 +72,11 @@ def test_collection_with_pagination(): items = [ APNote( id=f"https://example.com/note/{i}", - type="Note", content=f"Note {i}" ) for i in range(1, 6) ] collection = APCollection( id="https://example.com/collection/123", - type="Collection", total_items=len(items), items=items, first="https://example.com/collection/123?page=1", @@ -186,7 +184,6 @@ def test_mention_in_content(): mention = APMention( id="https://example.com/mention/123", - type="Mention", href=mentioned.id, name=f"@{mentioned.preferred_username}" ) @@ -206,21 +203,18 @@ def test_collection_pagination_interaction(): items = [ APNote( id=f"https://example.com/note/{i}", - type="Note", content=f"Note {i}" ) for i in range(1, 11) ] collection = APOrderedCollection( id="https://example.com/collection/123", - type="OrderedCollection", total_items=len(items), ordered_items=items[:5] # First page ) page = APOrderedCollectionPage( id="https://example.com/collection/123/page/1", - type="OrderedCollectionPage", part_of=collection.id, ordered_items=items[5:], # Second page start_index=5 @@ -256,7 +250,6 @@ def test_actor_relationships(): relationship = APRelationship( id="https://example.com/relationship/1", - type="Relationship", subject=member.id, object=group.id, relationship="member" @@ -270,7 +263,6 @@ def test_content_with_multiple_attachments(): """Test creating content with multiple types of attachments.""" image = APImage( id="https://example.com/image/1", - type="Image", url="https://example.com/image.jpg", width=800, height=600 @@ -278,7 +270,6 @@ def test_content_with_multiple_attachments(): document = APDocument( id="https://example.com/document/1", - type="Document", name="Specification", url="https://example.com/spec.pdf" ) @@ -314,7 +305,6 @@ def test_event_series(): collection = APOrderedCollection( id="https://example.com/collection/workshop-series", - type="OrderedCollection", name="Workshop Series", ordered_items=events ) diff --git a/tests/tests_models/test_links.py b/tests/unit_tests/models/test_links.py similarity index 94% rename from tests/tests_models/test_links.py rename to tests/unit_tests/models/test_links.py index 40a3f3f7561fa972496853d85712196cb8f63fba..5e47e287e65c8f839ec3c5f5f195dc97dc461c35 100644 --- a/tests/tests_models/test_links.py +++ b/tests/unit_tests/models/test_links.py @@ -10,7 +10,7 @@ def test_valid_link(): """Test creating a valid Link object.""" link = APLink( id="https://example.com/link/123", - type="Link", + href="https://example.com/resource" ) assert link.type == "Link" @@ -20,7 +20,7 @@ def test_link_with_optional_fields(): """Test creating a Link with all optional fields.""" link = APLink( id="https://example.com/link/123", - type="Link", + href="https://example.com/resource", name="Test Link", hreflang="en", @@ -40,7 +40,6 @@ def test_link_with_rel(): """Test creating a Link with relationship fields.""" link = APLink( id="https://example.com/link/123", - type="Link", href="https://example.com/resource", rel=["canonical", "alternate"] ) @@ -51,7 +50,6 @@ def test_valid_mention(): """Test creating a valid Mention object.""" mention = APMention( id="https://example.com/mention/123", - type="Mention", href="https://example.com/user/alice", name="@alice" ) @@ -64,7 +62,6 @@ def test_invalid_link_missing_href(): with pytest.raises(ValidationError): APLink( id="https://example.com/link/123", - type="Link" ) def test_invalid_link_invalid_url(): @@ -72,7 +69,6 @@ def test_invalid_link_invalid_url(): with pytest.raises(ValidationError): APLink( id="https://example.com/link/123", - type="Link", href="not-a-url" ) @@ -81,7 +77,6 @@ def test_invalid_link_invalid_media_type(): with pytest.raises(ValidationError): APLink( id="https://example.com/link/123", - type="Link", href="https://example.com/resource", media_type="invalid/type" ) diff --git a/tests/tests_models/test_objects.py b/tests/unit_tests/models/test_objects.py similarity index 100% rename from tests/tests_models/test_objects.py rename to tests/unit_tests/models/test_objects.py diff --git a/tests/tests_serializers/__init__.py b/tests/unit_tests/serializers/__init__.py similarity index 100% rename from tests/tests_serializers/__init__.py rename to tests/unit_tests/serializers/__init__.py diff --git a/tests/tests_serializers/test_serialization.py b/tests/unit_tests/serializers/test_serialization.py similarity index 55% rename from tests/tests_serializers/test_serialization.py rename to tests/unit_tests/serializers/test_serialization.py index 5f9468fe2a7c1d2e590da55a80645d1c82812951..9d30b32eedf11935f2298dc2f2575637cb9f82fb 100644 --- a/tests/tests_serializers/test_serialization.py +++ b/tests/unit_tests/serializers/test_serialization.py @@ -1,13 +1,14 @@ """ -test_serialization.py -This module contains tests for JSON serialization of ActivityPub objects. +Tests for ActivityPub serialization. """ + import pytest -from datetime import datetime -from pydantic import ValidationError +from datetime import datetime, timezone +import json + from pyfed.models import ( - APObject, APPerson, APNote, APImage, APCollection, - APCreate, APLike, APFollow, APPlace, APEvent + APObject, APNote, APPerson, APCollection, + APCreate, APPlace, APEvent ) from pyfed.serializers.json_serializer import ActivityPubSerializer @@ -19,32 +20,29 @@ def test_serialize_ap_object(): name="Test Object", content="This is a test object." ) - serialized = ActivityPubSerializer.serialize(obj) - assert '"@context"' in serialized - assert '"type": "Object"' in serialized - assert '"name": "Test Object"' in serialized - -def test_serialize_without_context(): - """Test serialization without @context field.""" - obj = APObject( - id="https://example.com/object/123", - type="Object", - name="Test Object" - ) - serialized = ActivityPubSerializer.serialize(obj, include_context=False) - assert '"@context"' not in serialized + serialized = obj.serialize() + + # Verify serialization + assert serialized["@context"] == "https://www.w3.org/ns/activitystreams" + assert serialized["id"] == "https://example.com/object/123" + assert serialized["type"] == "Object" + assert serialized["name"] == "Test Object" + assert serialized["content"] == "This is a test object." def test_serialize_with_datetime(): """Test serialization of objects with datetime fields.""" - now = datetime.utcnow() + now = datetime.now(timezone.utc) obj = APObject( id="https://example.com/object/123", type="Object", published=now, updated=now ) - serialized = ActivityPubSerializer.serialize(obj) - assert now.isoformat() in serialized + serialized = obj.serialize() + + # Verify datetime serialization + assert serialized["published"] == now.isoformat() + assert serialized["updated"] == now.isoformat() def test_serialize_nested_objects(): """Test serialization of objects with nested objects.""" @@ -61,9 +59,12 @@ def test_serialize_nested_objects(): content="Hello, World!", attributed_to=author ) - serialized = ActivityPubSerializer.serialize(note) - assert '"attributedTo"' in serialized - assert '"type": "Person"' in serialized + serialized = note.serialize() + + # Verify nested object serialization + assert serialized["attributedTo"]["id"] == "https://example.com/users/alice" + assert serialized["attributedTo"]["type"] == "Person" + assert serialized["attributedTo"]["name"] == "Alice" def test_serialize_collection(): """Test serialization of collections.""" @@ -80,10 +81,13 @@ def test_serialize_collection(): total_items=len(items), items=items ) - serialized = ActivityPubSerializer.serialize(collection) - assert '"type": "Collection"' in serialized - assert '"totalItems": 3' in serialized - assert '"items"' in serialized + serialized = collection.serialize() + + # Verify collection serialization + assert serialized["type"] == "Collection" + assert serialized["totalItems"] == 3 + assert len(serialized["items"]) == 3 + assert all(item["type"] == "Note" for item in serialized["items"]) def test_serialize_activity(): """Test serialization of activities.""" @@ -98,46 +102,44 @@ def test_serialize_activity(): actor="https://example.com/users/alice", object=note ) - serialized = ActivityPubSerializer.serialize(create) - assert '"type": "Create"' in serialized - assert '"object"' in serialized - assert '"content": "Hello, World!"' in serialized - -def test_serialize_with_urls(): - """Test serialization of objects with URL fields.""" - place = APPlace( - id="https://example.com/places/123", - type="Place", - name="Test Place", - latitude=51.5074, - longitude=-0.1278 - ) - event = APEvent( - id="https://example.com/events/123", - type="Event", - name="Test Event", - location=place, - url="https://example.com/events/123/details" - ) - serialized = ActivityPubSerializer.serialize(event) - assert '"url":' in serialized - assert '"location"' in serialized + serialized = create.serialize() + + # Verify activity serialization + assert serialized["type"] == "Create" + assert serialized["actor"] == "https://example.com/users/alice" + assert serialized["object"]["type"] == "Note" + assert serialized["object"]["content"] == "Hello, World!" def test_deserialize_ap_object(): """Test basic object deserialization.""" - json_str = ''' - { + data = { "@context": "https://www.w3.org/ns/activitystreams", "type": "Object", "id": "https://example.com/object/123", "name": "Test Object", "content": "This is a test object." } - ''' - obj = ActivityPubSerializer.deserialize(json_str, APObject) + obj = ActivityPubSerializer.deserialize(data, APObject) + + # Verify deserialization + assert str(obj.id) == "https://example.com/object/123" assert obj.type == "Object" assert obj.name == "Test Object" + assert obj.content == "This is a test object." + +def test_deserialize_from_json_string(): + """Test deserialization from JSON string.""" + json_str = json.dumps({ + "type": "Object", + "id": "https://example.com/object/123", + "name": "Test Object" + }) + obj = ActivityPubSerializer.deserialize(json_str, APObject) + + # Verify deserialization from string assert str(obj.id) == "https://example.com/object/123" + assert obj.type == "Object" + assert obj.name == "Test Object" def test_deserialize_invalid_json(): """Test deserialization of invalid JSON.""" @@ -146,9 +148,9 @@ def test_deserialize_invalid_json(): def test_deserialize_missing_required_fields(): """Test deserialization with missing required fields.""" - json_str = '{"type": "Object", "name": "Test"}' # Missing required 'id' - with pytest.raises(ValidationError): - ActivityPubSerializer.deserialize(json_str, APObject) + data = {"type": "Object", "name": "Test"} # Missing required 'id' + with pytest.raises(Exception): # Pydantic will raise validation error + ActivityPubSerializer.deserialize(data, APObject) def test_serialize_deserialize_complex_object(): """Test round-trip serialization and deserialization.""" @@ -159,32 +161,27 @@ def test_serialize_deserialize_complex_object(): to=["https://example.com/users/bob"], cc=["https://www.w3.org/ns/activitystreams#Public"] ) - serialized = ActivityPubSerializer.serialize(original) + serialized = original.serialize() deserialized = ActivityPubSerializer.deserialize(serialized, APNote) - assert deserialized.id == original.id + + # Verify round-trip + assert str(deserialized.id) == str(original.id) + assert deserialized.type == original.type assert deserialized.content == original.content assert deserialized.to == original.to assert deserialized.cc == original.cc -def test_serialize_with_custom_json_options(): - """Test serialization with custom JSON options.""" - obj = APObject( - id="https://example.com/object/123", - type="Object", - name="Test Object" - ) - serialized = ActivityPubSerializer.serialize(obj, indent=2) - assert '\n "' in serialized # Check for indentation - def test_deserialize_with_extra_fields(): """Test deserialization with extra fields in JSON.""" - json_str = ''' - { + data = { "type": "Object", "id": "https://example.com/object/123", "name": "Test Object", "extra_field": "Should be ignored" } - ''' - obj = ActivityPubSerializer.deserialize(json_str, APObject) - assert not hasattr(obj, "extra_field") + obj = ActivityPubSerializer.deserialize(data, APObject) + + # Verify extra fields are handled + assert str(obj.id) == "https://example.com/object/123" + assert obj.type == "Object" + assert obj.name == "Test Object"