"""
Title: NPCi Workflow Script

Description:
This script processes spatial data to calculate the Natural Pest Control Index (NPCi). 
It takes input rasters such as Morphological Spatial Pattern Analysis (MSPA), extensive grasslands (extGL), and CORINE data, performs reclassification and resampling, and combines the data into comprehensive outputs. 
The script facilitates various geospatial analyses including masking, merging, and normalization of raster layers.

Key Features:
- Cuts large raster files into smaller tiles for efficient processing.
- Reclassifies MSPA rasters into meaningful categories (Core, Edge, Linear).
- Combines MSPA and extensive grasslands rasters into Semi-Natural Habitat (SNH) layers.
- Generates a binary CORINE raster and masks other layers using this binary mask.
- Applies focal statistics to calculate NPCi and normalizes the results for further analysis.
- Merges processed tiles back into single raster files.
- Designed for flexible input paths and dynamic directory management.

Usage:
1. Define paths to the project working directory and input raster files.
2. Comment/uncomment the relevant steps in the `main` function based on the desired processing.
3. Run the script in an environment with required libraries installed (GDAL, NumPy, Rasterio, etc.).
4. The script produces intermediate and final outputs in the specified output directory.

Author: Rui Catarino (rui.catarino@ec.europa.eu)
Date: 16/01/25
"""


import os
import sys
from pathlib import Path
import subprocess
import time
import geopandas as gpd
from shapely.geometry import box
from osgeo import gdal, ogr
import numpy as np
import rasterio
from scipy.ndimage import distance_transform_edt
from scipy.ndimage import convolve
import matplotlib.pyplot as plt
import pandas as pd
from rasterstats import zonal_stats
from affine import Affine

# User-defined inputs
project_path = input("Enter the root project directory: ").strip() or "/path/to/project"  # Default fallback
output_path = os.path.join(project_path, "Output_Data")
temp_path = os.path.join(output_path, "Temp_NPCi")

# Input files
MSPA_tif = input("Enter the path to the MSPA raster file: ").strip() or "/path/to/MSPA.tif"
CORINE_tif = input("Enter the path to the CORINE raster file: ").strip() or "/path/to/CORINE.tif"
extGL_tif = input("Enter the path to the extensive grasslands raster file: ").strip() or "/path/to/extGL.tif"

# Output files
corine_bin_tif = os.path.join(output_path, "U2018_CLC2018_50bin.tif")
NPCi_tif = os.path.join(output_path, "NPCi_2018.tif")
NPCimasked_tif = os.path.join(output_path, "NPCi_masked.tif")
NPCinorm_tif = os.path.join(output_path, "NPCi_norm.tif")

# Temporary directories
MSPA_tiles_path = os.path.join(temp_path, "MSPA_tiles")
recMSPA_tiles_path = os.path.join(temp_path, "MSPA_rec_tiles")
extGL_tiles_path = os.path.join(temp_path, "extGL_tiles")
SNH_tiles_path = os.path.join(temp_path, "SNH_tiles")
NPCi_tiles_path = os.path.join(temp_path, "NPCi_tiles")

# Ensure directories exist
Path(MSPA_tiles_path).mkdir(parents=True, exist_ok=True)
Path(recMSPA_tiles_path).mkdir(parents=True, exist_ok=True)
Path(extGL_tiles_path).mkdir(parents=True, exist_ok=True)
Path(SNH_tiles_path).mkdir(parents=True, exist_ok=True)
Path(NPCi_tiles_path).mkdir(parents=True, exist_ok=True)

# Report time function
def report_time(start_time, step_name):
    elapsed_time = time.time() - start_time
    hours, remainder = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"\n{step_name} completed in {int(hours)}h {int(minutes)}m {int(seconds)}s.")

# Function to check if an output file exists and is valid (not corrupted)
def check_output_tif(output_raster_path):
    """
    Check if the output file exists and whether it is corrupted.

    Parameters:
    output_raster_path (str): Path to the output raster file.

    Returns:
    bool: True if the file exists and is not corrupted, False if the file is corrupted or doesn't exist.
    """
    if os.path.exists(output_raster_path):
        try:
            # Try opening the raster to check if it is readable
            ds = gdal.Open(output_raster_path)
            if ds is not None:
                print(f"Output file {output_raster_path} already exists and is valid. Skipping...")
                return True  # File exists and is valid
            else:
                print(f"Output file {output_raster_path} is corrupted. Reprocessing...")
                os.remove(output_raster_path)  # Delete corrupted file
                return False  # File is corrupted
        except Exception as e:
            print(f"Error while checking {output_raster_path}: {e}")
            print(f"File {output_raster_path} is corrupted. Deleting it.")
            os.remove(output_raster_path)  # Delete corrupted file
            return False  # File is corrupted or unreadable
    return False  # File doesn't exist

# Function to save array to a GeoTIFF file
def save_array_to_tiff(array, output_path, driver, mspa_ds):
    out_ds = driver.Create(str(output_path), mspa_ds.RasterXSize, mspa_ds.RasterYSize, 1, gdal.GDT_Float32)
    out_ds.SetProjection(mspa_ds.GetProjection())
    out_ds.SetGeoTransform(mspa_ds.GetGeoTransform())
    out_ds.GetRasterBand(1).WriteArray(array)
    out_ds.FlushCache()
    out_ds = None  # Close and save the file
   
# Define function to print the extent using rasterio
def print_tile_extent(tile_path):
    """
    Function to print the extent of a given tile using rasterio.

    Parameters:
    tile_path (str): Path to the tile file.
    """
    with rasterio.open(tile_path) as ds:
        bounds = ds.bounds
        print(f"Extent of tile {tile_path}:")
        print(f"  X Min: {bounds.left}")
        print(f"  X Max: {bounds.right}")
        print(f"  Y Min: {bounds.bottom}")
        print(f"  Y Max: {bounds.top}")

# Function to generate the kernel to be used in focal statistics 
def generate_focal_kernel(u, v, exp, save_folder):
    """
    This function generates a kernel matrix for focal statistics based on input parameters u, v, and exp.
    The kernel is saved as CSV and plotted as an image. Both are named as 'm1' and 'kernel_plot'.

    Parameters:
        u (int): Scaling factor in the kernel formula.
        v (int): Degrees of freedom in the kernel formula.
        exp (int): Exponent value in the kernel formula.
        save_folder (str): The folder where the kernel matrix and plot will be saved.
    """
    
    # Step 1: Create the Euclidean Distance Matrix (Kernel Base)
    # Define the grid from -500 to 500 with increments of 50 and compute the Euclidean distance
    x = np.arange(-500, 501, 50)
    y = np.arange(-500, 501, 50)
    X, Y = np.meshgrid(x, y)
    kernel_base = np.sqrt(X**2 + Y**2)  # This forms the base distance matrix for the kernel

    # Step 2: Apply the Kernel Formula
    # The formula incorporates the scaling factor u, degrees of freedom v, and exponent exp
    a = (1 / (u**2 * v * 3.14159)) * 12  # Constant factor derived from u and v
    m1 = a * (1 + (kernel_base**2 / (v * u**2)))**exp

    # Step 3: Normalize the Kernel by the Central Value
    # Find the central value at the middle of the kernel (index [10, 10] in this case) and normalize the kernel
    central_value = m1[10, 10]
    print(f'Central value (at [10, 10]): {central_value}')
    m1_norm = m1 / central_value

    # Step 4: Apply Threshold to the Kernel Matrix
    # The threshold value is the value at the first row, 10th column; below this value, set kernel values to 0
    m1_threshold = m1_norm[0, 10]
    print(f'Threshold value (at [1, 10]): {m1_threshold}')
    m2 = np.where(m1_norm >= m1_threshold, m1_norm, 0)

    # Save the final kernel (m2) matrix for focal statistics as 'kernel.csv'
    os.makedirs(save_folder, exist_ok=True)
    kernel_tif = os.path.join(save_folder, 'kernel.csv')
    np.savetxt(kernel_tif, m2, delimiter=",", fmt='%.15e')
    print(f'Kernel matrix saved at: {kernel_tif}')
    
    ### Plot Distance Kernel Profile
    column_dist = np.arange(-500, 500, 1)
    kernel_val = (1 + (column_dist**2 / (v * u**2)))**(-(v + 1) / 2)
    kernel_mat = np.column_stack((column_dist, kernel_val))

    # Plot
    plt.figure(figsize=(10, 6))
    plt.plot(column_dist, kernel_val, label=f'Kernel Profile (u={u}, v={v})')
    plt.fill_between(column_dist, kernel_val, color='lightblue', alpha=0.5)
    plt.title(f'Kernel Profile with u={u}m and v={v}')
    plt.xlabel('Distance (meters)')
    plt.ylabel('Kernel Value')
    plt.legend()
    
    # Save the plot as 'kernel_val_plot.png'
    plot_tif = os.path.join(save_folder, 'kernel_val_plot.png')
    plt.savefig(plot_tif)
    print(f'Kernel profile plot saved at: {save_folder}')

# Function to Convert the NumPy array back into a GDAL dataset before using gdal.Translate.
def array_to_gdal_dataset(array, affine_transform, projection, nodata_value=None):
    """ Converts a NumPy array to a GDAL dataset in memory. """
    driver = gdal.GetDriverByName('MEM')  # In-memory GDAL driver
    rows, cols = array.shape
    dataset = driver.Create('', cols, rows, 1, gdal.GDT_Float32)  # Create a dataset with 1 band
    
    # Set the affine transformation and projection
    dataset.SetGeoTransform(affine_transform)
    dataset.SetProjection(projection)
    
    # Write the array to the dataset's first band
    band = dataset.GetRasterBand(1)
    band.WriteArray(array)
    
    # Set NoData value if provided
    if nodata_value is not None:
        band.SetNoDataValue(nodata_value)
    
    band.FlushCache()  # Ensure changes are written to the dataset
    return dataset
   
