diff --git a/lib/training/generator.py b/lib/training/generator.py index 966c886..2429f39 100644 --- a/lib/training/generator.py +++ b/lib/training/generator.py @@ -71,7 +71,7 @@ class DataGenerator(): self._batch_size = batch_size self._process_size = max(img[1] for img in model.input_shapes + model.output_shapes) - self._output_sizes = [shape[1] for shape in model.output_shapes if shape[-1] != 1] + self._output_sizes = self._get_output_sizes(model) self._coverage_ratio = model.coverage_ratio self._color_order = model.color_order.lower() @@ -103,6 +103,27 @@ class DataGenerator(): channels += len(mults) return channels + def _get_output_sizes(self, model: "ModelBase") -> List[int]: + """ Obtain the size of each output tensor for the model. + + Parameters + ---------- + model: :class:`~plugins.train.model.ModelBase` + The model that this data generator is feeding + + Returns + ------- + list + A list of integers for the model output size for the current side + """ + out_shapes = model.output_shapes + split = len(out_shapes) // 2 + side_out = out_shapes[:split] if self._side == "a" else out_shapes[split:] + retval = [shape[1] for shape in side_out if shape[-1] != 1] + logger.debug("side: %s, model output shapes: %s, output sizes: %s", + self._side, model.output_shapes, retval) + return retval + def minibatch_ab(self, do_shuffle: bool = True) -> Generator[BatchType, None, None]: """ A Background iterator to return augmented images, samples and targets. @@ -313,7 +334,6 @@ class DataGenerator(): batch = self._buffer() self._crop_to_coverage(filenames, raw_faces, detected_faces, batch) self._apply_mask(detected_faces, batch) - return self.process_batch(filenames, raw_faces, detected_faces, batch) def process_batch(self,