"""Tests for text message batching across all gateway adapters.

When a user sends a long message, the messaging client splits it at the
platform's character limit.  Each adapter should buffer rapid successive
text messages from the same session and aggregate them before dispatching.

Covers: Discord, Matrix, WeCom, and the adaptive delay logic for
Telegram and Feishu.
"""

import asyncio
from unittest.mock import AsyncMock

import pytest

from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType, SessionSource


# =====================================================================
# Helpers
# =====================================================================

def _make_event(
    text: str,
    platform: Platform,
    chat_id: str = "12345",
    msg_type: MessageType = MessageType.TEXT,
) -> MessageEvent:
    return MessageEvent(
        text=text,
        message_type=msg_type,
        source=SessionSource(platform=platform, chat_id=chat_id, chat_type="dm"),
    )


# =====================================================================
# Discord text batching
# =====================================================================

def _make_discord_adapter():
    """Create a minimal DiscordAdapter for testing text batching."""
    from plugins.platforms.discord.adapter import DiscordAdapter

    config = PlatformConfig(enabled=True, token="test-token")
    adapter = object.__new__(DiscordAdapter)
    adapter._platform = Platform.DISCORD
    adapter.config = config
    adapter._pending_text_batches = {}
    adapter._pending_text_batch_tasks = {}
    adapter._text_batch_delay_seconds = 0.1  # fast for tests
    adapter._text_batch_split_delay_seconds = 0.3  # fast for tests
    adapter._active_sessions = {}
    adapter._pending_messages = {}
    adapter._message_handler = AsyncMock()
    adapter.handle_message = AsyncMock()
    return adapter


class TestDiscordTextBatching:
    @pytest.mark.asyncio
    async def test_single_message_dispatched_after_delay(self):
        adapter = _make_discord_adapter()
        event = _make_event("hello world", Platform.DISCORD)

        adapter._enqueue_text_event(event)

        # Not dispatched yet
        adapter.handle_message.assert_not_called()

        # Wait for flush
        await asyncio.sleep(0.2)

        adapter.handle_message.assert_called_once()
        dispatched = adapter.handle_message.call_args[0][0]
        assert dispatched.text == "hello world"

    @pytest.mark.asyncio
    async def test_split_messages_aggregated(self):
        """Two rapid messages from the same chat should be merged."""
        adapter = _make_discord_adapter()

        adapter._enqueue_text_event(_make_event("Part one of a long", Platform.DISCORD))
        await asyncio.sleep(0.02)
        adapter._enqueue_text_event(_make_event("message that was split.", Platform.DISCORD))

        adapter.handle_message.assert_not_called()

        await asyncio.sleep(0.2)

        adapter.handle_message.assert_called_once()
        text = adapter.handle_message.call_args[0][0].text
        assert "Part one" in text
        assert "split" in text

    @pytest.mark.asyncio
    async def test_three_way_split_aggregated(self):
        adapter = _make_discord_adapter()

        adapter._enqueue_text_event(_make_event("chunk 1", Platform.DISCORD))
        await asyncio.sleep(0.02)
        adapter._enqueue_text_event(_make_event("chunk 2", Platform.DISCORD))
        await asyncio.sleep(0.02)
        adapter._enqueue_text_event(_make_event("chunk 3", Platform.DISCORD))

        await asyncio.sleep(0.2)

        adapter.handle_message.assert_called_once()
        text = adapter.handle_message.call_args[0][0].text
        assert "chunk 1" in text
        assert "chunk 2" in text
        assert "chunk 3" in text

    @pytest.mark.asyncio
    async def test_different_chats_not_merged(self):
        adapter = _make_discord_adapter()

        adapter._enqueue_text_event(_make_event("from A", Platform.DISCORD, chat_id="111"))
        adapter._enqueue_text_event(_make_event("from B", Platform.DISCORD, chat_id="222"))

        await asyncio.sleep(0.2)

        assert adapter.handle_message.call_count == 2

    @pytest.mark.asyncio
    async def test_batch_cleans_up_after_flush(self):
        adapter = _make_discord_adapter()

        adapter._enqueue_text_event(_make_event("test", Platform.DISCORD))
        await asyncio.sleep(0.2)

        assert len(adapter._pending_text_batches) == 0

    @pytest.mark.asyncio
    async def test_adaptive_delay_for_near_limit_chunk(self):
        """Chunks near the 2000-char limit should trigger longer delay."""
        adapter = _make_discord_adapter()
        # Simulate a chunk near Discord's 2000-char split point
        long_text = "x" * 1950
        adapter._enqueue_text_event(_make_event(long_text, Platform.DISCORD))

        # After the short delay (0.1s), should NOT have flushed yet (split delay is 0.3s)
        await asyncio.sleep(0.15)
        adapter.handle_message.assert_not_called()

        # After the split delay, should be flushed
        await asyncio.sleep(0.25)
        adapter.handle_message.assert_called_once()

    @pytest.mark.asyncio
    async def test_shield_protects_handle_message_from_cancel(self):
        """Regression guard: a follow-up chunk arriving while
        handle_message is mid-flight must NOT cancel the running
        dispatch.  _enqueue_text_event fires prior_task.cancel() on
        every new chunk; without asyncio.shield around handle_message
        the cancel propagates into the agent's streaming request and
        aborts the response.
        """
        adapter = _make_discord_adapter()

        handle_started = asyncio.Event()
        release_handle = asyncio.Event()
        first_handle_cancelled = asyncio.Event()
        first_handle_completed = asyncio.Event()
        call_count = [0]

        async def slow_handle(event):
            call_count[0] += 1
            # Only the first call (batch 1) is the one we're protecting.
            if call_count[0] == 1:
                handle_started.set()
                try:
                    await release_handle.wait()
                    first_handle_completed.set()
                except asyncio.CancelledError:
                    first_handle_cancelled.set()
                    raise
            # Second call (batch 2) returns immediately — not the subject
            # of this test.

        adapter.handle_message = slow_handle

        # Prime batch 1 and wait for it to land inside handle_message.
        adapter._enqueue_text_event(_make_event("batch 1", Platform.DISCORD))
        await asyncio.wait_for(handle_started.wait(), timeout=1.0)

        # A new chunk arrives — _enqueue_text_event fires
        # prior_task.cancel() on batch 1's flush task, which is
        # currently awaiting inside handle_message.
        adapter._enqueue_text_event(_make_event("batch 2 follow-up", Platform.DISCORD))

        # Let the cancel propagate.
        await asyncio.sleep(0.05)

        # CRITICAL ASSERTION: batch 1's handle_message must NOT have
        # been cancelled.  Without asyncio.shield this assertion fails
        # because CancelledError propagates from the flush task's
        # `await self.handle_message(event)` into slow_handle.
        assert not first_handle_cancelled.is_set(), (
            "handle_message for batch 1 was cancelled by a follow-up "
            "chunk — asyncio.shield is missing or broken"
        )

        # Release batch 1's handle_message and let it complete.
        release_handle.set()
        await asyncio.wait_for(first_handle_completed.wait(), timeout=1.0)
        assert first_handle_completed.is_set()

        # Cleanup
        for task in list(adapter._pending_text_batch_tasks.values()):
            task.cancel()
        await asyncio.sleep(0.01)


