"""
Title: WVM and TCD Raster Processing Script

Description:
This script processes geospatial raster data from two input layers: WVM (Water and Vegetation Mask) and TCD (Tree Cover Density). It tiles the large input rasters, resamples the TCD tiles to match the WVM resolution, combines the tiles based on specific reclassification and prioritization rules, and merges the processed tiles into a single output raster.

Key Features:
- Splits WVM and TCD rasters into smaller tiles for efficient processing.
- Resamples TCD tiles from 10m resolution to match the 5m resolution of WVM.
- Reclassifies values in WVM and TCD tiles to align with ecological analysis needs.
- Combines WVM and TCD tiles by prioritizing WVM values and resolving conflicts.
- Merges combined tiles back into a single raster file for further use.
- Provides user-defined input and output paths for flexibility.

Usage:
1. Define paths to the project directory and input rasters (WVM and TCD).
2. Run the script in an environment with required libraries installed (GDAL, NumPy, etc.).
3. Intermediate tiles and final outputs are saved to specified directories.

Author: Rui Catarino
Date: 16/01/25
"""

import os
from pathlib import Path
import time
import numpy as np
from osgeo import gdal
import subprocess

# Function to report the time taken for a processing step
def report_time(start_time, step_name):
    elapsed_time = time.time() - start_time
    hours, remainder = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"\n{step_name} completed in {int(hours)}h {int(minutes)}m {int(seconds)}s.")

# Function to check if an output TIFF file exists and is valid
def check_output_tif(output_raster_path):
    """
    Checks if the output raster file exists and is not corrupted.

    Parameters:
        output_raster_path (str): Path to the output raster file.

    Returns:
        bool: True if the file exists and is valid, False otherwise.
    """
    if os.path.exists(output_raster_path):
        try:
            ds = gdal.Open(output_raster_path)
            if ds is not None:
                return True
            else:
                os.remove(output_raster_path)
                return False
        except:
            os.remove(output_raster_path)
            return False
    return False

# Function to create tiles from input rasters
def create_tiles(wvm_rst, tcd_rst, WVM_tiles_path, TCD_tiles_path, y2):
    """
    Cuts WVM and TCD rasters into horizontal tiles and saves them to specified directories.

    Parameters:
        wvm_rst (str): Path to the WVM raster file.
        tcd_rst (str): Path to the TCD raster file.
        WVM_tiles_path (str): Directory to save WVM tiles.
        TCD_tiles_path (str): Directory to save TCD tiles.
        y2 (int): Height of each tile in kilometers.
    """
    step_start_time = time.time()   
    print("\n" + "="*40)
    print(f"Processing input rasters: \n - {wvm_rst} \n - {tcd_rst} \n")

    # Load rasters
    wvm_ds = gdal.Open(wvm_rst)
    tcd_ds = gdal.Open(tcd_rst)
     
    # Get geotransforms and raster size
    wvm_geotransform = wvm_ds.GetGeoTransform()
    tcd_geotransform = tcd_ds.GetGeoTransform()
    
    # Calculate extents
    wvm_x_extent = wvm_ds.RasterXSize * wvm_geotransform[1]
    tcd_x_extent = tcd_ds.RasterXSize * tcd_geotransform[1] 
    x2 = max(wvm_x_extent, tcd_x_extent) 
    
    raster_height = max(wvm_ds.RasterYSize, tcd_ds.RasterYSize)
    y2 = y2 * 1000

    # Log geotransforms and extents
    print(f"WVM Geotransform: {wvm_geotransform}")
    print(f"TCD Geotransform: {tcd_geotransform}")
    
    # Calculate the number of tiles vertically
    y_cuts = int(raster_height / y2)
    print(f"Calculated number of cuts (height-wise): {y_cuts}")
    
    for i in range(y_cuts):
        tile_id = f"tile_{i}.tif"
        y1 = i * y2

        # File paths for output tiles
        wvm_path = os.path.join(WVM_tiles_path, tile_id)
        tcd_path = os.path.join(TCD_tiles_path, tile_id)
        width_x = wvm_ds.RasterXSize

        process_wvm = not check_output_tif(str(wvm_path))
        process_tcd = not check_output_tif(str(tcd_path))

        if process_wvm:
            command_wvm = [
                "gdal_translate",
                "-of", "GTiff",
                "-co", "BIGTIFF=YES",
                "-co", "TILED=YES",
                "-srcwin", "0", str(y1), str(width_x), str(y2),
                wvm_rst,
                wvm_path
            ]
            print(f"Processing tile {i + 1}/{y_cuts} for WVM")
            subprocess.run(command_wvm, check=True)

        if process_tcd:
            command_tcd = [
                "gdal_translate",
                "-of", "GTiff",
                "-co", "BIGTIFF=YES",
                "-srcwin", "0", str(y1), str(width_x), str(y2),
                tcd_rst,
                tcd_path
            ]
            print(f"Processing tile {i + 1}/{y_cuts} for TCD")
            subprocess.run(command_tcd, check=True)
        
    print(f"Tiles saved to -{WVM_tiles_path} \n -{TCD_tiles_path}")
    report_time(step_start_time, "Tiles creation")
    print("="*40)

