diff --git a/lib/gui/analysis/stats.py b/lib/gui/analysis/stats.py index b055cfa..44f9651 100644 --- a/lib/gui/analysis/stats.py +++ b/lib/gui/analysis/stats.py @@ -60,10 +60,15 @@ class GlobalSession(): """ str: The full model filename """ return os.path.join(self._model_dir, self._model_name) + @property + def have_session_data(self) -> bool: + """ bool : ``True`` if session data is available otherwise ``False`` """ + return bool(self._state and self._state["sessions"]) + @property def batch_sizes(self) -> dict[int, int]: """ dict: The batch sizes for each session_id for the model. """ - if not self._state: + if not self.have_session_data: return {} return {int(sess_id): sess["batchsize"] for sess_id, sess in self._state.get("sessions", {}).items()} @@ -76,9 +81,9 @@ class GlobalSession(): @property def logging_disabled(self) -> bool: - """ bool: ``True`` if logging is enabled for the currently training session otherwise + """ bool: ``True`` if logging is disabled for the currently training session otherwise ``False``. """ - if not self._state: + if not self.have_session_data: return True max_id = str(max(int(idx) for idx in self._state["sessions"])) return self._state["sessions"][max_id]["no_logs"] @@ -311,6 +316,10 @@ class SessionsSummary(): within the loaded data as well as the totals. """ logger.debug("Compiling sessions summary data") + if not self._session.have_session_data: + logger.debug("Session data doesn't exist. Most likely task has been " + "terminated during compilation, or is from LR finder") + return [] self._get_time_stats() self._get_per_session_stats() if not self._per_session_stats: @@ -365,9 +374,9 @@ class SessionsSummary(): compiled = [] for session_id in self._time_stats: logger.debug("Compiling session ID: %s", session_id) - if not self._state: - logger.debug("Session state dict doesn't exist. Most likely task has been " - "terminated during compilation") + if not self._session.have_session_data: + logger.debug("Session data doesn't exist. Most likely task has been " + "terminated during compilation, or is from LR finder") return compiled.append(self._collate_stats(session_id)) @@ -435,6 +444,8 @@ class SessionsSummary(): iterations for all session ids within the loaded data. """ logger.debug("Compiling Totals") + starttime = 0.0 + endtime = 0.0 elapsed = 0 examples = 0 iterations = 0 @@ -450,13 +461,14 @@ class SessionsSummary(): batchset.add(summary["batch"]) iterations += summary["iterations"] batch = ",".join(str(bs) for bs in batchset) - totals = {"session": "Total", - "start": starttime, - "end": endtime, - "elapsed": elapsed, - "rate": examples / elapsed if elapsed != 0 else 0, - "batch": batch, - "iterations": iterations} + totals: dict[str, str | int | float] = { + "session": "Total", + "start": starttime, + "end": endtime, + "elapsed": elapsed, + "rate": examples / elapsed if elapsed != 0 else 0, + "batch": batch, + "iterations": iterations} logger.debug(totals) return totals @@ -533,7 +545,7 @@ class Calculations(): ``True`` if values significantly away from the average should be excluded, otherwise ``False``. Default: ``False`` """ - def __init__(self, session_id, + def __init__(self, session_id, # pylint:disable=too-many-positional-arguments display: str = "loss", loss_keys: list[str] | str = "loss", selections: list[str] | str = "raw", diff --git a/lib/gui/display_analysis.py b/lib/gui/display_analysis.py index 9dcef89..bf37123 100644 --- a/lib/gui/display_analysis.py +++ b/lib/gui/display_analysis.py @@ -195,12 +195,13 @@ class Analysis(DisplayPage): # pylint:disable=too-many-ancestors else: logger.debug("Retrieving data from thread") result = self._thread.get_result() - if result is None: + del self._thread + self._thread = None + if not result: logger.debug("No result from session summary. Clearing analysis view") self._clear_session() return self._summary = result - self._thread = None self.set_info(f"Session: {message}") self._stats.tree_insert_data(self._summary)