"""
Adapter WebSocket para comunicação em tempo real
"""
import json
import asyncio
from typing import Dict, Any, Callable, Set
from dataclasses import dataclass


@dataclass
class WebSocketConnection:
    """Representa uma conexão WebSocket"""
    id: str
    user_id: str
    send_func: Callable
    connected_at: str


class SocketAdapter:
    """
    Adapter WebSocket para chat em tempo real
    """

    def __init__(self, brain):
        self.brain = brain
        self.connections: Dict[str, WebSocketConnection] = {}
        self.user_connections: Dict[str, Set[str]] = {}

    async def handle_connect(self, connection_id: str, user_id: str, send_func: Callable):
        """
        Nova conexão WebSocket
        """
        from datetime import datetime
        conn = WebSocketConnection(
            id=connection_id,
            user_id=user_id,
            send_func=send_func,
            connected_at=datetime.now().isoformat()
        )

        self.connections[connection_id] = conn

        if user_id not in self.user_connections:
            self.user_connections[user_id] = set()
        self.user_connections[user_id].add(connection_id)

        # Enviar mensagem de boas-vindas
        await self._send_to_connection(connection_id, {
            "type": "connected",
            "connection_id": connection_id,
            "message": "Conectado ao assistente IA"
        })

    async def handle_disconnect(self, connection_id: str):
        """
        Desconexão
        """
        if connection_id in self.connections:
            conn = self.connections[connection_id]
            user_id = conn.user_id

            del self.connections[connection_id]

            if user_id in self.user_connections:
                self.user_connections[user_id].discard(connection_id)
                if not self.user_connections[user_id]:
                    del self.user_connections[user_id]

    async def handle_message(self, connection_id: str, data: Dict[str, Any]):
        """
        Processa mensagem recebida via WebSocket
        """
        if connection_id not in self.connections:
            return

        conn = self.connections[connection_id]
        message = data.get("message", "").strip()

        if not message:
            await self._send_to_connection(connection_id, {
                "type": "error",
                "message": "Mensagem vazia"
            })
            return

        # Enviar indicador de digitação
        await self._send_to_connection(connection_id, {
            "type": "typing",
            "status": "start"
        })

        try:
            # Processar com streaming
            full_response = []
            async for chunk in self.brain.stream_process(
                message=message,
                user_id=conn.user_id,
                session_id=data.get("session_id")
            ):
                full_response.append(chunk)
                await self._send_to_connection(connection_id, {
                    "type": "chunk",
                    "content": chunk
                })

            # Enviar confirmação de fim
            complete_response = "".join(full_response)
            await self._send_to_connection(connection_id, {
                "type": "message_complete",
                "full_response": complete_response
            })

        except Exception as e:
            await self._send_to_connection(connection_id, {
                "type": "error",
                "message": str(e)
            })

        finally:
            await self._send_to_connection(connection_id, {
                "type": "typing",
                "status": "stop"
            })

    async def broadcast_to_user(self, user_id: str, message: Dict[str, Any]):
        """
        Envia mensagem para todas as conexões de um usuário
        """
        if user_id not in self.user_connections:
            return

        for conn_id in self.user_connections[user_id]:
            await self._send_to_connection(conn_id, message)

    async def _send_to_connection(self, connection_id: str, data: Dict[str, Any]):
        """
        Envia dados para uma conexão específica
        """
        if connection_id in self.connections:
            conn = self.connections[connection_id]
            try:
                await conn.send_func(json.dumps(data))
            except Exception:
                # Conexão provavelmente fechada
                await self.handle_disconnect(connection_id)

    def get_stats(self) -> Dict[str, Any]:
        """
        Estatísticas de conexões
        """
        return {
            "total_connections": len(self.connections),
            "unique_users": len(self.user_connections),
            "connections_per_user": {
                user_id: len(conns)
                for user_id, conns in self.user_connections.items()
            }
        }