"""Plotlywrapper: to make easy plots easy to make."""
from typing import Generator, Optional
from tempfile import NamedTemporaryFile
import plotly.offline as py
import plotly.graph_objs as go
from plotly.basedatatypes import BaseTraceType # pylint: disable=no-name-in-module,import-error
import numpy as np
import pandas as pd
__version__ = '0.2.0-dev'
try:
Figure = go.FigureWidget
except AttributeError:
Figure = go.Figure
def _labels(base='trace') -> Generator[str, None, None]:
i = 0
while True:
yield base + ' ' + str(i)
i += 1
def _detect_notebook() -> bool:
"""Detect if code is running in a Jupyter Notebook.
This isn't 100% correct but seems good enough
Returns
-------
bool
True if it detects this is a notebook, otherwise False.
"""
try:
from IPython import get_ipython
from ipykernel import zmqshell
except ImportError:
return False
kernel = get_ipython()
try:
from spyder.utils.ipython.spyder_kernel import SpyderKernel
if isinstance(kernel.kernel, SpyderKernel):
return False
except (ImportError, AttributeError):
pass
return isinstance(kernel, zmqshell.ZMQInteractiveShell)
def _merge_layout(x: go.Layout, y: go.Layout) -> go.Layout:
"""Merge attributes from two layouts."""
xjson = x.to_plotly_json()
yjson = y.to_plotly_json()
if 'shapes' in yjson and 'shapes' in xjson:
xjson['shapes'] += yjson['shapes']
yjson.update(xjson)
return go.Layout(yjson)
def _try_pydatetime(x):
"""Try to convert to pandas objects to datetimes.
Plotly doesn't know how to handle them.
"""
try:
# for datetimeindex
x = [y.isoformat() for y in x.to_pydatetime()]
except AttributeError:
pass
try:
# for generic series
x = [y.isoformat() for y in x.dt.to_pydatetime()]
except AttributeError:
pass
return x
[docs]class Chart(Figure):
"""Plotly chart base class.
Usually this object will get created by from a function.
"""
def __init__(self, data=None, layout=None):
"""Create a chart."""
super().__init__(data=data, layout=layout)
def __add__(self, other):
"""Add another chart or plot type to this chart."""
# pylint: disable=attribute-defined-outside-init
if isinstance(other, Chart):
self.add_traces(other.data)
self.layout = _merge_layout(self.layout, other.layout)
elif isinstance(other, BaseTraceType):
self.add_trace(other)
elif isinstance(other, go.Layout):
self.layout = _merge_layout(self.layout, other)
else:
raise ValueError('Cannot add {} to Chart'.format(other))
return self
def __radd__(self, other):
"""Add another chart or plot type to this chart."""
return self.__add__(other)
[docs] def group(self):
"""Set bar graph display mode to "grouped".
Returns
-------
Chart
"""
self.layout.barmode = 'group'
return self
[docs] def stack(self):
"""Set bar graph display mode to "stacked".
Returns
-------
Chart
"""
self.layout.barmode = 'stack'
return self
[docs] def legend(self, visible=True):
"""Make legend visible.
Parameters
----------
visible : bool, optional
Returns
-------
Chart
"""
self.layout.showlegend = visible
return self
@property
def xlabel(self):
"""Xaxis Label."""
return self.layout.xaxis.title
@xlabel.setter
def xlabel(self, label):
self.layout.xaxis.title = label
@property
def ylabel(self):
"""Left Yaxis Label."""
return self.layout.yaxis.title
@ylabel.setter
def ylabel(self, label):
self.layout.yaxis.title = label
@property
def zlabel(self):
"""Zaxis Label."""
return self.layout.zaxis.title
@zlabel.setter
def zlabel(self, label):
self.layout.zaxis.title = label
[docs] def closest(self):
"""Set hovermode to closest.
https://plot.ly/python/reference/#layout-hovermode
"""
self.layout.hovermode = 'closest'
return self
[docs] def xcompare(self):
"""Set hovermode to compare along x axis.
https://plot.ly/python/reference/#layout-hovermode
"""
self.layout.hovermode = 'x'
return self
[docs] def ycompare(self):
"""Set hovermode to compare along y axis.
https://plot.ly/python/reference/#layout-hovermode
"""
self.layout.hovermode = 'y'
return self
[docs] def xtickangle(self, angle):
"""Set the angle of the x-axis tick labels.
Parameters
----------
value : int
Angle in degrees
Returns
-------
Chart
"""
self.layout.xaxis.tickangle = angle
return self
[docs] def ytickangle(self, angle, index=1):
"""Set the angle of the y-axis tick labels.
Parameters
----------
value : int
Angle in degrees
index : int, optional
Y-axis index
Returns
-------
Chart
"""
self.layout['yaxis' + str(index)]['tickangle'] = angle
return self
[docs] def xlabelsize(self, size):
"""Set the size of the label.
Parameters
----------
size : int
Returns
-------
Chart
"""
self.layout['xaxis']['titlefont']['size'] = size
return self
[docs] def ylabelsize(self, size, index=1):
"""Set the size of the label.
Parameters
----------
size : int
Returns
-------
Chart
"""
self.layout['yaxis' + str(index)]['titlefont']['size'] = size
return self
[docs] def xticksize(self, size):
"""Set the tick font size.
Parameters
----------
size : int
Returns
-------
Chart
"""
self.layout['xaxis']['tickfont']['size'] = size
return self
[docs] def yticksize(self, size, index=1):
"""Set the tick font size.
Parameters
----------
size : int
Returns
-------
Chart
"""
self.layout['yaxis' + str(index)]['tickfont']['size'] = size
return self
[docs] def ytickvals(self, values, index=1):
"""Set the tick values.
Parameters
----------
values : array-like
Returns
-------
Chart
"""
self.layout['yaxis' + str(index)]['tickvals'] = values
return self
[docs] def yticktext(self, labels, index=1):
"""Set the tick labels.
Parameters
----------
labels : array-like
Returns
-------
Chart
"""
self.layout['yaxis' + str(index)]['ticktext'] = labels
return self
[docs] def xlim(self, low, high):
"""Set xaxis limits.
Parameters
----------
low : number
high : number
Returns
-------
Chart
"""
self.layout['xaxis']['range'] = [low, high]
return self
[docs] def ylim(self, low, high, index=1):
"""Set yaxis limits.
Parameters
----------
low : number
high : number
index : int, optional
Returns
-------
Chart
"""
self.layout['yaxis' + str(index)]['range'] = [low, high]
return self
[docs] def xdtick(self, dtick):
"""Set the tick distance."""
self.layout.xaxis.dtick = dtick
return self
[docs] def ydtick(self, dtick, index=1):
"""Set the tick distance."""
self.layout['yaxis' + str(index)]['dtick'] = dtick
return self
[docs] def xnticks(self, nticks):
"""Set the number of ticks."""
self.layout.xaxis.nticks = nticks
return self
[docs] def ynticks(self, nticks, index=1):
"""Set the number of ticks."""
self.layout['yaxis' + str(index)]['nticks'] = nticks
return self
[docs] def yaxis_left(self, index=1):
"""Put the yaxis on the left hand side.
Parameters
----------
index : int, optional
Returns
-------
Chart
"""
self.layout['yaxis' + str(index)]['side'] = 'left'
[docs] def yaxis_right(self, index=1):
"""Put the yaxis on the right hand side.
Parameters
----------
index : int, optional
Returns
-------
Chart
"""
self.layout['yaxis' + str(index)]['side'] = 'right'
[docs] def show(
self,
filename: Optional[str] = None,
show_link: bool = True,
auto_open: bool = True,
detect_notebook: bool = True,
) -> None:
"""Display the chart.
Parameters
----------
filename : str, optional
Save plot to this filename, otherwise it's saved to a temporary file.
show_link : bool, optional
Show link to plotly.
auto_open : bool, optional
Automatically open the plot (in the browser).
detect_notebook : bool, optional
Try to detect if we're running in a notebook.
"""
kargs = {}
if detect_notebook and _detect_notebook():
py.init_notebook_mode()
plot = py.iplot
else:
plot = py.plot
if filename is None:
filename = NamedTemporaryFile(prefix='plotly', suffix='.html', delete=False).name
kargs['filename'] = filename
kargs['auto_open'] = auto_open
plot(self, show_link=show_link, **kargs)
[docs] def save(
self,
filename: Optional[str] = None,
show_link: bool = True,
auto_open: bool = False,
output: str = 'file',
plotlyjs: bool = True,
) -> str:
"""Save the chart to an html file."""
if filename is None:
filename = NamedTemporaryFile(prefix='plotly', suffix='.html', delete=False).name
# NOTE: this doesn't work for output 'div'
py.plot(
self,
show_link=show_link,
filename=filename,
auto_open=auto_open,
output_type=output,
include_plotlyjs=plotlyjs,
)
return filename
@property
def dict(self):
"""Convert Chart to a dict."""
return self.to_dict()
def spark_shape(points, shapes, fill=None, color='blue', width=5, yindex=0, heights=None):
"""TODO: Docstring for spark.
Parameters
----------
points : array-like
shapes : array-like
fill : array-like, optional
Returns
-------
Chart
"""
assert len(points) == len(shapes) + 1
data = [{'marker': {'color': 'white'}, 'x': [points[0], points[-1]], 'y': [yindex, yindex]}]
if fill is None:
fill = [False] * len(shapes)
if heights is None:
heights = [0.4] * len(shapes)
lays = []
for i, (shape, height) in enumerate(zip(shapes, heights)):
if shape is None:
continue
if fill[i]:
fillcolor = color
else:
fillcolor = 'white'
lays.append(
dict(
type=shape,
x0=points[i],
x1=points[i + 1],
y0=yindex - height,
y1=yindex + height,
xref='x',
yref='y',
fillcolor=fillcolor,
line=dict(color=color, width=width),
)
)
layout = dict(shapes=lays)
return Chart(data=data, layout=layout)
[docs]def vertical(x, ymin=0, ymax=1, color=None, width=None, dash=None, opacity=None):
"""Draws a vertical line from `ymin` to `ymax`.
Parameters
----------
xmin : int, optional
xmax : int, optional
color : str, optional
width : number, optional
Returns
-------
Chart
"""
lineattr = {}
if color:
lineattr['color'] = color
if width:
lineattr['width'] = width
if dash:
lineattr['dash'] = dash
layout = dict(
shapes=[dict(type='line', x0=x, x1=x, y0=ymin, y1=ymax, opacity=opacity, line=lineattr)]
)
return Chart(layout=layout)
[docs]def horizontal(y, xmin=0, xmax=1, color=None, width=None, dash=None, opacity=None):
"""Draws a horizontal line from `xmin` to `xmax`.
Parameters
----------
xmin : int, optional
xmax : int, optional
color : str, optional
width : number, optional
Returns
-------
Chart
"""
lineattr = {}
if color:
lineattr['color'] = color
if width:
lineattr['width'] = width
if dash:
lineattr['dash'] = dash
layout = dict(
shapes=[dict(type='line', x0=xmin, x1=xmax, y0=y, y1=y, opacity=opacity, line=lineattr)]
)
return Chart(layout=layout)
[docs]def line(
x=None,
y=None,
label=None,
color=None,
width=None,
dash=None,
opacity=None,
mode='lines+markers',
yaxis=1,
fill=None,
text="",
markersize=6,
):
"""Draws connected dots.
Parameters
----------
x : array-like, optional
y : array-like, optional
label : array-like, optional
Returns
-------
Chart
"""
assert x is not None or y is not None, "x or y must be something"
yn = 'y' + str(yaxis)
lineattr = {}
if color:
lineattr['color'] = color
if width:
lineattr['width'] = width
if dash:
lineattr['dash'] = dash
if y is None:
y = x
x = None
if x is None:
x = np.arange(len(y))
else:
x = _try_pydatetime(x)
x = np.atleast_1d(x)
y = np.atleast_1d(y)
assert x.shape[0] == y.shape[0]
if y.ndim == 2:
if not hasattr(label, '__iter__'):
if label is None:
label = _labels()
else:
label = _labels(label)
data = [
go.Scatter(
x=x,
y=yy,
name=ll,
line=lineattr,
mode=mode,
text=text,
fill=fill,
opacity=opacity,
yaxis=yn,
marker=dict(size=markersize, opacity=opacity),
)
for ll, yy in zip(label, y.T)
]
else:
data = [
go.Scatter(
x=x,
y=y,
name=label,
line=lineattr,
mode=mode,
text=text,
fill=fill,
opacity=opacity,
yaxis=yn,
marker=dict(size=markersize, opacity=opacity),
)
]
if yaxis == 1:
return Chart(data=data)
return Chart(data=data, layout={'yaxis' + str(yaxis): dict(overlaying='y')})
[docs]def line3d(
x, y, z, label=None, color=None, width=None, dash=None, opacity=None, mode='lines+markers'
):
"""Create a 3d line chart."""
x = np.atleast_1d(x)
y = np.atleast_1d(y)
z = np.atleast_1d(z)
assert x.shape == y.shape
assert y.shape == z.shape
lineattr = {}
if color:
lineattr['color'] = color
if width:
lineattr['width'] = width
if dash:
lineattr['dash'] = dash
if y.ndim == 2:
if not hasattr(label, '__iter__'):
if label is None:
label = _labels()
else:
label = _labels(label)
data = [
go.Scatter3d(x=xx, y=yy, z=zz, name=ll, line=lineattr, mode=mode, opacity=opacity)
for ll, xx, yy, zz in zip(label, x.T, y.T, z.T)
]
else:
data = [go.Scatter3d(x=x, y=y, z=z, name=label, line=lineattr, mode=mode, opacity=opacity)]
return Chart(data=data)
[docs]def scatter3d(x, y, z, label=None, color=None, width=None, dash=None, opacity=None, mode='markers'):
"""Create a 3D scatter Plot.
Parameters
----------
x : array-like
data on x-dimension
y : array-like
data on y-dimension
z : array-like
data on z-dimension
label : TODO, optional
mode : 'group' or 'stack', default 'group'
opacity : TODO, optional
Returns
-------
Chart
"""
x = np.atleast_1d(x)
y = np.atleast_1d(y)
z = np.atleast_1d(z)
assert x.shape == y.shape
assert y.shape == z.shape
lineattr = {}
if color:
lineattr['color'] = color
if width:
lineattr['width'] = width
if dash:
lineattr['dash'] = dash
if y.ndim == 2:
if not hasattr(label, '__iter__'):
if label is None:
label = _labels()
else:
label = _labels(label)
data = [
go.Scatter3d(x=xx, y=yy, z=zz, name=ll, line=lineattr, mode=mode, opacity=opacity)
for ll, xx, yy, zz in zip(label, x.T, y.T, z.T)
]
else:
data = [go.Scatter3d(x=x, y=y, z=z, name=label, line=lineattr, mode=mode, opacity=opacity)]
return Chart(data=data)
[docs]def scatter(
x=None,
y=None,
label=None,
color=None,
width=None,
dash=None,
opacity=None,
markersize=6,
yaxis=1,
fill=None,
text="",
mode='markers',
):
"""Draws dots.
Parameters
----------
x : array-like, optional
y : array-like, optional
label : array-like, optional
Returns
-------
Chart
"""
return line(
x=x,
y=y,
label=label,
color=color,
width=width,
dash=dash,
opacity=opacity,
mode=mode,
yaxis=yaxis,
fill=fill,
text=text,
markersize=markersize,
)
[docs]def bar(x=None, y=None, label=None, mode='group', yaxis=1, opacity=None):
"""Create a bar chart.
Parameters
----------
x : array-like, optional
y : TODO, optional
label : TODO, optional
mode : 'group' or 'stack', default 'group'
opacity : TODO, optional
Returns
-------
Chart
A Chart with bar graph data.
"""
assert x is not None or y is not None, "x or y must be something"
yn = 'y' + str(yaxis)
if y is None:
y = x
x = None
if x is None:
x = np.arange(len(y))
else:
x = _try_pydatetime(x)
x = np.atleast_1d(x)
y = np.atleast_1d(y)
if y.ndim == 2:
if not hasattr(label, '__iter__'):
if label is None:
label = _labels()
else:
label = _labels(label)
data = [go.Bar(x=x, y=yy, name=ll, yaxis=yn, opacity=opacity) for ll, yy in zip(label, y.T)]
else:
data = [go.Bar(x=x, y=y, name=label, yaxis=yn, opacity=opacity)]
if yaxis == 1:
return Chart(data=data, layout={'barmode': mode})
return Chart(data=data, layout={'barmode': mode, 'yaxis' + str(yaxis): dict(overlaying='y')})
[docs]def heatmap(z, x=None, y=None, colorscale='Viridis'):
"""Create a heatmap.
Parameters
----------
z : TODO
x : TODO, optional
y : TODO, optional
colorscale : TODO, optional
Returns
-------
Chart
"""
z = np.atleast_1d(z)
data = [go.Heatmap(z=z, x=x, y=y, colorscale=colorscale)]
return Chart(data=data)
[docs]def fill_zero(
x=None,
y=None,
label=None,
color=None,
width=None,
dash=None,
opacity=None,
mode='lines+markers',
**kargs
):
"""Fill to zero.
Parameters
----------
x : array-like, optional
y : TODO, optional
label : TODO, optional
Returns
-------
Chart
"""
return line(
x=x,
y=y,
label=label,
color=color,
width=width,
dash=dash,
opacity=opacity,
mode=mode,
fill='tozeroy',
**kargs
)
[docs]def fill_between(
x=None,
ylow=None,
yhigh=None,
label=None,
color=None,
width=None,
dash=None,
opacity=None,
mode='lines+markers',
**kargs
):
"""Fill between `ylow` and `yhigh`.
Parameters
----------
x : array-like, optional
ylow : TODO, optional
yhigh : TODO, optional
Returns
-------
Chart
"""
plot = line(
x=x,
y=ylow,
label=label,
color=color,
width=width,
dash=dash,
opacity=opacity,
mode=mode,
fill=None,
**kargs
)
plot += line(
x=x,
y=yhigh,
label=label,
color=color,
width=width,
dash=dash,
opacity=opacity,
mode=mode,
fill='tonexty',
**kargs
)
return plot
[docs]def rug(x, label=None, opacity=None):
"""Rug chart.
Parameters
----------
x : array-like, optional
label : TODO, optional
opacity : TODO, optional
Returns
-------
Chart
"""
x = _try_pydatetime(x)
x = np.atleast_1d(x)
data = [
go.Scatter(
x=x,
y=np.ones_like(x),
name=label,
opacity=opacity,
mode='markers',
marker=dict(symbol='line-ns-open'),
)
]
layout = dict(
barmode='overlay',
hovermode='closest',
legend=dict(traceorder='reversed'),
xaxis1=dict(zeroline=False),
yaxis1=dict(
domain=[0.85, 1],
showline=False,
showgrid=False,
zeroline=False,
anchor='free',
position=0.0,
showticklabels=False,
),
)
return Chart(data=data, layout=layout)
[docs]def surface(x, y, z):
"""Surface plot.
Parameters
----------
x : array-like, optional
y : array-like, optional
z : array-like, optional
Returns
-------
Chart
"""
data = [go.Surface(x=x, y=y, z=z)]
return Chart(data=data)
[docs]def hist(x, mode='overlay', label=None, opacity=None, horz=False, histnorm=None):
"""Histogram.
Parameters
----------
x : array-like
mode : str, optional
label : TODO, optional
opacity : float, optional
horz : bool, optional
histnorm : None, "percent", "probability", "density", "probability density", optional
Specifies the type of normalization used for this histogram trace.
If ``None``, the span of each bar corresponds to the number of occurrences
(i.e. the number of data points lying inside the bins). If "percent",
the span of each bar corresponds to the percentage of occurrences with
respect to the total number of sample points (here, the sum of all bin
area equals 100%). If "density", the span of each bar corresponds to the
number of occurrences in a bin divided by the size of the bin interval
(here, the sum of all bin area equals the total number of sample
points). If "probability density", the span of each bar corresponds to
the probability that an event will fall into the corresponding bin
(here, the sum of all bin area equals 1).
Returns
-------
Chart
"""
x = np.atleast_1d(x)
if horz:
kargs = dict(y=x)
else:
kargs = dict(x=x)
layout = dict(barmode=mode)
data = [go.Histogram(opacity=opacity, name=label, histnorm=histnorm, **kargs)]
return Chart(data=data, layout=layout)
[docs]def hist2d(x, y, label=None, opacity=None):
"""2D Histogram.
Parameters
----------
x : array-like, optional
y : array-like, optional
label : TODO, optional
opacity : float, optional
Returns
-------
Chart
"""
x = np.atleast_1d(x)
y = np.atleast_1d(y)
data = [go.Histogram2d(x=x, y=y, opacity=opacity, name=label)]
return Chart(data=data)
[docs]@pd.api.extensions.register_dataframe_accessor('plotly')
@pd.api.extensions.register_series_accessor('plotly')
class PandasPlotting:
"""Pandas plotly charting methods.
Examples
--------
Here's an example of how to do that.
>>> df = pd.DataFrame([[1, 2], [1, 4]])
>>> chart = df.plotly.line()
>>> chart.show()
"""
def __init__(self, data):
"""Create the pandas accessor."""
self._data = data
if isinstance(data, pd.DataFrame):
self._label = data.columns
elif isinstance(data, pd.Series):
self._label = data.name
[docs] def line(
self,
label=None,
color=None,
width=None,
dash=None,
opacity=None,
mode='lines+markers',
fill=None,
**kargs
):
"""Create a line chart."""
if label is None:
label = self._label
return line(
x=self._data.index,
y=self._data.values,
label=label,
color=color,
width=width,
dash=dash,
opacity=opacity,
mode=mode,
fill=fill,
**kargs
)
[docs] def scatter(
self, label=None, color=None, width=None, dash=None, opacity=None, mode='markers', **kargs
):
"""Create a bar chart.
Parameters
----------
label : list of strings, optional
list of labels to override column names
Returns
-------
Chart
"""
if label is None:
label = self._label
return scatter(
x=self._data.index,
y=self._data.values,
label=label,
color=color,
width=width,
dash=dash,
opacity=opacity,
mode=mode,
**kargs
)
[docs] def bar(self, label=None, mode='group', opacity=None, **kargs):
"""Create a bar chart.
Parameters
----------
label : list of strings, optional
list of labels to override column names
mode : str, optional
'group' or 'stack'
Returns
-------
Chart
"""
if label is None:
label = self._label
return bar(
x=self._data.index,
y=self._data.values,
label=label,
mode=mode,
opacity=opacity,
**kargs
)
[docs] def stack(self, mode='lines', label=None, **kargs):
"""Create a stacked area plot.
Parameters
----------
mode : string, optional
label : list of strings, optional
list of labels to override column names
Returns
-------
Chart
"""
if label is None:
label = self._label
cum = self._data.cumsum(axis=1)
chart = Chart()
for lab, (_, ser), (_, orig) in zip(label, cum.iteritems(), self._data.iteritems()):
chart += line(
x=ser.index,
y=ser.values,
label=lab,
fill='tonexty',
mode=mode,
text=orig.values,
**kargs
)
return chart
[docs] def sparklines(self, label=None, mode='lines', percent=90, epsilon=1e-3):
"""TODO: Docstring for sparklines.
Parameters
----------
label : array-like, optional
mode : str, optional
percent : number, optional
Returns
-------
Chart
"""
if label is None:
label = self._label
div = self._data.max(axis=0) - self._data.min(axis=0) + epsilon
center = div / 2.0 + self._data.min(axis=0)
normed = (self._data - center) / div
normed *= percent / 100.0
offset = np.arange(1, self._data.shape[1] + 1)
normed += offset
chart = line(x=self._data.index, y=normed, mode=mode, label=label)
chart.ytickvals(offset)
chart.yticktext(self._data.columns.values)
chart.legend(False)
return chart