#!/usr/bin/python
#
# fourphase - determines optimal delays and phases for coherently combining
# 4 VGOS polarization products, and writes them into a fourfit control file
#
# first created                         2015.6.8  rjc
# modified for > 2 stations             2016.3.16 rjc

import datetime
import optparse
import re
import string
import sys
import os
import math
import numpy as np

from subprocess import Popen, PIPE

def main():
    usage_text = '\n fourphase [options] <stations> <root_filename>' \
                 '\n e.g.: fourphase GKE -c ../cf_3419 3C279.xyzzys'
    parser = optparse.OptionParser(usage=usage_text)
    parser.add_option(
        '-c', '--controlfile', dest='cfile', help='control-file name')

    parser.add_option(
        '-o', '--outputfile', dest='ofile', help='output-file name, overrides default ' \
                                                 'of I appended to control-file name')

    parser.add_option(
        '-p', '--plot', action='store_true', dest='fplot', help='display ff plots (false)',
        default=False)

    parser.add_option(
        '-t', '--test', action='store_true', dest='test', 
        help='test mode with pre-existing file (false)', default=False)

    parser.add_option(
        '-v', '--verbose', action='store_true', dest='verbose', help='verbose mode (false)',
        default=False)

    parser.add_option(
        '-i', '--ion_original', action='store_true', dest='ion_original', 
        help='use original cf iono. model (false)', default=False)

    (opts, args) = parser.parse_args()

    if len(args) != 2:
        print "use -h option for help"
        sys.exit(0)

    if opts.verbose:
        print 'opts: ', opts
        print 'args: ', args
    if opts.test:
        print 'fourphase running in test mode'
                                    # initialization
    global pol_prods, chans
    pol_prods = ['XX', 'YY', 'XY', 'YX']
    chans = 'abcdefghijklmnopqrstuvwxyzABCDEF '

    stations, root = args
    suffix = root[-6:]
                                    # generate list of all possible baselines
    blist = []
    for s1 in stations:
        for s2 in stations:
            blist.append (s1+s2+'..'+suffix)

    for bl in blist:
        if bl[0] == bl[1]:
            blist.remove (bl)

    k = root.rfind ('/')            # form path for root's directory
    if k < 0:
        path = '.'
    else:
        path = root[0:k]

    filz = os.listdir (path)
    blines = []
    for bl in blist:                # winnow exhaustive bl list based on actual files present
        if bl in filz:
            blines.append (bl[0]+bl[1])
    if opts.verbose:
        print 'baselines:'
        for bl in blines:
            print bl

        

    sbd = []
    mbd = []
    phase = []
    tec = []
    ff = 'fourfit'
                                    # insert control file name, if specified
    if opts.cfile:
        cf = '-c ' + opts.cfile
    else:
        cf = ''
                                    # determine output file name
    if opts.test:
        of_name = opts.cfile
    else:
        if opts.ofile:
            of_name = opts.ofile
        else:
            of_name = opts.cfile + 'I'
                                    # are fringe plots desired?
    if opts.fplot:
        ff = 'fourfit -pt ' + cf
    else:
        ff = 'fourfit -t ' + cf

                                    # find best-fit ionosphere tec
    st_tec = all_blines_ion (ff, blines, stations, root, chans, opts)
#   st_tec = np.array ([-4.60069738, 9.47987958,-4.8791822 ])
    if opts.verbose:
        print 'dTEC by station for', stations, ':\n', st_tec

                                    # delay and phase fits
    if not opts.test:
        bproducts = all_blines_phd (ff, blines, stations, st_tec, root, chans, opts)
                                    # do statistical analysis to catch errors
        bproducts = bp_analyze (blines, bproducts, opts)
                                    # now fit ph & del params to data
        ydelays, yphases = fit_phd (blines, stations, bproducts, opts)
        if opts.verbose:
            print 'ydelay (ns)  by station:', '%7.3f ' * len(ydelays) % tuple (ydelays) 
            print 'yphase (deg) by station:', '%7.1f ' * len(yphases) % tuple (yphases) 

                                    # create output file with new params
        write_out (of_name, root, stations, st_tec, ydelays, yphases, opts)
                                    # see if new file works well
    ff = 'fourfit -t -c ' + of_name
    test_out (of_name, ff, blines, stations, root, st_tec, opts)

    sys.exit()



