prepare for 2d bo

This commit is contained in:
Niko Feith 2023-08-24 14:23:27 +02:00
parent d6195e13a1
commit 33b8093a49
7 changed files with 917 additions and 9 deletions

View File

@ -1,3 +1,3 @@
bool[] overwrite_weight
bool[] weight_preference
float64[] policy
float64[] weights

View File

@ -1,4 +1,4 @@
float64[] weights
bool[] overwrite_weight
bool[] weight_preference
uint16 final_step
float64 reward

View File

@ -0,0 +1,120 @@
import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern
from active_bo_ros.PolicyModel.GaussianModelMultiDim import GaussianRBF
from active_bo_ros.AcquisitionFunctions.ExpectedImprovement import ExpectedImprovement
from active_bo_ros.AcquisitionFunctions.ProbabilityOfImprovement import ProbabilityOfImprovement
from active_bo_ros.AcquisitionFunctions.ConfidenceBound import ConfidenceBound
from active_bo_ros.AcquisitionFunctions.PreferenceExpectedImprovement import PreferenceExpectedImprovement
from sklearn.exceptions import ConvergenceWarning
import warnings
warnings.filterwarnings('ignore', category=ConvergenceWarning)
class BayesianOptimization:
def __init__(self, nr_steps, nr_dims, nr_weights, acq='ei', seed=None):
self.acq = acq
self.episode = 0
self.nr_steps = nr_steps
self.nr_dims = nr_dims
self.nr_weights = nr_weights
self.weights = self.nr_weights * self.nr_dims
self.lower_bound = -1.0
self.upper_bound = 1.0
self.seed = seed
self.X = None
self.Y = None
self.gp = None
self.policy_model = GaussianRBF(self.nr_steps, self.nr_weights, self.nr_dims,
lowerb=self.lower_bound, upperb=self.upper_bound, seed=seed)
self.acq_sample_size = 100
self.best_reward = np.empty((1, 1))
if acq == "Preference Expected Improvement":
self.acq_fun = PreferenceExpectedImprovement(self.weights,
self.acq_sample_size,
self.lower_bound,
self.upper_bound,
initial_variance=10.0,
update_variance=0.05,
seed=seed)
self.reset_bo()
def reset_bo(self):
self.gp = GaussianProcessRegressor(Matern(nu=1.5, ), n_restarts_optimizer=5) #length_scale=(1e-8, 1e5)
self.best_reward = np.empty((1, 1))
self.X = np.zeros((1, self.weights), dtype=np.float64)
self.Y = np.zeros((1, 1), dtype=np.float64)
self.episode = 0
def next_observation(self):
if self.acq == "Expected Improvement":
x_next = ExpectedImprovement(self.gp,
self.X,
self.acq_sample_size,
self.weights,
kappa=0,
seed=self.seed,
lower=self.lower_bound,
upper=self.upper_bound)
elif self.acq == "Probability of Improvement":
x_next = ProbabilityOfImprovement(self.gp,
self.X,
self.acq_sample_size,
self.weights,
kappa=0,
seed=self.seed,
lower=self.lower_bound,
upper=self.upper_bound)
elif self.acq == "Upper Confidence Bound":
x_next = ConfidenceBound(self.gp,
self.acq_sample_size,
self.weights,
beta=2.576,
seed=self.seed,
lower=self.lower_bound,
upper=self.upper_bound)
elif self.acq == "Preference Expected Improvement":
x_next = self.acq_fun.expected_improvement(self.gp,
self.X,
kappa=0)
else:
raise NotImplementedError
return x_next
def add_observation(self, reward, x):
if self.episode == 0:
self.X[0, :] = x
self.Y[0] = reward
self.best_reward[0] = np.max(self.Y)
else:
self.X = np.vstack((self.X, np.around(x, decimals=8)), dtype=np.float64)
self.Y = np.vstack((self.Y, reward), dtype=np.float64)
self.best_reward = np.vstack((self.best_reward, np.max(self.Y)), dtype=np.float64)
self.gp.fit(self.X, self.Y)
self.episode += 1
def get_best_result(self):
y_max = np.max(self.Y)
idx = np.argmax(self.Y)
x_max = self.X[idx, :]
return y_max, x_max, idx

View File

