From 8a5ee6d9bfb21c896e24513fd3c111c729219e26 Mon Sep 17 00:00:00 2001 From: torzdf <36920800+torzdf@users.noreply.github.com> Date: Fri, 10 Feb 2023 15:52:34 +0000 Subject: [PATCH] Bugfix: Patch extract memory leak in batch mode --- lib/image.py | 8 ++++++-- lib/multithreading.py | 7 ++++++- plugins/extract/_base.py | 1 - scripts/extract.py | 17 ++++++++++++++--- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/lib/image.py b/lib/image.py index f85ff99..dd9910f 100644 --- a/lib/image.py +++ b/lib/image.py @@ -895,9 +895,10 @@ class ImageIO(): def _set_thread(self): """ Set the background thread for the load and save iterators and launch it. """ - logger.debug("Setting thread") + logger.trace("Setting thread") # type:ignore[attr-defined] if self._thread is not None and self._thread.is_alive(): - logger.debug("Thread pre-exists and is alive: %s", self._thread) + logger.trace("Thread pre-exists and is alive: %s", # type:ignore[attr-defined] + self._thread) return self._thread = MultiThread(self._process, self._queue, @@ -921,6 +922,7 @@ class ImageIO(): logger.debug("Received Close") if self._thread is not None: self._thread.join() + del self._thread self._thread = None logger.debug("Closed") @@ -1461,6 +1463,8 @@ class ImagesSaver(ImageIO): logger.trace("Saved image: '%s'", filename) # type:ignore except Exception as err: # pylint: disable=broad-except logger.error("Failed to save image '%s'. Original Error: %s", filename, str(err)) + del image + del filename def save(self, filename: str, diff --git a/lib/multithreading.py b/lib/multithreading.py index a06f399..e85685a 100644 --- a/lib/multithreading.py +++ b/lib/multithreading.py @@ -206,7 +206,10 @@ class MultiThread(): return retval def join(self) -> None: - """ Join the running threads, catching and re-raising any errors """ + """ Join the running threads, catching and re-raising any errors + + Clear the list of threads for class instance re-use + """ logger.debug("Joining Threads: '%s'", self._name) for thread in self._threads: logger.debug("Joining Thread: '%s'", thread._name) # pylint: disable=protected-access @@ -215,6 +218,8 @@ class MultiThread(): logger.error("Caught exception in thread: '%s'", thread._name) # pylint: disable=protected-access raise thread.err[1].with_traceback(thread.err[2]) + del self._threads + self._threads = [] logger.debug("Joined all Threads: '%s'", self._name) diff --git a/plugins/extract/_base.py b/plugins/extract/_base.py index 188eb67..efb95b7 100644 --- a/plugins/extract/_base.py +++ b/plugins/extract/_base.py @@ -394,7 +394,6 @@ class Extractor(): """ for thread in self._threads: thread.join() - del thread def check_and_raise_error(self) -> None: """ Check all threads for errors diff --git a/scripts/extract.py b/scripts/extract.py index 4e9c189..35b5343 100644 --- a/scripts/extract.py +++ b/scripts/extract.py @@ -7,6 +7,7 @@ import logging import os import sys from argparse import Namespace +from multiprocessing import Process from typing import List, Dict, Optional, Tuple, TYPE_CHECKING, Union import numpy as np @@ -150,19 +151,29 @@ class Extract(): # pylint:disable=too-few-public-methods Should only be called from :class:`lib.cli.launcher.ScriptExecutor` """ logger.info('Starting, this may take a while...') - inputs = self._input_locations if self._args.batch_mode: logger.info("Batch mode selected processing: %s", self._input_locations) for job_no, location in enumerate(self._input_locations): if self._args.batch_mode: - logger.info("Processing job %s of %s: '%s'", job_no + 1, len(inputs), location) + logger.info("Processing job %s of %s: '%s'", + job_no + 1, len(self._input_locations), location) arguments = Namespace(**self._args.__dict__) arguments.input_dir = location arguments.output_dir = self._output_for_input(location) else: arguments = self._args extract = _Extract(self._extractor, arguments) - extract.process() + if len(self._input_locations) > 1: + # TODO - Running this in a process is hideously hacky. However, there is a memory + # leak in some instances when running in batch mode. Many days have been spent + # trying to track this down to no avail (most likely coming from C-code.) Running + # the extract job inside a process prevents the memory leak in testing. This should + # be replaced if/when the memory leak is found + proc = Process(target=extract.process) + proc.start() + proc.join() + else: + extract.process() self._extractor.reset_phase_index()