Implementing ICLUB
This commit is contained in:
parent
abaca2bea9
commit
641c9bd57c
@ -128,26 +128,28 @@ class TransitionModel(nn.Module):
|
|||||||
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
||||||
def __init__(self, x_dim, y_dim, hidden_size):
|
def __init__(self, x_dim, y_dim, hidden_size):
|
||||||
super(CLUBSample, self).__init__()
|
super(CLUBSample, self).__init__()
|
||||||
self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
|
self.p_mu = nn.Sequential(
|
||||||
|
nn.Linear(x_dim, hidden_size//2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(hidden_size//2, y_dim))
|
nn.Linear(hidden_size//2, y_dim)
|
||||||
|
)
|
||||||
|
|
||||||
self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
|
self.p_logvar = nn.Sequential(
|
||||||
|
nn.Linear(x_dim, hidden_size//2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(hidden_size//2, y_dim),
|
nn.Linear(hidden_size//2, y_dim),
|
||||||
nn.Tanh())
|
nn.Tanh()
|
||||||
|
)
|
||||||
|
|
||||||
def get_mu_logvar(self, x_samples):
|
def get_mu_logvar(self, x_samples):
|
||||||
mu = self.p_mu(x_samples)
|
mu = self.p_mu(x_samples)
|
||||||
logvar = self.p_logvar(x_samples)
|
logvar = self.p_logvar(x_samples)
|
||||||
return mu, logvar
|
return mu, logvar
|
||||||
|
|
||||||
|
|
||||||
def loglikeli(self, x_samples, y_samples):
|
def loglikeli(self, x_samples, y_samples):
|
||||||
mu, logvar = self.get_mu_logvar(x_samples)
|
mu, logvar = self.get_mu_logvar(x_samples)
|
||||||
return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)
|
return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x_samples, y_samples):
|
def forward(self, x_samples, y_samples):
|
||||||
mu, logvar = self.get_mu_logvar(x_samples)
|
mu, logvar = self.get_mu_logvar(x_samples)
|
||||||
|
|
||||||
@ -165,8 +167,9 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
encoder = ObservationEncoder((12,84,84), 256)
|
encoder = ObservationEncoder((12,84,84), 256)
|
||||||
x = torch.randn(100, 12, 84, 84)
|
x = torch.randn(5000, 12, 84, 84)
|
||||||
print(encoder(x).shape)
|
print(encoder(x).shape)
|
||||||
|
exit()
|
||||||
|
|
||||||
club = CLUBSample(256, 256 , 512)
|
club = CLUBSample(256, 256 , 512)
|
||||||
x = torch.randn(100, 256)
|
x = torch.randn(100, 256)
|
||||||
|
50
DPI/train.py
50
DPI/train.py
@ -9,7 +9,7 @@ import dmc2gym
|
|||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
import utils
|
import utils
|
||||||
from utils import ReplayBuffer, make_env
|
from utils import ReplayBuffer, make_env, save_image
|
||||||
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
|
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
|
||||||
from logger import Logger
|
from logger import Logger
|
||||||
from video import VideoRecorder
|
from video import VideoRecorder
|
||||||
@ -34,18 +34,18 @@ def parse_args():
|
|||||||
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||||
parser.add_argument('--total_frames', default=1000, type=int)
|
parser.add_argument('--total_frames', default=1000, type=int)
|
||||||
# replay buffer
|
# replay buffer
|
||||||
parser.add_argument('--replay_buffer_capacity', default=100000, type=int)
|
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000
|
||||||
parser.add_argument('--episode_length', default=1000, type=int)
|
parser.add_argument('--episode_length', default=50, type=int)
|
||||||
# train
|
# train
|
||||||
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
|
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
|
||||||
parser.add_argument('--init_steps', default=1000, type=int)
|
parser.add_argument('--init_steps', default=1000, type=int)
|
||||||
parser.add_argument('--num_train_steps', default=1000, type=int)
|
parser.add_argument('--num_train_steps', default=1000, type=int)
|
||||||
parser.add_argument('--batch_size', default=512, type=int)
|
parser.add_argument('--batch_size', default=200, type=int) #512
|
||||||
parser.add_argument('--state_size', default=256, type=int)
|
parser.add_argument('--state_size', default=256, type=int)
|
||||||
parser.add_argument('--hidden_size', default=128, type=int)
|
parser.add_argument('--hidden_size', default=128, type=int)
|
||||||
parser.add_argument('--history_size', default=128, type=int)
|
parser.add_argument('--history_size', default=128, type=int)
|
||||||
parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model')
|
|
||||||
parser.add_argument('--load_encoder', default=None, type=str)
|
parser.add_argument('--load_encoder', default=None, type=str)
|
||||||
|
parser.add_argument('--imagination_horizon', default=15, type=str)
|
||||||
# eval
|
# eval
|
||||||
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
|
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
|
||||||
parser.add_argument('--num_eval_episodes', default=20, type=int)
|
parser.add_argument('--num_eval_episodes', default=20, type=int)
|
||||||
@ -79,7 +79,6 @@ def parse_args():
|
|||||||
parser.add_argument('--alpha_beta', default=0.9, type=float)
|
parser.add_argument('--alpha_beta', default=0.9, type=float)
|
||||||
# misc
|
# misc
|
||||||
parser.add_argument('--seed', default=1, type=int)
|
parser.add_argument('--seed', default=1, type=int)
|
||||||
parser.add_argument('--seed_steps', default=5000, type=int)
|
|
||||||
parser.add_argument('--work_dir', default='.', type=str)
|
parser.add_argument('--work_dir', default='.', type=str)
|
||||||
parser.add_argument('--save_tb', default=False, action='store_true')
|
parser.add_argument('--save_tb', default=False, action='store_true')
|
||||||
parser.add_argument('--save_model', default=False, action='store_true')
|
parser.add_argument('--save_model', default=False, action='store_true')
|
||||||
@ -117,20 +116,21 @@ class DPI:
|
|||||||
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
|
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
|
||||||
action_size=self.env.action_space.shape[0],
|
action_size=self.env.action_space.shape[0],
|
||||||
seq_len=self.args.episode_length,
|
seq_len=self.args.episode_length,
|
||||||
batch_size=args.batch_size)
|
batch_size=args.batch_size,
|
||||||
|
args=self.args)
|
||||||
|
|
||||||
# create work directory
|
# create work directory
|
||||||
utils.make_dir(self.args.work_dir)
|
utils.make_dir(self.args.work_dir)
|
||||||
video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video'))
|
self.video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video'))
|
||||||
model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
|
self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
|
||||||
buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
|
self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
|
||||||
|
|
||||||
# create video recorder
|
# create video recorder
|
||||||
#video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files)
|
#video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files)
|
||||||
#video.init(enabled=True)
|
#video.init(enabled=True)
|
||||||
|
|
||||||
# create models
|
# create models
|
||||||
self.build_models(use_saved=False, saved_model_dir=model_dir)
|
self.build_models(use_saved=False, saved_model_dir=self.model_dir)
|
||||||
|
|
||||||
def build_models(self, use_saved, saved_model_dir=None):
|
def build_models(self, use_saved, saved_model_dir=None):
|
||||||
self.obs_encoder = ObservationEncoder(
|
self.obs_encoder = ObservationEncoder(
|
||||||
@ -174,6 +174,7 @@ class DPI:
|
|||||||
next_obs, _, done, _ = self.env.step(action)
|
next_obs, _, done, _ = self.env.step(action)
|
||||||
|
|
||||||
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
done=False
|
done=False
|
||||||
@ -185,12 +186,33 @@ class DPI:
|
|||||||
#video.save('%d.mp4' % step)
|
#video.save('%d.mp4' % step)
|
||||||
#video.close()
|
#video.close()
|
||||||
|
|
||||||
def upper_bound_minimization(self):
|
def train(self):
|
||||||
pass
|
# collect experience
|
||||||
|
self.collect_random_episodes(self.args.batch_size)
|
||||||
|
|
||||||
|
# Group observations and next_observations by steps
|
||||||
|
observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()
|
||||||
|
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
|
||||||
|
|
||||||
|
# Train encoder
|
||||||
|
for i in range(self.args.episode_length):
|
||||||
|
# Encode observations and next_observations
|
||||||
|
self.features = self.obs_encoder(observations[i]) # (N,128)
|
||||||
|
self.next_features = self.obs_encoder(next_observations[i]) # (N,128)
|
||||||
|
|
||||||
|
# Calculate upper bound loss
|
||||||
|
past_loss = self.upper_bound_minimization(self.features, self.next_features)
|
||||||
|
|
||||||
|
def upper_bound_minimization(self, features, next_features):
|
||||||
|
club_sample = CLUBSample(self.args.state_size,
|
||||||
|
self.args.state_size,
|
||||||
|
self.args.hidden_size)
|
||||||
|
club_loss = club_sample(features, next_features)
|
||||||
|
return club_loss
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
dpi = DPI(args)
|
dpi = DPI(args)
|
||||||
dpi.collect_random_episodes(episodes=5)
|
dpi.train()
|
26
DPI/utils.py
26
DPI/utils.py
@ -13,6 +13,7 @@ import gym
|
|||||||
import dmc2gym
|
import dmc2gym
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
from PIL import Image
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
@ -105,7 +106,7 @@ class FrameStack(gym.Wrapper):
|
|||||||
|
|
||||||
|
|
||||||
class ReplayBuffer:
|
class ReplayBuffer:
|
||||||
def __init__(self, size, obs_shape, action_size, seq_len, batch_size):
|
def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.obs_shape = obs_shape
|
self.obs_shape = obs_shape
|
||||||
self.action_size = action_size
|
self.action_size = action_size
|
||||||
@ -113,6 +114,7 @@ class ReplayBuffer:
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.idx = 0
|
self.idx = 0
|
||||||
self.full = False
|
self.full = False
|
||||||
|
self.args = args
|
||||||
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
|
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
|
||||||
self.actions = np.empty((size, action_size), dtype=np.float32)
|
self.actions = np.empty((size, action_size), dtype=np.float32)
|
||||||
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
|
self.next_observations = np.empty((size, *obs_shape), dtype=np.uint8)
|
||||||
@ -152,6 +154,22 @@ class ReplayBuffer:
|
|||||||
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
|
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
|
||||||
return obs,acs,rews,terms
|
return obs,acs,rews,terms
|
||||||
|
|
||||||
|
def group_steps(self, buffer, variable):
|
||||||
|
variable = getattr(buffer, variable)
|
||||||
|
non_zero_indices = np.nonzero(buffer.episode_count)[0]
|
||||||
|
variable = variable[non_zero_indices]
|
||||||
|
|
||||||
|
variable = variable.reshape(self.args.episode_length, self.args.batch_size,
|
||||||
|
self.args.frame_stack*self.args.channels,
|
||||||
|
self.args.image_size,self.args.image_size)
|
||||||
|
return variable
|
||||||
|
|
||||||
|
def transform_grouped_steps(self, variable):
|
||||||
|
variable = variable.transpose((1, 0, 2, 3, 4))
|
||||||
|
variable = variable.reshape(self.args.batch_size*self.args.episode_length,self.args.frame_stack*self.args.channels,
|
||||||
|
self.args.image_size,self.args.image_size)
|
||||||
|
return variable
|
||||||
|
|
||||||
|
|
||||||
def make_env(args):
|
def make_env(args):
|
||||||
env = dmc2gym.make(
|
env = dmc2gym.make(
|
||||||
@ -168,3 +186,9 @@ def make_env(args):
|
|||||||
frame_skip=args.action_repeat
|
frame_skip=args.action_repeat
|
||||||
)
|
)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
def save_image(array, filename):
|
||||||
|
array = array.transpose(1, 2, 0)
|
||||||
|
array = (array * 255).astype(np.uint8)
|
||||||
|
image = Image.fromarray(array)
|
||||||
|
image.save(filename)
|
Loading…
Reference in New Issue
Block a user