@ -0,0 +1,49 @@
import numpy as np
class GaussianRBF:
def __init__(self, nr_steps, nr_weights, nr_dims, lowerb=-1.0, upperb=1.0, seed=None):
self.nr_weights = nr_weights
self.nr_steps = nr_steps
self.nr_dims = nr_dims
self.weights = None
self.trajectory = None
self.lowerb = lowerb
self.upperb = upperb
self.rng = np.random.default_rng(seed=seed)
# initialize
self.mid_points = np.linspace(0, self.nr_steps, self.nr_weights)
if nr_weights > 1:
self.std = self.mid_points[1] / (2 * np.sqrt(2 * np.log(2))) # Full width at half maximum
else:
self.std = self.nr_steps / 2
self.reset()
def reset(self):
self.weights = np.zeros((self.nr_weights, self.nr_dims))
self.trajectory = np.zeros((self.nr_steps, self.nr_dims))
def random_weights(self):
for dim in range(self.nr_dims):
self.weights[:, dim] = self.rng.uniform(self.lowerb, self.upperb, self.nr_weights)
def rollout(self):
self.trajectory = np.zeros((self.nr_steps, self.nr_dims))
for step in range(self.nr_steps):
for weight in range(self.nr_weights):
base_fun = np.exp(-0.5 * (step - self.mid_points[weight]) ** 2 / self.std ** 2)
for dim in range(self.nr_dims):
self.trajectory[step, dim] += base_fun * self.weights[weight, dim]
return self.trajectory
def set_weights(self, x):
self.weights = x.reshape(self.nr_weights, self.nr_dims)
def get_x(self):
return self.weights.reshape(self.nr_weights * self.nr_dims, 1)

View File

@ -74,7 +74,7 @@ class ActiveBOTopic(Node):
self.rl_weights = None
self.rl_final_step = None
self.rl_reward = 0.0
self.overwrite_weight = None
self.weight_preference = None
# State Publisher
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
@ -175,7 +175,7 @@ class ActiveBOTopic(Node):
if self.rl_pending:
# self.get_logger().info('Active Reinforcement Learning response received!')
self.rl_weights = np.array(msg.weights, dtype=np.float64)
self.overwrite_weight = np.array(msg.overwrite_weight, dtype=bool)
self.weight_preference = np.array(msg.weight_preference, dtype=bool)
self.rl_final_step = msg.final_step
self.rl_reward = msg.reward
@ -195,7 +195,7 @@ class ActiveBOTopic(Node):
if self.user_asked:
self.last_user_reward = self.rl_reward
if self.bo_acq_fcn == "Preference Expected Improvement":
self.BO.acq_fun.update_proposal_model(self.rl_weights, self.overwrite_weight)
self.BO.acq_fun.update_proposal_model(self.rl_weights, self.weight_preference)
self.user_asked = False
self.rl_pending = False
@ -371,14 +371,14 @@ class ActiveBOTopic(Node):
# self.get_logger().info(f"X: {self.BO.X}")
x_next = self.BO.next_observation()
# self.get_logger().info(f'x_next: {x_next}')
# self.get_logger().info(f'overwrite: {self.overwrite_weight}')
# self.get_logger().info(f'overwrite: {self.weight_preference}')
# self.get_logger().info(f'rl_weights: {self.rl_weights}')
if self.overwrite:
if self.overwrite_weight is not None and self.rl_weights is not None:
x_next[self.overwrite_weight] = self.rl_weights[self.overwrite_weight]
if self.weight_preference is not None and self.rl_weights is not None:
x_next[self.weight_preference] = self.rl_weights[self.weight_preference]
# self.get_logger().info(f'x_next: {x_next}')
# self.get_logger().info(f'overwrite: {self.overwrite_weight}')
# self.get_logger().info(f'overwrite: {self.weight_preference}')
# self.get_logger().info(f'rl_weights: {self.rl_weights}')
# self.get_logger().info('Next Observation BO!')
self.BO.policy_model.weights = np.around(x_next, decimals=8)

View File

