#!/usr/bin/python

#core imports
from __future__ import print_function
from __future__ import division
from builtins import str
from builtins import range
from past.utils import old_div
import argparse
import sys
import os
import math
import re
from datetime import datetime

#non-core imports
import numpy as np
import scipy.stats

import matplotlib
matplotlib.use("Agg")

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
matplotlib.rcParams.update({'savefig.dpi':300})
import pylab
import matplotlib.patches as mpatches

#HOPS module imports
import vpal.processing
import vpal.utility
import vpal.fringe_file_manipulation

import mk4b
import hopstestb as ht

################################################################################

def main():
    # usage_text = '\n plot_ddTEC_minsnr.py [options] <control-file> <ref_station> <stations> <experiment-directory>' \
    #              '\n e.g.: plot_ddTEC_minsnr.py ./cf_GHEVY_ff P HILM ./'
    # parser = optparse.OptionParser(usage=usage_text)

    parser = argparse.ArgumentParser(
        prog='plot_ddTEC_minsnr.py', \
        description='''utility for plotting ddTEC vs min_snr for polarization products''' \
        )

    parser.add_argument('control_file', help='the control file to be applied to all scans')
    parser.add_argument('ref_station', help='single character code of station for channel-by-channel residual analysis')
    parser.add_argument('stations', help='concatenated string of single character codes for remote stations to use')
    #parser.add_argument('pol_product', help='the polarization-product to be fringe fit')
    parser.add_argument('experiment_directory', help='relative path to directory containing experiment data')

    #parser.add_argument('-c', '--channels', dest='channels', help='specify the channels to be used, default=abcdefghijklmnopqrstuvwxyzABCDEF.', default='abcdefghijklmnopqrstuvwxyzABCDEF')
    #parser.add_argument('-s', '--snr-min', type=float, dest='snr_min', help='set minimum allowed snr threshold, default=15.', default=20.)
    #parser.add_argument('-q', '--quality-limit', type=int, dest='quality_lower_limit', help='set the lower limit on fringe quality (inclusive), default=3.', default=6)
    

    args = parser.parse_args()

    #print('args: ', args)

    #bandA = 'abcdefgh'
    #bandB = 'ijklmnop'
    #bandC = 'qrstuvwx'
    #bandD = 'yzABCDEF'

    #bands = ['abcdefgh', 'ijklmnop', 'qrstuvwx' , 'yzABCDEF']
    
    control_file = args.control_file
    ref_station = args.ref_station
    stations = args.stations
    #polprod = args.pol_product
    exp_dir = args.experiment_directory

    abs_exp_dir = os.path.abspath(exp_dir)
    exp_name = os.path.split(os.path.abspath(exp_dir))[1]

    if not os.path.isfile(os.path.abspath(control_file)):
        print("could not find control file: ", control_file)
        sys.exit(1)

    #pol product:
    #if polprod not in ['XX', 'YY', 'XY', 'YX', 'I']:
    #    print("polarization product must be one of: XX, YY, XY, YX, or I")
    #    sys.exit(1)

    polproducts = ['XX','XY','YX','YY']
    
    #determine all possible baselines
    print('Calculating baselines')

    baseline_list = vpal.processing.construct_valid_baseline_list(abs_exp_dir, ref_station, stations, network_reference_baselines_only=True)
    #baseline_list = ['EV', 'SV', 'TV']
    #baseline_list = ['SV', 'TV']
    print('Baselines:', baseline_list)
    
    #qcode_list = []
    #for q in list(range(args.quality_lower_limit, 10)):
    #    qcode_list.append( str(q) )

    #needed for plot-naming
    #control_file_stripped = re.sub('[/\.]', '', control_file)

    #default output filename
    plot_name = "./ddTEC_minsnr_" + ref_station + "_" + stations + '_' + exp_name

    # initialize a dictionary to hold lists of channel phase residuals; this stores the nominal 32 VGOS channels
    #channel_amp = dict()
    #for ch in args.channels:
    #    channel_amp[ch] = list()

    #freqs = dict()
    #YY_amplitudes = dict()
    #XX_amplitudes = dict()
    #for ch in args.channels:
    #    freqs[ch] = -1.0
    #    XX_amplitudes[ch] = list()
    #    YY_amplitudes[ch] = list()

    results = {}
    
    for bline in baseline_list:

        print('Collecting fringe files for baseline',bline)

        # baselines may have different channel lists (eg a station may ignore some noisy channels),
        # but the frequencies should all be the same
        #channel_freqs = list()

        
        #need to:
        #(1) collect all of the type_210 phase residuals,
        #(2) apply the snr, and quality code cuts
        #(3) for each channel, insert phase residual values and time stamps into array
        #(4) compute mean phase residual for each channel and remove it

        ################################################################################
        #collect/compute fringe files, and apply cuts
        #set_commands = "set gen_cf_record true"
        #ff_list_pre = vpal.processing.load_and_batch_fourfit( \
        #    os.path.abspath(exp_dir), bline[0], bline[1], os.path.abspath(control_file), set_commands, \
        #    num_processes=args.num_proc, start_scan_limit=args.begin_scan_limit, \
        #    stop_scan_limit=args.end_scan_limit, pol_products=[polprod], use_progress_ticker=args.use_progress_ticker \
        #)


        #ff_list_pre = vpal.processing.gather_fringe_files(exp_dir, control_file, [bline], pol_products=polprod, max_depth=2)
        ff_list = vpal.processing.gather_fringe_files(exp_dir, control_file, [bline], pol_products=polproducts, max_depth=2)
        
        print("n fringe files  =", str(len(ff_list)))
        
        #apply cuts
        #ff_list = []

        #for ff in ff_list_pre:
        #    if ff.snr >= args.snr_min and ff.quality >= args.quality_lower_limit:
        #        ff_list.append(ff)
        
        if len(ff_list) == 0:
            print("Error: no fringe files available after cuts, skipping baseline: ", bline)

        else:


            results[bline] = {}
            
            
            for ff in ff_list:

                if ff.scan_name not in results[bline].keys():
                    results[bline][ff.scan_name] = {}
                    results[bline][ff.scan_name]['XX'] = {}
                    results[bline][ff.scan_name]['XX']['snr'] = None
                    results[bline][ff.scan_name]['XX']['dTEC'] = None
                    results[bline][ff.scan_name]['XX']['mbd'] = None

                    results[bline][ff.scan_name]['XY'] = {}
                    results[bline][ff.scan_name]['XY']['snr'] = None
                    results[bline][ff.scan_name]['XY']['dTEC'] = None
                    results[bline][ff.scan_name]['XY']['mbd'] = None

                    results[bline][ff.scan_name]['YX'] = {}
                    results[bline][ff.scan_name]['YX']['snr'] = None
                    results[bline][ff.scan_name]['YX']['dTEC'] = None
                    results[bline][ff.scan_name]['YX']['mbd'] = None

                    results[bline][ff.scan_name]['YY'] = {}
                    results[bline][ff.scan_name]['YY']['snr'] = None
                    results[bline][ff.scan_name]['YY']['dTEC'] = None
                    results[bline][ff.scan_name]['YY']['mbd'] = None


                # this function returns ['I'] for Ixy results, so we should be ok
                polprod = ht.get_file_polarization_product_provisional(ff.filename)
                #print(ff.scan_name, ff.filename, polprod[0])
                if len(polprod)>1:
                    continue
                if polprod[0] not in ['XX','XY','YX','YY']:
                    continue

                #print(ff.scan_name, ff.filename, polprod[0])
                results[bline][ff.scan_name][polprod[0]]['snr'] = ff.snr
                results[bline][ff.scan_name][polprod[0]]['dTEC'] = ff.dtec
                results[bline][ff.scan_name][polprod[0]]['mbd'] = ff.mbdelay

                    
                
        
    #print("Number of scans:", len(XX_amplitudes[args.channels[0]]))
    
    # ok, we've collected the fringe results

    # now loop through baselines, for each scan check that the pol-prods are complete, calculate min_snr and ddTEC


    # in fringe_file_manipulation.SingleBaselinePolarizationProductCollection, the ddTEC is the max deviation from the mean (not the median)
    ddTEC = []
    min_snr = []

    for bline in results.keys():

        for scan in results[bline].keys():

            snr = [v['snr'] for k,v in results[bline][scan].items()]
            dtec = [v['dTEC'] for k,v in results[bline][scan].items()]

            if any(yy==None for yy in snr) or any(zz==None for zz in dtec):
                print('Scan', scan, 'is not complete on baseline',bline,'!',snr, dtec)
                continue
            else:
                #print('Scan', scan, 'is complete!',snr, dtec)
                mean_dTEC = np.mean(dtec)
                max_dTEC_deviation_from_mean = np.max(np.abs(dtec-mean_dTEC))
                ddTEC.append(max_dTEC_deviation_from_mean)
                min_snr.append(np.min(snr))

                #if np.min(snr)>20 and max_dTEC_deviation_from_mean>2:
                #    print(scan,bline,np.round(dtec,1),np.round(snr,1),np.round(max_dTEC_deviation_from_mean,1))

                

    #sys.exit()

    
    
        
    fig_width_pt = 600  # Get this from LaTeX using \showthe\columnwidth
    inches_per_pt = 1.0/72.27               # Convert pt to inch
    #golden_mean = (2.236-1.0)/2.0         # Aesthetic ratio
    golden_mean = 1.0
    fig_width = fig_width_pt*inches_per_pt  # width in inches
    fig_height = fig_width*golden_mean      # height in inches
    fig_size = [fig_width,fig_height]
                                                                        

    matplotlib.rcParams.update({'savefig.dpi':300,
    #                            'text.usetex':True,
                                'figure.figsize':fig_size})#,
    #                            'font.family':"serif",
    #                            'font.serif':["Times"]})
    
    
    fig = pylab.figure(np.random.randint(0,1000))

    ax0 = plt.subplot(111)

    pylab.plot(min_snr, ddTEC,'k.',markersize=4)

    pylab.grid(True, which='both', linestyle=':', alpha=0.6)
    pylab.xticks(fontsize=10)
    pylab.yticks(fontsize=10)
    pylab.xlabel('min SNR',fontsize=11)
    pylab.ylabel('max dTEC deviation from the mean [TECU]',fontsize=11)
    pylab.yscale('log')
    pylab.xscale('log')
    
    #pylab.title('',fontsize=12)
    
    pylab.savefig(plot_name + '.png', bbox_inches='tight')
    pylab.close()



if __name__ == '__main__':          # official entry point
    main()
    sys.exit(0)
