Implementing ICLUB

This commit is contained in:
Vedant Dave 2023-03-24 20:39:14 +01:00
parent abaca2bea9
commit 641c9bd57c
3 changed files with 77 additions and 28 deletions

View File

@ -128,25 +128,27 @@ class TransitionModel(nn.Module):
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def __init__(self, x_dim, y_dim, hidden_size): def __init__(self, x_dim, y_dim, hidden_size):
super(CLUBSample, self).__init__() super(CLUBSample, self).__init__()
self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2), self.p_mu = nn.Sequential(
nn.ReLU(), nn.Linear(x_dim, hidden_size//2),
nn.Linear(hidden_size//2, y_dim)) nn.ReLU(),
nn.Linear(hidden_size//2, y_dim)
)
self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2), self.p_logvar = nn.Sequential(
nn.ReLU(), nn.Linear(x_dim, hidden_size//2),
nn.Linear(hidden_size//2, y_dim), nn.ReLU(),
nn.Tanh()) nn.Linear(hidden_size//2, y_dim),
nn.Tanh()
)
def get_mu_logvar(self, x_samples): def get_mu_logvar(self, x_samples):
mu = self.p_mu(x_samples) mu = self.p_mu(x_samples)
logvar = self.p_logvar(x_samples) logvar = self.p_logvar(x_samples)
return mu, logvar return mu, logvar
def loglikeli(self, x_samples, y_samples): def loglikeli(self, x_samples, y_samples):
mu, logvar = self.get_mu_logvar(x_samples) mu, logvar = self.get_mu_logvar(x_samples)
return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0) return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)
def forward(self, x_samples, y_samples): def forward(self, x_samples, y_samples):
mu, logvar = self.get_mu_logvar(x_samples) mu, logvar = self.get_mu_logvar(x_samples)
@ -165,8 +167,9 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
if __name__ == "__main__": if __name__ == "__main__":
encoder = ObservationEncoder((12,84,84), 256) encoder = ObservationEncoder((12,84,84), 256)
x = torch.randn(100, 12, 84, 84) x = torch.randn(5000, 12, 84, 84)
print(encoder(x).shape) print(encoder(x).shape)
exit()
club = CLUBSample(256, 256 , 512) club = CLUBSample(256, 256 , 512)
x = torch.randn(100, 256) x = torch.randn(100, 256)

View File

@ -9,7 +9,7 @@ import dmc2gym
import wandb import wandb
import utils import utils
from utils import ReplayBuffer, make_env from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
from logger import Logger from logger import Logger
from video import VideoRecorder from video import VideoRecorder
@ -34,18 +34,18 @@ def parse_args():
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int) parser.add_argument('--total_frames', default=1000, type=int)
# replay buffer # replay buffer
parser.add_argument('--replay_buffer_capacity', default=100000, type=int) parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000
parser.add_argument('--episode_length', default=1000, type=int) parser.add_argument('--episode_length', default=50, type=int)
# train # train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=1000, type=int) parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000, type=int) parser.add_argument('--num_train_steps', default=1000, type=int)
parser.add_argument('--batch_size', default=512, type=int) parser.add_argument('--batch_size', default=200, type=int) #512
parser.add_argument('--state_size', default=256, type=int) parser.add_argument('--state_size', default=256, type=int)
parser.add_argument('--hidden_size', default=128, type=int) parser.add_argument('--hidden_size', default=128, type=int)
parser.add_argument('--history_size', default=128, type=int) parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model')
parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagination_horizon', default=15, type=str)
# eval # eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--num_eval_episodes', default=20, type=int)
@ -79,7 +79,6 @@ def parse_args():
parser.add_argument('--alpha_beta', default=0.9, type=float) parser.add_argument('--alpha_beta', default=0.9, type=float)
# misc # misc
parser.add_argument('--seed', default=1, type=int) parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--seed_steps', default=5000, type=int)
parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true') parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true') parser.add_argument('--save_model', default=False, action='store_true')
@ -117,20 +116,21 @@ class DPI:
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
action_size=self.env.action_space.shape[0], action_size=self.env.action_space.shape[0],
seq_len=self.args.episode_length, seq_len=self.args.episode_length,
batch_size=args.batch_size) batch_size=args.batch_size,
args=self.args)
# create work directory # create work directory
utils.make_dir(self.args.work_dir) utils.make_dir(self.args.work_dir)
video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video')) self.video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video'))
model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model')) self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer')) self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
# create video recorder # create video recorder
#video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files) #video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files)
#video.init(enabled=True) #video.init(enabled=True)
# create models # create models
self.build_models(use_saved=False, saved_model_dir=model_dir) self.build_models(use_saved=False, saved_model_dir=self.model_dir)
def build_models(self, use_saved, saved_model_dir=None): def build_models(self, use_saved, saved_model_dir=None):
self.obs_encoder = ObservationEncoder( self.obs_encoder = ObservationEncoder(
@ -167,13 +167,14 @@ class DPI:
def collect_random_episodes(self, episodes): def collect_random_episodes(self, episodes):
obs = self.env.reset() obs = self.env.reset()
done = False done = False
for episode_count in range(episodes): for episode_count in range(episodes):
for i in range(self.args.episode_length): for i in range(self.args.episode_length):
action = self.env.action_space.sample() action = self.env.action_space.sample()
next_obs, _, done, _ = self.env.step(action) next_obs, _, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, episode_count+1, done) self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
if done: if done:
obs = self.env.reset() obs = self.env.reset()
done=False done=False
@ -185,12 +186,33 @@ class DPI:
#video.save('%d.mp4' % step) #video.save('%d.mp4' % step)
#video.close() #video.close()
def upper_bound_minimization(self): def train(self):
pass # collect experience
self.collect_random_episodes(self.args.batch_size)
# Group observations and next_observations by steps
observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
# Train encoder
for i in range(self.args.episode_length):
# Encode observations and next_observations
self.features = self.obs_encoder(observations[i]) # (N,128)
self.next_features = self.obs_encoder(next_observations[i]) # (N,128)
# Calculate upper bound loss
past_loss = self.upper_bound_minimization(self.features, self.next_features)
def upper_bound_minimization(self, features, next_features):
club_sample = CLUBSample(self.args.state_size,
self.args.state_size,
self.args.hidden_size)
club_loss = club_sample(features, next_features)
return club_loss
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
dpi = DPI(args) dpi = DPI(args)
dpi.collect_random_episodes(episodes=5) dpi.train()

View File

@ -13,6 +13,7 @@ import gym
import dmc2gym import dmc2gym
import random import random
from PIL import Image
from collections import deque from collections import deque
@ -105,7 +106,7 @@ class FrameStack(gym.Wrapper):
class ReplayBuffer: class ReplayBuffer:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size): def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args):
self.size = size self.size = size
self.obs_shape = obs_shape self.obs_shape = obs_shape
self.action_size = action_size self.action_size = action_size
@ -113,6 +114,7 @@ class ReplayBuffer:
self.batch_size = batch_size self.batch_size = batch_size
self.idx = 0 self.idx = 0
self.full = False self.full = False
self.args = args
self.observations = np.empty((size, *obs_shape), dtype=np.uint8) self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.actions = np.empty((size, action_size), dtype=np.float32) self.actions = np.empty((size, action_size), dtype=np.float32)
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8) self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
@ -152,6 +154,22 @@ class ReplayBuffer:
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l) obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
return obs,acs,rews,terms return obs,acs,rews,terms
def group_steps(self, buffer, variable):
variable = getattr(buffer, variable)
non_zero_indices = np.nonzero(buffer.episode_count)[0]
variable = variable[non_zero_indices]
variable = variable.reshape(self.args.episode_length, self.args.batch_size,
self.args.frame_stack*self.args.channels,
self.args.image_size,self.args.image_size)
return variable
def transform_grouped_steps(self, variable):
variable = variable.transpose((1, 0, 2, 3, 4))
variable = variable.reshape(self.args.batch_size*self.args.episode_length,self.args.frame_stack*self.args.channels,
self.args.image_size,self.args.image_size)
return variable
def make_env(args): def make_env(args):
env = dmc2gym.make( env = dmc2gym.make(
@ -167,4 +185,10 @@ def make_env(args):
width=args.image_size, width=args.image_size,
frame_skip=args.action_repeat frame_skip=args.action_repeat
) )
return env return env
def save_image(array, filename):
array = array.transpose(1, 2, 0)
array = (array * 255).astype(np.uint8)
image = Image.fromarray(array)
image.save(filename)