mirror of
https://github.com/zebrajr/faceswap.git
synced 2026-01-15 12:15:15 +00:00
typing: lib.gui.analysis.stats
This commit is contained in:
@@ -13,8 +13,7 @@ import warnings
|
||||
|
||||
from math import ceil
|
||||
from threading import Event
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing_extensions import Self
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -33,12 +32,12 @@ class GlobalSession():
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
self._state = None
|
||||
self._state: Dict[str, Any] = {}
|
||||
self._model_dir = ""
|
||||
self._model_name = ""
|
||||
|
||||
self._tb_logs = None
|
||||
self._summary = None
|
||||
self._tb_logs: Optional[TensorBoardLogs] = None
|
||||
self._summary: Optional[SessionsSummary] = None
|
||||
|
||||
self._is_training = False
|
||||
self._is_querying = Event()
|
||||
@@ -62,9 +61,9 @@ class GlobalSession():
|
||||
return os.path.join(self._model_dir, self._model_name)
|
||||
|
||||
@property
|
||||
def batch_sizes(self) -> dict:
|
||||
def batch_sizes(self) -> Dict[int, int]:
|
||||
""" dict: The batch sizes for each session_id for the model. """
|
||||
if self._state is None:
|
||||
if not self._state:
|
||||
return {}
|
||||
return {int(sess_id): sess["batchsize"]
|
||||
for sess_id, sess in self._state.get("sessions", {}).items()}
|
||||
@@ -72,19 +71,21 @@ class GlobalSession():
|
||||
@property
|
||||
def full_summary(self) -> List[dict]:
|
||||
""" list: List of dictionaries containing summary statistics for each session id. """
|
||||
assert self._summary is not None
|
||||
return self._summary.get_summary_stats()
|
||||
|
||||
@property
|
||||
def logging_disabled(self) -> bool:
|
||||
""" bool: ``True`` if logging is enabled for the currently training session otherwise
|
||||
``False``. """
|
||||
if self._state is None:
|
||||
if not self._state:
|
||||
return True
|
||||
return self._state["sessions"][str(self.session_ids[-1])]["no_logs"]
|
||||
|
||||
@property
|
||||
def session_ids(self) -> List[int]:
|
||||
""" list: The sorted list of all existing session ids in the state file """
|
||||
assert self._tb_logs is not None
|
||||
return self._tb_logs.session_ids
|
||||
|
||||
def _load_state_file(self) -> None:
|
||||
@@ -96,8 +97,8 @@ class GlobalSession():
|
||||
logger.debug("Loaded state: %s", self._state)
|
||||
|
||||
def initialize_session(self,
|
||||
model_folder: Optional[str],
|
||||
model_name: Optional[str],
|
||||
model_folder: str,
|
||||
model_name: str,
|
||||
is_training: bool = False) -> None:
|
||||
""" Initialize a Session.
|
||||
|
||||
@@ -106,12 +107,14 @@ class GlobalSession():
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_folder: str, optional
|
||||
model_folder: str,
|
||||
If loading a session manually (e.g. for the analysis tab), then the path to the model
|
||||
folder must be provided. For training sessions, this should be left at ``None``
|
||||
folder must be provided. For training sessions, this should be passed through from the
|
||||
launcher
|
||||
model_name: str, optional
|
||||
If loading a session manually (e.g. for the analysis tab), then the model filename
|
||||
must be provided. For training sessions, this should be left at ``None``
|
||||
must be provided. For training sessions, this should be passed through from the
|
||||
launcher
|
||||
is_training: bool, optional
|
||||
``True`` if the session is being initialized for a training session, otherwise
|
||||
``False``. Default: ``False``
|
||||
@@ -120,6 +123,7 @@ class GlobalSession():
|
||||
|
||||
if self._model_dir == model_folder and self._model_name == model_name:
|
||||
if is_training:
|
||||
assert self._tb_logs is not None
|
||||
self._tb_logs.set_training(is_training)
|
||||
self._load_state_file()
|
||||
self._is_training = True
|
||||
@@ -157,7 +161,7 @@ class GlobalSession():
|
||||
|
||||
self._is_training = False
|
||||
|
||||
def get_loss(self, session_id: Optional[int]) -> dict:
|
||||
def get_loss(self, session_id: Optional[int]) -> Dict[str, np.ndarray]:
|
||||
""" Obtain the loss values for the given session_id.
|
||||
|
||||
Parameters
|
||||
@@ -176,13 +180,15 @@ class GlobalSession():
|
||||
if self._is_training:
|
||||
self._is_querying.set()
|
||||
|
||||
assert self._tb_logs is not None
|
||||
loss_dict = self._tb_logs.get_loss(session_id=session_id)
|
||||
if session_id is None:
|
||||
retval = {}
|
||||
all_loss: Dict[str, List[float]] = {}
|
||||
for key in sorted(loss_dict):
|
||||
for loss_key, loss in loss_dict[key].items():
|
||||
retval.setdefault(loss_key, []).extend(loss)
|
||||
retval = {key: np.array(val, dtype="float32") for key, val in retval.items()}
|
||||
all_loss.setdefault(loss_key, []).extend(loss)
|
||||
retval: Dict[str, np.ndarray] = {key: np.array(val, dtype="float32")
|
||||
for key, val in all_loss.items()}
|
||||
else:
|
||||
retval = loss_dict.get(session_id, {})
|
||||
|
||||
@@ -190,7 +196,8 @@ class GlobalSession():
|
||||
self._is_querying.clear()
|
||||
return retval
|
||||
|
||||
def get_timestamps(self, session_id: Optional[int]) -> Union[dict, np.ndarray]:
|
||||
def get_timestamps(self, session_id: Optional[int]) -> Union[Dict[int, np.ndarray],
|
||||
np.ndarray]:
|
||||
""" Obtain the time stamps keys for the given session_id.
|
||||
|
||||
Parameters
|
||||
@@ -211,6 +218,7 @@ class GlobalSession():
|
||||
if self._is_training:
|
||||
self._is_querying.set()
|
||||
|
||||
assert self._tb_logs is not None
|
||||
retval = self._tb_logs.get_timestamps(session_id=session_id)
|
||||
if session_id is not None:
|
||||
retval = retval[session_id]
|
||||
@@ -249,16 +257,17 @@ class GlobalSession():
|
||||
loss_keys = {int(sess_id): [name for name in session["loss_names"] if name != "total"]
|
||||
for sess_id, session in self._state["sessions"].items()}
|
||||
else:
|
||||
assert self._tb_logs is not None
|
||||
loss_keys = {sess_id: list(logs.keys())
|
||||
for sess_id, logs
|
||||
in self._tb_logs.get_loss(session_id=session_id).items()}
|
||||
|
||||
if session_id is None:
|
||||
retval = list(set(loss_key
|
||||
for session in loss_keys.values()
|
||||
for loss_key in session))
|
||||
retval: List[str] = list(set(loss_key
|
||||
for session in loss_keys.values()
|
||||
for loss_key in session))
|
||||
else:
|
||||
retval = loss_keys.get(session_id)
|
||||
retval = loss_keys.get(session_id, [])
|
||||
return retval
|
||||
|
||||
|
||||
@@ -279,8 +288,8 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
||||
self._session = session
|
||||
self._state = session._state
|
||||
|
||||
self._time_stats = None
|
||||
self._per_session_stats = None
|
||||
self._time_stats: Dict[int, Dict[str, Union[float, int]]] = {}
|
||||
self._per_session_stats: List[Dict[str, Any]] = []
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def get_summary_stats(self) -> List[dict]:
|
||||
@@ -315,20 +324,21 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
||||
If the main Session is currently training, then the training session ID is updated with the
|
||||
latest stats.
|
||||
"""
|
||||
if self._time_stats is None:
|
||||
if not self._time_stats:
|
||||
logger.debug("Collating summary time stamps")
|
||||
|
||||
self._time_stats = {
|
||||
sess_id: dict(start_time=np.min(timestamps) if np.any(timestamps) else 0,
|
||||
end_time=np.max(timestamps) if np.any(timestamps) else 0,
|
||||
iterations=timestamps.shape[0] if np.any(timestamps) else 0)
|
||||
for sess_id, timestamps in self._session.get_timestamps(None).items()}
|
||||
for sess_id, timestamps in cast(Dict[int, np.ndarray],
|
||||
self._session.get_timestamps(None)).items()}
|
||||
|
||||
elif _SESSION.is_training:
|
||||
logger.debug("Updating summary time stamps for training session")
|
||||
|
||||
session_id = _SESSION.session_ids[-1]
|
||||
latest = self._session.get_timestamps(session_id)
|
||||
latest = cast(np.ndarray, self._session.get_timestamps(session_id))
|
||||
|
||||
self._time_stats[session_id] = dict(
|
||||
start_time=np.min(latest) if np.any(latest) else 0,
|
||||
@@ -344,12 +354,12 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
||||
|
||||
If a training session is running, then updates the training sessions stats only.
|
||||
"""
|
||||
if self._per_session_stats is None:
|
||||
if not self._per_session_stats:
|
||||
logger.debug("Collating per session stats")
|
||||
compiled = []
|
||||
for session_id in self._time_stats:
|
||||
logger.debug("Compiling session ID: %s", session_id)
|
||||
if self._state is None:
|
||||
if not self._state:
|
||||
logger.debug("Session state dict doesn't exist. Most likely task has been "
|
||||
"terminated during compilation")
|
||||
return
|
||||
@@ -377,7 +387,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
||||
/ stats["elapsed"] if stats["elapsed"] > 0 else 0)
|
||||
logger.debug("per_session_stats: %s", self._per_session_stats)
|
||||
|
||||
def _collate_stats(self, session_id: int) -> dict:
|
||||
def _collate_stats(self, session_id: int) -> Dict[str, Union[int, float]]:
|
||||
""" Collate the session summary statistics for the given session ID.
|
||||
|
||||
Parameters
|
||||
@@ -406,14 +416,14 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
||||
logger.debug(retval)
|
||||
return retval
|
||||
|
||||
def _total_stats(self) -> dict:
|
||||
def _total_stats(self) -> Dict[str, Union[str, int, float]]:
|
||||
""" Compile the Totals stats.
|
||||
Totals are fully calculated each time as they will change on the basis of the training
|
||||
session.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
dict
|
||||
The Session name, start time, end time, elapsed time, rate, batch size and number of
|
||||
iterations for all session ids within the loaded data.
|
||||
"""
|
||||
@@ -486,8 +496,8 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
||||
tuple
|
||||
(`hours`, `minutes`, `seconds`) as strings
|
||||
"""
|
||||
hrs = int(timestamp // 3600)
|
||||
hrs = f"{hrs:02d}" if hrs < 10 else str(hrs)
|
||||
ihrs = int(timestamp // 3600)
|
||||
hrs = f"{ihrs:02d}" if ihrs < 10 else str(ihrs)
|
||||
mins = f"{(int(timestamp % 3600) // 60):02d}"
|
||||
secs = f"{(int(timestamp % 3600) % 60):02d}"
|
||||
return hrs, mins, secs
|
||||
@@ -536,13 +546,13 @@ class Calculations():
|
||||
self._loss_keys = loss_keys if isinstance(loss_keys, list) else [loss_keys]
|
||||
self._selections = selections if isinstance(selections, list) else [selections]
|
||||
self._is_totals = session_id is None
|
||||
self._args = dict(avg_samples=avg_samples,
|
||||
smooth_amount=smooth_amount,
|
||||
flatten_outliers=flatten_outliers)
|
||||
self._args: Dict[str, Union[int, float]] = dict(avg_samples=avg_samples,
|
||||
smooth_amount=smooth_amount,
|
||||
flatten_outliers=flatten_outliers)
|
||||
self._iterations = 0
|
||||
self._limit = 0
|
||||
self._start_iteration = 0
|
||||
self._stats = {}
|
||||
self._stats: Dict[str, np.ndarray] = {}
|
||||
self.refresh()
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
@@ -557,11 +567,11 @@ class Calculations():
|
||||
return self._start_iteration
|
||||
|
||||
@property
|
||||
def stats(self) -> dict:
|
||||
def stats(self) -> Dict[str, np.ndarray]:
|
||||
""" dict: The final calculated statistics """
|
||||
return self._stats
|
||||
|
||||
def refresh(self) -> Optional[Self]:
|
||||
def refresh(self) -> Optional["Calculations"]:
|
||||
""" Refresh the stats """
|
||||
logger.debug("Refreshing")
|
||||
if not _SESSION.is_loaded:
|
||||
@@ -658,11 +668,11 @@ class Calculations():
|
||||
if len(iterations) > 1:
|
||||
# Crop all losses to the same number of items
|
||||
if self._iterations == 0:
|
||||
self.stats = {lossname: np.array([], dtype=loss.dtype)
|
||||
for lossname, loss in self.stats.items()}
|
||||
self._stats = {lossname: np.array([], dtype=loss.dtype)
|
||||
for lossname, loss in self.stats.items()}
|
||||
else:
|
||||
self.stats = {lossname: loss[:self._iterations]
|
||||
for lossname, loss in self.stats.items()}
|
||||
self._stats = {lossname: loss[:self._iterations]
|
||||
for lossname, loss in self.stats.items()}
|
||||
|
||||
else: # Rate calculation
|
||||
data = self._calc_rate_total() if self._is_totals else self._calc_rate()
|
||||
@@ -719,8 +729,8 @@ class Calculations():
|
||||
The training rate for each iteration of the selected session
|
||||
"""
|
||||
logger.debug("Calculating rate")
|
||||
retval = (_SESSION.batch_sizes[self._session_id] * 2) / np.diff(_SESSION.get_timestamps(
|
||||
self._session_id))
|
||||
batch_size = _SESSION.batch_sizes[self._session_id] * 2
|
||||
retval = batch_size / np.diff(cast(np.ndarray, _SESSION.get_timestamps(self._session_id)))
|
||||
logger.debug("Calculated rate: Item_count: %s", len(retval))
|
||||
return retval
|
||||
|
||||
@@ -740,8 +750,8 @@ class Calculations():
|
||||
"""
|
||||
logger.debug("Calculating totals rate")
|
||||
batchsizes = _SESSION.batch_sizes
|
||||
total_timestamps = _SESSION.get_timestamps(None)
|
||||
rate = []
|
||||
total_timestamps = cast(Dict[int, np.ndarray], _SESSION.get_timestamps(None))
|
||||
rate: List[float] = []
|
||||
for sess_id in sorted(total_timestamps.keys()):
|
||||
batchsize = batchsizes[sess_id]
|
||||
timestamps = total_timestamps[sess_id]
|
||||
@@ -781,7 +791,7 @@ class Calculations():
|
||||
The moving average for the given data
|
||||
"""
|
||||
logger.debug("Calculating Average. Data points: %s", len(data))
|
||||
window = self._args["avg_samples"]
|
||||
window = cast(int, self._args["avg_samples"])
|
||||
pad = ceil(window / 2)
|
||||
datapoints = data.shape[0]
|
||||
|
||||
@@ -968,8 +978,8 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
|
||||
out /= scaling_factors[-2::-1] # cumulative sums / scaling
|
||||
|
||||
if offset != 0:
|
||||
offset = np.array(offset, copy=False).astype(self._dtype, copy=False)
|
||||
out += offset * scaling_factors[1:]
|
||||
noffset = np.array(offset, copy=False).astype(self._dtype, copy=False)
|
||||
out += noffset * scaling_factors[1:]
|
||||
|
||||
def _ewma_vectorized_2d(self, data: np.ndarray, out: np.ndarray) -> None:
|
||||
""" Calculates the exponential moving average over the last axis.
|
||||
|
||||
@@ -8,6 +8,7 @@ import tkinter as tk
|
||||
from tkinter import colorchooser, ttk
|
||||
from itertools import zip_longest
|
||||
from functools import partial
|
||||
from typing import Any, Dict
|
||||
|
||||
from _tkinter import Tcl_Obj, TclError
|
||||
|
||||
@@ -23,7 +24,7 @@ _ = _LANG.gettext
|
||||
# We store Tooltips, ContextMenus and Commands globally when they are created
|
||||
# Because we need to add them back to newly cloned widgets (they are not easily accessible from
|
||||
# original config or are prone to getting destroyed when the original widget is destroyed)
|
||||
_RECREATE_OBJECTS = dict(tooltips={}, commands={}, contextmenus={})
|
||||
_RECREATE_OBJECTS: Dict[str, Dict[str, Any]] = dict(tooltips={}, commands={}, contextmenus={})
|
||||
|
||||
|
||||
def _get_tooltip(widget, text=None, text_variable=None):
|
||||
|
||||
Reference in New Issue
Block a user