#!/usr/bin/env python3
# Authors: Farzaneh Aziz Zanjani & Falk Amelung 
import textwrap
import argparse
import math
import os
import numpy as np
import pandas as pd 
import glob
import mintpy
from osgeo import gdal
from mintpy.utils import readfile, utils as ut
import h5py
import datetime
from datetime import date
from datetime import datetime
from datetime import timedelta
from mintpy.objects import timeseries
from mintpy.dem_error import read_geometry
#from mintpy.objects import HDFEOS, giantTimeseries, timeseries
from mintpy.utils import ptime

###
REFERENCE = """reference:
  Fattahi, H., and F. Amelung (2013), DEM Error Correction in InSAR Time Series,
    IEEE Trans. Geosci. Remote Sens., 51(7), 4249-4259, doi:10.1109/TGRS.2012.2227761.
"""

EXAMPLE = """example: 
             ./DemError_pixel.py --lalo 25.877 -80.120 --exclude 20161015
"""
###
def cmd_line_parser():
    # This fuction gets the inputs and parameters from command line
         
    parser = argparse.ArgumentParser(description='Calculates Dem error for a pixel using timeseries.h5 file', epilog = REFERENCE  + '\n' + EXAMPLE)            
    parser.add_argument("--project_dir", metavar='', type=str, default="./network_single_reference", help = "path to the directory containing data files")
    parser.add_argument('--lalo', nargs=2, default=(25.793, -80.133), metavar='', type=float, help='latitude and longitude')
    parser.add_argument("--velocity", metavar='', type=str, default="./network_single_reference/velocity.h5", help = "path to the velocity file")
    parser.add_argument("--dem_error", metavar='', type=str, default="./network_single_reference/demErr.h5", help = "path to the Dem error file")
    parser.add_argument("--geometry", metavar='', type=str, default="./network_single_reference/inputs/geometryRadar.h5", help = "path to the geolocation file")
    parser.add_argument("--dsm", type=str, metavar='', default="./dsm_reprojected_wgs84.tif", help = "path to Lidar elevation file")
    parser.add_argument("--timeseries", metavar='', type=str, default="./network_single_reference/timeseries.h5", help = "path to timeseries file")
    parser.add_argument("--slcStack", metavar='', type=str, default="./network_single_reference/inputs/slcStack.h5", help = " slcstack file")
    parser.add_argument('--exclude', metavar='', nargs='*', default=[], help='Exclude date(s) for DEM error estimation.')
    args = parser.parse_args()

    return args

####
def creat_matrix():
    """ Creates matrix A from Fattahi & Amelung (2013) to calculate dem error using phase velocity
    Phase velocity is calculated from timeseries.h5 file : (4 *pi /lamda )*(dx(N) - dx(N-1) / time(N) - time(N-1)) 
    """
    args=cmd_line_parser()
    project_dir=args.project_dir
    geo_file=args.geometry
    ts_file=args.timeseries
    dem = args.dem_error
    excludeDate = args.exclude

    ts_obj = timeseries(ts_file)
    metadata = ts_obj.get_metadata()
    sin_inc_angle, range_dist, pbase = read_geometry(ts_file, geo_file, box=None)
    wavelength = float(metadata['radarWavelength'])
   
    ref_date = datetime.strptime(str(metadata['REF_DATE']), '%Y%m%d')
    ref_time = datetime.timestamp(ref_date) / 31556736
    date_list = ts_obj.get_date_list()
    lat1=args.lalo[0]
    lon1=args.lalo[1]
    points_lalo = np.array([[lat1, lon1]])
    attr = readfile.read_attribute(ts_file)
    coord = ut.coordinate(attr, geo_file)
    yg1, xg1 = coord.geo2radar(points_lalo[0][0], points_lalo[0][1])[0:2]
    time, atr = readfile.read(ts_file, datasetName='timeseries')
    dem_e = readfile.read(dem, datasetName='dem')[0]
    dem_xy = dem_e[yg1][xg1]
    conv1 = (4 *  math.pi) / wavelength 

    num = range(len(excludeDate))
    for j in num:
        n = date_list.index(excludeDate[j])
        del date_list[n]
        time = np.delete(time,j, axis = 0)

    l = len(date_list)
    values = range (1,l)

    A = []
    P = []
    for i in values :
        date_n = datetime.strptime(date_list[i], '%Y%m%d')
        time_n = datetime.timestamp(date_n)/ 31556736
        daten_minus = datetime.strptime(date_list[i-1], '%Y%m%d')
        timen_minus = datetime.timestamp(daten_minus) / 31556736
        dx = time[i, yg1 , xg1] - time[i-1, yg1 , xg1]
        dt = time_n - timen_minus
        perpb = float(pbase[i] - pbase[i-1]) / dt
        r = range_dist[i]
        sin_angle = sin_inc_angle[i]
        conv2 = r * sin_angle
        if i == 1:
           col1 = 1
           col2 = (time_n - ref_time)/2
           col3 = (time_n - ref_time)**2/6
           col4 = perpb * conv1 / conv2
           phasevelocity = dx * conv1 / dt
         #  phasevelocity = dx / dt
        else:
           col1 = 1
           col2 = ((time_n + timen_minus) - (2* ref_time))/2
           dt2 = (time_n - ref_time)**3
           dt1 = (timen_minus - ref_time)**3
           col3 = (dt2 - dt1) / (6 * dt )
           col4 = perpb * conv1 / conv2
           phasevelocity = dx * conv1 / dt
        Ai = col1, col2, col3, col4
        A.append([Ai])
        P.append([phasevelocity])
    A = np.array(A).reshape(l-1, 4)
    return A, P, dem_xy

####
""" Calculating X from Fattahi & Amelung (2013) 
"""

A, P, d = creat_matrix()
AT = np.transpose(A)
X = (np.linalg.inv(AT @ A)) @ AT @ P

print("Dem error :", X[3])
print("Dem error from mintpy :", d)



