/*
 Copyright 2001 Joseph R Hastings, Patrick M Cody, Pavel Langer, Ed Cotler

 October 29, 2001

 Demo implementation of Shoup's RSA threshold signature scheme protocol I from
 "Practical Threshold Signatures" EUROCRYPT 2000

 For 6.857 Network and Computer Security
 http://web.mit.edu/6.857/www/

 May be freely reproduced for educational or personal use
*/

import java.math.*;
import java.io.*;
import java.util.*;

public class thresholdRsa {
    
    private static final BigInteger TWO = new BigInteger("2");
    private static final int CERTAINTY = 10;
    
    // insecure randomness, but good for problem set
    private static Random rnd = new Random(1); 

    private static void usage () {
        System.out.println
            ("Usage: java thresholdRsa [generateX | deal | combine <file>]");
        System.exit (1);
    }
    
    public static void main (String argv[]) {

        if (argv.length < 1) {
            usage();
        } else if (argv[0].equals ("generateX")) {
            generateX ();
        } else if (argv[0].equals ("deal")) {
            simulateDealer ();
        } else if (argv[0].equals ("combine")) {
            if (argv.length < 2) { usage(); }
            combineSignatures (argv[1]);
            // combine
        } else {
            usage();
        }
    }

    private static String read (BufferedReader in) {
        String s = null;
        do {
            try {
                s = in.readLine();
            } catch (IOException e) {
                e.printStackTrace();
                System.exit (1);
            }
        } while (s.startsWith ("#"));
        return s;
    }

    private static void combineSignatures (String filename) {

        System.out.println ("combining signatures.");

        BufferedReader in = null;
        try {
            in = new BufferedReader (new FileReader(filename));
        } catch (FileNotFoundException e) {
            e.printStackTrace ();
            System.exit (1);
        }
        
        BigInteger n = new BigInteger (read (in));
        BigInteger e = new BigInteger (read (in));
        BigInteger l = new BigInteger (read (in));
        int k = Integer.parseInt (read (in));
        BigInteger V = new BigInteger (read (in));
        BigInteger x = new BigInteger (read (in));
        int g = Integer.parseInt (read (in));
        BigInteger sg = new BigInteger (read (in));
        BigInteger vg = new BigInteger (read (in));

        int[] sigShareIndex = new int [k];
        sigShareIndex[0] = g;
        BigInteger[] sigShare = new BigInteger [k];
        BigInteger[][] sigShareProof = new BigInteger[k][2];
        BigInteger[] v = new BigInteger[k];
        v[0] = vg;

        for (int i = 1; i < k; i++) {
            sigShare [i] = new BigInteger (read (in));
            sigShareProof [i][0] = new BigInteger (read (in));
            sigShareProof [i][1] = new BigInteger (read (in));
            v[i] = new BigInteger (read (in));
            sigShareIndex [i] = Integer.parseInt (read (in));
        }

        sigShare [0] = genRsaSignatureShare (sg, n, x, l);

        BigInteger xhat = x.modPow (BigInteger.valueOf (4).multiply 
                                    (factorial (l)), n);

        System.out.println ("xhat = " + xhat);

        sigShareProof [0] = generateCorrectnessProof
            (sg, n, V, v[0], sigShare[0], xhat);

        for (int i = 0; i < k; i++) {
            BigInteger[] zc = sigShareProof [i];
            BigInteger vt = V.modPow (zc[0], n).
                multiply (v[i].modPow (BigInteger.ZERO.subtract (zc[1]), n)).
                mod (n);
            System.out.println ("v'[" + sigShareIndex[i] + "] = " + vt);
            BigInteger xt = xhat.modPow (zc[0], n).
                multiply (sigShare[i].modPow (BigInteger.valueOf (-2).
                                     multiply (zc[1]), n)).
                mod (n);
            boolean b = verifySignatureShare
                (sigShare[i], n, V, zc[1], v[i], zc[0], xhat, vt, xt);
            System.out.println ("signature share from group " + 
                                sigShareIndex[i] +
                                (b ? " verifies." : " DOES NOT verify."));
        }

        BigInteger y = combineRsaSignatureShares 
            (sigShare, sigShareIndex, k, x, e, n, l);

        System.out.println ("combined signature" + 
                            (verifyRsa (e, n, y, x) ? " verifies." : 
                             " DOES NOT verify."));
        
        System.out.println ("xg = " + sigShare [0] + 
                            "\nzg = " + sigShareProof[0][0] + 
                            "\ncg = " + sigShareProof[0][1] +
                            "\nvg = " + v[0] +
                            "\ng = " + g);

        System.out.println ("y = " + y);
    }

