Add a class to freeze parameters

This commit is contained in:
Vedant Dave 2023-04-10 20:17:44 +02:00
parent 8fd56ba94d
commit 05dd20cdfa

View File

@ -17,6 +17,7 @@ import dmc2gym
import cv2 import cv2
from PIL import Image from PIL import Image
from typing import Iterable
class eval_mode(object): class eval_mode(object):
@ -197,6 +198,12 @@ def make_env(args):
) )
return env return env
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(
tau * param.data + (1 - tau) * target_param.data
)
def save_image(array, filename): def save_image(array, filename):
array = array.transpose(1, 2, 0) array = array.transpose(1, 2, 0)
array = (array * 255).astype(np.uint8) array = (array * 255).astype(np.uint8)
@ -256,4 +263,40 @@ class CorruptVideos:
print(f"{filepath} is corrupt.") print(f"{filepath} is corrupt.")
if delete: if delete:
self._delete_corrupt_video(filepath) self._delete_corrupt_video(filepath)
print(f"Deleted {filepath}") print(f"Deleted {filepath}")
def get_parameters(modules: Iterable[nn.Module]):
"""
Given a list of torch modules, returns a list of their parameters.
:param modules: iterable of modules
:returns: a list of parameters
"""
model_parameters = []
for module in modules:
model_parameters += list(module.parameters())
return model_parameters
class FreezeParameters:
def __init__(self, modules: Iterable[nn.Module]):
"""
Context manager to locally freeze gradients.
In some cases with can speed up computation because gradients aren't calculated for these listed modules.
example:
```
with FreezeParameters([module]):
output_tensor = module(input_tensor)
```
:param modules: iterable of modules. used to call .parameters() to freeze gradients.
"""
self.modules = modules
self.param_states = [p.requires_grad for p in get_parameters(self.modules)]
def __enter__(self):
for param in get_parameters(self.modules):
param.requires_grad = False
def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(get_parameters(self.modules)):
param.requires_grad = self.param_states[i]