# requirements:
# pip install psutil
# conda install -c plotly plotly-orca

import pandas as pd
import plotly.graph_objects as go
import numpy as np
import textwrap

# Read in CSV files and set the 'original_name' column as the index for both DataFrames
df_fr = pd.read_csv('../table/CropDeepTrans_FigS1_Sankey_FR.csv').set_index('original_name')
df_nl = pd.read_csv('../table/CropDeepTrans_FigS1_Sankey_NL.csv').set_index('original_name')

# Set the number of codes to display in the chart, or 'All' if all codes should be displayed. otherwise use integer like Num_code_to_display=40
# Num_code_to_display = 'All'
Num_code_to_display = 40

# Add a 'new_name' column to each DataFrame and assign it the same values as the index
df_nl['new_name']  = df_nl.index 
df_fr['new_name']  = df_fr.index 

# If Num_code_to_display is an integer, limit the number of codes displayed and group the remaining codes together as 'Anderen' for BRP and 'Autre' for RPG
if isinstance(Num_code_to_display, int) :
    df_nl  =df_nl.reset_index()
    df_nl.new_name.loc[df_nl.sort_values('area',ascending=False).iloc[Num_code_to_display:,:].index] = 'Anderen'
    df_nl  =df_nl.set_index('new_name')
    df_fr  =df_fr.reset_index()
    df_fr.new_name.loc[df_fr.sort_values('area',ascending=False).iloc[Num_code_to_display:,:].index] = 'Autre'
    df_fr  =df_fr.set_index('new_name')
    List_HCAT2 = list(np.unique(\
        df_fr.sort_values('area',ascending=False).iloc[:Num_code_to_display,:].HCAT2_code.unique().tolist() + \
        df_nl.sort_values('area',ascending=False).iloc[:Num_code_to_display,:].HCAT2_code.unique().tolist()  )) 
    rpg = df_fr.sort_values('area',ascending=False).iloc[:Num_code_to_display,:].index.unique().tolist()
    brp = df_nl.sort_values('area',ascending=False).iloc[:Num_code_to_display,:].index.unique().tolist()
else:# If display All codes
    df_nl  =df_nl.set_index('new_name')
    df_fr  =df_fr.set_index('new_name')
    List_HCAT2 = list(np.unique(\
        df_fr.HCAT2_code.unique().tolist() + \
        df_nl.HCAT2_code.unique().tolist()  )) 
    rpg = df_fr.index.unique().tolist()
    brp = df_nl.index.unique().tolist()

# Create a list of labels for the nodes in the Sankey diagram
label = rpg+brp+List_HCAT2+['Missing RPG code','Missing BRP code','Autre','Anderen']


# Iterate over each HCAT2_code in List_HCAT2 and create links for each code in both the French and Dutch DataFrames
source=[]
target=[]
value=[]

# Iterate through each HCAT2 code in the list
for i,h in enumerate(List_HCAT2):
    
    # Filter the dataframe to only include rows with the current HCAT2 code and group by new_name while summing the area
    df_temp  = df_fr.loc[df_fr.HCAT2_code == h].reset_index('new_name').groupby('new_name').sum()
    
    # If there are any rows remaining in the filtered dataframe, iterate through each row
    if df_temp.index.shape[0]>0:       
        # Iterate through each row in the filtered dataframe
        for c in df_fr.loc[df_fr.HCAT2_code == h].index:            
            # If the current row's index is in the RPG list, append its label index to the source list
            if c in rpg: 
                source.append(label.index(c))
            # Otherwise, append the label index for "Autre" to the source list
            else:        
                source.append(label.index('Autre'))
            
            # Append the label index for the current HCAT2 code to the target list
            target.append(label.index(h))
            
            # Append the area value for the current row to the value list
            value.append(df_temp.loc[c].area)
    
    # If there are no rows remaining in the filtered dataframe, append a small value to the value list and the label index for "Missing RPG code" to the source list
    else:
        target.append(label.index(h))
        source.append(label.index('Missing RPG code'))
        value.append(0.0001)

# Iterate through each HCAT2 code in the list
for i,h in enumerate(List_HCAT2):
    
    # Filter the dataframe to only include rows with the current HCAT2 code and group by new_name while summing the area
    df_temp  = df_nl.loc[df_nl.HCAT2_code == h].reset_index('new_name').groupby('new_name').sum()
    
    # If there are any rows remaining in the filtered dataframe, iterate through each row
    if df_temp.index.shape[0]>0:
        
        # Iterate through each row in the filtered dataframe
        for c in df_temp.index:
            
            # Append the label index for the current new_name to the target list
            target.append(label.index(c))
            
            # Append the label index for the current HCAT2 code to the source list
            source.append(label.index(h))
            
            # Append the area value for the current row to the value list
            value.append(df_temp.loc[c].area)
    
    # If there are no rows remaining in the filtered dataframe, append a very small value to the list and the label index for "Missing BRP code" to the target list
    else:
        source.append(label.index(h))
        target.append(label.index('Missing BRP code'))
        value.append(0.0001)

# Check if Num_code_to_display is an integer, ie not "All"
if isinstance(Num_code_to_display, int):
    # If so, update the label for 'Autre' and 'Anderen'
    label[label.index('Autre')] = 'Other RPG codes (>' + str(Num_code_to_display) + 'ᵗʰ main ones)'
    label[label.index('Anderen')] = 'Other BRP codes (>' + str(Num_code_to_display) + 'ᵗʰ main ones)'
    
# Shorten label text to a maximum of 40 characters with ellipses
label = [textwrap.shorten(x, width=40, placeholder='...') for x in label]

# Create Sankey diagram with specified node and link attributes
fig = go.Figure(data=[go.Sankey(
    node = dict(
        pad = 15,
        thickness = 20, 
        line = dict(color = 'black', width = 0.5),
        label =label,
        color = 'blue',
        hovertemplate='%{value}%<extra></extra>'
    ),
    link = dict(
        source = source, # indices correspond to labels, eg A1, A2, A1, B1, ...
        target = target,
        value=value
    ))])

# Update layout with title, font size, and dimensions
fig.update_layout(title_text='FR / NL crops harmonization ('+str(Num_code_to_display)+' crops selected)',\
                  font_size=10,width=900, height=len(label)*7)

# Display Sankey diagram - may not work everywhere. better to save html first
# fig.show()

# Write Sankey diagram as HTML file
fig.write_html('../data/sankey_'+str(Num_code_to_display)+'_crops.html')

# Write Sankey diagram as PNG file with double resolution
fig.write_image('../data/sankey_'+str(Num_code_to_display)+'_crops.png', scale=2)
