#!/usr/bin/env python3
############################################################
# Program is part of MintPy                                #
# Copyright (c) 2013, Zhang Yunjun, Heresh Fattahi         #
# Author: Zhang Yunjun, 2017                               #
############################################################


import os
import sys
import time
import numpy as np

from mintpy.objects.resample import resample
from mintpy.defaults.template import get_template_content
from mintpy.utils import (
    arg_utils,
    readfile,
    writefile,
    utils as ut,
    attribute as attr,
)


######################################################################################
TEMPLATE = get_template_content('geocode')

EXAMPLE = """example:
  geocode.py velocity.h5
  geocode.py velocity.h5 -b -0.5 -0.25 -91.3 -91.1
  geocode.py velocity.h5 timeseries.h5 -t smallbaselineApp.cfg --outdir ./geo --update

  # geocode file using ISCE-2 lat/lon.rdr file
  geocode.py filt_fine.int --lat-file ../../geom_reference/lat.rdr --lon-file ../../geom_reference/lon.rdr

  # radar-code file in geo coordinates
  geocode.py swbdLat_S02_N01_Lon_W092_W090.wbd -l geometryRadar.h5 -o waterMask.rdr --geo2radar
  geocode.py geo_velocity.h5 --geo2radar
"""

DEG2METER = """
degrees     --> meters on equator
0.000925926 --> 100
0.000833334 --> 90
0.000555556 --> 60
0.000462963 --> 50
0.000277778 --> 30
0.000185185 --> 20
0.000092593 --> 10
"""


def create_parser(subparsers=None):
    synopsis = 'Resample radar-coded files into geo-coordinates or vice versa'
    epilog = TEMPLATE + '\n' + EXAMPLE
    name = __name__.split('.')[-1]
    parser = arg_utils.create_argument_parser(
        name, synopsis=synopsis, description=synopsis, epilog=epilog, subparsers=subparsers)

    parser.add_argument('file', nargs='+', help='File(s) to be geocoded')
    parser.add_argument('-d', '--dset', help='dataset to be geocoded, for example:\n' +
                        'height                        for geometryRadar.h5\n' +
                        'unwrapPhase-20100114_20101017 for ifgramStack.h5')

    parser.add_argument('-l', '--lookup', dest='lookupFile',
                        help='Lookup table file generated by InSAR processors.')
    parser.add_argument('--lat-file', dest='latFile', help='lookup table file for latitude.')
    parser.add_argument('--lon-file', dest='lonFile', help='lookup table file for longitude.')

    parser.add_argument('--geo2radar', '--geo2rdr', dest='radar2geo', action='store_false',
                        help='resample geocoded files into radar coordinates.\n' +
                             'ONLY for lookup table in radar-coord (ISCE, Doris).')

    parser.add_argument('-t', '--template', dest='templateFile',
                        help="Template file with geocoding options.")

    # output grid / geometry
    out = parser.add_argument_group('grid in geo-coordinates')
    out.add_argument('-b', '--bbox', dest='SNWE', type=float, nargs=4, metavar=('S', 'N', 'W', 'E'),
                     help='Bounding box for the area of interest.\n'
                          'using coordinates of the uppler left corner of the first pixel\n'
                          '                 and the lower right corner of the last pixel\n'
                          "for radar2geo, it's the output spatial extent\n"
                          "for geo2radar, it's the input  spatial extent")
    out.add_argument('--lalo', '--lalo-step', dest='laloStep', type=float, nargs=2, metavar=('LAT_STEP', 'LON_STEP'),
                     help='output pixel size in degree in latitude / longitude.{}'.format(DEG2METER))

    # interpolation / resampling
    interp = parser.add_argument_group('interpolation')
    interp.add_argument('-i', '--interp', dest='interpMethod', default='nearest', choices={'nearest', 'linear'},
                        help='interpolation/resampling method (default: %(default)s).')
    interp.add_argument('--fill', dest='fillValue', type=float, default=np.nan,
                        help='Fill value for extrapolation (default: %(default)s).')
    interp.add_argument('-n','--nprocs', dest='nprocs', type=int, default=1,
                        help='number of processors to be used for calculation (default: %(default)s).\n'
                             'Note: Do not use more processes than available processor cores.')
    interp.add_argument('--software', dest='software', default='pyresample', choices={'pyresample', 'scipy'},
                        help='software/module used for interpolation (default: %(default)s)\n'
                             'Note: --bbox is not supported for -p scipy')

    parser.add_argument('--update', dest='updateMode', action='store_true',
                        help='skip resampling if output file exists and newer than input file')
    parser.add_argument('-o', '--output', dest='outfile',
                        help="output file name. Default: add prefix 'geo_'")
    parser.add_argument('--outdir', '--output-dir', dest='out_dir', help='output directory.')

    # computing
    parser = arg_utils.add_memory_argument(parser)

    return parser


