# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import concurrent.futures
import glob
import importlib
import math
import os
import pickle
import re
import sqlite3
import sys
from collections import deque

import numpy as np
import pandas as pd
import plotly.express as px
from ipywidgets import IntSlider, interact

# import nsys_rep
# import nvtx
# import multinode_loader

k_nanoseconds2milliseconds = 1000000


k_nvtx_start = "gpu_start"
k_nvtx_end = "gpu_end"
k_nvtx_duration = "gpu_duration"
k_nvtx_text = "text"


k_nvtx_stats_count = "Count"
k_nvtx_stats_sum = "Sum"
k_nvtx_stats_min = "Min"
k_nvtx_stats_max = "Max"
k_nvtx_stats_mean = "Mean"
k_nvtx_stats_median = "Median"
# k_nvtx_stats_std = 'StdDev'
k_nvtx_stats_q1 = "Q1"
k_nvtx_stats_q3 = "Q3"
k_nvtx_stats_rank = "Rank"


# DTSP-14650: Code cleanup
def display_column_graph(
    figs,
    vis_df,
    columnName,
    title="",
    xaxis_title="Rank",
    yaxis_title="Time",
    legend_title="Legend",
):
    k_nanoseconds2milliseconds = 1000000
    ranks_ds = vis_df[k_nvtx_stats_rank]
    column_ds = pd.Series(vis_df[columnName].values / k_nanoseconds2milliseconds)

    return __display_graph(
        figs,
        ranks_ds,  # x_axis
        column_ds,  # y_axes
        title=" ".join([title, columnName]),
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
        legend_title=legend_title,
    )


def display_pace_graph(
    figs,
    fileDir,
    tableName,  # nsys_rep.k_table_nvtx
    selectRowByColumnName,  # k_nvtx_text
    selectRowByColumnValue,  # ex: a particular NVTX range name like ncclAllReduce
    paceColumnName,  # ex: start, end, gpu_start, gpu_end
    wall_adjust=True,
    title=None,
    xaxis_title="Ranks",
    yaxis_title="Time",
    legend_title="Steps",
):
    session_start_min = sys.maxsize
    # session_start_list = list()
    if wall_adjust == True:
        for fileData in fileDir:
            session_df = fileData[nsys_rep.k_table_session_start]
            session_start = session_df.at[0, "utcEpochNs"]
            if session_start < session_start_min:
                session_start_min = session_start

    pace_ds_list = list()
    table_gdf_list = list()  # for pacing
    for fileData in fileDir:
        table_df = fileData[tableName]

        table_df = table_df.loc[
            (table_df[paceColumnName].isna() == False)
            & (table_df[selectRowByColumnName] == selectRowByColumnValue)
        ]

        if wall_adjust == True:
            session_df = fileData[nsys_rep.k_table_session_start]
            session_offset = session_df.at[0, "utcEpochNs"] - session_start_min
            table_df["session_offset"] = session_offset
            table_df[paceColumnName] = (
                table_df[paceColumnName] + table_df["session_offset"]
            )

        pace_ds = table_df[paceColumnName].values
        pace_ds_list.append(pace_ds)
    pace_df = pd.DataFrame(pace_ds_list)

    import warnings

    warnings.filterwarnings("ignore")
    fig = pace_df.plot.line(orientation="v")
    fig.update_layout(
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
        legend_title=legend_title,
        title=(
            title
            if (title != None)
            else (" ".join(["Pace of", selectRowByColumnValue, paceColumnName]))
        ),
    )
    fig.show()

    if figs != None:
        figs.append(fig)

    display("Each line represents how long it took a rank to reach this point in time.")
    return fig


def display_boxplots_grouped(
    figs,
    stats_groups,
    orientation="v",
    title=None,
    xaxis_title="Names",
    yaxis_title="Time",
    legend_title="Legend",
):
    # if we wanted outliers
    # The lower fence is at x = Q1 - 1.5 * IQR.
    # The upper fence is at x = Q3 + I.5 * IQR.
    # The IQR is the interquartile range: IQR = Q3 - Q1.
    # Since the IQR is the length of the box in the boxplot,
    # outliers are data that is more than 1.5 boxlengths
    # from the boxplot box.
    mean_ds = stats_groups["Mean"].mean()
    min_ds = stats_groups["Min"].min()
    q1_ds = stats_groups["Q1"].min()
    q3_ds = stats_groups["Q3"].max()
    max_ds = stats_groups["Max"].max()
    median_ds = stats_groups["Median"].median()
    index = list(stats_groups.groups.keys())

    return display_boxplot(
        figs,
        index,
        min_ds,
        q1_ds,
        median_ds,
        q3_ds,
        max_ds,
        mean_ds=mean_ds,
        orientation=orientation,
        title=title,
        xaxis_title=xaxis_title,
    )


