Adding Files

This commit is contained in:
Vedant Dave 2023-05-29 17:23:53 +02:00
parent 5066bc1d08
commit da42ed9300
3 changed files with 13 additions and 2 deletions

View File

@ -1,3 +1,7 @@
import warnings
warnings.filterwarnings('ignore', category=Warning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
import tensorflow as tf
from baselines.common.distributions import make_pdtype
from utils import getsess, small_convnet, activ, fc, flatten_two_dims, unflatten_first_dim
@ -55,6 +59,7 @@ class CnnPolicy(object):
with tf.variable_scope(self.scope + "_features", reuse=reuse):
x = (tf.to_float(x) - self.ob_mean) / self.ob_std
x = tf.transpose(x, [0, 2, 3, 1]) # shape=(None,84,84,4)
x = small_convnet(x, nl=self.nl, feat_dim=self.feat_dim, last_nl=None, layernormalize=self.layernormalize)
if x_has_timesteps:

View File

@ -42,10 +42,10 @@ class DynamicBottleneck(object):
self.features = self.get_features(self.obs) # (None,None,512)
self.next_features = self.get_features(self.next_ob, momentum=True) # (None,None,512) stop gradient
self.ac = self.policy.ph_ac # (None, None)
self.ac_pad = tf.one_hot(self.ac, self.ac_space.n, axis=2)
#self.ac_pad = tf.one_hot(self.ac, self.ac_space.n, axis=2)
# transition model
latent_params = self.transition_model([self.features, self.ac_pad]) # (None, None, 256)
latent_params = self.transition_model([self.features, self.ac]) # (None, None, 256)
self.latent_dis = normal_parse_params(latent_params, 1e-3) # Gaussian. mu, sigma=(None, None, 128)
# prior
@ -140,6 +140,7 @@ class DynamicBottleneck(object):
with tf.variable_scope(self.scope + "_features"):
x = (tf.to_float(x) - self.ob_mean) / self.ob_std
x = tf.transpose(x, [0, 2, 3, 1]) # shape=(None,84,84,4)
if momentum:
x = tf.stop_gradient(self.feature_conv_momentum(x)) # (None,512)
else:

5
run.py
View File

@ -3,6 +3,11 @@ try:
from OpenGL import GLU
except:
print("no OpenGL.GLU")
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
import functools
import os.path as osp
from functools import partial