Adding preprocessing function

This commit is contained in:
Vedant Dave 2023-04-12 17:16:47 +02:00
parent 7d7387bd5d
commit 1f4667a08d

View File

@ -200,6 +200,10 @@ def make_env(args):
) )
return env return env
def preprocess_obs(obs):
obs = obs/255.0 - 0.5
return obs
def soft_update_params(net, target_net, tau): def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()): for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_( target_param.data.copy_(
@ -301,4 +305,4 @@ class FreezeParameters:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(get_parameters(self.modules)): for i, param in enumerate(get_parameters(self.modules)):
param.requires_grad = self.param_states[i] param.requires_grad = self.param_states[i]