def display_boxplots_df(
    figs,
    stats_df,
    orientation="v",
    title=None,
    xaxis_title="Names",
    yaxis_title="Time",
    legend_title="Legend",
):
    # if we wanted outliers
    # The lower fence is at x = Q1 - 1.5 * IQR.
    # The upper fence is at x = Q3 + I.5 * IQR.
    # The IQR is the interquartile range: IQR = Q3 - Q1.
    # Since the IQR is the length of the box in the boxplot,
    # outliers are data that is more than 1.5 boxlengths
    # from the boxplot box.
    mean_ds = stats_df.get("Mean", None)
    if mean_ds is None:
        mean_ds = stats_df.get("mean", None)

    min_ds = stats_df.get("Min", None)
    if min_ds is None:
        min_ds = stats_df.get("min", None)

    max_ds = stats_df.get("Max", None)
    if max_ds is None:
        max_ds = stats_df.get("max", None)

    q1_ds = stats_df.get("Q1", None)
    if q1_ds is None:
        q1_ds = stats_df.get("Q1 (approx)", None)
        if q1_ds is None:
            q1_ds = stats_df["25%"]

    median_ds = stats_df.get("Median")
    if median_ds is None:
        median_ds = stats_df.get("Median (approx)", None)
        if median_ds is None:
            median_ds = stats_df["50%"]

    q3_ds = stats_df.get("Q3", None)
    if q3_ds is None:
        q3_ds = stats_df.get("Q3 (approx)", None)
        if q3_ds is None:
            q3_ds = stats_df["75%"]

    index = stats_df.index

    return display_boxplot(
        figs,
        index,
        min_ds,
        q1_ds,
        median_ds,
        q3_ds,
        max_ds,
        mean_ds=mean_ds,
        orientation=orientation,
        title=title,
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
        legend_title=legend_title,
    )


def display_boxplot_and_graph(
    figs,
    ranks_ds,
    vis_df,
    orientation="v",
    title=None,
    xaxis_title=None,
    yaxis_title="Time",
    legend_title="Legend",
):
    result_figs = list()
    display_boxplot(
        result_figs,
        ranks_ds,
        vis_df["Min"],
        vis_df["Q1"],
        vis_df["Median"],
        vis_df["Q3"],
        vis_df["Max"],
        mean_ds=vis_df["Mean"],
        orientation=orientation,
        title=(title + " - Full Distribution"),
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
        legend_title=legend_title,
    )

    display_graph(
        result_figs,
        ranks_ds,
        vis_df[["Q1", "Median", "Q3"]],
        title=(title + " - 50% of Distribution"),
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
        legend_title=legend_title,
    )

    if figs != None:
        figs.extend(result_figs)

    return result_figs


def display_boxplot(
    figs,
    x_axis,
    min_ds,
    q1_ds,
    median_ds,
    q3_ds,
    max_ds,
    mean_ds=None,
    orientation="v",
    title=None,
    xaxis_title=None,
    yaxis_title="Time",
    legend_title="Legend",
):
    import plotly.graph_objects as go

    fig = go.Figure()
    fig.add_trace(
        go.Box(
            x=x_axis,
            lowerfence=min_ds,
            q1=q1_ds,
            median=median_ds,
            q3=q3_ds,
            upperfence=max_ds,
            mean=mean_ds,
        )
    )
    fig.update_traces(orientation=orientation)
    fig.update_layout(
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
        legend_title=legend_title,
        title=title,
        height=800,
    )
    fig.show()

    if figs != None:
        figs.append(fig)

    return fig


