pqagent.mlp module

class pqagent.mlp.MLP(in_features: int, out_features: int, fc_layers: list | str, activation_func: str, layer_normalization: bool = False, dropout: float = 0)[source]

Bases: Module

forward(x: Tensor) Tensor[source]

Defines the computation performed at every call.

Args:

x (torch.Tensor): Input tensor.

Returns:

torch.Tensor: Output tensor.

freeze_group(group_name)[source]

Freeze the parameters of a specified group.

classmethod from_dict(param_space: dict, sweep_config: dict, in_features: int, out_features: int)[source]
get_num_of_layer_groups()[source]

Return the number of layer groups in the model.

reset_all_weights()[source]

Reset the weights of all layers in the model.

reset_layer_group_weights(group: str)[source]

Reset the weights of specific layer groups.

reset_layer_weights(m)[source]

Reset the weights of a given layer.

static train_fn(param_space: dict, sweep_config: dict, data: TrainingDataDict)[source]
unfreeze_all()[source]

Unfreeze all layers in the model.

unfreeze_group(group_name)[source]

Unfreeze the parameters of a specified group.

class pqagent.mlp.TrainingDataDict(X_train: torch.Tensor, y_train: torch.Tensor, X_val: torch.Tensor, y_val: torch.Tensor)[source]

Bases: object

X_train: Tensor
X_val: Tensor
y_train: Tensor
y_val: Tensor
pqagent.mlp.base_training_strategy(net: ~torch.nn.modules.module.Module, data: ~pqagent.mlp.TrainingDataDict, param_space: dict, sweep_config: dict, start_epoch: int = 1, epochs: int = None, optimizer=None) -> (<class 'torch.nn.modules.module.Module'>, <module 'torch.optim' from '/usr/local/lib/python3.11/site-packages/torch/optim/__init__.py'>, <built-in function any>)[source]
pqagent.mlp.check_model_improvement(goal, new, old)[source]
pqagent.mlp.get_loss_function(loss_function: str)[source]
pqagent.mlp.get_optimizer(optimizer: str, params: list[dict]) <module 'torch.optim' from '/usr/local/lib/python3.11/site-packages/torch/optim/__init__.py'>[source]
pqagent.mlp.report_train_progress(metrics, best_train_score, net, optimizer, loss_fn, save_checkpoint_based_on: str, metric_goal: str)[source]