# function to find best-fit ionosphere for all baselines

def all_blines_ion (ff, blines, stations, root, chans, opts):
    print '\nfitting ionosphere tec...'
    bl_tec = []
    bl_combo_snr = []
    for bl in blines:               # determine best-fit tec for each baseline
        [dtec, combo_snr] = bline_ion (ff, bl, root, chans, opts)
        bl_tec.append (dtec)
        bl_combo_snr.append (combo_snr)

    wgt = []                        # form weight array
    snr_min = min (bl_combo_snr)
    for bcs in bl_combo_snr:
        wgt.append (bcs / snr_min)
    wgt.append (1.0)                # weight of sum pseudo-obs.
    wgt = np.array (wgt)

                                    # build normal equations
    A = []
    B = []

    for bl in blines:
        arow = np.zeros (len (stations))
        s = 0
        for st in stations:
            if bl[0] == st:
                arow[s] = -1
            if bl[1] == st:
                arow[s] = 1
            s += 1
        A.append (arow)
    A.append (np.ones (len (stations)))

    A = np.array (A)

    Aw = A * wgt[:, None]

    bl_tec.append (0.0)            # append sum pseudo-obs.
    Bw = np.multiply (np.array (bl_tec), wgt)

    rvals = np.linalg.lstsq (Aw, Bw)

    st_tec = rvals[0]
    return st_tec





# function to find best ionosphere for 1 baseline

def bline_ion (ff, bl, root, chans, opts):
    refstn = bl[0]
    remstn = bl[1]
    bline = '-b'+bl
    setstring = zero_string (bl)
                                    # tack on ion. search stmts.
    if not opts.ion_original:
        setstring += 'if\n'
        setstring += 'ion_npts 41 ion_win -80.0 80.0\n'

    sum_dtec_weighted = 0.0
    sum_weights = 0.0

    dtec_min = 1e6
    dtec_max = -1e6
    for pp in pol_prods:
        [snr, sb, mb, phi, dtec] = pol_prod (ff, bline, pp, root, setstring)
        if opts.verbose:
            print 'base', bl, 'pol', pp, 'snr', snr, 'dtec', dtec
        sum_dtec_weighted += dtec * snr * snr
        sum_weights += snr * snr
        if dtec > dtec_max:
            dtec_max = dtec
        if dtec < dtec_min:
            dtec_min = dtec
    dtec = sum_dtec_weighted / sum_weights
    weight = math.sqrt (sum_weights)
                                    # detect overly large ion discrepancy
    dtec_ptop = dtec_max - dtec_min
    if dtec_ptop > 3.0:
        print '\n**** ERROR **** bad ionosphere fit! dtec on baseline', bl, \
              'differed by %6.1f ' % dtec_ptop, 'TEC units'
    elif dtec_ptop > 1.0:
        print '\n**** WARNING **** check ionosphere fit! dtec on baseline', bl, \
              'differed by %6.1f ' % dtec_ptop, 'TEC units'

    return dtec, weight





# function to generate setstring for 0 a priori case

def zero_string (bl):
                                    # force a priori phases and delays to zero
    setstring = 'set if station ' + bl[0] + '\n'
    setstring += 'pc_delay_x 0.0' + '\n'
    setstring += 'pc_delay_y 0.0' + '\n'
    setstring += 'pc_phases_x ' + chans + 32 * ' 0.0' + '\n'
    setstring += 'pc_phases_y ' + chans + 32 * ' 0.0' + '\n'

    setstring += '\nif station ' + bl[1] + '\n'
    setstring += 'pc_delay_x 0.0' + '\n'
    setstring += 'pc_delay_y 0.0' + '\n'
    setstring += 'pc_phases_x ' + chans + 32 * ' 0.0' + '\n'
    setstring += 'pc_phases_y ' + chans + 32 * ' 0.0' + '\n'

    return setstring





