import sqlite3
import json
import numpy as np
from datetime import datetime
from typing import Optional, List, Dict, Any, Tuple
from contextlib import contextmanager
import logging

from ..models import (
    FeedItem, FeedAuthor, FeedSignalRecord, FeedAccountRecord,
    LikeFeedState, StoredFeedState, FeedItemViewModel, FeedImportPayload
)

logger = logging.getLogger(__name__)


class SqliteFeedStore:
    """SQLite-backed store for feed data, embeddings, and signals."""

    def __init__(self, db_path: str = "app_data/feeds.db"):
        self.db_path = db_path
        self._init_db()

    @contextmanager
    def _get_connection(self):
        conn = sqlite3.connect(self.db_path)
        conn.row_factory = sqlite3.Row
        try:
            yield conn
        finally:
            conn.close()

    def _init_db(self) -> None:
        """Initialize database tables."""
        with self._get_connection() as conn:
            # Accounts table
            conn.execute("""
                CREATE TABLE IF NOT EXISTS accounts (
                    did TEXT PRIMARY KEY,
                    handle TEXT NOT NULL,
                    pds_url TEXT,
                    updated_at TEXT NOT NULL
                )
            """)

            # Feed items table
            conn.execute("""
                CREATE TABLE IF NOT EXISTS feed_items (
                    subject_uri TEXT PRIMARY KEY,
                    actor_did TEXT NOT NULL,
                    author_did TEXT NOT NULL,
                    author_handle TEXT NOT NULL,
                    author_display_name TEXT NOT NULL,
                    text TEXT NOT NULL,
                    origin TEXT NOT NULL,
                    labels TEXT,
                    created_at TEXT,
                    imported_at TEXT NOT NULL
                )
            """)

            # Authors table (liked network)
            conn.execute("""
                CREATE TABLE IF NOT EXISTS liked_authors (
                    did TEXT PRIMARY KEY,
                    actor_did TEXT NOT NULL,
                    handle TEXT NOT NULL,
                    display_name TEXT NOT NULL,
                    like_count INTEGER NOT NULL,
                    source TEXT NOT NULL,
                    updated_at TEXT NOT NULL
                )
            """)

            # Embeddings table
            conn.execute("""
                CREATE TABLE IF NOT EXISTS embeddings (
                    subject_uri TEXT PRIMARY KEY,
                    embedding BLOB NOT NULL,
                    updated_at TEXT NOT NULL
                )
            """)

            # Signals table
            conn.execute("""
                CREATE TABLE IF NOT EXISTS signals (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    actor_did TEXT NOT NULL,
                    collection TEXT NOT NULL,
                    subject_uri TEXT NOT NULL,
                    signal TEXT NOT NULL,
                    weight REAL NOT NULL,
                    generator_id TEXT NOT NULL,
                    created_at TEXT NOT NULL,
                    metadata TEXT
                )
            """)

            # Indexes
            conn.execute("CREATE INDEX IF NOT EXISTS idx_items_actor ON feed_items(actor_did)")
            conn.execute("CREATE INDEX IF NOT EXISTS idx_items_author ON feed_items(author_did)")
            conn.execute("CREATE INDEX IF NOT EXISTS idx_authors_actor ON liked_authors(actor_did)")
            conn.execute("CREATE INDEX IF NOT EXISTS idx_signals_actor ON signals(actor_did)")
            conn.execute("CREATE INDEX IF NOT EXISTS idx_signals_generator ON signals(actor_did, generator_id)")

            conn.commit()

    def import_feed(self, request: FeedImportPayload) -> None:
        """Import a liked network payload."""
        with self._get_connection() as conn:
            now = datetime.utcnow().isoformat()
            
            account = request.account
            account_did = account.get("did", "")
            account_handle = account.get("handle", "")
            pds_url = account.get("pdsUrl")

            # Upsert account
            conn.execute(
                """INSERT OR REPLACE INTO accounts (did, handle, pds_url, updated_at)
                   VALUES (?, ?, ?, ?)""",
                (account_did, account_handle, pds_url, now)
            )

            # Insert/update authors
            for author in request.authors:
                conn.execute(
                    """INSERT OR REPLACE INTO liked_authors
                       (did, actor_did, handle, display_name, like_count, source, updated_at)
                       VALUES (?, ?, ?, ?, ?, ?, ?)""",
                    (author.get("did"), account_did, author.get("handle"), author.get("displayName"),
                     author.get("likeCount", 0), author.get("source", "unknown"), now)
                )

            # Insert/update items
            for item in request.items:
                labels = item.get("labels", [])
                conn.execute(
                    """INSERT OR REPLACE INTO feed_items
                       (subject_uri, actor_did, author_did, author_handle, author_display_name,
                        text, origin, labels, created_at, imported_at)
                       VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
                    (item.get("subjectUri"), account_did, item.get("authorDid"), item.get("authorHandle"),
                     item.get("authorDisplayName"), item.get("text"), item.get("origin"),
                     json.dumps(labels) if labels else None,
                     item.get("createdAt"), now)
                )

            conn.commit()
            logger.info(f"Imported {len(request.items)} items for {account_did}")

    def get_like_feed(self, actor_did: Optional[str]) -> LikeFeedState:
        """Get the liked feed state for an actor."""
        if not actor_did:
            return LikeFeedState()

        with self._get_connection() as conn:
            # Get items
            cursor = conn.execute(
                """SELECT * FROM feed_items WHERE actor_did = ? ORDER BY imported_at DESC""",
                (actor_did,)
            )
            items = []
            for row in cursor:
                items.append(FeedItem(
                    subject_uri=row["subject_uri"],
                    actor_did=row["actor_did"],
                    author_did=row["author_did"],
                    author_handle=row["author_handle"],
                    author_display_name=row["author_display_name"],
                    text=row["text"],
                    origin=row["origin"],
                    labels=json.loads(row["labels"]) if row["labels"] else [],
                    created_at=row["created_at"],
                    imported_at=row["imported_at"]
                ))

            # Get authors
            cursor = conn.execute(
                """SELECT * FROM liked_authors WHERE actor_did = ? ORDER BY like_count DESC""",
                (actor_did,)
            )
            authors = []
            for row in cursor:
                authors.append(FeedAuthor(
                    did=row["did"],
                    handle=row["handle"],
                    display_name=row["display_name"],
                    like_count=row["like_count"],
                    source=row["source"]
                ))

            return LikeFeedState(
                actor_did=actor_did,
                liked_authors=authors,
                total_likes=len(items),
                feed_items=items
            )

    def get_state(self, actor_did: Optional[str], generator_id: str,
                  ranking_service, definition) -> StoredFeedState:
        """Get feed state with ranking applied."""
        if not actor_did:
            return StoredFeedState(generator_id=generator_id)

        like_feed = self.get_like_feed(actor_did)
        if not like_feed.feed_items:
            return StoredFeedState(actor_did=actor_did, generator_id=generator_id)

        # Get signals
        signals = self.get_signals(actor_did, generator_id)

        # Rank items using the ranking service
        ranked_items = ranking_service.rank(
            like_feed.feed_items,
            signals,
            definition,
            self.get_embeddings([i.subject_uri for i in like_feed.feed_items])
        )

        return StoredFeedState(
            actor_did=actor_did,
            generator_id=generator_id,
            feed_items=ranked_items,
        )

    def add_signal(self, actor_did: str, signal: FeedSignalRecord) -> None:
        """Add a signal for an actor."""
        with self._get_connection() as conn:
            conn.execute(
                """INSERT INTO signals
                   (actor_did, collection, subject_uri, signal, weight, generator_id, created_at, metadata)
                   VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
                (actor_did, signal.collection, signal.subject_uri, signal.signal,
                 signal.weight, signal.generator_id, signal.created_at,
                 json.dumps(signal.metadata) if signal.metadata else None)
            )
            conn.commit()

    def get_signals(self, actor_did: str, generator_id: Optional[str] = None) -> List[FeedSignalRecord]:
        """Get signals for an actor, optionally filtered by generator."""
        with self._get_connection() as conn:
            if generator_id:
                cursor = conn.execute(
                    """SELECT * FROM signals WHERE actor_did = ? AND generator_id = ?
                       ORDER BY created_at DESC""",
                    (actor_did, generator_id)
                )
            else:
                cursor = conn.execute(
                    """SELECT * FROM signals WHERE actor_did = ? ORDER BY created_at DESC""",
                    (actor_did,)
                )

            signals = []
            for row in cursor:
                signals.append(FeedSignalRecord(
                    collection=row["collection"],
                    subject_uri=row["subject_uri"],
                    signal=row["signal"],
                    weight=row["weight"],
                    generator_id=row["generator_id"],
                    created_at=row["created_at"],
                    metadata=json.loads(row["metadata"]) if row["metadata"] else {}
                ))
            return signals

    def store_embedding(self, subject_uri: str, embedding: np.ndarray) -> None:
        """Store text embedding for a post."""
        with self._get_connection() as conn:
            conn.execute(
                """INSERT OR REPLACE INTO embeddings (subject_uri, embedding, updated_at)
                   VALUES (?, ?, ?)""",
                (subject_uri, embedding.tobytes(), datetime.utcnow().isoformat())
            )
            conn.commit()

    def get_embeddings(self, subject_uris: List[str]) -> Dict[str, np.ndarray]:
        """Get embeddings for a list of posts."""
        with self._get_connection() as conn:
            placeholders = ",".join(["?"] * len(subject_uris))
            cursor = conn.execute(
                f"SELECT * FROM embeddings WHERE subject_uri IN ({placeholders})",
                tuple(subject_uris)
            )

            results = {}
            for row in cursor:
                uri = row["subject_uri"]
                embedding = np.frombuffer(row["embedding"], dtype=np.float32)
                results[uri] = embedding

            return results

    def get_all_embeddings(self) -> Dict[str, np.ndarray]:
        """Get all stored embeddings."""
        with self._get_connection() as conn:
            cursor = conn.execute("SELECT * FROM embeddings")
            results = {}
            for row in cursor:
                uri = row["subject_uri"]
                embedding = np.frombuffer(row["embedding"], dtype=np.float32)
                results[uri] = embedding
            return results
An unhandled error has occurred. Reload 🗙