Adding Files

This commit is contained in:
Vedant Dave 2023-05-29 13:34:41 +02:00
parent e45c6ab696
commit 061057fa11

View File

@ -206,7 +206,7 @@ class TransitionNetwork(tf.keras.Model):
a = flatten_two_dims(a) # shape=(None,4) a = flatten_two_dims(a) # shape=(None,4)
# #
x = self.dense1(tf.concat([x, a], axis=-1)) # (None, 256) x = self.dense1(tf.concat([x, a], axis=-1)) # (None, 256)
x = self.residual_block1([x, a]) # (None, 256) x = self.residual_block1([x, a]) # (None, 256)
x = self.residual_block2([x, a]) # (None, 256) x = self.residual_block2([x, a]) # (None, 256)
x = self.dense2(tf.concat([x, a], axis=-1)) # (None, 256) x = self.dense2(tf.concat([x, a], axis=-1)) # (None, 256)
@ -292,7 +292,7 @@ class GenerativeNetworkGaussian(tf.keras.Model):
return x return x
class ProjectionHead(tf.keras.Model): class ProjectionHead(tf.keras.Model):
def __init__(self, name=None): def __init__(self, name=None):
super(ProjectionHead, self).__init__(name=name) super(ProjectionHead, self).__init__(name=name)
self.dense1 = layers.Dense(256, activation=None) self.dense1 = layers.Dense(256, activation=None)