# function to invoke fourfit on 1 polarization product

def pol_prod (ff, bline, polar, root, extrastring):
    dtec = 0.0
    polar = '-P' + polar
    msglev = '-m1'
    pargs = ff.split() + [bline, polar, msglev, root] + extrastring.split ()
                                    # invoke fourfit via a pipe
    try:
        p = Popen (pargs, stdout=PIPE, stderr=PIPE)
    except OSError, e:
        print 'System error trying to run "', ff, '"'
        print e.strerror
        sys.exit ()

    output, stderr = p.communicate()

    for line in stderr.split('\n'):
        fields = line.split ()
        if re.search ('max555', line):
            sbd = float (fields[7])
            mbd = float (fields[9])
        elif re.search ('residual phase', line):
            phz = float (fields[3])
        elif re.search ('differential TEC', line):
            dtec = float (fields[5])
        elif re.search ('SNR', line):
            snr = float (fields[2])
                                    # ensure that values were found
    if 'sbd' in vars () and 'mbd' in vars () and 'phz' in vars ():
        return snr, sbd, mbd, phz, dtec
    else:
        print 'fourfit returned an error: ', stderr
        sys.exit ()





# function to find delay and phase parameters on all baselines

def all_blines_phd (ff, blines, stations, st_tec, root, chans, opts):
    print '\ngenerating baseline & pol-prod data using dTECs...'
    bproducts = []
    for bl in blines:
        dtec = 0.0
        setstring = zero_string (bl) + ' if ion_npts 1\n'
        s = 0
        for st in stations:         # form dtec = rem - ref tec
            if st == bl[0] or st == bl[1]:
                setstring += ' if station ' + st + ' ionosphere ' + str (st_tec[s])
            s += 1

        products = bline_phd (ff, bl, blines, root, chans, setstring, opts)
        bproducts.append (products)

    return bproducts





# function to find delay and phase parameters on 1 baseline & 4 polarization products

def bline_phd (ff, bl, blines, root, chans, setstring, opts):

    bline = '-b'+bl
    products = []

    p = 0
    for pp in pol_prods:
        [snr, sb, mb, phi, dtec] = pol_prod (ff, bline, pp, root, setstring)
        print 'baseline', bl, 'pol_prod', pp, 'mbd', mb, 'phase %6.1f' % phi, 'snr', snr
        if snr < 15:
            print '\n**** ERROR ****', pp, 'on', bl, 'has snr of only', snr
        elif snr < 25:
            print '\n**** WARNING ****', pp, 'on', bl, 'has snr of only', snr
        products.append ([mb, phi, snr])
        p += 1

    return products





# function to fit station y phases and delays to baseline data

# bproducts is triply-nested list, most-significantly indexed 
# by baseline, then pol-prod within baseline, and [mbd, phase, snr]

def fit_phd (blines, stations, bproducts, opts):

                                    # fit all x&y delays first
    mbd = [bproducts[u][v][0] for u in range(len(bproducts)) for v in range(4)]
    phz = [bproducts[u][v][1] for u in range(len(bproducts)) for v in range(4)]
    wgt = [bproducts[u][v][2] for u in range(len(bproducts)) for v in range(4)]
    wgt.append (1)                  # append nominal weight for X0 pseudo-obs

    wgt = np.array (wgt)
                                    # append x of 1st stn. pseudo-obs.
    mbd.append (0.0)
                                    # fit station y delays to mbd data
    delays = fit_data (blines, stations, mbd, wgt, opts)
    delays *= 1000                  # convert units to ns

                                    # append x of 1st stn. pseudo-obs.
    phz.append (0.0)
                                    # fit station y phases to phase data
    phases = fit_data (blines, stations, phz, wgt, opts)
    
    return delays, phases





# function to determine best-fit delays or phases using least-squares

