diff --git a/pylatex/figure.py b/pylatex/figure.py index 0e3add51..36fd7f4a 100644 --- a/pylatex/figure.py +++ b/pylatex/figure.py @@ -22,7 +22,7 @@ def add_image( filename, *, width=NoEscape(r"0.8\textwidth"), - placement=NoEscape(r"\centering") + placement=NoEscape(r"\centering"), ): """Add an image to the figure. @@ -50,7 +50,7 @@ def add_image( StandAloneGraphic(image_options=width, filename=fix_filename(filename)) ) - def _save_plot(self, *args, extension="pdf", **kwargs): + def _save_plot(self, *args, figure=None, extension="pdf", **kwargs): """Save the plot. Returns @@ -64,11 +64,12 @@ def _save_plot(self, *args, extension="pdf", **kwargs): filename = "{}.{}".format(str(uuid.uuid4()), extension.strip(".")) filepath = posixpath.join(tmp_path, filename) - plt.savefig(filepath, *args, **kwargs) + fig = figure or plt.gcf() + fig.savefig(filepath, *args, **kwargs) return filepath - def add_plot(self, *args, extension="pdf", **kwargs): - """Add the current Matplotlib plot to the figure. + def add_plot(self, *args, figure=None, extension="pdf", **kwargs): + """Add a Matplotlib plot to the figure. The plot that gets added is the one that would normally be shown when using ``plt.show()``. @@ -77,6 +78,8 @@ def add_plot(self, *args, extension="pdf", **kwargs): ---- args: Arguments passed to plt.savefig for displaying the plot. + figure: + Optional matplotlib figure. If None add the current figure. extension : str extension of image file indicating figure file type kwargs: @@ -92,7 +95,7 @@ def add_plot(self, *args, extension="pdf", **kwargs): if key in kwargs: add_image_kwargs[key] = kwargs.pop(key) - filename = self._save_plot(*args, extension=extension, **kwargs) + filename = self._save_plot(*args, figure=figure, extension=extension, **kwargs) self.add_image(filename, **add_image_kwargs) diff --git a/tests/test_pictures.py b/tests/test_pictures.py index 1b222e35..072a5ed4 100644 --- a/tests/test_pictures.py +++ b/tests/test_pictures.py @@ -2,18 +2,40 @@ import os +import matplotlib.pyplot as plt + from pylatex import Document, Section from pylatex.figure import Figure -def test(): +def test_add_image(): doc = Document() - section = Section("Multirow Test") + section = Section("Add image Test") figure = Figure() image_filename = os.path.join(os.path.dirname(__file__), "../examples/kitten.jpg") figure.add_image(image_filename) - figure.add_caption("Whoooo an imagage of a pdf") + figure.add_caption("Whoooo an image of a kitty") + section.append(figure) + doc.append(section) + + doc.generate_pdf() + + +def test_add_plot(): + doc = Document() + section = Section("Add plot Test") + mplfig = plt.figure() + + figure = Figure() + figure.add_plot() + figure.add_caption("Whoooo current matplotlib fig") section.append(figure) + + figure = Figure() + figure.add_plot(figure=mplfig) + figure.add_caption("Whoooo image from figure handle") + section.append(figure) + doc.append(section) doc.generate_pdf()