/*
 All work, inefficiencies, and bugs 
 Copyright 2001 Kevin Fu <fubob@mit.edu>  

 October 18, 2001

 Demo implementation of Shoup's RSA threshold signature scheme protocol I from
 "Practical Threshold Signatures" EUROCRYPT 2000

 For 6.857 Network and Computer Security
 http://web.mit.edu/6.857/www/

 May be freely reproduced for educational or personal use
*/


import java.math.*;
import java.io.*;
import java.util.*;

public class genShares
{
    
    static BigInteger TWO = new BigInteger("2");
    static Random rnd = new Random(); // insecure randomness, but good for problem set

    public static void main(String args[])
    {


	if (args.length != 4) {
	    System.out.println("Usage: java genShares <threshold> <groups> <modulus bitsize> <prime certainty>");
	    System.out.println("Try java genShares 3 21 64 10");
	    System.out.println("Primes are probably prime with probability 1 - 2^{-certainty}");

	    System.exit(1);
	}
	
	int threshold = Integer.parseInt(args[0]);
	int groups    = Integer.parseInt(args[1]);
	int bitsize   = Integer.parseInt(args[2]);
	int certainty = Integer.parseInt(args[3]);

	/*
	  p_1, q_1 511-bit primes 
          p = 2p_1 + 1 512-bit prime
          q = 2q_1 + 1 512-bit prime
	 */

	BigInteger sp[] = new BigInteger[2];

	sp = safePrime (bitsize, certainty);
	BigInteger p = sp[0];
	BigInteger p1 = sp[1];
        System.out.println("p  = " + p + " = 2*p1 +1");
        System.out.println("p1 = " + p1);


	sp = safePrime (bitsize, certainty);
	BigInteger q = sp[0];
	BigInteger q1 = sp[1];
        System.out.println("q  = " + q + " = 2*q1 +1");
        System.out.println("q1 = " + q1);

	BigInteger n = p.multiply (q);
        System.out.println("n  = " + n + " = p*q");

	BigInteger m = p1.multiply (q1);
        System.out.println("m  = " + m + " = p1*q1");

	String groupsbits = Integer.toBinaryString (groups);
	BigInteger e = new BigInteger(groupsbits.length () +1, certainty, rnd);
	System.out.println("e  = " + e);
	if (e.compareTo (new BigInteger(Integer.toString(groups))) <= 0) {
	    System.out.println ("FAIL: e <= l");
	    System.exit(1);
	}

	BigInteger d = e.modInverse (m);
        System.out.println("d  = " + d);
	//	System.out.println("de mod m = " + (d.multiply(e)).mod(m) + " sanity check");

	System.out.println ("Coefficients:");
	BigInteger a[] = new BigInteger[threshold];
	a[0] = d;
	System.out.println ("a[0] = " + a[0]);
	for (int i=1; i < threshold; i++) {
	    a[i] = bigRandom (m);
	    System.out.println ("a[" + i + "] = " + a[i]);
	}	


	System.out.println ("Key shares:");
	BigInteger s[] = new BigInteger[groups];
	for (int i=0; i < groups; i++) {
	    // Note, start at 1, not 0 because s[0] = a[0] = d = secret!
	    s[i] = polynomial (i + 1, threshold, a, m);
	    System.out.println ("key share s[" + i + "] for group " + (i+1) + " = " + s[i]);
	}

	System.out.println ("QRs mod n:");
	BigInteger V = bigRandQR (p, q);

	System.out.println ("V  = " + V);
	BigInteger v[] = new BigInteger[groups];
	for (int i=0; i < groups; i++) {
	    v[i] = V.modPow (s[i], n);
	    System.out.println ("v[" + i + "] for group " + (i+1) + " = " + v[i]);
	}	
	
	// Sanity checks

	BigInteger x = new BigInteger ("quiessettantusfructus", 36);
	//BigInteger x = new BigInteger ("quiessettantusfructusinprosperis", 36);
	// nisihaberesquiillisaequeactuipse

	System.out.println ("Message   = " + x);
	if ((x.gcd (n)).compareTo (BigInteger.ONE) != 0) {
	    System.out.println ("FAIL: Message not relatively prime!  You got really lucky.");
	    System.exit(1);
	}
	
	
	BigInteger signature = genRsaSignature (d, n, x);
	System.out.println ("Signature = " + signature);
	
	if (verifyRsa (e, n, signature, x)) 
	    System.out.println ("PASS: Signature good");
	else {
	    System.out.println ("FAIL: Signature invalid");
	    System.exit (1);
	}

	if (verifyRsa (e, n, signature, x.add (BigInteger.ONE))) {
	    System.out.println ("FAIL: Bad Signature good");
	    System.exit (1);
	} else
	    System.out.println ("PASS: Bad Signature invalid");

	BigInteger sigShare[] = new BigInteger[groups];
	BigInteger l = new BigInteger (Integer.toString (groups));
	
	for (int i=0; i < groups; i++) {
	    sigShare[i] = genRsaSignatureShare (s[i], n, x, l);
	    System.out.println ("Signature share from group " + i + " = " + sigShare[i]);
	    
	    // fake the proof of correctness
	    BigInteger c = new BigInteger (bitsize, rnd);
	    BigInteger r = new BigInteger (bitsize, rnd);
	    BigInteger z = (s[i].multiply (c)).add (r);
	    BigInteger xhat = x.modPow ((new BigInteger ("4")).multiply (factorial (l)), n);
	    BigInteger vt = V.modPow (r, n);
	    BigInteger xt = xhat.modPow (r, n);

	    if (verifySignatureShare (sigShare[i],  n, V, c, v[i],  z,  xhat, vt, xt)) {
		System.out.println ("Signature share x[" + i + "] NIZK valid");
	    } else {
		System.out.println ("Signature share x[" + i + "] NIZK invalid");
	    }
	    
	}	
	
	
	testShares (x, e, n, signature, V, v, sigShare, groups, threshold);

    }

