Bugfix: convert - Gif Writer

- Fix non-launch error on Gif Writer
  - convert plugins - linting
  - convert/fs_media/preview/queue_manager - typing
  - Change convert items from dict to Dataclass
This commit is contained in:
torzdf
2022-08-26 23:56:03 +01:00
parent 9bd86eb810
commit 1022651eb8
13 changed files with 888 additions and 590 deletions

View File

@@ -2,15 +2,56 @@
""" Converter for Faceswap """
import logging
import sys
from dataclasses import dataclass
from typing import Callable, cast, List, Optional, Tuple, TYPE_CHECKING, Union
import cv2
import numpy as np
from plugins.plugin_loader import PluginLoader
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
from argparse import Namespace
from lib.align.aligned_face import AlignedFace, CenteringType
from lib.align.detected_face import DetectedFace
from lib.config import FaceswapConfig
from lib.queue_manager import EventQueue
from scripts.convert import ConvertItem
from plugins.convert.color._base import Adjustment as ColorAdjust
from plugins.convert.color.seamless_clone import Color as SeamlessAdjust
from plugins.convert.mask.mask_blend import Mask as MaskAdjust
from plugins.convert.scaling._base import Adjustment as ScalingAdjust
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@dataclass
class Adjustments:
""" Dataclass to hold the optional processing plugins
Parameters
----------
color: :class:`~plugins.color._base.Adjustment`, Optional
The selected color processing plugin. Default: `None`
mask: :class:`~plugins.mask_blend.Mask`, Optional
The selected mask processing plugin. Default: `None`
seamless: :class:`~plugins.color.seamless_clone.Color`, Optional
The selected mask processing plugin. Default: `None`
sharpening: :class:`~plugins.scaling._base.Adjustment`, Optional
The selected mask processing plugin. Default: `None`
"""
color: Optional["ColorAdjust"] = None
mask: Optional["MaskAdjust"] = None
seamless: Optional["SeamlessAdjust"] = None
sharpening: Optional["ScalingAdjust"] = None
class Converter():
""" The converter is responsible for swapping the original face(s) in a frame with the output
of a trained Faceswap model.
@@ -37,8 +78,14 @@ class Converter():
Optional location of custom configuration ``ini`` file. If ``None`` then use the default
config location. Default: ``None``
"""
def __init__(self, output_size, coverage_ratio, centering, draw_transparent, pre_encode,
arguments, configfile=None):
def __init__(self,
output_size: int,
coverage_ratio: float,
centering: "CenteringType",
draw_transparent: bool,
pre_encode: Optional[Callable[[np.ndarray], List[bytes]]],
arguments: "Namespace",
configfile: Optional[str] = None) -> None:
logger.debug("Initializing %s: (output_size: %s, coverage_ratio: %s, centering: %s, "
"draw_transparent: %s, pre_encode: %s, arguments: %s, configfile: %s)",
self.__class__.__name__, output_size, coverage_ratio, centering,
@@ -52,18 +99,18 @@ class Converter():
self._configfile = configfile
self._scale = arguments.output_scale / 100
self._adjustments = dict(mask=None, color=None, seamless=None, sharpening=None)
self._adjustments = Adjustments()
self._load_plugins()
logger.debug("Initialized %s", self.__class__.__name__)
@property
def cli_arguments(self):
def cli_arguments(self) -> "Namespace":
""":class:`argparse.Namespace`: The command line arguments passed to the convert
process """
return self._args
def reinitialize(self, config):
def reinitialize(self, config: "FaceswapConfig") -> None:
""" Reinitialize this :class:`Converter`.
Called as part of the :mod:`~tools.preview` tool. Resets all adjustments then loads the
@@ -75,11 +122,13 @@ class Converter():
Pre-loaded :class:`lib.config.FaceswapConfig`. used over any configuration on disk.
"""
logger.debug("Reinitializing converter")
self._adjustments = dict(mask=None, color=None, seamless=None, sharpening=None)
self._adjustments = Adjustments()
self._load_plugins(config=config, disable_logging=True)
logger.debug("Reinitialized converter")
def _load_plugins(self, config=None, disable_logging=False):
def _load_plugins(self,
config: Optional["FaceswapConfig"] = None,
disable_logging: bool = False) -> None:
""" Load the requested adjustment plugins.
Loads the :mod:`plugins.converter` plugins that have been requested for this conversion
@@ -95,30 +144,32 @@ class Converter():
suppress these messages otherwise ``False``. Default: ``False``
"""
logger.debug("Loading plugins. config: %s", config)
self._adjustments["mask"] = PluginLoader.get_converter(
"mask",
"mask_blend",
disable_logging=disable_logging)(self._args.mask_type,
self._output_size,
self._coverage_ratio,
configfile=self._configfile,
config=config)
self._adjustments.mask = PluginLoader.get_converter("mask",
"mask_blend",
disable_logging=disable_logging)(
self._args.mask_type,
self._output_size,
self._coverage_ratio,
configfile=self._configfile,
config=config)
if self._args.color_adjustment != "none" and self._args.color_adjustment is not None:
self._adjustments["color"] = PluginLoader.get_converter(
"color",
self._args.color_adjustment,
disable_logging=disable_logging)(configfile=self._configfile, config=config)
self._adjustments.color = PluginLoader.get_converter("color",
self._args.color_adjustment,
disable_logging=disable_logging)(
configfile=self._configfile,
config=config)
sharpening = PluginLoader.get_converter(
"scaling",
"sharpen",
disable_logging=disable_logging)(configfile=self._configfile, config=config)
if sharpening.config.get("method", None) is not None:
self._adjustments["sharpening"] = sharpening
sharpening = PluginLoader.get_converter("scaling",
"sharpen",
disable_logging=disable_logging)(
configfile=self._configfile,
config=config)
if sharpening.config.get("method") is not None:
self._adjustments.sharpening = sharpening
logger.debug("Loaded plugins: %s", self._adjustments)
def process(self, in_queue, out_queue):
def process(self, in_queue: "EventQueue", out_queue: "EventQueue"):
""" Main convert process.
Takes items from the in queue, runs the relevant adjustments, patches faces to final frame
@@ -126,10 +177,10 @@ class Converter():
Parameters
----------
in_queue: :class:`queue.Queue`
in_queue: :class:`~lib.queue_manager.EventQueue`
The output from :class:`scripts.convert.Predictor`. Contains detected faces from the
Faceswap model as well as the frame to be patched.
out_queue: :class:`queue.Queue`
out_queue: :class:`~lib.queue_manager.EventQueue`
The queue to place patched frames into for writing by one of Faceswap's
:mod:`plugins.convert.writer` plugins.
"""
@@ -137,45 +188,44 @@ class Converter():
in_queue, out_queue)
log_once = False
while True:
items = in_queue.get()
if items == "EOF":
inbound: Union[Literal["EOF"], "ConvertItem", List["ConvertItem"]] = in_queue.get()
if inbound == "EOF":
logger.debug("EOF Received")
logger.debug("Patch queue finished")
# Signal EOF to other processes in pool
logger.debug("Putting EOF back to in_queue")
in_queue.put(items)
in_queue.put(inbound)
break
if isinstance(items, dict):
items = [items]
items = inbound if isinstance(inbound, list) else [inbound]
for item in items:
logger.trace("Patch queue got: '%s'", item["filename"])
logger.trace("Patch queue got: '%s'", item.inbound.filename) # type: ignore
try:
image = self._patch_image(item)
except Exception as err: # pylint: disable=broad-except
# Log error and output original frame
logger.error("Failed to convert image: '%s'. Reason: %s",
item["filename"], str(err))
image = item["image"]
item.inbound.filename, str(err))
image = item.inbound.image
loglevel = logger.trace if log_once else logger.warning
loglevel = logger.trace if log_once else logger.warning # type: ignore
loglevel("Convert error traceback:", exc_info=True)
log_once = True
# UNCOMMENT THIS CODE BLOCK TO PRINT TRACEBACK ERRORS
# import sys; import traceback
# exc_info = sys.exc_info(); traceback.print_exception(*exc_info)
logger.trace("Out queue put: %s", item["filename"])
out_queue.put((item["filename"], image))
logger.trace("Out queue put: %s", item.inbound.filename) # type: ignore
out_queue.put((item.inbound.filename, image))
logger.debug("Completed convert process")
def _patch_image(self, predicted):
def _patch_image(self, predicted: "ConvertItem") -> Union[np.ndarray, List[bytes]]:
""" Patch a swapped face onto a frame.
Run selected adjustments and swap the faces in a frame.
Parameters
----------
predicted: dict
predicted: :class:`~scripts.convert.ConvertItem`
The output from :class:`scripts.convert.Predictor`.
Returns
@@ -186,8 +236,8 @@ class Converter():
function (if it has one)
"""
logger.trace("Patching image: '%s'", predicted["filename"])
frame_size = (predicted["image"].shape[1], predicted["image"].shape[0])
logger.trace("Patching image: '%s'", predicted.inbound.filename) # type: ignore
frame_size = (predicted.inbound.image.shape[1], predicted.inbound.image.shape[0])
new_image, background = self._get_new_image(predicted, frame_size)
patched_face = self._post_warp_adjustments(background, new_image)
patched_face = self._scale_image(patched_face)
@@ -195,12 +245,16 @@ class Converter():
patched_face = np.rint(patched_face,
out=np.empty(patched_face.shape, dtype="uint8"),
casting='unsafe')
if self._writer_pre_encode is not None:
patched_face = self._writer_pre_encode(patched_face)
logger.trace("Patched image: '%s'", predicted["filename"])
return patched_face
if self._writer_pre_encode is None:
retval: Union[np.ndarray, List[bytes]] = patched_face
else:
retval = self._writer_pre_encode(patched_face)
logger.trace("Patched image: '%s'", predicted.inbound.filename) # type: ignore
return retval
def _get_new_image(self, predicted, frame_size):
def _get_new_image(self,
predicted: "ConvertItem",
frame_size: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]:
""" Get the new face from the predictor and apply pre-warp manipulations.
Applies any requested adjustments to the raw output of the Faceswap model
@@ -208,7 +262,7 @@ class Converter():
Parameters
----------
predicted: dict
predicted: :class:`~scripts.convert.ConvertItem`
The output from :class:`scripts.convert.Predictor`.
frame_size: tuple
The (`width`, `height`) of the final frame in pixels
@@ -220,16 +274,16 @@ class Converter():
background: :class: `numpy.ndarray`
The original frame
"""
logger.trace("Getting: (filename: '%s', faces: %s)",
predicted["filename"], len(predicted["swapped_faces"]))
logger.trace("Getting: (filename: '%s', faces: %s)", # type: ignore
predicted.inbound.filename, len(predicted.swapped_faces))
placeholder = np.zeros((frame_size[1], frame_size[0], 4), dtype="float32")
background = predicted["image"] / np.array(255.0, dtype="float32")
background = predicted.inbound.image / np.array(255.0, dtype="float32")
placeholder[:, :, :3] = background
for new_face, detected_face, reference_face in zip(predicted["swapped_faces"],
predicted["detected_faces"],
predicted["reference_faces"]):
for new_face, detected_face, reference_face in zip(predicted.swapped_faces,
predicted.inbound.detected_faces,
predicted.reference_faces):
predicted_mask = new_face[:, :, -1] if new_face.shape[2] == 4 else None
new_face = new_face[:, :, :3]
interpolator = reference_face.interpolators[1]
@@ -247,12 +301,16 @@ class Converter():
flags=cv2.WARP_INVERSE_MAP | interpolator,
borderMode=cv2.BORDER_TRANSPARENT)
logger.trace("Got filename: '%s'. (placeholders: %s)",
predicted["filename"], placeholder.shape)
logger.trace("Got filename: '%s'. (placeholders: %s)", # type: ignore
predicted.inbound.filename, placeholder.shape)
return placeholder, background
def _pre_warp_adjustments(self, new_face, detected_face, reference_face, predicted_mask):
def _pre_warp_adjustments(self,
new_face: np.ndarray,
detected_face: "DetectedFace",
reference_face: "AlignedFace",
predicted_mask: Optional[np.ndarray]) -> np.ndarray:
""" Run any requested adjustments that can be performed on the raw output from the Faceswap
model.
@@ -277,21 +335,25 @@ class Converter():
The face output from the Faceswap Model with any requested pre-warp adjustments
performed.
"""
logger.trace("new_face shape: %s, predicted_mask shape: %s", new_face.shape,
predicted_mask.shape if predicted_mask is not None else None)
old_face = reference_face.face[..., :3] / 255.0
logger.trace("new_face shape: %s, predicted_mask shape: %s", # type: ignore
new_face.shape, predicted_mask.shape if predicted_mask is not None else None)
old_face = cast(np.ndarray, reference_face.face)[..., :3] / 255.0
new_face, raw_mask = self._get_image_mask(new_face,
detected_face,
predicted_mask,
reference_face)
if self._adjustments["color"] is not None:
new_face = self._adjustments["color"].run(old_face, new_face, raw_mask)
if self._adjustments["seamless"] is not None:
new_face = self._adjustments["seamless"].run(old_face, new_face, raw_mask)
logger.trace("returning: new_face shape %s", new_face.shape)
if self._adjustments.color is not None:
new_face = self._adjustments.color.run(old_face, new_face, raw_mask)
if self._adjustments.seamless is not None:
new_face = self._adjustments.seamless.run(old_face, new_face, raw_mask)
logger.trace("returning: new_face shape %s", new_face.shape) # type: ignore
return new_face
def _get_image_mask(self, new_face, detected_face, predicted_mask, reference_face):
def _get_image_mask(self,
new_face: np.ndarray,
detected_face: "DetectedFace",
predicted_mask: Optional[np.ndarray],
reference_face: "AlignedFace") -> Tuple[np.ndarray, np.ndarray]:
""" Return any selected image mask
Places the requested mask into the new face's Alpha channel.
@@ -312,23 +374,26 @@ class Converter():
-------
:class:`numpy.ndarray`
The swapped face with the requested mask added to the Alpha channel
:class:`numpy.ndarray`
The raw mask with no erosion or blurring applied
"""
logger.trace("Getting mask. Image shape: %s", new_face.shape)
logger.trace("Getting mask. Image shape: %s", new_face.shape) # type: ignore
if self._args.mask_type not in ("none", "predicted"):
mask_centering = detected_face.mask[self._args.mask_type].stored_centering
else:
mask_centering = "face" # Unused but requires a valid value
mask, raw_mask = self._adjustments["mask"].run(detected_face,
reference_face.pose.offset[mask_centering],
reference_face.pose.offset[self._centering],
self._centering,
predicted_mask=predicted_mask)
logger.trace("Adding mask to alpha channel")
assert self._adjustments.mask is not None
mask, raw_mask = self._adjustments.mask.run(detected_face,
reference_face.pose.offset[mask_centering],
reference_face.pose.offset[self._centering],
self._centering,
predicted_mask=predicted_mask)
logger.trace("Adding mask to alpha channel") # type: ignore
new_face = np.concatenate((new_face, mask), -1)
logger.trace("Got mask. Image shape: %s", new_face.shape)
logger.trace("Got mask. Image shape: %s", new_face.shape) # type: ignore
return new_face, raw_mask
def _post_warp_adjustments(self, background, new_image):
def _post_warp_adjustments(self, background: np.ndarray, new_image: np.ndarray) -> np.ndarray:
""" Perform any requested adjustments to the swapped faces after they have been transformed
into the final frame.
@@ -344,8 +409,8 @@ class Converter():
:class:`numpy.ndarray`
The final merged and swapped frame with any requested post-warp adjustments applied
"""
if self._adjustments["sharpening"] is not None:
new_image = self._adjustments["sharpening"].run(new_image)
if self._adjustments.sharpening is not None:
new_image = self._adjustments.sharpening.run(new_image)
if self._draw_transparent:
frame = new_image
@@ -360,7 +425,7 @@ class Converter():
np.clip(frame, 0.0, 1.0, out=frame)
return frame
def _scale_image(self, frame):
def _scale_image(self, frame: np.ndarray) -> np.ndarray:
""" Scale the final image if requested.
If output scale has been requested in command line arguments, scale the output
@@ -378,11 +443,11 @@ class Converter():
"""
if self._scale == 1:
return frame
logger.trace("source frame: %s", frame.shape)
logger.trace("source frame: %s", frame.shape) # type: ignore
interp = cv2.INTER_CUBIC if self._scale > 1 else cv2.INTER_AREA
dims = (round((frame.shape[1] / 2 * self._scale) * 2),
round((frame.shape[0] / 2 * self._scale) * 2))
frame = cv2.resize(frame, dims, interpolation=interp)
logger.trace("resized frame: %s", frame.shape)
logger.trace("resized frame: %s", frame.shape) # type: ignore
np.clip(frame, 0.0, 1.0, out=frame)
return frame

