/* testing montgomery multiplication and associated routines
 * pesco, 2009
 */

#include <stdio.h>
#include <string.h>
#include "monty.h"

void test(const char *s, int b)
{
	printf("  %s: %s\n", s, b?"ok":"FAIL");
}

int main(int argc, char **argv)
{
        /* 1536-bit modulus from RFC 3526, little-endian */
        uint32_t N[] = {0xFFFFFFFF, 0xFFFFFFFF, 0xCA237327, 0xF1746C08, 0x4ABC9804, 0x670C354E,
                        0x7096966D, 0x9ED52907, 0x208552BB, 0x1C62F356, 0xDCA3AD96, 0x83655D23,
                        0xFD24CF5F, 0x69163FA8, 0x1C55D39A, 0x98DA4836, 0xA163BF05, 0xC2007CB8,
                        0xECE45B3D, 0x49286651, 0x7C4B1FE6, 0xAE9F2411, 0x5A899FA5, 0xEE386BFB,
                        0xF406B7ED, 0x0BFF5CB6, 0xA637ED6B, 0xF44C42E9, 0x625E7EC6, 0xE485B576,
                        0x6D51C245, 0x4FE1356D, 0xF25F1437, 0x302B0A6D, 0xCD3A431B, 0xEF9519B3,
                        0x8E3404DD, 0x514A0879, 0x3B139B22, 0x020BBEA6, 0x8A67CC74, 0x29024E08,
                        0x80DC1CD1, 0xC4C6628B, 0x2168C234, 0xC90FDAA2, 0xFFFFFFFF, 0xFFFFFFFF};
        /* 2^3072 mod N */
        uint32_t RR[] = {0x32c695e0, 0xf115d27d, 0x67478c73, 0x8e0e3e21, 0x8397f245, 0xd0ab92e1,
                         0xbcd49d68, 0xf466ee5f, 0x3b01e018, 0x8f2331b1, 0x98b5fb62, 0x7e8cd2ac,
                         0x7a58f170, 0xb9052bb4, 0xdb102d39, 0xb004a750, 0x93ae1ceb, 0x04a541ff,
                         0x8e434130, 0x07cd0a62, 0x04b9f796, 0x1c729c7e, 0x196b7e88, 0xb8fe6121,
                         0x0223b76b, 0x8e1abd78, 0xd46fec23, 0x22c296e9, 0xb270521b, 0xd62a0eea,
                         0xd4053f54, 0xdc541a4e, 0x969b7f02, 0xf8056564, 0xa87c7b37, 0x0be49647,
                         0x67984460, 0x57b59348, 0x9a36a51f, 0x102630fa, 0xcc2456ef, 0xe9c3fa02,
                         0x7929a1c7, 0xae594104, 0x6cc1ebd2, 0xee9c9a21, 0x59541c01, 0xe3b33c72};
        int n = sizeof(N) / sizeof(uint32_t);
	uint32_t R[n];
        uint32_t one[n];
        uint32_t zero[n];
        uint32_t minus1[n];
	uint32_t a[n];
	uint32_t b[n];
	uint32_t c[n];
	uint32_t d[n];
	uint32_t e[n];
	uint32_t f[n];
	int i,j;
	int k1,k2,k3,k4;
	uint32_t x,y;

	/* initialize */
        memset(zero, 0, sizeof(N));
        memset(one,  0, sizeof(N)); one[0] = 1;
	for(i=0; i<n; i++) minus1[i] = 0xFFFFFFFF;
	monty_mul(n, N, R, one, RR);


	printf("addition:\n");

	mrand(n, N, a);
	memcpy(b, a, n*4);         /* b <- a */
	muladd(n, b, 0, 1, zero);  /* b <- b + 1*0 */
	test("zero neutrality", eq(n, a, b));

	mrand(n, N, a);
	memcpy(b, a, n*4);
	memcpy(c, a, n*4);
	muladd(n, b, 0, 1, one);
	for(i=0; i<n; i++) {
		c[i]++;
		if(c[i] != 0)
			break;
	}
	test("one increments", eq(n, b, c));

	mrand(n, N, a);
	mrand(n, N, b);
	mrand(n, N, c);
	memcpy(d, a, n*4);
	muladd(n, d, 0, 1, b);  /* d = a + b */
	muladd(n, d, 0, 1, c);  /* d = (a + b) + c */
	muladd(n, b, 0, 1, c);  /* b = b + c */
	muladd(n, a, 0, 1, b);  /* a = a + (b + c) */
	test("associativity", eq(n, a, d));

	mrand(n, N, a);
	mrand(n, N, b);
	memcpy(c, a, n*4);
	muladd(n, c, 0, 1, b);  /* c = a + b */
	muladd(n, b, 0, 1, a);  /* b = b + a */
	test("commutativity", eq(n, b, c));

	printf("subtraction:\n");

	mrand(n, N, a);
	memcpy(b, a, n*4);
	sub(n, b, 0, one);
	for(i=0; i<n; i++) {
		a[i]--;
		if(a[i] != 0-1)
			break;
	}
	test("one decrements", eq(n, a, b));

	mrand(n, N, a);
	memcpy(b, a, n*4);  /* b = a */
	sub(n, b, 0, a);    /* b = a - a */
	test("yield zero", eq(n, b, zero));

	mrand(n, N, a);
	mrand(n, N, b);
	memcpy(c, b, n*4);
	sub(n, b, 0, a);        /* b = b - a */
	muladd(n, b, 0, 1, a);  /* b = (b - a) + a */
	test("addition neutralizes", eq(n, b, c));

	mrand(n, N, a);
	mrand(n, N, b);
	memcpy(c, b, n*4);
	muladd(n, b, 0, 1, a);  /* b = b + a */
	sub(n, b, 0, a);        /* b = (b + a) - a */
	test("neutralizes addition", eq(n, b, c));

	printf("multiplication:\n");

	mrand(n, N, a);
	memset(b, 0, n*4);
	muladd(n, b, 0, 1, a);  /* b = 0 + 1*a */
	test("one left-neutral", eq(n, a, b));

	memset(a, 0, n*4);
	a[0] = random();
	memset(b, 0, n*4);
	muladd(n, b, 0, a[0], one);  /* b = 0 + a*1 */
	test("one right-neutral", eq(n, a, b));

	memset(a, 0, n*4);
	muladd(n, a, 0, 0, a);
	test("zero left nullifies", eq(n, a, zero));

	memset(a, 0, n*4);
	muladd(n, a, 0, random(), zero);
	test("zero right nullifies", eq(n, a, zero));

	mrand(n, N, a);
	memset(b, 0, n*4);
	muladd(n, b, 0, 2, a);
	muladd(n, a, 0, 1, a);
	test("two doubles", eq(n, a, b));

	mrand(n, N, a);
	x = random() % 100;
	y = random() % 100;
	memset(b, 0, n*4);
	memset(c, 0, n*4);
	memset(d, 0, n*4);
	muladd(n, b, 0, x, a);  
	muladd(n, c, 0, y, b);    /* c = y * (x * a) */
	muladd(n, d, 0, y*x, a);  /* d = (y * x) * a */
	test("associativity", eq(n, c, d));

	mrand(n, N, a);
	mrand(n, N, b);
	x = random();
	memcpy(c,a,n*4);
	muladd(n, c, 0, x, b);  /* c = a + x b */
	memset(d,0,n*4);
	muladd(n, d, 0, x, b);  /* d = 0 + x b */
	muladd(n, d, 0, 1, a);  /* d = x b + a */
	test("muladd combination", eq(n, c, d));

	printf("montgomery multiplication:\n");

	mrand(n, N, a);
	monty_mul(n, N, b, a, R);
	test("R left-neutral", eq(n, a, b));

	mrand(n, N, a);
	monty_mul(n, N, b, R, a);
	test("R right-neutral", eq(n, a, b));

	mrand(n, N, a);
	monty_mul(n, N, b, zero, a);
	test("zero left nullifies", eq(n, b, zero));

	mrand(n, N, a);
	monty_mul(n, N, b, a, zero);
	test("zero right nullifies", eq(n, b, zero));

	mrand(n, N, a);
	mrand(n, N, b);
	monty_mul(n, N, c, a, b);
	monty_mul(n, N, d, b, a);
	test("commutativity", eq(n, c, d));

	k1=k2=k3=k4=0;
	for(i=0; i<1000; i++) {
		if(i>0) printf("\r  associativity: %d%% ok  %d%% off by 1  %d%% off by -1  %d%% junk ", k1*100/i, k2*100/i, k3*100/i, k4*100/i);
		mrand(n, N, a);
		mrand(n, N, b);
		mrand(n, N, c);
		monty_mul(n, N, d, b, c);
		monty_mul(n, N, e, a, d);  /* e = a * (b * c) */
		monty_mul(n, N, d, a, b);
		monty_mul(n, N, f, d, c);  /* a = (a * b) * c */
		sub(n, e, 0, f);
		if(eq(n, e, zero))
			k1++;
		else if(eq(n, e, one))
			k2++;
		else if(eq(n, e, minus1))
			k3++;
		else
			k4++;
	}
	printf("\n");


	return 0;
}