    static void testShares (BigInteger x, BigInteger e, BigInteger n,
			    BigInteger signature, BigInteger V,
			    BigInteger v[], BigInteger sigShare[],
			    int groups, int threshold) {

	BigInteger l = new BigInteger (Integer.toString (groups));
	
	System.out.println ("Signature shares verification:");

	// assume threshold =3
	    
	// Try every way of choosing t sigshares to produce a signature
	for (int jj=0; jj<groups; jj++) {
	    for (int kk=jj+1; kk<groups; kk++) { 
		for (int ll=kk+1; ll<groups; ll++) { 
		    int sigShareIndex[] = new int[3];	    
		    BigInteger someSigShare[] = new BigInteger[3];
		
		    sigShareIndex[0] = jj+1;
		    someSigShare[0] = sigShare[jj];

		    sigShareIndex[1] = kk+1;
		    someSigShare[1] = sigShare[kk];

		    sigShareIndex[2] = ll+1;
		    someSigShare[2] = sigShare[ll];
	    
		    BigInteger combSig = 
			combineRsaSignatureShares (someSigShare, sigShareIndex, 3, x, e, n, l);

		    
		    if (combSig.compareTo (signature) != 0) {
			System.out.println ("FAIL: signature invalid");
		    } else {
			System.out.println ("SUCCESS: signature valid");
		    }

		    
		    for (int mm=0; mm<3; mm++) {
			System.out.println ("Signature share x_" + sigShareIndex[mm] + ": " + someSigShare[mm]);
			System.out.println ("From group " + sigShareIndex[mm]);
			BigInteger le = lambdaExponent (sigShareIndex, 3, mm, l);
			System.out.println ("2\\lambda_{0," + sigShareIndex[mm] + "}^S = " + le);
			System.out.println ("v_" + sigShareIndex[mm] + " = " + v[mm]);
		    }

		
		    BigInteger w = lambdaProduct (someSigShare, sigShareIndex, 3, l, n);
    		    System.out.println ("w = " + w);

		    BigInteger delta = factorial (l);
		    BigInteger ee = delta.multiply (delta);
		    ee = ee.multiply (new BigInteger ("4"));
		    System.out.println ("e' " +  ee);
		    
		    BigInteger euclid[] = extendedEuclid (ee, e);
		    if (!euclid[0].equals (BigInteger.ONE)) {
			System.out.println ("FAIL: GCD not equal to one!");
			System.exit (1);
		    }
		    BigInteger a = euclid[1];
		    BigInteger b = euclid[2];
		    System.out.println ("a = " + a);
		    System.out.println ("b = " + b);

		    System.out.println ("Combined signature y: " +  combSig);
		    System.out.println ("");
		}
	    }   
	}

	
	
    }