View File

@@ -6,6 +6,7 @@
import logging
import threading
from typing import Dict
from queue import Queue, Empty as QueueEmpty # pylint: disable=unused-import; # noqa
from time import sleep
@@ -13,22 +14,42 @@ from time import sleep
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class QueueManager():
""" Manage queues for availabilty across processes
Don't import this class directly, instead
import the variable: queue_manager """
def __init__(self):
class EventQueue(Queue):
""" Standard Queue object with a separate global shutdown parameter indicating that the main
process, and by extension this queue, should be shut down.
Parameters
----------
shutdown_event: :class:`threading.Event`
The global shutdown event common to all managed queues
maxsize: int, Optional
Upperbound limit on the number of items that can be placed in the queue. Default: `0`
"""
def __init__(self, shutdown_event: threading.Event, maxsize: int = 0) -> None:
super().__init__(maxsize=maxsize)
self._shutdown = shutdown_event
@property
def shutdown(self) -> threading.Event:
""" :class:`threading.Event`: The global shutdown event """
return self._shutdown
class _QueueManager():
""" Manage :class:`EventQueue` objects for availabilty across processes.
Notes
-----
Don't import this class directly, instead import via :func:`queue_manager` """
def __init__(self) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
self.shutdown = threading.Event()
self.queues = dict()
self.queues: Dict[str, EventQueue] = {}
logger.debug("Initialized %s", self.__class__.__name__)
def add_queue(self, name, maxsize=0, create_new=False):
""" Add a queue to the manager.
Adds an event "shutdown" to the queue that can be used to indicate to a process that any
activity on the queue should cease.
def add_queue(self, name: str, maxsize: int = 0, create_new: bool = False) -> str:
""" Add a :class:`EventQueue` to the manager.
Parameters
----------
@@ -50,77 +71,108 @@ class QueueManager():
logger.debug("QueueManager adding: (name: '%s', maxsize: %s, create_new: %s)",
name, maxsize, create_new)
if not create_new and name in self.queues:
raise ValueError("Queue '{}' already exists.".format(name))
raise ValueError(f"Queue '{name}' already exists.")
if create_new and name in self.queues:
i = 0
while name in self.queues:
name = f"{name}{i}"
logger.debug("Duplicate queue name. Updated to: '%s'", name)
queue = Queue(maxsize=maxsize)
setattr(queue, "shutdown", self.shutdown)
self.queues[name] = queue
self.queues[name] = EventQueue(self.shutdown, maxsize=maxsize)
logger.debug("QueueManager added: (name: '%s')", name)
return name
def del_queue(self, name):
""" remove a queue from the manager """
def del_queue(self, name: str) -> None:
""" Remove a queue from the manager
Parameters
----------
name: str
The name of the queue to be deleted. Must exist within the queue manager.
"""
logger.debug("QueueManager deleting: '%s'", name)
del self.queues[name]
logger.debug("QueueManager deleted: '%s'", name)
def get_queue(self, name, maxsize=0):
""" Return a queue from the manager
If it doesn't exist, create it """
def get_queue(self, name: str, maxsize: int = 0) -> EventQueue:
""" Return a :class:`EventQueue` from the manager. If it doesn't exist, create it.
Parameters
----------
name: str
The name of the queue to obtain
maxsize: int, Optional
The maximum queue size. Set to `0` for unlimited. Only used if the requested queue
does not already exist. Default: `0`
"""
logger.debug("QueueManager getting: '%s'", name)
queue = self.queues.get(name, None)
queue = self.queues.get(name)
if not queue:
self.add_queue(name, maxsize)
queue = self.queues[name]
logger.debug("QueueManager got: '%s'", name)
return queue
def terminate_queues(self):
""" Set shutdown event, clear and send EOF to all queues
To be called if there is an error """
def terminate_queues(self) -> None:
""" Terminates all managed queues.
Sets the global shutdown event, clears and send EOF to all queues. To be called if there
is an error """
logger.debug("QueueManager terminating all queues")
self.shutdown.set()
self.flush_queues()
self._flush_queues()
for q_name, queue in self.queues.items():
logger.debug("QueueManager terminating: '%s'", q_name)
queue.put("EOF")
logger.debug("QueueManager terminated all queues")
def flush_queues(self):
""" Empty out all queues """
def _flush_queues(self):
""" Empty out the contents of every managed queue. """
for q_name in self.queues:
self.flush_queue(q_name)
logger.debug("QueueManager flushed all queues")
def flush_queue(self, q_name):
""" Empty out a specific queue """
logger.debug("QueueManager flushing: '%s'", q_name)
queue = self.queues[q_name]
def flush_queue(self, name: str) -> None:
""" Flush the contents from a managed queue.
Parameters
----------
name: str
The name of the managed :class:`EventQueue` to flush
"""
logger.debug("QueueManager flushing: '%s'", name)
queue = self.queues[name]
while not queue.empty():
queue.get(True, 1)
def debug_monitor(self, update_secs=2):
""" Debug tool for monitoring queues """
thread = threading.Thread(target=self.debug_queue_sizes,
args=(update_secs, ))
def debug_monitor(self, update_interval: int = 2) -> None:
""" A debug tool for monitoring managed :class:`EventQueues`.
Prints queue sizes to the console for all managed queues.
Parameters
----------
update_interval: int, Optional
The number of seconds between printing information to the console. Default: 2
"""
thread = threading.Thread(target=self._debug_queue_sizes,
args=(update_interval, ))
thread.daemon = True
thread.start()
def debug_queue_sizes(self, update_secs):
""" Output the queue sizes
logged to INFO so it also displays in console
def _debug_queue_sizes(self, update_interval) -> None:
""" Print the queue size for each managed queue to console.
Parameters
----------
update_interval: int
The number of seconds between printing information to the console
"""
while True:
logger.info("====================================================")
for name in sorted(self.queues.keys()):
logger.info("%s: %s", name, self.queues[name].qsize())
sleep(update_secs)
sleep(update_interval)
queue_manager = QueueManager() # pylint: disable=invalid-name
queue_manager = _QueueManager() # pylint: disable=invalid-name

View File

@@ -8,8 +8,7 @@ from ._base import Adjustment
class Color(Adjustment):
""" Adjust the mean of the color channels to be the same for the swap and old frame """
@staticmethod
def process(old_face, new_face, raw_mask):
def process(self, old_face, new_face, raw_mask):
for _ in [0, 1]:
diff = old_face - new_face
avg_diff = np.sum(diff * raw_mask, axis=(0, 1))

View File

@@ -43,7 +43,7 @@ class Color(Adjustment):
""" Convert colorspace based on mode or back to bgr """
mode = self.config["colorspace"].lower()
colorspace = "YCrCb" if mode == "ycrcb" else mode.upper()
conversion = "{}2BGR".format(colorspace) if to_bgr else "BGR2{}".format(colorspace)
conversion = f"{colorspace}2BGR" if to_bgr else f"BGR2{colorspace}"
image = cv2.cvtColor(new_face.astype("uint8"), # pylint: disable=no-member
getattr(cv2, "COLOR_{}".format(conversion))).astype("float32") / 255.0
getattr(cv2, f"COLOR_{conversion}")).astype("float32") / 255.0
return image

View File

@@ -16,8 +16,7 @@ class Color(Adjustment):
and does not have a natural home, so here for now.
"""
@staticmethod
def process(old_face, new_face, raw_mask):
def process(self, old_face, new_face, raw_mask):
height, width, _ = old_face.shape
height = height // 2
width = width // 2

