Source code for cogrecon.core.visualization.vis_iposition

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.animation as animation
import logging

if __name__ == "__main__":
    from cogrecon.core.tools import lerp
    from cogrecon.core.cogrecon_globals import default_animation_duration, default_animation_ticks, \
        default_visualization_transformed_points_color, default_visualization_transformed_points_alpha, \
        default_visualization_actual_points_color, default_visualization_data_points_color, \
        default_visualization_actual_points_size, default_visualization_data_points_size, \
        default_visualization_font_size, default_visualization_accuracies_corrected_alpha, \
        default_visualization_accuracies_incorrect_color, default_visualization_accuracies_correct_color, \
        default_visualization_accuracies_uncorrected_color, default_visualization_accuracies_uncorrected_alpha
else:
    from ..tools import lerp
    from ..cogrecon_globals import default_animation_duration, default_animation_ticks, \
        default_visualization_transformed_points_color, default_visualization_transformed_points_alpha, \
        default_visualization_actual_points_color, default_visualization_data_points_color, \
        default_visualization_actual_points_size, default_visualization_data_points_size, \
        default_visualization_font_size, \
        default_visualization_accuracies_corrected_alpha, default_visualization_accuracies_incorrect_color, \
        default_visualization_accuracies_correct_color, default_visualization_accuracies_uncorrected_color, \
        default_visualization_accuracies_uncorrected_alpha


# noinspection PyDefaultArgument
[docs]def visualization(trial_data, analysis_configuration, min_points, transformed_points, output_list, start_threshold, end_threshold, start_accuracy_map, end_accuracy_map, animation_duration=default_animation_duration, animation_ticks=default_animation_ticks, print_output=True, extent=None, fig_size=None, legend_args=None): """ This function visualizes TrialData, showing all the steps in the pipeline. :param end_accuracy_map: the accuracy map at the end of processing :param start_accuracy_map: the accuracy map at the beginning of processing :param end_threshold: the accuracy threshold at the end of processing :param start_threshold: the accuracy threshold at the beginning of processing :param trial_data: the TrialData to be visualized :param analysis_configuration: the AnalysisConfiguration to use to visualize (for accuracy visualization mainly) :param min_points: the points output from the deanonymization task :param transformed_points: the points output from the transformation task :param output_list: the final outputs produced by full_pipeline :param animation_duration: a time in seconds specifying the duration of the transform animation :param animation_ticks: the number of ticks (frame updates) which should occur throughout the animation :param print_output: if True, the output_list values will be printed in a user friendly form :param extent: the extents to plot in the data space :param fig_size: a tuple containing the size of the figure in inches :param legend_args: a list of arguments to be passed to the legend (default is None) """ from ..full_pipeline import get_header_labels actual_points = trial_data.actual_points data_points = trial_data.data_points # z_value = analysis_configuration.z_value debug_labels = analysis_configuration.debug_labels if print_output: for l, o in zip(get_header_labels(), output_list): print(l + ": " + str(o)) if len(actual_points[0]) == 1: logging.warning("the visualization method expects 2D points, but 1D was found. Appending 0s for 'y' for " "visualization.") for i in range(len(actual_points)): actual_points[i] = [actual_points[i][0], 0.] data_points[i] = [data_points[i][0], 0.] min_points[i] = [min_points[i][0], 0.] transformed_points[i] = [transformed_points[i][0], 0.] if len(actual_points[0]) != 2: logging.error("the visualization method expects 2D points, found {0}D".format(len(actual_points[0]))) return # Generate a figure with 3 scatter plots (actual points, data points, and transformed points) fig, ax = plt.subplots() plt.title(str(debug_labels)) ax.set_aspect('equal') labels = range(len(actual_points)) x = [float(v) for v in list(np.transpose(transformed_points)[0])] y = [float(v) for v in list(np.transpose(transformed_points)[1])] transformed_scatter = ax.scatter(x, y, c=default_visualization_transformed_points_color, alpha=default_visualization_transformed_points_alpha, label='Transformed Points') scat = ax.scatter(x, y, c=default_visualization_transformed_points_color, animated=True) actual_scatter = ax.scatter(np.transpose(actual_points)[0], np.transpose(actual_points)[1], c=default_visualization_actual_points_color, s=default_visualization_actual_points_size, label='Actual Points') data_scatter = ax.scatter(np.transpose(data_points)[0], np.transpose(data_points)[1], c=default_visualization_data_points_color, s=default_visualization_data_points_size, label='Data Points') # Label the stationary points (actual and data) for idx, xy in enumerate(zip(np.transpose(actual_points)[0], np.transpose(actual_points)[1])): ax.annotate(labels[idx], xy=xy, textcoords='data', fontsize=default_visualization_font_size) for idx, xy in enumerate(zip(np.transpose(data_points)[0], np.transpose(data_points)[1])): ax.annotate(labels[idx], xy=xy, textcoords='data', fontsize=default_visualization_font_size) # Generate a set of interpolated points to animate the transformation lerp_data = [[lerp(p1, p2, t) for p1, p2 in zip(min_points, transformed_points)] for t in np.linspace(0.0, 1.0, animation_ticks)] # Generate accuracy patches for acc, x, y in zip(start_accuracy_map, np.transpose(transformed_points)[0], np.transpose(transformed_points)[1]): ax.add_patch(plt.Circle((x, y), start_threshold, alpha=default_visualization_accuracies_uncorrected_alpha, color=default_visualization_accuracies_uncorrected_color)) for acc, x, y in zip(end_accuracy_map, np.transpose(min_points)[0], np.transpose(min_points)[1]): color = default_visualization_accuracies_incorrect_color if acc: color = default_visualization_accuracies_correct_color ax.add_patch(plt.Circle((x, y), end_threshold, alpha=default_visualization_accuracies_corrected_alpha, color=color)) # Generate legend transformed_threshold_incorrect_patch = mpatches.Patch(color=default_visualization_accuracies_incorrect_color, alpha=default_visualization_accuracies_corrected_alpha, label='Threshold, Incorrect') transformed_threshold_correct_patch = mpatches.Patch(color=default_visualization_accuracies_correct_color, alpha=default_visualization_accuracies_corrected_alpha, label='Threshold, Correct') pre_transformed_threshold = mpatches.Patch(color=default_visualization_accuracies_uncorrected_color, alpha=default_visualization_accuracies_uncorrected_alpha, label='Pre-Transformed Threshold') if legend_args is None: plt.legend(handles=[transformed_scatter, actual_scatter, data_scatter, transformed_threshold_incorrect_patch, transformed_threshold_correct_patch, pre_transformed_threshold]) else: plt.legend(handles=[transformed_scatter, actual_scatter, data_scatter, transformed_threshold_incorrect_patch, transformed_threshold_correct_patch, pre_transformed_threshold], *legend_args) # An update function which will set the animated scatter plot to the next interpolated points def update(_i): scat.set_offsets(lerp_data[_i % animation_ticks]) return scat, # Begin the animation/plot # noinspection PyUnusedLocal anim = animation.FuncAnimation(fig, update, interval=(float(animation_duration) / float(animation_ticks)) * 1000, blit=True) if extent is not None: assert isinstance(extent, list) and np.array(extent).shape == (2, 2) and \ all([isinstance(x, float) for x in np.array(extent).flatten().tolist()]), \ 'extent must be a 2 by 2 list of floating point values' axes = plt.gca() axes.set_xlim(extent[0]) axes.set_ylim(extent[1]) fig2 = plt.gcf() if fig_size is not None: fig2.set_size_inches(*fig_size) else: fig2.set_size_inches(15, 9) fig.show() plt.show()