    // Testing Java compilers for bugs.  Kaffe 1.0.5 fails this test
    // n prime
    // Random rnd = new Random();
    static void test () {
	BigInteger n = new BigInteger ("17");
	BigInteger c = new BigInteger ("5");
	BigInteger r = new BigInteger ("3");
	BigInteger s = new BigInteger ("7");
	BigInteger z = s.multiply (c);
	z = z.add (r);
	System.out.println ("z' = " + z);

	BigInteger V = new BigInteger ("11");

	BigInteger x = V.modPow (s, n);
	BigInteger v2 = V.modPow (r, n);
	BigInteger v1 = V.modPow (z, n);
	System.out.println ("v1' = " + v1);

	System.out.println ("x = " + x);
	System.out.println ("c.neg = " + c.negate ());
	System.out.println ("n = " + n);
	
	BigInteger tmp = x.modPow (c.negate(), n);
	//BigInteger tmp = x.modPow (c, n);
	//	tmp = tmp.modInverse (n);

	
	System.out.println ("tmp = " + tmp);

	v1 = v1.multiply (tmp);
	v1 = v1.mod (n);

	if (!v1.equals (v2)) {
	    System.out.println ("FAIL: v1 = " + v1);
	    System.out.println ("FAIL: v2 = " + v2);
	    
	}

	BigInteger mess = new BigInteger ("13");
	BigInteger xhat =  mess.modPow (new BigInteger ("480"), n);
	System.out.println ("xhat = " + xhat);
	BigInteger xt =  xhat.modPow (r, n);
	BigInteger xi = mess.modPow (s.multiply (new BigInteger ("240")), n); 
	BigInteger foo = (xi.modPow (c.multiply (TWO), n)).modInverse (n);
	BigInteger xxx = (xhat.modPow (z, n)).multiply (foo);
	xxx = xxx.mod (n);
	if (!xxx.equals (xt)) {
	    System.out.println ("FAIL: x' not equal to ~x^z x_i^{-2c}");
	    System.out.println ("FAIL: x' = " + xt);
	    System.out.println ("FAIL: xxx = " + xxx);
	}


    }

    // Returns a safe prime and (p-1)/2
    static BigInteger[] safePrime (int bitlen, int certainty) {
	BigInteger p[] = new BigInteger[2];
	p[0] = p[1] = BigInteger.ZERO;

	int counter = 0;
	while (!p[0].isProbablePrime(certainty)) {
	    p[1] = new BigInteger (bitlen, certainty, rnd);
	    p[0] = (p[1].multiply (TWO)).add (BigInteger.ONE);
	    counter++;
	}
	
	System.out.println ("prime iterations = " + counter);

	return p;
    }

    // Returns a number uniformly in the range 0..n-1
    static BigInteger bigRandom (BigInteger n) {
	
	BigInteger r = n;

	// Note, we must toss out any random number out of range.
	// Other operations like truncation would throw the distribution
	// out of wack
	while (r.compareTo (n) >= 0 || r.compareTo (BigInteger.ZERO) < 0) {
	    r = new BigInteger (n.bitLength (), rnd);
	}

	return r;
    }
    
    // Return: Sigma_{i=0}^{a.length-1} a[i]*group^i
    // Requires groups >= 1
    static BigInteger polynomial (int group, int threshold,
				  BigInteger a[], BigInteger modulus) {
	BigInteger X = new BigInteger (Integer.toString (group));
	BigInteger result = a[0];
	
	int i;
	for (i = 1; i < threshold; i++) {
	    result = result.add 
		(a[i].multiply (X.modPow (new BigInteger (Integer.toString (i)), modulus)));
	}

	return result.mod (modulus);
    }

    // Return true if "a" is a QR mod p where p prime.  False otherwise.
    // Requires p prime
    static boolean QR (BigInteger a, BigInteger p) {

	if ((a.gcd (p)).compareTo (BigInteger.ONE) != 0) {
	    // a not in Z_{p}^*
	    return false;
	}

	if ((a.modPow ((p.subtract (BigInteger.ONE)).divide (TWO), p))
	    .compareTo (BigInteger.ONE) == 0) {
	    return true;
	} else {
	    return false;
	}
    }

    // Return a random quadratic residue mod pq
    static BigInteger bigRandQR (BigInteger p, BigInteger q) {
	BigInteger n = p.multiply (q);
	BigInteger a = bigRandom (n);

	while (!QR (a, p) || !QR (a, q)) {
	    a = bigRandom (n);
	}
	
	return a;
    }

    // Return euler totient function of "pq"
    static BigInteger phi (BigInteger p, BigInteger q) {
	return ((p.subtract (BigInteger.ONE)).multiply 
		(q.subtract (BigInteger.ONE)));
    }

    // True if the signature is valid on x for the given public key (e,n)
    static boolean verifyRsa (BigInteger e, BigInteger n, BigInteger signature,
			      BigInteger x) {

	if ((signature.modPow (e, n)).compareTo (x) == 0) 
	    return true;
	else
	    return false;
    }


    // Make RSA signature x for the given public key (e, n) and secret (e)
    static BigInteger genRsaSignature (BigInteger d, BigInteger n, BigInteger x) {
	return x.modPow (d, n);
    }


    // l= # of groups
    static BigInteger genRsaSignatureShare (BigInteger s, BigInteger n, 
					    BigInteger x, BigInteger l) {
	return x.modPow ((TWO.multiply (s)).multiply (factorial (l)) , n);
    }