# Function to crop tiles 
def create_tiles(mspa_rst,gl_rst, MSPA_tiles_path, gl_tiles_path,y2, buffer=50):
    """
    Cuts both mspa_rst and gl_rst into horizontal strips (tiles) with optional buffer and saves them to output folders.
    
    Parameters:
    mspa_rst (str): Path to the MSPA raster file.
    gl_rst (str): Path to the gl raster file.
    MSPA_tiles_path (str): Path to the folder where MSPA tiles will be saved.
    gl_tiles_path (str): Path to the folder where gl tiles will be saved.
    y2 (int): Height of each tile in kilometers.
    buffer (int): Buffer size in meters for the cuts.
    """
    
    step_start_time = time.time()   
    print("\n" + "="*40)  # A line of "=" for separation
    print(f"Processing input rasters: \n - {mspa_rst} \n - {gl_rst} \n")

    # Load rasters
    mspa_ds = gdal.Open(mspa_rst)
    gl_ds = gdal.Open(gl_rst)
     
    # Get geotransforms (top-left corner and pixel size) and raster size for rasters
    mspa_geotransform = mspa_ds.GetGeoTransform()
    gl_geotransform = gl_ds.GetGeoTransform()
    
    # Calculate extents (width and height in pixels)
    mspa_x_extent = mspa_ds.RasterXSize * mspa_geotransform[1]  # X pixel size is geotransform[1]
    gl_x_extent = gl_ds.RasterXSize * gl_geotransform[1] 
    
    # Ensure the X extent is consistent for all rasters
    x2 = max(mspa_x_extent, gl_x_extent)  # Use the larger X extent to ensure tiles match
    
    # Use the larger raster height for the y axis
    raster_height = max(mspa_ds.RasterYSize, gl_ds.RasterYSize)

    # Convert y2 to pixels * 1000
    y2 = y2 * 1000
    
    print(f"MSPA Geotransform: {mspa_geotransform}")
    print(f"Grassland Geotransform: {gl_geotransform}")

    mspa_extent = (
        mspa_geotransform[0],  # top-left x
        mspa_geotransform[3],  # top-left y
        mspa_geotransform[0] + mspa_ds.RasterXSize * mspa_geotransform[1],  # bottom-right x
        mspa_geotransform[3] + mspa_ds.RasterYSize * mspa_geotransform[5]   # bottom-right y
    )

    gl_extent = (
        gl_geotransform[0],  # top-left x
        gl_geotransform[3],  # top-left y
        gl_geotransform[0] + gl_ds.RasterXSize * gl_geotransform[1],  # bottom-right x
        gl_geotransform[3] + gl_ds.RasterYSize * gl_geotransform[5]   # bottom-right y
    )

    print(f"MSPA Extent: {mspa_extent}")
    print(f"Grassland Extent: {gl_extent}")

    
    # Calculate number of vertical cuts based on y2 and the height of the larger raster
    y_cuts = int(raster_height / y2)
    print(f"Calculated number of cuts (height-wise): {y_cuts}")
    
    for i in range(y_cuts):
        tile_id = f"tile_{i}.tif"

        # Set the y offsets for each tile
        y1 = i * y2

        # Adjust for the first and last tiles with the buffer
        if i == 0:
            y2_buf = y2 + buffer  # Add buffer to the first cut
            y1_buf = y1
        elif i == (y_cuts - 1):
            y2_buf = y2 - buffer  # Reduce the final cut size by the buffer
            y1_buf = y1 - buffer  # Subtract buffer for the last cut
        else:
            y1_buf = y1 - buffer  # Subtract buffer for middle cuts
            y2_buf = y2 + buffer

        # File paths for MSPA and gl tiles
        mspa_path = os.path.join(MSPA_tiles_path, tile_id)
        gl_path = os.path.join(gl_tiles_path, tile_id)
        
        # Fix the width (X-axis) to be consistent with the original raster
        width_x = mspa_ds.RasterXSize  # Ensure width is constant
        
        # Check if the output file exists and is not corrupted. True if it doesn't exist or is corrupted
        process_mspa = not check_output_tif(str(mspa_path)) # for MSPA
        process_gl = not check_output_tif(str(gl_path))  # for gl

        # Proceed to process the files that need to be processed
        if process_mspa:
            # GDAL command for MSPA raster
            command_mspa = [
                "gdal_translate",
                "-of", "GTiff",  # Output format (GeoTIFF)
                "-co", "BIGTIFF=YES",  # Enable BigTIFF for large files
                "-co", "TILED=YES",  # Tile the image for better performance
                "-srcwin", "0", str(y1_buf), str(width_x), str(y2_buf),  # Fixed X-axis width
                mspa_rst,
                mspa_path
            ]
            # Run gdal_translate for MSPA
            print(f"Processing tile {i + 1}/{y_cuts} for MSPA")
            subprocess.run(command_mspa, check=True)

        if process_gl:
            # GDAL command for gl raster
            command_gl = [
                "gdal_translate",
                "-of", "GTiff",  # Output format (GeoTIFF)
                "-co", "BIGTIFF=YES",  # Enable BigTIFF for large files
                "-srcwin", "0", str(y1_buf), str(width_x), str(y2_buf),  # Fixed X-axis width
                gl_rst,
                gl_path
            ]
            # Run gdal_translate for gl
            print(f"Processing tile {i + 1}/{y_cuts} for Grasslands")
            subprocess.run(command_gl, check=True)
        

    #  Report the time taken
    print(f"Tiles saved to -{MSPA_tiles_path} \n -{gl_tiles_path}")
    report_time(step_start_time, "Tiles creation")
    print("="*40)  # A line of "=" for separation    
       
