scripts.train - type checking

This commit is contained in:
torzdf
2022-06-30 18:41:58 +01:00
parent 91fecc47b2
commit b1a8183ab4
3 changed files with 54 additions and 38 deletions

View File

@@ -7,8 +7,10 @@ import sys
from threading import Lock
from time import sleep
from typing import cast, Callable, Dict, List, Optional, TYPE_CHECKING
import cv2
import numpy as np
from lib.image import read_image_meta
from lib.keypress import KBHit
@@ -16,6 +18,16 @@ from lib.multithreading import MultiThread
from lib.utils import (get_folder, get_image_paths, FaceswapError, _image_extensions)
from plugins.plugin_loader import PluginLoader
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
import argparse
from plugins.train.model._base import ModelBase
from plugins.train.trainer._base import TrainerBase
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@@ -34,7 +46,7 @@ class Train(): # pylint:disable=too-few-public-methods
The arguments to be passed to the training process as generated from Faceswap's command
line arguments
"""
def __init__(self, arguments):
def __init__(self, arguments: "argparse.Namespace") -> None:
logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
self._args = arguments
if self._args.summary:
@@ -47,16 +59,16 @@ class Train(): # pylint:disable=too-few-public-methods
os.path.realpath(os.path.dirname(sys.argv[0])), "lib", "gui", ".cache")
self._gui_triggers = dict(update=os.path.join(gui_cache, ".preview_trigger"),
mask_toggle=os.path.join(gui_cache, ".preview_mask_toggle"))
self._stop = False
self._save_now = False
self._toggle_preview_mask = False
self._refresh_preview = False
self._preview_buffer = {}
self._stop: bool = False
self._save_now: bool = False
self._toggle_preview_mask: bool = False
self._refresh_preview: bool = False
self._preview_buffer: Dict[str, np.ndarray] = {}
self._lock = Lock()
logger.debug("Initialized %s", self.__class__.__name__)
def _get_images(self):
def _get_images(self) -> Dict[Literal["a", "b"], List[str]]:
""" Check the image folders exist and contains valid extracted faces. Obtain image paths.
Returns
@@ -68,6 +80,7 @@ class Train(): # pylint:disable=too-few-public-methods
logger.debug("Getting image paths")
images = {}
for side in ("a", "b"):
side = cast(Literal["a", "b"], side)
image_dir = getattr(self._args, f"input_{side}")
if not os.path.isdir(image_dir):
logger.error("Error: '%s' does not exist", image_dir)
@@ -96,7 +109,7 @@ class Train(): # pylint:disable=too-few-public-methods
return images
@classmethod
def _validate_image_counts(cls, images):
def _validate_image_counts(cls, images: Dict[Literal["a", "b"], List[str]]) -> None:
""" Validate that there are sufficient images to commence training without raising an
error.
@@ -124,7 +137,7 @@ class Train(): # pylint:disable=too-few-public-methods
"Results are likely to be poor.")
logger.warning(msg)
def _set_timelapse(self):
def _set_timelapse(self) -> Dict[Literal["input_a", "input_b", "output"], str]:
""" Set time-lapse paths if requested.
Returns
@@ -136,7 +149,7 @@ class Train(): # pylint:disable=too-few-public-methods
if (not self._args.timelapse_input_a and
not self._args.timelapse_input_b and
not self._args.timelapse_output):
return None
return {}
if (not self._args.timelapse_input_a or
not self._args.timelapse_input_b or
not self._args.timelapse_output):
@@ -147,6 +160,7 @@ class Train(): # pylint:disable=too-few-public-methods
timelapse_output = get_folder(self._args.timelapse_output)
for side in ("a", "b"):
side = cast(Literal["a", "b"], side)
folder = getattr(self._args, f"timelapse_input_{side}")
if folder is not None and not os.path.isdir(folder):
raise FaceswapError(f"The Timelapse path '{folder}' does not exist")
@@ -168,13 +182,14 @@ class Train(): # pylint:disable=too-few-public-methods
raise FaceswapError(f"All images in the Timelapse folder '{folder}' must exist in "
f"the training folder '{training_folder}'")
kwargs = {"input_a": self._args.timelapse_input_a,
"input_b": self._args.timelapse_input_b,
"output": timelapse_output}
TKey = Literal["input_a", "input_b", "output"]
kwargs = {cast(TKey, "input_a"): self._args.timelapse_input_a,
cast(TKey, "input_b"): self._args.timelapse_input_b,
cast(TKey, "output"): timelapse_output}
logger.debug("Timelapse enabled: %s", kwargs)
return kwargs
def process(self):
def process(self) -> None:
""" The entry point for triggering the Training Process.
Should only be called from :class:`lib.cli.launcher.ScriptExecutor`
@@ -190,7 +205,7 @@ class Train(): # pylint:disable=too-few-public-methods
self._end_thread(thread, err)
logger.debug("Completed Training Process")
def _start_thread(self):
def _start_thread(self) -> MultiThread:
""" Put the :func:`_training` into a background thread so we can keep control.
Returns
@@ -204,7 +219,7 @@ class Train(): # pylint:disable=too-few-public-methods
logger.debug("Launched Trainer thread")
return thread
def _end_thread(self, thread, err):
def _end_thread(self, thread: MultiThread, err: bool) -> None:
""" Output message and join thread back to main on termination.
Parameters
@@ -231,7 +246,7 @@ class Train(): # pylint:disable=too-few-public-methods
sys.stdout.flush()
logger.debug("Ended training thread")
def _training(self):
def _training(self) -> None:
""" The training process to be run inside a thread. """
try:
sleep(1) # Let preview instructions flush out to logger
@@ -251,7 +266,7 @@ class Train(): # pylint:disable=too-few-public-methods
except Exception as err:
raise err
def _load_model(self):
def _load_model(self) -> "ModelBase":
""" Load the model requested for training.
Returns
@@ -261,7 +276,7 @@ class Train(): # pylint:disable=too-few-public-methods
"""
logger.debug("Loading Model")
model_dir = get_folder(self._args.model_dir)
model = PluginLoader.get_model(self._args.trainer)(
model: "ModelBase" = PluginLoader.get_model(self._args.trainer)(
model_dir,
self._args,
predict=False)
@@ -269,7 +284,7 @@ class Train(): # pylint:disable=too-few-public-methods
logger.debug("Loaded Model")
return model
def _load_trainer(self, model):
def _load_trainer(self, model: "ModelBase") -> "TrainerBase":
""" Load the trainer requested for training.
Parameters
@@ -283,15 +298,15 @@ class Train(): # pylint:disable=too-few-public-methods
The requested model trainer plugin
"""
logger.debug("Loading Trainer")
trainer = PluginLoader.get_trainer(model.trainer)
trainer = trainer(model,
self._images,
self._args.batch_size,
self._args.configfile)
base = PluginLoader.get_trainer(model.trainer)
trainer: "TrainerBase" = base(model,
self._images,
self._args.batch_size,
self._args.configfile)
logger.debug("Loaded Trainer")
return trainer
def _run_training_cycle(self, model, trainer):
def _run_training_cycle(self, model: "ModelBase", trainer: "TrainerBase") -> None:
""" Perform the training cycle.
Handles the background training, updating previews/time-lapse on each save interval,
@@ -306,12 +321,12 @@ class Train(): # pylint:disable=too-few-public-methods
"""
logger.debug("Running Training Cycle")
if self._args.write_image or self._args.redirect_gui or self._args.preview:
display_func = self._show
display_func: Optional[Callable] = self._show
else:
display_func = None
for iteration in range(1, self._args.iterations + 1):
logger.trace("Training iteration: %s", iteration)
logger.trace("Training iteration: %s", iteration) # type:ignore
save_iteration = iteration % self._args.save_interval == 0 or iteration == 1
if self._toggle_preview_mask:
@@ -323,7 +338,7 @@ class Train(): # pylint:disable=too-few-public-methods
viewer = display_func
else:
viewer = None
timelapse = self._timelapse if save_iteration else None
timelapse = self._timelapse if save_iteration else {}
trainer.train_one_step(viewer, timelapse)
if self._stop:
logger.debug("Stop received. Terminating")
@@ -347,7 +362,7 @@ class Train(): # pylint:disable=too-few-public-methods
trainer.clear_tensorboard()
self._stop = True
def _output_startup_info(self):
def _output_startup_info(self) -> None:
""" Print the startup information to the console. """
logger.debug("Launching Monitor")
logger.info("===================================================")
@@ -530,7 +545,7 @@ class Train(): # pylint:disable=too-few-public-methods
logger.debug("Closed Monitor")
return err
def _show(self, image, name=""):
def _show(self, image: np.ndarray, name: str = "") -> None:
""" Generate the preview and write preview file output.
Handles the output and display of preview images.