    // This only checks that the last two arguments to the H' hash are good.
    // It's not a real NIZK proof.
    static boolean verifySignatureShare (BigInteger sigShare, BigInteger n,
					 BigInteger V, BigInteger c,
					 BigInteger v, BigInteger z, 
					 BigInteger xhat,
					 BigInteger vt, BigInteger xt) {

	BigInteger vvv = ((V.modPow (z, n)).
			  multiply ((v.modPow (c, n)).modInverse (n))).mod (n);
	if (!vvv.equals (vt)) {
	    System.out.println ("FAIL: v' not equal to v^z v_i^{-c}");
	    System.out.println ("v' = " + vt);
	    System.out.println ("vvv = " + vvv);
	    
	    return false;
	}

	BigInteger foo= (sigShare.modPow (c.multiply (TWO), n)).modInverse (n);
	BigInteger xxx = (xhat.modPow (z, n)).multiply (foo);
	xxx = xxx.mod (n);
	if (!xxx.equals (xt)) {
	    System.out.println ("FAIL: x' not equal to ~x^z x_i^{-2c}");
	    System.out.println ("x' = " + xt);
	    System.out.println ("xxx = " + xxx);
	    
	    return false;
	}
	    
	return true;
    }

    static BigInteger combineRsaSignatureShares (BigInteger sigShare[], 
						 int sigShareIndex[],
						 int numShares,
						 BigInteger x, BigInteger e,
						 BigInteger n, BigInteger l) {

	BigInteger w = lambdaProduct (sigShare, sigShareIndex, numShares, l, n);

	BigInteger delta = factorial (l);

	BigInteger ee = delta.multiply (delta);
	ee = ee.multiply (new BigInteger ("4"));

	BigInteger euclid[] = extendedEuclid (ee, e);
	if (!euclid[0].equals (BigInteger.ONE)) {
	    System.out.println ("FAIL: GCD not equal to one!");
	    System.exit (1);
	}

	BigInteger a = euclid[1];
	BigInteger b = euclid[2];

	return (((w.modPow (a, n)).multiply (x.modPow (b, n))).mod (n));

    }

    // Given an array of signature shares and the indexes (whose share it is)
    // return the value w 
    static BigInteger lambdaProduct (BigInteger sigShare[], int sigShareIndex[], int numShares,
				     BigInteger l, BigInteger n) {
	BigInteger w = BigInteger.ONE;

	for (int i = 0; i < numShares; i++) {
	    BigInteger le = lambdaExponent (sigShareIndex, numShares, i, l);
	    if (le.compareTo (BigInteger.ZERO) < 0) {
		w = w.multiply ((sigShare[i].modPow ((le.negate ()).multiply (TWO), n)).modInverse (n));
	    } else {
		w = w.multiply (sigShare[i].modPow (le.multiply (TWO), n));
	    }

	    w = w.mod (n);
	}
	
	return w;
    }

    static BigInteger lambdaExponent (int sigShareIndex[], int numShares, int j, BigInteger l) {
	BigInteger lambda = factorial (l);
	int k;

	for (k = 0; k < numShares; k++) {
	    if (k == j)
		continue;
	    
	    //  System.out.println (sigShareIndex[j] + "  " + sigShareIndex[k]);
	    BigInteger denom = new BigInteger (Integer.toString (sigShareIndex[j] - sigShareIndex[k]));

	    lambda = lambda.divide (denom);
	}


	for (k = 0; k < numShares; k++) {
	    if (k == j)
		continue;
	    
	    lambda = lambda.multiply (new BigInteger (Integer.toString (0 - sigShareIndex[k])));
	}

	return lambda;

	// It's critically important not to mod n here!
	
    }

    /* Based on code from
         Boufounos,Petros T    : petrosb@mit.edu
         Castagnola,Luciano    : luciano@mit.edu
         Michalakis,Nikolaos   : nikos@mit.edu
    */
    static BigInteger[] extendedEuclid(BigInteger a, BigInteger b)
    {
	// [c, x, y] where gcd(a,b) = c = xa + yb
        BigInteger[] result = new BigInteger[3];

        if(b.equals(BigInteger.ZERO)) {
            result[0]=a;
            result[1]=BigInteger.ONE;
            result[2]=BigInteger.ZERO;
            return result;
        }

        BigInteger[] tmp;
        tmp = extendedEuclid(b, a.mod(b));

        result[0] = tmp[0];
        result[1] = tmp[2];
        result[2] = tmp[1].subtract(a.divide(b).multiply(tmp[2]));
        return result;
    }


    static BigInteger factorial(BigInteger l) {
	BigInteger result = BigInteger.ONE;
	BigInteger count = l;
	
	while (count.compareTo (BigInteger.ONE) > 0) {
            result = result.multiply (count);
	    count = count.subtract (BigInteger.ONE);
        } 
        
	// ystem.out.println ("fact(" + l + ") = " + result);
	return result;
    }

}





