#!/usr/bin/env python3
#################################################################
# Program is used for concatenating data under radar geometry   #
# Author: Lv Xiaoran                                            #
# Created: Mar 2023                                             #
#################################################################

import os
import argparse
import numpy as np
import h5py
import copy
import pandas as pd
import math
from sklearn.neighbors import KNeighborsClassifier
from functools import partial

import mintpy
import mintpy.workflow
from mintpy.utils import readfile, ptime, writefile, utils as ut
from mintpy.objects import timeseries

import mimtpy.workflow
######################################################################################
NOTE = """
Please Note:
1. Four types of data are supported: velocity, timeseries, maskPS and maskTempCoh
2. This script concatenates two datasets together. The input files are object datasets, their corresponding geometryRadar.h5 files, and the whole region geometryRadar.h5 file processed with 1:1 looks
3. If a batch concatenation needed, please use the concatnate_patches.py script.
4. For timeseries, datasets should have same reference date

"""

EXAMPLE = """example:

    concatenate_radarGeo.py miaplpy_NE/velocity_msk.h5  miaplpy_NNE/velocity_msk.h5 --geo-file1 miaplpy_NE/inputs/geometryRadar.h5 --geo-file2 miaplpy_NNE/intpus/geometryRadar.h5 --geo-full ./inputs/geometryRadar.h5 --geo-write --out-suffix NE_NNE --outdir miaplpyBig

    concatenate_radarGeo.py miaplpy_NE/velocity_msk.h5  miaplpy_NNE/velocity_msk.h5 --geo-file1 miaplpy_NE/inputs/geometryRadar.h5 --geo-file2 miaplpy_NNE/intpus/geometryRadar.h5 --geo-full ./inputs/geometryRadar.h5  --out-suffix NE_NNE --outdir miaplpyBig
    
    concatenate_radarGeo.py miaplpy_NE/timeseries_msk.h5 miaplpy_NNE/timeseries_msk.h5 --geo-file1 miaplpy_NE/inputs/geometryRadar.h5 --geo-file2 miaplpy_NNE/inputs/geometryRadar.h5 --geo-full ./inputs/geometryRadar.h5  --out-suffix NE_NNE --outdir ./miaplpyBig/

    concatenate_radarGeo.py miaplpy_NE/maskPS.h5  miaplpy_NNE/maskPS.h5 --geo-file1 miaplpy_NE/inputs/geometryRadar.h5 --geo-file2 miaplpy_NNE/intpus/geometryRadar.h5 --geo-full ./inputs/geometryRadar.h5  --out-suffix NE_NNE --outdir miaplpyBig

    concatenate_radarGeo.py miaplpy_NE/maskTempCoh.h5  miaplpy_NNE/maskTempCoh.h5 --geo-file1 miaplpy_NE/inputs/geometryRadar.h5 --geo-file2 miaplpy_NNE/intpus/geometryRadar.h5 --geo-full ./inputs/geometryRadar.h5  --out-suffix NE_NNE --outdir miaplpyBig
"""


def cmd_line_parse():
    parser = argparse.ArgumentParser(description='Concatenate miaplpy patches',
                                     formatter_class=argparse.RawTextHelpFormatter,
                                     epilog=NOTE + '\n' + EXAMPLE)

    parser.add_argument('patch_files', nargs='+', type=str, help='two displacement datasets to be concatenated \n')

    parser.add_argument('--geo-file1', nargs=1, type=str, help='geometryRadar file of dataset1. \n')
    
    parser.add_argument('--geo-file2', nargs=1, type=str, help='geometryRadar file of dataset2. \n')
    
    parser.add_argument('--geo-full', nargs=1, type=str, help='Whole region geometryRadar.h5 file processed with 1:1 looks. \n')
    
    #parser.add_argument('--datatype', nargs=1, type=str, help='data type: vel, ts, maskPS, maskTC\n')
   
    parser.add_argument('--geo-write',action='store_true', default=False, help='whether write the concatenated geometryRadar.h5 results. \n')
 
    parser.add_argument('--out-suffix', nargs=1,  default=[''], help='suffix of output name of concatenated data. \n')

    parser.add_argument('--outdir',dest='outdir',nargs=1, default=[],
                        help='output directory to store the concatenated results')
    inps = parser.parse_args()  
    
    return inps

