mirror of
https://github.com/zebrajr/faceswap.git
synced 2026-01-15 12:15:15 +00:00
scripts.train - type checking
This commit is contained in:
@@ -7,8 +7,10 @@ import sys
|
||||
|
||||
from threading import Lock
|
||||
from time import sleep
|
||||
from typing import cast, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lib.image import read_image_meta
|
||||
from lib.keypress import KBHit
|
||||
@@ -16,6 +18,16 @@ from lib.multithreading import MultiThread
|
||||
from lib.utils import (get_folder, get_image_paths, FaceswapError, _image_extensions)
|
||||
from plugins.plugin_loader import PluginLoader
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import argparse
|
||||
from plugins.train.model._base import ModelBase
|
||||
from plugins.train.trainer._base import TrainerBase
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -34,7 +46,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
The arguments to be passed to the training process as generated from Faceswap's command
|
||||
line arguments
|
||||
"""
|
||||
def __init__(self, arguments):
|
||||
def __init__(self, arguments: "argparse.Namespace") -> None:
|
||||
logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
|
||||
self._args = arguments
|
||||
if self._args.summary:
|
||||
@@ -47,16 +59,16 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
os.path.realpath(os.path.dirname(sys.argv[0])), "lib", "gui", ".cache")
|
||||
self._gui_triggers = dict(update=os.path.join(gui_cache, ".preview_trigger"),
|
||||
mask_toggle=os.path.join(gui_cache, ".preview_mask_toggle"))
|
||||
self._stop = False
|
||||
self._save_now = False
|
||||
self._toggle_preview_mask = False
|
||||
self._refresh_preview = False
|
||||
self._preview_buffer = {}
|
||||
self._stop: bool = False
|
||||
self._save_now: bool = False
|
||||
self._toggle_preview_mask: bool = False
|
||||
self._refresh_preview: bool = False
|
||||
self._preview_buffer: Dict[str, np.ndarray] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def _get_images(self):
|
||||
def _get_images(self) -> Dict[Literal["a", "b"], List[str]]:
|
||||
""" Check the image folders exist and contains valid extracted faces. Obtain image paths.
|
||||
|
||||
Returns
|
||||
@@ -68,6 +80,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
logger.debug("Getting image paths")
|
||||
images = {}
|
||||
for side in ("a", "b"):
|
||||
side = cast(Literal["a", "b"], side)
|
||||
image_dir = getattr(self._args, f"input_{side}")
|
||||
if not os.path.isdir(image_dir):
|
||||
logger.error("Error: '%s' does not exist", image_dir)
|
||||
@@ -96,7 +109,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
return images
|
||||
|
||||
@classmethod
|
||||
def _validate_image_counts(cls, images):
|
||||
def _validate_image_counts(cls, images: Dict[Literal["a", "b"], List[str]]) -> None:
|
||||
""" Validate that there are sufficient images to commence training without raising an
|
||||
error.
|
||||
|
||||
@@ -124,7 +137,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
"Results are likely to be poor.")
|
||||
logger.warning(msg)
|
||||
|
||||
def _set_timelapse(self):
|
||||
def _set_timelapse(self) -> Dict[Literal["input_a", "input_b", "output"], str]:
|
||||
""" Set time-lapse paths if requested.
|
||||
|
||||
Returns
|
||||
@@ -136,7 +149,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
if (not self._args.timelapse_input_a and
|
||||
not self._args.timelapse_input_b and
|
||||
not self._args.timelapse_output):
|
||||
return None
|
||||
return {}
|
||||
if (not self._args.timelapse_input_a or
|
||||
not self._args.timelapse_input_b or
|
||||
not self._args.timelapse_output):
|
||||
@@ -147,6 +160,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
timelapse_output = get_folder(self._args.timelapse_output)
|
||||
|
||||
for side in ("a", "b"):
|
||||
side = cast(Literal["a", "b"], side)
|
||||
folder = getattr(self._args, f"timelapse_input_{side}")
|
||||
if folder is not None and not os.path.isdir(folder):
|
||||
raise FaceswapError(f"The Timelapse path '{folder}' does not exist")
|
||||
@@ -168,13 +182,14 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
raise FaceswapError(f"All images in the Timelapse folder '{folder}' must exist in "
|
||||
f"the training folder '{training_folder}'")
|
||||
|
||||
kwargs = {"input_a": self._args.timelapse_input_a,
|
||||
"input_b": self._args.timelapse_input_b,
|
||||
"output": timelapse_output}
|
||||
TKey = Literal["input_a", "input_b", "output"]
|
||||
kwargs = {cast(TKey, "input_a"): self._args.timelapse_input_a,
|
||||
cast(TKey, "input_b"): self._args.timelapse_input_b,
|
||||
cast(TKey, "output"): timelapse_output}
|
||||
logger.debug("Timelapse enabled: %s", kwargs)
|
||||
return kwargs
|
||||
|
||||
def process(self):
|
||||
def process(self) -> None:
|
||||
""" The entry point for triggering the Training Process.
|
||||
|
||||
Should only be called from :class:`lib.cli.launcher.ScriptExecutor`
|
||||
@@ -190,7 +205,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
self._end_thread(thread, err)
|
||||
logger.debug("Completed Training Process")
|
||||
|
||||
def _start_thread(self):
|
||||
def _start_thread(self) -> MultiThread:
|
||||
""" Put the :func:`_training` into a background thread so we can keep control.
|
||||
|
||||
Returns
|
||||
@@ -204,7 +219,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
logger.debug("Launched Trainer thread")
|
||||
return thread
|
||||
|
||||
def _end_thread(self, thread, err):
|
||||
def _end_thread(self, thread: MultiThread, err: bool) -> None:
|
||||
""" Output message and join thread back to main on termination.
|
||||
|
||||
Parameters
|
||||
@@ -231,7 +246,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
sys.stdout.flush()
|
||||
logger.debug("Ended training thread")
|
||||
|
||||
def _training(self):
|
||||
def _training(self) -> None:
|
||||
""" The training process to be run inside a thread. """
|
||||
try:
|
||||
sleep(1) # Let preview instructions flush out to logger
|
||||
@@ -251,7 +266,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
except Exception as err:
|
||||
raise err
|
||||
|
||||
def _load_model(self):
|
||||
def _load_model(self) -> "ModelBase":
|
||||
""" Load the model requested for training.
|
||||
|
||||
Returns
|
||||
@@ -261,7 +276,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
"""
|
||||
logger.debug("Loading Model")
|
||||
model_dir = get_folder(self._args.model_dir)
|
||||
model = PluginLoader.get_model(self._args.trainer)(
|
||||
model: "ModelBase" = PluginLoader.get_model(self._args.trainer)(
|
||||
model_dir,
|
||||
self._args,
|
||||
predict=False)
|
||||
@@ -269,7 +284,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
logger.debug("Loaded Model")
|
||||
return model
|
||||
|
||||
def _load_trainer(self, model):
|
||||
def _load_trainer(self, model: "ModelBase") -> "TrainerBase":
|
||||
""" Load the trainer requested for training.
|
||||
|
||||
Parameters
|
||||
@@ -283,15 +298,15 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
The requested model trainer plugin
|
||||
"""
|
||||
logger.debug("Loading Trainer")
|
||||
trainer = PluginLoader.get_trainer(model.trainer)
|
||||
trainer = trainer(model,
|
||||
self._images,
|
||||
self._args.batch_size,
|
||||
self._args.configfile)
|
||||
base = PluginLoader.get_trainer(model.trainer)
|
||||
trainer: "TrainerBase" = base(model,
|
||||
self._images,
|
||||
self._args.batch_size,
|
||||
self._args.configfile)
|
||||
logger.debug("Loaded Trainer")
|
||||
return trainer
|
||||
|
||||
def _run_training_cycle(self, model, trainer):
|
||||
def _run_training_cycle(self, model: "ModelBase", trainer: "TrainerBase") -> None:
|
||||
""" Perform the training cycle.
|
||||
|
||||
Handles the background training, updating previews/time-lapse on each save interval,
|
||||
@@ -306,12 +321,12 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
"""
|
||||
logger.debug("Running Training Cycle")
|
||||
if self._args.write_image or self._args.redirect_gui or self._args.preview:
|
||||
display_func = self._show
|
||||
display_func: Optional[Callable] = self._show
|
||||
else:
|
||||
display_func = None
|
||||
|
||||
for iteration in range(1, self._args.iterations + 1):
|
||||
logger.trace("Training iteration: %s", iteration)
|
||||
logger.trace("Training iteration: %s", iteration) # type:ignore
|
||||
save_iteration = iteration % self._args.save_interval == 0 or iteration == 1
|
||||
|
||||
if self._toggle_preview_mask:
|
||||
@@ -323,7 +338,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
viewer = display_func
|
||||
else:
|
||||
viewer = None
|
||||
timelapse = self._timelapse if save_iteration else None
|
||||
timelapse = self._timelapse if save_iteration else {}
|
||||
trainer.train_one_step(viewer, timelapse)
|
||||
if self._stop:
|
||||
logger.debug("Stop received. Terminating")
|
||||
@@ -347,7 +362,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
trainer.clear_tensorboard()
|
||||
self._stop = True
|
||||
|
||||
def _output_startup_info(self):
|
||||
def _output_startup_info(self) -> None:
|
||||
""" Print the startup information to the console. """
|
||||
logger.debug("Launching Monitor")
|
||||
logger.info("===================================================")
|
||||
@@ -530,7 +545,7 @@ class Train(): # pylint:disable=too-few-public-methods
|
||||
logger.debug("Closed Monitor")
|
||||
return err
|
||||
|
||||
def _show(self, image, name=""):
|
||||
def _show(self, image: np.ndarray, name: str = "") -> None:
|
||||
""" Generate the preview and write preview file output.
|
||||
|
||||
Handles the output and display of preview images.
|
||||
|
||||
Reference in New Issue
Block a user