def cmd_line_parse(iargs=None):
    parser = create_parser()
    inps = parser.parse_args(args=iargs)

    if inps.templateFile:
        inps = read_template2inps(inps.templateFile, inps)

    inps = _check_inps(inps)

    return inps


def _check_inps(inps):
    # check 1 - input file(s) existence
    inps.file = ut.get_file_list(inps.file)
    if not inps.file:
        raise Exception('ERROR: no input file found!')
    elif len(inps.file) > 1:
        inps.outfile = None

    # check 2 - lookup table existence
    if not inps.lookupFile:
        # grab default lookup table
        inps.lookupFile = ut.get_lookup_file(inps.lookupFile)

        # use isce-2 lat/lon.rdr file
        if not inps.lookupFile and inps.latFile:
            inps.lookupFile = inps.latFile

        # final check
        if not inps.lookupFile:
            raise FileNotFoundError('No lookup table found! Can not geocode without it.')

    # check 3 - src file coordinate & radar2geo operatioin
    atr = readfile.read_attribute(inps.file[0])
    if 'Y_FIRST' in atr.keys() and inps.radar2geo:
        print('input file is already geocoded')
        print('to resample geocoded files into radar coordinates, use --geo2radar option')
        print('exit without doing anything.')
        sys.exit(0)
    elif 'Y_FIRST' not in atr.keys() and not inps.radar2geo:
        print('input file is already in radar coordinates, exit without doing anything')
        sys.exit(0)

    # check 4 - laloStep
    # valid only if:
    # 1. radar2geo = True AND
    # 2. lookup table is in radar coordinates
    if inps.laloStep:
        if not inps.radar2geo:
            print('ERROR: "--lalo-step" can NOT be used together with "--geo2radar"!')
            sys.exit(0)
        atr = readfile.read_attribute(inps.lookupFile)
        if 'Y_FIRST' in atr.keys():
            print('ERROR: "--lalo-step" can NOT be used with lookup table file in geo-coordinates!')
            sys.exit(0)

    # check 5 - number of processors for multiprocessingg
    inps.nprocs = check_num_processor(inps.nprocs)


    # check 6 - geo2radar
    if not inps.radar2geo:
        if inps.SNWE:
            print('ERROR: "--geo2radar" can NOT be used together with "--bbox"!')
            sys.exit(0)
        if inps.software == 'scipy':
            print('ERROR: "--geo2radar" is NOT supported for "--software scipy"!')
            sys.exit(0)

    return inps


def read_template2inps(template_file, inps):
    """Read input template options into Namespace inps"""
    print('read input option from template file: ' + template_file)
    if not inps:
        inps = cmd_line_parse()
    inps_dict = vars(inps)
    template = readfile.read_template(template_file, skip_chars=['[', ']'])
    template = ut.check_template_auto_value(template)

    prefix = 'mintpy.geocode.'
    key_list = [i for i in list(inps_dict.keys()) if prefix + i in template.keys()]
    for key in key_list:
        value = template[prefix + key]
        if value:
            if key in ['SNWE', 'laloStep']:
                inps_dict[key] = [float(i) for i in value.split(',')]
            elif key in ['interpMethod']:
                inps_dict[key] = value
            elif key == 'fillValue':
                if 'nan' in value.lower():
                    inps_dict[key] = np.nan
                else:
                    inps_dict[key] = float(value)

    # computing configurations
    key = 'mintpy.compute.maxMemory'
    if key in template.keys() and template[key]:
        inps.maxMemory = float(template[key])

    return inps


############################################################################################
def check_num_processor(nprocs):
    """Check number of processors
    Note by Yunjun, 2019-05-02:
    1. conda install pyresample will install pykdtree and openmp, but it seems not working:
        geocode.py is getting slower with more processors
            Test on a TS HDF5 file in size of (241, 2267, 2390)
            Memory: up to 10GB
            Run time: 2.5 mins for nproc=1, 3 mins for nproc=4
    2. macports seems to have minor speedup when more processors
    Thus, default number of processors is set to 1; although the capability of using multiple
    processors is written here.
    """
    if not nprocs:
        #OMP_NUM_THREADS is defined in environment variable for OpenMP
        if 'OMP_NUM_THREADS' in os.environ:
            nprocs = int(os.getenv('OMP_NUM_THREADS'))
        else:
            nprocs = int(os.cpu_count() / 2)
    nprocs = min(os.cpu_count(), nprocs)
    print('number of processor to be used: {}'.format(nprocs))
    return nprocs