def flatten_trans(x):
    original_shape = x.shape
    return partial(np.reshape, newshape=original_shape)

def geo_position_backup(lat_sub, lon_sub, lat, lon, ref_flag=None):
    if ref_flag is None:
        lat_corner = lat_sub[0][0] # use the upper left point
        lon_corner = lon_sub[0][0]
    elif ref_flag == 1 or ref_flag == 4 or ref_flag == 6 or ref_flag == 7:
        lat_corner = lat_sub[0][0] # use the upper left point
        lon_corner = lon_sub[0][0]
    elif ref_flag == 2 or ref_flag == 3 or ref_flag == 8:
        lat_corner = lat_sub[0][-1] # use the upper right point
        lon_corner = lon_sub[0][-1]
    elif ref_flag == 5:
        lat_corner = lat_sub[-1][0] # use the lower left point
        lon_corner = lon_sub[-1][0]

    lat_flag = (lat == lat_corner)
    lon_flag = (lon == lon_corner)

    flag = lat_flag * lon_flag
    pos_flag = np.where(flag == True)

    return pos_flag[0][0], pos_flag[1][0]

def get_position(lat_sub, lon_sub, lat, lon):
    row_list = [0, -1]
    col_list = [0, -1]

    row_col_matrix = np.ones((4, 2), dtype=int) * np.nan
    num = 0

    for row in row_list:
        for col in col_list:
            lat_corner = lat_sub[row][col]
            lon_corner = lon_sub[row][col]
            lat_flag = (lat == lat_corner)
            lon_flag = (lon == lon_corner)
            flag = lat_flag * lon_flag
            pos = np.where(flag == True)
            row_col_matrix[num][0] = pos[0][0]
            row_col_matrix[num][1] = pos[1][0]
            num += 1

    return row_col_matrix

def design_joined_matrix(rc_ref, rc_aff):
    row_join_upper = int(np.min([rc_ref[0, 0], rc_aff[0, 0]]))
    row_join_lower = int(np.max([rc_ref[-1, 0], rc_aff[-1, 0]]))
    col_join_left = int(np.min([rc_ref[0, 1], rc_aff[0, 1]]))
    col_join_right = int(np.max([rc_ref[-1, 1], rc_aff[-1, 1]]))
    return row_join_upper, row_join_lower, col_join_left, col_join_right

def haversin(theta):
    v = np.sin(theta / 2)
    return v * v

def distance2points(lat1, lon1, lat2, lon2):
    radius = 6370

    lat1 = math.radians(lat1)
    lon1 = math.radians(lon1)
    lat2 = math.radians(lat2)
    lon2 = math.radians(lon2)
    
    dlon = lon2 - lon1
    dlat = lat2 - lat1

    h = haversin(dlat) + np.cos(lat1) * np.cos(lat2) * haversin(dlon)

    dis = 2 * radius * np.sin(np.sqrt(h))

    return dis

def overlay_lalo(lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten):
    # calculate the overlay region between two images
    lat1_min = np.min(lat1_flatten)
    lat1_max = np.max(lat1_flatten)
    lon1_min = np.min(lon1_flatten)
    lon1_max = np.max(lon1_flatten)

    lat2_min = np.min(lat2_flatten)
    lat2_max = np.max(lat2_flatten)
    lon2_min = np.min(lon2_flatten)
    lon2_max = np.max(lon2_flatten)

    # calculate the overlay lat and lon
    over_lat_min = max(lat1_min,lat2_min)
    over_lon_min = max(lon1_min,lon2_min)
    over_lat_max = min(lat1_max,lat2_max)
    over_lon_max = min(lon1_max,lon2_max)

    return over_lat_min, over_lon_min, over_lat_max, over_lon_max

def PS_overlay(latlon, over_lat_min, over_lon_min, over_lat_max, over_lon_max):
    # extract the PS points located in the overlay region
    flag_lat = np.where((latlon[:,0]<over_lat_max) & (latlon[:,0]>over_lat_min)) 
    flag_lon = np.where((latlon[:,1]<over_lon_max) & (latlon[:,1]>over_lon_min))

    flag = np.intersect1d(flag_lat, flag_lon)
    PS_num = len(flag)
    print('The total number of PS located at the overlay region is %d' % PS_num)

    return flag

