Adding Files
This commit is contained in:
parent
e45c6ab696
commit
061057fa11
4
utils.py
4
utils.py
@ -206,7 +206,7 @@ class TransitionNetwork(tf.keras.Model):
|
|||||||
a = flatten_two_dims(a) # shape=(None,4)
|
a = flatten_two_dims(a) # shape=(None,4)
|
||||||
|
|
||||||
#
|
#
|
||||||
x = self.dense1(tf.concat([x, a], axis=-1)) # (None, 256)
|
x = self.dense1(tf.concat([x, a], axis=-1)) # (None, 256)
|
||||||
x = self.residual_block1([x, a]) # (None, 256)
|
x = self.residual_block1([x, a]) # (None, 256)
|
||||||
x = self.residual_block2([x, a]) # (None, 256)
|
x = self.residual_block2([x, a]) # (None, 256)
|
||||||
x = self.dense2(tf.concat([x, a], axis=-1)) # (None, 256)
|
x = self.dense2(tf.concat([x, a], axis=-1)) # (None, 256)
|
||||||
@ -292,7 +292,7 @@ class GenerativeNetworkGaussian(tf.keras.Model):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ProjectionHead(tf.keras.Model):
|
class ProjectionHead(tf.keras.Model):
|
||||||
def __init__(self, name=None):
|
def __init__(self, name=None):
|
||||||
super(ProjectionHead, self).__init__(name=name)
|
super(ProjectionHead, self).__init__(name=name)
|
||||||
self.dense1 = layers.Dense(256, activation=None)
|
self.dense1 = layers.Dense(256, activation=None)
|
||||||
|
Loading…
Reference in New Issue
Block a user