# Function to resample tiles to a target resolution
def resample_tiles(input_path, output_path, target_resolution):
    """
    Resamples tiles in a directory to a specified resolution.

    Parameters:
        input_path (str): Directory containing input tiles.
        output_path (str): Directory to save resampled tiles.
        target_resolution (int): Target resolution in meters.
    """
    tile_paths = list(Path(input_path).glob("*.tif"))
    for tile_path in tile_paths:
        output_tile_path = os.path.join(output_path, tile_path.name)
        if not check_output_tif(output_tile_path):
            gdal.Warp(output_tile_path, str(tile_path), xRes=target_resolution, yRes=target_resolution, resampleAlg="near")

# Function to combine WVM and TCD tiles
def combine_tiles(wvm_tile, tcd_tile, output_tile):
    """
    Combines WVM and TCD tiles based on specified rules.

    Parameters:
        wvm_tile (str): Path to the WVM tile.
        tcd_tile (str): Path to the TCD tile.
        output_tile (str): Path to save the combined tile.
    """
    wvm_ds = gdal.Open(wvm_tile)
    tcd_ds = gdal.Open(tcd_tile)

    wvm_array = wvm_ds.GetRasterBand(1).ReadAsArray()
    tcd_array = tcd_ds.GetRasterBand(1).ReadAsArray()

    # Reclassify TCD array
    tcd_array = np.where(tcd_array < 0, 0, tcd_array)  # Set negative values to 0
    tcd_array = np.where(tcd_array == 2, 1, tcd_array)  # Reclassify 2 to 1
    print("TCD array reclassified.")

    # Reclassify WVM array
    wvm_array = np.where(wvm_array < 0, 0, wvm_array)  # Set negative values to 0
    wvm_array = np.where(wvm_array > 1, 0, wvm_array)  # Set all values > 1 to 0
    wvm_array = np.where(wvm_array == 1, 2, wvm_array)  # Reclassify 1 to 2
    print("WVM array reclassified.")

    # Merge both tiles, keeping the highest value between the corresponding pixels
    combined_array = np.maximum(tcd_array, wvm_array)

    # Reclassify merged array so the final file is only composed of 2s and 0s
    combined_array = np.where((combined_array == 1) | (combined_array == 2), 2, 0)

    # Save the combined array as a new tile
    driver = gdal.GetDriverByName("GTiff")
    output_ds = driver.Create(
        output_tile,
        wvm_ds.RasterXSize,
        wvm_ds.RasterYSize,
        1,
        gdal.GDT_Byte
    )
    output_ds.SetGeoTransform(wvm_ds.GetGeoTransform())
    output_ds.SetProjection(wvm_ds.GetProjection())
    output_ds.GetRasterBand(1).WriteArray(combined_array)
    output_ds.FlushCache()
    output_ds = None

    wvm_ds = None
    tcd_ds = None

# Function to merge tiles into a single raster
def merge_tiles(input_dir, output_file):
    """
    Merges all tiles in a directory into a single raster.

    Parameters:
        input_dir (str): Directory containing input tiles.
        output_file (str): Path to save the merged raster.
    """
    tile_paths = list(Path(input_dir).glob("*.tif"))
    tile_paths_str = " ".join(str(tile) for tile in tile_paths)

    subprocess.run([
        "gdal_merge.py",
        "-o", output_file,
        "-co", "COMPRESS=LZW",
        "-co", "TILED=YES",
        "-co", "BIGTIFF=YES",
        "-co", "BLOCKXSIZE=256", "-co", "BLOCKYSIZE=256",
        *tile_paths_str.split()
    ], check=True)

# Directories
project_path = '/scratch/silvrui/NPCi2018i_paper/work_dir/'
output_path = os.path.join(project_path, "Output_Data")
temp_path = os.path.join(output_path, "Temp_NPCi")

# Input and output files
WVM_path = os.path.join(project_path, "Input_Data/WVM2018.tif")
TCD_path = os.path.join(project_path, "Input_Data/TCD10_2018.tif")
WVM_tiles_path = os.path.join(temp_path, "WVM_tiles")
TCD_tiles_path = os.path.join(temp_path, "TCD_tiles")
WVM_TCD_tiles_path = os.path.join(temp_path, "WVM_TCD_tiles")
SWF_TCD_tif = os.path.join(output_path, "SWF_TCD_2018.tif")

# Ensure directories exist
Path(WVM_tiles_path).mkdir(parents=True, exist_ok=True)
Path(TCD_tiles_path).mkdir(parents=True, exist_ok=True)
Path(WVM_TCD_tiles_path).mkdir(parents=True, exist_ok=True)

# Main process
start_time = time.time()

# 1. Tile the input rasters
create_tiles(WVM_path, TCD_path, WVM_tiles_path, TCD_tiles_path, y2=2)

# 2. Resample TCD tiles to 5m resolution
resample_tiles(TCD_tiles_path, TCD_tiles_path, 5)

# 3. Combine WVM and TCD tiles
wvm_tiles = list(Path(WVM_tiles_path).glob("*.tif"))
tcd_tiles = list(Path(TCD_tiles_path).glob("*.tif"))

for wvm_tile, tcd_tile in zip(wvm_tiles, tcd_tiles):
    output_tile = os.path.join(WVM_TCD_tiles_path, Path(wvm_tile).name)
    combine_tiles(str(wvm_tile), str(tcd_tile), output_tile)

# 4. Merge combined tiles into a single raster
merge_tiles(WVM_TCD_tiles_path, SWF_TCD_tif)

report_time(start_time, "Tiling, merging, and combining rasters")
