mirror of
https://github.com/zebrajr/faceswap.git
synced 2026-01-15 12:15:15 +00:00
Aligner updates
- Add filter re-feeds option - bugfix roll calculation
This commit is contained in:
@@ -86,6 +86,17 @@ class Config(FaceswapConfig):
|
||||
"degrees. Aligned faces should have a roll value close to zero. Values that are a "
|
||||
"significant distance from 0 degrees tend to be misaligned images. These can usually "
|
||||
"be safely disgarded.")
|
||||
self.add_item(
|
||||
section=section,
|
||||
title="filter_refeed",
|
||||
datatype=bool,
|
||||
default=True,
|
||||
group="filters",
|
||||
info="If enabled, and re-feed has been selected for extraction, then interim "
|
||||
"alignments will be filtered prior to averaging the final landmarks. This can "
|
||||
"help improve the final alignments by removing any obvious misaligns from the "
|
||||
"interim results, and may also help pick up difficult alignments. If disabled, "
|
||||
"then all re-feed results will be averaged.")
|
||||
self.add_item(
|
||||
section=section,
|
||||
title="save_filtered",
|
||||
|
||||
@@ -377,10 +377,46 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
||||
raise FaceswapError(msg) from err
|
||||
raise
|
||||
|
||||
def _get_mean_landmarks(self, landmarks: np.ndarray, masks: List[List[bool]]) -> np.ndarray:
|
||||
""" Obtain the averaged landmarks from the re-fed alignments. If config option
|
||||
'filter_refeed' is enabled, then average those results which have not been filtered out
|
||||
otherwise average all results
|
||||
|
||||
Parameters
|
||||
----------
|
||||
landmarks: :class:`numpy.ndarray`
|
||||
The batch of re-fed alignments
|
||||
masks: list
|
||||
List of boolean values indicating whether each re-fed alignments passed or failed
|
||||
the filter test
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`numpy.ndarray`
|
||||
The final averaged landmarks
|
||||
"""
|
||||
if not self.config["filter_refeed"]:
|
||||
return landmarks.mean(axis=0).astype("float32")
|
||||
|
||||
mask = np.array(masks)
|
||||
if any(np.all(masked) for masked in mask.T):
|
||||
# hacky fix for faces which entirely failed the filter
|
||||
# We just unmask one value as it is junk anyway and will be discarded on output
|
||||
for idx, masked in enumerate(mask.T):
|
||||
if np.all(masked):
|
||||
mask[0, idx] = False
|
||||
|
||||
mask = np.broadcast_to(np.reshape(mask, (*landmarks.shape[:2], 1, 1)),
|
||||
landmarks.shape)
|
||||
return np.ma.array(landmarks, mask=mask).mean(axis=0).data.astype("float32")
|
||||
|
||||
def _process_output(self, batch: BatchType) -> AlignerBatch:
|
||||
""" Process the output from the aligner model multiple times based on the user selected
|
||||
`re-feed amount` configuration option, then average the results for final prediction.
|
||||
|
||||
If the config option 'filter_refeed' is enabled, then mask out any returned alignments
|
||||
that fail a filter test
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch : :class:`AlignerBatch`
|
||||
@@ -392,7 +428,8 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
||||
The batch item with :attr:`landmarks` populated
|
||||
"""
|
||||
assert isinstance(batch, AlignerBatch)
|
||||
landmarks = []
|
||||
landmark_list: List[np.ndarray] = []
|
||||
masks: List[List[bool]] = []
|
||||
for idx in range(self._re_feed + 1):
|
||||
# Create a pseudo object that only populates the data, feed and prediction slots with
|
||||
# the current re-feed iteration
|
||||
@@ -403,8 +440,14 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
||||
prediction=batch.prediction[idx],
|
||||
data=[batch.data[idx]])
|
||||
self.process_output(subbatch)
|
||||
landmarks.append(subbatch.landmarks)
|
||||
batch.landmarks = np.average(landmarks, axis=0)
|
||||
landmark_list.append(subbatch.landmarks)
|
||||
|
||||
if self.config["filter_refeed"]:
|
||||
fcs = [DetectedFace(landmarks_xy=lm) for lm in subbatch.landmarks.copy()]
|
||||
min_sizes = [min(img.shape[:2]) for img in batch.image]
|
||||
masks.append(self._filter.filtered_mask(fcs, min_sizes))
|
||||
|
||||
batch.landmarks = self._get_mean_landmarks(np.array(landmark_list), masks)
|
||||
return batch
|
||||
|
||||
# <<< FACE NORMALIZATION METHODS >>> #
|
||||
@@ -574,7 +617,7 @@ class AlignedFilter():
|
||||
sub_folders[idx] = "_align_filt_distance"
|
||||
continue
|
||||
|
||||
if not -self._roll <= aligned.pose.roll <= self._roll:
|
||||
if not 0.0 < abs(aligned.pose.roll) < self._roll:
|
||||
self._counts["roll"] += 1
|
||||
if self._save_output:
|
||||
retval.append(face)
|
||||
@@ -619,7 +662,7 @@ class AlignedFilter():
|
||||
|
||||
return None
|
||||
|
||||
def filtered_mask(self, faces: List[DetectedFace], minimum_dimension: int) -> List[bool]:
|
||||
def filtered_mask(self, faces: List[DetectedFace], minimum_dimension: List[int]) -> List[bool]:
|
||||
""" Obtain a list of boolean values for the given faces indicating whether they pass the
|
||||
filter test.
|
||||
|
||||
@@ -627,8 +670,8 @@ class AlignedFilter():
|
||||
----------
|
||||
faces: list
|
||||
List of detected face objects to test the filters for
|
||||
minimum_dimension: int
|
||||
The minimum (height, width) of the original frame
|
||||
minimum_dimension: list
|
||||
The minimum (height, width) of the original frames that the faces come from
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -637,13 +680,13 @@ class AlignedFilter():
|
||||
test. ``False`` the face passed the test. ``True`` it failed
|
||||
"""
|
||||
retval = [True for _ in range(len(faces))]
|
||||
for idx, face in enumerate(faces):
|
||||
for idx, (face, dim) in enumerate(zip(faces, minimum_dimension)):
|
||||
aligned = AlignedFace(landmarks=face.landmarks_xy)
|
||||
if self._scale_test(aligned, minimum_dimension) is not None:
|
||||
if self._scale_test(aligned, dim) is not None:
|
||||
continue
|
||||
if 0.0 < self._distance < aligned.average_distance:
|
||||
continue
|
||||
if not -self._roll <= aligned.pose.roll <= self._roll:
|
||||
if not 0.0 < abs(aligned.pose.roll) < self._roll:
|
||||
continue
|
||||
retval[idx] = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user