View File

@@ -137,7 +137,7 @@ class Output():
"""
raise NotImplementedError
def pre_encode(self, image: np.ndarray) -> Any: # pylint: disable=unused-argument,no-self-use
def pre_encode(self, image: np.ndarray) -> Any: # pylint: disable=unused-argument
""" Some writer plugins support the pre-encoding of images prior to saving out. As
patching is done in multiple threads, but writing is done in a single thread, it can
speed up the process to do any pre-encoding as part of the converter process.

View File

@@ -3,7 +3,7 @@
import os
from math import ceil
from subprocess import CalledProcessError, check_output, STDOUT
from typing import Optional, List, Tuple, Generator
from typing import cast, Generator, List, Optional, Tuple
import imageio
import imageio_ffmpeg as im_ffm
@@ -32,7 +32,7 @@ class Writer(Output):
def __init__(self,
output_folder: str,
total_count: int,
frame_ranges: Optional[List[Tuple[int]]],
frame_ranges: Optional[List[Tuple[int, int]]],
source_video: str,
**kwargs) -> None:
super().__init__(output_folder, **kwargs)
@@ -40,7 +40,7 @@ class Writer(Output):
total_count, frame_ranges, source_video)
self._source_video: str = source_video
self._output_filename: str = self._get_output_filename()
self._frame_ranges: Optional[List[Tuple[int]]] = frame_ranges
self._frame_ranges: Optional[List[Tuple[int, int]]] = frame_ranges
self.frame_order: List[int] = self._set_frame_order(total_count)
self._output_dimensions: Optional[str] = None # Fix dims on 1st received frame
# Need to know dimensions of first frame, so set writer then
@@ -90,7 +90,7 @@ class Writer(Output):
""" str or ``None``: The audio codec to use. This will either be ``"copy"`` (the default) or
``None`` if skip muxing has been selected in configuration options, or if frame ranges have
been passed in the command line arguments. """
retval = "copy"
retval: Optional[str] = "copy"
if self.config["skip_mux"]:
logger.info("Skipping audio muxing due to configuration settings.")
retval = None
@@ -128,9 +128,9 @@ class Writer(Output):
try:
out = check_output(cmd, stderr=STDOUT)
except CalledProcessError as err:
out = err.output.decode(errors="ignore")
raise ValueError("Error checking audio stream. Status: "
f"{err.returncode}\n{out}") from err
err_out = err.output.decode(errors="ignore")
msg = f"Error checking audio stream. Status: {err.returncode}\n{err_out}"
raise ValueError(msg) from err
retval = False
for line in out.splitlines():
@@ -191,7 +191,7 @@ class Writer(Output):
logger.debug("frame_order: %s", retval)
return retval
def _get_writer(self, frame_dims: Tuple[int]) -> Generator[None, np.ndarray, None]:
def _get_writer(self, frame_dims: Tuple[int, int]) -> Generator[None, np.ndarray, None]:
""" Add the requested encoding options and return the writer.
Parameters
@@ -235,15 +235,16 @@ class Writer(Output):
image: :class:`numpy.ndarray`
The converted image to be written
"""
logger.trace("Received frame: (filename: '%s', shape: %s", filename, image.shape)
logger.trace("Received frame: (filename: '%s', shape: %s", # type: ignore
filename, image.shape)
if not self._output_dimensions:
input_dims = image.shape[:2]
input_dims = cast(Tuple[int, int], image.shape[:2])
self._set_dimensions(input_dims)
self._writer = self._get_writer(input_dims)
self.cache_frame(filename, image)
self._save_from_cache()
def _set_dimensions(self, frame_dims: Tuple[int]) -> None:
def _set_dimensions(self, frame_dims: Tuple[int, int]) -> None:
""" Set the attribute :attr:`_output_dimensions` based on the first frame received.
This protects against different sized images coming in and ensures all images are written
to ffmpeg at the same size. Dimensions are mapped to a macro block size 8.
@@ -261,16 +262,18 @@ class Writer(Output):
def _save_from_cache(self) -> None:
""" Writes any consecutive frames to the video container that are ready to be output
from the cache. """
assert self._writer is not None
while self.frame_order:
if self.frame_order[0] not in self.cache:
logger.trace("Next frame not ready. Continuing")
logger.trace("Next frame not ready. Continuing") # type: ignore
break
save_no = self.frame_order.pop(0)
save_image = self.cache.pop(save_no)
logger.trace("Rendering from cache. Frame no: %s", save_no)
logger.trace("Rendering from cache. Frame no: %s", save_no) # type: ignore
self._writer.send(np.ascontiguousarray(save_image[:, :, ::-1]))
logger.trace("Current cache size: %s", len(self.cache))
logger.trace("Current cache size: %s", len(self.cache)) # type: ignore
def close(self) -> None:
""" Close the ffmpeg writer and mux the audio """
self._writer.close()
if self._writer is not None:
self._writer.close()

View File

@@ -1,13 +1,16 @@
#!/usr/bin/env python3
""" Animated GIF writer for faceswap.py converter """
import os
from typing import Optional, List, Tuple
from typing import Optional, List, Tuple, TYPE_CHECKING
import cv2
import imageio
from ._base import Output, logger
if TYPE_CHECKING:
from imageio.plugins.pillowmulti import GIFFormat
class Writer(Output):
""" GIF output writer using imageio.
@@ -28,12 +31,12 @@ class Writer(Output):
def __init__(self,
output_folder: str,
total_count: int,
frame_ranges: Optional[List[Tuple[int]]],
frame_ranges: Optional[List[Tuple[int, int]]],
**kwargs) -> None:
logger.debug("total_count: %s, frame_ranges: %s", total_count, frame_ranges)
super().__init__(output_folder, **kwargs)
self.frame_order: List[int] = self._set_frame_order(total_count, frame_ranges)
self._output_dimensions: Optional[str] = None # Fix dims on 1st received frame
self._output_dimensions: Optional[Tuple[int, int]] = None # Fix dims on 1st received frame
# Need to know dimensions of first frame, so set writer then
self._writer: Optional[imageio.plugins.pillowmulti.GIFFormat.Writer] = None
self._gif_file: Optional[str] = None # Set filename based on first file seen
@@ -46,7 +49,8 @@ class Writer(Output):
return kwargs
@staticmethod
def _set_frame_order(total_count: int, frame_ranges: Optional[List[Tuple[int]]]) -> List[int]:
def _set_frame_order(total_count: int,
frame_ranges: Optional[List[Tuple[int, int]]]) -> List[int]:
""" Obtain the full list of frames to be converted in order.
Parameters
@@ -71,7 +75,7 @@ class Writer(Output):
logger.debug("frame_order: %s", retval)
return retval
def _get_writer(self) -> imageio.plugins.pillowmulti.GIFFormat.Writer:
def _get_writer(self) -> "GIFFormat.Writer":
""" Obtain the GIF writer with the requested GIF encoding options.
Returns
@@ -80,7 +84,6 @@ class Writer(Output):
The imageio GIF writer
"""
logger.debug("writer config: %s", self.config)
return imageio.get_writer(self._gif_file,
mode="i",
**self._gif_params)
@@ -96,7 +99,8 @@ class Writer(Output):
image: :class:`numpy.ndarray`
The converted image to be written
"""
logger.trace("Received frame: (filename: '%s', shape: %s", filename, image.shape)
logger.trace("Received frame: (filename: '%s', shape: %s", # type: ignore
filename, image.shape)
if not self._gif_file:
self._set_gif_filename(filename)
self._set_dimensions(image.shape[:2])
@@ -140,7 +144,7 @@ class Writer(Output):
self._gif_file = retval
logger.info("Outputting to: '%s'", self._gif_file)
def _set_dimensions(self, frame_dims: str) -> None:
def _set_dimensions(self, frame_dims: Tuple[int, int]) -> None:
""" Set the attribute :attr:`_output_dimensions` based on the first frame received. This
protects against different sized images coming in and ensure all images get written to the
Gif at the sema dimensions. """
@@ -151,16 +155,18 @@ class Writer(Output):
def _save_from_cache(self) -> None:
""" Writes any consecutive frames to the GIF container that are ready to be output
from the cache. """
assert self._writer is not None
while self.frame_order:
if self.frame_order[0] not in self.cache:
logger.trace("Next frame not ready. Continuing")
logger.trace("Next frame not ready. Continuing") # type: ignore
break
save_no = self.frame_order.pop(0)
save_image = self.cache.pop(save_no)
logger.trace("Rendering from cache. Frame no: %s", save_no)
logger.trace("Rendering from cache. Frame no: %s", save_no) # type: ignore
self._writer.append_data(save_image[:, :, ::-1])
logger.trace("Current cache size: %s", len(self.cache))
logger.trace("Current cache size: %s", len(self.cache)) # type: ignore
def close(self) -> None:
""" Close the GIF writer on completion. """
self._writer.close()
if self._writer is not None:
self._writer.close()

View File

