import re
import numpy as np
from typing import List, Dict
import logging

logger = logging.getLogger(__name__)


class TextEmbeddingService:
    """Simple TF-IDF based text embedding service."""

    def __init__(self, max_features: int = 1000):
        self.max_features = max_features
        self.vocab: Dict[str, int] = {}
        self.idf: np.ndarray = np.array([])
        self._initialized = False

    def _tokenize(self, text: str) -> List[str]:
        """Simple tokenization."""
        # Lowercase and extract words
        text = text.lower()
        # Remove URLs, mentions, and hashtags for cleaner tokens
        text = re.sub(r'https?://\S+', '', text)
        text = re.sub(r'@\w+', '', text)
        text = re.sub(r'#\w+', '', text)
        # Extract alphanumeric tokens
        tokens = re.findall(r'\b[a-z][a-z0-9]{2,}\b', text)
        return tokens

    def _build_vocab(self, texts: List[str]) -> None:
        """Build vocabulary from texts."""
        token_freqs: Dict[str, int] = {}

        for text in texts:
            tokens = self._tokenize(text)
            seen = set()
            for token in tokens:
                if token not in seen:
                    token_freqs[token] = token_freqs.get(token, 0) + 1
                    seen.add(token)

        # Take top max_features by frequency
        sorted_tokens = sorted(token_freqs.items(), key=lambda x: x[1], reverse=True)
        top_tokens = sorted_tokens[:self.max_features]

        self.vocab = {token: idx for idx, (token, _) in enumerate(top_tokens)}

        # Compute IDF
        n_docs = len(texts)
        self.idf = np.zeros(len(self.vocab))
        for token, idx in self.vocab.items():
            df = token_freqs[token]
            self.idf[idx] = np.log((n_docs + 1) / (df + 1)) + 1

        self._initialized = True
        logger.info(f"Built vocab with {len(self.vocab)} tokens")

    def compute_embedding(self, text: str) -> np.ndarray:
        """Compute TF-IDF embedding for a text."""
        if not self._initialized:
            # Return zeros if not initialized
            return np.zeros(self.max_features)

        tokens = self._tokenize(text)

        # Count tokens
        token_counts: Dict[str, int] = {}
        for token in tokens:
            if token in self.vocab:
                token_counts[token] = token_counts.get(token, 0) + 1

        # TF-IDF vector
        vec = np.zeros(len(self.vocab))
        for token, count in token_counts.items():
            idx = self.vocab[token]
            tf = np.log1p(count)  # log(1 + count)
            vec[idx] = tf * self.idf[idx]

        # L2 normalize
        norm = np.linalg.norm(vec)
        if norm > 0:
            vec = vec / norm

        # Pad to max_features if needed
        if len(vec) < self.max_features:
            vec = np.pad(vec, (0, self.max_features - len(vec)))

        return vec.astype(np.float32)

    def compute_embeddings(self, texts: List[str]) -> List[np.ndarray]:
        """Compute embeddings for multiple texts."""
        return [self.compute_embedding(t) for t in texts]

    def fit_transform(self, texts: List[str]) -> List[np.ndarray]:
        """Build vocab and compute embeddings."""
        self._build_vocab(texts)
        return self.compute_embeddings(texts)

    def cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        """Compute cosine similarity between two vectors."""
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        if norm_a == 0 or norm_b == 0:
            return 0.0
        return float(np.dot(a, b) / (norm_a * norm_b))

    def compute_taste_profile(self, seed_texts: List[str]) -> np.ndarray:
        """Compute a weighted average taste profile from seed texts."""
        if not seed_texts:
            return np.zeros(self.max_features)

        embeddings = self.compute_embeddings(seed_texts)
        profile = np.mean(embeddings, axis=0)

        # L2 normalize
        norm = np.linalg.norm(profile)
        if norm > 0:
            profile = profile / norm

        return profile.astype(np.float32)
An unhandled error has occurred. Reload 🗙