# Function to reclassify MSPA tiles
def recla_MSPA(MSPA_tiles_path, recMSPA_tiles_path, start_feature=0):
    """
    Reclassify MSPA tiles into different classes (Core, Edge, Linear), adjust core scores based on distance to edges,
    and save the reclassified tiles.
    """

    step_start_time = time.time()
    print("\n" + "="*40)  # A line of "=" for separation
    print(f"Reclassifying MSPA tiles in classes (Core, Edge, Linear), ...")

    # Get the list of tiles from MSPA directory
    mspa_tiles = list(Path(MSPA_tiles_path).glob("*.tif"))
    total_tiles = len(mspa_tiles)

    # Initialize GDAL driver once
    driver = gdal.GetDriverByName('GTiff')

    for i, mspa_tile in enumerate(mspa_tiles):
        if i < start_feature:
            continue  # Skip tiles until reaching the start_feature index

        tile_id = mspa_tile.stem.split("_")[-1]

        print(f"MSPA reclassification, processing tile {tile_id} [{i+1}/{total_tiles}] ...")
        
        # Check if the output file exists and is not corrupted
        output_mspa_tile_path = Path(recMSPA_tiles_path) / f"tile_{tile_id}.tif"  # Set the output file path
        if check_output_tif(str(output_mspa_tile_path)):
            continue  # Skip if the file exists and is valid    

        # Open MSPA tile
        mspa_ds = gdal.Open(str(mspa_tile))
        mspa_array = mspa_ds.GetRasterBand(1).ReadAsArray()

        # Ensure the array is float32 to handle the reclassification properly
        mspa_array = mspa_array.astype(np.float32)

        ### Reclassify Core class (Woody areal interior) ###
        core_array = np.copy(mspa_array)
        core_array[(core_array == 17)] = 20.7
        core_array[(core_array == 117)] = 20.7
        core_array[(core_array != 20.7)] = 0  # All non-Core pixels set to 0
        
        ### Reclassify Edge class (Woody areal exterior) ###
        edge_array = np.copy(mspa_array)
        edge_array[(edge_array == 3)] = 45.6
        edge_array[(edge_array == 5)] = 45.6
        edge_array[(edge_array == 9)] = 45.6
        edge_array[(edge_array == 35)] = 45.6
        edge_array[(edge_array == 67)] = 45.6
        edge_array[(edge_array == 103)] = 45.6
        edge_array[(edge_array == 105)] = 45.6
        edge_array[(edge_array == 119)] = 45.6
        edge_array[(edge_array == 135)] = 45.6
        edge_array[(edge_array == 137)] = 45.6
        edge_array[(edge_array == 167)] = 45.6
        edge_array[(edge_array == 169)] = 45.6
        edge_array[(edge_array != 45.6)] = 0  # All non-Edge pixels set to 0
        
        ### Reclassify Linear class (Woody linear) ###
        linear_array = np.copy(mspa_array)
        linear_array[(linear_array == 1)] = 34.4
        linear_array[(linear_array == 33)] = 34.4
        linear_array[(linear_array == 37)] = 34.4
        linear_array[(linear_array == 65)] = 34.4
        linear_array[(linear_array == 69)] = 34.4
        linear_array[(linear_array == 101)] = 34.4
        linear_array[(linear_array == 109)] = 34.4
        linear_array[(linear_array == 133)] = 34.4
        linear_array[(linear_array == 165)] = 34.4
        linear_array[(linear_array != 34.4)] = 0  # All non-Linear pixels set to 0

        ### Apply Euclidean distance transform using SciPy ###
        array_euc = np.copy(edge_array)
        array_euc[(array_euc == 45.6)] = 1  # Treat edge pixels as 1 for the distance calculation
        array_euc[(array_euc != 1)] = 0  # All non-edge pixels set to 0
        
        # Perform Euclidean distance transform
        array_euc = array_euc != 1
        array_euc = distance_transform_edt(array_euc)
        array_euc = array_euc.astype(np.float32)  # Ensure array is in float32 format 
      
        # array_euc **= 0.5  # Take the square root to get actual Euclidean distances
        # No need to take the square root of the array after using scipy.ndimage.distance_transform_edt. The function distance_transform_edt already computes the exact Euclidean distance. 
        maxdist = 10.0 # Set 0 all pixels further away to 100 meters (10 pixel) 
        array_euc = np.where(array_euc <= maxdist, array_euc, 0) # Set 0 all pixels further away to 100 meters (10 pixel)
        array_euc = array_euc-1 #

        core_array_rec = core_array * np.exp(-0.063 * array_euc) # Apply the negative exponential function
        
        # As raising to the power -1 (pixels beyond 10 meters are set to 0 and then 1 is deducted)
        # is 22.04, we have still the interior beyong 10 pixel filled
        # with 22.04 which should be set to 0.
        core_array_rec[(core_array_rec > 22)] = 0
        core_array_rec[(core_array_rec == 1)] = 0
        
        # As the Euclean distance is calculated with respect to the edge, it is calculaed on both sides from the edge.
        # However, we are only interessted in the core, i.e. the distance inside of the edge. 
        # Hence we need to clean the other side and set it to 0.
        binom_core =  np.copy(core_array)
        binom_core[(binom_core != 0)] = 1
        core_array_rec = core_array_rec * binom_core

        ### Combine arrays ###
        mspa_array = edge_array + core_array_rec + linear_array

        # Save the final reclassified MSPA array
        save_array_to_tiff(mspa_array, output_mspa_tile_path, driver, mspa_ds)

        # Clean up
        mspa_ds = None

    # Report the time taken
    print(f"Output saved in: {SNH_tiles_path}")
    report_time(step_start_time, "Reclassifying MSPA Tiles in classes (Core, Edge, Linear), ")
    print("="*40)  # A line of "=" for separation    

