Source code for cogrecon.core.visualization.vis_time_travel_task

import logging
import os
import easygui
import math
import sys

import numpy as np
import pyqtgraph as pg
import pyqtgraph.opengl as gl

from pyqtgraph.Qt import QtCore, QtGui
from scipy.misc import imread

import matplotlib as mpl
import matplotlib.pyplot as plt

if __name__ == '__main__':
    from cogrecon.core.data_flexing.time_travel_task.time_travel_task_binary_reader import parse_test_items, \
        get_item_details, get_click_locations_and_indicies, get_items_solutions, phase_num_to_str, \
        read_binary_file, get_filename_meta_data, find_data_files_in_directory
else:
    from ..data_flexing.time_travel_task.time_travel_task_binary_reader import parse_test_items, get_item_details, \
        get_click_locations_and_indicies, get_items_solutions, phase_num_to_str, read_binary_file, \
        get_filename_meta_data, find_data_files_in_directory


# TODO: Fix issue with pyqt globals (should not used globals)
[docs]def visualize_time_travel_data(path=None, automatically_rotate=True): """ This function visualizes data from the Time Travel Task in 3D. :param path: the path to a data file to visualize :param automatically_rotate: If True, the figure will automatically rotate in an Abs(Sin(x)) function shape, otherwise the user can interact with the figure. :return: Nothing """ # noinspection PyGlobalUndefined global path_line, idx, timer, iterations, click_scatter, click_pos, click_color, click_size, window, meta, \ reconstruction_items, num_points_to_update, line_color, line_color_state, auto_rotate auto_rotate = automatically_rotate #################################################################################################################### # Setup #################################################################################################################### # Get Log File Path and Load File local_directory = os.path.dirname(os.path.realpath(__file__)) # The directory of this script # filename = '001_1_1_1_2016-08-29_10-26-03.dat' # The relative path to the data file (CHANGE ME) # path = os.path.join(local_directory, filename) if path is None: path = easygui.fileopenbox() if path is '': logging.info('No file selected. Closing.') exit() if not os.path.exists(path): logging.error('File not found. Closing.') exit() meta = None # noinspection PyBroadException try: meta = get_filename_meta_data(os.path.basename(path)) # The meta filename information for convenience except: logging.error('There was an error reading the filename meta-information. Please confirm this is a valid log ' 'file.') exit() logging.info("Parsing file (" + str(path) + ")...") # First we populate a list of each iteration's data # This section of code contains some custom binary parser data which won't be explained here iterations = read_binary_file(path) # Output the iterations count for debugging purposes logging.info("Plotting " + str(len(iterations)) + " iterations.") # Generate UI Window and Set Camera Settings app = QtGui.QApplication([]) window = gl.GLViewWidget() window.opts['center'] = pg.Qt.QtGui.QVector3D(0, 0, 30) window.opts['distance'] = 200 window.setWindowTitle('Timeline Visualizer' + ' - Subject {0}, Trial {1}, Phase {2}'.format(meta['subID'], meta['trial'], phase_num_to_str(int(meta['phase'])))) #################################################################################################################### # Generate static graphical items #################################################################################################################### # Make Grid grid_items = [] def make_grid_item(loc, rot, scale): g = gl.GLGridItem() g.scale(scale[0], scale[1], scale[2]) g.rotate(rot[0], rot[1], rot[2], rot[3]) g.translate(loc[0], loc[1], loc[2]) return g if meta['phase'] == '0' or meta['phase'] == '3' or meta['phase'] == '6': g0 = make_grid_item((-19, 0, 15), (90, 0, 1, 0), (1.5, 1.9, 1.9)) g1 = make_grid_item((0, -19, 15), (90, 1, 0, 0), (1.9, 1.5, 1.9)) grid_items.append(g0) grid_items.append(g1) window.addItem(g0) window.addItem(g1) else: g0 = make_grid_item((-19, 0, 15), (90, 0, 1, 0), (1.5, 1.9, 1.9)) g1 = make_grid_item((-19, 0, 45), (90, 0, 1, 0), (1.5, 1.9, 1.9)) g2 = make_grid_item((0, -19, 15), (90, 1, 0, 0), (1.9, 1.5, 1.9)) g3 = make_grid_item((0, -19, 45), (90, 1, 0, 0), (1.9, 1.5, 1.9)) grid_items.append(g0) grid_items.append(g1) grid_items.append(g2) grid_items.append(g3) window.addItem(g0) window.addItem(g1) window.addItem(g2) window.addItem(g3) gn = make_grid_item((0, 0, 0), (0, 0, 0, 0), (1.9, 1.9, 1.9)) grid_items.append(gn) window.addItem(gn) # Make Image Base # Determine the background image according to meta phase img_location = './media/time_travel_task/' bg_path = 'studyBG.png' if meta['phase'] == '0' or meta['phase'] == '3': bg_path = 'practiceBG.png' elif meta['phase'] == '6': bg_path = 'practiceBG.png' elif meta['phase'] == '7' or meta['phase'] == '8': bg_path = 'studyBG.png' img = imread(os.path.abspath(os.path.join(img_location, bg_path))) image_scale = (19.0 * 2.0) / float(img.shape[0]) tex1 = pg.makeRGBA(img)[0] base_image = gl.GLImageItem(tex1) base_image.translate(-19, -19, 0) base_image.rotate(270, 0, 0, 1) base_image.scale(image_scale, image_scale, image_scale) window.addItem(base_image) # Make Timeline Colored Bars color_bars = [] def make_color_bar(rgb, p, r, s): v = gl.GLImageItem(np.array([[rgb + (255,)]])) v.translate(p[0], p[1], p[2]) v.scale(s[0], s[1], s[2]) v.rotate(r[0], r[1], r[2], r[3]) return v color_bar_length = 15 if meta['phase'] == '0' or meta['phase'] == '3' or meta['phase'] == '6': times = [0, 7.5, 15, 22.5] color_bar_length = 7.5 else: times = [0, 15, 30, 45] if meta['inverse'] == '1': times.reverse() v0 = make_color_bar((255, 255, 0), (19, times[0], 19), (90, 1, 0, 0), (5, color_bar_length, 0)) v1 = make_color_bar((255, 0, 0), (19, times[1], 19), (90, 1, 0, 0), (5, color_bar_length, 0)) v2 = make_color_bar((0, 255, 0), (19, times[2], 19), (90, 1, 0, 0), (5, color_bar_length, 0)) v3 = make_color_bar((0, 0, 255), (19, times[3], 19), (90, 1, 0, 0), (5, color_bar_length, 0)) color_bars.append(v0) color_bars.append(v1) color_bars.append(v2) color_bars.append(v3) window.addItem(v0) window.addItem(v1) window.addItem(v2) window.addItem(v3) # Generate Path Line forwardColor = (255, 255, 255, 255) backwardColor = (255, 0, 255, 255) line_color = np.empty((len(iterations), 4)) line_color_state = np.empty((len(iterations), 4)) x = [] y = [] z = [] for idx, i in enumerate(iterations): x.append(float(i['x'])) y.append(float(i['z'])) z.append(float(i['time_val'])) c = forwardColor if i['timescale'] <= 0: c = backwardColor line_color[idx] = pg.glColor(c) line_color_state[idx] = pg.glColor((0, 0, 0, 0)) pts = np.vstack([x, y, z]).transpose() path_line = gl.GLLinePlotItem(pos=pts, color=line_color_state, mode='line_strip', antialias=True) window.addItem(path_line) # Generate Item Lines (ground truth) # noinspection PyUnusedLocal items, times, directions = get_items_solutions(meta) if meta['phase'] == '0' or meta['phase'] == '3' or meta['phase'] == '6': times = [2, 12, 18, 25] directions = [2, 1, 2, 1] # Fall = 2, Fly = 1, Stay = 0 if meta['inverse'] == '1': times.reverse() directions.reverse() items = [{'direction': directions[0], 'pos': (2, -12, times[0]), 'color': (255, 255, 0)}, {'direction': directions[1], 'pos': (2, 13, times[1]), 'color': (255, 0, 0)}, {'direction': directions[2], 'pos': (-13, 2, times[2]), 'color': (0, 255, 0)}, {'direction': directions[3], 'pos': (-12, -17, times[3]), 'color': (0, 0, 255)}, {'direction': 0, 'pos': (13, 5, 0), 'color': (128, 0, 128)}] # elif meta['phase'] == '7' or meta['phase'] == '8': # times = [2, 8, 17, 23] # directions = [2, 1, 1, 2] # Fall = 2, Fly = 1, Stay = 0 # if meta['inverse'] == '1': # times.reverse() # directions.reverse() # items = [{'direction': directions[0], 'pos': (16, -14, times[0]), 'color': (255, 255, 0)}, # {'direction': directions[1], 'pos': (-10, -2, times[1]), 'color': (255, 0, 0)}, # {'direction': directions[2], 'pos': (15, -8, times[2]), 'color': (0, 255, 0)}, # {'direction': directions[3], 'pos': (-15, -15, times[3]), 'color': (0, 0, 255)}, # {'direction': 0, 'pos': (-2, 10, 0), 'color': (128, 0, 128)}] else: times = [4, 10, 16, 25, 34, 40, 46, 51] directions = [2, 1, 1, 2, 2, 1, 2, 1] # Fall = 2, Fly = 1, Stay = 0 if meta['inverse'] == '1': times.reverse() directions.reverse() items = [{'direction': directions[0], 'pos': (18, -13, times[0]), 'color': (255, 255, 0)}, {'direction': directions[1], 'pos': (-13, 9, times[1]), 'color': (255, 255, 0)}, {'direction': directions[2], 'pos': (-10, -2, times[2]), 'color': (255, 0, 0)}, {'direction': directions[3], 'pos': (6, -2, times[3]), 'color': (255, 0, 0)}, {'direction': directions[4], 'pos': (17, -8, times[4]), 'color': (0, 255, 0)}, {'direction': directions[5], 'pos': (-2, -7, times[5]), 'color': (0, 255, 0)}, {'direction': directions[6], 'pos': (-15, -15, times[6]), 'color': (0, 0, 255)}, {'direction': directions[7], 'pos': (6, 18, times[7]), 'color': (0, 0, 255)}, {'direction': 0, 'pos': (14, 6, 0), 'color': (128, 0, 128)}, {'direction': 0, 'pos': (-2, 10, 0), 'color': (128, 0, 128)}] item_lines = [] pos = np.empty((len(items), 3)) size = np.empty((len(items))) color = np.empty((len(items), 4)) end_time = 60 if meta['phase'] == '0' or meta['phase'] == '3' or meta['phase'] == '6': end_time = 30 for idx, i in enumerate(items): pos[idx] = i['pos'] size[idx] = 2 if i['direction'] == 0: size[idx] = 0 color[idx] = (i['color'][0] / 255, i['color'][1] / 255, i['color'][2] / 255, 1) idx += 1 end = i['pos'] if i['direction'] == 1: end = (end[0], end[1], 0) elif i['direction'] == 2 or i['direction'] == 0: end = (end[0], end[1], end_time) line = gl.GLLinePlotItem(pos=np.vstack([[i['pos'][0], end[0]], [i['pos'][1], end[1]], [i['pos'][2], end[2]]]).transpose(), color=pg.glColor(i['color']), width=3, antialias=True) item_lines.append(line) window.addItem(line) item_scatter_plot = gl.GLScatterPlotItem(pos=pos, size=size, color=color, pxMode=False) window.addItem(item_scatter_plot) #################################################################################################################### # Generate data graphical items #################################################################################################################### # If Study/Practice, label click events '''click_pos = np.empty((len(items), 3)) click_size = np.zeros((len(iterations), len(items))) click_color = np.empty((len(items), 4)) if meta['phase'] == '0' or meta['phase'] == '1' or meta['phase'] == '3' or meta['phase'] == '4' \ or meta['phase'] == '6' or meta['phase'] == '7': for idx, i in enumerate(iterations): if idx + 1 < len(iterations): for idxx, (i1, i2) in enumerate(zip(i['itemsclicked'], iterations[idx + 1]['itemsclicked'])): if i['itemsclicked'][idxx]: click_size[idx][idxx] = 0.5 if not i1 == i2: click_pos[idxx] = (i['x'], i['z'], i['time_val']) click_color[idxx] = (128, 128, 128, 255) else: for idxx, i1 in enumerate(i['itemsclicked']): if i['itemsclicked'][idxx]: click_size[idx][idxx] = 0.5 ''' click_pos, _, click_size, click_color = get_click_locations_and_indicies(iterations, items, meta) click_scatter = gl.GLScatterPlotItem(pos=click_pos, size=click_size[0], color=click_color, pxMode=False) window.addItem(click_scatter) # If Test, Generate Reconstruction Items event_state_labels, item_number_label, item_label_filename, cols = get_item_details() # if meta['phase'] == '7' or meta['phase'] == '8': # item_number_label = ['bottle', 'clover', 'boot', 'bandana', 'guitar'] # item_label_filename = ['bottle.jpg', 'clover.jpg', 'boot.jpg', 'bandana.jpg', 'guitar.jpg'] # cols = [(255, 255, pastel_factor), (255, pastel_factor, pastel_factor), (pastel_factor, 255, pastel_factor), # (pastel_factor, pastel_factor, 255), (128, pastel_factor / 2, 128)] reconstruction_item_scatter_plot = None reconstruction_item_lines = [] if meta['phase'] == '2' or meta['phase'] == '5' or meta['phase'] == '8': reconstruction_items, order = parse_test_items(iterations, cols, item_number_label, event_state_labels) pos = np.empty((len(reconstruction_items), 3)) size = np.empty((len(reconstruction_items))) color = np.empty((len(reconstruction_items), 4)) # Iterate through the reconstruction items and visualize them for idx, i in enumerate(reconstruction_items): pos[idx] = i['pos'] size[idx] = 2 if i['direction'] == 0: size[idx] = 0 color[idx] = (i['color'][0] / 255, i['color'][1] / 255, i['color'][2] / 255, 1) end = pos[idx] if i['direction'] == 1: end = (end[0], end[1], 0) elif i['direction'] == 2 or i['direction'] == 0: end = (end[0], end[1], end_time) line = gl.GLLinePlotItem(pos=np.vstack([[pos[idx][0], end[0]], [pos[idx][1], end[1]], [pos[idx][2], end[2]]]).transpose(), color=pg.glColor(i['color']), width=3, antialias=True) reconstruction_item_lines.append(line) window.addItem(line) img_path = item_label_filename[idx] img = imread(os.path.join(local_directory, img_path)) expected_size = 2.0 image_scale = expected_size / float(img.shape[0]) offset_param = 0.0 - image_scale / 2 - expected_size / 2 tex = pg.makeRGBA(img)[0] label_image = gl.GLImageItem(tex) t = pos[idx][2] if i['direction'] == 0: t = end_time label_image.translate(pos[idx][0] + offset_param, pos[idx][1] + offset_param, t) label_image.scale(image_scale, image_scale, image_scale) window.addItem(label_image) billboard_item_labels.append(label_image) reconstruction_item_scatter_plot = gl.GLScatterPlotItem(pos=pos, size=size, color=color, pxMode=False) window.addItem(reconstruction_item_scatter_plot) #################################################################################################################### # Show UI #################################################################################################################### window.show() logging.info("Showing plot. Close plot to exit program.") #################################################################################################################### # Custom Keyboard Controls #################################################################################################################### # These variables are modified by the keyboard controls idx = 0 num_points_to_update = 5 saved_points_to_update = 0 paused = False # GUI Callbacks def speed_up(): global num_points_to_update, paused if not paused: num_points_to_update += 5 logging.info("Setting speed to " + str(num_points_to_update) + " points per tick.") def speed_down(): global num_points_to_update, paused if not paused: num_points_to_update -= 5 logging.info("Setting speed to " + str(num_points_to_update) + " points per tick.") def pause(): global num_points_to_update, saved_points_to_update, paused if not paused: logging.info("Paused.") saved_points_to_update = num_points_to_update num_points_to_update = 0 paused = True else: logging.info("Unpaused.") num_points_to_update = saved_points_to_update saved_points_to_update = -0.5 paused = False def reset(): global idx, line_color_state logging.info("Resetting to time zero.") idx = 0 for index in range(0, len(line_color_state) - 1): line_color_state[index] = (0, 0, 0, 0) def go_to_end(): global idx, line_color_state, line_color logging.info("Going to end.") idx = len(line_color_state) - 1 for index in range(0, len(line_color_state) - 1): line_color_state[index] = line_color[index] def close_all(): global timer, app logging.info("User Shutdown Via Button Press") timer.stop() app.closeAllWindows() # Visibility Variables grid_visible = True base_visible = True color_bars_visible = True items_visible = True path_line_visible = True reconstruction_item_lines_visible = True billboard_item_labels_visible = True def toggle_grid_visible(): global grid_visible if grid_visible: for g in grid_items: g.hide() grid_visible = False else: for g in grid_items: g.show() grid_visible = True def toggle_base_visible(): global base_visible if base_visible: base_image.hide() base_visible = False else: base_image.show() base_visible = True def toggle_color_bars_visible(): global color_bars_visible if color_bars_visible: for bar in color_bars: bar.hide() color_bars_visible = False else: for bar in color_bars: bar.show() color_bars_visible = True def toggle_items_visible(): global items_visible if items_visible: item_scatter_plot.hide() for il in item_lines: il.hide() items_visible = False else: item_scatter_plot.show() for il in item_lines: il.show() items_visible = True def toggle_path_line_visible(): global path_line_visible if path_line_visible: path_line.hide() click_scatter.hide() path_line_visible = False else: path_line.show() click_scatter.show() path_line_visible = True def toggle_reconstruction_item_lines_visible(): global reconstruction_item_lines_visible if reconstruction_item_lines_visible: if reconstruction_item_scatter_plot is not None: reconstruction_item_scatter_plot.hide() for ril in reconstruction_item_lines: ril.hide() reconstruction_item_lines_visible = False else: if reconstruction_item_scatter_plot is not None: reconstruction_item_scatter_plot.show() for ril in reconstruction_item_lines: ril.show() reconstruction_item_lines_visible = True def toggle_billboard_item_labels_visible(): global billboard_item_labels_visible if billboard_item_labels_visible: for il in billboard_item_labels: il.hide() billboard_item_labels_visible = False else: for il in billboard_item_labels: il.show() billboard_item_labels_visible = True # GUI Initialization sh = QtGui.QShortcut(QtGui.QKeySequence("+"), window, speed_up) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("-"), window, speed_down) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence(" "), window, pause) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("R"), window, reset) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("E"), window, go_to_end) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("Escape"), window, close_all) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("1"), window, toggle_grid_visible) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("2"), window, toggle_base_visible) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("3"), window, toggle_color_bars_visible) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("4"), window, toggle_items_visible) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("5"), window, toggle_path_line_visible) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("6"), window, toggle_reconstruction_item_lines_visible) sh.setContext(QtCore.Qt.ApplicationShortcut) sh = QtGui.QShortcut(QtGui.QKeySequence("7"), window, toggle_billboard_item_labels_visible) sh.setContext(QtCore.Qt.ApplicationShortcut) #################################################################################################################### # Animation Loop #################################################################################################################### timer = QtCore.QTimer() # noinspection PyUnresolvedReferences timer.timeout.connect(update) timer.start(1) #################################################################################################################### # PyQtGraph Initialization #################################################################################################################### if (sys.flags.interactive != 1) or not hasattr(QtCore, 'PYQT_VERSION'): # noinspection PyArgumentList QtGui.QApplication.instance().exec_()
[docs]def update(): """ This function is the animation update function used internally in visualize_time_travel_data. """ global path_line, idx, timer, iterations, click_scatter, click_pos, click_color, click_size, window, meta, \ reconstruction_items, num_points_to_update, line_color, line_color_state, auto_rotate if auto_rotate: window.opts['elevation'] = math.fabs(math.cos(float(idx) / 800.) * 20.) + 10. window.opts['azimuth'] = math.sin(float(idx) / 800.) * 45 + 45 for _ in range(0, abs(num_points_to_update)): if num_points_to_update > 0: line_color_state[idx] = line_color[idx] idx += 1 else: line_color_state[idx] = (0, 0, 0, 0) idx -= 1 if idx < 0: idx = 0 elif idx >= len(line_color): idx = len(line_color) - 1 break path_line.setData(color=line_color_state) if meta['phase'] == '2' or meta['phase'] == '5' or meta['phase'] == '8': xs = [] zs = [] ts = [] for item in reconstruction_items: xs.append(item['pos'][0]) zs.append(item['pos'][1]) ts.append(item['pos'][2]) # position = np.array([(xpos, zpos, tpos) for (xpos, zpos, tpos) in zip(iterations[idx]['itemsx'], # iterations[idx]['itemsz'], # iterations[idx]['itemstime'])]) position = np.array([(xpos, zpos, tpos) for (xpos, zpos, tpos) in zip(xs, zs, ts)]) active_colors = [] for active_item in iterations[idx]['itemsactive']: if active_item: active_colors.append((255, 255, 255, 255)) else: active_colors.append((0, 0, 0, 0)) active_sizes = np.array([0.5] * len(position)) click_scatter.setData(pos=position, size=active_sizes, color=np.array(active_colors)) else: click_scatter.setData(pos=click_pos, size=click_size[idx], color=click_color, pxMode=False)
[docs]def get_rotation_matrix(i_v, unit=None): """ This function gets a rotation matrix given a vector from a unit vector. :param i_v: the vector whose rotation should be calculated :param unit: the unit vector for reference :return: a rotation matrix """ if unit is None: unit = [1.0, 0.0, 0.0] i_v = np.divide(i_v, np.sqrt(np.dot(i_v, i_v))) u, v, w = np.cross(i_v, unit) axis = np.array([u, v, w]) u, v, w = np.divide(axis, np.sqrt(np.dot(axis, axis))) d = np.dot(i_v, unit) phi = np.arccos(d) rcos = np.cos(phi) rsin = np.sin(phi) matrix = np.zeros((3, 3)) matrix[0][0] = rcos + u * u * (1.0 - rcos) matrix[1][0] = w * rsin + v * u * (1.0 - rcos) matrix[2][0] = -v * rsin + w * u * (1.0 - rcos) matrix[0][1] = -w * rsin + u * v * (1.0 - rcos) matrix[1][1] = rcos + v * v * (1.0 - rcos) matrix[2][1] = u * rsin + w * v * (1.0 - rcos) matrix[0][2] = v * rsin + u * w * (1.0 - rcos) matrix[1][2] = -u * rsin + v * w * (1.0 - rcos) matrix[2][2] = rcos + w * w * (1.0 - rcos) return matrix
[docs]def generate_normed_segments(path, __meta=None, normalize_translation=True, normalize_length=True, normalize_rotation=True): """ This function generates normalized line segments given a path, meta-data, and normalization flags. :param path: the path to the data to process. :param __meta: the meta information from the data filename (automatically detected if None) :param normalize_translation: if True, translation will be normalized :param normalize_length: if True, scaling will be normalized :param normalize_rotation: if True, rotation will be normalized :return: """ _iterations = read_binary_file(path) logging.info("Plotting " + str(len(_iterations)) + " iterations.") if __meta is None: # noinspection PyBroadException try: __meta = get_filename_meta_data(os.path.basename(path)) # The meta filename information for convenience except: logging.error( 'There was an error reading the filename meta-information. Please confirm this is a valid log file.') exit() items, times, directions = get_items_solutions(__meta) # noinspection PyRedeclaration _click_pos, _click_idx, _, _ = get_click_locations_and_indicies(_iterations, items, __meta) _click_idx, _click_pos = [list(l) for l in zip(*sorted(zip(_click_idx, _click_pos)))] num_lines = len(_click_idx) - 1 xs = [] ys = [] zs = [] for line_idx in range(0, num_lines): start_idx = int(_click_idx[line_idx]) end_idx = int(_click_idx[line_idx + 1]) start_iter = _iterations[start_idx] end_iter = _iterations[end_idx - 1] start_pos = [float(start_iter['x']), float(start_iter['z']), float(start_iter['time_val'])] end_pos = [float(end_iter['x']), float(end_iter['z']), float(end_iter['time_val'])] original_vector = np.subtract(end_pos, start_pos) magnitude = np.sqrt(np.dot(original_vector, original_vector)) R = get_rotation_matrix(original_vector) x = [] y = [] z = [] sub_iterations = _iterations[start_idx:end_idx] for _, i in enumerate(sub_iterations): xtmp = float(i['x']) ytmp = float(i['z']) ztmp = float(i['time_val']) if normalize_translation: xtmp, ytmp, ztmp = np.subtract([xtmp, ytmp, ztmp], start_pos) if normalize_length: xtmp, ytmp, ztmp = np.divide([xtmp, ytmp, ztmp], magnitude) if normalize_rotation: xtmp, ytmp, ztmp = np.dot(np.array([xtmp, ytmp, ztmp]).T, R.T) x.append(xtmp) y.append(ytmp) z.append(ztmp) xs.append(x) ys.append(y) zs.append(z) return xs, ys, zs, num_lines
[docs]def subsetter(path): """ This function is used as a helper to subset data via a particular criteria. :param path: the path to the data :return: a bool which, if True, suggests keeping the data, otherwise it suggests removing the data from the visualization """ __meta = None # noinspection PyBroadException try: __meta = get_filename_meta_data(os.path.basename(path)) # The meta filename information for convenience except: logging.error( 'There was an error reading the filename meta-information. Please confirm this is a valid log file.') exit() return int(__meta["subID"]) == 29 and int(__meta["trial"]) >= 0
[docs]def item_path_visualization(search_directory=None, file_regex="\d\d\d_\d_1_\d_\d\d\d\d-\d\d-\d\d_\d\d-\d\d-\d\d.dat"): """ This function visualizes item-to-item paths given some input files. :param search_directory: a directory to search recursively for input files :param file_regex: a regular expression to search for in the files in search_directory :rtype: bool :return: True if files were successfully processed, False otherwise. """ import mpl_toolkits.mplot3d #################################################################################################################### # Setup #################################################################################################################### mpl.rcParams['legend.fontsize'] = 10 fig = plt.figure() ax = fig.gca(projection='3d') # Have a look at the colormaps here and decide which one you'd like: # http://matplotlib.org/1.2.1/examples/pylab_examples/show_colormaps.html colormap = plt.cm.Accent # winter if search_directory is None: search_directory = easygui.diropenbox() if search_directory is '': logging.info('No directory selected, returning.') return False if not os.path.exists(search_directory): raise IOError('Specified search directory was not found.') files = find_data_files_in_directory(search_directory, file_regex=file_regex) if len(files) == 0: logging.info('No files found. returning.') return False num_subs = 0 for path in files: if subsetter(path): num_subs += 1 count = 0 for path in files: _meta = None # noinspection PyBroadException try: _meta = get_filename_meta_data(os.path.basename(path)) # The meta filename information for convenience except: logging.error( 'There was an error reading the filename meta-information. Please confirm this is a valid log file.') exit() if subsetter(path): xs, ys, zs, num_lines = generate_normed_segments(path, _meta) # plt.gca().set_color_cycle([colormap(i) for i in np.linspace(0, 0.9, num_lines)]) base_colors = [colormap(i) for i in np.linspace(0, 0.9, num_subs)] col_space = [base_colors[count]] * num_lines col_space = [(x[0], x[1], x[2], a) for x, a in zip(col_space, np.linspace(0.5, 1, num_lines))] labels = [""] * num_lines labels[-1] = "{0} trial {1}".format(_meta["subID"], _meta["trial"]) [ax.plot(x, y, z, label=labels[_idx], color=col_space[_idx]) for _idx, (x, y, z) in enumerate(zip(xs, ys, zs))] count += 1 ax.scatter([0, 1], [0, 0], [0, 0], color='k') ax.legend() ax.set_aspect('equal') plt.show() ''' import sklearn.preprocessing as skpre from scipy import stats from mayavi import mlab xss = [] yss = [] zss = [] for path in files: meta = None # noinspection PyBroadException try: meta = get_filename_meta_data(os.path.basename(path)) # The meta filename information for convenience except: logging.error( 'There was an error reading the filename meta-information. Please confirm this is a valid log file.') exit() if subsetter(path): xs, ys, zs, num_lines = generate_normed_segments(path, meta) xss.extend(xs) yss.extend(ys) zss.extend(zs) xss = np.array([item for sublist in xss for item in sublist]) yss = np.array([item for sublist in yss for item in sublist]) zss = np.array([item for sublist in zss for item in sublist]) xyz = np.vstack([xss, yss, zss]) kde = stats.gaussian_kde(xyz) # Evaluate kde on a grid xmin, ymin, zmin = xss.min(), yss.min(), zss.min() xmax, ymax, zmax = xss.max(), yss.max(), zss.max() xi, yi, zi = np.mgrid[xmin:xmax:30j, ymin:ymax:30j, zmin:zmax:30j] coords = np.vstack([item.ravel() for item in [xi, yi, zi]]) density = kde(coords).reshape(xi.shape) # Plot scatter with mayavi figure = mlab.figure('DensityPlot') grid = mlab.pipeline.scalar_field(xi, yi, zi, density) min = density.min() max=density.max() mlab.pipeline.volume(grid, vmin=min, vmax=min + .5*(max-min)) mlab.axes() mlab.show() ''' return True
if __name__ == '__main__': visualize_time_travel_data() # The commented visualization below needs testing and does not currently work. # item_path_visualization(search_directory=r"C:\Users\Kevin\Desktop\Work\Time Travel Task\v2", # file_regex="021_\d_1_\d_\d\d\d\d-\d\d-\d\d_\d\d-\d\d-\d\d.dat")