def display_graph(
    figs,
    x_axis,
    y_axes,
    title=None,
    xaxis_title=None,
    yaxis_title=None,
    legend_title="Legend",
):
    data = None
    if isinstance(y_axes, pd.DataFrame) == True:
        data = y_axes.set_index(x_axis)
    elif isinstance(y_axes, dict) == True:
        data = pd.DataFrame(y_axes, index=x_axis)
    elif isinstance(y_axes, pd.Series) == True:
        data = d.DataFrame({"": y_axes}, index=x_axis)
    elif isinstance(y_axes, np.ndarray) == True:
        data = pd.DataFrame({"": pd.Series(y_axes)}, index=x_axis)
    else:
        # print(type(y_axes))
        return

    fig = data.plot.line()
    fig.update_layout(
        title=title,
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
        legend_title=legend_title,
    )

    fig.show()

    if figs != None:
        figs.append(fig)

    return fig


def display_pace_graph(figs, pace_map_by_column, pace_column, start=1):
    pace_df = pace_map_by_column[pace_column]
    pace_df = pace_df.loc[start:]

    __display_pace_graph(figs, pace_df, pace_column)


def display_pace_graph_delta_minus_median(figs, pace_map_by_column, start=1):
    pace_column = "delta"
    stats_df = pace_map_by_column["delta_stats"]
    median_ds = stats_df["Median"]

    pace_df = pace_map_by_column[pace_column].copy()
    for columnName, column_ds in list(pace_df.items()):
        pace_df[columnName] = column_ds - median_ds

    pace_df = pace_df.loc[start:]
    __display_pace_graph(figs, pace_df, "variance of " + pace_column)


def __display_pace_graph(figs, pace_df, pace_column):
    # display(pace_df)

    import warnings

    warnings.filterwarnings("ignore")

    fig = pace_df.plot.line()
    fig.update_layout(
        yaxis_title="Time",
        title="Progress - Iterations defined by " + pace_column,
    )
    fig.show()
    figs.append(fig)

    fig = pace_df.T.plot.line()
    fig.update_layout(
        yaxis_title="Time",
        title="Consistency - Iterations defined by " + pace_column,
    )
    fig.show()
    figs.append(fig)


def __display_stats_per_rank_of_group(selected, rank_stats_gdf):
    df = rank_stats_gdf.get_group(selected)
    df = df.reset_index(drop=True)
    df = df.set_index(df[k_nvtx_stats_rank])

    display(df)

    figs = list()
    display_boxplots_df(figs, df, xaxis_title="Ranks")
    display_graph(
        figs,
        df.index,
        df[["Q1", "Median", "Q3"]],
        title="50% of Distribution",
        xaxis_title="Ranks",
    )


def display_stats_per_rank_groups_combobox(rank_stats_gdf):
    from ipywidgets import Dropdown, fixed, interact

    list_names = list(rank_stats_gdf.groups.keys())

    # Plot does not display if the value is not manually changed
    if len(list_names) > 1:
        dropdown = Dropdown(
            options=list_names, layout={"width": "max-content"}, value=list_names[1]
        )
        interact(
            __display_stats_per_rank_of_group,
            selected=dropdown,
            rank_stats_gdf=fixed(rank_stats_gdf),
        )
        dropdown.value = list_names[0]
    elif len(list_names) == 1:
        __display_stats_per_rank_of_group(list_names[0], rank_stats_gdf)


def display_top_n_per_rank(df, x, y, color, xaxis_title, yaxis_title, **layout_args):
    display_df = df.reset_index()

    def display_histogram(n):
        grouped_df = display_df.groupby(x).sum()
        grouped_top_df = grouped_df.nlargest(n, y).reset_index()

        top_df = display_df[display_df[x].isin(grouped_top_df[x])]
        top_df = top_df.sort_values([color])

        table_df = top_df.pivot_table(index=x, columns=color, values=y)
        table_df = table_df.reindex(grouped_top_df[x])
        table_df = table_df.rename_axis(xaxis_title)
        display(table_df)

        fig = px.histogram(top_df, x=x, y=y, color=color, barmode="group")

        fig.update_layout(**layout_args)
        fig.update_xaxes(title=xaxis_title, categoryorder="total descending")
        fig.update_yaxes(title=yaxis_title)

        fig.update_traces(
            hovertemplate=f"{xaxis_title}: %{{x}}<br>{yaxis_title}: %{{y}}<extra></extra>"
        )

        fig.show()

    num_group = display_df[x].nunique()
    slider = IntSlider(
        description="N", min=1, max=10, step=1, value=5 if num_group > 5 else num_group
    )
    interact(display_histogram, n=slider)