# =====================================================================
# Matrix text batching
# =====================================================================

def _make_matrix_adapter():
    """Create a minimal MatrixAdapter for testing text batching."""
    from gateway.platforms.matrix import MatrixAdapter

    config = PlatformConfig(enabled=True, token="test-token")
    adapter = object.__new__(MatrixAdapter)
    adapter._platform = Platform.MATRIX
    adapter.config = config
    adapter._pending_text_batches = {}
    adapter._pending_text_batch_tasks = {}
    adapter._text_batch_delay_seconds = 0.1
    adapter._text_batch_split_delay_seconds = 0.3
    adapter._active_sessions = {}
    adapter._pending_messages = {}
    adapter._message_handler = AsyncMock()
    adapter.handle_message = AsyncMock()
    return adapter


class TestMatrixTextBatching:
    @pytest.mark.asyncio
    async def test_single_message_dispatched_after_delay(self):
        adapter = _make_matrix_adapter()
        event = _make_event("hello world", Platform.MATRIX)

        adapter._enqueue_text_event(event)

        adapter.handle_message.assert_not_called()
        await asyncio.sleep(0.2)

        adapter.handle_message.assert_called_once()
        assert adapter.handle_message.call_args[0][0].text == "hello world"

    @pytest.mark.asyncio
    async def test_split_messages_aggregated(self):
        adapter = _make_matrix_adapter()

        adapter._enqueue_text_event(_make_event("first part", Platform.MATRIX))
        await asyncio.sleep(0.02)
        adapter._enqueue_text_event(_make_event("second part", Platform.MATRIX))

        adapter.handle_message.assert_not_called()
        await asyncio.sleep(0.2)

        adapter.handle_message.assert_called_once()
        text = adapter.handle_message.call_args[0][0].text
        assert "first part" in text
        assert "second part" in text

    @pytest.mark.asyncio
    async def test_different_rooms_not_merged(self):
        adapter = _make_matrix_adapter()

        adapter._enqueue_text_event(_make_event("room A", Platform.MATRIX, chat_id="!aaa:matrix.org"))
        adapter._enqueue_text_event(_make_event("room B", Platform.MATRIX, chat_id="!bbb:matrix.org"))

        await asyncio.sleep(0.2)

        assert adapter.handle_message.call_count == 2

    @pytest.mark.asyncio
    async def test_adaptive_delay_for_near_limit_chunk(self):
        """Chunks near the 4000-char limit should trigger longer delay."""
        adapter = _make_matrix_adapter()
        long_text = "x" * 3950
        adapter._enqueue_text_event(_make_event(long_text, Platform.MATRIX))

        await asyncio.sleep(0.15)
        adapter.handle_message.assert_not_called()

        await asyncio.sleep(0.25)
        adapter.handle_message.assert_called_once()

    @pytest.mark.asyncio
    async def test_batch_cleans_up_after_flush(self):
        adapter = _make_matrix_adapter()
        adapter._enqueue_text_event(_make_event("test", Platform.MATRIX))
        await asyncio.sleep(0.2)
        assert len(adapter._pending_text_batches) == 0


