
.. _program_listing_file_src_tfc_utils_PlotlyMakePlot.py:

Program Listing for File PlotlyMakePlot.py
==========================================

|exhale_lsh| :ref:`Return to documentation for file <file_src_tfc_utils_PlotlyMakePlot.py>` (``src/tfc/utils/PlotlyMakePlot.py``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: py

   import numpy as np
   import plotly
   import plotly.graph_objects as go
   
   from .TFCUtils import TFCPrint
   from .tfc_types import StrArrayLike, uint, Path, Literal
   from typing import Optional
   
   TFCPrint()
   
   
   class MakePlot:
       """
       A MakePlot class for Plotly.
       This class wraps common Plotly functions to ease figure creation.
       """
   
       _template = {
           "data": {
               "surface": [
                   {
                       "colorbar": {"outlinewidth": 0, "ticks": ""},
                       "colorscale": [
                           [0.0, "black"],
                           [0.2, "rebeccapurple"],
                           [0.3, "blueviolet"],
                           [0.7, "#4682B4"],
                           [1.0, "aquamarine"],
                       ],
                       "lighting": {
                           "ambient": 0.6,
                           "specular": 0.05,
                           "diffuse": 0.4,
                           "fresnel": 3.0,
                       },
                       "type": "surface",
                   }
               ],
               "volume": [
                   {
                       "colorbar": {"outlinewidth": 0, "ticks": ""},
                       "colorscale": [
                           [0.0, "black"],
                           [0.2, "rebeccapurple"],
                           [0.3, "blueviolet"],
                           [0.7, "#4682B4"],
                           [1.0, "aquamarine"],
                       ],
                       "lighting": {
                           "ambient": 0.6,
                           "specular": 0.05,
                           "diffuse": 0.4,
                           "fresnel": 3.0,
                       },
                       "type": "volume",
                   }
               ],
               "contour": [
                   {
                       "colorbar": {"outlinewidth": 0, "ticks": ""},
                       "colorscale": [
                           [0.0, "black"],
                           [0.2, "rebeccapurple"],
                           [0.3, "blueviolet"],
                           [0.7, "#4682B4"],
                           [1.0, "aquamarine"],
                       ],
                   }
               ],
           },
           "layout": {
               "margin": {"t": 50, "b": 50, "r": 50, "l": 50},
               "autosize": True,
               "font": {"size": 16},
               "paper_bgcolor": "rgba(0,0,0,0)",
           },
       }
       _backgroundColor = "rgba(0,0,0,0)"
       _gridColor = "rgb(176, 176, 176)"
   
       def __init__(
           self,
           xlabs: StrArrayLike,
           ylabs: StrArrayLike,
           titles: Optional[StrArrayLike] = None,
           zlabs: Optional[StrArrayLike] = None,
       ):
           """
           This function initializes the plot/subplots based on the inputs provided.
   
           Parameters
           ----------
           xlabs: StrArrayLike
               The x-axes labels for the plots
           ylabs: StrArrayLike
               The y-axes labels for the plots
           titles: StrArrayLike, optional
               The titles for the plots. (Default value = None)
           zlabs: StrArrayLike, optional
               The z-axes labels of for the plots. Setting this forces subplots to be 3D. (Default value = None)
           """
   
           # View distance used by the view method
           self.viewDistance = 2.5
   
           # Consistify all label types
           if isinstance(xlabs, np.ndarray):
               pass
           elif isinstance(xlabs, str):
               xlabs = np.array([[xlabs]])
           elif isinstance(xlabs, tuple) or isinstance(xlabs, list):
               xlabs = np.array(xlabs)
           else:
               TFCPrint.Error(
                   "The xlabels provided are not of a valid type. Please provide valid xlabels"
               )
           if len(xlabs.shape) == 1:
               xlabs = np.expand_dims(xlabs, 1)
   
           if isinstance(ylabs, np.ndarray):
               pass
           elif isinstance(ylabs, str):
               ylabs = np.array([[ylabs]])
           elif isinstance(ylabs, tuple) or isinstance(ylabs, list):
               ylabs = np.array(ylabs)
           else:
               TFCPrint.Error(
                   "The ylabels provided are not of a valid type. Please provide valid ylabels"
               )
           if len(ylabs.shape) == 1:
               ylabs = np.expand_dims(ylabs, 1)
   
           if zlabs is None:
               self._is3d = False
           else:
               self._is3d = True
               if isinstance(zlabs, np.ndarray):
                   pass
               elif isinstance(zlabs, str):
                   zlabs = np.array([[zlabs]])
               elif isinstance(zlabs, tuple) or isinstance(zlabs, list):
                   zlabs = np.array(zlabs)
               else:
                   TFCPrint.Error(
                       "The zlabels provided are not of a valid type. Please provide valid zlabels"
                   )
               if len(zlabs.shape) == 1:
                   zlabs = np.expand_dims(zlabs, 1)
   
           if titles is not None:
               if isinstance(titles, np.ndarray):
                   titles = titles.flatten().tolist()
               elif isinstance(titles, str):
                   titles = [titles]
               elif isinstance(titles, tuple) or isinstance(titles, list):
                   titles = np.array(titles).flatten().tolist()
               else:
                   TFCPrint.Error(
                       "The titles provided are not of a valid type. Please provide valid titles."
                   )
   
           if int(np.prod(xlabs.shape)) != 1:
               from plotly.subplots import make_subplots
   
               if self._is3d:
                   specs = [
                       [
                           {"is_3d": True},
                       ]
                       * xlabs.shape[1],
                   ] * xlabs.shape[0]
                   self.fig = make_subplots(
                       rows=xlabs.shape[0],
                       cols=xlabs.shape[1],
                       specs=specs,
                       subplot_titles=titles,
                   )
               else:
                   self.fig = make_subplots(
                       rows=xlabs.shape[0],
                       cols=xlabs.shape[1],
                       subplot_titles=titles,
                   )
               self._hasSubplots = True
           else:
               self.fig = go.Figure()
               if titles is not None:
                   self.fig["layout"]["title"] = titles[0]
               self._hasSubplots = False
   
           if self._hasSubplots:
               for row in range(xlabs.shape[0]):
                   for col in range(xlabs.shape[1]):
                       if self._is3d:
                           if row == 0 and col == 0:
                               self.fig["layout"]["scene"]["xaxis"].update(
                                   title=xlabs[row, col], exponentformat="e"
                               )
                               self.fig["layout"]["scene"]["yaxis"].update(
                                   title=ylabs[row, col], exponentformat="e"
                               )
                               self.fig["layout"]["scene"]["zaxis"].update(
                                   title=zlabs[row, col], exponentformat="e"
                               )
                           else:
                               self.fig["layout"]["scene" + str(row + col + 1)]["xaxis"].update(
                                   title=xlabs[row, col], exponentformat="e"
                               )
                               self.fig["layout"]["scene" + str(row + col + 1)]["yaxis"].update(
                                   title=ylabs[row, col], exponentformat="e"
                               )
                               self.fig["layout"]["scene" + str(row + col + 1)]["zaxis"].update(
                                   title=zlabs[row, col], exponentformat="e"
                               )
   
                       else:
                           if row == 0 and col == 0:
                               self.fig["layout"]["xaxis"].update(
                                   title=xlabs[row, col], exponentformat="e"
                               )
                               self.fig["layout"]["yaxis"].update(
                                   title=ylabs[row, col], exponentformat="e"
                               )
                           else:
                               self.fig["layout"]["xaxis" + str(row + col + 1)].update(
                                   title=xlabs[row, col], exponentformat="e"
                               )
                               self.fig["layout"]["yaxis" + str(row + col + 1)].update(
                                   title=ylabs[row, col], exponentformat="e"
                               )
   
           else:
               if self._is3d:
                   self.fig.update_layout(
                       scene=dict(xaxis=dict(title=xlabs[0, 0], exponentformat="e"))
                   )
                   self.fig.update_layout(
                       scene=dict(yaxis=dict(title=ylabs[0, 0], exponentformat="e"))
                   )
                   self.fig.update_layout(
                       scene=dict(zaxis=dict(title=zlabs[0, 0], exponentformat="e"))
                   )
               else:
                   self.fig.update_xaxes(title=xlabs[0, 0], exponentformat="e")
                   self.fig.update_yaxes(title=ylabs[0, 0], exponentformat="e")
   
           # Update grid and background colors
           if self._is3d:
               for row in range(xlabs.shape[0]):
                   for col in range(xlabs.shape[1]):
                       if row == 0 and col == 0:
                           self.fig["layout"]["scene"]["xaxis"]["gridcolor"] = self._gridColor
                           self.fig["layout"]["scene"]["yaxis"]["gridcolor"] = self._gridColor
                           self.fig["layout"]["scene"]["zaxis"]["gridcolor"] = self._gridColor
                           self.fig["layout"]["scene"]["xaxis"][
                               "backgroundcolor"
                           ] = self._backgroundColor
                           self.fig["layout"]["scene"]["yaxis"][
                               "backgroundcolor"
                           ] = self._backgroundColor
                           self.fig["layout"]["scene"]["zaxis"][
                               "backgroundcolor"
                           ] = self._backgroundColor
                       else:
                           self.fig["layout"]["scene" + str(row + col + 1)]["xaxis"].update(
                               gridcolor=self._gridColor,
                               zerolinecolor=self._gridColor,
                               backgroundcolor=self._backgroundColor,
                           )
                           self.fig["layout"]["scene" + str(row + col + 1)]["yaxis"].update(
                               gridcolor=self._gridColor,
                               zerolinecolor=self._gridColor,
                               backgroundcolor=self._backgroundColor,
                           )
                           self.fig["layout"]["scene" + str(row + col + 1)]["zaxis"].update(
                               gridcolor=self._gridColor,
                               zerolinecolor=self._gridColor,
                               backgroundcolor=self._backgroundColor,
                           )
           else:
               self.fig.update_xaxes(
                   gridcolor=self._gridColor, linecolor="black", zerolinecolor=self._gridColor
               )
               self.fig.update_yaxes(
                   gridcolor=self._gridColor, linecolor="black", zerolinecolor=self._gridColor
               )
               self.fig.update_layout(plot_bgcolor=self._backgroundColor)
   
           # Update layout
           self.fig.update_layout(
               template=self._template,
           )
   
       def Surface(
           self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
       ) -> plotly.graph_objects.Surface:
           """
           Creates a plotly surface on the subplot specified by row and col or on
           the main figure if not using subplots.
   
           Parameters
           ----------
           row : Optional[uint]
               Subplot row (Default value = None)
           col : Optional[uint]
               Subplot column (Default value = None)
           **kwargs : dict, optional
               Keyword arguments passed on to plotly.graphic_objects.Surface
   
           Returns
           -------
           surf : plotly.graph_objects.Surface
           """
           return self.fig.add_trace(go.Surface(**kwargs), row=row, col=col)
   
       def Scatter3d(
           self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
       ) -> plotly.graph_objects.Scatter3d:
           """
           Creates a 3d plotly scatter on the subplot specified by row and col or on
           the main figure if not using subplots.
   
           Parameters
           ----------
           row : Optional[uint]
                Subplot row (Default value = None)
           col : Optional[uint]
                Subplot column (Default value = None)
           **kwargs : dict, optional
               keyword arguments passed on to plotly.graphic_objects.Scatter3d
   
           Returns
           -------
           scatter : plotly.graph_objects.Scatter3d
           """
           return self.fig.add_trace(go.Scatter3d(**kwargs), row=row, col=col)
   
       def Scatter(
           self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
       ) -> plotly.graph_objects.Scatter:
           """
           Creates a plotly scatter on the subplot specified by row and col or on
           the main figure if not using subplots.
   
           Parameters
           ----------
           row : Optional[uint]
                Subplot row (Default value = None)
           col : Optional[uint]
                Subplot column (Default value = None)
           **kwargs : dict, optional
               keyword arguments passed on to plotly.graph_objects.Scatter
   
           Returns
           -------
           scatter : plotly.graphic_objects.Scatter
           """
           return self.fig.add_trace(go.Scatter(**kwargs), row=row, col=col)
   
       def Histogram(
           self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
       ) -> plotly.graph_objects.Histogram:
           """
           Creates a plotly histogram on the subplot specified by row and col or on
           the main figure if not using subplots.
   
           Parameters
           ----------
           row : Optional[uint]
                Subplot row (Default value = None)
           col : Optional[uint]
                Subplot column (Default value = None)
           **kwargs : dict, optional
               keyword arguments passed on to plotly.graph_objects.Histogram
   
           Returns
           -------
           hist : plotly.graphic_objects.Histogram
           """
           return self.fig.add_trace(go.Histogram(**kwargs), row=row, col=col)
   
       def Contour(
           self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
       ) -> plotly.graph_objects.Contour:
           """
           Creates a plotly contour on the subplot specified by row and col or on
           the main figure if not using subplots.
   
           Parameters
           ----------
           row : Optional[uint]
                Subplot row (Default value = None)
           col : Optional[uint]
                Subplot column (Default value = None)
           **kwargs : dict, optional
               keyword arguments passed on to plotly.graphic_objects.Contour
   
           Returns
           -------
           contour : plotly.graph_objects.Contour
           """
           return self.fig.add_trace(go.Contour(**kwargs), row=row, col=col)
   
       def Box(
           self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
       ) -> plotly.graph_objects.Box:
           """
           Creates a plotly box on the subplot specified by row and col or on
           the main figure if not using subplots.
   
           Parameters
           ----------
           row : Optional[uint]
                Subplot row (Default value = None)
           col : Optional[uint]
                Subplot column (Default value = None)
           **kwargs : dict, optional
               keyword arguments passed on to plotly.graph_objects.Box
   
           Returns
           -------
           box : plotly.graphic_objects.Box
           """
           return self.fig.add_trace(go.Box(**kwargs), row=row, col=col)
   
       def Violin(
           self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
       ) -> plotly.graph_objects.Violin:
           """
           Creates a plotly violin on the subplot specified by row and col or on
           the main figure if not using subplots.
   
           Parameters
           ----------
           row : Optional[uint]
                Subplot row (Default value = None)
           col : Optional[uint]
                Subplot column (Default value = None)
           **kwargs : dict, optional
               keyword arguments passed on to plotly.graph_objects.Violin
   
           Returns
           -------
           violin : plotly.graphic_objects.Violin
           """
           return self.fig.add_trace(go.Violin(**kwargs), row=row, col=col)
   
       def Volume(
           self, row: Optional[uint] = None, col: Optional[uint] = None, **kwargs
       ) -> plotly.graph_objects.Volume:
           """
           Creates a plotly volume on the subplot specified by row and col or on
           the main figure if not using subplots.
   
           Parameters
           ----------
           row : Optional[uint]
                Subplot row (Default value = None)
           col : Optional[uint]
                Subplot column (Default value = None)
           **kwargs : dict, optional
               keyword arguments passed on to plotly.graph_objects.Volume
   
           Returns
           -------
           volume : plotly.graphic_objects.Volume
           """
           return self.fig.add_trace(go.Volume(**kwargs), row=row, col=col)
   
       def show(self, **kwargs):
           """
           Calls the figure's show method.
   
           Parameters
           ----------
           **kwargs : dict, optional
               keyword arguments passed on to fig.show
           """
           self.fig.show(**kwargs)
   
       def save(
           self,
           fileName: Path,
           tight: bool = True,
           fileType: Literal["pdf", "jpg", "png", "svg", "eps", "html", None] = None,
           **kwargs,
       ):
           """
           Saves the figure using the type specified. If HTML is specified, the figure will
           be saved as a dynamic html file. All other file types are static.
   
           Parameters
           ----------
           fileName : Path
               File name to save the figure as.
           tight : boolean, optional
               If the fileType is pdf or png and this value is true, then a tool is used to eliminate whitespace. pdfCropMargins is used for PDFs and PIL is used for png's. (Default value = True)
           fileType : Literal["pdf","jpg","png","svg","eps","html",None], optional
               File suffix to use. If None, then the suffix will be inferred from the suffix of the fileName. (Default value = None)
           **kwargs : dict, optional
               Keyword arguments passed onto fig.write_image or fig.write_html, depending on fileType.
           """
           if not fileType:
               from pathlib import Path
   
               suffix = Path(fileName).suffix[1:]
               if suffix in ["pdf", "jpg", "png", "svg", "eps", "html"]:
                   fileType = suffix
               else:
                   fileType = "png"
                   TFCPrint.Warning(
                       f"Warning, file type could not be inferred from {fileName}. The file type has been set to png."
                   )
                   fileName += "." + fileType
               fileNameFull = fileName
           else:
               fileNameFull = fileName + "." + fileType
           if fileType == "html":
               self.fig.write_html(fileNameFull, **kwargs)
           elif fileType == "png" and tight == True:
               from io import BytesIO
               from PIL import Image
   
               dark = BytesIO()
               dark.write(self.fig.to_image(fileType, **kwargs))
               dark.seek(0)
               pilImage = Image.open(dark)
               pngArray = np.array(pilImage)
               blankPx = pngArray[0, 0, :]
               mask = pngArray != blankPx
               coords = np.argwhere(mask)
               x0, y0, z0 = coords.min(axis=0)
               x1, y1, z1 = coords.max(axis=0) + 1
               croppedBox = pngArray[x0:x1, y0:y1, z0:z1]
               pilImage = Image.fromarray(croppedBox, "RGBA")
               pilImage.save(fileNameFull)
           else:
               self.fig.write_image(fileNameFull, **kwargs)
               if fileType == "pdf" and tight:
                   from pdfCropMargins import crop
   
                   crop(["-p", "0", fileNameFull])
   
       def UploadToPlotly(self, username: str, apiKey: str, fileName: str, autoOpen: bool = False):
           """
           Upload your plot to Plotly.
   
           Parameters
           ----------
           username : str
               Plotly username
           apiKey : str
               Plotly api_key
           fileName : str
               Name of the file to save the plot as.
           autoOpen : bool, optional
               If true, plot will open in browser after saving. (Default value = False)
           """
   
           import chart_studio
   
           chart_studio.tools.set_credentials_file(username=username, api_key=apiKey)
           return chart_studio.plotly.plot(self.fig, filename=fileName, auto_open=autoOpen)
   
       def view(
           self,
           azimuth: float,
           elevation: float,
           row: Optional[uint] = None,
           col: Optional[uint] = None,
           viewDistance: Optional[float] = None,
       ):
           """
           Change the view on the subplot specified by row and col or on
           the main figure if not using subplots.
   
           Parameters
           ----------
           azimuth : float
               Azimuth value in degrees
           elevation : float
               Elevation value in degrees
           row : Optional[uint]
                Subplot row (Default value = None)
           col : Optional[uint]
                Subplot column (Default value = None)
           viewDistance : Optional[float]
               Distance from camera to plot. (Default value = self.viewDistance)
           """
   
           if not self._is3d:
               TFCPrint.Error("The view method is only for 3d plots.")
   
           if viewDistance:
               self.viewDistance = viewDistance
   
           azimuth *= np.pi / 180.0
           elevation *= np.pi / 180.0
           dark = self.viewDistance * np.array(
               [
                   -np.sin(azimuth) * np.cos(elevation),
                   -np.cos(azimuth) * np.cos(elevation),
                   np.sin(elevation),
               ]
           )
           if row and col:
               sceneNum = row + col - 1
               if sceneNum == 1:
                   sceneNum = ""
               else:
                   sceneNum = str(sceneNum)
   
               self.fig["layout"]["scene" + sceneNum]["camera"].eye = dict(
                   x=dark[0], y=dark[1], z=dark[2]
               )
           else:
               self.fig["layout"]["scene"]["camera"].eye = dict(x=dark[0], y=dark[1], z=dark[2])
   
       def FullScreen(self):
           """
           Make the plot full screen.
           """
           import tkinter as tk
   
           root = tk.Tk()
           root.withdraw()
           width = root.winfo_screenwidth()
           height = root.winfo_screenheight()
   
           self.fig.update_layout(width=width, height=height)
   
       def PartScreen(self, width: float, height: float, units: Literal["in", "mm", "px"] = "in"):
           """
           Make the plot size equal to width x height.
   
           Parameters
           ----------
           width : float
               Width of the plot
           height : float
               Height of the plot
           units : Literal["in","mm","px"], optional
               Units width and height are given in. (Default value = inches)
           """
           if units != "px":
               import tkinter as tk
   
               root = tk.Tk()
               root.withdraw()
   
               if units == "in":
                   widthConvert = root.winfo_screenwidth() / (root.winfo_screenmmwidth() / 25.4)
                   heightConvert = root.winfo_screenheight() / (root.winfo_screenmmheight() / 25.4)
               else:
                   widthConvert = root.winfo_screenwidth() / root.winfo_screenmmwidth()
                   heightConvert = root.winfo_screenheight() / root.winfo_screenmmheight()
   
               self.fig.update_layout(
                   width=int(width * widthConvert), height=int(height * heightConvert)
               )
           else:
               self.fig.update_layout(width=width, height=height)
   
       def NormalizeColorScale(
           self,
           types: list[str] = [],
           data: Optional[str] = None,
           cmax: Optional[float] = None,
           cmin: Optional[float] = None,
       ):
           """
           Normalizes the color scale for the plots whose type is in the types list.
   
           If cmax and/or cmin are given, then the data variable is not used, and all
           plots whose type is in the types list are assigned that cmax/cmin value.
   
           If cmax and cmin are not specified, then they are set by taking the max and min
           of the data that matches the data variable in all plots whose type is in the
           typpes list.
   
           Parameters
           ----------
           types: list[str]
               Plot types to set cmax and cmin for.
           data: Optional[str]
               Data type to use to calculate cmax and cmin if not already specified. (Default value = None)
           cmax: Optional[float]
               cmax value to use when setting the colorscale
           cmin: Optional[float]
               cmin value to use when setting the colorscale
           """
   
           dark = self.fig.data
           if cmax is None and cmin is None:
               cmax = -np.inf
               cmin = np.inf
               for k in dark:
                   if k.__class__.__name__ in types:
                       if data in k:
                           kmax = k[data].max()
                           kmin = k[data].min()
   
                       if kmax > cmax:
                           cmax = kmax
                       if kmin < cmin:
                           cmin = kmin
   
           colorScaleUpdate = {}
           if not cmax is None:
               colorScaleUpdate.update({"cmax": cmax})
           if not cmin is None:
               colorScaleUpdate.update({"cmin": cmin})
   
           for k in dark:
               if k.__class__.__name__ in types:
                   if data in k:
                       k.update(colorScaleUpdate)