def auto_output_filename(infile, inps):
    if len(inps.file) == 1 and inps.outfile:
        return inps.outfile

    if inps.radar2geo:
        prefix = 'geo_'
    else:
        prefix = 'rdr_'

    if inps.dset:
        outfile = '{}{}.h5'.format(prefix, inps.dset)
    else:
        outfile = '{}{}'.format(prefix, os.path.basename(infile))

    if inps.out_dir:
        if not os.path.isdir(inps.out_dir):
            os.makedirs(inps.out_dir)
            print('create directory: {}'.format(inps.out_dir))
        outfile = os.path.join(inps.out_dir, outfile)
    return outfile


def run_geocode(inps):
    """geocode all input files"""
    start_time = time.time()

    # feed the largest file for resample object initiation
    ind_max = np.argmax([os.path.getsize(i) for i in inps.file])

    # prepare geometry for geocoding
    kwargs = dict(interp_method=inps.interpMethod,
                  fill_value=inps.fillValue,
                  nprocs=inps.nprocs,
                  max_memory=inps.maxMemory,
                  software=inps.software,
                  print_msg=True)
    if inps.latFile and inps.lonFile:
        kwargs['lat_file'] = inps.latFile
        kwargs['lon_file'] = inps.lonFile
    res_obj = resample(lut_file=inps.lookupFile,
                       src_file=inps.file[ind_max],
                       SNWE=inps.SNWE,
                       lalo_step=inps.laloStep,
                       **kwargs)
    res_obj.open()
    res_obj.prepare()

    # resample input files one by one
    for infile in inps.file:
        print('-' * 50+'\nresampling file: {}'.format(infile))
        atr = readfile.read_attribute(infile, datasetName=inps.dset)
        outfile = auto_output_filename(infile, inps)

        # update_mode
        if inps.updateMode:
            print('update mode: ON')
            if ut.run_or_skip(outfile, in_file=[infile, inps.lookupFile]) == 'skip':
                continue

        ## prepare output
        # update metadata
        if inps.radar2geo:
            atr = attr.update_attribute4radar2geo(atr, res_obj=res_obj)
        else:
            atr = attr.update_attribute4geo2radar(atr, res_obj=res_obj)

        # instantiate output file
        file_is_hdf5 = os.path.splitext(outfile)[1] in ['.h5', '.he5']
        if file_is_hdf5:
            compression = readfile.get_hdf5_compression(infile)
            writefile.layout_hdf5(outfile, metadata=atr, ref_file=infile, compression=compression)
        else:
            dsDict = dict()

        ## run
        dsNames = readfile.get_dataset_list(infile, datasetName=inps.dset)
        maxDigit = max([len(i) for i in dsNames])
        for dsName in dsNames:

            if not file_is_hdf5:
                dsDict[dsName] = np.zeros((res_obj.length, res_obj.width))

            # loop for block-by-block IO
            for i in range(res_obj.num_box):
                src_box = res_obj.src_box_list[i]
                dest_box = res_obj.dest_box_list[i]

                # read
                print('-'*50+'\nreading {d:<{w}} in block {b} from {f} ...'.format(
                    d=dsName, w=maxDigit, b=src_box, f=os.path.basename(infile)))

                data = readfile.read(infile,
                                     datasetName=dsName,
                                     box=src_box,
                                     print_msg=False)[0]

                # resample
                data = res_obj.run_resample(src_data=data, box_ind=i)

                # write / save block data
                if data.ndim == 3:
                    block = [0, data.shape[0],
                             dest_box[1], dest_box[3],
                             dest_box[0], dest_box[2]]
                else:
                    block = [dest_box[1], dest_box[3],
                             dest_box[0], dest_box[2]]

                if file_is_hdf5:
                    print('write data in block {} to file: {}'.format(block, outfile))
                    writefile.write_hdf5_block(outfile,
                                               data=data,
                                               datasetName=dsName,
                                               block=block,
                                               print_msg=False)
                else:
                    dsDict[dsName][block[0]:block[1],
                                   block[2]:block[3]] = data

            # for binary file: ensure same data type
            if not file_is_hdf5:
                dsDict[dsName] = np.array(dsDict[dsName], dtype=data.dtype)

        # write binary file
        if not file_is_hdf5:
            atr['BANDS'] = len(dsDict.keys())
            writefile.write(dsDict, out_file=outfile, metadata=atr, ref_file=infile)

            # create ISCE XML and GDAL VRT file if using ISCE lookup table file
            if inps.latFile and inps.lonFile:
                writefile.write_isce_xml(atr, fname=outfile)

    m, s = divmod(time.time()-start_time, 60)
    print('time used: {:02.0f} mins {:02.1f} secs.\n'.format(m, s))
    return outfile


######################################################################################
def main(iargs=None):
    inps = cmd_line_parse(iargs)

    run_geocode(inps)

    return


######################################################################################
if __name__ == '__main__':
    main(sys.argv[1:])
