bugfix: Get correct output size for learn mask

This commit is contained in:
torzdf
2022-08-29 01:04:40 +01:00
parent 05077265d7
commit f3b88d5626
2 changed files with 18 additions and 12 deletions

View File

@@ -69,9 +69,10 @@ class DataGenerator():
self._side = side
self._images = images
self._batch_size = batch_size
self._process_size = max(img[1]
for img in model.model.input_shape + model.model.output_shape)
self._output_sizes = [shape[0] for shape in model.output_shapes[0] if shape[-1] != 1]
self._process_size = max(img[1] for img in model.input_shapes + model.output_shapes)
self._output_sizes = [shape[1] for shape in model.output_shapes if shape[-1] != 1]
self._coverage_ratio = model.coverage_ratio
self._color_order = model.color_order.lower()
self._use_mask = self._config["mask_type"] and (self._config["penalized_mask_loss"] or
@@ -399,7 +400,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
self._no_warp = model.command_line_arguments.no_warp
self._warp_to_landmarks = (not self._no_warp
and model.command_line_arguments.warp_to_landmarks)
self._model_input_size = max(img[1] for img in model.model.input_shape)
self._model_input_size = max(img[1] for img in model.input_shapes)
if self._warp_to_landmarks:
self._face_cache.pre_fill(images, side)

View File

@@ -10,7 +10,7 @@ import sys
import time
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import numpy as np
@@ -204,13 +204,18 @@ class ModelBase():
return self.name
@property
def output_shapes(self) -> List[List[Tuple[int, int, int]]]:
""" list: A list of list of shape tuples for the outputs of the model with the batch
dimension removed. The outer list contains 2 sub-lists (one for each side "a" and "b").
The inner sub-lists contain the output shapes for that side. """
shapes: List[Tuple[int, int, int]] = [tuple(K.int_shape(output)[-3:]) # type: ignore
for output in self.model.outputs]
return [shapes[:len(shapes) // 2], shapes[len(shapes) // 2:]]
def input_shapes(self) -> List[Tuple[None, int, int, int]]:
""" list: A flattened list corresponding to all of the inputs to the model. """
shapes = [cast(Tuple[None, int, int, int], K.int_shape(inputs))
for inputs in self.model.inputs]
return shapes
@property
def output_shapes(self) -> List[Tuple[None, int, int, int]]:
""" list: A flattened list corresponding to all of the outputs of the model. """
shapes = [cast(Tuple[None, int, int, int], K.int_shape(output))
for output in self.model.outputs]
return shapes
@property
def iterations(self) -> int: