import string

import hypothesis
import hypothesis.strategies as st
import pytest


def _encode(s):
    # Grab first character.
    c = s[0]
    s = s[1:]
    # Setup initial state.
    last = c
    count = 1
    for c in s:
        if c != last:
            # We found a fresh character, yield what we had buffered.
            yield f"{count}{last}"
            # Reset state and continue loop.
            count = 1
            last = c
        else:
            # Found a repeat, just increase count.
            count += 1
    yield f"{count}{last}"


def encode(s):
    return "".join(_encode(s)) if s else s


def _partition(lst: list, n: int):
    for idx in range(0, len(lst), n):
        yield lst[idx : idx + n]


def partition(lst: list, n: int):
    if not lst or n < 1:
        return lst
    else:
        return _partition(lst, n)


def _decode(s):
    for n, c in partition(s, 2):
        yield c * int(n)


def decode(s):
    return "".join(_decode(s)) if s else s


#############
# TEST
##########


decode_test_data = [
    ("1a", "a"),
    ("3a", "aaa"),
    ("2a2b", "aabb"),
    ("1a3b1c2d", "abbbcdd"),
    ("4A3B2C1D2A", "AAAABBBCCDAA"),
]

encode_test_data = [(decoded, encoded) for encoded, decoded in decode_test_data]


@pytest.mark.parametrize("encoded_str, expected_str", decode_test_data)
def test_runlength_decoding(encoded_str, expected_str):
    assert decode(encoded_str) == expected_str


@pytest.mark.parametrize("s, expected_str", encode_test_data)
def test_runlength_encoding(s, expected_str):
    assert encode(s) == expected_str


@pytest.mark.parametrize("lst, expected", [([1, 2, 3, 4, 5, 6], [[1, 2], [3, 4], [5, 6]])])
def test_partition_basic(lst, expected):
    assert list(partition(lst, 2)) == expected


@hypothesis.given(st.integers(0, 100), st.integers(0, 100))
def test_partition_more(n, r):
    lst = list(range(r))
    partitioned_list = list(partition(lst, n))

    if r == 0:
        # No list generated for empty range.
        assert partitioned_list == []
    elif n == 0:
        # n = 0 should return original list.
        assert partitioned_list == lst
    elif n and n < r:
        # Check that we can get our original input back.
        flattend = []
        for part in partitioned_list:
            flattend += part
        assert flattend == lst

        # Check that we split correctly.
        remainder = r % n
        assert len(partitioned_list) == r // n + 1 if remainder else r // n
        if remainder:
            last_part = partitioned_list[-1:][0]
            partitioned_list = partitioned_list[:-1]
            assert len(last_part) == remainder
        assert all(map(lambda part: len(part) == n, partitioned_list))
    else:
        # If n >= r should return the original list wrapped in a single partition.
        assert partitioned_list == [lst]


@hypothesis.given(st.text(list(string.ascii_letters)))
def test_roundtrip(s):
    assert decode(encode(s)) == s