def calculate_offset_matrix(vel_ref, lat_ref, lon_ref, vel_aff, lat_aff, lon_aff):
    # calculate the offset between reference and afflicate image overlay region
    # constructure PD frame
    find_PS = {'lon':lon_ref, 'lat':lat_ref,'vel':vel_ref}
    find = pd.DataFrame(find_PS)
    data_PS = {'lon':lon_aff, 'lat':lat_aff,'vel':vel_aff}
    data = pd.DataFrame(data_PS)
    
    data_fit = data.iloc[:, [0, 1]]
    y = [1] * len(data_fit)

    find_x = find.iloc[:, [0, 1]]

    knn = KNeighborsClassifier(n_neighbors=1,
                               algorithm='ball_tree',
                               metric=lambda s1, s2: distance2points(*s1, *s2))

    # train the knn model
    knn.fit(data_fit, y)
    # calculate the nearest point
    distance, point = knn.kneighbors(find_x, n_neighbors=1, return_distance=True)

    # calculate the median of difference between reference image and affilicate image
    #offset_overlay = m_overlay - s_overlay
    #offset = np.nanmedian(offset_overlay) 
    offset = np.array([[1]])
    for i, row in find.iterrows():
        tmp = data.iloc[point[i]]
        if distance[i][0] < 0.006: 
            find_s = pd.DataFrame(row).T
            vel_ref_value = find_s.loc[i, 'vel']
            vel_aff_value = tmp['vel'].get(point[i][0])
            vel_delta = vel_ref_value - vel_aff_value
            offset = np.append(offset, vel_delta)
    
    return offset[1:]

def concatenate_process(data1_flatten, data2_flatten, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten):
    # for two datasets, do concatenation
    latlon1 = np.hstack((np.transpose(np.array([lat1_flatten])), np.transpose(np.array([lon1_flatten]))))
    latlon2 = np.hstack((np.transpose(np.array([lat2_flatten])), np.transpose(np.array([lon2_flatten]))))
    
    # compare the value of reference between orginal data and concatenated data
    # calculate the overlay latlon
    over_lat_min, over_lon_min, over_lat_max, over_lon_max = overlay_lalo(lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten)

    # calculate the PS points located in the overlay region
    PS_flag1 = PS_overlay(latlon1, over_lat_min, over_lon_min, over_lat_max, over_lon_max)
    PS_flag2 = PS_overlay(latlon2, over_lat_min, over_lon_min, over_lat_max, over_lon_max)

    # calculate the offset between dataset1(reference image) and dataset2 (afflicate image)
    # extract the PS points whose vel value is not Nan
    data1_tmp = data1_flatten[PS_flag1]
    mask1 = ~np.isnan(data1_tmp)
    data2_tmp = data2_flatten[PS_flag2]
    mask2 = ~np.isnan(data2_tmp)   
 
    data1_overlay_num = len(data1_tmp[mask1])
    print('The Nan PS point of reference image located in the overlay region is %d' % data1_overlay_num)
    data2_overlay_num = len(data2_tmp[mask2])
    print('The Nan PS point of affilicate image located in the overlay region is %d' % data2_overlay_num)

    if data1_overlay_num <= data2_overlay_num:
        data_ref = data1_tmp[mask1]
        lat_ref = latlon1[:, 0][PS_flag1][mask1]
        lon_ref = latlon1[:, 1][PS_flag1][mask1]
        data_aff = data2_tmp[mask2]
        lat_aff = latlon2[:, 0][PS_flag2][mask2]
        lon_aff = latlon2[:, 1][PS_flag2][mask2]
        offset = calculate_offset_matrix(data_ref, lat_ref, lon_ref, data_aff, lat_aff, lon_aff)
    else:
        data_ref = data2_tmp[mask2]
        lat_ref = latlon2[:, 0][PS_flag2][mask2]
        lon_ref = latlon2[:, 1][PS_flag2][mask2]
        data_aff = data1_tmp[mask1]
        lat_aff = latlon1[:, 0][PS_flag1][mask1]
        lon_aff = latlon1[:, 1][PS_flag1][mask1]
        offset = calculate_offset_matrix(data_ref, lat_ref, lon_ref, data_aff, lat_aff, lon_aff)
        offset *= (-1)

    overlay_offset = np.nanmedian(offset)
    print('The overlay offset is %f' % overlay_offset)

    # adjust the affiliate image
    data2_flatten_adjust = data2_flatten + overlay_offset

    return data2_flatten_adjust

