/* montgomery multiplication / exponentiation
 * pesco, 2009
 */

#include "monty.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>


int eq(int n, uint32_t *A, uint32_t *B)
{
	int i;

	for(i=0; i<n; i++) {
		if(*A != *B)
			return 0;
		A++;
		B++;
	}

	return 1;
}

int gte(int n, uint32_t *A, int i, uint32_t *B)
{
        int j;
        uint32_t a,b;

        for(j=n-1; j>=0; j--) {
		i = (i+n-1) % n;

                a = A[i];
                b = B[j];

                if(a!=b)
                        return (a>b);
        }

        return 1;
}

/* A -= B */
int sub(int n, uint32_t *A, int i, uint32_t *B)
{
        int j;
        uint32_t a,b,k=0;

        for(j=0; j<n; j++) {
                a = A[i];
                b = B[j];

                A[i] = a - k - b;
                k = a<k || a-k<b;
		assert(k==0 || k==1);

                i = (i+1) % n;
        }

	return k;
}

/* A += x*Y */
uint32_t muladd(int n, uint32_t *A, int i, uint32_t x, uint32_t *Y)
{
        uint64_t t=0;   /* double precision intermediate result */
        int j;

        for(j=0; j<n; j++) {
                t += (uint64_t)A[i] + (uint64_t)x * (uint64_t)Y[j];
                A[i] = t;       /* mod 2^32 */
                i = (i+1) % n;
                t >>= 32;       /* carry */
        }

        return t;
}

/* A = X*Y / 2^(32*n) mod N
 *
 * precondition: N[0] = 0xFFFFFFFF
 * precondition: N >= 2^(32*n-1)   (i.e. highest bit is set)
 */
void monty_mul(int n, uint32_t *N, uint32_t *A, uint32_t *X, uint32_t *Y)
{
        uint64_t c=0;   /* carry, we need 33 bits (yep) */
        int i;

	memset(A, 0, n*4);

        for(i=0; i<n; ) {
		assert(!gte(n, A, i, N));
                c = muladd(n, A, i, X[i], Y);   /* A = A + xi Y */
		                                /* assert: N = -1 (mod 2^32) */
                c += muladd(n, A, i, A[i], N);  /* A = A + a0 N */

                /* A[i] should now be 0 and will hold a_(n-1) after the shift */
                assert(A[i] == 0);
                A[i] = c;
                c >>= 32;       /* only 1 bit of carry left */
                assert(c<2);

                /* shift happens by incrementing i */
		i++;

		/* modulo reduction */
		/* note: A < 2N */
                if(c || gte(n, A, i%n, N))
                        sub(n, A, i%n, N);
        }
}

/* A = (X / 2^(32*n))^Y * 2^(32*n) mod N
 *
 * precondition: N[0] = 0xFFFFFFFF
 * precondition: N >= 2^(32*n-1)     (i.e. highest bit is set)
 * precondition: R = 2^(32*n) mod N  (i.e. 1 in montgomery representation)
 */
void monty_exp(int n, uint32_t *N, uint32_t *R, uint32_t *A, uint32_t *X, uint32_t *Y)
{
	uint32_t AA[n];
	uint32_t *p;
	int i;

	memcpy(A, R, n*4);

	p = Y+n-1;

	/* skip zero digits */
	while(p>=Y && *p==0)
		p--;

	/* square & multiply */
	while(p>=Y) {
		uint32_t y = *p--;
		
		for(i=0; i<32; i++) {
			monty_mul(n, N, AA, A, A);          /* square */

			if(y&0x80000000) {
				monty_mul(n, N, A, X, AA);  /* multiply */
			} else {
				memcpy(A, AA, n*4);
			}

			y <<= 1;
		}
	}
}

/* A = random mod N */
void mrand(int n, uint32_t *N, uint32_t *A)
{
	FILE *f;
	int i;

	f = fopen("/dev/urandom", "r");
	if(!f) {
		perror("/dev/urandom");
		exit(-1);
	}

	do {
		unsigned char *p = (unsigned char *)A;
		for(i=0; i<n*4; i++)
			*p++ = fgetc(f);
	} while(gte(n, A, 0, N));

	fclose(f);
}
