import api.streaming as streaming


class _FakeSession:
    def __init__(self):
        self.messages = [{'role': 'user', 'content': 'hello'}]
        self.title = 'Untitled'
        self.saved = 0

    def save(self):
        self.saved += 1


def test_persist_stream_error_appends_assistant_message(monkeypatch):
    fake = _FakeSession()
    monkeypatch.setattr(streaming, 'get_session', lambda _sid: fake)
    monkeypatch.setattr(streaming, 'title_from', lambda messages, fallback: 'Updated title')

    streaming._persist_stream_error(
        'sess-1',
        'rate_limit',
        'HTTP 429 upstream rate limit',
        'Try again in a moment.',
    )

    assert fake.saved == 1
    assert fake.title == 'Updated title'
    assert fake.messages[-1]['role'] == 'assistant'
    assert 'Rate limit reached' in fake.messages[-1]['content']
    assert 'HTTP 429 upstream rate limit' in fake.messages[-1]['content']
    assert 'Try again in a moment.' in fake.messages[-1]['content']


def test_persist_stream_error_deduplicates_same_message(monkeypatch):
    fake = _FakeSession()
    content = '**⚠️ Error:** Boom'
    fake.messages.append({'role': 'assistant', 'content': content})
    monkeypatch.setattr(streaming, 'get_session', lambda _sid: fake)

    streaming._persist_stream_error('sess-1', 'error', 'Boom')

    assert fake.saved == 0
    assert fake.messages[-1]['content'] == content


def test_persist_stream_error_uses_persisted_session_when_cache_is_stale(monkeypatch):
    cached = _FakeSession()
    cached.messages = []
    persisted = _FakeSession()
    persisted.messages = [{'role': 'user', 'content': 'test'}]

    monkeypatch.setattr(streaming, 'get_session', lambda _sid: cached)
    monkeypatch.setattr(streaming, 'Session', type('FakeSessionLoader', (), {
        'load': staticmethod(lambda _sid: persisted),
    }))
    monkeypatch.setattr(streaming, 'title_from', lambda messages, fallback: 'Updated title')

    streaming._persist_stream_error(
        'sess-1',
        'rate_limit',
        'HTTP 429 upstream rate limit',
        'Try again in a moment.',
    )

    assert persisted.saved == 1
    assert persisted.messages[0]['role'] == 'user'
    assert persisted.messages[-1]['role'] == 'assistant'
    assert 'Rate limit reached' in persisted.messages[-1]['content']


def test_ensure_final_response_message_appends_missing_assistant():
    messages = [{'role': 'user', 'content': 'test'}]
    updated = streaming._ensure_final_response_message(messages, 'API call failed after 3 retries')

    assert updated[-1]['role'] == 'assistant'
    assert updated[-1]['content'] == 'API call failed after 3 retries'
    assert updated[0]['role'] == 'user'


def test_ensure_final_response_message_skips_duplicate_assistant():
    messages = [
        {'role': 'user', 'content': 'test'},
        {'role': 'assistant', 'content': 'API call failed after 3 retries'},
    ]
    updated = streaming._ensure_final_response_message(messages, 'API call failed after 3 retries')

    assert updated == messages