# =====================================================================
# WeCom text batching
# =====================================================================

def _make_wecom_adapter():
    """Create a minimal WeComAdapter for testing text batching."""
    from gateway.platforms.wecom import WeComAdapter

    config = PlatformConfig(enabled=True, token="test-token")
    adapter = object.__new__(WeComAdapter)
    adapter._platform = Platform.WECOM
    adapter.config = config
    adapter._pending_text_batches = {}
    adapter._pending_text_batch_tasks = {}
    adapter._text_batch_delay_seconds = 0.1
    adapter._text_batch_split_delay_seconds = 0.3
    adapter._active_sessions = {}
    adapter._pending_messages = {}
    adapter._message_handler = AsyncMock()
    adapter.handle_message = AsyncMock()
    return adapter


class TestWeComTextBatching:
    @pytest.mark.asyncio
    async def test_single_message_dispatched_after_delay(self):
        adapter = _make_wecom_adapter()
        event = _make_event("hello world", Platform.WECOM)

        adapter._enqueue_text_event(event)

        adapter.handle_message.assert_not_called()
        await asyncio.sleep(0.2)

        adapter.handle_message.assert_called_once()
        assert adapter.handle_message.call_args[0][0].text == "hello world"

    @pytest.mark.asyncio
    async def test_split_messages_aggregated(self):
        adapter = _make_wecom_adapter()

        adapter._enqueue_text_event(_make_event("first part", Platform.WECOM))
        await asyncio.sleep(0.02)
        adapter._enqueue_text_event(_make_event("second part", Platform.WECOM))

        adapter.handle_message.assert_not_called()
        await asyncio.sleep(0.2)

        adapter.handle_message.assert_called_once()
        text = adapter.handle_message.call_args[0][0].text
        assert "first part" in text
        assert "second part" in text

    @pytest.mark.asyncio
    async def test_different_chats_not_merged(self):
        adapter = _make_wecom_adapter()

        adapter._enqueue_text_event(_make_event("chat A", Platform.WECOM, chat_id="chat_a"))
        adapter._enqueue_text_event(_make_event("chat B", Platform.WECOM, chat_id="chat_b"))

        await asyncio.sleep(0.2)

        assert adapter.handle_message.call_count == 2

    @pytest.mark.asyncio
    async def test_adaptive_delay_for_near_limit_chunk(self):
        """Chunks near the 4000-char limit should trigger longer delay."""
        adapter = _make_wecom_adapter()
        long_text = "x" * 3950
        adapter._enqueue_text_event(_make_event(long_text, Platform.WECOM))

        await asyncio.sleep(0.15)
        adapter.handle_message.assert_not_called()

        await asyncio.sleep(0.25)
        adapter.handle_message.assert_called_once()

    @pytest.mark.asyncio
    async def test_batch_cleans_up_after_flush(self):
        adapter = _make_wecom_adapter()
        adapter._enqueue_text_event(_make_event("test", Platform.WECOM))
        await asyncio.sleep(0.2)
        assert len(adapter._pending_text_batches) == 0


# =====================================================================
# Telegram adaptive delay (PR #6891)
# =====================================================================

