From 061057fa11b4ac21433fb794bf1510ef4514a752 Mon Sep 17 00:00:00 2001 From: VedantDave Date: Mon, 29 May 2023 13:34:41 +0200 Subject: [PATCH] Adding Files --- utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils.py b/utils.py index 84e19a4..9d26bb9 100644 --- a/utils.py +++ b/utils.py @@ -206,7 +206,7 @@ class TransitionNetwork(tf.keras.Model): a = flatten_two_dims(a) # shape=(None,4) # - x = self.dense1(tf.concat([x, a], axis=-1)) # (None, 256) + x = self.dense1(tf.concat([x, a], axis=-1)) # (None, 256) x = self.residual_block1([x, a]) # (None, 256) x = self.residual_block2([x, a]) # (None, 256) x = self.dense2(tf.concat([x, a], axis=-1)) # (None, 256) @@ -292,7 +292,7 @@ class GenerativeNetworkGaussian(tf.keras.Model): return x -class ProjectionHead(tf.keras.Model): +class ProjectionHead(tf.keras.Model): def __init__(self, name=None): super(ProjectionHead, self).__init__(name=name) self.dense1 = layers.Dense(256, activation=None)