diff --git a/cnn_policy.py b/cnn_policy.py index a8b4f44..3914db4 100644 --- a/cnn_policy.py +++ b/cnn_policy.py @@ -1,3 +1,7 @@ +import warnings +warnings.filterwarnings('ignore', category=Warning) +warnings.filterwarnings('ignore', category=DeprecationWarning) + import tensorflow as tf from baselines.common.distributions import make_pdtype from utils import getsess, small_convnet, activ, fc, flatten_two_dims, unflatten_first_dim @@ -55,6 +59,7 @@ class CnnPolicy(object): with tf.variable_scope(self.scope + "_features", reuse=reuse): x = (tf.to_float(x) - self.ob_mean) / self.ob_std + x = tf.transpose(x, [0, 2, 3, 1]) # shape=(None,84,84,4) x = small_convnet(x, nl=self.nl, feat_dim=self.feat_dim, last_nl=None, layernormalize=self.layernormalize) if x_has_timesteps: diff --git a/dynamic_bottleneck.py b/dynamic_bottleneck.py index 9481755..c1d86cb 100644 --- a/dynamic_bottleneck.py +++ b/dynamic_bottleneck.py @@ -42,10 +42,10 @@ class DynamicBottleneck(object): self.features = self.get_features(self.obs) # (None,None,512) self.next_features = self.get_features(self.next_ob, momentum=True) # (None,None,512) stop gradient self.ac = self.policy.ph_ac # (None, None) - self.ac_pad = tf.one_hot(self.ac, self.ac_space.n, axis=2) + #self.ac_pad = tf.one_hot(self.ac, self.ac_space.n, axis=2) # transition model - latent_params = self.transition_model([self.features, self.ac_pad]) # (None, None, 256) + latent_params = self.transition_model([self.features, self.ac]) # (None, None, 256) self.latent_dis = normal_parse_params(latent_params, 1e-3) # Gaussian. mu, sigma=(None, None, 128) # prior @@ -140,6 +140,7 @@ class DynamicBottleneck(object): with tf.variable_scope(self.scope + "_features"): x = (tf.to_float(x) - self.ob_mean) / self.ob_std + x = tf.transpose(x, [0, 2, 3, 1]) # shape=(None,84,84,4) if momentum: x = tf.stop_gradient(self.feature_conv_momentum(x)) # (None,512) else: diff --git a/run.py b/run.py index b3433e5..86ce047 100644 --- a/run.py +++ b/run.py @@ -3,6 +3,11 @@ try: from OpenGL import GLU except: print("no OpenGL.GLU") + +import warnings +warnings.filterwarnings('ignore', category=FutureWarning) +warnings.filterwarnings('ignore', category=DeprecationWarning) + import functools import os.path as osp from functools import partial