def concatenate_2D(val_ref, val_aff, rc_ref, rc_aff, ref_flag, data_type):
    row_join_start, row_join_end, col_join_start, col_join_end = design_joined_matrix(rc_ref, rc_aff)
    val_join = np.ones((row_join_end - row_join_start + 1, col_join_end - col_join_start + 1)) * np.nan
    row_a_r = int(np.abs(rc_ref[0, 0] - rc_aff[0, 0]))
    col_a_r = int(np.abs(rc_ref[0, 1] - rc_aff[0, 1]))
    if ref_flag == 1 or ref_flag == 4:
        # join geo
        print('Full the concatenated data: {}, {}'.format(val_join.shape[0], val_join.shape[1]))
        val_join[0: val_ref.shape[0], 0: val_ref.shape[1]] = val_ref
        val_join[row_a_r: , col_a_r: ] = val_aff
        if data_type == 'msk':
            val_join[np.where(val_join == np.nan)] = 0

    elif ref_flag == 2 or ref_flag == 3:
        # join geo
        print('Full the concatenated data: {}, {}'.format(val_join.shape[0], val_join.shape[1]))
        val_join[0: val_ref.shape[0], col_a_r: ] = val_ref
        val_join[row_a_r: , 0: val_aff.shape[1]] = val_aff
        if datatype == 'msk':
            val_join[np.where(val_join == np.nan)] = 0

    elif ref_flag == 5:
        print('Full the concatenated data: {}, {}'.format(val_join.shape[0], val_join.shape[1]))
        val_join[row_a_r: , 0: val_ref.shape[1]] = val_ref
        val_join[0: val_aff.shape[0], col_a_r: col_a_r + val_aff.shape[1]] = val_aff
        if datatype == 'msk':
            val_join[np.where(val_join == np.nan)] = 0

    elif ref_flag == 6:
        print('Full the concatenated data: {}, {}'.format(val_join.shape[0], val_join.shape[1]))
        val_join[0: val_ref.shape[0], 0: val_ref.shape[1]] = val_ref
        val_join[row_a_r: , col_a_r: col_a_r + val_aff.shape[1]] = val_aff
        if datatype == 'msk':
            val_join[np.where(val_join == np.nan)] = 0

    elif ref_flag == 7:
        print('Full the concatenated data: {}, {}'.format(val_join.shape[0], val_join.shape[1]))
        val_join[0: val_ref.shape[0], col_a_r: ] = val_ref
        val_join[row_a_r: row_a_r + val_aff.shape[0], 0: val_aff.shape[1]] = val_aff
        if datatype == 'msk':
            val_join[np.where(val_join == np.nan)] = 0

    elif ref_flag == 8:
        print('Full the concatenated data: {}, {}'.format(val_join.shape[0], val_join.shape[1]))
        val_join[0: val_ref.shape[0], 0: val_ref.shape[1]] = val_ref
        val_join[row_a_r: row_a_r + val_aff.shape[0], col_a_r: col_a_r + val_aff.shape[1]] = val_aff
        if datatype == 'msk':
            val_join[np.where(val_join == np.nan)] = 0

    return val_join