@ -0,0 +1,446 @@
from active_bo_msgs.msg import ActiveBORequest
from active_bo_msgs.msg import ActiveBOResponse
from active_bo_msgs.msg import ActiveRLRequest
from active_bo_msgs.msg import ActiveRLResponse
from active_bo_msgs.msg import ActiveBOState
import rclpy
from rclpy.node import Node
from rclpy.callback_groups import ReentrantCallbackGroup
from active_bo_ros.BayesianOptimization.BO2D import BayesianOptimization
from active_bo_ros.UserQuery.random_query import RandomQuery
from active_bo_ros.UserQuery.regular_query import RegularQuery
from active_bo_ros.UserQuery.improvement_query import ImprovementQuery
from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery
import numpy as np
import time
import os
class ActiveBOTopic(Node):
def __init__(self):
super().__init__('active_bo_topic')
bo_callback_group = ReentrantCallbackGroup()
rl_callback_group = ReentrantCallbackGroup()
mainloop_callback_group = ReentrantCallbackGroup()
# Active Bayesian Optimization Publisher, Subscriber and Message attributes
self.active_bo_pub = self.create_publisher(ActiveBOResponse,
'active_bo_response',
1, callback_group=bo_callback_group)
self.active_bo_sub = self.create_subscription(ActiveBORequest,
'active_bo_request',
self.active_bo_callback,
1, callback_group=bo_callback_group)
self.active_bo_pending = False
self.bo_env = None
self.bo_metric = None
self.bo_fixed_seed = False
self.bo_nr_weights = None
self.bo_nr_dims = 2
self.bo_steps = 0
self.bo_episodes = 0
self.bo_runs = 0
self.bo_acq_fcn = None
self.bo_metric_parameter = None
self.bo_metric_parameter_2 = None
self.current_run = 0
self.current_episode = 0
self.seed = None
self.seed_array = None
self.save_result = False
# Active Reinforcement Learning Publisher, Subscriber and Message attributes
self.active_rl_pub = self.create_publisher(ActiveRLRequest,
'active_rl_request',
1, callback_group=rl_callback_group)
self.active_rl_sub = self.create_subscription(ActiveRLResponse,
'active_rl_response',
self.active_rl_callback,
1, callback_group=rl_callback_group)
self.rl_pending = False
self.rl_weights = None
self.rl_final_step = None
self.rl_reward = 0.0
self.weight_preference = None
# State Publisher
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
# RL Environments and BO
self.env = None
self.BO = None
self.nr_init = 3
self.init_step = 0
self.init_pending = False
self.reward = None
self.best_reward = 0.0
self.best_pol_reward = None
self.best_policy = None
self.best_weights = None
# User Query
self.last_query = 0
self.user_asked = False
self.last_user_reward = 0.0
self.overwrite = False
# Main loop timer object
self.mainloop_timer_period = 0.1
self.mainloop = self.create_timer(self.mainloop_timer_period,
self.mainloop_callback,
callback_group=mainloop_callback_group)
def reset_bo_request(self):
self.bo_env = None
self.bo_metric = None
self.bo_fixed_seed = False
self.bo_nr_weights = None
self.bo_steps = 0
self.bo_episodes = 0
self.bo_runs = 0
self.bo_acq_fcn = None
self.bo_metric_parameter = None
self.bo_metric_parameter_2 = None
self.current_run = 0
self.current_episode = 0
self.save_result = False
self.seed_array = None
self.env = None
self.active_bo_pending = False
self.BO = None
self.overwrite = False
def active_bo_callback(self, msg):
if not self.active_bo_pending:
# self.get_logger().info('Active Bayesian Optimization request pending!')
self.active_bo_pending = True
self.bo_env = msg.env
self.bo_metric = msg.metric
self.bo_fixed_seed = msg.fixed_seed
self.bo_nr_weights = msg.nr_weights
self.bo_steps = msg.max_steps
self.bo_episodes = msg.nr_episodes
self.bo_runs = msg.nr_runs
self.bo_acq_fcn = msg.acquisition_function
self.bo_metric_parameter = msg.metric_parameter
self.bo_metric_parameter_2 = msg.metric_parameter_2
self.save_result = msg.save_result
self.seed_array = np.zeros((1, self.bo_runs))
self.overwrite = msg.overwrite
# initialize
self.reward = np.zeros((self.bo_episodes + self.nr_init - 1, self.bo_runs))
self.best_pol_reward = np.zeros((1, self.bo_runs))
self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
# set the seed
if self.bo_fixed_seed:
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
# self.get_logger().info(str(self.seed))
else:
self.seed = None
# set rl environment
if self.bo_env == "Reacher":
pass
else:
raise NotImplementedError
def reset_rl_response(self):
self.rl_weights = None
self.rl_final_step = None
def active_rl_callback(self, msg):
if self.rl_pending:
# self.get_logger().info('Active Reinforcement Learning response received!')
self.rl_weights = np.array(msg.weights, dtype=np.float64)
self.weight_preference = np.array(msg.weight_preference, dtype=bool)
self.rl_final_step = msg.final_step
self.rl_reward = msg.reward
try:
self.BO.add_new_observation(self.rl_reward, self.rl_weights)
# self.get_logger().info('Active Reinforcement Learning added new observation!')
except Exception as e:
self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}')
if self.init_pending:
self.init_step += 1
if self.init_step == self.nr_init:
self.init_step = 0
self.init_pending = False
if self.user_asked:
self.last_user_reward = self.rl_reward
if self.bo_acq_fcn == "Preference Expected Improvement":
self.BO.acq_fun.update_proposal_model(self.rl_weights, self.weight_preference)
self.user_asked = False
self.rl_pending = False
# self.reset_rl_response()
def mainloop_callback(self):
if not self.active_bo_pending:
return
else:
if self.rl_pending:
return
if self.BO is None:
self.BO = BayesianOptimization(self.bo_steps,
2,
self.bo_nr_weights,
acq=self.bo_acq_fcn)
self.BO.reset_bo()
# self.BO.initialize()
self.init_pending = True
self.get_logger().info('BO Initialization is starting!')
# self.get_logger().info(f'{self.rl_pending}')
if self.init_pending:
if self.bo_fixed_seed:
seed = self.seed
else:
seed = int(np.random.randint(1, 2147483647, 1)[0])
rl_msg = ActiveRLRequest()
rl_msg.env = self.bo_env
rl_msg.seed = seed
rl_msg.display_run = False
rl_msg.interactive_run = 2
rl_msg.weights = self.BO.policy_model.random_policy().tolist()
rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist()
self.active_rl_pub.publish(rl_msg)
self.rl_pending = True
if self.current_run == self.bo_runs:
bo_response = ActiveBOResponse()
best_policy_idx = np.argmax(self.best_pol_reward)
bo_response.best_policy = self.best_policy[:, best_policy_idx].tolist()
bo_response.best_weights = self.best_weights[:, best_policy_idx].tolist()
# self.get_logger().info(f'Best Policy: {self.best_pol_reward}')
self.get_logger().info(f'{best_policy_idx}, {int(self.seed_array[0, best_policy_idx])}')
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
if self.save_result:
if self.bo_env == "Mountain Car":
env = 'mc'
elif self.bo_env == "Cartpole":
env = 'cp'
elif self.bo_env == "Acrobot":
env = 'ab'
elif self.bo_env == "Pendulum":
env = 'pd'
else:
raise NotImplementedError
if self.bo_acq_fcn == "Expected Improvement":
acq = 'ei'
elif self.bo_acq_fcn == "Probability of Improvement":
acq = 'pi'
elif self.bo_acq_fcn == "Upper Confidence Bound":
acq = 'cb'
elif self.bo_acq_fcn == "Preference Expected Improvement":
acq = 'pei'
else:
raise NotImplementedError
home_dir = os.path.expanduser('~')
file_path = os.path.join(home_dir, 'Documents/IntRLResults')
filename = env + '-' + acq + '-' + self.bo_metric + '-' \
+ str(round(self.bo_metric_parameter, 2)) + '-' \
+ str(self.bo_nr_weights) + '-' + str(time.time())
filename = filename.replace('.', '_') + '.csv'
path = os.path.join(file_path, filename)
data = self.reward
np.savetxt(path, data, delimiter=',')
active_rl_request = ActiveRLRequest()
if self.bo_fixed_seed:
seed = int(self.seed_array[0, best_policy_idx])
# self.get_logger().info(f'Used seed{seed}')
else:
seed = int(np.random.randint(1, 2147483647, 1)[0])
active_rl_request.env = self.bo_env
active_rl_request.seed = seed
active_rl_request.display_run = True
active_rl_request.interactive_run = 1
active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist()
active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist()
self.active_rl_pub.publish(active_rl_request)
self.get_logger().info('Responding: Active BO')
self.active_bo_pub.publish(bo_response)
self.reset_bo_request()
else:
if self.init_pending:
return
else:
if self.current_episode < self.bo_episodes + self.nr_init - 1:
# metrics
if self.bo_metric == "random":
user_query = RandomQuery(self.bo_metric_parameter)
elif self.bo_metric == "regular":
user_query = RegularQuery(self.bo_metric_parameter, self.current_episode)
elif self.bo_metric == "improvement":
user_query = ImprovementQuery(self.bo_metric_parameter,
self.bo_metric_parameter_2,
self.last_query,
self.reward[:self.current_episode, self.current_run])
elif self.bo_metric == "max acquisition":
user_query = MaxAcqQuery(self.bo_metric_parameter,
self.BO.GP,
100,
self.bo_nr_weights,
acq=self.bo_acq_fcn,
X=self.BO.X)
else:
raise NotImplementedError
if user_query.query():
self.last_query = self.current_episode
self.user_asked = True
active_rl_request = ActiveRLRequest()
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
# self.get_logger().info(f'Best: {y_max}, w:{old_weights}')
# self.get_logger().info(f'Size of Y: {self.BO.Y.shape}, Size of X: {self.BO.X.shape}')
if self.bo_fixed_seed:
seed = self.seed
else:
seed = int(np.random.randint(1, 2147483647, 1)[0])
active_rl_request.env = self.bo_env
active_rl_request.seed = seed
active_rl_request.display_run = True
active_rl_request.interactive_run = 0
active_rl_request.policy = old_policy.tolist()
active_rl_request.weights = old_weights.tolist()
# self.get_logger().info('Calling: Active RL')
self.active_rl_pub.publish(active_rl_request)
self.rl_pending = True
else:
# if self.bo_acq_fcn == "Preference Expected Improvement":
# self.get_logger().info(f"{self.BO.acq_fun.proposal_mean}")
# self.get_logger().info(f"X: {self.BO.X}")
x_next = self.BO.next_observation()
# self.get_logger().info(f'x_next: {x_next}')
# self.get_logger().info(f'overwrite: {self.weight_preference}')
# self.get_logger().info(f'rl_weights: {self.rl_weights}')
if self.overwrite:
if self.weight_preference is not None and self.rl_weights is not None:
x_next[self.weight_preference] = self.rl_weights[self.weight_preference]
# self.get_logger().info(f'x_next: {x_next}')
# self.get_logger().info(f'overwrite: {self.weight_preference}')
# self.get_logger().info(f'rl_weights: {self.rl_weights}')
# self.get_logger().info('Next Observation BO!')
self.BO.policy_model.weights = np.around(x_next, decimals=8)
if self.bo_fixed_seed:
seed = self.seed
else:
seed = int(np.random.randint(1, 2147483647, 1)[0])
rl_msg = ActiveRLRequest()
rl_msg.env = self.bo_env
rl_msg.seed = seed
rl_msg.display_run = False
rl_msg.interactive_run = 2
rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist()
rl_msg.weights = x_next.tolist()
self.rl_pending = True
self.active_rl_pub.publish(rl_msg)
self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y)
self.get_logger().info(f'Current Episode: {self.current_episode},'
f' best reward: {self.reward[self.current_episode, self.current_run]}')
self.current_episode += 1
else:
self.best_policy[:, self.current_run], \
self.best_pol_reward[:, self.current_run], \
self.best_weights[:, self.current_run], idx = self.BO.get_best_result()
if self.current_run < self.bo_runs - 1:
self.BO = None
self.current_episode = 0
self.last_query = 0
if self.bo_fixed_seed:
self.seed_array[0, self.current_run] = self.seed
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
# self.get_logger().info(f'{self.seed}')
self.current_run += 1
self.get_logger().info(f'Current Run: {self.current_run}')
# send the current states
if self.BO is not None and self.BO.Y is not None:
self.best_reward = np.max(self.BO.Y)
state_msg = ActiveBOState()
state_msg.current_run = self.current_run + 1 if self.current_run < self.bo_runs else self.bo_runs
state_msg.current_episode = self.current_episode \
if self.current_episode < self.bo_episodes else self.bo_episodes
state_msg.best_reward = float(self.best_reward)
state_msg.last_user_reward = self.last_user_reward
self.state_pub.publish(state_msg)
def main(args=None):
rclpy.init(args=args)
active_bo_topic = ActiveBOTopic()
rclpy.spin(active_bo_topic)
try:
rclpy.spin(active_bo_topic)
except KeyboardInterrupt:
pass
active_bo_topic.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,293 @@
from active_bo_msgs.msg import ActiveRLRequest
from active_bo_msgs.msg import ActiveRLResponse
from active_bo_msgs.msg import ActiveRLEvalRequest
from active_bo_msgs.msg import ActiveRLEvalResponse
from active_bo_msgs.msg import ImageFeedback
import rclpy
from rclpy.node import Node
from rclpy.callback_groups import ReentrantCallbackGroup
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
from active_bo_ros.ReinforcementLearning.CartPole import CartPoleEnv
from active_bo_ros.ReinforcementLearning.Pendulum import PendulumEnv
from active_bo_ros.ReinforcementLearning.Acrobot import AcrobotEnv
import numpy as np
import time
import copy
class ActiveRL(Node):
def __init__(self):
super().__init__('active_rl_service')
rl_callback_group = ReentrantCallbackGroup()
topic_callback_group = ReentrantCallbackGroup()
mainloop_callback_group = ReentrantCallbackGroup()
# Active Reinforcement Learning Publisher, Subscriber and Message attributes
self.active_rl_pub = self.create_publisher(ActiveRLResponse,
'active_rl_response',
1, callback_group=rl_callback_group)
self.active_rl_sub = self.create_subscription(ActiveRLRequest,
'active_rl_request',
self.active_rl_callback,
1, callback_group=rl_callback_group)
self.rl_env = None
self.rl_seed = None
self.rl_policy = None
self.rl_weights = None
self.rl_reward = 0.0
self.rl_step = 0
# Image publisher to publish the rgb array from the gym environment
self.image_pub = self.create_publisher(ImageFeedback,
'rl_feedback',
1, callback_group=topic_callback_group)
# Active RL Evaluation Publisher, Subscriber and Message attributes
self.eval_pub = self.create_publisher(ActiveRLEvalRequest,
'active_rl_eval_request',
1,
callback_group=topic_callback_group)
self.eval_sub = self.create_subscription(ActiveRLEvalResponse,
'active_rl_eval_response',
self.active_rl_eval_callback,
1,
callback_group=topic_callback_group)
self.eval_response_received = False
self.eval_policy = None
self.eval_weights = None
self.overwrite_weight = None
# RL Environments
self.env = None
# State Machine Variables
self.best_pol_shown = False
self.policy_sent = False
self.rl_pending = False
self.interactive_run = 0
self.display_run = False
# Main loop timer object
self.mainloop_timer_period = 0.1
self.mainloop = self.create_timer(self.mainloop_timer_period,
self.mainloop_callback,
callback_group=mainloop_callback_group)
def reset_rl_request(self):
self.rl_env = None
self.rl_seed = None
self.rl_policy = None
self.rl_weights = None
self.interactive_run = 0
self.display_run = False
def active_rl_callback(self, msg):
self.rl_env = msg.env
self.rl_seed = msg.seed
self.display_run = msg.display_run
self.rl_policy = np.array(msg.policy, dtype=np.float64)
self.rl_weights = msg.weights
self.interactive_run = msg.interactive_run
if self.rl_env == "Mountain Car":
self.env = Continuous_MountainCarEnv(render_mode="rgb_array")
elif self.rl_env == "Cartpole":
self.env = CartPoleEnv(render_mode="rgb_array")
elif self.rl_env == "Acrobot":
self.env = AcrobotEnv(render_mode="rgb_array")
elif self.rl_env == "Pendulum":
self.env = PendulumEnv(render_mode="rgb_array")
else:
raise NotImplementedError
# self.get_logger().info('Active RL: Called!')
self.env.reset(seed=self.rl_seed)
self.rl_pending = True
self.policy_sent = False
self.rl_step = 0
def reset_eval_request(self):
self.eval_policy = None
self.eval_weights = None
def active_rl_eval_callback(self, msg):
self.eval_policy = np.array(msg.policy, dtype=np.float64)
self.eval_weights = msg.weights
self.overwrite_weight = msg.overwrite_weight
self.get_logger().info('Active RL Eval: Responded!')
self.env.reset(seed=self.rl_seed)
self.eval_response_received = True
def next_image(self, policy, display_run):
action = policy[self.rl_step]
action_clipped = action.clip(min=-1.0, max=1.0)
output = self.env.step(action_clipped.astype(np.float64))
self.rl_reward += output[1]
done = output[2]
self.rl_step += 1
if display_run:
rgb_array = self.env.render()
rgb_shape = rgb_array.shape
red = rgb_array[:, :, 0].flatten().tolist()
green = rgb_array[:, :, 1].flatten().tolist()
blue = rgb_array[:, :, 2].flatten().tolist()
feedback_msg = ImageFeedback()
feedback_msg.height = rgb_shape[0]
feedback_msg.width = rgb_shape[1]
feedback_msg.current_time = self.rl_step
feedback_msg.red = red
feedback_msg.green = green
feedback_msg.blue = blue
self.image_pub.publish(feedback_msg)
if not done and self.rl_step == len(policy):
done = True
return done
def complete_run(self, policy):
env_reward = 0.0
step_count = 0
self.env.reset(seed=self.rl_seed)
for i in range(len(policy)):
action = policy[i]
action_clipped = action.clip(min=-1.0, max=1.0)
output = self.env.step(action_clipped.astype(np.float64))
env_reward += output[1]
done = output[2]
step_count += 1
if done:
break
self.env.reset(seed=self.rl_seed)
return env_reward, step_count
def mainloop_callback(self):
if self.rl_pending:
if self.interactive_run == 0:
if not self.best_pol_shown:
if not self.policy_sent:
self.rl_step = 0
self.rl_reward = 0.0
self.env.reset(seed=self.rl_seed)
eval_request = ActiveRLEvalRequest()
eval_request.policy = self.rl_policy.tolist()
eval_request.weights = self.rl_weights
self.eval_pub.publish(eval_request)
self.get_logger().info('Active RL: Called!')
self.get_logger().info('Active RL: Waiting for Eval!')
self.policy_sent = True
done = self.next_image(self.rl_policy, self.display_run)
if done:
self.best_pol_shown = True
self.rl_step = 0
self.rl_reward = 0.0
elif self.best_pol_shown:
if not self.eval_response_received:
pass
if self.eval_response_received:
done = self.next_image(self.eval_policy, self.display_run)
if done:
rl_response = ActiveRLResponse()
rl_response.weights = self.eval_weights
rl_response.reward = self.rl_reward
rl_response.final_step = self.rl_step
rl_response.overwrite_weight = self.overwrite_weight
self.active_rl_pub.publish(rl_response)
self.env.reset(seed=self.rl_seed)
# reset flags and attributes
self.reset_eval_request()
self.reset_rl_request()
self.rl_step = 0
self.rl_reward = 0.0
self.best_pol_shown = False
self.eval_response_received = False
self.rl_pending = False
elif self.interactive_run == 1:
if not self.policy_sent:
self.rl_step = 0
self.rl_reward = 0.0
self.env.reset(seed=self.rl_seed)
eval_request = ActiveRLEvalRequest()
eval_request.policy = self.rl_policy.tolist()
eval_request.weights = self.rl_weights
self.eval_pub.publish(eval_request)
self.get_logger().info('Active RL: Called!')
self.get_logger().info('Active RL: Waiting for Eval!')
self.policy_sent = True
done = self.next_image(self.rl_policy, self.display_run)
if done:
self.rl_step = 0
self.rl_reward = 0.0
self.rl_pending = False
elif self.interactive_run == 2:
env_reward, step_count = self.complete_run(self.rl_policy)
rl_response = ActiveRLResponse()
rl_response.weights = self.rl_weights
rl_response.reward = env_reward
rl_response.final_step = step_count
if self.overwrite_weight is None:
overwrite_weight = [False] * len(self.rl_weights)
else:
overwrite_weight = self.overwrite_weight
rl_response.overwrite_weight = overwrite_weight
self.active_rl_pub.publish(rl_response)
self.reset_rl_request()
self.rl_pending = False
def main(args=None):
rclpy.init(args=args)
active_rl_service = ActiveRL()
rclpy.spin(active_rl_service)
rclpy.shutdown()
if __name__ == '__main__':
main()