Improved simulation time
This commit is contained in:
parent
e0451ab4e3
commit
b902e07424
@ -78,6 +78,10 @@ class ActiveRLService(Node):
|
||||
self.mainloop_callback,
|
||||
callback_group=mainloop_callback_group)
|
||||
|
||||
# time messurements
|
||||
self.begin_time = None
|
||||
self.end_time = None
|
||||
|
||||
def reset_rl_request(self):
|
||||
self.rl_env = None
|
||||
self.rl_seed = None
|
||||
@ -86,6 +90,7 @@ class ActiveRLService(Node):
|
||||
self.interactive_run = 0
|
||||
|
||||
def active_rl_callback(self, msg):
|
||||
self.begin_time = time.time()
|
||||
self.rl_env = msg.env
|
||||
self.rl_seed = msg.seed
|
||||
self.display_run = msg.display_run
|
||||
@ -155,6 +160,27 @@ class ActiveRLService(Node):
|
||||
|
||||
return done
|
||||
|
||||
def complete_run(self, policy):
|
||||
env_reward = 0.0
|
||||
step_count = 0
|
||||
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
|
||||
for i in range(len(policy)):
|
||||
action = policy[i]
|
||||
action_clipped = action.clip(min=-1.0, max=1.0)
|
||||
output = self.env.step(action_clipped.astype(np.float64))
|
||||
|
||||
env_reward += output[1]
|
||||
done = output[2]
|
||||
step_count += 1
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
return env_reward, step_count
|
||||
|
||||
def mainloop_callback(self):
|
||||
if self.rl_pending:
|
||||
if self.interactive_run == 0:
|
||||
@ -195,6 +221,10 @@ class ActiveRLService(Node):
|
||||
rl_response.final_step = self.rl_step
|
||||
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
self.end_time = time.time()
|
||||
self.get_logger().info(f'RL Time: {self.end_time-self.begin_time}, mode: {self.interactive_run}')
|
||||
self.begin_time = None
|
||||
self.end_time = None
|
||||
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
|
||||
@ -232,32 +262,57 @@ class ActiveRLService(Node):
|
||||
self.rl_reward = 0.0
|
||||
self.rl_pending = False
|
||||
|
||||
self.end_time = time.time()
|
||||
self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
|
||||
self.begin_time = None
|
||||
self.end_time = None
|
||||
|
||||
elif self.interactive_run == 2:
|
||||
if not self.policy_sent:
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
self.policy_sent = True
|
||||
done = self.next_image(self.rl_policy, self.display_run)
|
||||
env_reward, step_count = self.complete_run(self.rl_policy)
|
||||
|
||||
if done:
|
||||
rl_response = ActiveRLResponse()
|
||||
rl_response.weights = self.rl_weights
|
||||
rl_response.reward = self.rl_reward
|
||||
rl_response.final_step = self.rl_step
|
||||
rl_response = ActiveRLResponse()
|
||||
rl_response.weights = self.rl_weights
|
||||
rl_response.reward = env_reward
|
||||
rl_response.final_step = step_count
|
||||
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
|
||||
# reset flags and attributes
|
||||
self.reset_eval_request()
|
||||
self.reset_rl_request()
|
||||
self.end_time = time.time()
|
||||
self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
|
||||
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
|
||||
self.rl_pending = False
|
||||
self.begin_time = None
|
||||
self.end_time = None
|
||||
|
||||
self.reset_rl_request()
|
||||
self.rl_pending = False
|
||||
|
||||
# if not self.policy_sent:
|
||||
# self.rl_step = 0
|
||||
# self.rl_reward = 0.0
|
||||
# self.env.reset(seed=self.rl_seed)
|
||||
# self.policy_sent = True
|
||||
# done = self.next_image(self.rl_policy, self.display_run)
|
||||
#
|
||||
# if done:
|
||||
# rl_response = ActiveRLResponse()
|
||||
# rl_response.weights = self.rl_weights
|
||||
# rl_response.reward = self.rl_reward
|
||||
# rl_response.final_step = self.rl_step
|
||||
#
|
||||
# self.active_rl_pub.publish(rl_response)
|
||||
# self.end_time = time.time()
|
||||
# self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
|
||||
# self.begin_time = None
|
||||
# self.end_time = None
|
||||
#
|
||||
# # reset flags and attributes
|
||||
# self.reset_eval_request()
|
||||
# self.reset_rl_request()
|
||||
#
|
||||
# self.rl_step = 0
|
||||
# self.rl_reward = 0.0
|
||||
#
|
||||
# self.rl_pending = False
|
||||
|
||||
|
||||
def main(args=None):
|
||||
|
@ -98,6 +98,15 @@ class ActiveBOTopic(Node):
|
||||
self.mainloop_callback,
|
||||
callback_group=mainloop_callback_group)
|
||||
|
||||
# time messurements
|
||||
self.init_begin = None
|
||||
self.init_end = None
|
||||
|
||||
self.rl_begin = None
|
||||
self.rl_end = None
|
||||
self.user_query_begin = None
|
||||
self.user_query_end = None
|
||||
|
||||
def reset_bo_request(self):
|
||||
self.bo_env = None
|
||||
self.bo_metric = None
|
||||
@ -144,6 +153,18 @@ class ActiveBOTopic(Node):
|
||||
else:
|
||||
self.seed = None
|
||||
|
||||
# set rl environment
|
||||
if self.bo_env == "Mountain Car":
|
||||
self.env = Continuous_MountainCarEnv()
|
||||
elif self.bo_env == "Cartpole":
|
||||
self.env = CartPoleEnv()
|
||||
elif self.bo_env == "Acrobot":
|
||||
self.env = AcrobotEnv()
|
||||
elif self.bo_env == "Pendulum":
|
||||
self.env = PendulumEnv()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_rl_response(self):
|
||||
self.rl_weights = None
|
||||
self.rl_final_step = None
|
||||
@ -163,6 +184,11 @@ class ActiveBOTopic(Node):
|
||||
|
||||
if self.init_pending:
|
||||
self.init_step += 1
|
||||
self.init_end = time.time()
|
||||
self.get_logger().info(f'Init Time: {self.init_end-self.init_begin}')
|
||||
self.init_begin = None
|
||||
self.init_end = None
|
||||
|
||||
if self.init_step == self.nr_init:
|
||||
self.init_step = 0
|
||||
self.init_pending = False
|
||||
@ -175,20 +201,12 @@ class ActiveBOTopic(Node):
|
||||
self.reset_rl_response()
|
||||
|
||||
def mainloop_callback(self):
|
||||
if self.active_bo_pending:
|
||||
if not self.active_bo_pending:
|
||||
return
|
||||
|
||||
# set rl environment
|
||||
if self.env is None:
|
||||
if self.bo_env == "Mountain Car":
|
||||
self.env = Continuous_MountainCarEnv()
|
||||
elif self.bo_env == "Cartpole":
|
||||
self.env = CartPoleEnv()
|
||||
elif self.bo_env == "Acrobot":
|
||||
self.env = AcrobotEnv()
|
||||
elif self.bo_env == "Pendulum":
|
||||
self.env = PendulumEnv()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
if self.rl_pending:
|
||||
return
|
||||
|
||||
if self.BO is None:
|
||||
self.BO = BayesianOptimization(self.env,
|
||||
@ -204,8 +222,8 @@ class ActiveBOTopic(Node):
|
||||
self.get_logger().info('BO Initialization is starting!')
|
||||
# self.get_logger().info(f'{self.rl_pending}')
|
||||
|
||||
if self.init_pending and not self.rl_pending:
|
||||
|
||||
if self.init_pending:
|
||||
self.init_begin = time.time()
|
||||
if self.bo_fixed_seed:
|
||||
seed = self.seed
|
||||
else:
|
||||
@ -222,7 +240,6 @@ class ActiveBOTopic(Node):
|
||||
self.active_rl_pub.publish(rl_msg)
|
||||
|
||||
self.rl_pending = True
|
||||
return
|
||||
|
||||
if self.current_run == self.bo_runs:
|
||||
bo_response = ActiveBOResponse()
|
||||
@ -293,7 +310,7 @@ class ActiveBOTopic(Node):
|
||||
self.reset_bo_request()
|
||||
|
||||
else:
|
||||
if self.rl_pending or self.init_pending:
|
||||
if self.init_pending:
|
||||
return
|
||||
else:
|
||||
if self.current_episode < self.bo_episodes:
|
||||
|
Loading…
Reference in New Issue
Block a user