def concatenate_vel(inps, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, unflatten_trans1, unflatten_trans2, rc1, rc2, ref_flag):
    ref_No, aff_No, rc_ref, rc_aff,lat_ref_flatten, lon_ref_flatten, lat_aff_flatten, lon_aff_flatten = appoint_ref(rc1, rc2, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, ref_flag)
    data_ref = inps.patch_files[ref_No]
    data_aff = inps.patch_files[aff_No]

    print('Read the reference dataset') 
    vel_ref, vel_ref_atr = readfile.read(data_ref, datasetName='velocity')
    vel_ref_flatten = vel_ref.flatten()
    
    print('Read the affilicate dataset') 
    vel_aff, vel_aff_atr = readfile.read(data_aff, datasetName='velocity')
    vel_aff_flatten = vel_aff.flatten()

    vel_aff_flatten_adjust = concatenate_process(vel_ref_flatten, vel_aff_flatten, lat_ref_flatten, lon_ref_flatten, lat_aff_flatten, lon_aff_flatten)
    if aff_No == 0:
        vel_aff_adjust = unflatten_trans1(vel_aff_flatten_adjust)
    else:
        vel_aff_adjust = unflatten_trans2(vel_aff_flatten_adjust)

    # generate 2D concatenation results
    data_type = 'vel'
    vel_joined = concatenate_2D(vel_ref, vel_aff_adjust, rc_ref, rc_aff, ref_flag, data_type)

    # adjust the attribute table
    vel_atr = vel_ref_atr
    vel_atr['LENGTH'] = vel_joined.shape[0]
    vel_atr['WIDTH'] = vel_joined.shape[1]
    
    return vel_joined, vel_atr

def concatenate_ts(inps, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, unflatten_trans1, unflatten_trans2, rc1, rc2, ref_flag):
    ref_No, aff_No, rc_ref, rc_aff,lat_ref_flatten, lon_ref_flatten, lat_aff_flatten, lon_aff_flatten = appoint_ref(rc1, rc2, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, ref_flag)
    data_ref = inps.patch_files[ref_No]
    data_aff = inps.patch_files[aff_No]

    print('Read the reference dataset') 
    ts_ref, ts_ref_atr = readfile.read(data_ref, datasetName='timeseries')
    print('Read the affilite dataset') 
    ts_aff, ts2_affatr = readfile.read(data_aff, datasetName='timeseries')

    bperp_date_ref = h5py.File(data_ref,'r')
    bperp_ref = bperp_date_ref['/bperp']
    dateList_ref = timeseries(data_ref).get_date_list()

    bperp_date_aff = h5py.File(data_aff,'r')
    bperp_aff = bperp_date_aff['/bperp']
    dateList_aff = timeseries(data_aff).get_date_list()

    # judging whether dominant and affiliate data have same dimension
    dim_ref = ts_ref.shape[0]
    rows_ref, colms_ref = ts_ref.shape[1:3]
    dim_aff = ts_aff.shape[0]
    rows_aff, colms_aff = ts_aff.shape[1:3]

    #calculate the intersected date betwee two datasets    
    date_final, Date1, Date2, bperp = mimtpy.concatenate_offset.date_match(dateList_ref, dateList_aff, dim_ref, dim_aff, bperp_ref, bperp_aff)
    
    # prepare to concatenate
    join_dim = len(Date1)
    row_join_start, row_join_end, col_join_start, col_join_end = design_joined_matrix(rc_ref, rc_aff)
    row_sum = row_join_end - row_join_start + 1
    col_sum = col_join_end - col_join_start + 1

    ts_join_dataset = dict()
    ts_join = np.empty(shape=(join_dim, row_sum, col_sum), dtype=float)
    # do concatenation
    i = 0
    for date1, date2 in zip(Date1, Date2):
        print('Process displacement data of date %s' % date1)
        dis_ref = readfile.read(data_ref, datasetName=date1)[0]
        dis_ref_flatten = dis_ref.flatten()
        dis_aff = readfile.read(data_aff, datasetName=date2)[0]
        dis_aff_flatten = dis_aff.flatten()

        dis_aff_flatten_adjust = concatenate_process(dis_ref_flatten, dis_aff_flatten, lat_ref_flatten, lon_ref_flatten, lat_aff_flatten, lon_aff_flatten)
        if aff_No == 0:
            dis_aff_adjust = unflatten_trans1(dis_aff_flatten_adjust)
        else:
            dis_aff_adjust = unflatten_trans2(dis_aff_flatten_adjust)

        # generate 2D concatenation results
        data_type = 'ts'
        ts_join[i, :, :] = concatenate_2D(dis_ref, dis_aff_adjust, rc_ref, rc_aff, ref_flag, data_type)
        i += 1

    ts_join_dataset['bperp'] = np.array(bperp, dtype=float)
    ts_join_dataset['date'] = np.array(date_final, dtype=np.string_)
    ts_join_dataset['timeseries'] = ts_join

    # adjust the attribute table
    ts_atr = ts_ref_atr
    ts_atr['LENGTH'] = ts_join.shape[1]
    ts_atr['WIDTH'] = ts_join.shape[2]

    return ts_join_dataset, ts_atr, date_final
        
