mirror of
https://github.com/zebrajr/faceswap.git
synced 2026-01-15 12:15:15 +00:00
Merge branch 'staging'
This commit is contained in:
@@ -44,6 +44,9 @@ class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods
|
||||
batch_matrix: bool, Optional
|
||||
``True`` to calculate the spectrum weight matrix using batch-based statistics otherwise
|
||||
``False``. Default: ``False``
|
||||
epsilon : float, Optional
|
||||
Small epsilon for safer weights scaling division. Default: `1e-6`
|
||||
|
||||
|
||||
References
|
||||
----------
|
||||
@@ -56,13 +59,15 @@ class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods
|
||||
patch_factor: int = 1,
|
||||
ave_spectrum: bool = False,
|
||||
log_matrix: bool = False,
|
||||
batch_matrix: bool = False) -> None:
|
||||
batch_matrix: bool = False,
|
||||
epsilon: float = 1e-6) -> None:
|
||||
self._alpha = alpha
|
||||
# TODO Fix bug where FFT will be incorrect if patch_factor > 1
|
||||
self._patch_factor = patch_factor
|
||||
self._ave_spectrum = ave_spectrum
|
||||
self._log_matrix = log_matrix
|
||||
self._batch_matrix = batch_matrix
|
||||
self._epsilon = epsilon
|
||||
self._dims: tuple[int, int] = (0, 0)
|
||||
|
||||
def _get_patches(self, inputs: tf.Tensor) -> tf.Tensor:
|
||||
@@ -145,11 +150,11 @@ class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods
|
||||
weights = K.log(weights + 1.0)
|
||||
|
||||
if self._batch_matrix: # calculate the spectrum weight matrix using batch-based statistics
|
||||
weights = weights / K.max(weights)
|
||||
scale = K.max(weights)
|
||||
else:
|
||||
weights = weights / K.max(K.max(weights, axis=-2), axis=-2)[..., None, None, :]
|
||||
scale = K.max(weights, axis=(-2, -3), keepdims=True)
|
||||
weights = weights / K.maximum(scale, self._epsilon)
|
||||
|
||||
weights = K.switch(tf.math.is_nan(weights), K.zeros_like(weights), weights)
|
||||
weights = K.clip(weights, min_value=0.0, max_value=1.0)
|
||||
|
||||
return weights
|
||||
|
||||
Reference in New Issue
Block a user