from vergeml.display import _parse_ansi, BufferOutput, ProgressBar, Table, Display, StatsTable
import time
import re

def test_parse_chars():
    assert list(_parse_ansi("Hello World!")) == [('ch', 'Hello World!')]

def test_parse_newline():
    assert list(_parse_ansi("Hello World!\n")) == [('ch', 'Hello World!'), ('nl', None)]

def test_parse_newline_cr():
    assert list(_parse_ansi("Hello World!\rHallo Welt!!!\n")) == [('ch', 'Hello World!'), ('cr', None), ('ch', 'Hallo Welt!!!'), ('nl', None)]

def test_parse_up():
    assert list(_parse_ansi("Hello!\033[10ABye!\n")) == [('ch', 'Hello!'), ('up', 10), ('ch', 'Bye!'), ('nl', None)]

def test_parse_down():
    assert list(_parse_ansi("Hello!\033[1BBye!\n")) == [('ch', 'Hello!'), ('down', 1), ('ch', 'Bye!'), ('nl', None)]

def test_parse_comb():
    assert list(_parse_ansi("Hello!\033[1B\033[2ABye!")) == [('ch', 'Hello!'), ('down', 1), ('up', 2), ('ch', 'Bye!')]

def test_buffero():
    buffer = BufferOutput()
    print("Hello World!", file=buffer)
    assert buffer.getvalue() == "Hello World!\n"

def test_buffero_long():
    buffer = BufferOutput()
    print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-YYYYYYYYYYYYYYYYYYYYY", file=buffer)
    assert buffer.getvalue() == "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n"

def test_buffero_multiline():
    buffer = BufferOutput()
    print("Hello World 1!", file=buffer)
    print("Hello World 2!", file=buffer)
    print("Hello World 3!", file=buffer)
    assert buffer.getvalue() == "Hello World 1!\nHello World 2!\nHello World 3!\n"

def test_buffero_cr():
    buffer = BufferOutput()
    buffer.write("Hello World!\rHa\n")
    assert buffer.getvalue() == "Hallo World!\n"

def test_buffero_too_far_up():
    buffer = BufferOutput()
    buffer.write("Hello World!\033[12A!!\n")
    assert buffer.getvalue() == "Hello World!!!\n"

def test_buffero_too_far_down():
    buffer = BufferOutput()
    buffer.write("Hello World!\033[12B!!\n")
    assert buffer.getvalue() == "Hello World!!!\n"

def test_buffero_stats():
    buffer = BufferOutput()
    print("Accuracy: -", file=buffer)
    print("Validation Accuracy: -", file=buffer)
    print("", file=buffer)
    print("Training ...", file=buffer)
    buffer.write("\033[4A")
    print("Accuracy: 0.73", file=buffer)
    print("Validation Accuracy: 0.70", file=buffer)
    assert buffer.getvalue() == "Accuracy: 0.73\nValidation Accuracy: 0.70\n\nTraining ...\n"

def test_progress():
    buffer = BufferOutput()
    progress = ProgressBar(range(100), file=buffer)
    progress.start()
    progress.update(1)
    time.sleep(0.001)
    progress.update(2)
    assert re.match(r'  3%\|█▏                                      \|   3/100 \[[0-9]+\.[0-9][0-9] it/sec\]', buffer.getvalue())
    
def test_table1():
    table = Table([[1,2,3]])
    assert str(table) == """\
╭───┬───┬───╮
│ 1 │ 2 │ 3 │
╰───┴───┴───╯"""

def test_table2():
    table = Table([[1,2,3], [10, 20, 30]])
    assert str(table) == """\
╭────┬────┬────╮
│ 1  │ 2  │ 3  │
├────┼────┼────┤
│ 10 │ 20 │ 30 │
╰────┴────┴────╯"""

def test_table3():
    table = Table([["Accuracy", "Val Accuracy", "Loss", "Val Loss"], [0.89, 0.88, 0.213, 0.334]])
    assert str(table) == """\
╭──────────┬──────────────┬───────┬──────────╮
│ Accuracy │ Val Accuracy │ Loss  │ Val Loss │
├──────────┼──────────────┼───────┼──────────┤
│     0.89 │         0.88 │ 0.213 │    0.334 │
╰──────────┴──────────────┴───────┴──────────╯"""

def test_table4():
    table = Table([["Accuracy", "Val Accuracy", "Loss", "Val Loss"], [0.89, 0.88, 0.213, 0.334], [0.23, 0.89, 0.001, 0.003]])
    assert str(table)  == """\
╭──────────┬──────────────┬───────┬──────────╮
│ Accuracy │ Val Accuracy │ Loss  │ Val Loss │
├──────────┼──────────────┼───────┼──────────┤
│     0.89 │         0.88 │ 0.213 │    0.334 │
│     0.23 │         0.89 │ 0.001 │    0.003 │
╰──────────┴──────────────┴───────┴──────────╯"""

def test_table5():
    table = Table([["Accuracy", "Val Accuracy", "Loss", "Val Loss"], [0.89, 0.88, 0.213, 0.334], [0.23, 0.89, 0.001, 0.003]], separate='row')
    assert str(table) == """\
╭──────────┬──────────────┬───────┬──────────╮
│ Accuracy │ Val Accuracy │ Loss  │ Val Loss │
├──────────┼──────────────┼───────┼──────────┤
│     0.89 │         0.88 │ 0.213 │    0.334 │
├──────────┼──────────────┼───────┼──────────┤
│     0.23 │         0.89 │ 0.001 │    0.003 │
╰──────────┴──────────────┴───────┴──────────╯"""

def test_table6():
    table = Table([["Accuracy", "Val Accuracy", "Loss", "Val Loss"], [0.89, 0.88, 0.213, 0.334], [0.23, 0.89, 0.001, 0.003]], separate='none')
    assert str(table) == """\
╭──────────┬──────────────┬───────┬──────────╮
│ Accuracy │ Val Accuracy │ Loss  │ Val Loss │
│     0.89 │         0.88 │ 0.213 │    0.334 │
│     0.23 │         0.89 │ 0.001 │    0.003 │
╰──────────┴──────────────┴───────┴──────────╯"""

def test_default_table():
    buffer = BufferOutput()
    display = Display(stdout=buffer, stderr=buffer)
    table = display.table([["Accuracy", "Val Accuracy", "Loss", "Val Loss"], [0.89, 0.88, 0.213, 0.334], [0.23, 0.89, 0.001, 0.003]], separate='none')
    assert str(table) == """\
╭──────────┬──────────────┬───────┬──────────╮
│ Accuracy │ Val Accuracy │ Loss  │ Val Loss │
│     0.89 │         0.88 │ 0.213 │    0.334 │
│     0.23 │         0.89 │ 0.001 │    0.003 │
╰──────────┴──────────────┴───────┴──────────╯"""


def test_default_progress():
    buffer = BufferOutput()
    display = Display(stdout=buffer, stderr=buffer)
    progress = display.progressbar(range(100),
                                   epochs=10,
                                   file=buffer)
    progress.start()
    assert buffer.getvalue() == 'Epoch  1/10|▎                                |   1/100 [     - it/sec]'

    progress.update(24)
    assert buffer.getvalue() == 'Epoch  1/10|████████▎                        |  25/100 [     - it/sec]'