diff --git a/lib/training/generator.py b/lib/training/generator.py index 3ae9e0e..966c886 100644 --- a/lib/training/generator.py +++ b/lib/training/generator.py @@ -69,9 +69,10 @@ class DataGenerator(): self._side = side self._images = images self._batch_size = batch_size - self._process_size = max(img[1] - for img in model.model.input_shape + model.model.output_shape) - self._output_sizes = [shape[0] for shape in model.output_shapes[0] if shape[-1] != 1] + + self._process_size = max(img[1] for img in model.input_shapes + model.output_shapes) + self._output_sizes = [shape[1] for shape in model.output_shapes if shape[-1] != 1] + self._coverage_ratio = model.coverage_ratio self._color_order = model.color_order.lower() self._use_mask = self._config["mask_type"] and (self._config["penalized_mask_loss"] or @@ -399,7 +400,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met self._no_warp = model.command_line_arguments.no_warp self._warp_to_landmarks = (not self._no_warp and model.command_line_arguments.warp_to_landmarks) - self._model_input_size = max(img[1] for img in model.model.input_shape) + self._model_input_size = max(img[1] for img in model.input_shapes) if self._warp_to_landmarks: self._face_cache.pre_fill(images, side) diff --git a/plugins/train/model/_base/model.py b/plugins/train/model/_base/model.py index c7bab46..1f890b8 100644 --- a/plugins/train/model/_base/model.py +++ b/plugins/train/model/_base/model.py @@ -10,7 +10,7 @@ import sys import time from collections import OrderedDict -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import numpy as np @@ -204,13 +204,18 @@ class ModelBase(): return self.name @property - def output_shapes(self) -> List[List[Tuple[int, int, int]]]: - """ list: A list of list of shape tuples for the outputs of the model with the batch - dimension removed. The outer list contains 2 sub-lists (one for each side "a" and "b"). - The inner sub-lists contain the output shapes for that side. """ - shapes: List[Tuple[int, int, int]] = [tuple(K.int_shape(output)[-3:]) # type: ignore - for output in self.model.outputs] - return [shapes[:len(shapes) // 2], shapes[len(shapes) // 2:]] + def input_shapes(self) -> List[Tuple[None, int, int, int]]: + """ list: A flattened list corresponding to all of the inputs to the model. """ + shapes = [cast(Tuple[None, int, int, int], K.int_shape(inputs)) + for inputs in self.model.inputs] + return shapes + + @property + def output_shapes(self) -> List[Tuple[None, int, int, int]]: + """ list: A flattened list corresponding to all of the outputs of the model. """ + shapes = [cast(Tuple[None, int, int, int], K.int_shape(output)) + for output in self.model.outputs] + return shapes @property def iterations(self) -> int: