"""
Functions to plot basic objects like tables and image comparisons.
"""

import difflib
import io
from hashlib import sha256
from typing import Dict, List, Optional, Sequence, Tuple, TypeVar, Union

import numpy as np
import PIL.Image
import PIL.ImageDraw
from IPython.core.display import display
from ipywidgets import widgets
from matplotlib import cm, colors
from matplotlib import pyplot as plt

from aidkit_client._endpoints.models import ImageObjectDetectionBoundingBox

T = TypeVar("T", str, PIL.Image.Image)  # pylint: disable=C0103


def display_table(
    data: Dict[str, List[Union[str, float, int]]],
    header: Optional[Sequence[Union[str, float, int]]] = None,
    highlight_row_header: Optional[Dict[str, str]] = None,
    highlight_cells: Optional[Dict[str, Dict[int, str]]] = None,
    table_width: int = 540,
) -> widgets.GridBox:
    """
    Create a table widget from a dictionary.

    :raises RuntimeError: If the header length and the data do not match.
    :param data: Dict containing the table values.
    :param header: Optional list of headers.
    :param highlight_row_header: Dict specifying which row headers to highlight in which color:
            access color as the following highlight_row_header[row_header].
    :param highlight_cells: Dict specifying which cells to highlight in which color:
            access color as the following highlight_cells[row_header][col].
    :param table_width: Width in pixels of the table. The columns size is
            allocated dynamically.
    :return: Widget containing the table.
    """
    items = []
    if header:
        items.append(widgets.Label(""))
        for _header in header:
            items.append(widgets.HTML(value=f"<b>{_header}</b>"))

    cols = 1 + len(list(data.values())[0])

    if header and cols != (len(header) + 1):
        raise RuntimeError("The length of the header as to match the data.")

    for key, value_list in data.items():

        header_html = f"<b>{key}</b>"

        # add background color
        if highlight_row_header and key in highlight_row_header:
            header_html = (
                f"<span style='background-color:{highlight_row_header[key]};"
                "'>\u00A0\u00A0\u00A0\u00A0</span> " + header_html
            )
        items.append(widgets.HTML(value=header_html))

        for i, value in enumerate(value_list):
            if isinstance(value, dict):
                value_to_print = _pretty_print_dict(value)
            else:
                value_to_print = str(value)

            if highlight_cells:
                highlight = highlight_cells.get(key, {}).get(i, None)
                if highlight:
                    value_to_print = (
                        f"<span style='background-color:{highlight};'>" + value_to_print + "</span>"
                    )
            items.append(widgets.HTML(value=value_to_print))

    return widgets.GridBox(
        items,
        layout=widgets.Layout(
            grid_template_columns=f"150px repeat({cols-1},\
            {(table_width - 150)/(cols-1)}px)"
        ),
    )


def _pretty_print_dict(dictionary: Dict) -> str:
    """
    Pretty print a dictionary.

    :param dictionary: Dictionary to print.
    :return: String with html representation.
    """
    value_to_print = ""
    for key, value in dictionary.items():
        if isinstance(value, list):
            value_to_print += f"{key}: {str(value)[1:-1]}<br>"
        else:
            value_to_print += f"{key}: {value}<br>"
    return value_to_print


def display_observation(
    observation: T,
    title: Optional[str] = None,
    caption: Optional[List[Tuple[str, str]]] = None,
    width: Optional[int] = 300,
) -> Union[widgets.Image, widgets.HTML, widgets.VBox]:
    """
    Create a widget displaying a single observation.

    :param observation: Observation to display.
    :param title: Title to be displayed above the observation.
    :param caption: Caption to be displayed below the observation.
    :param width: Width of the image.
    :return: Widget containing the observation.
    """
    if isinstance(observation, str):
        observation_widget = widgets.HTML(value=observation)
    else:
        buff = io.BytesIO()
        observation.save(buff, format="png")
        observation_widget = widgets.Image(value=buff.getvalue(), width=width)

    widget_list = [observation_widget]

    if title:
        title_widget = widgets.HTML(value=title)
        widget_list = [title_widget] + widget_list

    if caption:
        caption_text = "<center>"
        for i, (key, value) in enumerate(caption):
            if i != 0:
                caption_text += "<br>"
            if len(key + value) > 30:
                caption_text += f'<span title="{value}"><b>{key}</b>: ...{value[-30:]}</b></span>'
            else:
                caption_text += f"<span><b>{key}</b>: {value}</b></span>"

        caption_text += "</center>"
        caption_widget = widgets.HTML(value=caption_text)
        widget_list.append(caption_widget)

    if len(widget_list) == 1:
        return widget_list[0]

    return widgets.VBox(widget_list)