@@ -48,10 +48,10 @@ class Writer(Output):
filetype = self.config["format"]
args: Tuple[int, ...] = tuple()
if filetype == "jpg" and self.config["jpg_quality"] > 0:
args = (cv2.IMWRITE_JPEG_QUALITY, # pylint: disable=no-member
args = (cv2.IMWRITE_JPEG_QUALITY,
self.config["jpg_quality"])
if filetype == "png" and self.config["png_compress_level"] > -1:
args = (cv2.IMWRITE_PNG_COMPRESSION, # pylint: disable=no-member
args = (cv2.IMWRITE_PNG_COMPRESSION,
self.config["png_compress_level"])
logger.debug(args)
return args
@@ -99,11 +99,11 @@ class Writer(Output):
mask = image[..., -1]
image = image[..., :3]
retval.append(cv2.imencode(self._extension, # pylint: disable=no-member
retval.append(cv2.imencode(self._extension,
mask,
self._args)[1])
retval.insert(0, cv2.imencode(self._extension, # pylint: disable=no-member
retval.insert(0, cv2.imencode(self._extension,
image,
self._args)[1])
return retval

View File

@@ -11,6 +11,8 @@ plugins either in parallel or in series, giving easy access to input and output.
"""
import logging
import sys
from typing import cast, List, Optional, Tuple, TYPE_CHECKING
import cv2
@@ -19,6 +21,15 @@ from lib.queue_manager import queue_manager, QueueEmpty
from lib.utils import get_backend
from plugins.plugin_loader import PluginLoader
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
import numpy as np
from lib.align.detected_face import DetectedFace
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
_INSTANCES = -1 # Tracking for multiple instances of pipeline
@@ -577,15 +588,15 @@ class Extractor():
if get_backend() != "nvidia":
logger.debug("Backend is not Nvidia. Not updating batchsize requirements")
return
if sum([plugin.vram for plugin in self._active_plugins]) == 0:
if sum(plugin.vram for plugin in self._active_plugins) == 0:
logger.debug("No plugins use VRAM. Not updating batchsize requirements.")
return
batch_required = sum([plugin.vram_per_batch * plugin.batchsize
for plugin in self._active_plugins])
batch_required = sum(plugin.vram_per_batch * plugin.batchsize
for plugin in self._active_plugins)
gpu_plugins = [p for p in self._current_phase if self._vram_per_phase[p] > 0]
scaling = self._parallel_scaling.get(len(gpu_plugins), self._scaling_fallback)
plugins_required = sum([self._vram_per_phase[p] for p in gpu_plugins]) * scaling
plugins_required = sum(self._vram_per_phase[p] for p in gpu_plugins) * scaling
if plugins_required + batch_required <= self._vram_stats["vram_free"]:
logger.debug("Plugin requirements within threshold: (plugins_required: %sMB, "
"vram_free: %sMB)", plugins_required, self._vram_stats["vram_free"])
@@ -674,44 +685,50 @@ class ExtractMedia():
The original frame
detected_faces: list, optional
A list of :class:`~lib.align.DetectedFace` objects. Detected faces can be added
later with :func:`add_detected_faces`. Default: ``None``
later with :func:`add_detected_faces`. Setting ``None`` will default to an empty list.
Default: ``None``
"""
def __init__(self, filename, image, detected_faces=None):
logger.trace("Initializing %s: (filename: '%s', image shape: %s, detected_faces: %s)",
self.__class__.__name__, filename, image.shape, detected_faces)
def __init__(self,
filename: str,
image: "np.ndarray",
detected_faces: Optional[List["DetectedFace"]] = None) -> None:
logger.trace("Initializing %s: (filename: '%s', image shape: %s, " # type: ignore
"detected_faces: %s)", self.__class__.__name__, filename, image.shape,
detected_faces)
self._filename = filename
self._image = image
self._image_shape = image.shape
self._detected_faces = detected_faces
self._image: Optional["np.ndarray"] = image
self._image_shape = cast(Tuple[int, int, int], image.shape)
self._detected_faces: List["DetectedFace"] = ([] if detected_faces is None
else detected_faces)
@property
def filename(self):
def filename(self) -> str:
""" str: The base name of the :attr:`image` filename. """
return self._filename
@property
def image(self):
def image(self) -> "np.ndarray":
""" :class:`numpy.ndarray`: The source frame for this object. """
assert self._image is not None
return self._image
@property
def image_shape(self):
def image_shape(self) -> Tuple[int, int, int]:
""" tuple: The shape of the stored :attr:`image`. """
return self._image_shape
@property
def image_size(self):
def image_size(self) -> Tuple[int, int]:
""" tuple: The (`height`, `width`) of the stored :attr:`image`. """
return self._image_shape[:2]
@property
def detected_faces(self):
"""list: A list of :class:`~lib.align.DetectedFace` objects in the
:attr:`image`. """
def detected_faces(self) -> List["DetectedFace"]:
"""list: A list of :class:`~lib.align.DetectedFace` objects in the :attr:`image`. """
return self._detected_faces
def get_image_copy(self, color_format):
def get_image_copy(self, color_format: Literal["BGR", "RGB", "GRAY"]) -> "np.ndarray":
""" Get a copy of the image in the requested color format.
Parameters
@@ -724,11 +741,12 @@ class ExtractMedia():
:class:`numpy.ndarray`:
A copy of :attr:`image` in the requested :attr:`color_format`
"""
logger.trace("Requested color format '%s' for frame '%s'", color_format, self._filename)
logger.trace("Requested color format '%s' for frame '%s'", # type: ignore
color_format, self._filename)
image = getattr(self, f"_image_as_{color_format.lower()}")()
return image
def add_detected_faces(self, faces):
def add_detected_faces(self, faces: List["DetectedFace"]) -> None:
""" Add detected faces to the object. Called at the end of each extraction phase.
Parameters
@@ -736,21 +754,21 @@ class ExtractMedia():
faces: list
A list of :class:`~lib.align.DetectedFace` objects
"""
logger.trace("Adding detected faces for filename: '%s'. (faces: %s, lrtb: %s)",
self._filename, faces,
logger.trace("Adding detected faces for filename: '%s'. " # type: ignore
"(faces: %s, lrtb: %s)", self._filename, faces,
[(face.left, face.right, face.top, face.bottom) for face in faces])
self._detected_faces = faces
def remove_image(self):
def remove_image(self) -> None:
""" Delete the image and reset :attr:`image` to ``None``.
Required for multi-phase extraction to avoid the frames stacking RAM.
"""
logger.trace("Removing image for filename: '%s'", self._filename)
logger.trace("Removing image for filename: '%s'", self._filename) # type: ignore
del self._image
self._image = None
def set_image(self, image):
def set_image(self, image: "np.ndarray") -> None:
""" Add the image back into :attr:`image`
Required for multi-phase extraction adds the image back to this object.
@@ -760,33 +778,33 @@ class ExtractMedia():
image: :class:`numpy.ndarry`
The original frame to be re-applied to for this :attr:`filename`
"""
logger.trace("Reapplying image: (filename: `%s`, image shape: %s)",
logger.trace("Reapplying image: (filename: `%s`, image shape: %s)", # type: ignore
self._filename, image.shape)
self._image = image
def _image_as_bgr(self):
def _image_as_bgr(self) -> "np.ndarray":
""" Get a copy of the source frame in BGR format.
Returns
-------
:class:`numpy.ndarray`:
A copy of :attr:`image` in BGR color format """
return self._image[..., :3].copy()
return self.image[..., :3].copy()
def _image_as_rgb(self):
def _image_as_rgb(self) -> "np.ndarray":
""" Get a copy of the source frame in RGB format.
Returns
-------
:class:`numpy.ndarray`:
A copy of :attr:`image` in RGB color format """
return self._image[..., 2::-1].copy()
return self.image[..., 2::-1].copy()
def _image_as_gray(self):
def _image_as_gray(self) -> "np.ndarray":
""" Get a copy of the source frame in gray-scale format.
Returns
-------
:class:`numpy.ndarray`:
A copy of :attr:`image` in gray-scale color format """
return cv2.cvtColor(self._image.copy(), cv2.COLOR_BGR2GRAY)
return cv2.cvtColor(self.image.copy(), cv2.COLOR_BGR2GRAY)

View File

@@ -1,12 +1,14 @@
#!/usr/bin python3
""" Main entry point to the convert process of FaceSwap """
from dataclasses import dataclass, field
import logging
import re
import os
import sys
from threading import Event
from time import sleep
from typing import Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import cv2
import numpy as np
@@ -24,9 +26,46 @@ from lib.utils import FaceswapError, get_backend, get_folder, get_image_paths
from plugins.extract.pipeline import Extractor, ExtractMedia
from plugins.plugin_loader import PluginLoader
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if TYPE_CHECKING:
from argparse import Namespace
from plugins.convert.writer._base import Output
from plugins.train.model._base import ModelBase
from lib.align.aligned_face import CenteringType
from lib.queue_manager import EventQueue
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@dataclass
class ConvertItem:
""" A single frame with associated objects passing through the convert process.
Parameters
----------
input: :class:`~plugins.extract.pipeline.ExtractMedia`
The ExtractMedia object holding the :attr:`filename`, :attr:`image` and attr:`list` of
:class:`~lib.align.DetectedFace` objects loaded from disk
feed_faces: list, Optional
list of :class:`lib.align.AlignedFace` objects for feeding into the model's predict
function
reference_faces: list, Optional
list of :class:`lib.align.AlignedFace` objects at model output sized for using as reference
in the convert functionfor feeding into the model's predict
swapped_faces: :class:`np.ndarray`
The swapped faces returned from the model's predict function
"""
inbound: ExtractMedia
feed_faces: List[AlignedFace] = field(default_factory=list)
reference_faces: List[AlignedFace] = field(default_factory=list)
swapped_faces: np.ndarray = np.array([])
class Convert(): # pylint:disable=too-few-public-methods
""" The Faceswap Face Conversion Process.
@@ -45,11 +84,10 @@ class Convert(): # pylint:disable=too-few-public-methods
The arguments to be passed to the convert process as generated from Faceswap's command
line arguments
"""
def __init__(self, arguments):
def __init__(self, arguments: "Namespace") -> None:
logger.debug("Initializing %s: (args: %s)", self.__class__.__name__, arguments)
self._args = arguments
self._patch_threads = None
self._images = ImagesLoader(self._args.input_dir, fast_count=True)
self._alignments = Alignments(self._args, False, self._images.is_video)
if self._alignments.version == 1.0:
@@ -74,12 +112,13 @@ class Convert(): # pylint:disable=too-few-public-methods
self._disk_io.pre_encode,
arguments,
configfile=configfile)
self._patch_threads = self._get_threads()
logger.debug("Initialized %s", self.__class__.__name__)
@property
def _queue_size(self):
def _queue_size(self) -> int:
""" int: Size of the converter queues. 16 for single process otherwise 32 """
# TODO why do we need such big queues?
if self._args.singleprocess:
retval = 16
else:
@@ -88,7 +127,7 @@ class Convert(): # pylint:disable=too-few-public-methods
return retval
@property
def _pool_processes(self):
def _pool_processes(self) -> int:
""" int: The number of threads to run in parallel. Based on user options and number of
available processors. """
if self._args.singleprocess:
@@ -101,7 +140,7 @@ class Convert(): # pylint:disable=too-few-public-methods
logger.debug(retval)
return retval
def _validate(self):
def _validate(self) -> None:
""" Validate the Command Line Options.
Ensure that certain cli selections are valid and won't result in an error. Checks:
@@ -133,12 +172,12 @@ class Convert(): # pylint:disable=too-few-public-methods
if (not self._args.on_the_fly and
self._args.mask_type not in ("none", "predicted") and
not self._alignments.mask_is_valid(self._args.mask_type)):
msg = ("You have selected the Mask Type `{}` but at least one face does not have this "
"mask stored in the Alignments File.\nYou should generate the required masks "
"with the Mask Tool or set the Mask Type option to an existing Mask Type.\nA "
"summary of existing masks is as follows:\nTotal faces: {}, Masks: "
"{}".format(self._args.mask_type, self._alignments.faces_count,
self._alignments.mask_summary))
msg = (f"You have selected the Mask Type `{self._args.mask_type}` but at least one "
"face does not have this mask stored in the Alignments File.\nYou should "
"generate the required masks with the Mask Tool or set the Mask Type option to "
"an existing Mask Type.\nA summary of existing masks is as follows:\nTotal "
f"faces: {self._alignments.faces_count}, "
f"Masks: {self._alignments.mask_summary}")
raise FaceswapError(msg)
if self._args.mask_type == "predicted" and not self._predictor.has_predicted_mask:
@@ -154,16 +193,34 @@ class Convert(): # pylint:disable=too-few-public-methods
"mask. Selecting first available mask: '%s'", mask_type)
self._args.mask_type = mask_type
def _add_queues(self):
def _add_queues(self) -> None:
""" Add the queues for in, patch and out. """
logger.debug("Adding queues. Queue size: %s", self._queue_size)
for qname in ("convert_in", "convert_out", "patch"):
queue_manager.add_queue(qname, self._queue_size)
def process(self):
def _get_threads(self) -> MultiThread:
""" Get the threads for patching the converted faces onto the frames.
Returns
:class:`lib.multithreading.MultiThread`
The threads that perform the patching of swapped faces onto the output frames
"""
# TODO Check if multiple threads actually speeds anything up
save_queue = queue_manager.get_queue("convert_out")
patch_queue = queue_manager.get_queue("patch")
return MultiThread(self._converter.process, patch_queue, save_queue,
thread_count=self._pool_processes, name="patch")
def process(self) -> None:
""" The entry point for triggering the Conversion Process.
Should only be called from :class:`lib.cli.launcher.ScriptExecutor`
Raises
------
FaceswapError
Error raised if the process runs out of memory
"""
logger.debug("Starting Conversion")
# queue_manager.debug_monitor(5)
@@ -184,15 +241,10 @@ class Convert(): # pylint:disable=too-few-public-methods
"'singleprocess' flag (-sp) or lowering the number of parallel jobs (-j).")
raise FaceswapError(msg) from err
def _convert_images(self):
def _convert_images(self) -> None:
""" Start the multi-threaded patching process, monitor all threads for errors and join on
completion. """
logger.debug("Converting images")
save_queue = queue_manager.get_queue("convert_out")
patch_queue = queue_manager.get_queue("patch")
self._patch_threads = MultiThread(self._converter.process, patch_queue, save_queue,
thread_count=self._pool_processes, name="patch")
self._patch_threads.start()
while True:
self._check_thread_error()
@@ -206,11 +258,17 @@ class Convert(): # pylint:disable=too-few-public-methods
self._patch_threads.join()
logger.debug("Putting EOF")
save_queue.put("EOF")
queue_manager.get_queue("convert_out").put("EOF")
logger.debug("Converted images")
def _check_thread_error(self):
""" Monitor all running threads for errors, and raise accordingly. """
def _check_thread_error(self) -> None:
""" Monitor all running threads for errors, and raise accordingly.
Raises
------
Error
Re-raises any error encountered within any of the running threads
"""
for thread in (self._predictor.thread,
self._disk_io.load_thread,
self._disk_io.save_thread,
@@ -236,7 +294,8 @@ class DiskIO():
line arguments
"""
def __init__(self, alignments, images, arguments):
def __init__(self,
alignments: Alignments, images: ImagesLoader, arguments: "Namespace") -> None:
logger.debug("Initializing %s: (alignments: %s, images: %s, arguments: %s)",
self.__class__.__name__, alignments, images, arguments)
self._alignments = alignments
@@ -253,61 +312,63 @@ class DiskIO():
# Extractor for on the fly detection
self._extractor = self._load_extractor()
self._queues = dict(load=None, save=None)
self._threads = dict(oad=None, save=None)
self._queues: Dict[Literal["load", "save"], "EventQueue"] = {}
self._threads: Dict[Literal["load", "save"], MultiThread] = {}
self._init_threads()
logger.debug("Initialized %s", self.__class__.__name__)
@property
def completion_event(self):
def completion_event(self) -> Event:
""" :class:`event.Event`: Event is set when the DiskIO Save task is complete """
return self._completion_event
@property
def draw_transparent(self):
def draw_transparent(self) -> bool:
""" bool: ``True`` if the selected writer's Draw_transparent configuration item is set
otherwise ``False`` """
return self._writer.config.get("draw_transparent", False)
@property
def pre_encode(self):
def pre_encode(self) -> Optional[Callable[[np.ndarray], List[bytes]]]:
""" python function: Selected writer's pre-encode function, if it has one,
otherwise ``None`` """
dummy = np.zeros((20, 20, 3), dtype="uint8")
test = self._writer.pre_encode(dummy)
retval = None if test is None else self._writer.pre_encode
retval: Optional[Callable[[np.ndarray],
List[bytes]]] = None if test is None else self._writer.pre_encode
logger.debug("Writer pre_encode function: %s", retval)
return retval
@property
def save_thread(self):
def save_thread(self) -> MultiThread:
""" :class:`lib.multithreading.MultiThread`: The thread that is running the image writing
operation. """
return self._threads["save"]
@property
def load_thread(self):
def load_thread(self) -> MultiThread:
""" :class:`lib.multithreading.MultiThread`: The thread that is running the image loading
operation. """
return self._threads["load"]
@property
def load_queue(self):
""" :class:`queue.Queue()`: The queue that images and detected faces are loaded into. """
def load_queue(self) -> "EventQueue":
""" :class:`~lib.queue_manager.EventQueue`: The queue that images and detected faces are "
"loaded into. """
return self._queues["load"]
@property
def _total_count(self):
def _total_count(self) -> int:
""" int: The total number of frames to be converted """
if self._frame_ranges and not self._args.keep_unchanged:
retval = sum([fr[1] - fr[0] + 1 for fr in self._frame_ranges])
retval = sum(fr[1] - fr[0] + 1 for fr in self._frame_ranges)
else:
retval = self._images.count
logger.debug(retval)
return retval
# Initialization
def _get_writer(self):
def _get_writer(self) -> "Output":
""" Load the selected writer plugin.
Returns
@@ -328,7 +389,7 @@ class DiskIO():
return PluginLoader.get_converter("writer", self._args.writer)(*args,
configfile=configfile)
def _get_frame_ranges(self):
def _get_frame_ranges(self) -> Optional[List[Tuple[int, int]]]:
""" Obtain the frame ranges that are to be converted.
If frame ranges have been specified, then split the command line formatted arguments into
@@ -357,7 +418,7 @@ class DiskIO():
raise FaceswapError("Frame Ranges specified, but could not determine frame numbering "
"from filenames")
retval = list()
retval = []
for rng in self._args.frame_ranges:
if "-" not in rng:
raise FaceswapError("Frame Ranges not specified in the correct format")
@@ -366,7 +427,7 @@ class DiskIO():
logger.debug("frame ranges: %s", retval)
return retval
def _load_extractor(self):
def _load_extractor(self) -> Optional[Extractor]:
""" Load the CV2-DNN Face Extractor Chain.
For On-The-Fly conversion we use a CPU based extractor to avoid stacking the GPU.
@@ -405,18 +466,18 @@ class DiskIO():
logger.debug("Loaded extractor")
return extractor
def _init_threads(self):
def _init_threads(self) -> None:
""" Initialize queues and threads.
Creates the load and save queues and the load and save threads. Starts the threads.
"""
logger.debug("Initializing DiskIO Threads")
for task in ("load", "save"):
for task in get_args(Literal["load", "save"]):
self._add_queue(task)
self._start_thread(task)
logger.debug("Initialized DiskIO Threads")
def _add_queue(self, task):
def _add_queue(self, task: Literal["load", "save"]) -> None:
""" Add the queue to queue_manager and to :attr:`self._queues` for the given task.
Parameters
@@ -434,7 +495,7 @@ class DiskIO():
self._queues[task] = queue_manager.get_queue(q_name)
logger.debug("Added queue for task: '%s'", task)
def _start_thread(self, task):
def _start_thread(self, task: Literal["load", "save"]) -> None:
""" Create the thread for the given task, add it it :attr:`self._threads` and start it.
Parameters
@@ -444,14 +505,14 @@ class DiskIO():
"""
logger.debug("Starting thread: '%s'", task)
args = self._completion_event if task == "save" else None
func = getattr(self, "_{}".format(task))
func = getattr(self, f"_{task}")
io_thread = MultiThread(func, args, thread_count=1)
io_thread.start()
self._threads[task] = io_thread
logger.debug("Started thread: '%s'", task)
# Loading tasks
def _load(self, *args): # pylint: disable=unused-argument
def _load(self, *args) -> None: # pylint: disable=unused-argument
""" Load frames from disk.
In a background thread:
@@ -474,23 +535,23 @@ class DiskIO():
continue
if self._check_skipframe(filename):
if self._args.keep_unchanged:
logger.trace("Saving unchanged frame: %s", filename)
logger.trace("Saving unchanged frame: %s", filename) # type:ignore
out_file = os.path.join(self._args.output_dir, os.path.basename(filename))
self._queues["save"].put((out_file, image))
else:
logger.trace("Discarding frame: '%s'", filename)
logger.trace("Discarding frame: '%s'", filename) # type:ignore
continue
detected_faces = self._get_detected_faces(filename, image)
item = dict(filename=filename, image=image, detected_faces=detected_faces)
self._pre_process.do_actions(item)
item = ConvertItem(ExtractMedia(filename, image, detected_faces))
self._pre_process.do_actions(item.inbound)
self._queues["load"].put(item)
logger.debug("Putting EOF")
self._queues["load"].put("EOF")
logger.debug("Load Images: Complete")
def _check_skipframe(self, filename):
def _check_skipframe(self, filename: str) -> bool:
""" Check whether a frame is to be skipped.
Parameters
@@ -504,18 +565,18 @@ class DiskIO():
``True`` if the frame is to be skipped otherwise ``False``
"""
if not self._frame_ranges:
return None
return False
indices = self._imageidxre.findall(filename)
if not indices:
logger.warning("Could not determine frame number. Frame will be converted: '%s'",
filename)
return False
idx = int(indices[0]) if indices else None
idx = int(indices[0])
skipframe = not any(map(lambda b: b[0] <= idx <= b[1], self._frame_ranges))
logger.trace("idx: %s, skipframe: %s", idx, skipframe)
logger.trace("idx: %s, skipframe: %s", idx, skipframe) # type: ignore
return skipframe
def _get_detected_faces(self, filename, image):
def _get_detected_faces(self, filename: str, image: np.ndarray) -> List[DetectedFace]:
""" Return the detected faces for the given image.
If we have an alignments file, then the detected faces are created from that file. If
@@ -533,15 +594,15 @@ class DiskIO():
list
List of :class:`lib.align.DetectedFace` objects
"""
logger.trace("Getting faces for: '%s'", filename)
logger.trace("Getting faces for: '%s'", filename) # type:ignore
if not self._extractor:
detected_faces = self._alignments_faces(os.path.basename(filename), image)
else:
detected_faces = self._detect_faces(filename, image)
logger.trace("Got %s faces for: '%s'", len(detected_faces), filename)
logger.trace("Got %s faces for: '%s'", len(detected_faces), filename) # type:ignore
return detected_faces
def _alignments_faces(self, frame_name, image):
def _alignments_faces(self, frame_name: str, image: np.ndarray) -> List[DetectedFace]:
""" Return detected faces from an alignments file.
Parameters
@@ -557,10 +618,10 @@ class DiskIO():
List of :class:`lib.align.DetectedFace` objects
"""
if not self._check_alignments(frame_name):
return list()
return []
faces = self._alignments.get_faces_in_frame(frame_name)
detected_faces = list()
detected_faces = []
for rawface in faces:
face = DetectedFace()
@@ -568,7 +629,7 @@ class DiskIO():
detected_faces.append(face)
return detected_faces
def _check_alignments(self, frame_name):
def _check_alignments(self, frame_name: str) -> bool:
""" Ensure that we have alignments for the current frame.
If we have no alignments for this image, skip it and output a message.
@@ -585,11 +646,10 @@ class DiskIO():
"""
have_alignments = self._alignments.frame_exists(frame_name)
if not have_alignments:
tqdm.write("No alignment found for {}, "
"skipping".format(frame_name))
tqdm.write(f"No alignment found for {frame_name}, skipping")
return have_alignments
def _detect_faces(self, filename, image):
def _detect_faces(self, filename: str, image: np.ndarray) -> List[DetectedFace]:
""" Extract the face from a frame for On-The-Fly conversion.
Pulls detected faces out of the Extraction pipeline.
@@ -606,12 +666,13 @@ class DiskIO():
list
List of :class:`lib.align.DetectedFace` objects
"""
assert self._extractor is not None
self._extractor.input_queue.put(ExtractMedia(filename, image))
faces = next(self._extractor.detected_faces())
return faces.detected_faces
# Saving tasks
def _save(self, completion_event):
def _save(self, completion_event: Event) -> None:
""" Save the converted images.
Puts the selected writer into a background thread and feeds it from the output of the
@@ -650,7 +711,7 @@ class Predict():
Parameters
----------
in_queue: :class:`queue.Queue`
in_queue: :class:`~lib.queue_manager.EventQueue`
The queue that contains images and detected faces for feeding the model
queue_size: int
The maximum size of the input queue
@@ -658,7 +719,7 @@ class Predict():
The arguments that were passed to the convert process as generated from Faceswap's command
line arguments
"""
def __init__(self, in_queue, queue_size, arguments):
def __init__(self, in_queue: "EventQueue", queue_size: int, arguments: "Namespace") -> None:
logger.debug("Initializing %s: (args: %s, queue_size: %s, in_queue: %s)",
self.__class__.__name__, arguments, queue_size, in_queue)
self._args = arguments
@@ -678,52 +739,52 @@ class Predict():
logger.debug("Initialized %s: (out_queue: %s)", self.__class__.__name__, self._out_queue)
@property
def thread(self):
def thread(self) -> MultiThread:
""" :class:`~lib.multithreading.MultiThread`: The thread that is running the prediction
function from the Faceswap model. """
return self._thread
@property
def in_queue(self):
""" :class:`queue.Queue`: The input queue to the predictor. """
def in_queue(self) -> "EventQueue":
""" :class:`~lib.queue_manager.EventQueue`: The input queue to the predictor. """
return self._in_queue
@property
def out_queue(self):
""" :class:`queue.Queue`: The output queue from the predictor. """
def out_queue(self) -> "EventQueue":
""" :class:`~lib.queue_manager.EventQueue`: The output queue from the predictor. """
return self._out_queue
@property
def faces_count(self):
def faces_count(self) -> int:
""" int: The total number of faces seen by the Predictor. """
return self._faces_count
@property
def verify_output(self):
def verify_output(self) -> bool:
""" bool: ``True`` if multiple faces have been found in frames, otherwise ``False``. """
return self._verify_output
@property
def coverage_ratio(self):
def coverage_ratio(self) -> float:
""" float: The coverage ratio that the model was trained at. """
return self._coverage_ratio
@property
def centering(self):
""" str: The centering that the model was trained on (`"face"` or `"legacy"`) """
def centering(self) -> "CenteringType":
""" str: The centering that the model was trained on (`"head", "face"` or `"legacy"`) """
return self._centering
@property
def has_predicted_mask(self):
def has_predicted_mask(self) -> bool:
""" bool: ``True`` if the model was trained to learn a mask, otherwise ``False``. """
return bool(self._model.config.get("learn_mask", False))
@property
def output_size(self):
def output_size(self) -> int:
""" int: The size in pixels of the Faceswap model output. """
return self._sizes["output"]
def _get_io_sizes(self):
def _get_io_sizes(self) -> Dict[str, int]:
""" Obtain the input size and output size of the model.
Returns
@@ -739,7 +800,7 @@ class Predict():
logger.debug(retval)
return retval
def _load_model(self):
def _load_model(self) -> "ModelBase":
""" Load the Faceswap model.
Returns
@@ -757,7 +818,7 @@ class Predict():
logger.debug("Loaded Model")
return model
def _get_batchsize(self, queue_size):
def _get_batchsize(self, queue_size: int) -> int:
""" Get the batch size for feeding the model.
Sets the batch size to 1 if inference is being run on CPU, otherwise the minimum of the
@@ -781,7 +842,7 @@ class Predict():
logger.debug("Got batchsize: %s", batchsize)
return batchsize
def _get_model_name(self, model_dir):
def _get_model_name(self, model_dir: str) -> str:
""" Return the name of the Faceswap model used.
If a "trainer" option has been selected in the command line arguments, use that value,
@@ -802,13 +863,13 @@ class Predict():
logger.debug("Trainer name provided: '%s'", self._args.trainer)
return self._args.trainer
statefile = [fname for fname in os.listdir(str(model_dir))
if fname.endswith("_state.json")]
if len(statefile) != 1:
raise FaceswapError("There should be 1 state file in your model folder. {} were "
"found. Specify a trainer with the '-t', '--trainer' "
"option.".format(len(statefile)))
statefile = os.path.join(str(model_dir), statefile[0])
statefiles = [fname for fname in os.listdir(str(model_dir))
if fname.endswith("_state.json")]
if len(statefiles) != 1:
raise FaceswapError("There should be 1 state file in your model folder. "
f"{len(statefiles)} were found. Specify a trainer with the '-t', "
"'--trainer' option.")
statefile = os.path.join(str(model_dir), statefiles[0])
state = self._serializer.load(statefile)
trainer = state.get("name", None)
@@ -819,7 +880,7 @@ class Predict():
logger.debug("Trainer from state file: '%s'", trainer)
return trainer
def _launch_predictor(self):
def _launch_predictor(self) -> MultiThread:
""" Launch the prediction process in a background thread.
Starts the prediction thread and returns the thread.
@@ -833,7 +894,7 @@ class Predict():
thread.start()
return thread
def _predict_faces(self):
def _predict_faces(self) -> None:
""" Run Prediction on the Faceswap model in a background thread.
Reads from the :attr:`self._in_queue`, prepares images for prediction
@@ -841,64 +902,63 @@ class Predict():
"""
faces_seen = 0
consecutive_no_faces = 0
batch = list()
batch: List[ConvertItem] = []
is_amd = get_backend() == "amd"
while True:
item = self._in_queue.get()
if item != "EOF":
logger.trace("Got from queue: '%s'", item["filename"])
faces_count = len(item["detected_faces"])
item: Union[Literal["EOF"], ConvertItem] = self._in_queue.get()
if item == "EOF":
logger.debug("EOF Received")
break
logger.trace("Got from queue: '%s'", item.inbound.filename) # type:ignore
faces_count = len(item.inbound.detected_faces)
# Safety measure. If a large stream of frames appear that do not have faces,
# these will stack up into RAM. Keep a count of consecutive frames with no faces.
# If self._batchsize number of frames appear, force the current batch through
# to clear RAM.
consecutive_no_faces = consecutive_no_faces + 1 if faces_count == 0 else 0
self._faces_count += faces_count
if faces_count > 1:
self._verify_output = True
logger.verbose("Found more than one face in an image! '%s'",
os.path.basename(item["filename"]))
# Safety measure. If a large stream of frames appear that do not have faces,
# these will stack up into RAM. Keep a count of consecutive frames with no faces.
# If self._batchsize number of frames appear, force the current batch through
# to clear RAM.
consecutive_no_faces = consecutive_no_faces + 1 if faces_count == 0 else 0
self._faces_count += faces_count
if faces_count > 1:
self._verify_output = True
logger.verbose("Found more than one face in an image! '%s'", # type:ignore
os.path.basename(item.inbound.filename))
self.load_aligned(item)
self.load_aligned(item)
faces_seen += faces_count
faces_seen += faces_count
batch.append(item)
batch.append(item)
if item != "EOF" and (faces_seen < self._batchsize and
consecutive_no_faces < self._batchsize):
logger.trace("Continuing. Current batchsize: %s, consecutive_no_faces: %s",
faces_seen, consecutive_no_faces)
if faces_seen < self._batchsize and consecutive_no_faces < self._batchsize:
logger.trace("Continuing. Current batchsize: %s, " # type:ignore
"consecutive_no_faces: %s", faces_seen, consecutive_no_faces)
continue
if batch:
logger.trace("Batching to predictor. Frames: %s, Faces: %s",
logger.trace("Batching to predictor. Frames: %s, Faces: %s", # type:ignore
len(batch), faces_seen)
feed_batch = [feed_face for item in batch
for feed_face in item["feed_faces"]]
for feed_face in item.feed_faces]
if faces_seen != 0:
feed_faces = self._compile_feed_faces(feed_batch)
batch_size = None
if is_amd and feed_faces.shape[0] != self._batchsize:
logger.verbose("Fallback to BS=1")
logger.verbose("Fallback to BS=1") # type:ignore
batch_size = 1
predicted = self._predict(feed_faces, batch_size)
else:
predicted = list()
predicted = np.array([])
self._queue_out_frames(batch, predicted)
consecutive_no_faces = 0
faces_seen = 0
batch = list()
if item == "EOF":
logger.debug("EOF Received")
break
batch = []
logger.debug("Putting EOF")
self._out_queue.put("EOF")
logger.debug("Load queue complete")
def load_aligned(self, item):
def load_aligned(self, item: ConvertItem) -> None:
""" Load the model's feed faces and the reference output faces.
For each detected face in the incoming item, load the feed face and reference face
@@ -906,18 +966,15 @@ class Predict():
Parameters
----------
item: dict
The incoming image, list of :class:`~lib.align.DetectedFace` objects and list of
:class:`~lib.align.AlignedFace` objects for the feed face(s) and list of
:class:`~lib.align.AlignedFace` objects for the reference face(s)
item: :class:`ConvertMedia`
The convert media object, containing the ExctractMedia for the current image
"""
logger.trace("Loading aligned faces: '%s'", item["filename"])
logger.trace("Loading aligned faces: '%s'", item.inbound.filename) # type:ignore
feed_faces = []
reference_faces = []
for detected_face in item["detected_faces"]:
for detected_face in item.inbound.detected_faces:
feed_face = AlignedFace(detected_face.landmarks_xy,
image=item["image"],
image=item.inbound.image,
centering=self._centering,
size=self._sizes["input"],
coverage_ratio=self._coverage_ratio,
@@ -926,18 +983,18 @@ class Predict():
reference_faces.append(feed_face)
else:
reference_faces.append(AlignedFace(detected_face.landmarks_xy,
image=item["image"],
image=item.inbound.image,
centering=self._centering,
size=self._sizes["output"],
coverage_ratio=self._coverage_ratio,
dtype="float32"))
feed_faces.append(feed_face)
item["feed_faces"] = feed_faces
item["reference_faces"] = reference_faces
logger.trace("Loaded aligned faces: '%s'", item["filename"])
item.feed_faces = feed_faces
item.reference_faces = reference_faces
logger.trace("Loaded aligned faces: '%s'", item.inbound.filename) # type:ignore
@staticmethod
def _compile_feed_faces(feed_faces):
def _compile_feed_faces(feed_faces: List[AlignedFace]) -> np.ndarray:
""" Compile a batch of faces for feeding into the Predictor.
Parameters
@@ -950,12 +1007,13 @@ class Predict():
:class:`numpy.ndarray`
A batch of faces ready for feeding into the Faceswap model.
"""
logger.trace("Compiling feed face. Batchsize: %s", len(feed_faces))
retval = np.stack([feed_face.face[..., :3] for feed_face in feed_faces]) / 255.0
logger.trace("Compiled Feed faces. Shape: %s", retval.shape)
logger.trace("Compiling feed face. Batchsize: %s", len(feed_faces)) # type:ignore
retval = np.stack([cast(np.ndarray, feed_face.face)[..., :3]
for feed_face in feed_faces]) / 255.0
logger.trace("Compiled Feed faces. Shape: %s", retval.shape) # type:ignore
return retval
def _predict(self, feed_faces, batch_size=None):
def _predict(self, feed_faces: np.ndarray, batch_size: Optional[int] = None) -> np.ndarray:
""" Run the Faceswap models' prediction function.
Parameters
@@ -971,32 +1029,33 @@ class Predict():
:class:`numpy.ndarray`
The swapped faces for the given batch
"""
logger.trace("Predicting: Batchsize: %s", len(feed_faces))
logger.trace("Predicting: Batchsize: %s", len(feed_faces)) # type:ignore
if self._model.color_order.lower() == "rgb":
feed_faces = feed_faces[..., ::-1]
feed = [feed_faces]
logger.trace("Input shape(s): %s", [item.shape for item in feed])
logger.trace("Input shape(s): %s", [item.shape for item in feed]) # type:ignore
predicted = self._model.model.predict(feed, verbose=0, batch_size=batch_size)
predicted = predicted if isinstance(predicted, list) else [predicted]
inbound = self._model.model.predict(feed, verbose=0, batch_size=batch_size)
predicted: List[np.ndarray] = inbound if isinstance(inbound, list) else [inbound]
if self._model.color_order.lower() == "rgb":
predicted[0] = predicted[0][..., ::-1]
logger.trace("Output shape(s): %s", [predict.shape for predict in predicted])
logger.trace("Output shape(s): %s", # type:ignore
[predict.shape for predict in predicted])
# Only take last output(s)
if predicted[-1].shape[-1] == 1: # Merge mask to alpha channel
predicted = np.concatenate(predicted[-2:], axis=-1).astype("float32")
retval = np.concatenate(predicted[-2:], axis=-1).astype("float32")
else:
predicted = predicted[-1].astype("float32")
retval = predicted[-1].astype("float32")
logger.trace("Final shape: %s", predicted.shape)
return predicted
logger.trace("Final shape: %s", retval.shape) # type:ignore
return retval
def _queue_out_frames(self, batch, swapped_faces):
def _queue_out_frames(self, batch: List[ConvertItem], swapped_faces: np.ndarray) -> None:
""" Compile the batch back to original frames and put to the Out Queue.
For batching, faces are split away from their frames. This compiles all detected faces
@@ -1009,21 +1068,20 @@ class Predict():
swapped_faces: :class:`numpy.ndarray`
The predictions returned from the model's predict function
"""
logger.trace("Queueing out batch. Batchsize: %s", len(batch))
logger.trace("Queueing out batch. Batchsize: %s", len(batch)) # type:ignore
pointer = 0
for item in batch:
num_faces = len(item["detected_faces"])
if num_faces == 0:
item["swapped_faces"] = np.array(list())
else:
item["swapped_faces"] = swapped_faces[pointer:pointer + num_faces]
num_faces = len(item.inbound.detected_faces)
if num_faces != 0:
item.swapped_faces = swapped_faces[pointer:pointer + num_faces]
logger.trace("Putting to queue. ('%s', detected_faces: %s, reference_faces: %s, "
"swapped_faces: %s)", item["filename"], len(item["detected_faces"]),
len(item["reference_faces"]), item["swapped_faces"].shape[0])
logger.trace("Putting to queue. ('%s', detected_faces: %s, " # type:ignore
"reference_faces: %s, swapped_faces: %s)", item.inbound.filename,
len(item.inbound.detected_faces), len(item.reference_faces),
item.swapped_faces.shape[0])
pointer += num_faces
self._out_queue.put(batch)
logger.trace("Queued out batch. Batchsize: %s", len(batch))
logger.trace("Queued out batch. Batchsize: %s", len(batch)) # type:ignore
class OptionalActions(): # pylint:disable=too-few-public-methods
@@ -1041,8 +1099,10 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
alignments: :class:`lib.align.Alignments`
The alignments file for this conversion
"""
def __init__(self, arguments, input_images, alignments):
def __init__(self,
arguments: "Namespace",
input_images: List[np.ndarray],
alignments: Alignments) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
self._args = arguments
self._input_images = input_images
@@ -1052,7 +1112,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
logger.debug("Initialized %s", self.__class__.__name__)
# SKIP FACES #
def _remove_skipped_faces(self):
def _remove_skipped_faces(self) -> None:
""" If the user has specified an input aligned directory, remove any non-matching faces
from the alignments file. """
logger.debug("Filtering Faces")
@@ -1064,7 +1124,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
self._alignments.filter_faces(accept_dict, filter_out=False)
logger.info("Faces filtered out: %s", pre_face_count - self._alignments.faces_count)
def _get_face_metadata(self):
def _get_face_metadata(self) -> Dict[str, List[int]]:
""" Check for the existence of an aligned directory for identifying which faces in the
target frames should be swapped. If it exists, scan the folder for face's metadata
@@ -1073,12 +1133,12 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
dict
Dictionary of source frame names with a list of associated face indices to be skipped
"""
retval = dict()
retval: Dict[str, List[int]] = {}
input_aligned_dir = self._args.input_aligned_dir
if input_aligned_dir is None:
logger.verbose("Aligned directory not specified. All faces listed in the "
"alignments file will be converted")
logger.verbose("Aligned directory not specified. All faces listed in " # type:ignore
"the alignments file will be converted")
return retval
if not os.path.isdir(input_aligned_dir):
logger.warning("Aligned directory not found. All faces listed in the "
@@ -1100,13 +1160,13 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
data = update_legacy_png_header(fullpath, self._alignments)
if not data:
raise FaceswapError(
"Some of the faces being passed in from '{}' could not be matched to the "
"alignments file '{}'\nPlease double check your sources and try "
"again.".format(input_aligned_dir, self._alignments.file))
f"Some of the faces being passed in from '{input_aligned_dir}' could not "
f"be matched to the alignments file '{self._alignments.file}'\n"
"Please double check your sources and try again.")
meta = data["source"]
else:
meta = metadata["itxt"]["source"]
retval.setdefault(meta["source_filename"], list()).append(meta["face_index"])
retval.setdefault(meta["source_filename"], []).append(meta["face_index"])
if not retval:
raise FaceswapError("Aligned directory is empty, no faces will be converted!")

View File

@@ -9,6 +9,7 @@ Holds optional pre/post processing functions for convert and extract.
import logging
import os
import sys
from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
import cv2
import numpy as np
@@ -19,10 +20,19 @@ from lib.face_filter import FaceFilter as FilterFunc
from lib.image import count_frames, read_image
from lib.utils import (camel_case_split, get_image_paths, _video_extensions)
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if TYPE_CHECKING:
from argparse import Namespace
from plugins.extract.pipeline import ExtractMedia
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def finalize(images_found, num_faces_detected, verify_output):
def finalize(images_found: int, num_faces_detected: int, verify_output: bool) -> None:
""" Output summary statistics at the end of the extract or convert processes.
Parameters
@@ -62,7 +72,10 @@ class Alignments(AlignmentsBase):
``True`` if the input to the process is a video, ``False`` if it is a folder of images.
Default: False
"""
def __init__(self, arguments, is_extract, input_is_video=False):
def __init__(self,
arguments: "Namespace",
is_extract: bool,
input_is_video: bool = False) -> None:
logger.debug("Initializing %s: (is_extract: %s, input_is_video: %s)",
self.__class__.__name__, is_extract, input_is_video)
self._args = arguments
@@ -71,7 +84,7 @@ class Alignments(AlignmentsBase):
super().__init__(folder, filename=filename)
logger.debug("Initialized %s", self.__class__.__name__)
def _set_folder_filename(self, input_is_video):
def _set_folder_filename(self, input_is_video: bool) -> Tuple[str, str]:
""" Return the folder and the filename for the alignments file.
If the input is a video, the alignments file will be stored in the same folder
@@ -106,7 +119,7 @@ class Alignments(AlignmentsBase):
logger.debug("Setting Alignments: (folder: '%s' filename: '%s')", folder, filename)
return folder, filename
def _load(self):
def _load(self) -> Dict[str, Any]:
""" Override the parent :func:`~lib.align.Alignments._load` to handle skip existing
frames and faces on extract.
@@ -119,10 +132,10 @@ class Alignments(AlignmentsBase):
Any alignments that have already been extracted if skip existing has been selected
otherwise an empty dictionary
"""
data = {}
data: Dict[str, Any] = {}
if not self._is_extract and not self.have_alignments_file:
return data
if not self._is_extract:
if not self.have_alignments_file:
return data
data = super()._load()
return data
@@ -146,7 +159,7 @@ class Alignments(AlignmentsBase):
logger.debug("Frames with no faces selected for redetection: %s", len(del_keys))
for key in del_keys:
if key in data:
logger.trace("Selected for redetection: '%s'", key)
logger.trace("Selected for redetection: '%s'", key) # type: ignore
del data[key]
return data
@@ -160,7 +173,7 @@ class Images():
arguments: :class:`argparse.Namespace`
The command line arguments that were passed to Faceswap
"""
def __init__(self, arguments):
def __init__(self, arguments: "Namespace") -> None:
logger.debug("Initializing %s", self.__class__.__name__)
self._args = arguments
self._is_video = self._check_input_folder()
@@ -169,22 +182,22 @@ class Images():
logger.debug("Initialized %s", self.__class__.__name__)
@property
def is_video(self):
def is_video(self) -> bool:
"""bool: ``True`` if the input is a video file otherwise ``False``. """
return self._is_video
@property
def input_images(self):
def input_images(self) -> Union[str, List[str]]:
"""str or list: Path to the video file if the input is a video otherwise list of
image paths. """
return self._input_images
@property
def images_found(self):
def images_found(self) -> int:
"""int: The number of frames that exist in the video file, or the folder of images. """
return self._images_found
def _count_images(self):
def _count_images(self) -> int:
""" Get the number of Frames from a video file or folder of images.
Returns
@@ -198,7 +211,7 @@ class Images():
retval = len(self._input_images)
return retval
def _check_input_folder(self):
def _check_input_folder(self) -> bool:
""" Check whether the input is a folder or video.
Returns
@@ -218,7 +231,7 @@ class Images():
retval = False
return retval
def _get_input_images(self):
def _get_input_images(self) -> Union[str, List[str]]:
""" Return the list of images or path to video file that is to be processed.
Returns
@@ -233,7 +246,7 @@ class Images():
return input_images
def load(self):
def load(self) -> Generator[Tuple[str, np.ndarray], None, None]:
""" Generator to load frames from a folder of images or from a video file.
Yields
@@ -247,7 +260,7 @@ class Images():
for filename, image in iterator():
yield filename, image
def _load_disk_frames(self):
def _load_disk_frames(self) -> Generator[Tuple[str, np.ndarray], None, None]:
""" Generator to load frames from a folder of images.
Yields
@@ -264,7 +277,7 @@ class Images():
continue
yield filename, image
def _load_video_frames(self):
def _load_video_frames(self) -> Generator[Tuple[str, np.ndarray], None, None]:
""" Generator to load frames from a video file.
Yields
@@ -281,11 +294,11 @@ class Images():
# Convert to BGR for cv2 compatibility
frame = frame[:, :, ::-1]
filename = f"{vidname}_{i + 1:06d}.png"
logger.trace("Loading video frame: '%s'", filename)
logger.trace("Loading video frame: '%s'", filename) # type: ignore
yield filename, frame
reader.close()
def load_one_image(self, filename):
def load_one_image(self, filename) -> np.ndarray:
""" Obtain a single image for the given filename.
Parameters
@@ -299,19 +312,20 @@ class Images():
The image for the requested filename,
"""
logger.trace("Loading image: '%s'", filename)
logger.trace("Loading image: '%s'", filename) # type: ignore
if self._is_video:
if filename.isdigit():
frame_no = filename
else:
frame_no = os.path.splitext(filename)[0][filename.rfind("_") + 1:]
logger.trace("Extracted frame_no %s from filename '%s'", frame_no, filename)
logger.trace("Extracted frame_no %s from filename '%s'", # type: ignore
frame_no, filename)
retval = self._load_one_video_frame(int(frame_no))
else:
retval = read_image(filename, raise_error=True)
return retval
def _load_one_video_frame(self, frame_no):
def _load_one_video_frame(self, frame_no: int) -> np.ndarray:
""" Obtain a single frame from a video file.
Parameters
@@ -324,7 +338,7 @@ class Images():
:class:`numpy.ndarray`
The image for the requested frame index,
"""
logger.trace("Loading video frame: %s", frame_no)
logger.trace("Loading video frame: %s", frame_no) # type: ignore
reader = imageio.get_reader(self._args.input_dir, "ffmpeg")
reader.set_image_index(frame_no - 1)
frame = reader.get_next_data()[:, :, ::-1]
@@ -343,13 +357,13 @@ class PostProcess(): # pylint:disable=too-few-public-methods
arguments: :class:`argparse.Namespace`
The command line arguments that were passed to Faceswap
"""
def __init__(self, arguments):
def __init__(self, arguments: "Namespace") -> None:
logger.debug("Initializing %s", self.__class__.__name__)
self._args = arguments
self._actions = self._set_actions()
logger.debug("Initialized %s", self.__class__.__name__)
def _set_actions(self):
def _set_actions(self) -> List["PostProcessAction"]:
""" Compile the requested actions to be performed into a list
Returns
@@ -358,7 +372,7 @@ class PostProcess(): # pylint:disable=too-few-public-methods
The list of :class:`PostProcessAction` to be performed
"""
postprocess_items = self._get_items()
actions = []
actions: List["PostProcessAction"] = []
for action, options in postprocess_items.items():
options = {} if options is None else options
args = options.get("args", tuple())
@@ -370,13 +384,13 @@ class PostProcess(): # pylint:disable=too-few-public-methods
logger.debug("Adding Postprocess action: '%s'", task)
actions.append(task)
for action in actions:
action_name = camel_case_split(action.__class__.__name__)
for ppaction in actions:
action_name = camel_case_split(ppaction.__class__.__name__)
logger.info("Adding post processing item: %s", " ".join(action_name))
return actions
def _get_items(self):
def _get_items(self) -> Dict[str, Optional[Dict[str, Union[tuple, dict]]]]:
""" Check the passed in command line arguments for requested actions,
For any requested actions, add the item to the actions list along with
@@ -388,7 +402,7 @@ class PostProcess(): # pylint:disable=too-few-public-methods
The name of the action to be performed as the key. Any action specific
arguments and keyword arguments as the value.
"""
postprocess_items = {}
postprocess_items: Dict[str, Optional[Dict[str, Union[tuple, dict]]]] = {}
# Debug Landmarks
if (hasattr(self._args, 'debug_landmarks') and self._args.debug_landmarks):
postprocess_items["DebugLandmarks"] = None
@@ -423,7 +437,7 @@ class PostProcess(): # pylint:disable=too-few-public-methods
logger.debug("Postprocess Items: %s", postprocess_items)
return postprocess_items
def do_actions(self, extract_media):
def do_actions(self, extract_media: "ExtractMedia") -> None:
""" Perform the requested optional post-processing actions on the given image.
Parameters
@@ -455,19 +469,19 @@ class PostProcessAction(): # pylint: disable=too-few-public-methods
kwargs: dict
Varies for specific post process action
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
logger.debug("Initializing %s: (args: %s, kwargs: %s)",
self.__class__.__name__, args, kwargs)
self._valid = True # Set to False if invalid parameters passed in to disable
logger.debug("Initialized base class %s", self.__class__.__name__)
@property
def valid(self):
def valid(self) -> bool:
"""bool: ``True`` if the action if the parameters passed in for this action are valid,
otherwise ``False`` """
return self._valid
def process(self, extract_media):
def process(self, extract_media: "ExtractMedia") -> None:
""" Override for specific post processing action
Parameters
@@ -481,12 +495,12 @@ class PostProcessAction(): # pylint: disable=too-few-public-methods
class DebugLandmarks(PostProcessAction): # pylint: disable=too-few-public-methods
""" Draw debug landmarks on face output. Extract Only """
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(self, *args, **kwargs)
self._face_size = 0
self._legacy_size = 0
def process(self, extract_media):
def process(self, extract_media: "ExtractMedia") -> None:
""" Draw landmarks on a face.
Parameters
@@ -494,12 +508,6 @@ class DebugLandmarks(PostProcessAction): # pylint: disable=too-few-public-metho
extract_media: :class:`~plugins.extract.pipeline.ExtractMedia`
The :class:`~plugins.extract.pipeline.ExtractMedia` object that contains the faces to
draw the landmarks on to
Returns
-------
:class:`~plugins.extract.pipeline.ExtractMedia`
The original :class:`~plugins.extract.pipeline.ExtractMedia` with landmarks drawn
onto the face
"""
frame = os.path.splitext(os.path.basename(extract_media.filename))[0]
for idx, face in enumerate(extract_media.detected_faces):
@@ -514,12 +522,12 @@ class DebugLandmarks(PostProcessAction): # pylint: disable=too-few-public-metho
face.aligned.size)
logger.debug("set legacy size: %s", self._legacy_size)
logger.trace("Drawing Landmarks. Frame: '%s'. Face: %s", frame, idx)
logger.trace("Drawing Landmarks. Frame: '%s'. Face: %s", frame, idx) # type: ignore
# Landmarks
for (pos_x, pos_y) in face.aligned.landmarks.astype("int32"):
cv2.circle(face.aligned.face, (pos_x, pos_y), 1, (0, 255, 255), -1)
# Pose
center = tuple(np.int32((face.aligned.size / 2, face.aligned.size / 2)))
center = (face.aligned.size // 2, face.aligned.size // 2)
points = (face.aligned.pose.xyz_2d * face.aligned.size).astype("int32")
cv2.line(face.aligned.face, center, tuple(points[1]), (0, 255, 0), 1)
cv2.line(face.aligned.face, center, tuple(points[0]), (255, 0, 0), 1)
@@ -554,13 +562,18 @@ class FaceFilter(PostProcessAction):
* **filter_lists** (`dict`) - The filter and nfilter image paths
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
logger.info("Extracting and aligning face for Face Filter...")
self._filter = self._load_face_filter(**kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
def _load_face_filter(self, filter_lists, ref_threshold, aligner, detector, multiprocess):
def _load_face_filter(self,
filter_lists: Dict[str, str],
ref_threshold: float,
aligner: str,
detector: str,
multiprocess: bool) -> Optional[FilterFunc]:
""" Set up and load the :class:`~lib.face_filter.FaceFilter`.
Parameters
@@ -586,7 +599,7 @@ class FaceFilter(PostProcessAction):
facefilter = None
filter_files = [self._set_face_filter(f_type, filter_lists[f_type])
for f_type in ("filter", "nfilter")]
for f_type in get_args(Literal["filter", "nfilter"])]
if any(filters for filters in filter_files):
facefilter = FilterFunc(filter_files[0],
@@ -597,11 +610,13 @@ class FaceFilter(PostProcessAction):
ref_threshold)
logger.debug("Face filter: %s", facefilter)
else:
self.valid = False
self._valid = False
return facefilter
@staticmethod
def _set_face_filter(f_type, f_args):
@classmethod
def _set_face_filter(cls,
f_type: Literal["filter", "nfilter"],
f_args: Union[str, List[str]]) -> List[str]:
""" Check filter files exist and add the filter file paths to a list.
Parameters
@@ -621,14 +636,14 @@ class FaceFilter(PostProcessAction):
logger.info("%s: %s", f_type.title(), f_args)
filter_files = f_args if isinstance(f_args, list) else [f_args]
filter_files = list(filter(lambda fpath: os.path.exists(fpath), filter_files))
filter_files = [fpath for fpath in filter_files if os.path.exists(fpath)]
if not filter_files:
logger.warning("Face %s files were requested, but no files could be found. This "
"filter will not be applied.", f_type)
logger.debug("Face Filter files: %s", filter_files)
return filter_files
def process(self, extract_media):
def process(self, extract_media: "ExtractMedia") -> None:
""" Filters in or out any wanted or unwanted faces based on command line arguments.
Parameters
@@ -636,12 +651,6 @@ class FaceFilter(PostProcessAction):
extract_media: :class:`~plugins.extract.pipeline.ExtractMedia`
The :class:`~plugins.extract.pipeline.ExtractMedia` object to perform the
face filtering on.
Returns
-------
:class:`~plugins.extract.pipeline.ExtractMedia`
The original :class:`~plugins.extract.pipeline.ExtractMedia` with any requested filters
applied
"""
if not self._filter:
return
@@ -649,10 +658,10 @@ class FaceFilter(PostProcessAction):
for idx, detect_face in enumerate(extract_media.detected_faces):
check_item = detect_face["face"] if isinstance(detect_face, dict) else detect_face
if not self._filter.check(extract_media.image, check_item):
logger.verbose("Skipping not recognized face: (Frame: %s Face %s)",
logger.verbose("Skipping not recognized face: (Frame: %s Face %s)", # type: ignore
extract_media.filename, idx)
continue
logger.trace("Accepting recognised face. Frame: %s. Face: %s",
logger.trace("Accepting recognised face. Frame: %s. Face: %s", # type: ignore
extract_media.filename, idx)
ret_faces.append(detect_face)
extract_media.add_detected_faces(ret_faces)

File diff suppressed because it is too large Load Diff