    private static void simulateDealer () {
        System.out.println ("Testing the code by simulating the dealer.");
        
        int bitlen = 100; // 512
        int l = 21; // 21
        BigInteger ll = BigInteger.valueOf (l);
        int k = 3; // 3

        BigInteger[] pp = safePrime (bitlen, CERTAINTY); // p, p'
        BigInteger[] qq = safePrime (bitlen, CERTAINTY); // q, q'
        while (pp[0].equals (qq[0])) {
            qq = safePrime (bitlen, CERTAINTY);
        }
        BigInteger m = pp[1].multiply (qq[1]).
            multiply (BigInteger.valueOf (4));
        BigInteger n = pp[0].multiply (qq[0]);

        // e must be prime greater than l
        BigInteger e;
        do {
            e = new BigInteger (m.bitLength(), CERTAINTY, rnd);
        } while (e.compareTo (ll) < 0 || e.compareTo (m) >= 0);

        System.out.println ("l = " + l +
                            "\nk = " + k +
                            "\ne = " + e +
                            "\np' = " + pp[1] +
                            "\np = " + pp[0] +
                            "\nq' = " + qq[1] +
                            "\nq = " + qq[0] +
                            "\nm = " + m +
                            "\nn = " + n);

        BigInteger[] a = new BigInteger [k];
        a[0] = e.modInverse (m);
        
        System.out.println ("d = " + a[0]);

        for (int i = 1; i < k; i++) {
            a[i] = bigRandom (m);
            System.out.println ("a[" + i + "] = " + a[i]);
        }

        BigInteger[] s = new BigInteger [l + 1];
        for (int i = 1; i <= l; i++) {
            s[i] = polynomial (i, k, a, m);
            System.out.println ("s[" + i + "] = " + s[i]);
        }
        BigInteger V = bigRandQR (pp[0], qq[0]);

        System.out.println ("V = " + V);

        BigInteger[] v = new BigInteger [l + 1];
        for (int i = 1; i <= l; i++) {
            v[i] = V.modPow (s[i], n);
            System.out.println ("v[" + i + "] = " + v[i]);
        }

        // pick k distinct groups at random
        int[] sigShareIndex = new int [k];
        for (int i = 0; i < k; i++) {
            boolean b;
            do {
                b = false;
                sigShareIndex[i] = rnd.nextInt (l) + 1;
                for (int j = 0; j < i && !b; j++) {
                    if (sigShareIndex[i] == sigShareIndex[j]) {
                        b = true;
                    }
                }
            } while (b);
        }
        
        BigInteger x = new BigInteger ("nisihaberesquiillisaequeactuipse", 36);
        x = x.mod (n);

        System.out.println ("x = " + x);

        BigInteger xhat = x.modPow (BigInteger.valueOf (4).multiply 
                                    (factorial (ll)), n);

        System.out.println ("xhat = " + xhat);

        BigInteger[] sigShare = new BigInteger [k];
        BigInteger[][] sigShareProof = new BigInteger[k][];
        for (int i = 0; i < k; i++) {
            int j = sigShareIndex[i];
            BigInteger si = s[j];
            BigInteger vi = v[j];
            BigInteger xi = genRsaSignatureShare 
                (si, n, x, ll);
            sigShare [i] = xi;

            // generate proof
            BigInteger[] zc = generateCorrectnessProof
                (si, n, V, vi, xi, xhat);
            sigShareProof[i] = zc;


            System.out.println ("x[" + j + "] = " + xi +
                                "\nz[" + j + "] = " + zc[0] +
                                "\nc[" + j + "] = " + zc[1]);

            // verify proof
            BigInteger vt = V.modPow (zc[0], n).
                multiply (vi.modPow (BigInteger.ZERO.subtract (zc[1]), n)).
                mod (n);
            BigInteger xt = xhat.modPow (zc[0], n).
                multiply (xi.modPow (BigInteger.valueOf (-2).
                                     multiply (zc[1]), n)).
                mod (n);
            boolean b = verifySignatureShare
                (xi, n, V, zc[1], vi, zc[0], xhat, vt, xt);
            System.out.println ("verify " + j + ": " + b);
        }

        // generate y using d
        BigInteger y1 = x.modPow (a[0], n);
        BigInteger y2 = combineRsaSignatureShares 
            (sigShare, sigShareIndex, k, x, e, n, ll);
        
        System.out.println ("y1 = " + y1 + 
                            "\ny2 = " + y2);

        System.out.println ("y1 == y2: " + y1.equals(y2));
        System.out.println ("y1 is valid: " + verifyRsa (e, n, y1, x));
        System.out.println ("y2 is valid: " + verifyRsa (e, n, y2, x));
    }