def display_observation_difference(original: T, perturbed: T) -> widgets.VBox:
    """
    Create a widget displaying the difference of two observations of the same
    type depending on an interactive scalar, multiplying the difference.

    :param original: Original observation.
    :param perturbed: Perturbed observation.
    :return: Widget displaying the diff.
    """
    if isinstance(original, str):
        return display_static_observation_difference(original, perturbed)

    original = np.array(original)
    perturbed = np.array(perturbed)
    diff = np.abs(perturbed - original).astype(np.uint16)

    def _plot_img(scalar: widgets.IntSlider) -> None:
        curr = diff * scalar
        curr[curr > 255] = 255
        plt.imshow(curr, vmin=0, vmax=255)
        plt.title("Difference: Perturbation - Original")
        plt.axis("off")

    diff_slider = widgets.IntSlider(value=1, min=1, max=20, description="Scalar")
    interactive_function = widgets.interactive(_plot_img, scalar=diff_slider)
    interactive_function.update()
    return widgets.VBox(children=interactive_function.children)


def display_static_observation_difference(original: T, perturbed: T) -> widgets.VBox:
    """
    Create a widget displaying the difference of two observations of the same
    type.

    :param original: Original observation.
    :param perturbed: Perturbed observation.
    :return: Widget displaying the diff.
    """
    if isinstance(original, str):
        out_html_1 = ""
        out_html_2 = ""
        for character_diff in difflib.ndiff(original, perturbed):
            if character_diff[0] == " ":
                out_html_1 += character_diff[-1]
                out_html_2 += character_diff[-1]
            elif character_diff[0] == "-":
                out_html_1 += f'<FONT COLOR="#FF8000"><s>{character_diff[-1]}</s></FONT>'
            elif character_diff[0] == "+":

                out_html_2 += f'<FONT COLOR="0040FF">{character_diff[-1]}</FONT>'
        return widgets.VBox([widgets.HTML(value=out_html_1), widgets.HTML(value=out_html_2)])

    original = np.array(original)
    perturbed = np.array(perturbed)

    if original.shape != perturbed.shape and len(perturbed.shape) == 2:
        perturbed = np.stack((perturbed,) * 3, axis=-1)

    diff_array = np.abs(perturbed.astype(float) - original.astype(float))
    diff_array_normalized = diff_array * (255 / np.max(diff_array))
    diff = PIL.Image.fromarray(diff_array_normalized.astype("uint8"))

    buff = io.BytesIO()
    diff.save(buff, format="png")
    return widgets.VBox(
        [
            widgets.HTML(value="<b>Difference: Perturbation - Original</b>"),
            widgets.Image(
                value=buff.getvalue(),
                width=300,
                height=400,
            ),
            widgets.HTML(
                value='<FONT COLOR="#949191">The pixel values are \
                 normalized to the range [0, 1] for visibility.</FONT>'
            ),
        ]
    )


def blended_images_widget(
    background: PIL.Image, foreground: PIL.Image, title: str, slider: widgets.FloatRangeSlider
) -> widgets.Output:
    """
    Blend two images together depending on the value of a slider.

    :param background: Image shown when the slider is set to 0.
    :param foreground: The image shown when the slider is set to 1.
    :param title: Title shown above the image.
    :param slider: Range widget used to decide how the images are blended.
    :return: Output widget displaying the image.
    """

    def _display_blended_images(opacity: widgets.FloatRangeSlider) -> None:
        """
        Blend two images together given an opacity value and display the
        result.

        :param opacity: Opacity of the foreground image.
        """
        foreground_rgba = foreground.convert("RGBA")
        background_rgba = background.convert("RGBA")

        blended = PIL.Image.blend(background_rgba, foreground_rgba, alpha=opacity)
        display(display_observation(blended, title=f"<b>{title}</b>", width=250))

    return widgets.interactive_output(_display_blended_images, {"opacity": slider})


def display_semantic_segmentation_inference_argmax_widget(
    original: PIL.Image,
    perturbed: PIL.Image,
    original_prediction: PIL.Image,
    perturbed_prediction: PIL.Image,
    target_classes: List[dict],
) -> widgets.VBox:
    """
    Display a widget containing the prediction of a semantic segmentation
    model.

    :param original: Original observation.
    :param perturbed: The perturbed observation.
    :param original_prediction: An image representing the prediction of the model
        on the original image.
    :param perturbed_prediction: An image representing the prediction of the model
        on the perturbed image.
    :param target_classes: A dictionary containing the name and colors
        of the target classes.
    :returns: A widget displaying the inference of the model.
    """
    target_classes_legend = """<style>
        .class_legend_element {
            background-color:#efefef;
            margin: 3px;
            padding-left: 5px;
            padding-right: 5px;
            float: left;

        .class_legend {
            display: block;
        }
        </style>
        <div class="class_legend">
        """
    for target_class in target_classes:
        target_classes_legend += (
            '<span class="class_legend_element">'
            f'<span style="color:{target_class["color"]}; font-size: 16px">&#9632; </span>'
            f'{target_class["name"]}</span>'
        )
    target_classes_legend += "</div>"

    target_class_legend_widget = widgets.HTML(value=target_classes_legend)

    class_prediction_opacity_slider = widgets.FloatSlider(
        value=0.9, min=0.0, max=1.0, step=0.1, continuous_update=False
    )

    class_prediction_opacity_slider_with_description = widgets.HBox(
        [widgets.Label("Labels Opacity:"), class_prediction_opacity_slider]
    )

    original_with_prediction = blended_images_widget(
        original,
        original_prediction,
        "Prediction for original observation",
        class_prediction_opacity_slider,
    )
    perturbed_with_prediction = blended_images_widget(
        perturbed,
        perturbed_prediction,
        "Prediction for perturbed observation",
        class_prediction_opacity_slider,
    )

    prediction_argmax_widget = widgets.HBox([original_with_prediction, perturbed_with_prediction])

    return widgets.VBox(
        [
            prediction_argmax_widget,
            target_class_legend_widget,
            class_prediction_opacity_slider_with_description,
        ]
    )


def display_image_color_map_legend(
    min_val: float = 0, max_val: float = 0, width: int = 256, color_map: str = "viridis"
) -> widgets.Image:
    """
    Display the legend of an image given a color map.

    :param min_val: The minimum value of the data shown.
    :param max_val: The maximum value of the data shown.
    :param width: The width of the legend in pixels.
    :param color_map: A valid string definining a matplotlib color map.
    :returns: An Image widget showing the color map for the given range.
    """
    pixels = 1 / plt.rcParams["figure.dpi"]
    gradient = np.linspace(min_val, max_val, 256)
    gradient = np.vstack((gradient, gradient))
    legend = plt.figure(figsize=[width * pixels, 10 * pixels])
    legend.gca().imshow(gradient, aspect="auto", cmap=color_map, extent=[min_val, max_val, 0, 1])
    legend.gca().get_yaxis().set_visible(False)

    buff = io.BytesIO()
    plt.savefig(buff, format="jpg", bbox_inches="tight")
    plt.close()

    image_widget = widgets.Image(value=buff.getvalue())
    buff.close()

    return image_widget


def display_class_color(class_id: int, n_classes: int, color_map: str) -> str:
    """
    Map a class ID to a color using the specified color map.

    :param class_id: The ID of the class we want the color for.
    :param n_classes: The total number of clases.
    :param color_map: A valid string definining a matplotlib color map.
    :returns: A string containing the hexadecimal representation of a color.
    """
    cmap = cm.get_cmap(color_map)
    n_classes = max(n_classes - 1, 1)
    rgba = cmap(class_id / n_classes)
    color = colors.to_hex(rgba)
    return color


def get_inference_per_class_confidence(inference_result: np.ndarray) -> List[PIL.Image.Image]:
    """
    Generate a list of images corresponding to the per pixel confidence for a
    given array of values.

    :param inference_result: A numpy array with shape `[r,c,n]` where `r` is the number of rows,
        `c` the number of columns and `n` the number of classes.
    :returns: A list of size `n` of PIL Images corresponding to the per pixel confidence
        of every class.
    """
    inference_per_class_confidence = inference_result.transpose(2, 0, 1)
    inference_confidence_per_class_images = []
    for image in inference_per_class_confidence:
        inference_confidence_per_class_images.append(
            PIL.Image.fromarray(np.uint8(cm.viridis(image, bytes=True)))
        )
    return inference_confidence_per_class_images


def get_inference_argmax_prediction(inference_result: np.ndarray) -> PIL.Image:
    """
    Generates an image where each pixel has the color of the class for which it
    has the highest confidence.

    :param inference_result: A numpy array with shape `[r,c,n]` where `r` is the number of rows,
        `c` the number of columns and `n` the number of classes.
    :returns: A PIL Image where every pixel has the color of the class for which it has the
        highest confidence.
    """
    n_classes = inference_result.shape[2]
    inference_argmax = inference_result.argmax(axis=2)
    n_classes = max(n_classes - 1, 1)
    inference_argmax = inference_argmax / n_classes
    inference_argmax_image = PIL.Image.fromarray(np.uint8(cm.turbo(inference_argmax, bytes=True)))
    return inference_argmax_image


def display_inference_per_class_widget(
    per_class_confidence_images: List,
    perturbed_per_class_confidence_images: List,
    target_classes_dropdown_options: List[Tuple[str, int]],
) -> widgets.VBox:
    """
    Display the model inference for a given class.

    :param per_class_confidence_images: A list of images representing the inference per class.
    :param perturbed_per_class_confidence_images: A list of images representing the inference
        per class on the perturbed observation.
    :param target_classes_dropdown_options: A list pairs of class names and associated class id.
    :returns: A widget containing the inference for all classes.
    """
    target_class_dropdown = widgets.Dropdown(
        options=target_classes_dropdown_options, description="Target class: "
    )

    classes_inference_widgets = []
    for image1, image2 in zip(
        per_class_confidence_images,
        perturbed_per_class_confidence_images,
    ):
        classes_inference_widgets.append(
            widgets.HBox(
                [
                    display_observation(
                        image1, title="<b>Confidence for original image</b>", width=250
                    ),
                    display_observation(
                        image2, title="<b>Confidence for perturbed image</b>", width=250
                    ),
                ]
            )
        )

    legend_widget = display_image_color_map_legend(
        min_val=0.0, max_val=1.0, width=256, color_map="viridis"
    )

    legend_widget = widgets.HBox([widgets.Label(value="Confidence:"), legend_widget])

    def show_class_inference(selection: int) -> None:
        """
        Display the class inference widget of the class with ID given by the
        `selection` argument.

        :param selection: ID of the class for which the inference widget is displayed.
        """
        display(classes_inference_widgets[selection])

    class_inference_per_class_widget = widgets.interactive_output(
        show_class_inference, {"selection": target_class_dropdown}
    )

    class_inference_per_class_widget = widgets.VBox(
        [target_class_dropdown, class_inference_per_class_widget, legend_widget]
    )

    return class_inference_per_class_widget


def display_object_detection_box_count_widget(
    original_detection_output: List[ImageObjectDetectionBoundingBox],
    perturbed_detection_output: List[ImageObjectDetectionBoundingBox],
    class_names: List[str],
) -> widgets.GridBox:
    """
    Create a widget displaying the distribution of boxes in a table.

    :param original_detection_output: Object detection inference of the original observation.
    :param perturbed_detection_output: Object detection inference of the perturbed observation.
    :param class_names: List of all class names.
    :return: Table widget.
    """
    box_table_1_counter: Dict[int, int] = {}
    for bbox in original_detection_output:
        box_table_1_counter[bbox.class_index] = box_table_1_counter.get(bbox.class_index, 0) + 1
    box_table_2_counter: Dict[int, int] = {}
    for bbox in perturbed_detection_output:
        box_table_2_counter[bbox.class_index] = box_table_2_counter.get(bbox.class_index, 0) + 1

    box_table_counter: Dict[str, List[Union[str, float, int]]] = {}
    highlights = {}
    for class_ in set(list(box_table_1_counter.keys()) + list(box_table_2_counter.keys())):
        box_table_counter[class_names[class_]] = [
            box_table_1_counter.get(class_, 0),
            box_table_2_counter.get(class_, 0),
        ]
        highlights[
            class_names[class_]
        ] = f"#{sha256(class_names[class_].encode()).hexdigest()[-6:]}"

    box_table_counter["Sum"] = [len(original_detection_output), len(perturbed_detection_output)]
    return display_table(
        box_table_counter, header=["Original", "Perturbed"], highlight_row_header=highlights
    )


