From 47867a0dd424b3e31d7beead0ffdb8b37c970a9e Mon Sep 17 00:00:00 2001 From: torzdf <36920800+torzdf@users.noreply.github.com> Date: Thu, 13 Oct 2022 11:46:54 +0100 Subject: [PATCH] typing: lib.gui.analysis.stats --- lib/gui/analysis/stats.py | 112 +++++++++++++++++++++----------------- lib/gui/control_helper.py | 3 +- 2 files changed, 63 insertions(+), 52 deletions(-) diff --git a/lib/gui/analysis/stats.py b/lib/gui/analysis/stats.py index b9c8612..a92a4b0 100644 --- a/lib/gui/analysis/stats.py +++ b/lib/gui/analysis/stats.py @@ -13,8 +13,7 @@ import warnings from math import ceil from threading import Event -from typing import List, Optional, Tuple, Union -from typing_extensions import Self +from typing import Any, cast, Dict, List, Optional, Tuple, Union import numpy as np @@ -33,12 +32,12 @@ class GlobalSession(): """ def __init__(self) -> None: logger.debug("Initializing %s", self.__class__.__name__) - self._state = None + self._state: Dict[str, Any] = {} self._model_dir = "" self._model_name = "" - self._tb_logs = None - self._summary = None + self._tb_logs: Optional[TensorBoardLogs] = None + self._summary: Optional[SessionsSummary] = None self._is_training = False self._is_querying = Event() @@ -62,9 +61,9 @@ class GlobalSession(): return os.path.join(self._model_dir, self._model_name) @property - def batch_sizes(self) -> dict: + def batch_sizes(self) -> Dict[int, int]: """ dict: The batch sizes for each session_id for the model. """ - if self._state is None: + if not self._state: return {} return {int(sess_id): sess["batchsize"] for sess_id, sess in self._state.get("sessions", {}).items()} @@ -72,19 +71,21 @@ class GlobalSession(): @property def full_summary(self) -> List[dict]: """ list: List of dictionaries containing summary statistics for each session id. """ + assert self._summary is not None return self._summary.get_summary_stats() @property def logging_disabled(self) -> bool: """ bool: ``True`` if logging is enabled for the currently training session otherwise ``False``. """ - if self._state is None: + if not self._state: return True return self._state["sessions"][str(self.session_ids[-1])]["no_logs"] @property def session_ids(self) -> List[int]: """ list: The sorted list of all existing session ids in the state file """ + assert self._tb_logs is not None return self._tb_logs.session_ids def _load_state_file(self) -> None: @@ -96,8 +97,8 @@ class GlobalSession(): logger.debug("Loaded state: %s", self._state) def initialize_session(self, - model_folder: Optional[str], - model_name: Optional[str], + model_folder: str, + model_name: str, is_training: bool = False) -> None: """ Initialize a Session. @@ -106,12 +107,14 @@ class GlobalSession(): Parameters ---------- - model_folder: str, optional + model_folder: str, If loading a session manually (e.g. for the analysis tab), then the path to the model - folder must be provided. For training sessions, this should be left at ``None`` + folder must be provided. For training sessions, this should be passed through from the + launcher model_name: str, optional If loading a session manually (e.g. for the analysis tab), then the model filename - must be provided. For training sessions, this should be left at ``None`` + must be provided. For training sessions, this should be passed through from the + launcher is_training: bool, optional ``True`` if the session is being initialized for a training session, otherwise ``False``. Default: ``False`` @@ -120,6 +123,7 @@ class GlobalSession(): if self._model_dir == model_folder and self._model_name == model_name: if is_training: + assert self._tb_logs is not None self._tb_logs.set_training(is_training) self._load_state_file() self._is_training = True @@ -157,7 +161,7 @@ class GlobalSession(): self._is_training = False - def get_loss(self, session_id: Optional[int]) -> dict: + def get_loss(self, session_id: Optional[int]) -> Dict[str, np.ndarray]: """ Obtain the loss values for the given session_id. Parameters @@ -176,13 +180,15 @@ class GlobalSession(): if self._is_training: self._is_querying.set() + assert self._tb_logs is not None loss_dict = self._tb_logs.get_loss(session_id=session_id) if session_id is None: - retval = {} + all_loss: Dict[str, List[float]] = {} for key in sorted(loss_dict): for loss_key, loss in loss_dict[key].items(): - retval.setdefault(loss_key, []).extend(loss) - retval = {key: np.array(val, dtype="float32") for key, val in retval.items()} + all_loss.setdefault(loss_key, []).extend(loss) + retval: Dict[str, np.ndarray] = {key: np.array(val, dtype="float32") + for key, val in all_loss.items()} else: retval = loss_dict.get(session_id, {}) @@ -190,7 +196,8 @@ class GlobalSession(): self._is_querying.clear() return retval - def get_timestamps(self, session_id: Optional[int]) -> Union[dict, np.ndarray]: + def get_timestamps(self, session_id: Optional[int]) -> Union[Dict[int, np.ndarray], + np.ndarray]: """ Obtain the time stamps keys for the given session_id. Parameters @@ -211,6 +218,7 @@ class GlobalSession(): if self._is_training: self._is_querying.set() + assert self._tb_logs is not None retval = self._tb_logs.get_timestamps(session_id=session_id) if session_id is not None: retval = retval[session_id] @@ -249,16 +257,17 @@ class GlobalSession(): loss_keys = {int(sess_id): [name for name in session["loss_names"] if name != "total"] for sess_id, session in self._state["sessions"].items()} else: + assert self._tb_logs is not None loss_keys = {sess_id: list(logs.keys()) for sess_id, logs in self._tb_logs.get_loss(session_id=session_id).items()} if session_id is None: - retval = list(set(loss_key - for session in loss_keys.values() - for loss_key in session)) + retval: List[str] = list(set(loss_key + for session in loss_keys.values() + for loss_key in session)) else: - retval = loss_keys.get(session_id) + retval = loss_keys.get(session_id, []) return retval @@ -279,8 +288,8 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods self._session = session self._state = session._state - self._time_stats = None - self._per_session_stats = None + self._time_stats: Dict[int, Dict[str, Union[float, int]]] = {} + self._per_session_stats: List[Dict[str, Any]] = [] logger.debug("Initialized %s", self.__class__.__name__) def get_summary_stats(self) -> List[dict]: @@ -315,20 +324,21 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods If the main Session is currently training, then the training session ID is updated with the latest stats. """ - if self._time_stats is None: + if not self._time_stats: logger.debug("Collating summary time stamps") self._time_stats = { sess_id: dict(start_time=np.min(timestamps) if np.any(timestamps) else 0, end_time=np.max(timestamps) if np.any(timestamps) else 0, iterations=timestamps.shape[0] if np.any(timestamps) else 0) - for sess_id, timestamps in self._session.get_timestamps(None).items()} + for sess_id, timestamps in cast(Dict[int, np.ndarray], + self._session.get_timestamps(None)).items()} elif _SESSION.is_training: logger.debug("Updating summary time stamps for training session") session_id = _SESSION.session_ids[-1] - latest = self._session.get_timestamps(session_id) + latest = cast(np.ndarray, self._session.get_timestamps(session_id)) self._time_stats[session_id] = dict( start_time=np.min(latest) if np.any(latest) else 0, @@ -344,12 +354,12 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods If a training session is running, then updates the training sessions stats only. """ - if self._per_session_stats is None: + if not self._per_session_stats: logger.debug("Collating per session stats") compiled = [] for session_id in self._time_stats: logger.debug("Compiling session ID: %s", session_id) - if self._state is None: + if not self._state: logger.debug("Session state dict doesn't exist. Most likely task has been " "terminated during compilation") return @@ -377,7 +387,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods / stats["elapsed"] if stats["elapsed"] > 0 else 0) logger.debug("per_session_stats: %s", self._per_session_stats) - def _collate_stats(self, session_id: int) -> dict: + def _collate_stats(self, session_id: int) -> Dict[str, Union[int, float]]: """ Collate the session summary statistics for the given session ID. Parameters @@ -406,14 +416,14 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods logger.debug(retval) return retval - def _total_stats(self) -> dict: + def _total_stats(self) -> Dict[str, Union[str, int, float]]: """ Compile the Totals stats. Totals are fully calculated each time as they will change on the basis of the training session. Returns ------- - dict: + dict The Session name, start time, end time, elapsed time, rate, batch size and number of iterations for all session ids within the loaded data. """ @@ -486,8 +496,8 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods tuple (`hours`, `minutes`, `seconds`) as strings """ - hrs = int(timestamp // 3600) - hrs = f"{hrs:02d}" if hrs < 10 else str(hrs) + ihrs = int(timestamp // 3600) + hrs = f"{ihrs:02d}" if ihrs < 10 else str(ihrs) mins = f"{(int(timestamp % 3600) // 60):02d}" secs = f"{(int(timestamp % 3600) % 60):02d}" return hrs, mins, secs @@ -536,13 +546,13 @@ class Calculations(): self._loss_keys = loss_keys if isinstance(loss_keys, list) else [loss_keys] self._selections = selections if isinstance(selections, list) else [selections] self._is_totals = session_id is None - self._args = dict(avg_samples=avg_samples, - smooth_amount=smooth_amount, - flatten_outliers=flatten_outliers) + self._args: Dict[str, Union[int, float]] = dict(avg_samples=avg_samples, + smooth_amount=smooth_amount, + flatten_outliers=flatten_outliers) self._iterations = 0 self._limit = 0 self._start_iteration = 0 - self._stats = {} + self._stats: Dict[str, np.ndarray] = {} self.refresh() logger.debug("Initialized %s", self.__class__.__name__) @@ -557,11 +567,11 @@ class Calculations(): return self._start_iteration @property - def stats(self) -> dict: + def stats(self) -> Dict[str, np.ndarray]: """ dict: The final calculated statistics """ return self._stats - def refresh(self) -> Optional[Self]: + def refresh(self) -> Optional["Calculations"]: """ Refresh the stats """ logger.debug("Refreshing") if not _SESSION.is_loaded: @@ -658,11 +668,11 @@ class Calculations(): if len(iterations) > 1: # Crop all losses to the same number of items if self._iterations == 0: - self.stats = {lossname: np.array([], dtype=loss.dtype) - for lossname, loss in self.stats.items()} + self._stats = {lossname: np.array([], dtype=loss.dtype) + for lossname, loss in self.stats.items()} else: - self.stats = {lossname: loss[:self._iterations] - for lossname, loss in self.stats.items()} + self._stats = {lossname: loss[:self._iterations] + for lossname, loss in self.stats.items()} else: # Rate calculation data = self._calc_rate_total() if self._is_totals else self._calc_rate() @@ -719,8 +729,8 @@ class Calculations(): The training rate for each iteration of the selected session """ logger.debug("Calculating rate") - retval = (_SESSION.batch_sizes[self._session_id] * 2) / np.diff(_SESSION.get_timestamps( - self._session_id)) + batch_size = _SESSION.batch_sizes[self._session_id] * 2 + retval = batch_size / np.diff(cast(np.ndarray, _SESSION.get_timestamps(self._session_id))) logger.debug("Calculated rate: Item_count: %s", len(retval)) return retval @@ -740,8 +750,8 @@ class Calculations(): """ logger.debug("Calculating totals rate") batchsizes = _SESSION.batch_sizes - total_timestamps = _SESSION.get_timestamps(None) - rate = [] + total_timestamps = cast(Dict[int, np.ndarray], _SESSION.get_timestamps(None)) + rate: List[float] = [] for sess_id in sorted(total_timestamps.keys()): batchsize = batchsizes[sess_id] timestamps = total_timestamps[sess_id] @@ -781,7 +791,7 @@ class Calculations(): The moving average for the given data """ logger.debug("Calculating Average. Data points: %s", len(data)) - window = self._args["avg_samples"] + window = cast(int, self._args["avg_samples"]) pad = ceil(window / 2) datapoints = data.shape[0] @@ -968,8 +978,8 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods out /= scaling_factors[-2::-1] # cumulative sums / scaling if offset != 0: - offset = np.array(offset, copy=False).astype(self._dtype, copy=False) - out += offset * scaling_factors[1:] + noffset = np.array(offset, copy=False).astype(self._dtype, copy=False) + out += noffset * scaling_factors[1:] def _ewma_vectorized_2d(self, data: np.ndarray, out: np.ndarray) -> None: """ Calculates the exponential moving average over the last axis. diff --git a/lib/gui/control_helper.py b/lib/gui/control_helper.py index 3b471c2..a3dabc2 100644 --- a/lib/gui/control_helper.py +++ b/lib/gui/control_helper.py @@ -8,6 +8,7 @@ import tkinter as tk from tkinter import colorchooser, ttk from itertools import zip_longest from functools import partial +from typing import Any, Dict from _tkinter import Tcl_Obj, TclError @@ -23,7 +24,7 @@ _ = _LANG.gettext # We store Tooltips, ContextMenus and Commands globally when they are created # Because we need to add them back to newly cloned widgets (they are not easily accessible from # original config or are prone to getting destroyed when the original widget is destroyed) -_RECREATE_OBJECTS = dict(tooltips={}, commands={}, contextmenus={}) +_RECREATE_OBJECTS: Dict[str, Dict[str, Any]] = dict(tooltips={}, commands={}, contextmenus={}) def _get_tooltip(widget, text=None, text_variable=None):