def concatenate_mask(inps, rc1, rc2, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, ref_flag):
    """concantenate maskPS data"""
    ref_No, aff_No, rc_ref, rc_aff,lat_ref_flatten, lon_ref_flatten, lat_aff_flatten, lon_aff_flatten = appoint_ref(rc1, rc2, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, ref_flag)
    data_ref = inps.patch_files[ref_No]
    data_aff = inps.patch_files[aff_No]

    print('Read the reference dataset') 
    msk_ref, msk_ref_atr = readfile.read(data_ref) 
    print('Read the affilite dataset') 
    msk_aff, msk_aff_atr = readfile.read(data_aff)

    data_type = 'msk'
    msk_joined = concatenate_2D(msk_ref, msk_aff, rc_ref, rc_aff, ref_flag, data_type)

    # adjust the attribute table
    msk_atr = msk_ref_atr
    msk_atr['LENGTH'] = msk_joined.shape[0]
    msk_atr['WIDTH'] = msk_joined.shape[1]

    return msk_joined, msk_atr

def concatenate_geo(inps):
    """concatenate geometry data"""
    data_geo1 = inps.geo_file1[0]
    data_geo2 = inps.geo_file2[0]    
        
    geo_full = inps.geo_full[0]

    print('Read the geometry data for the full region') 
    lat_full = readfile.read(geo_full, datasetName='latitude')[0]
    lon_full = readfile.read(geo_full, datasetName='longitude')[0]
    inc_full = readfile.read(geo_full, datasetName='incidenceAngle')[0]
    azi_full = readfile.read(geo_full, datasetName='azimuthAngle')[0]
    hgt_full = readfile.read(geo_full, datasetName='height')[0]
    
    print('Read the first dataset') 
    lat1, lat_atr1 = readfile.read(data_geo1, datasetName='latitude')
    lon1, lon_atr1 = readfile.read(data_geo1, datasetName='longitude')
    inc1, inc_atr1 = readfile.read(data_geo1, datasetName='incidenceAngle')
    azi1, azi_atr1 = readfile.read(data_geo1, datasetName='azimuthAngle')
    hgt1, hgt_atr1 = readfile.read(data_geo1, datasetName='height')

    print('Read the second dataset') 
    lat2, lat_atr2 = readfile.read(data_geo2, datasetName='latitude')
    lon2, lon_atr2 = readfile.read(data_geo2, datasetName='longitude')
    inc2, inc_atr2 = readfile.read(data_geo2, datasetName='incidenceAngle')
    azi2, azi_atr2 = readfile.read(data_geo2, datasetName='azimuthAngle')
    hgt2, hgt_atr2 = readfile.read(data_geo2, datasetName='height')
    
    lat1_flatten = lat1.flatten() # flatten matrix according rows
    lon1_flatten = lon1.flatten()

    lat2_flatten = lat2.flatten()
    lon2_flatten = lon2.flatten()
    
    # calculate the unflatten pattern
    unflatten_trans1 = flatten_trans(lat1)
    unflatten_trans2 = flatten_trans(lat2)

    print('Convert to X-Y coordinate')
    rc1 = get_position(lat1, lon1, lat_full, lon_full)
    rc2 = get_position(lat2, lon2, lat_full, lon_full)

    # extract the geometry info for the joined region
    row_join_start, row_join_end, col_join_start, col_join_end = design_joined_matrix(rc1, rc2)

    lat_joined = lat_full[row_join_start: row_join_end + 1, col_join_start: col_join_end + 1]
    lon_joined = lon_full[row_join_start: row_join_end + 1, col_join_start: col_join_end + 1]
    inc_joined = inc_full[row_join_start: row_join_end + 1, col_join_start: col_join_end + 1]
    azi_joined = azi_full[row_join_start: row_join_end + 1, col_join_start: col_join_end + 1]
    hgt_joined = hgt_full[row_join_start: row_join_end + 1, col_join_start: col_join_end + 1]
    
    return lat_joined, lon_joined, inc_joined, azi_joined, hgt_joined, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, unflatten_trans1, unflatten_trans2, rc1, rc2

