Source code for pandas_visual_analysis.widgets.parcats

import ipywidgets as widgets
import plotly.graph_objects as go
import numpy as np

from pandas_visual_analysis import DataSource
from pandas_visual_analysis.utils.config import Config
from pandas_visual_analysis.widgets import BaseWidget, register_widget
from pandas_visual_analysis.widgets.helpers.multi_select import HasMultiSelect


[docs]@register_widget class ParallelCategoriesWidget(BaseWidget, HasMultiSelect): """ The ParallelCategoriesWidget is useful for high dimensional data and supports brushing. Only displays categorical columns, which can be reordered arbitrarily. Displays a multi column selection if there are too many columns to display them all at once. """ def __init__( self, data_source: DataSource, row: int, index: int, relative_size: float, max_height: int, ): """ :param data_source: :class:`pandas_visual_analysis.data_source.DataSource` for the widget. :param row: The row the widget is in. :param index: Index of the row the widget is in. :param relative_size: The space the widget has in a row which is then converted to the width. (e.g. 0.33 => 33%) :param max_height: height in pixels the plot has to have """ super(ParallelCategoriesWidget, self).__init__( data_source, row, index, relative_size, max_height ) super(BaseWidget, self).__init__( self.data_source.categorical_columns, relative_size, max_height ) if len(self.columns) < 1: raise ValueError( "The data contains too few categorical columns to display a parallel categories plot." "Remove the widget from the layout!" ) self.trace, self.figure_widget = self._get_figure_widget() self.set_observers() self.root: widgets.Widget = widgets.HBox( [self.figure_widget], layout=widgets.Layout(width="100%") ) if self.multi_select: self.root = widgets.HBox([self.figure_widget, self.multi_select_widget]) if self.data_source.few_num_cols: pass # todo: implement behaviour for few numerical columns (HTML message)
[docs] def build(self) -> widgets.Widget: return self.apply_size_constraints(self.root)
[docs] def apply_size_constraints(self, widget): return super().apply_size_constraints(widget)
[docs] def observe_brush_indices_change(self, sender): new_indices = list(self.data_source.brushed_indices) new_color = np.zeros(self.data_source.len, dtype="uint8") new_color[new_indices] = 1 with self.figure_widget.batch_update(), self.figure_widget.hold_trait_notifications(): self.figure_widget.data[0].line.color = new_color
[docs] def set_observers(self): self.data_source.on_indices_changed.connect(self.observe_brush_indices_change) if self.use_multi_select: self.multi_select.on_selected_options_changed.connect( self._on_selected_columns_changed )
[docs] def on_selection(self, trace, points, state): new_color = np.zeros(self.data_source.len, dtype="uint8") new_color[points.point_inds] = 1 with self.figure_widget.batch_update(), self.figure_widget.hold_trait_notifications(): self.figure_widget.data[0].line.color = new_color self.data_source.brushed_indices = points.point_inds
[docs] def on_deselection(self, trace, points): pass
def _get_figure_widget(self): config = Config() trace = go.Parcats( dimensions=[ {"label": col, "values": self.data_source.data[col]} for col in self.selected_columns ], line=dict( color=config.color_scale[1][1], colorscale=config.color_scale, cmin=0, cmax=1, shape="hspline", ), ) figure_widget = go.FigureWidget( data=[trace], layout=go.Layout( margin=dict(l=20, r=20, b=20, t=20, pad=5), autosize=True, showlegend=False, ), ) figure_widget.data[0].on_click(self.on_selection) return trace, figure_widget def _on_selected_columns_changed(self, sender): self.selected_columns = sender.selected_options self._redraw_plot() def _redraw_plot(self): new_dims = [ {"label": col, "values": self.data_source.data[col]} for col in self.selected_columns ] self.figure_widget.data[0].dimensions = new_dims new_color = np.zeros(self.data_source.len, dtype="uint8") new_color[list(self.data_source.brushed_indices)] = 1 with self.figure_widget.batch_update(), self.figure_widget.hold_trait_notifications(): self.figure_widget.data[0].line.color = new_color