#!/usr/bin/python

from math import floor, log
from os import environ
from random import randint
from subprocess import Popen, PIPE, run

u1  = {'sz': 1, 'fmt': b'%02x'}
u3  = {'sz': 1 << 3,  'fmt': b'%02x'}
u5  = {'sz': 1 << 5,  'fmt': b'%02x'}
u8  = {'sz': 1 << 8,  'fmt': b'%02x'}
u16 = {'sz': 1 << 16, 'fmt': b'%04x'}
u32 = {'sz': 1 << 32, 'fmt': b'%08x'}

def fmt(gx, x):
    return gx['fmt'].decode('utf-8') % (x % gx['sz'])

def testcase(p, sym, args, out, f):
    vals = [(name, g, randint(0, g['sz'] - 1)) for (name, g) in args]
    p.stdin.write(sym)
    for _, g, x in vals:
        p.stdin.write(b' ')
        p.stdin.write(g['fmt'] % x)
    p.stdin.write(b'\n')
    p.stdin.flush()
    got = p.stdout.readline().strip().decode('utf-8')
    xs = [x for _, _, x in vals]
    z = f(*xs)
    expected = fmt(out, z)
    if got == expected:
        return None
    elif out == u1 and bool(got) == bool(expected):
        return None
    else:
        res = {'got': got, 'expected': expected}
        for name, _, x in vals:
            res[name] = x
        return res

def test(p, trials, sym, args, out, f):
    fails = 0
    cases = []
    maximum = (1 << 32) - 1
    for i in range(0, trials):
        case = testcase(p, sym, args, out, f)
        if case is not None:
            fails += 1
            cases.append(case)
    name = sym.decode('utf-8')
    if fails == 0:
        print('%s passed %d trials' % (name, trials))
    else:
        print('%s failed %d/%d trials (%r)' % (name, fails, trials, cases))

def pipe():
    return Popen(['uxncli', 'run.rom'], stdin=PIPE, stdout=PIPE)

def bitcount(x):
    return floor(log(x, 2)) + 1

def gcd(a, b):
    return a if b == 0 else gcd(b, a % b)

def main():
    trials = 100
    run(['uxnasm', 'test-math32.tal', 'run.rom'])
    p = pipe()
    test(p, trials, b'+', [('x', u32), ('y', u32)], u32, lambda x, y: x + y)
    test(p, trials, b'-', [('x', u32), ('y', u32)], u32, lambda x, y: x - y)
    test(p, trials, b'*', [('x', u32), ('y', u32)], u32, lambda x, y: x * y)
    test(p, trials, b'/', [('x', u32), ('y', u32)], u32, lambda x, y: x // y)
    test(p, trials, b'%', [('x', u32), ('y', u32)], u32, lambda x, y: x % y)
    test(p, trials, b'G', [('x', u32), ('y', u32)], u32, gcd)
    test(p, trials, b'L', [('x', u32), ('y', u5)], u32, lambda x, y: x << y)
    test(p, trials, b'R', [('x', u32), ('y', u5)], u32, lambda x, y: x >> y)
    test(p, trials, b'B', [('x', u32)], u8, bitcount)
    test(p, trials, b'&', [('x', u32), ('y', u32)], u32, lambda x, y: x & y)
    test(p, trials, b'|', [('x', u32), ('y', u32)], u32, lambda x, y: x | y)
    test(p, trials, b'^', [('x', u32), ('y', u32)], u32, lambda x, y: x ^ y)
    test(p, trials, b'~', [('x', u32)], u32, lambda x: ~x)
    test(p, trials, b'N', [('x', u32)], u32, lambda x: -x)
    test(p, trials, b'=', [('x', u32), ('y', u32)], u1, lambda x, y: int(x == y))
    test(p, trials, b'!', [('x', u32), ('y', u32)], u1, lambda x, y: int(x != y))
    test(p, trials, b'0', [('x', u32)], u1, lambda x: int(x == 0))
    test(p, trials, b'Z', [('x', u32)], u1, lambda x: int(x != 0))
    test(p, trials, b'<', [('x', u32), ('y', u32)], u1, lambda x, y: int(x < y))
    test(p, trials, b'>', [('x', u32), ('y', u32)], u1, lambda x, y: int(x > y))
    test(p, trials, b'{', [('x', u32), ('y', u32)], u1, lambda x, y: int(x <= y))
    test(p, trials, b'}', [('x', u32), ('y', u32)], u1, lambda x, y: int(x >= y))
    p.stdin.close()
    p.stdout.close()
    p.kill()

if __name__ == "__main__":
    main()
