Adding Files
This commit is contained in:
parent
5066bc1d08
commit
da42ed9300
@ -1,3 +1,7 @@
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore', category=Warning)
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
|
||||
import tensorflow as tf
|
||||
from baselines.common.distributions import make_pdtype
|
||||
from utils import getsess, small_convnet, activ, fc, flatten_two_dims, unflatten_first_dim
|
||||
@ -55,6 +59,7 @@ class CnnPolicy(object):
|
||||
|
||||
with tf.variable_scope(self.scope + "_features", reuse=reuse):
|
||||
x = (tf.to_float(x) - self.ob_mean) / self.ob_std
|
||||
x = tf.transpose(x, [0, 2, 3, 1]) # shape=(None,84,84,4)
|
||||
x = small_convnet(x, nl=self.nl, feat_dim=self.feat_dim, last_nl=None, layernormalize=self.layernormalize)
|
||||
|
||||
if x_has_timesteps:
|
||||
|
@ -42,10 +42,10 @@ class DynamicBottleneck(object):
|
||||
self.features = self.get_features(self.obs) # (None,None,512)
|
||||
self.next_features = self.get_features(self.next_ob, momentum=True) # (None,None,512) stop gradient
|
||||
self.ac = self.policy.ph_ac # (None, None)
|
||||
self.ac_pad = tf.one_hot(self.ac, self.ac_space.n, axis=2)
|
||||
#self.ac_pad = tf.one_hot(self.ac, self.ac_space.n, axis=2)
|
||||
|
||||
# transition model
|
||||
latent_params = self.transition_model([self.features, self.ac_pad]) # (None, None, 256)
|
||||
latent_params = self.transition_model([self.features, self.ac]) # (None, None, 256)
|
||||
self.latent_dis = normal_parse_params(latent_params, 1e-3) # Gaussian. mu, sigma=(None, None, 128)
|
||||
|
||||
# prior
|
||||
@ -140,6 +140,7 @@ class DynamicBottleneck(object):
|
||||
|
||||
with tf.variable_scope(self.scope + "_features"):
|
||||
x = (tf.to_float(x) - self.ob_mean) / self.ob_std
|
||||
x = tf.transpose(x, [0, 2, 3, 1]) # shape=(None,84,84,4)
|
||||
if momentum:
|
||||
x = tf.stop_gradient(self.feature_conv_momentum(x)) # (None,512)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user