Merge branch 'staging'

This commit is contained in:
torzdf
2025-11-04 12:30:02 +00:00

View File

@@ -44,6 +44,9 @@ class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods
batch_matrix: bool, Optional
``True`` to calculate the spectrum weight matrix using batch-based statistics otherwise
``False``. Default: ``False``
epsilon : float, Optional
Small epsilon for safer weights scaling division. Default: `1e-6`
References
----------
@@ -56,13 +59,15 @@ class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods
patch_factor: int = 1,
ave_spectrum: bool = False,
log_matrix: bool = False,
batch_matrix: bool = False) -> None:
batch_matrix: bool = False,
epsilon: float = 1e-6) -> None:
self._alpha = alpha
# TODO Fix bug where FFT will be incorrect if patch_factor > 1
self._patch_factor = patch_factor
self._ave_spectrum = ave_spectrum
self._log_matrix = log_matrix
self._batch_matrix = batch_matrix
self._epsilon = epsilon
self._dims: tuple[int, int] = (0, 0)
def _get_patches(self, inputs: tf.Tensor) -> tf.Tensor:
@@ -145,11 +150,11 @@ class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods
weights = K.log(weights + 1.0)
if self._batch_matrix: # calculate the spectrum weight matrix using batch-based statistics
weights = weights / K.max(weights)
scale = K.max(weights)
else:
weights = weights / K.max(K.max(weights, axis=-2), axis=-2)[..., None, None, :]
scale = K.max(weights, axis=(-2, -3), keepdims=True)
weights = weights / K.maximum(scale, self._epsilon)
weights = K.switch(tf.math.is_nan(weights), K.zeros_like(weights), weights)
weights = K.clip(weights, min_value=0.0, max_value=1.0)
return weights