moved the Evaluation from BayesianOptimization.py to the active_rl_topic.py due to inconsistency in the evaluation of rl envs
This commit is contained in:
parent
1b9e099696
commit
e99b131ee9
@ -1,5 +1,6 @@
|
||||
string env
|
||||
uint32 seed
|
||||
bool final_run
|
||||
bool display_run
|
||||
uint8 interactive_run
|
||||
float64[] policy
|
||||
float64[] weights
|
@ -38,9 +38,11 @@ class BayesianOptimization:
|
||||
|
||||
def reset_bo(self):
|
||||
self.counter_array = np.empty((1, 1))
|
||||
self.GP = None
|
||||
self.GP = GaussianProcessRegressor(Matern(nu=1.5, length_scale_bounds=(1e-8, 1e5)), n_restarts_optimizer=5, )
|
||||
self.episode = 0
|
||||
self.best_reward = np.empty((1, 1))
|
||||
self.X = np.zeros((1, self.nr_policy_weights), dtype=np.float64)
|
||||
self.Y = np.zeros((1, 1), dtype=np.float64)
|
||||
|
||||
def runner(self, policy, seed=None):
|
||||
env_reward = 0.0
|
||||
@ -71,9 +73,6 @@ class BayesianOptimization:
|
||||
self.env.reset(seed=seed)
|
||||
self.reset_bo()
|
||||
|
||||
self.X = np.zeros((self.nr_init, self.nr_policy_weights), dtype=np.float64)
|
||||
self.Y = np.zeros((self.nr_init, 1), dtype=np.float64)
|
||||
|
||||
for i in range(self.nr_init):
|
||||
self.policy_model.random_policy()
|
||||
self.X[i, :] = self.policy_model.weights.T
|
||||
@ -141,16 +140,16 @@ class BayesianOptimization:
|
||||
return step_count
|
||||
|
||||
def add_new_observation(self, reward, x_new):
|
||||
self.X = np.vstack((self.X, np.around(x_new, decimals=8)), dtype=np.float64)
|
||||
self.Y = np.vstack((self.Y, reward), dtype=np.float64)
|
||||
|
||||
self.GP.fit(self.X, self.Y)
|
||||
|
||||
if self.episode == 0:
|
||||
self.X[0, :] = x_new
|
||||
self.Y[0] = reward
|
||||
self.best_reward[0] = np.max(self.Y)
|
||||
else:
|
||||
self.X = np.vstack((self.X, np.around(x_new, decimals=8)), dtype=np.float64)
|
||||
self.Y = np.vstack((self.Y, reward), dtype=np.float64)
|
||||
self.best_reward = np.vstack((self.best_reward, np.max(self.Y)), dtype=np.float64)
|
||||
|
||||
self.GP.fit(self.X, self.Y)
|
||||
self.episode += 1
|
||||
|
||||
def get_best_result(self):
|
||||
|
@ -25,6 +25,7 @@ class GaussianRBF:
|
||||
|
||||
def random_policy(self):
|
||||
self.weights = np.around(self.rng.uniform(self.low, self.upper, self.nr_weights), decimals=8)
|
||||
return self.weights.T
|
||||
|
||||
def rollout(self):
|
||||
self.policy = np.zeros((self.nr_steps, 1))
|
||||
|
@ -68,8 +68,9 @@ class ActiveRLService(Node):
|
||||
# State Machine Variables
|
||||
self.best_pol_shown = False
|
||||
self.policy_sent = False
|
||||
self.active_rl_pending = False
|
||||
self.final_run = False
|
||||
self.rl_pending = False
|
||||
self.interactive_run = False
|
||||
self.display_run = False
|
||||
|
||||
# Main loop timer object
|
||||
self.mainloop_timer_period = 0.05
|
||||
@ -86,9 +87,10 @@ class ActiveRLService(Node):
|
||||
def active_rl_callback(self, msg):
|
||||
self.rl_env = msg.env
|
||||
self.rl_seed = msg.seed
|
||||
self.display_run = msg.display_run
|
||||
self.rl_policy = np.array(msg.policy, dtype=np.float64)
|
||||
self.rl_weights = msg.weights
|
||||
self.final_run = msg.final_run
|
||||
self.interactive_run = msg.interactive_run
|
||||
|
||||
if self.rl_env == "Mountain Car":
|
||||
self.env = Continuous_MountainCarEnv(render_mode="rgb_array")
|
||||
@ -103,7 +105,7 @@ class ActiveRLService(Node):
|
||||
|
||||
self.get_logger().info('Active RL: Called!')
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
self.active_rl_pending = True
|
||||
self.rl_pending = True
|
||||
self.policy_sent = False
|
||||
self.rl_step = 0
|
||||
|
||||
@ -119,7 +121,7 @@ class ActiveRLService(Node):
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
self.eval_response_received = True
|
||||
|
||||
def next_image(self, policy):
|
||||
def next_image(self, policy, display_run):
|
||||
action = policy[self.rl_step]
|
||||
action_clipped = action.clip(min=-1.0, max=1.0)
|
||||
output = self.env.step(action_clipped.astype(np.float64))
|
||||
@ -128,6 +130,7 @@ class ActiveRLService(Node):
|
||||
done = output[2]
|
||||
self.rl_step += 1
|
||||
|
||||
if display_run:
|
||||
rgb_array = self.env.render()
|
||||
rgb_shape = rgb_array.shape
|
||||
|
||||
@ -152,8 +155,8 @@ class ActiveRLService(Node):
|
||||
return done
|
||||
|
||||
def mainloop_callback(self):
|
||||
if self.active_rl_pending:
|
||||
if not self.final_run:
|
||||
if self.rl_pending:
|
||||
if self.interactive_run == 0:
|
||||
if not self.best_pol_shown:
|
||||
if not self.policy_sent:
|
||||
self.rl_step = 0
|
||||
@ -170,7 +173,7 @@ class ActiveRLService(Node):
|
||||
|
||||
self.policy_sent = True
|
||||
|
||||
done = self.next_image(self.rl_policy)
|
||||
done = self.next_image(self.rl_policy, self.display_run)
|
||||
|
||||
if done:
|
||||
self.best_pol_shown = True
|
||||
@ -182,7 +185,7 @@ class ActiveRLService(Node):
|
||||
pass
|
||||
|
||||
if self.eval_response_received:
|
||||
done = self.next_image(self.eval_policy)
|
||||
done = self.next_image(self.eval_policy, self.display_run)
|
||||
|
||||
if done:
|
||||
rl_response = ActiveRLResponse()
|
||||
@ -203,8 +206,8 @@ class ActiveRLService(Node):
|
||||
|
||||
self.best_pol_shown = False
|
||||
self.eval_response_received = False
|
||||
self.active_rl_pending = False
|
||||
else:
|
||||
self.rl_pending = False
|
||||
elif self.interactive_run == 1:
|
||||
if not self.policy_sent:
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
@ -220,13 +223,39 @@ class ActiveRLService(Node):
|
||||
|
||||
self.policy_sent = True
|
||||
|
||||
done = self.next_image(self.rl_policy)
|
||||
done = self.next_image(self.rl_policy, self.display_run)
|
||||
|
||||
if done:
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
self.final_run = False
|
||||
self.active_rl_pending = False
|
||||
self.rl_pending = False
|
||||
|
||||
elif self.interactive_run == 2:
|
||||
if not self.policy_sent:
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
self.policy_sent = True
|
||||
done = self.next_image(self.rl_policy, self.display_run)
|
||||
|
||||
if done:
|
||||
rl_response = ActiveRLResponse()
|
||||
rl_response.weights = self.eval_weights
|
||||
rl_response.reward = self.rl_reward
|
||||
rl_response.final_step = self.rl_step
|
||||
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
|
||||
# reset flags and attributes
|
||||
self.reset_eval_request()
|
||||
self.reset_rl_request()
|
||||
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
|
||||
self.rl_pending = False
|
||||
|
||||
|
||||
|
||||
|
||||
def main(args=None):
|
||||
|
399
src/active_bo_ros/active_bo_ros/interactive_bo.py
Normal file
399
src/active_bo_ros/active_bo_ros/interactive_bo.py
Normal file
@ -0,0 +1,399 @@
|
||||
from active_bo_msgs.msg import ActiveBORequest
|
||||
from active_bo_msgs.msg import ActiveBOResponse
|
||||
from active_bo_msgs.msg import ActiveRL
|
||||
from active_bo_msgs.msg import ActiveRLResponse
|
||||
from active_bo_msgs.msg import ActiveBOState
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||
|
||||
from active_bo_ros.BayesianOptimization.BayesianOptimization import BayesianOptimization
|
||||
|
||||
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
|
||||
from active_bo_ros.ReinforcementLearning.CartPole import CartPoleEnv
|
||||
from active_bo_ros.ReinforcementLearning.Pendulum import PendulumEnv
|
||||
from active_bo_ros.ReinforcementLearning.Acrobot import AcrobotEnv
|
||||
|
||||
from active_bo_ros.UserQuery.random_query import RandomQuery
|
||||
from active_bo_ros.UserQuery.regular_query import RegularQuery
|
||||
from active_bo_ros.UserQuery.improvement_query import ImprovementQuery
|
||||
from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
class ActiveBOTopic(Node):
|
||||
def __init__(self):
|
||||
super().__init__('active_bo_topic')
|
||||
|
||||
bo_callback_group = ReentrantCallbackGroup()
|
||||
rl_callback_group = ReentrantCallbackGroup()
|
||||
mainloop_callback_group = ReentrantCallbackGroup()
|
||||
|
||||
# Active Bayesian Optimization Publisher, Subscriber and Message attributes
|
||||
self.active_bo_pub = self.create_publisher(ActiveBOResponse,
|
||||
'active_bo_response',
|
||||
1, callback_group=bo_callback_group)
|
||||
|
||||
self.active_bo_sub = self.create_subscription(ActiveBORequest,
|
||||
'active_bo_request',
|
||||
self.active_bo_callback,
|
||||
1, callback_group=bo_callback_group)
|
||||
|
||||
self.active_bo_pending = False
|
||||
self.bo_env = None
|
||||
self.bo_metric = None
|
||||
self.bo_fixed_seed = False
|
||||
self.bo_nr_weights = None
|
||||
self.bo_steps = 0
|
||||
self.bo_episodes = 0
|
||||
self.bo_runs = 0
|
||||
self.bo_acq_fcn = None
|
||||
self.bo_metric_parameter = None
|
||||
self.current_run = 0
|
||||
self.current_episode = 0
|
||||
self.seed = None
|
||||
self.seed_array = None
|
||||
self.save_result = False
|
||||
|
||||
# Active Reinforcement Learning Publisher, Subscriber and Message attributes
|
||||
self.active_rl_pub = self.create_publisher(ActiveRL,
|
||||
'active_rl_request',
|
||||
1, callback_group=rl_callback_group)
|
||||
self.active_rl_sub = self.create_subscription(ActiveRLResponse,
|
||||
'active_rl_response',
|
||||
self.active_rl_callback,
|
||||
1, callback_group=rl_callback_group)
|
||||
|
||||
self.rl_pending = False
|
||||
self.rl_weights = None
|
||||
self.rl_final_step = None
|
||||
self.rl_reward = 0.0
|
||||
|
||||
# State Publisher
|
||||
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
|
||||
|
||||
# RL Environments and BO
|
||||
self.env = None
|
||||
|
||||
self.BO = None
|
||||
self.nr_init = 3
|
||||
self.init_step = 0
|
||||
self.init_pending = False
|
||||
self.reward = None
|
||||
self.best_reward = 0.0
|
||||
self.best_pol_reward = None
|
||||
self.best_policy = None
|
||||
self.best_weights = None
|
||||
|
||||
# Main loop timer object
|
||||
self.mainloop_timer_period = 0.1
|
||||
self.mainloop = self.create_timer(self.mainloop_timer_period,
|
||||
self.mainloop_callback,
|
||||
callback_group=mainloop_callback_group)
|
||||
|
||||
def reset_bo_request(self):
|
||||
self.bo_env = None
|
||||
self.bo_metric = None
|
||||
self.bo_fixed_seed = False
|
||||
self.bo_nr_weights = None
|
||||
self.bo_steps = 0
|
||||
self.bo_episodes = 0
|
||||
self.bo_runs = 0
|
||||
self.bo_acq_fcn = None
|
||||
self.bo_metric_parameter = None
|
||||
self.current_run = 0
|
||||
self.current_episode = 0
|
||||
self.save_result = False
|
||||
self.seed_array = None
|
||||
|
||||
def active_bo_callback(self, msg):
|
||||
if not self.active_bo_pending:
|
||||
self.get_logger().info('Active Bayesian Optimization request pending!')
|
||||
self.active_bo_pending = True
|
||||
self.bo_env = msg.env
|
||||
self.bo_metric = msg.metric
|
||||
self.bo_fixed_seed = msg.fixed_seed
|
||||
self.bo_nr_weights = msg.nr_weights
|
||||
self.bo_steps = msg.max_steps
|
||||
self.bo_episodes = msg.nr_episodes
|
||||
self.bo_runs = msg.nr_runs
|
||||
self.bo_acq_fcn = msg.acquisition_function
|
||||
self.bo_metric_parameter = msg.metric_parameter
|
||||
self.save_result = msg.save_result
|
||||
self.seed_array = np.zeros((1, self.bo_runs))
|
||||
|
||||
# initialize
|
||||
self.reward = np.zeros((self.bo_episodes, self.bo_runs))
|
||||
self.best_pol_reward = np.zeros((1, self.bo_runs))
|
||||
self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
|
||||
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
|
||||
|
||||
# set the seed
|
||||
if self.bo_fixed_seed:
|
||||
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
self.get_logger().info(str(self.seed))
|
||||
else:
|
||||
self.seed = None
|
||||
|
||||
def reset_rl_response(self):
|
||||
self.rl_weights = None
|
||||
self.rl_final_step = None
|
||||
|
||||
def active_rl_callback(self, msg):
|
||||
if self.rl_pending:
|
||||
self.get_logger().info('Active Reinforcement Learning response received!')
|
||||
self.rl_weights = msg.weights
|
||||
self.rl_final_step = msg.final_step
|
||||
self.rl_reward = msg.reward
|
||||
|
||||
try:
|
||||
self.BO.add_new_observation(self.rl_reward, self.rl_weights)
|
||||
self.get_logger().info('Active Reinforcement Learning added new observation!')
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}')
|
||||
|
||||
if self.init_pending:
|
||||
self.init_step += 1
|
||||
if self.init_step == self.nr_init:
|
||||
self.init_step = 0
|
||||
self.init_pending = False
|
||||
|
||||
self.rl_pending = False
|
||||
self.reset_rl_response()
|
||||
|
||||
def mainloop_callback(self):
|
||||
if self.active_bo_pending:
|
||||
|
||||
# set rl environment
|
||||
if self.bo_env == "Mountain Car":
|
||||
self.env = Continuous_MountainCarEnv()
|
||||
elif self.bo_env == "Cartpole":
|
||||
self.env = CartPoleEnv()
|
||||
elif self.bo_env == "Acrobot":
|
||||
self.env = AcrobotEnv()
|
||||
elif self.bo_env == "Pendulum":
|
||||
self.env = PendulumEnv()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.BO is None:
|
||||
self.BO = BayesianOptimization(self.env,
|
||||
self.bo_steps,
|
||||
nr_init=self.nr_init,
|
||||
acq=self.bo_acq_fcn,
|
||||
nr_weights=self.bo_nr_weights)
|
||||
|
||||
# self.BO.initialize()
|
||||
self.init_pending = True
|
||||
|
||||
if self.init_pending and not self.rl_pending:
|
||||
|
||||
if self.bo_fixed_seed:
|
||||
seed = self.seed
|
||||
else:
|
||||
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
|
||||
rl_msg = ActiveRL()
|
||||
rl_msg.env = self.bo_env
|
||||
rl_msg.seed = seed
|
||||
rl_msg.display_run = False
|
||||
rl_msg.interactive_run = 2
|
||||
rl_msg.weights = self.BO.policy_model.random_policy()
|
||||
rl_msg.policy = self.BO.policy_model.rollout()
|
||||
self.rl_pending = True
|
||||
|
||||
if self.current_run == self.bo_runs:
|
||||
bo_response = ActiveBOResponse()
|
||||
|
||||
best_policy_idx = np.argmax(self.best_pol_reward)
|
||||
bo_response.best_policy = self.best_policy[:, best_policy_idx].tolist()
|
||||
bo_response.best_weights = self.best_weights[:, best_policy_idx].tolist()
|
||||
|
||||
self.get_logger().info(f'Best Policy: {self.best_pol_reward}')
|
||||
|
||||
self.get_logger().info(f'{best_policy_idx}, {int(self.seed_array[0, best_policy_idx])}')
|
||||
|
||||
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
|
||||
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
||||
|
||||
if self.save_result:
|
||||
if self.bo_env == "Mountain Car":
|
||||
env = 'mc'
|
||||
elif self.bo_env == "Cartpole":
|
||||
env = 'cp'
|
||||
elif self.bo_env == "Acrobot":
|
||||
env = 'ab'
|
||||
elif self.bo_env == "Pendulum":
|
||||
env = 'pd'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.bo_acq_fcn == "Expected Improvement":
|
||||
acq = 'ei'
|
||||
elif self.bo_acq_fcn == "Probability of Improvement":
|
||||
acq = 'pi'
|
||||
elif self.bo_acq_fcn == "Upper Confidence Bound":
|
||||
acq = 'cb'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
home_dir = os.path.expanduser('~')
|
||||
file_path = os.path.join(home_dir, 'Documents/IntRLResults')
|
||||
filename = env + '-' + acq + '-' + self.bo_metric + '-' \
|
||||
+ str(round(self.bo_metric_parameter, 2)) + '-' \
|
||||
+ str(self.bo_nr_weights) + '-' + str(time.time())
|
||||
filename = filename.replace('.', '_') + '.csv'
|
||||
path = os.path.join(file_path, filename)
|
||||
|
||||
data = self.reward
|
||||
|
||||
np.savetxt(path, data, delimiter=',')
|
||||
|
||||
active_rl_request = ActiveRL()
|
||||
|
||||
if self.bo_fixed_seed:
|
||||
seed = int(self.seed_array[0, best_policy_idx])
|
||||
self.get_logger().info(f'Used seed{seed}')
|
||||
else:
|
||||
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
|
||||
active_rl_request.env = self.bo_env
|
||||
active_rl_request.seed = seed
|
||||
active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist()
|
||||
active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist()
|
||||
active_rl_request.interactive_run = 1
|
||||
|
||||
self.active_rl_pub.publish(active_rl_request)
|
||||
|
||||
self.get_logger().info('Responding: Active BO')
|
||||
self.active_bo_pub.publish(bo_response)
|
||||
self.reset_bo_request()
|
||||
self.active_bo_pending = False
|
||||
self.BO = None
|
||||
|
||||
else:
|
||||
if self.rl_pending:
|
||||
pass
|
||||
else:
|
||||
if self.init_pending:
|
||||
pass
|
||||
elif self.current_episode < self.bo_episodes:
|
||||
# metrics
|
||||
if self.bo_metric == "random":
|
||||
user_query = RandomQuery(self.bo_metric_parameter)
|
||||
|
||||
elif self.bo_metric == "regular":
|
||||
user_query = RegularQuery(self.bo_metric_parameter, self.current_episode)
|
||||
|
||||
elif self.bo_metric == "max acquisition":
|
||||
user_query = MaxAcqQuery(self.bo_metric_parameter,
|
||||
self.BO.GP,
|
||||
100,
|
||||
self.bo_nr_weights,
|
||||
acq=self.bo_acq_fcn,
|
||||
X=self.BO.X)
|
||||
|
||||
elif self.bo_metric == "improvement":
|
||||
user_query = ImprovementQuery(self.bo_metric_parameter, 10)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if user_query.query():
|
||||
active_rl_request = ActiveRL()
|
||||
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
|
||||
|
||||
self.get_logger().info(f'Best: {y_max}, w:{old_weights}')
|
||||
self.get_logger().info(f'Size of Y: {self.BO.Y.shape}, Size of X: {self.BO.X.shape}')
|
||||
|
||||
if self.bo_fixed_seed:
|
||||
seed = self.seed
|
||||
else:
|
||||
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
|
||||
active_rl_request.env = self.bo_env
|
||||
active_rl_request.seed = seed
|
||||
active_rl_request.display_run = True
|
||||
active_rl_request.policy = old_policy.tolist()
|
||||
active_rl_request.weights = old_weights.tolist()
|
||||
active_rl_request.interactive_run = 0
|
||||
|
||||
self.get_logger().info('Calling: Active RL')
|
||||
self.active_rl_pub.publish(active_rl_request)
|
||||
self.rl_pending = True
|
||||
|
||||
else:
|
||||
x_next = self.BO.next_observation()
|
||||
self.BO.policy_model.weights = np.around(x_next, decimals=8)
|
||||
if self.bo_fixed_seed:
|
||||
seed = self.seed
|
||||
else:
|
||||
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
|
||||
rl_msg = ActiveRL()
|
||||
rl_msg.env = self.bo_env
|
||||
rl_msg.seed = seed
|
||||
rl_msg.display_run = False
|
||||
rl_msg.interactive_run = 2
|
||||
rl_msg.weights = x_next
|
||||
rl_msg.policy = self.BO.policy_model.rollout()
|
||||
self.rl_pending = True
|
||||
|
||||
self.current_episode += 1
|
||||
# self.get_logger().info(f'Current Episode: {self.current_episode}')
|
||||
else:
|
||||
self.best_policy[:, self.current_run], \
|
||||
self.best_pol_reward[:, self.current_run], \
|
||||
self.best_weights[:, self.current_run], idx = self.BO.get_best_result()
|
||||
|
||||
self.get_logger().info(f'best idx: {idx}')
|
||||
|
||||
self.reward[:, self.current_run] = self.BO.best_reward.T
|
||||
|
||||
self.BO = None
|
||||
|
||||
self.current_episode = 0
|
||||
if self.bo_fixed_seed:
|
||||
self.seed_array[0, self.current_run] = self.seed
|
||||
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
self.get_logger().info(f'{self.seed}')
|
||||
self.current_run += 1
|
||||
self.get_logger().info(f'Current Run: {self.current_run}')
|
||||
|
||||
# send the current states
|
||||
|
||||
if self.BO is not None and self.BO.Y is not None:
|
||||
self.best_reward = np.max(self.BO.Y)
|
||||
|
||||
state_msg = ActiveBOState()
|
||||
state_msg.current_run = self.current_run + 1 if self.current_run < self.bo_runs else self.bo_runs
|
||||
state_msg.current_episode = self.current_episode + 1 \
|
||||
if self.current_episode < self.bo_episodes else self.bo_episodes
|
||||
state_msg.best_reward = float(self.best_reward)
|
||||
state_msg.last_user_reward = float(self.rl_reward)
|
||||
self.state_pub.publish(state_msg)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
|
||||
active_bo_topic = ActiveBOTopic()
|
||||
|
||||
rclpy.spin(active_bo_topic)
|
||||
|
||||
try:
|
||||
rclpy.spin(active_bo_topic)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
active_bo_topic.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user