Source code for pandas_visual_analysis.widgets.parcats
import ipywidgets as widgets
import plotly.graph_objects as go
import numpy as np
from pandas_visual_analysis import DataSource
from pandas_visual_analysis.utils.config import Config
from pandas_visual_analysis.widgets import BaseWidget, register_widget
from pandas_visual_analysis.widgets.helpers.multi_select import HasMultiSelect
[docs]@register_widget
class ParallelCategoriesWidget(BaseWidget, HasMultiSelect):
"""
The ParallelCategoriesWidget is useful for high dimensional data and supports brushing.
Only displays categorical columns, which can be reordered arbitrarily.
Displays a multi column selection if there are too many columns to display them all at once.
"""
def __init__(
self,
data_source: DataSource,
row: int,
index: int,
relative_size: float,
max_height: int,
):
"""
:param data_source: :class:`pandas_visual_analysis.data_source.DataSource` for the widget.
:param row: The row the widget is in.
:param index: Index of the row the widget is in.
:param relative_size: The space the widget has in a row which is then converted to the width. (e.g. 0.33 => 33%)
:param max_height: height in pixels the plot has to have
"""
super(ParallelCategoriesWidget, self).__init__(
data_source, row, index, relative_size, max_height
)
super(BaseWidget, self).__init__(
self.data_source.categorical_columns, relative_size, max_height
)
if len(self.columns) < 1:
raise ValueError(
"The data contains too few categorical columns to display a parallel categories plot."
"Remove the widget from the layout!"
)
self.trace, self.figure_widget = self._get_figure_widget()
self.set_observers()
self.root: widgets.Widget = widgets.HBox(
[self.figure_widget], layout=widgets.Layout(width="100%")
)
if self.multi_select:
self.root = widgets.HBox([self.figure_widget, self.multi_select_widget])
if self.data_source.few_num_cols:
pass # todo: implement behaviour for few numerical columns (HTML message)
[docs] def observe_brush_indices_change(self, sender):
new_indices = list(self.data_source.brushed_indices)
new_color = np.zeros(self.data_source.len, dtype="uint8")
new_color[new_indices] = 1
with self.figure_widget.batch_update(), self.figure_widget.hold_trait_notifications():
self.figure_widget.data[0].line.color = new_color
[docs] def set_observers(self):
self.data_source.on_indices_changed.connect(self.observe_brush_indices_change)
if self.use_multi_select:
self.multi_select.on_selected_options_changed.connect(
self._on_selected_columns_changed
)
[docs] def on_selection(self, trace, points, state):
new_color = np.zeros(self.data_source.len, dtype="uint8")
new_color[points.point_inds] = 1
with self.figure_widget.batch_update(), self.figure_widget.hold_trait_notifications():
self.figure_widget.data[0].line.color = new_color
self.data_source.brushed_indices = points.point_inds
def _get_figure_widget(self):
config = Config()
trace = go.Parcats(
dimensions=[
{"label": col, "values": self.data_source.data[col]}
for col in self.selected_columns
],
line=dict(
color=config.color_scale[1][1],
colorscale=config.color_scale,
cmin=0,
cmax=1,
shape="hspline",
),
)
figure_widget = go.FigureWidget(
data=[trace],
layout=go.Layout(
margin=dict(l=20, r=20, b=20, t=20, pad=5),
autosize=True,
showlegend=False,
),
)
figure_widget.data[0].on_click(self.on_selection)
return trace, figure_widget
def _on_selected_columns_changed(self, sender):
self.selected_columns = sender.selected_options
self._redraw_plot()
def _redraw_plot(self):
new_dims = [
{"label": col, "values": self.data_source.data[col]}
for col in self.selected_columns
]
self.figure_widget.data[0].dimensions = new_dims
new_color = np.zeros(self.data_source.len, dtype="uint8")
new_color[list(self.data_source.brushed_indices)] = 1
with self.figure_widget.batch_update(), self.figure_widget.hold_trait_notifications():
self.figure_widget.data[0].line.color = new_color