def write_vel(vel_joined, vel_atr, datatype, inps):
    
    row, colm = vel_joined.shape

    atr_vel = dict()
    atr_vel['WIDTH'] = str(colm)
    atr_vel['LENGTH'] = str(row)
    atr_vel['FILE_TYPE'] = 'velocity'
    
    vel_data = dict()
    vel_data['velocity'] = vel_joined

    output_dir = inps.outdir[0]
    outname = inps.out_suffix[0]
    if len(outname) is 0:
        vel_filename = output_dir + '/' + datatype +  '.h5'
    else:
        vel_filename = output_dir + '/' + datatype + '_' + outname + '.h5'

    writefile.write(datasetDict=vel_data, out_file=vel_filename, metadata=atr_vel)

    return 

def write_ts(ts_joined_dataset, ts_atr, date_final, datatype, inps):
    row, colm = ts_joined_dataset['timeseries'].shape[1: ]

    atr_ts = ts_atr
    atr_ts['WIDTH'] = str(colm)
    atr_ts['LENGTH'] = str(row)
    atr_ts['FILE_TYPE'] = 'timeseries'
    
    output_dir = inps.outdir[0]
    outname = inps.out_suffix[0]
    file_name = os.path.basename(inps.patch_files[0])
    if len(outname) is 0:
        ts_filename = output_dir + '/' + file_name +  '.h5'
    else:
        ts_filename = output_dir + '/' + file_name + '_' + outname + '.h5'

    writefile.write(datasetDict=ts_joined_dataset, out_file=ts_filename, metadata=atr_ts)
    
    return

def write_mask(msk_joined, datatype, inps):
    row, colm = msk_joined.shape
    
    # write simple attribution
    atr_msk = dict()
    atr_msk['WIDTH'] = str(colm)
    atr_msk['LENGTH'] = str(row)
    atr_msk['FILE_TYPE'] = 'mask'

    msk_data = dict()
    msk_data['mask'] = msk_joined

    output_dir = inps.outdir[0]
    outname = inps.out_suffix[0]
    file_name = os.path.basename(inps.patch_files[0])
    if len(outname) is 0:
        msk_filename = output_dir + '/' + file_name +  '.h5'
    else:
        msk_filename = output_dir + '/' + file_name + '_' + outname + '.h5'

    writefile.write(datasetDict=msk_data, out_file=msk_filename, metadata=atr_msk)
     
    return 

def write_geo(lat_joined, lon_joined, inc_joined, azi_joined, hgt_joined, inps):

    row, colm = lat_joined.shape

    # write simple attribution
    atr_geo = dict()
    atr_geo['WIDTH'] = str(colm)
    atr_geo['LENGTH'] = str(row)
    atr_geo['FILE_TYPE'] = 'geometry'


    lat_data = dict()
    lat_data['latitude'] = lat_joined

    lon_data = dict()
    lon_data['longitude'] = lon_joined

    geo_data = dict()
    geo_data['azimuthAngle'] = azi_joined
    geo_data['incidenceAngle'] = inc_joined
    geo_data['height'] = hgt_joined
    geo_data['latitude'] = lat_joined
    geo_data['longitude'] = lon_joined

    output_dir = inps.outdir[0]
    outname = inps.out_suffix[0]
    if len(outname) is 0:
        geo_outname = 'geometryRadar'
    else:
        geo_outname = 'geometryRadar_' + outname

    geo_filename = output_dir + '/' + geo_outname + '.h5'

    # write h5 file
    writefile.write(datasetDict=geo_data, out_file=geo_filename, metadata=atr_geo)

    print('Finish!')

