diff --git a/tools/mask/mask.py b/tools/mask/mask.py index 92a5aa0..97ae49b 100644 --- a/tools/mask/mask.py +++ b/tools/mask/mask.py @@ -4,6 +4,7 @@ import logging import os import sys from argparse import Namespace +from multiprocessing import Process from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union import cv2 @@ -79,28 +80,28 @@ class Mask(): # pylint:disable=too-few-public-methods logger.debug("Returning output: '%s' for input: '%s'", retval, input_location) return retval - def _get_extractor(self) -> Optional[Extractor]: - """ Obtain a Mask extractor plugin and launch it + @staticmethod + def _run_mask_process(arguments: Namespace) -> None: + """ The mask process to be run in a spawned process. - Returns - ------- - :class:`plugins.extract.pipeline.Extractor`: - The launched Extractor + In some instances, batch-mode memory leaks. Launching each job in a separate process + prevents this leak. + + Parameters + ---------- + arguments: :class:`argparse.Namespace` + The :mod:`argparse` arguments to be used for the given job """ - if self._args.processing == "output": - logger.debug("Update type `output` selected. Not launching extractor") - return None - logger.debug("masker: %s", self._args.masker) - extractor = Extractor(None, None, self._args.masker, exclude_gpus=self._args.exclude_gpus) - logger.debug(extractor) - return extractor + logger.debug("Starting process: (arguments: %s)", arguments) + mask = _Mask(arguments) + mask.process() + logger.debug("Finished process: (arguments: %s)", arguments) def process(self) -> None: """ The entry point for triggering the Extraction Process. Should only be called from :class:`lib.cli.launcher.ScriptExecutor` """ - extractor = self._get_extractor() for idx, location in enumerate(self._input_locations): if self._args.batch_mode: logger.info("Processing job %s of %s: %s", @@ -115,14 +116,12 @@ class Mask(): # pylint:disable=too-few-public-methods else: arguments = self._args - if extractor is not None: - extractor.launch() - - mask = _Mask(arguments, extractor) - mask.process() - - if extractor is not None: - extractor.reset_phase_index() + if len(self._input_locations) > 1: + proc = Process(target=self._run_mask_process, args=(arguments, )) + proc.start() + proc.join() + else: + self._run_mask_process(arguments) class _Mask(): # pylint:disable=too-few-public-methods @@ -136,12 +135,9 @@ class _Mask(): # pylint:disable=too-few-public-methods ---------- arguments: :class:`argparse.Namespace` The :mod:`argparse` arguments as passed in from :mod:`tools.py` - extractor: :class:`plugins.extract.pipeline.Extractor`: - The launched Extractor """ - def __init__(self, arguments: Namespace, extractor: Optional[Extractor]) -> None: - logger.debug("Initializing %s: (arguments: %s, extractor: %s)", - self.__class__.__name__, arguments, extractor) + def __init__(self, arguments: Namespace) -> None: + logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) self._update_type = arguments.processing self._input_is_faces = arguments.input_type == "faces" self._mask_type = arguments.masker @@ -159,7 +155,7 @@ class _Mask(): # pylint:disable=too-few-public-methods self._faces_saver: Optional[ImagesSaver] = None self._alignments = self._get_alignments(arguments) - self._extractor = extractor + self._extractor = self._get_extractor(arguments.exclude_gpus) self._set_correct_mask_type() self._extractor_input_thread = self._feed_extractor() @@ -246,6 +242,27 @@ class _Mask(): # pylint:disable=too-few-public-methods return Alignments(folder, filename=filename) + def _get_extractor(self, exclude_gpus: List[int]) -> Optional[Extractor]: + """ Obtain a Mask extractor plugin and launch it + Parameters + ---------- + exclude_gpus: list or ``None`` + A list of indices correlating to connected GPUs that Tensorflow should not use. Pass + ``None`` to not exclude any GPUs. + Returns + ------- + :class:`plugins.extract.pipeline.Extractor`: + The launched Extractor + """ + if self._update_type == "output": + logger.debug("Update type `output` selected. Not launching extractor") + return None + logger.debug("masker: %s", self._mask_type) + extractor = Extractor(None, None, self._mask_type, exclude_gpus=exclude_gpus) + extractor.launch() + logger.debug(extractor) + return extractor + def _set_correct_mask_type(self): """ Some masks have multiple variants that they can be saved as depending on config options so update the :attr:`_mask_type` accordingly