Adding preprocessing function
This commit is contained in:
parent
7d7387bd5d
commit
1f4667a08d
@ -200,6 +200,10 @@ def make_env(args):
|
|||||||
)
|
)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
def preprocess_obs(obs):
|
||||||
|
obs = obs/255.0 - 0.5
|
||||||
|
return obs
|
||||||
|
|
||||||
def soft_update_params(net, target_net, tau):
|
def soft_update_params(net, target_net, tau):
|
||||||
for param, target_param in zip(net.parameters(), target_net.parameters()):
|
for param, target_param in zip(net.parameters(), target_net.parameters()):
|
||||||
target_param.data.copy_(
|
target_param.data.copy_(
|
||||||
@ -301,4 +305,4 @@ class FreezeParameters:
|
|||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
for i, param in enumerate(get_parameters(self.modules)):
|
for i, param in enumerate(get_parameters(self.modules)):
|
||||||
param.requires_grad = self.param_states[i]
|
param.requires_grad = self.param_states[i]
|
Loading…
Reference in New Issue
Block a user