def find_the_reference(rc1, rc2):
    """Find the reference data based on geo info"""
    row1_ul = rc1[0][0] # upper left point
    col1_ul = rc1[0][1]
    row1_lr = rc1[3][0] # lower right point
    col1_lr = rc1[3][1]

    row2_ul = rc2[0][0] # upper left point
    col2_ul = rc2[0][1]
    row2_lr = rc2[3][0] # lower right point
    col2_lr = rc2[3][1]

    if row1_ul <= row2_ul and col1_ul <= col2_ul and row1_lr <= row2_lr and col1_lr <= col2_lr:
        return 1
    elif row1_ul > row2_ul and col1_ul < col2_ul and row1_lr > row2_lr and col1_lr < col2_lr:
        return 2
    elif row1_ul < row2_ul and col1_ul > col2_ul and row1_lr < row2_lr and col1_lr > col2_lr:
        return 3
    elif row1_ul > row2_ul and col1_ul > col2_ul and row1_lr > row2_lr and col1_lr > col2_lr:
        return 4
    elif row1_ul > row2_ul and col1_ul < col2_ul and row1_lr > row2_lr and col1_lr > col2_lr:
        return 5
    elif row1_ul < row2_ul and col1_ul < col2_ul and row1_lr < row2_lr and col1_lr > col2_lr:
        return 6
    elif row1_ul < row2_ul and col1_ul < col2_ul and row1_lr > row2_lr and col1_lr < col2_lr:
        return 7
    elif row1_ul < row2_ul and col1_ul > col2_ul and row1_lr > row2_lr and col1_lr > col2_lr:
        return 8

def appoint_ref(rc1, rc2, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, ref_flag):
    if ref_flag == 2 or ref_flag == 4:
        data_ref = 1
        data_aff = 0
        rc_ref = rc2
        rc_aff = rc1
        lat_ref_flatten = lat2_flatten
        lon_ref_flatten = lon2_flatten
        lat_aff_flatten = lat1_flatten
        lon_aff_flatten = lon1_flatten
    else:
        data_ref = 0
        data_aff = 1
        rc_ref = rc1
        rc_aff = rc2
        lat_ref_flatten = lat1_flatten
        lon_ref_flatten = lon1_flatten
        lat_aff_flatten = lat2_flatten
        lon_aff_flatten = lon2_flatten

    return data_ref, data_aff, rc_ref, rc_aff, lat_ref_flatten, lon_ref_flatten, lat_aff_flatten, lon_aff_flatten

def determine_datatype(inps):
    data = inps.patch_files[0]
    data_atr = readfile.read_attribute(data)
    datatype = data_atr['FILE_TYPE'] 

    return datatype

def main():
    inps = cmd_line_parse()   

    datatype = determine_datatype(inps)
    print('Data type found: ', datatype)

    print('Process the geometry info')
    lat_joined, lon_joined, inc_joined, azi_joined, hgt_joined, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, unflatten_trans1, unflatten_trans2, rc1, rc2 = concatenate_geo(inps)
    if inps.geo_write:
        write_geo(lat_joined, lon_joined, inc_joined, azi_joined, hgt_joined, inps)

    print('Find which dataset is reference dataset')
    ref_flag = find_the_reference(rc1, rc2)

    print('process %s data' % datatype)
    if datatype == 'velocity':
        vel_joined, vel_atr = concatenate_vel(inps, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, unflatten_trans1, unflatten_trans2, rc1, rc2, ref_flag)
        write_vel(vel_joined, vel_atr, datatype, inps)
    elif datatype == 'timeseries':
        ts_join_dataset, ts_atr, date_final = concatenate_ts(inps, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, unflatten_trans1, unflatten_trans2, rc1, rc2, ref_flag)
        write_ts(ts_join_dataset, ts_atr, date_final, datatype, inps)
    elif datatype == 'mask':
        msk_joined, msk_atr = concatenate_mask(inps, rc1, rc2, lat1_flatten, lon1_flatten, lat2_flatten, lon2_flatten, ref_flag)
        write_mask(msk_joined, datatype, inps)
    #elif datatype == 'mask':
    #    msk_joined, msk_atr = concatenate_mask(inps, row_ref, col_ref, row_aff, col_aff, row_a_r, col_a_r, ref_flag)
    #    write_mask(msk_joined, inps)
######################################################################################
if __name__ == '__main__':
    main()
