typing: lib.gui.analysis.stats

This commit is contained in:
torzdf
2022-10-13 11:46:54 +01:00
parent 8910ae505b
commit 47867a0dd4
2 changed files with 63 additions and 52 deletions

View File

@@ -13,8 +13,7 @@ import warnings
from math import ceil
from threading import Event
from typing import List, Optional, Tuple, Union
from typing_extensions import Self
from typing import Any, cast, Dict, List, Optional, Tuple, Union
import numpy as np
@@ -33,12 +32,12 @@ class GlobalSession():
"""
def __init__(self) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
self._state = None
self._state: Dict[str, Any] = {}
self._model_dir = ""
self._model_name = ""
self._tb_logs = None
self._summary = None
self._tb_logs: Optional[TensorBoardLogs] = None
self._summary: Optional[SessionsSummary] = None
self._is_training = False
self._is_querying = Event()
@@ -62,9 +61,9 @@ class GlobalSession():
return os.path.join(self._model_dir, self._model_name)
@property
def batch_sizes(self) -> dict:
def batch_sizes(self) -> Dict[int, int]:
""" dict: The batch sizes for each session_id for the model. """
if self._state is None:
if not self._state:
return {}
return {int(sess_id): sess["batchsize"]
for sess_id, sess in self._state.get("sessions", {}).items()}
@@ -72,19 +71,21 @@ class GlobalSession():
@property
def full_summary(self) -> List[dict]:
""" list: List of dictionaries containing summary statistics for each session id. """
assert self._summary is not None
return self._summary.get_summary_stats()
@property
def logging_disabled(self) -> bool:
""" bool: ``True`` if logging is enabled for the currently training session otherwise
``False``. """
if self._state is None:
if not self._state:
return True
return self._state["sessions"][str(self.session_ids[-1])]["no_logs"]
@property
def session_ids(self) -> List[int]:
""" list: The sorted list of all existing session ids in the state file """
assert self._tb_logs is not None
return self._tb_logs.session_ids
def _load_state_file(self) -> None:
@@ -96,8 +97,8 @@ class GlobalSession():
logger.debug("Loaded state: %s", self._state)
def initialize_session(self,
model_folder: Optional[str],
model_name: Optional[str],
model_folder: str,
model_name: str,
is_training: bool = False) -> None:
""" Initialize a Session.
@@ -106,12 +107,14 @@ class GlobalSession():
Parameters
----------
model_folder: str, optional
model_folder: str,
If loading a session manually (e.g. for the analysis tab), then the path to the model
folder must be provided. For training sessions, this should be left at ``None``
folder must be provided. For training sessions, this should be passed through from the
launcher
model_name: str, optional
If loading a session manually (e.g. for the analysis tab), then the model filename
must be provided. For training sessions, this should be left at ``None``
must be provided. For training sessions, this should be passed through from the
launcher
is_training: bool, optional
``True`` if the session is being initialized for a training session, otherwise
``False``. Default: ``False``
@@ -120,6 +123,7 @@ class GlobalSession():
if self._model_dir == model_folder and self._model_name == model_name:
if is_training:
assert self._tb_logs is not None
self._tb_logs.set_training(is_training)
self._load_state_file()
self._is_training = True
@@ -157,7 +161,7 @@ class GlobalSession():
self._is_training = False
def get_loss(self, session_id: Optional[int]) -> dict:
def get_loss(self, session_id: Optional[int]) -> Dict[str, np.ndarray]:
""" Obtain the loss values for the given session_id.
Parameters
@@ -176,13 +180,15 @@ class GlobalSession():
if self._is_training:
self._is_querying.set()
assert self._tb_logs is not None
loss_dict = self._tb_logs.get_loss(session_id=session_id)
if session_id is None:
retval = {}
all_loss: Dict[str, List[float]] = {}
for key in sorted(loss_dict):
for loss_key, loss in loss_dict[key].items():
retval.setdefault(loss_key, []).extend(loss)
retval = {key: np.array(val, dtype="float32") for key, val in retval.items()}
all_loss.setdefault(loss_key, []).extend(loss)
retval: Dict[str, np.ndarray] = {key: np.array(val, dtype="float32")
for key, val in all_loss.items()}
else:
retval = loss_dict.get(session_id, {})
@@ -190,7 +196,8 @@ class GlobalSession():
self._is_querying.clear()
return retval
def get_timestamps(self, session_id: Optional[int]) -> Union[dict, np.ndarray]:
def get_timestamps(self, session_id: Optional[int]) -> Union[Dict[int, np.ndarray],
np.ndarray]:
""" Obtain the time stamps keys for the given session_id.
Parameters
@@ -211,6 +218,7 @@ class GlobalSession():
if self._is_training:
self._is_querying.set()
assert self._tb_logs is not None
retval = self._tb_logs.get_timestamps(session_id=session_id)
if session_id is not None:
retval = retval[session_id]
@@ -249,16 +257,17 @@ class GlobalSession():
loss_keys = {int(sess_id): [name for name in session["loss_names"] if name != "total"]
for sess_id, session in self._state["sessions"].items()}
else:
assert self._tb_logs is not None
loss_keys = {sess_id: list(logs.keys())
for sess_id, logs
in self._tb_logs.get_loss(session_id=session_id).items()}
if session_id is None:
retval = list(set(loss_key
for session in loss_keys.values()
for loss_key in session))
retval: List[str] = list(set(loss_key
for session in loss_keys.values()
for loss_key in session))
else:
retval = loss_keys.get(session_id)
retval = loss_keys.get(session_id, [])
return retval
@@ -279,8 +288,8 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
self._session = session
self._state = session._state
self._time_stats = None
self._per_session_stats = None
self._time_stats: Dict[int, Dict[str, Union[float, int]]] = {}
self._per_session_stats: List[Dict[str, Any]] = []
logger.debug("Initialized %s", self.__class__.__name__)
def get_summary_stats(self) -> List[dict]:
@@ -315,20 +324,21 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
If the main Session is currently training, then the training session ID is updated with the
latest stats.
"""
if self._time_stats is None:
if not self._time_stats:
logger.debug("Collating summary time stamps")
self._time_stats = {
sess_id: dict(start_time=np.min(timestamps) if np.any(timestamps) else 0,
end_time=np.max(timestamps) if np.any(timestamps) else 0,
iterations=timestamps.shape[0] if np.any(timestamps) else 0)
for sess_id, timestamps in self._session.get_timestamps(None).items()}
for sess_id, timestamps in cast(Dict[int, np.ndarray],
self._session.get_timestamps(None)).items()}
elif _SESSION.is_training:
logger.debug("Updating summary time stamps for training session")
session_id = _SESSION.session_ids[-1]
latest = self._session.get_timestamps(session_id)
latest = cast(np.ndarray, self._session.get_timestamps(session_id))
self._time_stats[session_id] = dict(
start_time=np.min(latest) if np.any(latest) else 0,
@@ -344,12 +354,12 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
If a training session is running, then updates the training sessions stats only.
"""
if self._per_session_stats is None:
if not self._per_session_stats:
logger.debug("Collating per session stats")
compiled = []
for session_id in self._time_stats:
logger.debug("Compiling session ID: %s", session_id)
if self._state is None:
if not self._state:
logger.debug("Session state dict doesn't exist. Most likely task has been "
"terminated during compilation")
return
@@ -377,7 +387,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
/ stats["elapsed"] if stats["elapsed"] > 0 else 0)
logger.debug("per_session_stats: %s", self._per_session_stats)
def _collate_stats(self, session_id: int) -> dict:
def _collate_stats(self, session_id: int) -> Dict[str, Union[int, float]]:
""" Collate the session summary statistics for the given session ID.
Parameters
@@ -406,14 +416,14 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
logger.debug(retval)
return retval
def _total_stats(self) -> dict:
def _total_stats(self) -> Dict[str, Union[str, int, float]]:
""" Compile the Totals stats.
Totals are fully calculated each time as they will change on the basis of the training
session.
Returns
-------
dict:
dict
The Session name, start time, end time, elapsed time, rate, batch size and number of
iterations for all session ids within the loaded data.
"""
@@ -486,8 +496,8 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
tuple
(`hours`, `minutes`, `seconds`) as strings
"""
hrs = int(timestamp // 3600)
hrs = f"{hrs:02d}" if hrs < 10 else str(hrs)
ihrs = int(timestamp // 3600)
hrs = f"{ihrs:02d}" if ihrs < 10 else str(ihrs)
mins = f"{(int(timestamp % 3600) // 60):02d}"
secs = f"{(int(timestamp % 3600) % 60):02d}"
return hrs, mins, secs
@@ -536,13 +546,13 @@ class Calculations():
self._loss_keys = loss_keys if isinstance(loss_keys, list) else [loss_keys]
self._selections = selections if isinstance(selections, list) else [selections]
self._is_totals = session_id is None
self._args = dict(avg_samples=avg_samples,
smooth_amount=smooth_amount,
flatten_outliers=flatten_outliers)
self._args: Dict[str, Union[int, float]] = dict(avg_samples=avg_samples,
smooth_amount=smooth_amount,
flatten_outliers=flatten_outliers)
self._iterations = 0
self._limit = 0
self._start_iteration = 0
self._stats = {}
self._stats: Dict[str, np.ndarray] = {}
self.refresh()
logger.debug("Initialized %s", self.__class__.__name__)
@@ -557,11 +567,11 @@ class Calculations():
return self._start_iteration
@property
def stats(self) -> dict:
def stats(self) -> Dict[str, np.ndarray]:
""" dict: The final calculated statistics """
return self._stats
def refresh(self) -> Optional[Self]:
def refresh(self) -> Optional["Calculations"]:
""" Refresh the stats """
logger.debug("Refreshing")
if not _SESSION.is_loaded:
@@ -658,11 +668,11 @@ class Calculations():
if len(iterations) > 1:
# Crop all losses to the same number of items
if self._iterations == 0:
self.stats = {lossname: np.array([], dtype=loss.dtype)
for lossname, loss in self.stats.items()}
self._stats = {lossname: np.array([], dtype=loss.dtype)
for lossname, loss in self.stats.items()}
else:
self.stats = {lossname: loss[:self._iterations]
for lossname, loss in self.stats.items()}
self._stats = {lossname: loss[:self._iterations]
for lossname, loss in self.stats.items()}
else: # Rate calculation
data = self._calc_rate_total() if self._is_totals else self._calc_rate()
@@ -719,8 +729,8 @@ class Calculations():
The training rate for each iteration of the selected session
"""
logger.debug("Calculating rate")
retval = (_SESSION.batch_sizes[self._session_id] * 2) / np.diff(_SESSION.get_timestamps(
self._session_id))
batch_size = _SESSION.batch_sizes[self._session_id] * 2
retval = batch_size / np.diff(cast(np.ndarray, _SESSION.get_timestamps(self._session_id)))
logger.debug("Calculated rate: Item_count: %s", len(retval))
return retval
@@ -740,8 +750,8 @@ class Calculations():
"""
logger.debug("Calculating totals rate")
batchsizes = _SESSION.batch_sizes
total_timestamps = _SESSION.get_timestamps(None)
rate = []
total_timestamps = cast(Dict[int, np.ndarray], _SESSION.get_timestamps(None))
rate: List[float] = []
for sess_id in sorted(total_timestamps.keys()):
batchsize = batchsizes[sess_id]
timestamps = total_timestamps[sess_id]
@@ -781,7 +791,7 @@ class Calculations():
The moving average for the given data
"""
logger.debug("Calculating Average. Data points: %s", len(data))
window = self._args["avg_samples"]
window = cast(int, self._args["avg_samples"])
pad = ceil(window / 2)
datapoints = data.shape[0]
@@ -968,8 +978,8 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
out /= scaling_factors[-2::-1] # cumulative sums / scaling
if offset != 0:
offset = np.array(offset, copy=False).astype(self._dtype, copy=False)
out += offset * scaling_factors[1:]
noffset = np.array(offset, copy=False).astype(self._dtype, copy=False)
out += noffset * scaling_factors[1:]
def _ewma_vectorized_2d(self, data: np.ndarray, out: np.ndarray) -> None:
""" Calculates the exponential moving average over the last axis.

View File

@@ -8,6 +8,7 @@ import tkinter as tk
from tkinter import colorchooser, ttk
from itertools import zip_longest
from functools import partial
from typing import Any, Dict
from _tkinter import Tcl_Obj, TclError
@@ -23,7 +24,7 @@ _ = _LANG.gettext
# We store Tooltips, ContextMenus and Commands globally when they are created
# Because we need to add them back to newly cloned widgets (they are not easily accessible from
# original config or are prone to getting destroyed when the original widget is destroyed)
_RECREATE_OBJECTS = dict(tooltips={}, commands={}, contextmenus={})
_RECREATE_OBJECTS: Dict[str, Dict[str, Any]] = dict(tooltips={}, commands={}, contextmenus={})
def _get_tooltip(widget, text=None, text_variable=None):