def fit_data (blines, stations, data, wgt, opts):
                                    # build normal equations
    A = []
    B = []
    ns = len (stations)
    b = 0
    for bl in blines:
        p = 0
        for pp in pol_prods:
                                    # each row has both x&y per station
            arow = 2 * ns * [0] 
            sref = stations.index (bl[0])
            srem = stations.index (bl[1])

            xy = 0
            for pol in ['X', 'Y']:
                if pp[0] == pol:
                    arow[sref + xy] = -1
                if pp[1] == pol:
                    arow[srem + xy] = +1
                xy += ns
            A.append (arow)
            p += 1
        b += 1
                                    # constrain first station's X pol. to be 0
    A.append ([1] + (2 * ns - 1) * [0])    

    Aw = np.array (A) * wgt[:, None]

    Bw = np.multiply (np.array (data), wgt)

    rvals = np.linalg.lstsq (Aw, Bw)
    params = rvals[0]
                                    # make all stn params relative to xpol = 0
    for n in range (ns):            # and flip sign
        params[n+ns] = params[n] - params[n+ns]
                                    # ret y values, which are in 2nd half of the array
    return params[ns:2*ns]





# examine closure quantities to see if fits look OK
def bp_analyze (blines, bproducts, opts):
    b = 0
    dir = [1, 1, -1,-1]             # xx + yy - xy - yx

    for bl in blines:
        dsum = 0
        phsum = 0
        p = 0
        for pp in pol_prods:
            dsum  += bproducts[b][p][0] * dir[p] * 1e6
            phsum += bproducts[b][p][1] * dir[p]
            p += 1
            
        if opts.verbose:
            print 'baseline', bl, 'closure delay %8.1f ps' % dsum, 'closure phase %7.1f deg' % phsum
        if abs (dsum) > 25:
            print '\n**** ERROR **** delay misclosure on', bl, \
                  'pol_prods is %6.0f' % dsum, 'ps'
        elif abs (dsum) > 10:
            print '\n**** WARNING **** delay misclosure on', bl, \
                  'pol_prods is %6.0f' % dsum, 'ps'

                                    # find ambiguity combo with minimum rms
        phases = [0] * 4
        for i in range(4):
            phases[i] = bproducts[b][i][1]
        orig_phases = list (phases)
        orig_rms = rms (phases)
        min_rms = orig_rms
        best_phases = list (orig_phases)
                                    # search over all possible ambiguities
        for k in range(1, 15):
            phases = list (orig_phases)
            for i in range(4):
                phases[i] += 360 * ((k>>i) & 1)
            r = rms (phases)
            if r < min_rms:
                min_rms = r         # remember new minimum
                best_phases = phases

        if min_rms < orig_rms:
            for i in range(4):      # update original phases
                bproducts[b][i][1] = best_phases[i]
            print '\nnew phase ambiguities selected for baseline', bl
            print 'old phases: %7.1f %7.1f %7.1f %7.1f' % tuple(orig_phases), \
                  'rms %7.1f' % orig_rms
            print 'new phases: %7.1f %7.1f %7.1f %7.1f' % tuple(best_phases), \
                  'rms %7.1f' % min_rms

        # print 'reference:'
        # print 'original %s YX-XX' % bl, orig_phases[3] - orig_phases[0]
        # print 'original %s YY-XY' % bl, orig_phases[1] - orig_phases[2]
        # print 'modified %s YX-XX' % bl, best_phases[3] - best_phases[0]
        # print 'modified %s YY-XY' % bl, best_phases[1] - best_phases[2]
        # print 'remote:'
        # print 'original %s XY-XX' % bl, orig_phases[2] - orig_phases[0]
        # print 'original %s YY-YX' % bl, orig_phases[1] - orig_phases[3]
        # print 'modified %s XY-XX' % bl, best_phases[2] - best_phases[0]
        # print 'modified %s YY-YX' % bl, best_phases[1] - best_phases[3]

        if opts.verbose:
            print 'baseline', bl, 'closure delay %8.1f ps' % dsum, 'closure phase %7.1f deg' % phsum
                                    # recalculate phase sum after modifications
        phsum = 0
        p = 0
        for pp in pol_prods:
            phsum += bproducts[b][p][1] * dir[p]
            p += 1

        if abs (phsum) > 25:
            print '\n**** ERROR **** phase misclosure on', bl, \
                  'pol_prods is %6.0f' % phsum, 'deg'
        elif abs (phsum) > 10:
            print '\n**** WARNING **** phase misclosure on', bl, \
                  'pol_prods is %6.0f' % phsum, 'deg'

        b += 1
    return bproducts





# create output file and write new delay & phase params to it
def write_out (of_name, root, stations, st_tec, ydelays, yphases, opts):
                                    # open control file & new output
    ifile = open (opts.cfile, 'r')
    orig = ifile.read ()
    ifile.close ()

    ofile = open (of_name, 'w')
    ofile.write (orig)              # copy contents of input cf)

    now = datetime.datetime.now ()
    ofile.write ('* following lines added by fourphase on ' + str (now))
    ofile.write ('\n* by analysis of root file ' + root)

                                    # add lines for every station
    n = 0
    for stn in stations:
        ofile.write ('\nif station ' + stn + '\n')
        ofile.write ('  * following ionosphere was used to derive phase and delay\n')
        ofile.write ('  * ionosphere ' + '%8.3f' % (st_tec[n]) + '\n')
        ofile.write ('  pc_delay_x ' + ' 0.0' + '\n')
        ofile.write ('  pc_delay_y ' + '%8.3f' % (ydelays[n]) + '\n')

        pcpx = '  pc_phases_x abcdefghijklmnopqrstuvwxyzABCDEF'
        pcpy = '  pc_phases_y abcdefghijklmnopqrstuvwxyzABCDEF'

        for i in range(32):
            pcpx += ' 0.0'
            pcpy += '%7.1f' % (yphases[n])

        ofile.write (pcpx + '\n')
        ofile.write (pcpy + '\n')
        n += 1
    ofile.close ()
    return





# do consistency tests on 
def test_out (cf_name, ff, blines, stations, root, st_tec, opts):
    print '\nchecking derived offset parameters...'
    for bl in blines:               # test baselines one at a time
        mbds = []
        phis = []
        bline = '-b'+bl
        sref = stations.index (bl[0])
        srem = stations.index (bl[1])
        setstring = 'set ion_npts 1'
        setstring += ' if station ' + stations[sref] + ' ionosphere ' + str(st_tec[sref])
        setstring += ' if station ' + stations[srem] + ' ionosphere ' + str(st_tec[srem])

        for pp in pol_prods:        # gather data on each pol-prod per bl
            [snr, sb, mb, phi, dtec] = pol_prod (ff, bline, pp, root, setstring)
            mbds.append (mb)
            phis.append (phi)
            print pp, 'mbd', mb, 'phase %6.1f' % phi
                                    # adjust phases if wrap-around is detected
        if abs(phis[0]) + abs(phis[1]) + abs(phis[2]) + abs(phis[3]) > 360:
            for i in range(4):
                if phis[i] < 0:     # modify negative phases only
                    phis[i] += 360
                                    # do rms calculations, maybe issue warnings
        mb_rms = 1e6 * rms (mbds)
        phi_rms = rms (phis)

        print bl, 'rms mbd %7.1f' % mb_rms, 'ps   rms phase %6.1f' % phi_rms, 'deg\n'

        if mb_rms > 25:
            print '\n**** ERROR ****', bl, 'mb delay rms of %7.1f' % mb_rms, ' ps is too large'
        elif mb_rms > 10:
            print '\n**** WARNING ****', bl, 'mb delay rms of %7.1f' % mb_rms, ' ps is large'

        if phi_rms > 40:
            print '\n**** ERROR ****', bl, 'phase rms of %6.1f' % phi_rms, ' deg is too large'
        elif phi_rms > 10:
            print '\n**** WARNING ****', bl, 'phase rms of %6.1f' % phi_rms, ' deg is large'
    return





# calculate rms of a numeric list
def rms (x_array):
    sum = 0.0
    for x in x_array:               # find mean of x
        sum += x

    xbar = sum / len (x_array)

    sum = 0.0
    for x in x_array:
        sum += (x-xbar)**2

    return math.sqrt (sum / len(x_array))






if __name__ == '__main__':          # official entry point
    main()
    sys.exit(0)
