bo 2d finished - now testing
This commit is contained in:
parent
33b8093a49
commit
597579bd98
@ -4,6 +4,7 @@ string metric
|
|||||||
uint16 nr_weights
|
uint16 nr_weights
|
||||||
uint16 max_steps
|
uint16 max_steps
|
||||||
uint16 nr_episodes
|
uint16 nr_episodes
|
||||||
|
uint16 nr_dims
|
||||||
uint16 nr_runs
|
uint16 nr_runs
|
||||||
string acquisition_function
|
string acquisition_function
|
||||||
float32 metric_parameter
|
float32 metric_parameter
|
||||||
|
@ -3,4 +3,5 @@ uint32 seed
|
|||||||
bool display_run
|
bool display_run
|
||||||
uint8 interactive_run
|
uint8 interactive_run
|
||||||
float64[] policy
|
float64[] policy
|
||||||
float64[] weights
|
float64[] weights
|
||||||
|
uint16 nr_dims
|
@ -36,7 +36,7 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
self.active_bo_sub = self.create_subscription(ActiveBORequest,
|
self.active_bo_sub = self.create_subscription(ActiveBORequest,
|
||||||
'active_bo_request',
|
'active_bo_request',
|
||||||
self.active_bo_callback,
|
self.bo_callback,
|
||||||
1, callback_group=bo_callback_group)
|
1, callback_group=bo_callback_group)
|
||||||
|
|
||||||
self.active_bo_pending = False
|
self.active_bo_pending = False
|
||||||
@ -63,7 +63,7 @@ class ActiveBOTopic(Node):
|
|||||||
1, callback_group=rl_callback_group)
|
1, callback_group=rl_callback_group)
|
||||||
self.active_rl_sub = self.create_subscription(ActiveRLResponse,
|
self.active_rl_sub = self.create_subscription(ActiveRLResponse,
|
||||||
'active_rl_response',
|
'active_rl_response',
|
||||||
self.active_rl_callback,
|
self.rl_callback,
|
||||||
1, callback_group=rl_callback_group)
|
1, callback_group=rl_callback_group)
|
||||||
|
|
||||||
self.rl_pending = False
|
self.rl_pending = False
|
||||||
@ -105,6 +105,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.bo_metric = None
|
self.bo_metric = None
|
||||||
self.bo_fixed_seed = False
|
self.bo_fixed_seed = False
|
||||||
self.bo_nr_weights = None
|
self.bo_nr_weights = None
|
||||||
|
self.bo_nr_dims = None
|
||||||
self.bo_steps = 0
|
self.bo_steps = 0
|
||||||
self.bo_episodes = 0
|
self.bo_episodes = 0
|
||||||
self.bo_runs = 0
|
self.bo_runs = 0
|
||||||
@ -120,7 +121,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.BO = None
|
self.BO = None
|
||||||
self.overwrite = False
|
self.overwrite = False
|
||||||
|
|
||||||
def active_bo_callback(self, msg):
|
def bo_callback(self, msg):
|
||||||
if not self.active_bo_pending:
|
if not self.active_bo_pending:
|
||||||
# self.get_logger().info('Active Bayesian Optimization request pending!')
|
# self.get_logger().info('Active Bayesian Optimization request pending!')
|
||||||
self.active_bo_pending = True
|
self.active_bo_pending = True
|
||||||
@ -128,6 +129,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.bo_metric = msg.metric
|
self.bo_metric = msg.metric
|
||||||
self.bo_fixed_seed = msg.fixed_seed
|
self.bo_fixed_seed = msg.fixed_seed
|
||||||
self.bo_nr_weights = msg.nr_weights
|
self.bo_nr_weights = msg.nr_weights
|
||||||
|
self.bo_nr_dims = msg.nr_dims
|
||||||
self.bo_steps = msg.max_steps
|
self.bo_steps = msg.max_steps
|
||||||
self.bo_episodes = msg.nr_episodes
|
self.bo_episodes = msg.nr_episodes
|
||||||
self.bo_runs = msg.nr_runs
|
self.bo_runs = msg.nr_runs
|
||||||
@ -135,19 +137,18 @@ class ActiveBOTopic(Node):
|
|||||||
self.bo_metric_parameter = msg.metric_parameter
|
self.bo_metric_parameter = msg.metric_parameter
|
||||||
self.bo_metric_parameter_2 = msg.metric_parameter_2
|
self.bo_metric_parameter_2 = msg.metric_parameter_2
|
||||||
self.save_result = msg.save_result
|
self.save_result = msg.save_result
|
||||||
self.seed_array = np.zeros((1, self.bo_runs))
|
self.seed_array = np.zeros((self.bo_runs, 1))
|
||||||
self.overwrite = msg.overwrite
|
self.overwrite = msg.overwrite
|
||||||
|
|
||||||
# initialize
|
# initialize
|
||||||
self.reward = np.zeros((self.bo_episodes + self.nr_init - 1, self.bo_runs))
|
self.reward = np.zeros((self.bo_runs, self.bo_episodes + self.nr_init - 1))
|
||||||
self.best_pol_reward = np.zeros((1, self.bo_runs))
|
self.best_pol_reward = np.zeros((self.bo_runs, 1))
|
||||||
self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
|
self.best_policy = np.zeros((self.bo_runs, self.bo_steps, self.bo_nr_dims))
|
||||||
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
|
self.best_weights = np.zeros((self.bo_runs, self.bo_nr_weights, self.bo_nr_dims))
|
||||||
|
|
||||||
# set the seed
|
# set the seed
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||||
# self.get_logger().info(str(self.seed))
|
|
||||||
else:
|
else:
|
||||||
self.seed = None
|
self.seed = None
|
||||||
|
|
||||||
@ -161,7 +162,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.rl_weights = None
|
self.rl_weights = None
|
||||||
self.rl_final_step = None
|
self.rl_final_step = None
|
||||||
|
|
||||||
def active_rl_callback(self, msg):
|
def rl_callback(self, msg):
|
||||||
if self.rl_pending:
|
if self.rl_pending:
|
||||||
# self.get_logger().info('Active Reinforcement Learning response received!')
|
# self.get_logger().info('Active Reinforcement Learning response received!')
|
||||||
self.rl_weights = np.array(msg.weights, dtype=np.float64)
|
self.rl_weights = np.array(msg.weights, dtype=np.float64)
|
||||||
@ -201,7 +202,7 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
if self.BO is None:
|
if self.BO is None:
|
||||||
self.BO = BayesianOptimization(self.bo_steps,
|
self.BO = BayesianOptimization(self.bo_steps,
|
||||||
2,
|
self.bo_nr_dims,
|
||||||
self.bo_nr_weights,
|
self.bo_nr_weights,
|
||||||
acq=self.bo_acq_fcn)
|
acq=self.bo_acq_fcn)
|
||||||
|
|
||||||
@ -223,8 +224,9 @@ class ActiveBOTopic(Node):
|
|||||||
rl_msg.seed = seed
|
rl_msg.seed = seed
|
||||||
rl_msg.display_run = False
|
rl_msg.display_run = False
|
||||||
rl_msg.interactive_run = 2
|
rl_msg.interactive_run = 2
|
||||||
rl_msg.weights = self.BO.policy_model.random_policy().tolist()
|
rl_msg.weights = self.BO.policy_model.random_policy().flatten().tolist()
|
||||||
rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist()
|
rl_msg.policy = self.BO.policy_model.rollout().flatten().tolist()
|
||||||
|
rl_msg.nr_dims = self.bo_nr_dims
|
||||||
|
|
||||||
self.active_rl_pub.publish(rl_msg)
|
self.active_rl_pub.publish(rl_msg)
|
||||||
|
|
||||||
@ -234,25 +236,15 @@ class ActiveBOTopic(Node):
|
|||||||
bo_response = ActiveBOResponse()
|
bo_response = ActiveBOResponse()
|
||||||
|
|
||||||
best_policy_idx = np.argmax(self.best_pol_reward)
|
best_policy_idx = np.argmax(self.best_pol_reward)
|
||||||
bo_response.best_policy = self.best_policy[:, best_policy_idx].tolist()
|
bo_response.best_policy = self.best_policy[best_policy_idx, :, :].flatten().tolist()
|
||||||
bo_response.best_weights = self.best_weights[:, best_policy_idx].tolist()
|
bo_response.best_weights = self.best_weights[best_policy_idx, :, :].flatten().tolist()
|
||||||
|
|
||||||
# self.get_logger().info(f'Best Policy: {self.best_pol_reward}')
|
|
||||||
|
|
||||||
self.get_logger().info(f'{best_policy_idx}, {int(self.seed_array[0, best_policy_idx])}')
|
|
||||||
|
|
||||||
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
|
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
|
||||||
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
||||||
|
|
||||||
if self.save_result:
|
if self.save_result:
|
||||||
if self.bo_env == "Mountain Car":
|
if self.bo_env == "Reacher":
|
||||||
env = 'mc'
|
env = 're'
|
||||||
elif self.bo_env == "Cartpole":
|
|
||||||
env = 'cp'
|
|
||||||
elif self.bo_env == "Acrobot":
|
|
||||||
env = 'ab'
|
|
||||||
elif self.bo_env == "Pendulum":
|
|
||||||
env = 'pd'
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -283,7 +275,6 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
seed = int(self.seed_array[0, best_policy_idx])
|
seed = int(self.seed_array[0, best_policy_idx])
|
||||||
# self.get_logger().info(f'Used seed{seed}')
|
|
||||||
else:
|
else:
|
||||||
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||||
|
|
||||||
@ -291,8 +282,9 @@ class ActiveBOTopic(Node):
|
|||||||
active_rl_request.seed = seed
|
active_rl_request.seed = seed
|
||||||
active_rl_request.display_run = True
|
active_rl_request.display_run = True
|
||||||
active_rl_request.interactive_run = 1
|
active_rl_request.interactive_run = 1
|
||||||
active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist()
|
active_rl_request.policy = self.best_policy[best_policy_idx, :, :].flatten().tolist()
|
||||||
active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist()
|
active_rl_request.weights = self.best_weights[best_policy_idx, :, :].flatten().tolist()
|
||||||
|
active_rl_request.nr_dims = self.bo_nr_dims
|
||||||
|
|
||||||
self.active_rl_pub.publish(active_rl_request)
|
self.active_rl_pub.publish(active_rl_request)
|
||||||
|
|
||||||
@ -316,7 +308,7 @@ class ActiveBOTopic(Node):
|
|||||||
user_query = ImprovementQuery(self.bo_metric_parameter,
|
user_query = ImprovementQuery(self.bo_metric_parameter,
|
||||||
self.bo_metric_parameter_2,
|
self.bo_metric_parameter_2,
|
||||||
self.last_query,
|
self.last_query,
|
||||||
self.reward[:self.current_episode, self.current_run])
|
self.reward[self.current_run, :self.current_episode])
|
||||||
|
|
||||||
elif self.bo_metric == "max acquisition":
|
elif self.bo_metric == "max acquisition":
|
||||||
user_query = MaxAcqQuery(self.bo_metric_parameter,
|
user_query = MaxAcqQuery(self.bo_metric_parameter,
|
||||||
@ -347,30 +339,17 @@ class ActiveBOTopic(Node):
|
|||||||
active_rl_request.seed = seed
|
active_rl_request.seed = seed
|
||||||
active_rl_request.display_run = True
|
active_rl_request.display_run = True
|
||||||
active_rl_request.interactive_run = 0
|
active_rl_request.interactive_run = 0
|
||||||
active_rl_request.policy = old_policy.tolist()
|
active_rl_request.policy = old_policy.flatten().tolist()
|
||||||
active_rl_request.weights = old_weights.tolist()
|
active_rl_request.weights = old_weights.flatten().tolist()
|
||||||
|
active_rl_request.nr_dims = self.bo_nr_dims
|
||||||
|
|
||||||
# self.get_logger().info('Calling: Active RL')
|
# self.get_logger().info('Calling: Active RL')
|
||||||
self.active_rl_pub.publish(active_rl_request)
|
self.active_rl_pub.publish(active_rl_request)
|
||||||
self.rl_pending = True
|
self.rl_pending = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# if self.bo_acq_fcn == "Preference Expected Improvement":
|
|
||||||
# self.get_logger().info(f"{self.BO.acq_fun.proposal_mean}")
|
|
||||||
# self.get_logger().info(f"X: {self.BO.X}")
|
|
||||||
x_next = self.BO.next_observation()
|
x_next = self.BO.next_observation()
|
||||||
# self.get_logger().info(f'x_next: {x_next}')
|
self.BO.policy_model.set_weights(np.around(x_next, decimals=8))
|
||||||
# self.get_logger().info(f'overwrite: {self.weight_preference}')
|
|
||||||
# self.get_logger().info(f'rl_weights: {self.rl_weights}')
|
|
||||||
|
|
||||||
if self.overwrite:
|
|
||||||
if self.weight_preference is not None and self.rl_weights is not None:
|
|
||||||
x_next[self.weight_preference] = self.rl_weights[self.weight_preference]
|
|
||||||
# self.get_logger().info(f'x_next: {x_next}')
|
|
||||||
# self.get_logger().info(f'overwrite: {self.weight_preference}')
|
|
||||||
# self.get_logger().info(f'rl_weights: {self.rl_weights}')
|
|
||||||
# self.get_logger().info('Next Observation BO!')
|
|
||||||
self.BO.policy_model.weights = np.around(x_next, decimals=8)
|
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
seed = self.seed
|
seed = self.seed
|
||||||
else:
|
else:
|
||||||
@ -381,23 +360,23 @@ class ActiveBOTopic(Node):
|
|||||||
rl_msg.seed = seed
|
rl_msg.seed = seed
|
||||||
rl_msg.display_run = False
|
rl_msg.display_run = False
|
||||||
rl_msg.interactive_run = 2
|
rl_msg.interactive_run = 2
|
||||||
rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist()
|
rl_msg.policy = self.BO.policy_model.rollout().flatten().tolist()
|
||||||
rl_msg.weights = x_next.tolist()
|
rl_msg.weights = x_next.flatten().tolist()
|
||||||
|
rl_msg.nr_dims = self.bo_nr_dims
|
||||||
|
|
||||||
self.rl_pending = True
|
self.rl_pending = True
|
||||||
|
|
||||||
self.active_rl_pub.publish(rl_msg)
|
self.active_rl_pub.publish(rl_msg)
|
||||||
|
|
||||||
self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y)
|
self.reward[self.current_run, self.current_episode] = np.max(self.BO.Y)
|
||||||
self.get_logger().info(f'Current Episode: {self.current_episode},'
|
self.get_logger().info(f'Current Episode: {self.current_episode},'
|
||||||
f' best reward: {self.reward[self.current_episode, self.current_run]}')
|
f' best reward: {self.reward[self.current_run, self.current_episode]}')
|
||||||
self.current_episode += 1
|
self.current_episode += 1
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.best_policy[:, self.current_run], \
|
self.best_policy[self.current_run, :, :], \
|
||||||
self.best_pol_reward[:, self.current_run], \
|
self.best_pol_reward[self.current_run, :], \
|
||||||
self.best_weights[:, self.current_run], idx = self.BO.get_best_result()
|
self.best_weights[self.current_run, :, :], idx = self.BO.get_best_result()
|
||||||
|
|
||||||
if self.current_run < self.bo_runs - 1:
|
if self.current_run < self.bo_runs - 1:
|
||||||
self.BO = None
|
self.BO = None
|
||||||
@ -405,7 +384,8 @@ class ActiveBOTopic(Node):
|
|||||||
self.current_episode = 0
|
self.current_episode = 0
|
||||||
self.last_query = 0
|
self.last_query = 0
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
self.seed_array[0, self.current_run] = self.seed
|
self.seed_array[self.current_run, 0] = self.seed
|
||||||
|
else:
|
||||||
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||||
# self.get_logger().info(f'{self.seed}')
|
# self.get_logger().info(f'{self.seed}')
|
||||||
self.current_run += 1
|
self.current_run += 1
|
||||||
@ -443,4 +423,3 @@ def main(args=None):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
@ -10,11 +10,7 @@ from rclpy.node import Node
|
|||||||
|
|
||||||
from rclpy.callback_groups import ReentrantCallbackGroup
|
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||||
|
|
||||||
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
|
from dm_control import suite
|
||||||
from active_bo_ros.ReinforcementLearning.CartPole import CartPoleEnv
|
|
||||||
from active_bo_ros.ReinforcementLearning.Pendulum import PendulumEnv
|
|
||||||
from active_bo_ros.ReinforcementLearning.Acrobot import AcrobotEnv
|
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
@ -34,7 +30,7 @@ class ActiveRL(Node):
|
|||||||
1, callback_group=rl_callback_group)
|
1, callback_group=rl_callback_group)
|
||||||
self.active_rl_sub = self.create_subscription(ActiveRLRequest,
|
self.active_rl_sub = self.create_subscription(ActiveRLRequest,
|
||||||
'active_rl_request',
|
'active_rl_request',
|
||||||
self.active_rl_callback,
|
self.rl_callback,
|
||||||
1, callback_group=rl_callback_group)
|
1, callback_group=rl_callback_group)
|
||||||
|
|
||||||
self.rl_env = None
|
self.rl_env = None
|
||||||
@ -56,17 +52,20 @@ class ActiveRL(Node):
|
|||||||
callback_group=topic_callback_group)
|
callback_group=topic_callback_group)
|
||||||
self.eval_sub = self.create_subscription(ActiveRLEvalResponse,
|
self.eval_sub = self.create_subscription(ActiveRLEvalResponse,
|
||||||
'active_rl_eval_response',
|
'active_rl_eval_response',
|
||||||
self.active_rl_eval_callback,
|
self.rl_eval_callback,
|
||||||
1,
|
1,
|
||||||
callback_group=topic_callback_group)
|
callback_group=topic_callback_group)
|
||||||
|
|
||||||
self.eval_response_received = False
|
self.eval_response_received = False
|
||||||
self.eval_policy = None
|
self.eval_policy = None
|
||||||
self.eval_weights = None
|
self.eval_weights = None
|
||||||
self.overwrite_weight = None
|
self.weight_preference = None
|
||||||
|
|
||||||
# RL Environments
|
# RL Environments
|
||||||
self.env = None
|
self.env = None
|
||||||
|
self.rl_spec = None
|
||||||
|
self.rl_dims = None
|
||||||
|
self.pol_dims = None
|
||||||
|
|
||||||
# State Machine Variables
|
# State Machine Variables
|
||||||
self.best_pol_shown = False
|
self.best_pol_shown = False
|
||||||
@ -83,61 +82,65 @@ class ActiveRL(Node):
|
|||||||
|
|
||||||
def reset_rl_request(self):
|
def reset_rl_request(self):
|
||||||
self.rl_env = None
|
self.rl_env = None
|
||||||
|
self.rl_spec = None
|
||||||
self.rl_seed = None
|
self.rl_seed = None
|
||||||
self.rl_policy = None
|
self.rl_policy = None
|
||||||
self.rl_weights = None
|
self.rl_weights = None
|
||||||
self.interactive_run = 0
|
self.interactive_run = 0
|
||||||
self.display_run = False
|
self.display_run = False
|
||||||
|
self.rl_dims = None
|
||||||
|
self.pol_dims = None
|
||||||
|
|
||||||
def active_rl_callback(self, msg):
|
def rl_callback(self, msg):
|
||||||
self.rl_env = msg.env
|
self.rl_env = msg.env
|
||||||
self.rl_seed = msg.seed
|
self.rl_seed = msg.seed
|
||||||
self.display_run = msg.display_run
|
self.display_run = msg.display_run
|
||||||
self.rl_policy = np.array(msg.policy, dtype=np.float64)
|
self.rl_dims = msg.nr_dims
|
||||||
self.rl_weights = msg.weights
|
self.rl_weights = msg.weights
|
||||||
|
pol = msg.policy
|
||||||
|
self.pol_dims = (len(pol)/self.rl_dims, self.rl_dims)
|
||||||
|
self.rl_policy = np.array(pol, dtype=np.float64).reshape(self.pol_dims)
|
||||||
self.interactive_run = msg.interactive_run
|
self.interactive_run = msg.interactive_run
|
||||||
|
|
||||||
if self.rl_env == "Mountain Car":
|
if self.rl_env == "Reacher":
|
||||||
self.env = Continuous_MountainCarEnv(render_mode="rgb_array")
|
random_state = np.random.RandomState(seed=self.rl_seed)
|
||||||
elif self.rl_env == "Cartpole":
|
self.env = suite.load('reacher', 'hard', task_kwargs={'random': random_state})
|
||||||
self.env = CartPoleEnv(render_mode="rgb_array")
|
self.rl_spec = self.env.action_spec()
|
||||||
elif self.rl_env == "Acrobot":
|
self.env.reset()
|
||||||
self.env = AcrobotEnv(render_mode="rgb_array")
|
|
||||||
elif self.rl_env == "Pendulum":
|
|
||||||
self.env = PendulumEnv(render_mode="rgb_array")
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# self.get_logger().info('Active RL: Called!')
|
|
||||||
self.env.reset(seed=self.rl_seed)
|
|
||||||
self.rl_pending = True
|
self.rl_pending = True
|
||||||
self.policy_sent = False
|
self.policy_sent = False
|
||||||
self.rl_step = 0
|
|
||||||
|
|
||||||
def reset_eval_request(self):
|
def reset_eval_request(self):
|
||||||
self.eval_policy = None
|
self.eval_policy = None
|
||||||
self.eval_weights = None
|
self.eval_weights = None
|
||||||
|
|
||||||
def active_rl_eval_callback(self, msg):
|
def rl_eval_callback(self, msg):
|
||||||
self.eval_policy = np.array(msg.policy, dtype=np.float64)
|
self.eval_policy = np.array(msg.policy, dtype=np.float64).reshape(self.pol_dims)
|
||||||
self.eval_weights = msg.weights
|
self.eval_weights = msg.weights
|
||||||
self.overwrite_weight = msg.overwrite_weight
|
self.weight_preference = msg.weight_preference
|
||||||
|
|
||||||
self.get_logger().info('Active RL Eval: Responded!')
|
self.get_logger().info('Active RL Eval: Responded!')
|
||||||
self.env.reset(seed=self.rl_seed)
|
self.env.reset()
|
||||||
self.eval_response_received = True
|
self.eval_response_received = True
|
||||||
|
|
||||||
def next_image(self, policy, display_run):
|
def step(self, policy, display_run):
|
||||||
action = policy[self.rl_step]
|
done = False
|
||||||
|
|
||||||
|
action = policy[self.rl_step, :]
|
||||||
action_clipped = action.clip(min=-1.0, max=1.0)
|
action_clipped = action.clip(min=-1.0, max=1.0)
|
||||||
output = self.env.step(action_clipped.astype(np.float64))
|
output = self.env.step(action_clipped.astype(np.float64))
|
||||||
|
|
||||||
self.rl_reward += output[1]
|
if output.reward != 0.0:
|
||||||
done = output[2]
|
self.rl_reward += output.reward * 10
|
||||||
self.rl_step += 1
|
done = True
|
||||||
|
else:
|
||||||
|
self.rl_step -= 1.0
|
||||||
|
|
||||||
if display_run:
|
if display_run:
|
||||||
rgb_array = self.env.render()
|
rgb_array = self.env.physics.render(camera_id=0, height=400, width=600)
|
||||||
rgb_shape = rgb_array.shape
|
rgb_shape = rgb_array.shape
|
||||||
|
|
||||||
red = rgb_array[:, :, 0].flatten().tolist()
|
red = rgb_array[:, :, 0].flatten().tolist()
|
||||||
@ -160,25 +163,27 @@ class ActiveRL(Node):
|
|||||||
|
|
||||||
return done
|
return done
|
||||||
|
|
||||||
def complete_run(self, policy):
|
def runner(self, policy):
|
||||||
env_reward = 0.0
|
env_reward = 0.0
|
||||||
step_count = 0
|
step_count = 0
|
||||||
|
done = False
|
||||||
|
|
||||||
self.env.reset(seed=self.rl_seed)
|
self.env.reset()
|
||||||
|
|
||||||
for i in range(len(policy)):
|
for i in range(policy.shape[0]):
|
||||||
action = policy[i]
|
action = policy[i, :]
|
||||||
action_clipped = action.clip(min=-1.0, max=1.0)
|
action_clipped = action.clip(min=-1.0, max=1.0)
|
||||||
output = self.env.step(action_clipped.astype(np.float64))
|
output = self.env.step(action_clipped.astype(np.float64))
|
||||||
|
|
||||||
env_reward += output[1]
|
|
||||||
done = output[2]
|
|
||||||
step_count += 1
|
step_count += 1
|
||||||
|
|
||||||
|
if output.reward != 0.0:
|
||||||
|
self.rl_reward += output.reward * 10
|
||||||
|
done = True
|
||||||
|
else:
|
||||||
|
self.rl_step -= 1.0
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
self.env.reset(seed=self.rl_seed)
|
|
||||||
return env_reward, step_count
|
return env_reward, step_count
|
||||||
|
|
||||||
def mainloop_callback(self):
|
def mainloop_callback(self):
|
||||||
@ -188,10 +193,10 @@ class ActiveRL(Node):
|
|||||||
if not self.policy_sent:
|
if not self.policy_sent:
|
||||||
self.rl_step = 0
|
self.rl_step = 0
|
||||||
self.rl_reward = 0.0
|
self.rl_reward = 0.0
|
||||||
self.env.reset(seed=self.rl_seed)
|
self.env.reset()
|
||||||
|
|
||||||
eval_request = ActiveRLEvalRequest()
|
eval_request = ActiveRLEvalRequest()
|
||||||
eval_request.policy = self.rl_policy.tolist()
|
eval_request.policy = self.rl_policy.flatten().tolist()
|
||||||
eval_request.weights = self.rl_weights
|
eval_request.weights = self.rl_weights
|
||||||
|
|
||||||
self.eval_pub.publish(eval_request)
|
self.eval_pub.publish(eval_request)
|
||||||
@ -200,7 +205,7 @@ class ActiveRL(Node):
|
|||||||
|
|
||||||
self.policy_sent = True
|
self.policy_sent = True
|
||||||
|
|
||||||
done = self.next_image(self.rl_policy, self.display_run)
|
done = self.step(self.rl_policy, self.display_run)
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
self.best_pol_shown = True
|
self.best_pol_shown = True
|
||||||
@ -212,18 +217,18 @@ class ActiveRL(Node):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if self.eval_response_received:
|
if self.eval_response_received:
|
||||||
done = self.next_image(self.eval_policy, self.display_run)
|
done = self.step(self.eval_policy, self.display_run)
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
rl_response = ActiveRLResponse()
|
rl_response = ActiveRLResponse()
|
||||||
rl_response.weights = self.eval_weights
|
rl_response.weights = self.eval_weights
|
||||||
rl_response.reward = self.rl_reward
|
rl_response.reward = self.rl_reward
|
||||||
rl_response.final_step = self.rl_step
|
rl_response.final_step = self.rl_step
|
||||||
rl_response.overwrite_weight = self.overwrite_weight
|
rl_response.weight_preference = self.weight_preference
|
||||||
|
|
||||||
self.active_rl_pub.publish(rl_response)
|
self.active_rl_pub.publish(rl_response)
|
||||||
|
|
||||||
self.env.reset(seed=self.rl_seed)
|
self.env.reset()
|
||||||
|
|
||||||
# reset flags and attributes
|
# reset flags and attributes
|
||||||
self.reset_eval_request()
|
self.reset_eval_request()
|
||||||
@ -240,10 +245,10 @@ class ActiveRL(Node):
|
|||||||
if not self.policy_sent:
|
if not self.policy_sent:
|
||||||
self.rl_step = 0
|
self.rl_step = 0
|
||||||
self.rl_reward = 0.0
|
self.rl_reward = 0.0
|
||||||
self.env.reset(seed=self.rl_seed)
|
self.env.reset()
|
||||||
|
|
||||||
eval_request = ActiveRLEvalRequest()
|
eval_request = ActiveRLEvalRequest()
|
||||||
eval_request.policy = self.rl_policy.tolist()
|
eval_request.policy = self.rl_policy.flatten().tolist()
|
||||||
eval_request.weights = self.rl_weights
|
eval_request.weights = self.rl_weights
|
||||||
|
|
||||||
self.eval_pub.publish(eval_request)
|
self.eval_pub.publish(eval_request)
|
||||||
@ -252,7 +257,7 @@ class ActiveRL(Node):
|
|||||||
|
|
||||||
self.policy_sent = True
|
self.policy_sent = True
|
||||||
|
|
||||||
done = self.next_image(self.rl_policy, self.display_run)
|
done = self.step(self.rl_policy, self.display_run)
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
self.rl_step = 0
|
self.rl_step = 0
|
||||||
@ -260,18 +265,18 @@ class ActiveRL(Node):
|
|||||||
self.rl_pending = False
|
self.rl_pending = False
|
||||||
|
|
||||||
elif self.interactive_run == 2:
|
elif self.interactive_run == 2:
|
||||||
env_reward, step_count = self.complete_run(self.rl_policy)
|
env_reward, step_count = self.runner(self.rl_policy)
|
||||||
|
|
||||||
rl_response = ActiveRLResponse()
|
rl_response = ActiveRLResponse()
|
||||||
rl_response.weights = self.rl_weights
|
rl_response.weights = self.rl_weights
|
||||||
rl_response.reward = env_reward
|
rl_response.reward = env_reward
|
||||||
rl_response.final_step = step_count
|
rl_response.final_step = step_count
|
||||||
if self.overwrite_weight is None:
|
if self.weight_preference is None:
|
||||||
overwrite_weight = [False] * len(self.rl_weights)
|
weight_preference = [False] * len(self.rl_weights)
|
||||||
else:
|
else:
|
||||||
overwrite_weight = self.overwrite_weight
|
weight_preference = self.weight_preference
|
||||||
|
|
||||||
rl_response.overwrite_weight = overwrite_weight
|
rl_response.weight_preference = weight_preference
|
||||||
|
|
||||||
self.active_rl_pub.publish(rl_response)
|
self.active_rl_pub.publish(rl_response)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user