    private static void generateX () {
        BigInteger s = new BigInteger ("30771931851803123741886562372298615155696330435975237661714002840641542197296");
        BigInteger e = new BigInteger ("67");
        BigInteger n = new BigInteger ("85212746447079824936395777044274071120738223794208795362205208542665542508313");
        
        BigInteger x = s.modPow (e, n);
        System.out.println ("possible x: " + x);
    }

    // True if the signature is valid on x for the given public key (e, n)
    private static boolean verifyRsa 
        (BigInteger e, BigInteger n, 
         BigInteger signature, BigInteger x) {
        return signature.modPow(e, n).equals(x.mod(n));
    }
  
    // Compute a hash.  It might not be CR or OW, but you can
    // assume it is.  If you break the CR or OW property,
    // we'll give you a bonus point.  The output is assumed to be
    // 128 bits
    // Return: h = 2^a[0] * 3^a[1] * ... * p_numElements^a[numElements-1] mod modulus
    // where p_k is the kth prime number
    private static BigInteger H 
        (BigInteger a[], int numElements, BigInteger modulus) {
        
        BigInteger result = BigInteger.ONE;
        BigInteger largestPrime = BigInteger.ZERO;
        BigInteger nextPrime;
        
        for (int i = 0; i < numElements; i++) {
            nextPrime = getNextPrime(largestPrime);
            largestPrime = nextPrime;
            result = result.multiply(nextPrime.modPow(a[i],modulus)).
                mod(modulus);
        }
        
        BigInteger mm = TWO.pow (128);

        return result.mod (mm);
    }
    
    // requires: q is odd, 0, or 2
    // returns: the smallest prime greater than q
    private static BigInteger getNextPrime(BigInteger q){
        if (q.equals(BigInteger.ZERO)) {
            return TWO;
        } else if (q.equals(TWO)) {
            return new BigInteger("3");
        } else {
            q = q.add(TWO);
            while(!q.isProbablePrime(CERTAINTY)) {
                q = q.add(TWO);
            }
            return q;
        }
    }
    
    private static BigInteger genRsaSignatureShare 
        (BigInteger s, BigInteger n, 
         BigInteger x,
         BigInteger l) {
        
        return x.modPow(TWO.
            multiply(factorial(l)).multiply(s), n);
    }
  
    private static boolean verifySignatureShare 
        (BigInteger share, BigInteger n,
         BigInteger V, BigInteger c,
         BigInteger v, BigInteger z, BigInteger xhat,
         BigInteger vt, BigInteger xt) {
        
        BigInteger[] a = new BigInteger[6];
        a[0] = V;
        a[1] = xhat;
        a[2] = v;
        a[3] = share.modPow (TWO, n);
        a[4] = vt;
        a[5] = xt;

        BigInteger result = H(a, 6, n); 

        return result.equals(c);
    }
  
    private static BigInteger combineRsaSignatureShares 
        (BigInteger sigShare[], int sigShareIndex[],
         int numShares,
         BigInteger x, BigInteger e,
         BigInteger n, BigInteger l) {
        
        BigInteger w = lambdaProduct(sigShare, sigShareIndex, numShares, l, n);
        System.out.println ("w = " + w);
        BigInteger et = factorial(l).pow(2).multiply(BigInteger.valueOf(4));
        System.out.println ("e' = " + et);
        BigInteger ab[] = extendedEuclid(et, e);
        System.out.println ("a = " + ab[1] +
                            "\nb = " + ab[2]);

        return w.modPow(ab[1], n).multiply(x.modPow(ab[2], n)).mod(n);
    
    }
  
    // Given an array of signature shares and the indexes (whose share it is)
    // return the value w 
    private static BigInteger lambdaProduct 
        (BigInteger sigShare[], int sigShareIndex[], 
         int numShares,
         BigInteger l, BigInteger n) {

        BigInteger w = BigInteger.ONE;
        for (int k = 0; k < numShares; k++) {
            BigInteger lambda = lambdaExponent
                (sigShareIndex, numShares, sigShareIndex[k], l);
            System.out.println ("lambda[" + sigShareIndex[k] + "] = " 
                                + lambda);
            w = w.multiply (sigShare[k].modPow (lambda, n)).mod (n);
        }
        return w;
    }
  
