diff --git a/utils.py b/utils.py index 84e19a4..9d26bb9 100644 --- a/utils.py +++ b/utils.py @@ -206,7 +206,7 @@ class TransitionNetwork(tf.keras.Model): a = flatten_two_dims(a) # shape=(None,4) # - x = self.dense1(tf.concat([x, a], axis=-1)) # (None, 256) + x = self.dense1(tf.concat([x, a], axis=-1)) # (None, 256) x = self.residual_block1([x, a]) # (None, 256) x = self.residual_block2([x, a]) # (None, 256) x = self.dense2(tf.concat([x, a], axis=-1)) # (None, 256) @@ -292,7 +292,7 @@ class GenerativeNetworkGaussian(tf.keras.Model): return x -class ProjectionHead(tf.keras.Model): +class ProjectionHead(tf.keras.Model): def __init__(self, name=None): super(ProjectionHead, self).__init__(name=name) self.dense1 = layers.Dense(256, activation=None)