Aligner updates

- Add filter re-feeds option
  - bugfix roll calculation
This commit is contained in:
torzdf
2022-11-08 12:10:56 +00:00
parent 0f5d2e887c
commit cb8ec69789
2 changed files with 64 additions and 10 deletions

View File

@@ -86,6 +86,17 @@ class Config(FaceswapConfig):
"degrees. Aligned faces should have a roll value close to zero. Values that are a "
"significant distance from 0 degrees tend to be misaligned images. These can usually "
"be safely disgarded.")
self.add_item(
section=section,
title="filter_refeed",
datatype=bool,
default=True,
group="filters",
info="If enabled, and re-feed has been selected for extraction, then interim "
"alignments will be filtered prior to averaging the final landmarks. This can "
"help improve the final alignments by removing any obvious misaligns from the "
"interim results, and may also help pick up difficult alignments. If disabled, "
"then all re-feed results will be averaged.")
self.add_item(
section=section,
title="save_filtered",

View File

@@ -377,10 +377,46 @@ class Aligner(Extractor): # pylint:disable=abstract-method
raise FaceswapError(msg) from err
raise
def _get_mean_landmarks(self, landmarks: np.ndarray, masks: List[List[bool]]) -> np.ndarray:
""" Obtain the averaged landmarks from the re-fed alignments. If config option
'filter_refeed' is enabled, then average those results which have not been filtered out
otherwise average all results
Parameters
----------
landmarks: :class:`numpy.ndarray`
The batch of re-fed alignments
masks: list
List of boolean values indicating whether each re-fed alignments passed or failed
the filter test
Returns
-------
:class:`numpy.ndarray`
The final averaged landmarks
"""
if not self.config["filter_refeed"]:
return landmarks.mean(axis=0).astype("float32")
mask = np.array(masks)
if any(np.all(masked) for masked in mask.T):
# hacky fix for faces which entirely failed the filter
# We just unmask one value as it is junk anyway and will be discarded on output
for idx, masked in enumerate(mask.T):
if np.all(masked):
mask[0, idx] = False
mask = np.broadcast_to(np.reshape(mask, (*landmarks.shape[:2], 1, 1)),
landmarks.shape)
return np.ma.array(landmarks, mask=mask).mean(axis=0).data.astype("float32")
def _process_output(self, batch: BatchType) -> AlignerBatch:
""" Process the output from the aligner model multiple times based on the user selected
`re-feed amount` configuration option, then average the results for final prediction.
If the config option 'filter_refeed' is enabled, then mask out any returned alignments
that fail a filter test
Parameters
----------
batch : :class:`AlignerBatch`
@@ -392,7 +428,8 @@ class Aligner(Extractor): # pylint:disable=abstract-method
The batch item with :attr:`landmarks` populated
"""
assert isinstance(batch, AlignerBatch)
landmarks = []
landmark_list: List[np.ndarray] = []
masks: List[List[bool]] = []
for idx in range(self._re_feed + 1):
# Create a pseudo object that only populates the data, feed and prediction slots with
# the current re-feed iteration
@@ -403,8 +440,14 @@ class Aligner(Extractor): # pylint:disable=abstract-method
prediction=batch.prediction[idx],
data=[batch.data[idx]])
self.process_output(subbatch)
landmarks.append(subbatch.landmarks)
batch.landmarks = np.average(landmarks, axis=0)
landmark_list.append(subbatch.landmarks)
if self.config["filter_refeed"]:
fcs = [DetectedFace(landmarks_xy=lm) for lm in subbatch.landmarks.copy()]
min_sizes = [min(img.shape[:2]) for img in batch.image]
masks.append(self._filter.filtered_mask(fcs, min_sizes))
batch.landmarks = self._get_mean_landmarks(np.array(landmark_list), masks)
return batch
# <<< FACE NORMALIZATION METHODS >>> #
@@ -574,7 +617,7 @@ class AlignedFilter():
sub_folders[idx] = "_align_filt_distance"
continue
if not -self._roll <= aligned.pose.roll <= self._roll:
if not 0.0 < abs(aligned.pose.roll) < self._roll:
self._counts["roll"] += 1
if self._save_output:
retval.append(face)
@@ -619,7 +662,7 @@ class AlignedFilter():
return None
def filtered_mask(self, faces: List[DetectedFace], minimum_dimension: int) -> List[bool]:
def filtered_mask(self, faces: List[DetectedFace], minimum_dimension: List[int]) -> List[bool]:
""" Obtain a list of boolean values for the given faces indicating whether they pass the
filter test.
@@ -627,8 +670,8 @@ class AlignedFilter():
----------
faces: list
List of detected face objects to test the filters for
minimum_dimension: int
The minimum (height, width) of the original frame
minimum_dimension: list
The minimum (height, width) of the original frames that the faces come from
Returns
-------
@@ -637,13 +680,13 @@ class AlignedFilter():
test. ``False`` the face passed the test. ``True`` it failed
"""
retval = [True for _ in range(len(faces))]
for idx, face in enumerate(faces):
for idx, (face, dim) in enumerate(zip(faces, minimum_dimension)):
aligned = AlignedFace(landmarks=face.landmarks_xy)
if self._scale_test(aligned, minimum_dimension) is not None:
if self._scale_test(aligned, dim) is not None:
continue
if 0.0 < self._distance < aligned.average_distance:
continue
if not -self._roll <= aligned.pose.roll <= self._roll:
if not 0.0 < abs(aligned.pose.roll) < self._roll:
continue
retval[idx] = False