    // might be helpful subroutine
    private static BigInteger lambdaExponent 
        (int sigShareIndex[], int numShares, int j, BigInteger l) {
        
        BigInteger lambda = factorial (l);
        int numer = 2;
        int denom = 1;

        for (int k = 0; k < numShares; k++) {
            if (j != sigShareIndex[k]) {
                numer = numer * (0 - sigShareIndex[k]);
                denom = denom * (j - sigShareIndex[k]);
            }
        }

        lambda = lambda.multiply (BigInteger.valueOf (numer)).
            divide (BigInteger.valueOf (denom));
				       
        return lambda;
    }
  
    private static BigInteger[] extendedEuclid (BigInteger a, BigInteger b) {
        // [c, x, y] where gcd(a,b) = c = xa + yb
        BigInteger[] r;
        if (b.equals (BigInteger.ZERO)) { 
            r = new BigInteger[3];
            r[0] = a; // d
            r[1] = BigInteger.ONE; // x
            r[2] = BigInteger.ZERO; // y
        } else {
            r = extendedEuclid (b, a.mod (b));
            BigInteger t = r[2]; // y
            r[2] = r[1].subtract (a.divide (b).multiply (r[2]));
            r[1] = t;
        }

        return r;
    }
  
  
    private static BigInteger factorial (BigInteger l) {
        BigInteger result = BigInteger.ONE;
        
        while (l.compareTo(BigInteger.ONE) >= 0) {
            result = result.multiply (l);
            l = l.subtract(BigInteger.ONE);
        }

        return result;
    }

  
    /* Code for generation of proof of correctness numbers 
     * for public share, custom made...
     * Returns [z, c] 
     */
    private static BigInteger[] generateCorrectnessProof  
        (BigInteger si,
         BigInteger n, BigInteger V,
         BigInteger vi, BigInteger xi, 
         BigInteger xhat) {
    
        BigInteger r = new BigInteger (n.bitLength() + 128, rnd);

        //c = H(v,xh, vi,xi^2,V^r,Xh^r)
        BigInteger[] hArray = new BigInteger[6];
        hArray[0] = V;
        hArray[1] = xhat;
        hArray[2] = vi;
        hArray[3] = xi.modPow (TWO, n);
        hArray[4] = V.modPow (r, n);  //v`=V^r
        hArray[5] = xhat.modPow (r, n); //x`=Xhat^r
        
        BigInteger[] result = new BigInteger[2];
        result[1] = H (hArray, 6, n);
        
        //z = si*c + r
        result[0]= si.multiply (result[1]). add(r);
        
        return result;
    }
    
    // Return: Sigma_{i=0}^{threshold - 1} a[i]*group^i
    // Requires groups >= 1
    private static BigInteger polynomial 
        (int group, int threshold,
         BigInteger a[], BigInteger modulus) {

        BigInteger l = BigInteger.valueOf (group);

        BigInteger r = BigInteger.ZERO;
        for (int i = threshold - 1; i >= 0; i--) {
            r = r.multiply (l).add (a[i]).mod (modulus);
        }
        
        return r;
    }
    
    // Returns [a safe prime p, (p - 1) / 2]
    private static BigInteger[] safePrime (int bitlen, int certainty) {
        BigInteger p[] = new BigInteger[2];
        
        do {
            p[1] = new BigInteger (bitlen, certainty, rnd);
            p[0] = p[1].multiply (TWO).add (BigInteger.ONE);
        } while (! p[0].isProbablePrime (certainty));

        return p;
    }
    
    
    // Returns a number uniformly in the range 0..n-1
    private static BigInteger bigRandom (BigInteger n) {
        
        BigInteger r = n;
    
        // Note, we must toss out any random number out of range.
        // Other operations like truncation would throw the distribution
        // out of wack
        while (r.compareTo (n) >= 0 || r.compareTo (BigInteger.ZERO) < 0) {
            r = new BigInteger (n.bitLength (), rnd);
        }
    
        return r;
    }
  
    // Return true if "a" is a QR mod p where p prime.  False otherwise.
    // Requires p prime.
    private static boolean QR (BigInteger a, BigInteger p) {
        // QR iff a^((p-1)/2) = 1 mod p
        return a.modPow (p.subtract (BigInteger.ONE).divide (TWO), p).
                         equals (BigInteger.ONE);
    }
  
    // Return a random quadratic residue mod pq
    private static BigInteger bigRandQR (BigInteger p, BigInteger q) {
        BigInteger r;
        int len = p.bitLength () + q.bitLength ();
        BigInteger n = p.multiply (q);
        do {
            r = new BigInteger (len, rnd);
            System.out.print (".");
        } while (r.compareTo (n) >= 0 || r.compareTo (BigInteger.ZERO) < 0 ||
                 ! QR (r, p) || ! QR (r, q));
        System.out.println ();
        return r;
    }
}
