diff --git a/src/active_bo_msgs/msg/ActiveRLEvalResponse.msg b/src/active_bo_msgs/msg/ActiveRLEvalResponse.msg index 9a562fb..23eb44e 100644 --- a/src/active_bo_msgs/msg/ActiveRLEvalResponse.msg +++ b/src/active_bo_msgs/msg/ActiveRLEvalResponse.msg @@ -1,3 +1,3 @@ -bool[] overwrite_weight +bool[] weight_preference float64[] policy float64[] weights \ No newline at end of file diff --git a/src/active_bo_msgs/msg/ActiveRLResponse.msg b/src/active_bo_msgs/msg/ActiveRLResponse.msg index f844283..1d9da78 100644 --- a/src/active_bo_msgs/msg/ActiveRLResponse.msg +++ b/src/active_bo_msgs/msg/ActiveRLResponse.msg @@ -1,4 +1,4 @@ float64[] weights -bool[] overwrite_weight +bool[] weight_preference uint16 final_step float64 reward \ No newline at end of file diff --git a/src/active_bo_ros/active_bo_ros/BayesianOptimization/BO2D.py b/src/active_bo_ros/active_bo_ros/BayesianOptimization/BO2D.py new file mode 100644 index 0000000..073ca48 --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/BayesianOptimization/BO2D.py @@ -0,0 +1,120 @@ +import numpy as np +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import Matern + +from active_bo_ros.PolicyModel.GaussianModelMultiDim import GaussianRBF + +from active_bo_ros.AcquisitionFunctions.ExpectedImprovement import ExpectedImprovement +from active_bo_ros.AcquisitionFunctions.ProbabilityOfImprovement import ProbabilityOfImprovement +from active_bo_ros.AcquisitionFunctions.ConfidenceBound import ConfidenceBound +from active_bo_ros.AcquisitionFunctions.PreferenceExpectedImprovement import PreferenceExpectedImprovement + +from sklearn.exceptions import ConvergenceWarning +import warnings + +warnings.filterwarnings('ignore', category=ConvergenceWarning) + + +class BayesianOptimization: + def __init__(self, nr_steps, nr_dims, nr_weights, acq='ei', seed=None): + self.acq = acq + self.episode = 0 + + self.nr_steps = nr_steps + self.nr_dims = nr_dims + self.nr_weights = nr_weights + self.weights = self.nr_weights * self.nr_dims + + self.lower_bound = -1.0 + self.upper_bound = 1.0 + self.seed = seed + + self.X = None + self.Y = None + self.gp = None + + self.policy_model = GaussianRBF(self.nr_steps, self.nr_weights, self.nr_dims, + lowerb=self.lower_bound, upperb=self.upper_bound, seed=seed) + + self.acq_sample_size = 100 + + self.best_reward = np.empty((1, 1)) + + if acq == "Preference Expected Improvement": + self.acq_fun = PreferenceExpectedImprovement(self.weights, + self.acq_sample_size, + self.lower_bound, + self.upper_bound, + initial_variance=10.0, + update_variance=0.05, + seed=seed) + + self.reset_bo() + + def reset_bo(self): + self.gp = GaussianProcessRegressor(Matern(nu=1.5, ), n_restarts_optimizer=5) #length_scale=(1e-8, 1e5) + self.best_reward = np.empty((1, 1)) + self.X = np.zeros((1, self.weights), dtype=np.float64) + self.Y = np.zeros((1, 1), dtype=np.float64) + self.episode = 0 + + def next_observation(self): + if self.acq == "Expected Improvement": + x_next = ExpectedImprovement(self.gp, + self.X, + self.acq_sample_size, + self.weights, + kappa=0, + seed=self.seed, + lower=self.lower_bound, + upper=self.upper_bound) + + elif self.acq == "Probability of Improvement": + x_next = ProbabilityOfImprovement(self.gp, + self.X, + self.acq_sample_size, + self.weights, + kappa=0, + seed=self.seed, + lower=self.lower_bound, + upper=self.upper_bound) + + elif self.acq == "Upper Confidence Bound": + x_next = ConfidenceBound(self.gp, + self.acq_sample_size, + self.weights, + beta=2.576, + seed=self.seed, + lower=self.lower_bound, + upper=self.upper_bound) + + elif self.acq == "Preference Expected Improvement": + x_next = self.acq_fun.expected_improvement(self.gp, + self.X, + kappa=0) + + else: + raise NotImplementedError + + return x_next + + def add_observation(self, reward, x): + if self.episode == 0: + self.X[0, :] = x + self.Y[0] = reward + self.best_reward[0] = np.max(self.Y) + else: + self.X = np.vstack((self.X, np.around(x, decimals=8)), dtype=np.float64) + self.Y = np.vstack((self.Y, reward), dtype=np.float64) + self.best_reward = np.vstack((self.best_reward, np.max(self.Y)), dtype=np.float64) + + self.gp.fit(self.X, self.Y) + self.episode += 1 + + def get_best_result(self): + y_max = np.max(self.Y) + idx = np.argmax(self.Y) + x_max = self.X[idx, :] + + return y_max, x_max, idx + diff --git a/src/active_bo_ros/active_bo_ros/PolicyModel/GaussianModelMultiDim.py b/src/active_bo_ros/active_bo_ros/PolicyModel/GaussianModelMultiDim.py new file mode 100644 index 0000000..3f8a841 --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/PolicyModel/GaussianModelMultiDim.py @@ -0,0 +1,49 @@ +import numpy as np + + +class GaussianRBF: + def __init__(self, nr_steps, nr_weights, nr_dims, lowerb=-1.0, upperb=1.0, seed=None): + self.nr_weights = nr_weights + self.nr_steps = nr_steps + self.nr_dims = nr_dims + + self.weights = None + self.trajectory = None + + self.lowerb = lowerb + self.upperb = upperb + + self.rng = np.random.default_rng(seed=seed) + + # initialize + self.mid_points = np.linspace(0, self.nr_steps, self.nr_weights) + if nr_weights > 1: + self.std = self.mid_points[1] / (2 * np.sqrt(2 * np.log(2))) # Full width at half maximum + else: + self.std = self.nr_steps / 2 + + self.reset() + + def reset(self): + self.weights = np.zeros((self.nr_weights, self.nr_dims)) + self.trajectory = np.zeros((self.nr_steps, self.nr_dims)) + + def random_weights(self): + for dim in range(self.nr_dims): + self.weights[:, dim] = self.rng.uniform(self.lowerb, self.upperb, self.nr_weights) + + def rollout(self): + self.trajectory = np.zeros((self.nr_steps, self.nr_dims)) + for step in range(self.nr_steps): + for weight in range(self.nr_weights): + base_fun = np.exp(-0.5 * (step - self.mid_points[weight]) ** 2 / self.std ** 2) + for dim in range(self.nr_dims): + self.trajectory[step, dim] += base_fun * self.weights[weight, dim] + + return self.trajectory + + def set_weights(self, x): + self.weights = x.reshape(self.nr_weights, self.nr_dims) + + def get_x(self): + return self.weights.reshape(self.nr_weights * self.nr_dims, 1) diff --git a/src/active_bo_ros/active_bo_ros/interactive_bo.py b/src/active_bo_ros/active_bo_ros/interactive_bo.py index 259cf0e..36f45fb 100644 --- a/src/active_bo_ros/active_bo_ros/interactive_bo.py +++ b/src/active_bo_ros/active_bo_ros/interactive_bo.py @@ -74,7 +74,7 @@ class ActiveBOTopic(Node): self.rl_weights = None self.rl_final_step = None self.rl_reward = 0.0 - self.overwrite_weight = None + self.weight_preference = None # State Publisher self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1) @@ -175,7 +175,7 @@ class ActiveBOTopic(Node): if self.rl_pending: # self.get_logger().info('Active Reinforcement Learning response received!') self.rl_weights = np.array(msg.weights, dtype=np.float64) - self.overwrite_weight = np.array(msg.overwrite_weight, dtype=bool) + self.weight_preference = np.array(msg.weight_preference, dtype=bool) self.rl_final_step = msg.final_step self.rl_reward = msg.reward @@ -195,7 +195,7 @@ class ActiveBOTopic(Node): if self.user_asked: self.last_user_reward = self.rl_reward if self.bo_acq_fcn == "Preference Expected Improvement": - self.BO.acq_fun.update_proposal_model(self.rl_weights, self.overwrite_weight) + self.BO.acq_fun.update_proposal_model(self.rl_weights, self.weight_preference) self.user_asked = False self.rl_pending = False @@ -371,14 +371,14 @@ class ActiveBOTopic(Node): # self.get_logger().info(f"X: {self.BO.X}") x_next = self.BO.next_observation() # self.get_logger().info(f'x_next: {x_next}') - # self.get_logger().info(f'overwrite: {self.overwrite_weight}') + # self.get_logger().info(f'overwrite: {self.weight_preference}') # self.get_logger().info(f'rl_weights: {self.rl_weights}') if self.overwrite: - if self.overwrite_weight is not None and self.rl_weights is not None: - x_next[self.overwrite_weight] = self.rl_weights[self.overwrite_weight] + if self.weight_preference is not None and self.rl_weights is not None: + x_next[self.weight_preference] = self.rl_weights[self.weight_preference] # self.get_logger().info(f'x_next: {x_next}') - # self.get_logger().info(f'overwrite: {self.overwrite_weight}') + # self.get_logger().info(f'overwrite: {self.weight_preference}') # self.get_logger().info(f'rl_weights: {self.rl_weights}') # self.get_logger().info('Next Observation BO!') self.BO.policy_model.weights = np.around(x_next, decimals=8) diff --git a/src/active_bo_ros/active_bo_ros/interactive_bo_2d.py b/src/active_bo_ros/active_bo_ros/interactive_bo_2d.py new file mode 100644 index 0000000..eab7b11 --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/interactive_bo_2d.py @@ -0,0 +1,446 @@ +from active_bo_msgs.msg import ActiveBORequest +from active_bo_msgs.msg import ActiveBOResponse +from active_bo_msgs.msg import ActiveRLRequest +from active_bo_msgs.msg import ActiveRLResponse +from active_bo_msgs.msg import ActiveBOState + +import rclpy +from rclpy.node import Node + +from rclpy.callback_groups import ReentrantCallbackGroup + +from active_bo_ros.BayesianOptimization.BO2D import BayesianOptimization + +from active_bo_ros.UserQuery.random_query import RandomQuery +from active_bo_ros.UserQuery.regular_query import RegularQuery +from active_bo_ros.UserQuery.improvement_query import ImprovementQuery +from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery + +import numpy as np +import time +import os + + +class ActiveBOTopic(Node): + def __init__(self): + super().__init__('active_bo_topic') + + bo_callback_group = ReentrantCallbackGroup() + rl_callback_group = ReentrantCallbackGroup() + mainloop_callback_group = ReentrantCallbackGroup() + + # Active Bayesian Optimization Publisher, Subscriber and Message attributes + self.active_bo_pub = self.create_publisher(ActiveBOResponse, + 'active_bo_response', + 1, callback_group=bo_callback_group) + + self.active_bo_sub = self.create_subscription(ActiveBORequest, + 'active_bo_request', + self.active_bo_callback, + 1, callback_group=bo_callback_group) + + self.active_bo_pending = False + self.bo_env = None + self.bo_metric = None + self.bo_fixed_seed = False + self.bo_nr_weights = None + self.bo_nr_dims = 2 + self.bo_steps = 0 + self.bo_episodes = 0 + self.bo_runs = 0 + self.bo_acq_fcn = None + self.bo_metric_parameter = None + self.bo_metric_parameter_2 = None + self.current_run = 0 + self.current_episode = 0 + self.seed = None + self.seed_array = None + self.save_result = False + + # Active Reinforcement Learning Publisher, Subscriber and Message attributes + self.active_rl_pub = self.create_publisher(ActiveRLRequest, + 'active_rl_request', + 1, callback_group=rl_callback_group) + self.active_rl_sub = self.create_subscription(ActiveRLResponse, + 'active_rl_response', + self.active_rl_callback, + 1, callback_group=rl_callback_group) + + self.rl_pending = False + self.rl_weights = None + self.rl_final_step = None + self.rl_reward = 0.0 + self.weight_preference = None + + # State Publisher + self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1) + + # RL Environments and BO + self.env = None + + self.BO = None + self.nr_init = 3 + self.init_step = 0 + self.init_pending = False + self.reward = None + self.best_reward = 0.0 + self.best_pol_reward = None + self.best_policy = None + self.best_weights = None + + # User Query + self.last_query = 0 + self.user_asked = False + self.last_user_reward = 0.0 + self.overwrite = False + + # Main loop timer object + self.mainloop_timer_period = 0.1 + self.mainloop = self.create_timer(self.mainloop_timer_period, + self.mainloop_callback, + callback_group=mainloop_callback_group) + + def reset_bo_request(self): + self.bo_env = None + self.bo_metric = None + self.bo_fixed_seed = False + self.bo_nr_weights = None + self.bo_steps = 0 + self.bo_episodes = 0 + self.bo_runs = 0 + self.bo_acq_fcn = None + self.bo_metric_parameter = None + self.bo_metric_parameter_2 = None + self.current_run = 0 + self.current_episode = 0 + self.save_result = False + self.seed_array = None + self.env = None + self.active_bo_pending = False + self.BO = None + self.overwrite = False + + def active_bo_callback(self, msg): + if not self.active_bo_pending: + # self.get_logger().info('Active Bayesian Optimization request pending!') + self.active_bo_pending = True + self.bo_env = msg.env + self.bo_metric = msg.metric + self.bo_fixed_seed = msg.fixed_seed + self.bo_nr_weights = msg.nr_weights + self.bo_steps = msg.max_steps + self.bo_episodes = msg.nr_episodes + self.bo_runs = msg.nr_runs + self.bo_acq_fcn = msg.acquisition_function + self.bo_metric_parameter = msg.metric_parameter + self.bo_metric_parameter_2 = msg.metric_parameter_2 + self.save_result = msg.save_result + self.seed_array = np.zeros((1, self.bo_runs)) + self.overwrite = msg.overwrite + + # initialize + self.reward = np.zeros((self.bo_episodes + self.nr_init - 1, self.bo_runs)) + self.best_pol_reward = np.zeros((1, self.bo_runs)) + self.best_policy = np.zeros((self.bo_steps, self.bo_runs)) + self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs)) + + # set the seed + if self.bo_fixed_seed: + self.seed = int(np.random.randint(1, 2147483647, 1)[0]) + # self.get_logger().info(str(self.seed)) + else: + self.seed = None + + # set rl environment + if self.bo_env == "Reacher": + pass + else: + raise NotImplementedError + + def reset_rl_response(self): + self.rl_weights = None + self.rl_final_step = None + + def active_rl_callback(self, msg): + if self.rl_pending: + # self.get_logger().info('Active Reinforcement Learning response received!') + self.rl_weights = np.array(msg.weights, dtype=np.float64) + self.weight_preference = np.array(msg.weight_preference, dtype=bool) + self.rl_final_step = msg.final_step + self.rl_reward = msg.reward + + try: + self.BO.add_new_observation(self.rl_reward, self.rl_weights) + # self.get_logger().info('Active Reinforcement Learning added new observation!') + except Exception as e: + self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}') + + if self.init_pending: + self.init_step += 1 + + if self.init_step == self.nr_init: + self.init_step = 0 + self.init_pending = False + + if self.user_asked: + self.last_user_reward = self.rl_reward + if self.bo_acq_fcn == "Preference Expected Improvement": + self.BO.acq_fun.update_proposal_model(self.rl_weights, self.weight_preference) + self.user_asked = False + + self.rl_pending = False + # self.reset_rl_response() + + def mainloop_callback(self): + if not self.active_bo_pending: + return + + else: + if self.rl_pending: + return + + if self.BO is None: + self.BO = BayesianOptimization(self.bo_steps, + 2, + self.bo_nr_weights, + acq=self.bo_acq_fcn) + + self.BO.reset_bo() + + # self.BO.initialize() + self.init_pending = True + self.get_logger().info('BO Initialization is starting!') + # self.get_logger().info(f'{self.rl_pending}') + + if self.init_pending: + if self.bo_fixed_seed: + seed = self.seed + else: + seed = int(np.random.randint(1, 2147483647, 1)[0]) + + rl_msg = ActiveRLRequest() + rl_msg.env = self.bo_env + rl_msg.seed = seed + rl_msg.display_run = False + rl_msg.interactive_run = 2 + rl_msg.weights = self.BO.policy_model.random_policy().tolist() + rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist() + + self.active_rl_pub.publish(rl_msg) + + self.rl_pending = True + + if self.current_run == self.bo_runs: + bo_response = ActiveBOResponse() + + best_policy_idx = np.argmax(self.best_pol_reward) + bo_response.best_policy = self.best_policy[:, best_policy_idx].tolist() + bo_response.best_weights = self.best_weights[:, best_policy_idx].tolist() + + # self.get_logger().info(f'Best Policy: {self.best_pol_reward}') + + self.get_logger().info(f'{best_policy_idx}, {int(self.seed_array[0, best_policy_idx])}') + + bo_response.reward_mean = np.mean(self.reward, axis=1).tolist() + bo_response.reward_std = np.std(self.reward, axis=1).tolist() + + if self.save_result: + if self.bo_env == "Mountain Car": + env = 'mc' + elif self.bo_env == "Cartpole": + env = 'cp' + elif self.bo_env == "Acrobot": + env = 'ab' + elif self.bo_env == "Pendulum": + env = 'pd' + else: + raise NotImplementedError + + if self.bo_acq_fcn == "Expected Improvement": + acq = 'ei' + elif self.bo_acq_fcn == "Probability of Improvement": + acq = 'pi' + elif self.bo_acq_fcn == "Upper Confidence Bound": + acq = 'cb' + elif self.bo_acq_fcn == "Preference Expected Improvement": + acq = 'pei' + else: + raise NotImplementedError + + home_dir = os.path.expanduser('~') + file_path = os.path.join(home_dir, 'Documents/IntRLResults') + filename = env + '-' + acq + '-' + self.bo_metric + '-' \ + + str(round(self.bo_metric_parameter, 2)) + '-' \ + + str(self.bo_nr_weights) + '-' + str(time.time()) + filename = filename.replace('.', '_') + '.csv' + path = os.path.join(file_path, filename) + + data = self.reward + + np.savetxt(path, data, delimiter=',') + + active_rl_request = ActiveRLRequest() + + if self.bo_fixed_seed: + seed = int(self.seed_array[0, best_policy_idx]) + # self.get_logger().info(f'Used seed{seed}') + else: + seed = int(np.random.randint(1, 2147483647, 1)[0]) + + active_rl_request.env = self.bo_env + active_rl_request.seed = seed + active_rl_request.display_run = True + active_rl_request.interactive_run = 1 + active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist() + active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist() + + self.active_rl_pub.publish(active_rl_request) + + self.get_logger().info('Responding: Active BO') + self.active_bo_pub.publish(bo_response) + self.reset_bo_request() + + else: + if self.init_pending: + return + else: + if self.current_episode < self.bo_episodes + self.nr_init - 1: + # metrics + if self.bo_metric == "random": + user_query = RandomQuery(self.bo_metric_parameter) + + elif self.bo_metric == "regular": + user_query = RegularQuery(self.bo_metric_parameter, self.current_episode) + + elif self.bo_metric == "improvement": + user_query = ImprovementQuery(self.bo_metric_parameter, + self.bo_metric_parameter_2, + self.last_query, + self.reward[:self.current_episode, self.current_run]) + + elif self.bo_metric == "max acquisition": + user_query = MaxAcqQuery(self.bo_metric_parameter, + self.BO.GP, + 100, + self.bo_nr_weights, + acq=self.bo_acq_fcn, + X=self.BO.X) + + else: + raise NotImplementedError + + if user_query.query(): + self.last_query = self.current_episode + self.user_asked = True + active_rl_request = ActiveRLRequest() + old_policy, y_max, old_weights, _ = self.BO.get_best_result() + + # self.get_logger().info(f'Best: {y_max}, w:{old_weights}') + # self.get_logger().info(f'Size of Y: {self.BO.Y.shape}, Size of X: {self.BO.X.shape}') + + if self.bo_fixed_seed: + seed = self.seed + else: + seed = int(np.random.randint(1, 2147483647, 1)[0]) + + active_rl_request.env = self.bo_env + active_rl_request.seed = seed + active_rl_request.display_run = True + active_rl_request.interactive_run = 0 + active_rl_request.policy = old_policy.tolist() + active_rl_request.weights = old_weights.tolist() + + # self.get_logger().info('Calling: Active RL') + self.active_rl_pub.publish(active_rl_request) + self.rl_pending = True + + else: + # if self.bo_acq_fcn == "Preference Expected Improvement": + # self.get_logger().info(f"{self.BO.acq_fun.proposal_mean}") + # self.get_logger().info(f"X: {self.BO.X}") + x_next = self.BO.next_observation() + # self.get_logger().info(f'x_next: {x_next}') + # self.get_logger().info(f'overwrite: {self.weight_preference}') + # self.get_logger().info(f'rl_weights: {self.rl_weights}') + + if self.overwrite: + if self.weight_preference is not None and self.rl_weights is not None: + x_next[self.weight_preference] = self.rl_weights[self.weight_preference] + # self.get_logger().info(f'x_next: {x_next}') + # self.get_logger().info(f'overwrite: {self.weight_preference}') + # self.get_logger().info(f'rl_weights: {self.rl_weights}') + # self.get_logger().info('Next Observation BO!') + self.BO.policy_model.weights = np.around(x_next, decimals=8) + if self.bo_fixed_seed: + seed = self.seed + else: + seed = int(np.random.randint(1, 2147483647, 1)[0]) + + rl_msg = ActiveRLRequest() + rl_msg.env = self.bo_env + rl_msg.seed = seed + rl_msg.display_run = False + rl_msg.interactive_run = 2 + rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist() + rl_msg.weights = x_next.tolist() + + self.rl_pending = True + + self.active_rl_pub.publish(rl_msg) + + self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y) + self.get_logger().info(f'Current Episode: {self.current_episode},' + f' best reward: {self.reward[self.current_episode, self.current_run]}') + self.current_episode += 1 + + + else: + self.best_policy[:, self.current_run], \ + self.best_pol_reward[:, self.current_run], \ + self.best_weights[:, self.current_run], idx = self.BO.get_best_result() + + if self.current_run < self.bo_runs - 1: + self.BO = None + + self.current_episode = 0 + self.last_query = 0 + if self.bo_fixed_seed: + self.seed_array[0, self.current_run] = self.seed + self.seed = int(np.random.randint(1, 2147483647, 1)[0]) + # self.get_logger().info(f'{self.seed}') + self.current_run += 1 + self.get_logger().info(f'Current Run: {self.current_run}') + + # send the current states + + if self.BO is not None and self.BO.Y is not None: + self.best_reward = np.max(self.BO.Y) + + state_msg = ActiveBOState() + state_msg.current_run = self.current_run + 1 if self.current_run < self.bo_runs else self.bo_runs + state_msg.current_episode = self.current_episode \ + if self.current_episode < self.bo_episodes else self.bo_episodes + state_msg.best_reward = float(self.best_reward) + state_msg.last_user_reward = self.last_user_reward + self.state_pub.publish(state_msg) + + +def main(args=None): + rclpy.init(args=args) + + active_bo_topic = ActiveBOTopic() + + rclpy.spin(active_bo_topic) + + try: + rclpy.spin(active_bo_topic) + except KeyboardInterrupt: + pass + + active_bo_topic.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() + diff --git a/src/active_bo_ros/active_bo_ros/interactive_rl_2d.py b/src/active_bo_ros/active_bo_ros/interactive_rl_2d.py new file mode 100644 index 0000000..0729c09 --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/interactive_rl_2d.py @@ -0,0 +1,293 @@ +from active_bo_msgs.msg import ActiveRLRequest +from active_bo_msgs.msg import ActiveRLResponse +from active_bo_msgs.msg import ActiveRLEvalRequest +from active_bo_msgs.msg import ActiveRLEvalResponse + +from active_bo_msgs.msg import ImageFeedback + +import rclpy +from rclpy.node import Node + +from rclpy.callback_groups import ReentrantCallbackGroup + +from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv +from active_bo_ros.ReinforcementLearning.CartPole import CartPoleEnv +from active_bo_ros.ReinforcementLearning.Pendulum import PendulumEnv +from active_bo_ros.ReinforcementLearning.Acrobot import AcrobotEnv + + +import numpy as np +import time +import copy + + +class ActiveRL(Node): + def __init__(self): + super().__init__('active_rl_service') + rl_callback_group = ReentrantCallbackGroup() + topic_callback_group = ReentrantCallbackGroup() + mainloop_callback_group = ReentrantCallbackGroup() + + # Active Reinforcement Learning Publisher, Subscriber and Message attributes + self.active_rl_pub = self.create_publisher(ActiveRLResponse, + 'active_rl_response', + 1, callback_group=rl_callback_group) + self.active_rl_sub = self.create_subscription(ActiveRLRequest, + 'active_rl_request', + self.active_rl_callback, + 1, callback_group=rl_callback_group) + + self.rl_env = None + self.rl_seed = None + self.rl_policy = None + self.rl_weights = None + self.rl_reward = 0.0 + self.rl_step = 0 + + # Image publisher to publish the rgb array from the gym environment + self.image_pub = self.create_publisher(ImageFeedback, + 'rl_feedback', + 1, callback_group=topic_callback_group) + + # Active RL Evaluation Publisher, Subscriber and Message attributes + self.eval_pub = self.create_publisher(ActiveRLEvalRequest, + 'active_rl_eval_request', + 1, + callback_group=topic_callback_group) + self.eval_sub = self.create_subscription(ActiveRLEvalResponse, + 'active_rl_eval_response', + self.active_rl_eval_callback, + 1, + callback_group=topic_callback_group) + + self.eval_response_received = False + self.eval_policy = None + self.eval_weights = None + self.overwrite_weight = None + + # RL Environments + self.env = None + + # State Machine Variables + self.best_pol_shown = False + self.policy_sent = False + self.rl_pending = False + self.interactive_run = 0 + self.display_run = False + + # Main loop timer object + self.mainloop_timer_period = 0.1 + self.mainloop = self.create_timer(self.mainloop_timer_period, + self.mainloop_callback, + callback_group=mainloop_callback_group) + + def reset_rl_request(self): + self.rl_env = None + self.rl_seed = None + self.rl_policy = None + self.rl_weights = None + self.interactive_run = 0 + self.display_run = False + + def active_rl_callback(self, msg): + self.rl_env = msg.env + self.rl_seed = msg.seed + self.display_run = msg.display_run + self.rl_policy = np.array(msg.policy, dtype=np.float64) + self.rl_weights = msg.weights + self.interactive_run = msg.interactive_run + + if self.rl_env == "Mountain Car": + self.env = Continuous_MountainCarEnv(render_mode="rgb_array") + elif self.rl_env == "Cartpole": + self.env = CartPoleEnv(render_mode="rgb_array") + elif self.rl_env == "Acrobot": + self.env = AcrobotEnv(render_mode="rgb_array") + elif self.rl_env == "Pendulum": + self.env = PendulumEnv(render_mode="rgb_array") + else: + raise NotImplementedError + + # self.get_logger().info('Active RL: Called!') + self.env.reset(seed=self.rl_seed) + self.rl_pending = True + self.policy_sent = False + self.rl_step = 0 + + def reset_eval_request(self): + self.eval_policy = None + self.eval_weights = None + + def active_rl_eval_callback(self, msg): + self.eval_policy = np.array(msg.policy, dtype=np.float64) + self.eval_weights = msg.weights + self.overwrite_weight = msg.overwrite_weight + + self.get_logger().info('Active RL Eval: Responded!') + self.env.reset(seed=self.rl_seed) + self.eval_response_received = True + + def next_image(self, policy, display_run): + action = policy[self.rl_step] + action_clipped = action.clip(min=-1.0, max=1.0) + output = self.env.step(action_clipped.astype(np.float64)) + + self.rl_reward += output[1] + done = output[2] + self.rl_step += 1 + + if display_run: + rgb_array = self.env.render() + rgb_shape = rgb_array.shape + + red = rgb_array[:, :, 0].flatten().tolist() + green = rgb_array[:, :, 1].flatten().tolist() + blue = rgb_array[:, :, 2].flatten().tolist() + + feedback_msg = ImageFeedback() + + feedback_msg.height = rgb_shape[0] + feedback_msg.width = rgb_shape[1] + feedback_msg.current_time = self.rl_step + feedback_msg.red = red + feedback_msg.green = green + feedback_msg.blue = blue + + self.image_pub.publish(feedback_msg) + + if not done and self.rl_step == len(policy): + done = True + + return done + + def complete_run(self, policy): + env_reward = 0.0 + step_count = 0 + + self.env.reset(seed=self.rl_seed) + + for i in range(len(policy)): + action = policy[i] + action_clipped = action.clip(min=-1.0, max=1.0) + output = self.env.step(action_clipped.astype(np.float64)) + + env_reward += output[1] + done = output[2] + step_count += 1 + + if done: + break + + self.env.reset(seed=self.rl_seed) + return env_reward, step_count + + def mainloop_callback(self): + if self.rl_pending: + if self.interactive_run == 0: + if not self.best_pol_shown: + if not self.policy_sent: + self.rl_step = 0 + self.rl_reward = 0.0 + self.env.reset(seed=self.rl_seed) + + eval_request = ActiveRLEvalRequest() + eval_request.policy = self.rl_policy.tolist() + eval_request.weights = self.rl_weights + + self.eval_pub.publish(eval_request) + self.get_logger().info('Active RL: Called!') + self.get_logger().info('Active RL: Waiting for Eval!') + + self.policy_sent = True + + done = self.next_image(self.rl_policy, self.display_run) + + if done: + self.best_pol_shown = True + self.rl_step = 0 + self.rl_reward = 0.0 + + elif self.best_pol_shown: + if not self.eval_response_received: + pass + + if self.eval_response_received: + done = self.next_image(self.eval_policy, self.display_run) + + if done: + rl_response = ActiveRLResponse() + rl_response.weights = self.eval_weights + rl_response.reward = self.rl_reward + rl_response.final_step = self.rl_step + rl_response.overwrite_weight = self.overwrite_weight + + self.active_rl_pub.publish(rl_response) + + self.env.reset(seed=self.rl_seed) + + # reset flags and attributes + self.reset_eval_request() + self.reset_rl_request() + + self.rl_step = 0 + self.rl_reward = 0.0 + + self.best_pol_shown = False + self.eval_response_received = False + self.rl_pending = False + + elif self.interactive_run == 1: + if not self.policy_sent: + self.rl_step = 0 + self.rl_reward = 0.0 + self.env.reset(seed=self.rl_seed) + + eval_request = ActiveRLEvalRequest() + eval_request.policy = self.rl_policy.tolist() + eval_request.weights = self.rl_weights + + self.eval_pub.publish(eval_request) + self.get_logger().info('Active RL: Called!') + self.get_logger().info('Active RL: Waiting for Eval!') + + self.policy_sent = True + + done = self.next_image(self.rl_policy, self.display_run) + + if done: + self.rl_step = 0 + self.rl_reward = 0.0 + self.rl_pending = False + + elif self.interactive_run == 2: + env_reward, step_count = self.complete_run(self.rl_policy) + + rl_response = ActiveRLResponse() + rl_response.weights = self.rl_weights + rl_response.reward = env_reward + rl_response.final_step = step_count + if self.overwrite_weight is None: + overwrite_weight = [False] * len(self.rl_weights) + else: + overwrite_weight = self.overwrite_weight + + rl_response.overwrite_weight = overwrite_weight + + self.active_rl_pub.publish(rl_response) + + self.reset_rl_request() + self.rl_pending = False + + +def main(args=None): + rclpy.init(args=args) + + active_rl_service = ActiveRL() + + rclpy.spin(active_rl_service) + + rclpy.shutdown() + + +if __name__ == '__main__': + main()