Program Listing for File PlotlyMakePlot.py#
↰ Return to documentation for file (src/tfc/utils/PlotlyMakePlot.py)
import numpy as np
import plotly
import plotly.graph_objects as go
from .TFCUtils import TFCPrint
from .tfc_types import StrArrayLike, uint, Path, Literal
from typing import Optional
TFCPrint()
class MakePlot:
"""
A MakePlot class for Plotly.
This class wraps common Plotly functions to ease figure creation.
"""
_template = {
"data": {
"surface": [
{
"colorbar": {"outlinewidth": 0, "ticks": ""},
"colorscale": [
[0.0, "black"],
[0.2, "rebeccapurple"],
[0.3, "blueviolet"],
[0.7, "#4682B4"],
[1.0, "aquamarine"],
],
"lighting": {
"ambient": 0.6,
"specular": 0.05,
"diffuse": 0.4,
"fresnel": 3.0,
},
"type": "surface",
}
],
"volume": [
{
"colorbar": {"outlinewidth": 0, "ticks": ""},
"colorscale": [
[0.0, "black"],
[0.2, "rebeccapurple"],
[0.3, "blueviolet"],
[0.7, "#4682B4"],
[1.0, "aquamarine"],
],
"lighting": {
"ambient": 0.6,
"specular": 0.05,
"diffuse": 0.4,
"fresnel": 3.0,
},
"type": "volume",
}
],
"contour": [
{
"colorbar": {"outlinewidth": 0, "ticks": ""},
"colorscale": [
[0.0, "black"],
[0.2, "rebeccapurple"],
[0.3, "blueviolet"],
[0.7, "#4682B4"],
[1.0, "aquamarine"],
],
}
],
},
"layout": {
"margin": {"t": 50, "b": 50, "r": 50, "l": 50},
"autosize": True,
"font": {"size": 16},
"paper_bgcolor": "rgba(0,0,0,0)",
},
}
_backgroundColor = "rgba(0,0,0,0)"
_gridColor = "rgb(176, 176, 176)"
def __init__(
self,
xlabs: StrArrayLike,
ylabs: StrArrayLike,
titles: Optional[StrArrayLike] = None,
zlabs: Optional[StrArrayLike] = None,
):
"""
This function initializes the plot/subplots based on the inputs provided.
Parameters
----------
xlabs: StrArrayLike
The x-axes labels for the plots
ylabs: StrArrayLike
The y-axes labels for the plots
titles: StrArrayLike, optional
The titles for the plots. (Default value = None)
zlabs: StrArrayLike, optional
The z-axes labels of for the plots. Setting this forces subplots to be 3D. (Default value = None)
"""
# View distance used by the view method
self.viewDistance = 2.5
# Consistify all label types
if isinstance(xlabs, np.ndarray):
pass
elif isinstance(xlabs, str):
xlabs = np.array([[xlabs]])
elif isinstance(xlabs, tuple) or isinstance(xlabs, list):
xlabs = np.array(xlabs)
else:
TFCPrint.Error(
"The xlabels provided are not of a valid type. Please provide valid xlabels"
)
if len(xlabs.shape) == 1:
xlabs = np.expand_dims(xlabs, 1)
if isinstance(ylabs, np.ndarray):
pass
elif isinstance(ylabs, str):
ylabs = np.array([[ylabs]])
elif isinstance(ylabs, tuple) or isinstance(ylabs, list):
ylabs = np.array(ylabs)
else:
TFCPrint.Error(
"The ylabels provided are not of a valid type. Please provide valid ylabels"
)
if len(ylabs.shape) == 1:
ylabs = np.expand_dims(ylabs, 1)
if zlabs is None:
self._is3d = False
else:
self._is3d = True
if isinstance(zlabs, np.ndarray):
pass
elif isinstance(zlabs, str):
zlabs = np.array([[zlabs]])
elif isinstance(zlabs, tuple) or isinstance(zlabs, list):
zlabs = np.array(zlabs)
else:
TFCPrint.Error(
"The zlabels provided are not of a valid type. Please provide valid zlabels"
)
if len(zlabs.shape) == 1:
zlabs = np.expand_dims(zlabs, 1)
if titles is not None:
if isinstance(titles, np.ndarray):
titles = titles.flatten().tolist()
elif isinstance(titles, str):
titles = [titles]
elif isinstance(titles, tuple) or isinstance(titles, list):
titles = np.array(titles).flatten().tolist()
else:
TFCPrint.Error(
"The titles provided are not of a valid type. Please provide valid titles."
)
if int(np.prod(xlabs.shape)) != 1:
from plotly.subplots import make_subplots
if self._is3d:
specs = [
[
{"is_3d": True},
]
* xlabs.shape[1],
] * xlabs.shape[0]
self.fig = make_subplots(
rows=xlabs.shape[0],
cols=xlabs.shape[1],
specs=specs,
subplot_titles=titles,
)
else:
self.fig = make_subplots(
rows=xlabs.shape[0],
cols=xlabs.shape[1],
subplot_titles=titles,
)
self._hasSubplots = True
else:
self.fig = go.Figure()
if titles is not None:
self.fig["layout"]["title"] = titles[0]
self._hasSubplots = False
if self._hasSubplots:
for row in range(xlabs.shape[0]):
for col in range(xlabs.shape[1]):
if self._is3d:
if row == 0 and col == 0:
self.fig["layout"]["scene"]["xaxis"].update(
title=xlabs[row, col], exponentformat="e"
)
self.fig["layout"]["scene"]["yaxis"].update(
title=ylabs[row, col], exponentformat="e"
)
self.fig["layout"]["scene"]["zaxis"].update(
title=zlabs[row, col], exponentformat="e"
)
else:
self.fig["layout"]["scene" + str(row + col + 1)]["xaxis"].update(
title=xlabs[row, col], exponentformat="e"
)
self.fig["layout"]["scene" + str(row + col + 1)]["yaxis"].update(
title=ylabs[row, col], exponentformat="e"
)
self.fig["layout"]["scene" + str(row + col + 1)]["zaxis"].update(
title=zlabs[row, col], exponentformat="e"
)
else:
if row == 0 and col == 0:
self.fig["layout"]["xaxis"].update(
title=xlabs[row, col], exponentformat="e"
)
self.fig["layout"]["yaxis"].update(
title=ylabs[row, col], exponentformat="e"
)
else:
self.fig["layout"]["xaxis" + str(row + col + 1)].update(
title=xlabs[row, col], exponentformat="e"
)
self.fig["layout"]["yaxis" + str(row + col + 1)].update(
title=ylabs[row, col], exponentformat="e"
)
else:
if self._is3d:
self.fig.update_layout(
scene=dict(xaxis=dict(title=xlabs[0, 0], exponentformat="e"))
)
self.fig.update_layout(
scene=dict(yaxis=dict(title=ylabs[0, 0], exponentformat="e"))
)
self.fig.update_layout(
scene=dict(zaxis=dict(title=zlabs[0, 0], exponentformat="e"))
)
else:
self.fig.update_xaxes(title=xlabs[0, 0], exponentformat="e")
self.fig.update_yaxes(title=ylabs[0, 0], exponentformat="e")
# Update grid and background colors
if self._is3d:
for row in range(xlabs.shape[0]):
for col in range(xlabs.shape[1]):
if row == 0 and col == 0:
self.fig["layout"]["scene"]["xaxis"]["gridcolor"] = self._gridColor
self.fig["layout"]["scene"]["yaxis"]["gridcolor"] = self._gridColor
self.fig["layout"]["scene"]["zaxis"]["gridcolor"] = self._gridColor
self.fig["layout"]["scene"]["xaxis"][
"backgroundcolor"
] = self._backgroundColor
self.fig["layout"]["scene"]["yaxis"][
"backgroundcolor"
] = self._backgroundColor
self.fig["layout"]["scene"]["zaxis"][
"backgroundcolor"
] = self._backgroundColor
else:
self.fig["layout"]["scene" + str(row + col + 1)]["xaxis"].update(
gridcolor=self._gridColor,
zerolinecolor=self._gridColor,
backgroundcolor=self._backgroundColor,
)
self.fig["layout"]["scene" + str(row + col + 1)]["yaxis"].update(
gridcolor=self._gridColor,
zerolinecolor=self._gridColor,
backgroundcolor=self._backgroundColor,
)
self.fig["layout"]["scene" + str(row + col + 1)]["zaxis"].update(
gridcolor=self._gridColor,
zerolinecolor=self._gridColor,
backgroundcolor=self._backgroundColor,
)
else:
self.fig.update_xaxes(
gridcolor=self._gridColor, linecolor="black", zerolinecolor=self._gridColor
)
self.fig.update_yaxes(
gridcolor=self._gridColor, linecolor="black", zerolinecolor=self._gridColor
)
self.fig.update_layout(plot_bgcolor=self._backgroundColor)
# Update layout
self.fig.update_layout(
template=self._template,
)
def Surface(
self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
) -> plotly.graph_objects.Surface:
"""
Creates a plotly surface on the subplot specified by row and col or on
the main figure if not using subplots.
Parameters
----------
row : Optional[uint]
Subplot row (Default value = None)
col : Optional[uint]
Subplot column (Default value = None)
**kwargs : dict, optional
Keyword arguments passed on to plotly.graphic_objects.Surface
Returns
-------
surf : plotly.graph_objects.Surface
"""
return self.fig.add_trace(go.Surface(**kwargs), row=row, col=col)
def Scatter3d(
self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
) -> plotly.graph_objects.Scatter3d:
"""
Creates a 3d plotly scatter on the subplot specified by row and col or on
the main figure if not using subplots.
Parameters
----------
row : Optional[uint]
Subplot row (Default value = None)
col : Optional[uint]
Subplot column (Default value = None)
**kwargs : dict, optional
keyword arguments passed on to plotly.graphic_objects.Scatter3d
Returns
-------
scatter : plotly.graph_objects.Scatter3d
"""
return self.fig.add_trace(go.Scatter3d(**kwargs), row=row, col=col)
def Scatter(
self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
) -> plotly.graph_objects.Scatter:
"""
Creates a plotly scatter on the subplot specified by row and col or on
the main figure if not using subplots.
Parameters
----------
row : Optional[uint]
Subplot row (Default value = None)
col : Optional[uint]
Subplot column (Default value = None)
**kwargs : dict, optional
keyword arguments passed on to plotly.graph_objects.Scatter
Returns
-------
scatter : plotly.graphic_objects.Scatter
"""
return self.fig.add_trace(go.Scatter(**kwargs), row=row, col=col)
def Histogram(
self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
) -> plotly.graph_objects.Histogram:
"""
Creates a plotly histogram on the subplot specified by row and col or on
the main figure if not using subplots.
Parameters
----------
row : Optional[uint]
Subplot row (Default value = None)
col : Optional[uint]
Subplot column (Default value = None)
**kwargs : dict, optional
keyword arguments passed on to plotly.graph_objects.Histogram
Returns
-------
hist : plotly.graphic_objects.Histogram
"""
return self.fig.add_trace(go.Histogram(**kwargs), row=row, col=col)
def Contour(
self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
) -> plotly.graph_objects.Contour:
"""
Creates a plotly contour on the subplot specified by row and col or on
the main figure if not using subplots.
Parameters
----------
row : Optional[uint]
Subplot row (Default value = None)
col : Optional[uint]
Subplot column (Default value = None)
**kwargs : dict, optional
keyword arguments passed on to plotly.graphic_objects.Contour
Returns
-------
contour : plotly.graph_objects.Contour
"""
return self.fig.add_trace(go.Contour(**kwargs), row=row, col=col)
def Box(
self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
) -> plotly.graph_objects.Box:
"""
Creates a plotly box on the subplot specified by row and col or on
the main figure if not using subplots.
Parameters
----------
row : Optional[uint]
Subplot row (Default value = None)
col : Optional[uint]
Subplot column (Default value = None)
**kwargs : dict, optional
keyword arguments passed on to plotly.graph_objects.Box
Returns
-------
box : plotly.graphic_objects.Box
"""
return self.fig.add_trace(go.Box(**kwargs), row=row, col=col)
def Violin(
self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
) -> plotly.graph_objects.Violin:
"""
Creates a plotly violin on the subplot specified by row and col or on
the main figure if not using subplots.
Parameters
----------
row : Optional[uint]
Subplot row (Default value = None)
col : Optional[uint]
Subplot column (Default value = None)
**kwargs : dict, optional
keyword arguments passed on to plotly.graph_objects.Violin
Returns
-------
violin : plotly.graphic_objects.Violin
"""
return self.fig.add_trace(go.Violin(**kwargs), row=row, col=col)
def Volume(
self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
) -> plotly.graph_objects.Volume:
"""
Creates a plotly volume on the subplot specified by row and col or on
the main figure if not using subplots.
Parameters
----------
row : Optional[uint]
Subplot row (Default value = None)
col : Optional[uint]
Subplot column (Default value = None)
**kwargs : dict, optional
keyword arguments passed on to plotly.graph_objects.Volume
Returns
-------
volume : plotly.graphic_objects.Volume
"""
return self.fig.add_trace(go.Volume(**kwargs), row=row, col=col)
def show(self, **kwargs):
"""
Calls the figure's show method.
Parameters
----------
**kwargs : dict, optional
keyword arguments passed on to fig.show
"""
self.fig.show(**kwargs)
def save(
self,
fileName: Path,
tight: bool = True,
fileType: Literal["pdf", "jpg", "png", "svg", "eps", "html", None] = None,
**kwargs,
):
"""
Saves the figure using the type specified. If HTML is specified, the figure will
be saved as a dynamic html file. All other file types are static.
Parameters
----------
fileName : Path
File name to save the figure as.
tight : boolean, optional
If the fileType is pdf or png and this value is true, then a tool is used to eliminate whitespace. pdfCropMargins is used for PDFs and PIL is used for png's. (Default value = True)
fileType : Literal["pdf","jpg","png","svg","eps","html",None], optional
File suffix to use. If None, then the suffix will be inferred from the suffix of the fileName. (Default value = None)
**kwargs : dict, optional
Keyword arguments passed onto fig.write_image or fig.write_html, depending on fileType.
"""
if not fileType:
from pathlib import Path
suffix = Path(fileName).suffix[1:]
if suffix in ["pdf", "jpg", "png", "svg", "eps", "html"]:
fileType = suffix
else:
fileType = "png"
TFCPrint.Warning(
f"Warning, file type could not be inferred from {fileName}. The file type has been set to png."
)
fileName += "." + fileType
fileNameFull = fileName
else:
fileNameFull = fileName + "." + fileType
if fileType == "html":
self.fig.write_html(fileNameFull, **kwargs)
elif fileType == "png" and tight == True:
from io import BytesIO
from PIL import Image
dark = BytesIO()
dark.write(self.fig.to_image(fileType, **kwargs))
dark.seek(0)
pilImage = Image.open(dark)
pngArray = np.array(pilImage)
blankPx = pngArray[0, 0, :]
mask = pngArray != blankPx
coords = np.argwhere(mask)
x0, y0, z0 = coords.min(axis=0)
x1, y1, z1 = coords.max(axis=0) + 1
croppedBox = pngArray[x0:x1, y0:y1, z0:z1]
pilImage = Image.fromarray(croppedBox, "RGBA")
pilImage.save(fileNameFull)
else:
self.fig.write_image(fileNameFull, **kwargs)
if fileType == "pdf" and tight:
from pdfCropMargins import crop
crop(["-p", "0", fileNameFull])
def UploadToPlotly(self, username: str, apiKey: str, fileName: str, autoOpen: bool = False):
"""
Upload your plot to Plotly.
Parameters
----------
username : str
Plotly username
apiKey : str
Plotly api_key
fileName : str
Name of the file to save the plot as.
autoOpen : bool, optional
If true, plot will open in browser after saving. (Default value = False)
"""
import chart_studio
chart_studio.tools.set_credentials_file(username=username, api_key=apiKey)
return chart_studio.plotly.plot(self.fig, filename=fileName, auto_open=autoOpen)
def view(
self,
azimuth: float,
elevation: float,
row: Optional[uint] = None,
col: Optional[uint] = None,
viewDistance: Optional[float] = None,
):
"""
Change the view on the subplot specified by row and col or on
the main figure if not using subplots.
Parameters
----------
azimuth : float
Azimuth value in degrees
elevation : float
Elevation value in degrees
row : Optional[uint]
Subplot row (Default value = None)
col : Optional[uint]
Subplot column (Default value = None)
viewDistance : Optional[float]
Distance from camera to plot. (Default value = self.viewDistance)
"""
if not self._is3d:
TFCPrint.Error("The view method is only for 3d plots.")
if viewDistance:
self.viewDistance = viewDistance
azimuth *= np.pi / 180.0
elevation *= np.pi / 180.0
dark = self.viewDistance * np.array(
[
-np.sin(azimuth) * np.cos(elevation),
-np.cos(azimuth) * np.cos(elevation),
np.sin(elevation),
]
)
if row and col:
sceneNum = row + col - 1
if sceneNum == 1:
sceneNum = ""
else:
sceneNum = str(sceneNum)
self.fig["layout"]["scene" + sceneNum]["camera"].eye = dict(
x=dark[0], y=dark[1], z=dark[2]
)
else:
self.fig["layout"]["scene"]["camera"].eye = dict(x=dark[0], y=dark[1], z=dark[2])
def FullScreen(self):
"""
Make the plot full screen.
"""
import tkinter as tk
root = tk.Tk()
root.withdraw()
width = root.winfo_screenwidth()
height = root.winfo_screenheight()
self.fig.update_layout(width=width, height=height)
def PartScreen(self, width: float, height: float, units: Literal["in", "mm", "px"] = "in"):
"""
Make the plot size equal to width x height.
Parameters
----------
width : float
Width of the plot
height : float
Height of the plot
units : Literal["in","mm","px"], optional
Units width and height are given in. (Default value = inches)
"""
if units != "px":
import tkinter as tk
root = tk.Tk()
root.withdraw()
if units == "in":
widthConvert = root.winfo_screenwidth() / (root.winfo_screenmmwidth() / 25.4)
heightConvert = root.winfo_screenheight() / (root.winfo_screenmmheight() / 25.4)
else:
widthConvert = root.winfo_screenwidth() / root.winfo_screenmmwidth()
heightConvert = root.winfo_screenheight() / root.winfo_screenmmheight()
self.fig.update_layout(
width=int(width * widthConvert), height=int(height * heightConvert)
)
else:
self.fig.update_layout(width=width, height=height)
def NormalizeColorScale(
self,
types: list[str] = [],
data: Optional[str] = None,
cmax: Optional[float] = None,
cmin: Optional[float] = None,
):
"""
Normalizes the color scale for the plots whose type is in the types list.
If cmax and/or cmin are given, then the data variable is not used, and all
plots whose type is in the types list are assigned that cmax/cmin value.
If cmax and cmin are not specified, then they are set by taking the max and min
of the data that matches the data variable in all plots whose type is in the
typpes list.
Parameters
----------
types: list[str]
Plot types to set cmax and cmin for.
data: Optional[str]
Data type to use to calculate cmax and cmin if not already specified. (Default value = None)
cmax: Optional[float]
cmax value to use when setting the colorscale
cmin: Optional[float]
cmin value to use when setting the colorscale
"""
dark = self.fig.data
if cmax is None and cmin is None:
cmax = -np.inf
cmin = np.inf
for k in dark:
if k.__class__.__name__ in types:
if data in k:
kmax = k[data].max()
kmin = k[data].min()
if kmax > cmax:
cmax = kmax
if kmin < cmin:
cmin = kmin
colorScaleUpdate = {}
if not cmax is None:
colorScaleUpdate.update({"cmax": cmax})
if not cmin is None:
colorScaleUpdate.update({"cmin": cmin})
for k in dark:
if k.__class__.__name__ in types:
if data in k:
k.update(colorScaleUpdate)