def _make_telegram_adapter():
    """Create a minimal TelegramAdapter for testing adaptive delay."""
    from gateway.platforms.telegram import TelegramAdapter

    config = PlatformConfig(enabled=True, token="test-token")
    adapter = object.__new__(TelegramAdapter)
    adapter._platform = Platform.TELEGRAM
    adapter.config = config
    adapter._pending_text_batches = {}
    adapter._pending_text_batch_tasks = {}
    adapter._text_batch_delay_seconds = 0.1
    adapter._text_batch_split_delay_seconds = 0.3
    adapter._active_sessions = {}
    adapter._pending_messages = {}
    adapter._message_handler = AsyncMock()
    adapter.handle_message = AsyncMock()
    return adapter


class TestTelegramAdaptiveDelay:
    @pytest.mark.asyncio
    async def test_short_chunk_uses_normal_delay(self):
        adapter = _make_telegram_adapter()
        adapter._enqueue_text_event(_make_event("short msg", Platform.TELEGRAM))

        # Should flush after the normal 0.1s delay
        await asyncio.sleep(0.15)
        adapter.handle_message.assert_called_once()

    @pytest.mark.asyncio
    async def test_near_limit_chunk_uses_split_delay(self):
        """A chunk near the 4096-char limit should trigger longer delay."""
        adapter = _make_telegram_adapter()
        long_text = "x" * 4050  # near the 4096 limit
        adapter._enqueue_text_event(_make_event(long_text, Platform.TELEGRAM))

        # After the short delay, should NOT have flushed yet
        await asyncio.sleep(0.15)
        adapter.handle_message.assert_not_called()

        # After the split delay, should be flushed
        await asyncio.sleep(0.25)
        adapter.handle_message.assert_called_once()

    @pytest.mark.asyncio
    async def test_split_continuation_merged(self):
        """Two near-limit chunks should both be merged."""
        adapter = _make_telegram_adapter()

        adapter._enqueue_text_event(_make_event("x" * 4050, Platform.TELEGRAM))
        await asyncio.sleep(0.05)
        adapter._enqueue_text_event(_make_event("continuation text", Platform.TELEGRAM))

        # Short chunk arrived → should use normal delay now
        await asyncio.sleep(0.15)
        adapter.handle_message.assert_called_once()
        text = adapter.handle_message.call_args[0][0].text
        assert "continuation text" in text


# =====================================================================
# Feishu adaptive delay
# =====================================================================

def _make_feishu_adapter():
    """Create a minimal FeishuAdapter for testing adaptive delay."""
    from gateway.platforms.feishu import FeishuAdapter, FeishuBatchState

    config = PlatformConfig(enabled=True, token="test-token")
    adapter = object.__new__(FeishuAdapter)
    adapter._platform = Platform.FEISHU
    adapter.config = config
    batch_state = FeishuBatchState()
    adapter._pending_text_batches = batch_state.events
    adapter._pending_text_batch_tasks = batch_state.tasks
    adapter._pending_text_batch_counts = batch_state.counts
    adapter._text_batch_delay_seconds = 0.1
    adapter._text_batch_split_delay_seconds = 0.3
    adapter._text_batch_max_messages = 20
    adapter._text_batch_max_chars = 50000
    adapter._active_sessions = {}
    adapter._pending_messages = {}
    adapter._message_handler = AsyncMock()
    adapter._handle_message_with_guards = AsyncMock()
    return adapter


class TestFeishuAdaptiveDelay:
    @pytest.mark.asyncio
    async def test_short_chunk_uses_normal_delay(self):
        adapter = _make_feishu_adapter()
        event = _make_event("short msg", Platform.FEISHU)
        await adapter._enqueue_text_event(event)

        await asyncio.sleep(0.15)
        adapter._handle_message_with_guards.assert_called_once()

    @pytest.mark.asyncio
    async def test_near_limit_chunk_uses_split_delay(self):
        """A chunk near the 4096-char limit should trigger longer delay."""
        adapter = _make_feishu_adapter()
        long_text = "x" * 4050
        event = _make_event(long_text, Platform.FEISHU)
        await adapter._enqueue_text_event(event)

        await asyncio.sleep(0.15)
        adapter._handle_message_with_guards.assert_not_called()

        await asyncio.sleep(0.25)
        adapter._handle_message_with_guards.assert_called_once()

    @pytest.mark.asyncio
    async def test_split_continuation_merged(self):
        adapter = _make_feishu_adapter()

        await adapter._enqueue_text_event(_make_event("x" * 4050, Platform.FEISHU))
        await asyncio.sleep(0.05)
        await adapter._enqueue_text_event(_make_event("continuation text", Platform.FEISHU))

        await asyncio.sleep(0.15)
        adapter._handle_message_with_guards.assert_called_once()
        text = adapter._handle_message_with_guards.call_args[0][0].text
        assert "continuation text" in text