# Function to score extGL_path in Herbaceous Areal, and adding these into the reclassified MSPA tiles 
def SNH_tiles(recMSPA_tiles_path, extGL_path, SNH_tiles_path, start_feature=0, buffer=50):
    """
    Merge tiles from recMSPA_tiles_path and extGL_path, applying a specific rule:
    Add the value of extGL_path only when recMSPA_tiles_path is not 20.7, 45.6, or 35.
    Save the merged tiles in SNH_tiles_path.
    """

    step_start_time = time.time()
    print("\n" + "="*40)  # A line of "=" for separation
    print(f"Adding Herbaceous Areal to MSPA...")
    
    # Get the list of tiles from recMSPA_tiles_path
    recMSPA_tiles = list(Path(recMSPA_tiles_path).glob("*.tif"))
    total_tiles = len(recMSPA_tiles)
    
    # Initialize GDAL driver once
    driver = gdal.GetDriverByName('GTiff')
    
    for i, recMSPA_tile in enumerate(recMSPA_tiles):
        if i < start_feature:
            continue  # Skip tiles until reaching the start_feature index

        tile_id = recMSPA_tile.stem.split("_")[-1]
        extGL_tile = Path(extGL_path) / f"tile_{tile_id}.tif"

        # Check if corresponding extGL tile exists
        if not extGL_tile.exists():
            print(f"Corresponding extGL tile not found for {tile_id}. Stopping process.")
            return  # Stop the function if the tile is missing in extGL_path

        # Define output paths for merged tiles
        SNH_tile_path = Path(SNH_tiles_path) / f"tile_{tile_id}.tif"

        # Check if the output file already exists and is not corrupted
        if check_output_tif(str(SNH_tile_path)):
            continue  # Skip if the file exists and is valid

        print(f"SNH classification. Processing tile {tile_id} [{i+1}/{total_tiles}]...")

        # Open recMSPA and extGL tiles
        recMSPA_ds = gdal.Open(str(recMSPA_tile))
        extGL_ds = gdal.Open(str(extGL_tile))

        recMSPA_array = recMSPA_ds.GetRasterBand(1).ReadAsArray()
        extGL_array = extGL_ds.GetRasterBand(1).ReadAsArray()
        
        ### Reclassify Herbaceous Areal (HA) ###
        extGL_array = extGL_array.astype(np.float32)
        extGL_array[(extGL_array == 1)] = 26.8
        
        # Build the Semi Natual abitat layer. Merge logic -> add extGL value only when recMSPA value is not 20.7, 45.6, or 35
        recMSPA_array_bin = np.where(recMSPA_array >0, 0, 1)
        SNH_array = (recMSPA_array_bin*extGL_array)+recMSPA_array
        
        # Save the final reclassified MSPA array
        save_array_to_tiff(SNH_array, SNH_tile_path, driver, recMSPA_ds)

        # Clean up
        recMSPA_ds, extGL_ds, recMSPA_array_bin,SNH_array = None, None, None, None
        print(f"Tile processed")

    # Report the time taken
    print(f"Output saved in: {recMSPA_tiles_path}")
    report_time(step_start_time, "Herbaceous Areal to MSPA Tiles")
    print("="*40)  # A line of "=" for separation

# Function to resample tiles to 50m, apply focal statistics, and save the output
def calculate_NPCi(SNH_tiles_path, NPCi_tiles_path, project_path, target_resolution, buffer=50):
    """
    Resample each tile in SNH_tiles_path to 50m resolution, apply focal statistics, remove the buffer, and save the output to NPCi_tiles_path.
    
    Parameters:
    SNH_tiles_path (str): Directory containing input tiles.
    NPCi_tiles_path (str): Directory where the output tiles will be saved.
    project_path (str): Project path containing the kernel.csv file.
    target_resolution (int): The target resolution in meters (default is 50m).
    buffer (int): Buffer size in meters (default is 50m).
    """
    step_start_time = time.time()
    print("\n" + "="*40)  # A line of "=" for separation
    print("Processing tiles to calculate NPCi...")

    # Load the kernel created from 'generate_focal_kernel'
    kernel_tif = os.path.join(project_path, "Output_Data", 'kernel.csv')
    kernel = pd.read_csv(kernel_tif, header=None).to_numpy()  # Load the kernel as a NumPy array

    # Get the list of SNH tiles
    SNH_tiles = list(Path(SNH_tiles_path).glob("*.tif"))
    total_tiles = len(SNH_tiles)
    
    # Initialize GDAL driver once
    driver = gdal.GetDriverByName('GTiff')
    
    for i, SNH_tile in enumerate(SNH_tiles):
        tile_id = SNH_tile.stem.split("_")[-1]

        # Define output path for the resampled tile
        resampled_tile_path = Path(NPCi_tiles_path) / f"NPCi_{tile_id}.tif"

        # Check if the output file already exists and is not corrupted
        if check_output_tif(str(resampled_tile_path)):
            continue  # Skip if the file exists and is valid

        print(f"NPCi calculus. Processing tile {tile_id} [{i+1}/{total_tiles}]...")

        # Open the SNH tile
        snh_ds = gdal.Open(str(SNH_tile))
        
        # Resample each tile to 50m resolution
        snh_resampled_ds = gdal.Warp('', snh_ds, format='MEM', xRes=target_resolution, yRes=target_resolution, resampleAlg='nearest')
        snh_resampled_array = snh_resampled_ds.ReadAsArray()

        # Apply focal statistics using the kernel
        snh_focal_array = convolve(snh_resampled_array, kernel, mode='constant', cval=0)
        print(f"Focal statistics applied to tile {tile_id}.")

        # Get the dimensions of the tile (including the buffer)
        y_size, x_size = snh_focal_array.shape

        # Buffer corresponds to 1 pixel in the 50m resolution
        buffer_pixels = int(buffer/target_resolution)   # Calculate buffer in pixels based on target resolution
        
        # Determine how many rows to trim from top/bottom
        if i == 0:  # first tile → remove full buffer from bottom only
            upper_trim = 0
            lower_trim = buffer_pixels
        elif i == total_tiles - 1:  # last tile → remove full buffer top only
            upper_trim = buffer_pixels
            lower_trim = 0
        else:  # middle tiles → half the buffer above, half below
            upper_trim = buffer_pixels
            lower_trim = buffer_pixels 
            
        # Slice away the buffer rows
        y_size, x_size = snh_focal_array.shape
        snh_shrunk_array = snh_focal_array[upper_trim : y_size - lower_trim, :]    
        
        # Update the GeoTransform to reflect the shrunken size
        affine_transform = list(snh_resampled_ds.GetGeoTransform())
        if upper_trim:
           affine_transform[3] -= upper_trim * target_resolution  # shift origin Y
            
        # Create the output dataset with the shrunk dimensions
        output_ds = driver.Create(
            str(resampled_tile_path),
            snh_shrunk_array.shape[1],  # width (x)
            snh_shrunk_array.shape[0],  # height (y)
            1,  # single band
            gdal.GDT_Float32,
        )
        output_ds.SetGeoTransform(affine_transform)
        output_ds.SetProjection(snh_ds.GetProjection())
        output_ds.GetRasterBand(1).WriteArray(snh_shrunk_array)
        output_ds.FlushCache()
        print(f"Processed NPCi tile at {target_resolution}m resolution")

    # Report the time taken
    print(f"Processed tiles saved in: {NPCi_tiles_path}")
    report_time(step_start_time, "Calculate NPCi Tiles")
    print("="*40)  # A line of "=" for separation

