From 05dd20cdfab22fa2ed249a997a0f0e9263ffc467 Mon Sep 17 00:00:00 2001 From: VedantDave Date: Mon, 10 Apr 2023 20:17:44 +0200 Subject: [PATCH] Add a class to freeze parameters --- DPI/utils.py | 45 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/DPI/utils.py b/DPI/utils.py index 8c2cd14..a237e6f 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -17,6 +17,7 @@ import dmc2gym import cv2 from PIL import Image +from typing import Iterable class eval_mode(object): @@ -197,6 +198,12 @@ def make_env(args): ) return env +def soft_update_params(net, target_net, tau): + for param, target_param in zip(net.parameters(), target_net.parameters()): + target_param.data.copy_( + tau * param.data + (1 - tau) * target_param.data + ) + def save_image(array, filename): array = array.transpose(1, 2, 0) array = (array * 255).astype(np.uint8) @@ -256,4 +263,40 @@ class CorruptVideos: print(f"{filepath} is corrupt.") if delete: self._delete_corrupt_video(filepath) - print(f"Deleted {filepath}") \ No newline at end of file + print(f"Deleted {filepath}") + + +def get_parameters(modules: Iterable[nn.Module]): + """ + Given a list of torch modules, returns a list of their parameters. + :param modules: iterable of modules + :returns: a list of parameters + """ + model_parameters = [] + for module in modules: + model_parameters += list(module.parameters()) + return model_parameters + +class FreezeParameters: + def __init__(self, modules: Iterable[nn.Module]): + """ + Context manager to locally freeze gradients. + In some cases with can speed up computation because gradients aren't calculated for these listed modules. + example: + ``` + with FreezeParameters([module]): + output_tensor = module(input_tensor) + ``` + :param modules: iterable of modules. used to call .parameters() to freeze gradients. + """ + self.modules = modules + self.param_states = [p.requires_grad for p in get_parameters(self.modules)] + + def __enter__(self): + + for param in get_parameters(self.modules): + param.requires_grad = False + + def __exit__(self, exc_type, exc_val, exc_tb): + for i, param in enumerate(get_parameters(self.modules)): + param.requires_grad = self.param_states[i] \ No newline at end of file