diff --git a/DPI/utils.py b/DPI/utils.py index 83e9b9d..3790104 100644 --- a/DPI/utils.py +++ b/DPI/utils.py @@ -200,6 +200,10 @@ def make_env(args): ) return env +def preprocess_obs(obs): + obs = obs/255.0 - 0.5 + return obs + def soft_update_params(net, target_net, tau): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_( @@ -301,4 +305,4 @@ class FreezeParameters: def __exit__(self, exc_type, exc_val, exc_tb): for i, param in enumerate(get_parameters(self.modules)): - param.requires_grad = self.param_states[i] \ No newline at end of file + param.requires_grad = self.param_states[i] \ No newline at end of file