
77 lines
3.8 KiB
Raw Permalink Normal View History

2023-05-29 15:23:53 +00:00
import warnings
warnings.filterwarnings('ignore', category=Warning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
2021-10-09 00:33:47 +00:00
import tensorflow as tf
from baselines.common.distributions import make_pdtype
from utils import getsess, small_convnet, activ, fc, flatten_two_dims, unflatten_first_dim
class CnnPolicy(object):
def __init__(self, ob_space, ac_space, hidsize,
ob_mean, ob_std, feat_dim, layernormalize, nl, scope="policy"):
""" ob_space: (84,84,4); ac_space: 4;
ob_mean.shape=(84,84,4); ob_std=1.7; hidsize: 512;
feat_dim: 512; layernormalize: False; nl: tf.nn.leaky_relu.
if layernormalize:
print("Warning: policy is operating on top of layer-normed features. It might slow down the training.")
self.layernormalize = layernormalize
self.nl = nl
self.ob_mean = ob_mean
self.ob_std = ob_std
with tf.variable_scope(scope):
self.ob_space = ob_space
self.ac_space = ac_space
self.ac_pdtype = make_pdtype(ac_space)
self.ph_ob = tf.placeholder(dtype=tf.int32,
shape=(None, None) + ob_space.shape, name='ob')
self.ph_ac = self.ac_pdtype.sample_placeholder([None, None], name='ac')
self.pd = self.vpred = None
self.hidsize = hidsize
self.feat_dim = feat_dim
self.scope = scope
pdparamsize = self.ac_pdtype.param_shape()[0]
sh = tf.shape(self.ph_ob) # ph_ob.shape = (None,None,84,84,4)
x = flatten_two_dims(self.ph_ob) # x.shape = (None,84,84,4)
self.flat_features = self.get_features(x, reuse=False) # shape=(None,512)
self.features = unflatten_first_dim(self.flat_features, sh) # shape=(None,None,512)
with tf.variable_scope(scope, reuse=False):
x = fc(self.flat_features, units=hidsize, activation=activ) # activ=tf.nn.relu
x = fc(x, units=hidsize, activation=activ) # value and policy
pdparam = fc(x, name='pd', units=pdparamsize, activation=None) # logits, shape=(None,4)
vpred = fc(x, name='value_function_output', units=1, activation=None) # shape=(None,1)
pdparam = unflatten_first_dim(pdparam, sh) # shape=(None,None,4)
self.vpred = unflatten_first_dim(vpred, sh)[:, :, 0] # value function shape=(None,None)
self.pd = pd = self.ac_pdtype.pdfromflat(pdparam) # mean,neglogp,kl,entropy,sample
self.a_samp = pd.sample() #
self.entropy = pd.entropy() # (None,None)
self.nlp_samp = pd.neglogp(self.a_samp) # -log pi(a|s) (None,None)
def get_features(self, x, reuse):
x_has_timesteps = (x.get_shape().ndims == 5)
if x_has_timesteps:
sh = tf.shape(x)
x = flatten_two_dims(x)
with tf.variable_scope(self.scope + "_features", reuse=reuse):
x = (tf.to_float(x) - self.ob_mean) / self.ob_std
2023-05-29 15:23:53 +00:00
x = tf.transpose(x, [0, 2, 3, 1]) # shape=(None,84,84,4)
2021-10-09 00:33:47 +00:00
x = small_convnet(x, nl=self.nl, feat_dim=self.feat_dim, last_nl=None, layernormalize=self.layernormalize)
if x_has_timesteps:
x = unflatten_first_dim(x, sh)
return x
def get_ac_value_nlp(self, ob):
# ob.shape=(128,84,84,1), ob[:,None].shape=(128,1,84,84,4)
a, vpred, nlp = \
getsess().run([self.a_samp, self.vpred, self.nlp_samp],
feed_dict={self.ph_ob: ob[:, None]})
return a[:, 0], vpred[:, 0], nlp[:, 0]