def display_detection_observation(
    observation: PIL.Image.Image,
    observation_inference: List[ImageObjectDetectionBoundingBox],
    class_names: List[str],
    class_filter: str = "All Classes",
    title: Optional[str] = None,
    caption: Optional[str] = None,
) -> Union[widgets.VBox, widgets.Image]:
    """
    Create a widget displaying a single observation.

    :param observation: Observation to display.
    :param observation_inference: Inference of the passed observation.
    :param class_names: List of all class names.
    :param class_filter: Class name of the class to display,
                         or "All Classes".
    :param title: Title to display above the observation.
    :param caption: Caption to be displayed below the observation.
    :return: Widget containing the observation.
    """
    org = observation.copy()

    for bbox in observation_inference:
        if class_filter not in ["All Classes", class_names[bbox.class_index]]:
            continue
        new = PIL.Image.new("RGBA", observation.size, (255, 255, 255, 0))
        draw = PIL.ImageDraw.Draw(new)

        x1, y1, x2, y2 = (  # pylint: disable=invalid-name
            bbox.min_x,
            bbox.min_y,
            bbox.max_x,
            bbox.max_y,
        )
        if isinstance(x1, float):

            x1 = int(x1 * observation.size[0])  # pylint: disable=invalid-name
            x2 = int(x2 * observation.size[0])  # pylint: disable=invalid-name
            y1 = int(y1 * observation.size[1])  # pylint: disable=invalid-name
            y2 = int(y2 * observation.size[1])  # pylint: disable=invalid-name

        draw.rectangle(
            ((x1, y1), (x2, y2)),
            fill=(255, 0, 0, 0),
            width=int(observation.size[1] / 100),
            outline=f"#{sha256(class_names[bbox.class_index].encode()).hexdigest()[-6:]}",
        )

        org = PIL.Image.alpha_composite(org.convert("RGBA"), new)

    buff = io.BytesIO()
    org.save(buff, format="png")

    observation_widget = widgets.Image(
        value=buff.getvalue(),
        width=300,
        height=400,
    )

    widget_list = [observation_widget]

    if title:
        title_widget = widgets.HTML(value=title)
        widget_list = [title_widget] + widget_list

    if caption:
        caption_widget = widgets.HTML(value=caption)
        widget_list.append(caption_widget)

    if len(widget_list) == 1:
        return observation_widget

    return widgets.VBox(widget_list)


def display_selection_class_detection_widget(
    observation_image: PIL.Image,
    adversarial_example_image: PIL.Image,
    observation_detection_output: List[ImageObjectDetectionBoundingBox],
    adversarial_example_detection_output: List[ImageObjectDetectionBoundingBox],
    class_names: List[str],
    observation_title: Optional[str] = None,
    observation_caption: Optional[str] = None,
    perturbation_title: Optional[str] = None,
    perturbation_caption: Optional[str] = None,
) -> widgets.VBox:
    """
    Create a widget displaying an object detection inference in addition to a
    selector filtering the output by class.

    :param observation_image: Image of the original observation.
    :param adversarial_example_image: Image of the adversarial example.
    :param observation_detection_output: Inference of the original observation.
    :param adversarial_example_detection_output: Inference of the adversarial example.
    :param class_names: List of all class names.
    :param observation_title: Title to be displayed above the original observation.
    :param observation_caption: Caption to be displayed under the original observation.
    :param perturbation_title: Title to be displayed above the perturbed observation.
    :param perturbation_caption: Caption to be displayed under the perturbed observation.
    :return: Widget displaying both detection outputs with a selector to either
             display only boxes of the selected class or to display all boxes.
    """

    def _plot_img_box(selector: widgets.Dropdown) -> None:
        display(
            widgets.VBox(
                [
                    widgets.HBox(
                        [
                            display_detection_observation(
                                observation_image,
                                observation_detection_output,
                                class_names=class_names,
                                class_filter=selector,
                                title=observation_title,
                                caption=observation_caption,
                            ),
                            display_detection_observation(
                                adversarial_example_image,
                                adversarial_example_detection_output,
                                class_names=class_names,
                                class_filter=selector,
                                title=perturbation_title,
                                caption=perturbation_caption,
                            ),
                        ]
                    )
                ]
            ),
        )

    class_indices = list(
        set(
            [bbox.class_index for bbox in observation_detection_output]
            + [bbox.class_index for bbox in adversarial_example_detection_output]
        )
    )
    filtered_class_names = [class_names[idx] for idx in class_indices]

    class_selection = widgets.Dropdown(
        options=["All Classes"] + filtered_class_names,
        value="All Classes",
        description="Filter Bounding Boxes:",
        disabled=False,
    )
    class_selection.style.description_width = "150px"

    interactive_function = widgets.interactive(_plot_img_box, selector=class_selection)
    interactive_function.update()

    return widgets.VBox(children=interactive_function.children[::-1])
