Source code for pqagent.pqmodel

from dataclasses import dataclass, field

import pandas as pd
from absl.testing.parameterized import parameters
from matplotlib import pyplot as plt
import torch
from ray.air import Result

from dwrappr import save_file, load_file

from .mlp import MLP

import logging
logger = logging.getLogger(__name__)

[docs] @dataclass class PqModel(): net: MLP = field(init=True) result: Result = field(default_factory=dict, init=False)
[docs] @classmethod def load_model(cls, filepath: str): model = load_file(filepath) return model
@property def metrics(self): return self.result.metrics
[docs] def set_result(self, ray_result: Result): self.result = ray_result
[docs] def predict(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor: self.net.eval() return self.net(X)
[docs] def save(self, path: str, *args, **kwargs) -> None: save_file(self, path)
[docs] def get_training_graph(self, metrics=list, show_plot:bool=True)->plt.figure: for metric in metrics: if metric not in self.result.metrics_dataframe.columns: raise ValueError(f"Entry '{metric}' not found in metrics of model.") for metric in metrics: plt.plot(self.result.metrics_dataframe['epoch'], self.result.metrics_dataframe[metric], label=metric) plt.xlabel("epoch") plt.legend() if show_plot: plt.show() return plt
[docs] def get_parameters(self): for name, param in self.net.named_parameters(): if 'weight' in name: print(f"Layer: {name} | Weights: {param.data}") elif 'bias' in name: print(f"Layer: {name} | Biases: {param.data}")