# Create Binary Grasslands layer 
def binary_CORINE(CORINE_tif, corine_bin_path, output_path, target_resolution=50):
    """
    Reclassifies the CORINE raster to binary (sets values not in classes_to_mask to 0, and values in classes_to_mask to 1),
    saves it as 'corine_100bin_tif', and then resamples this binary raster to the target resolution (default 50m) and saves
    the resampled result to 'corine_bin_path'.
    
    Parameters:
    - CORINE_tif (str): Path to the original CORINE raster file.
    - corine_bin_path (str): Path to save the resampled binary raster file.
    - output_path (str): Directory to save the binary raster before resampling.
    - target_resolution (int): The target resolution in meters for the resampled file (default is 50m).
    """
    
    step_start_time = time.time()
    print("\n" + "="*40)
    print(f"Masking and resampling CORINE raster...")

    corine_100bin_tif = os.path.join(output_path, "U2018_CLC2018_100bin.tif")  # Binary output file
    # Check if the resampled binary CORINE file already exists
    if os.path.exists(corine_bin_path):
        print(f"Resampled binary CORINE file already exists: {corine_bin_path}. Skipping reclassification and resampling...")
        return

    # Reclassify and save the binary version
    if not os.path.exists(corine_100bin_tif):
        # Open the CORINE raster
        corine_ds = gdal.Open(CORINE_tif)
        corine_array = corine_ds.GetRasterBand(1).ReadAsArray()

        # Apply the mask: set values not in classes_to_mask to 0, and values in classes_to_mask to 1
        classes_to_mask = [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]  # Classes to retain
        corine_mask_array = np.where(np.isin(corine_array, classes_to_mask), 1, 0)

        # Create the output binary raster (corine_100bin_tif)
        driver = gdal.GetDriverByName('GTiff')
        output_masked = driver.Create(str(corine_100bin_tif), corine_mask_array.shape[1], corine_mask_array.shape[0], 1, gdal.GDT_Byte,
                                      ["COMPRESS=DEFLATE", "PREDICTOR=2", "ZLEVEL=9", "BIGTIFF=YES"])
        output_masked.SetGeoTransform(corine_ds.GetGeoTransform())  # Use the original GeoTransform
        output_masked.SetProjection(corine_ds.GetProjection())  # Use the original projection

        # Write the reclassified mask to the output GeoTIFF
        output_masked.GetRasterBand(1).WriteArray(corine_mask_array)  # Save the reclassified array
        output_masked.GetRasterBand(1).SetNoDataValue(0)  # Set NoData value
        output_masked.FlushCache()  # Ensure everything is written to disk
        output_masked = None  # Close the dataset

        print(f"Binary CORINE file saved as: {corine_100bin_tif}")
    else:
        print(f"Binary CORINE file already exists: {corine_100bin_tif}. Skipping binary creation...")

    # Step 2: Resample the binary version to the target resolution
    gdal.Warp(corine_bin_path, corine_100bin_tif, xRes=target_resolution, yRes=target_resolution, 
              resampleAlg="near", options=["COMPRESS=DEFLATE", "PREDICTOR=2", "ZLEVEL=9", "BIGTIFF=YES"])
    print(f"Resampled CORINE file saved as: {corine_bin_path}")

    # Report the time taken
    report_time(step_start_time, "CORINE binary and resampling")
    print("="*40)

# Merge the processed strips back together 
def merge_strips(input_dir, output_file):
    """
    Merges all the processed strips (tiles) in the input_dir and saves the result as a single file to output_file.
    
    Parameters:
    input_dir (str): Directory containing the processed strips (tiles) to merge.
    output_file (str): Output file path where the merged raster will be saved.
    """
    step_start_time = time.time()
    print("\n" + "="*40)  # A line of "=" for separation
    print(f"Merging processed strips from {input_dir} \n into {output_file}...")

    # Collect all .tif files in the input directory
    strips_pattern = list(Path(input_dir).glob("*.tif"))
    
    # Convert PosixPath objects to strings and join them into a single space-separated string
    strips_pattern_str = ' '.join([str(tile) for tile in strips_pattern])

    # GDAL merge command as a single string
    command_merge = f"gdal_merge.py -o {output_file} " \
                    f"-co COMPRESS=LZW " \
                    f"-co PREDICTOR=2 " \
                    f"-co BIGTIFF=YES " \
                    f"-co TILED=YES " \
                    f"-co BLOCKXSIZE=256 -co BLOCKYSIZE=256 " \
                    f"{strips_pattern_str}"

    # Run the merge command
    subprocess.run(command_merge, shell=True, check=True)
    print(f"Merged raster saved to {output_file}.")
    
    report_time(step_start_time, "Merging Strips")
    print("="*40)  # A line of "=" for separation
    
    # Automatically create the compressed output file name by adding "_comp" before ".tif"
    output_file_comp = output_file.replace(".tif", "_comp.tif")

    # Apply additional compression with gdal_translate
    command_translate = f"gdal_translate -of GTiff -co COMPRESS=DEFLATE -co PREDICTOR=2 " \
                        f"-co ZLEVEL=9 -co BIGTIFF=YES -co TILED=YES -co BLOCKXSIZE=256 " \
                        f"-co BLOCKYSIZE=256 {output_file} {output_file_comp}"

    # Run the gdal_translate command for further compression
    subprocess.run(command_translate, shell=True, check=True)
    print(f"Compressed raster saved to {output_file_comp}.")

    # Report the time taken
    report_time(step_start_time, "Merging Strips")
    print("="*40)  # A line of "=" for separation

# Function to Resample source_tif to match the resolution, extent, and geotransform of target_tif
def resample_raster(source_tif, target_tif):
    """
    Resample source_tif to match the resolution, extent, and geotransform of target_tif.
    Saves the resampled file with the same name as source_tif + "_res", using compression.
    Returns the path to the resampled file.
    """
    target_ds = gdal.Open(target_tif)
    if target_ds is None:
        raise FileNotFoundError(f"Target raster not found or invalid: {target_tif}")

    target_geotransform = target_ds.GetGeoTransform()
    target_projection = target_ds.GetProjection()
    target_x_size = target_ds.RasterXSize
    target_y_size = target_ds.RasterYSize

    # Define the output file path
    resampled_tif = os.path.splitext(source_tif)[0] + "_res.tif"

    # Perform the resampling with compression
    warp_result = gdal.Warp(
        resampled_tif,
        source_tif,
        format="GTiff",
        width=target_x_size,
        height=target_y_size,
        dstSRS=target_projection,
        resampleAlg="near",  # Nearest neighbor resampling
        outputBounds=[
            target_geotransform[0],
            target_geotransform[3] + target_geotransform[5] * target_y_size,
            target_geotransform[0] + target_geotransform[1] * target_x_size,
            target_geotransform[3],
        ],
        creationOptions=["COMPRESS=DEFLATE", "PREDICTOR=2", "ZLEVEL=9", "BIGTIFF=YES"],  # Compression options
    )

    if warp_result is None:
        raise RuntimeError(f"Failed to resample raster {source_tif} to match {target_tif}")

    warp_result = None  # Close the GDAL dataset
    print(f"Resampled raster with compression saved to: {resampled_tif}")
    return resampled_tif

# Mask NPCi with CORINE raster
def mask_NPCi(corine_bin_tif, npci_tif, output_tif):
    """
    Mask the values of NPCi_tif where corine_bin_tif is not equal to 1,
    and save the result to output_tif (compressed).
    """

    step_start_time = time.time()
    print("\n" + "="*40)
    print(f"Processing NPCi with CORINE mask...")
    print(f"CORINE Binary: {corine_bin_tif}")
    print(f"NPCi Raster: {npci_tif}")

    # Open NPCi and read array
    npci_ds = gdal.Open(npci_tif)
    if npci_ds is None:
        raise FileNotFoundError(f"NPCi raster not found: {npci_tif}")

    npci_array = npci_ds.GetRasterBand(1).ReadAsArray()

    # Resample CORINE binary to match NPCi
    print("Resampling CORINE raster to match NPCi...")
    resampled_corine_tif = resample_raster(corine_bin_tif, npci_tif)

    # Open the resampled CORINE raster
    corine_ds = gdal.Open(resampled_corine_tif)
    if corine_ds is None:
        raise FileNotFoundError(f"Resampled CORINE raster not found: {resampled_corine_tif}")

    corine_array = corine_ds.GetRasterBand(1).ReadAsArray()

    # Mask the NPCi array using the CORINE binary array
    print("Applying CORINE mask...")
    masked_array = np.where(corine_array == 1, npci_array, np.nan)  # Mask NPCi where CORINE != 1

    # Save the result to output_tif with compression
    print(f"Saving masked NPCi raster to: {output_tif}")
    driver = gdal.GetDriverByName('GTiff')
    out_ds = driver.Create(
        output_tif,
        npci_ds.RasterXSize,
        npci_ds.RasterYSize,
        1,
        gdal.GDT_Float32,
        options=["COMPRESS=DEFLATE", "PREDICTOR=2", "ZLEVEL=9", "BIGTIFF=YES"],
    )
    out_ds.SetGeoTransform(npci_ds.GetGeoTransform())
    out_ds.SetProjection(npci_ds.GetProjection())
    out_ds.GetRasterBand(1).WriteArray(masked_array)
    out_ds.GetRasterBand(1).SetNoDataValue(np.nan)
    out_ds.FlushCache()

    # Clean up
    npci_ds, corine_ds, out_ds = None, None, None
    print(f"Masked NPCi saved successfully to: {output_tif}")

    # Report the time taken
    report_time(step_start_time, "Masking NPCi with CORINE binary")
    print("="*40)

