# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

from nsys_recipe.data_service import DataReader
from nsys_recipe.lib import exceptions


class ProfileInfo:
    @staticmethod
    def mapper_func(report_path):
        service = DataReader(report_path)

        table_column_dict = {
            "ANALYSIS_DETAILS": ["duration"],
            "TARGET_INFO_SESSION_START_TIME": ["utcEpochNs"],
        }

        df_dict = service.read_tables(table_column_dict)
        if df_dict is None:
            return None, None, None

        max_duration = df_dict["ANALYSIS_DETAILS"].at[0, "duration"]
        session_time = df_dict["TARGET_INFO_SESSION_START_TIME"].at[0, "utcEpochNs"]

        return report_path, max_duration, session_time

    @staticmethod
    def reducer_func(mapper_res, parsed_args):
        filtered_mapper_res = [res for res in mapper_res if res != (None, None, None)]
        if not filtered_mapper_res:
            raise exceptions.NoDataError

        # Reorder the results based on the input order.
        filtered_mapper_res = tuple(
            sorted(filtered_mapper_res, key=lambda x: parsed_args.input.index(x[0]))
        )
        report_paths, max_durations, session_times = map(
            list, zip(*filtered_mapper_res)
        )

        disable_alignment = getattr(parsed_args, "disable_alignment", False)

        session_offsets = (
            [0] * len(session_times)
            if disable_alignment
            else [session_time - min(session_times) for session_time in session_times]
        )

        max_durations = [
            max_duration + session_offset
            for max_duration, session_offset in zip(max_durations, session_offsets)
        ]

        return report_paths, max_durations, session_offsets

    @staticmethod
    def get_profile_info(context, parsed_args):
        """Get profile information for each report file.

        Parameters
        ----------
        context : Context
            Context object.
        parsed_args : argparse.Namespace
            Parsed arguments with the following fields:
            - input (required)
            - disable_alignment (optional)

        Returns
        -------
        report_paths : list of str
            List of paths to report files.
        max_durations : list of int
            Maximum duration for each report file.
        session_offsets : list of int
            Session offsets for each report file.
        """
        mapper_res = context.wait(
            context.map(ProfileInfo.mapper_func, parsed_args.input)
        )
        return ProfileInfo.reducer_func(mapper_res, parsed_args)
