Merge branch 'main'

This commit is contained in:
Vedant Dave 2023-07-18 16:32:33 +02:00
commit 3b61469681
36 changed files with 22 additions and 19 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -122,7 +122,7 @@ class DMCWrapper(core.Env):
if img_source == "images": if img_source == "images":
self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=False, max_videos=100, random_bg=False) self._bg_source = natural_imgsource.RandomImageSource(shape2d, files, grayscale=False, max_videos=100, random_bg=False)
elif img_source == "video": elif img_source == "video":
self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=False,max_videos=100, random_bg=False) self._bg_source = natural_imgsource.RandomVideoSource(shape2d, files, grayscale=False, max_videos=100, random_bg=False)
else: else:
raise Exception("img_source %s not defined." % img_source) raise Exception("img_source %s not defined." % img_source)

View File

@ -308,6 +308,7 @@ class SeparationDreamer(Dreamer):
with tf.GradientTape(persistent=True) as model_tape: with tf.GradientTape(persistent=True) as model_tape:
# main # main
data["image"] = tf.transpose(data["image"], perm=[0, 1, 3, 4, 2])
embed = self._encode(data) embed = self._encode(data)
post, prior = self._dynamics.observe(embed, data['action']) post, prior = self._dynamics.observe(embed, data['action'])
feat = self._dynamics.get_feat(post) feat = self._dynamics.get_feat(post)

View File

@ -87,10 +87,12 @@ def video_summary(name, video, step=None, fps=20):
def encode_gif(frames, fps): def encode_gif(frames, fps):
from subprocess import Popen, PIPE from subprocess import Popen, PIPE
print(frames[0].shape) print(frames[0].shape)
if frames[0].shape[-1] != 3: if frames[0].shape[-1] > 3:
frames = np.transpose(frames, [0, 2, 3, 1]) frames = np.transpose(frames, [0, 2, 3, 1])
h, w, c = frames[0].shape h, w, c = frames[0].shape
print(h,w,c) print(frames[0].shape)
if c!=64:
pxfmt = {1: 'gray', 3: 'rgb24'}[c] pxfmt = {1: 'gray', 3: 'rgb24'}[c]
cmd = ' '.join([ cmd = ' '.join([
f'ffmpeg -y -f rawvideo -vcodec rawvideo', f'ffmpeg -y -f rawvideo -vcodec rawvideo',

View File

@ -1,6 +1,6 @@
dmc: dmc:
logdir: /home/vedant/tia/Dreamer/logdir logdir: /media/vedant/cpsDataStorageWK/Vedant/tia_logs
video_dir_train: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/train/ video_dir_train: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/train/
video_dir_test: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/test/ video_dir_test: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/test/
debug: False debug: False

View File

@ -1,6 +1,6 @@
dmc: dmc:
logdir: /home/vedant/tia/Dreamer/logdir logdir: /media/vedant/cpsDataStorageWK/Vedant/tia_logs
video_dir_train: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/train/ video_dir_train: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/train/
video_dir_test: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/test/ video_dir_test: /media/vedant/cpsDataStorageWK/Vedant/natural_video_setting/test/
debug: False debug: False