# Normalizes NPCi 0-100
def norm_NPCi(input_tif, output_tif):
    """
    Normalize the values of the input raster to the range [0, 100], round to units (no decimals),
    and save the result to output_tif (compressed).
    """

    step_start_time = time.time()
    print("\n" + "="*40)
    print(f"Normalizing raster values for {input_tif}...")

    # Open the input raster
    input_ds = gdal.Open(input_tif)
    input_array = input_ds.GetRasterBand(1).ReadAsArray()

    # Check for valid data and calculate normalization
    valid_mask = ~np.isnan(input_array)  # Mask to handle NaN values
    max_val = np.nanmax(input_array)    # Find the maximum value in the raster
    min_val = np.nanmin(input_array)    # Find the minimum value in the raster

    normalized_array = np.full_like(input_array, np.nan)  # Initialize with NaNs 
    normalized_array[valid_mask] = np.round(((input_array[valid_mask] - min_val) / (max_val - min_val)) * 100, decimals=2) # Normalize and round
   
    # Save the normalized raster to output_tif with compression
    driver = gdal.GetDriverByName('GTiff')
    out_ds = driver.Create(output_tif, input_ds.RasterXSize, input_ds.RasterYSize, 1, gdal.GDT_Float32, 
              options=["COMPRESS=DEFLATE", "PREDICTOR=2", "ZLEVEL=9", "BIGTIFF=YES"])
    out_ds.SetGeoTransform(input_ds.GetGeoTransform())
    out_ds.SetProjection(input_ds.GetProjection())
    out_ds.GetRasterBand(1).WriteArray(normalized_array)
    out_ds.GetRasterBand(1).SetNoDataValue(np.nan)
    out_ds.FlushCache()

    # Clean up
    input_ds, out_ds = None, None
    print(f"Normalized raster saved successfully to: {output_tif}")

    # Report the time taken
    report_time(step_start_time, "Normalizing raster")
    print("="*40)

# Main function to run all steps
def main():
    print("Starting the process...")
    start_time = time.time()
    
    buffer_npci = 500 
    
	
    # 1. Split MSPA and Grassland rasters into tiles
    create_tiles(MSPA_tif,extGL_tif, MSPA_tiles_path, extGL_tiles_path,y2=2, buffer=buffer_npci)
         
    # 2. Reclassify MSPA tiles into Core, Edge, and Linear categories
    recla_MSPA(MSPA_tiles_path, recMSPA_tiles_path, start_feature=0)
    
    # 3. Combine reclassified MSPA tiles with Grassland tiles to create SNH tiles
    SNH_tiles(recMSPA_tiles_path, extGL_tiles_path, SNH_tiles_path, start_feature=0)
         
    # 4. Generate a focal kernel and calculate NPCi  
    #generate_focal_kernel(u=200, v=25, exp=-13, save_folder=os.path.join(project_path, "Output_Data") )
    calculate_NPCi(SNH_tiles_path, NPCi_tiles_path, project_path,50,buffer_npci) 
    
    # 5. Create a binary CORINE raster
    binary_CORINE(CORINE_tif, corine_bin_tif, output_path, target_resolution=50)
    
    # 6. Merge NPCi tiles into a single raster
    merge_strips(NPCi_tiles_path, NPCi_tif)
    
    # 7. Resample CORINE binary raster to match NPCi resolution
    resample_raster(corine_bin_tif,NPCi_tif)
        
    # 8. Mask NPCi raster using the resampled CORINE binary layer
    corine_bin_res = os.path.join(output_path, "U2018_CLC2018_50bin_res.tif")  # Output raster path for binary CORINE
    mask_NPCi (corine_bin_tif, NPCi_tif, NPCimasked_tif)
    
    # 9. Normalize NPCi values to a range of 0-100
    norm_NPCi (NPCimasked_tif, NPCinorm_tif)

    # 10. Delete teporary folder 
    shutil.rmtree(os.path.join(output_path,"Temp_recMSPAs"), ignore_errors=True)

    # Calculate and print total elapsed time
    # Report total elapsed time
    print("="*40)  # A line of "=" for separation
    report_time(start_time, "Script completed ")
    print("="*40)  # A line of "=" for separation
    
    
# Run the main function
if __name__ == "__main__":
    main()
