integrating the metrics

This commit is contained in:
Niko Feith 2023-05-26 16:10:05 +02:00
parent 08eac34c2a
commit 952ee10b67
3 changed files with 37 additions and 7 deletions

View File

@ -15,6 +15,11 @@ from active_bo_ros.ReinforcementLearning.CartPole import CartPoleEnv
from active_bo_ros.ReinforcementLearning.Pendulum import PendulumEnv from active_bo_ros.ReinforcementLearning.Pendulum import PendulumEnv
from active_bo_ros.ReinforcementLearning.Acrobot import AcrobotEnv from active_bo_ros.ReinforcementLearning.Acrobot import AcrobotEnv
from active_bo_ros.UserQuery.random_query import RandomQuery
from active_bo_ros.UserQuery.regular_query import RegularQuery
from active_bo_ros.UserQuery.improvement_query import ImprovementQuery
from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery
import numpy as np import numpy as np
import time import time
@ -39,6 +44,7 @@ class ActiveBOTopic(Node):
self.active_bo_pending = False self.active_bo_pending = False
self.bo_env = None self.bo_env = None
self.bo_metric = None
self.bo_fixed_seed = False self.bo_fixed_seed = False
self.bo_nr_weights = None self.bo_nr_weights = None
self.bo_steps = None self.bo_steps = None
@ -82,6 +88,7 @@ class ActiveBOTopic(Node):
def reset_bo_request(self): def reset_bo_request(self):
self.bo_env = None self.bo_env = None
self.bo_metric = None
self.bo_fixed_seed = False self.bo_fixed_seed = False
self.bo_nr_weights = None self.bo_nr_weights = None
self.bo_steps = None self.bo_steps = None
@ -97,6 +104,7 @@ class ActiveBOTopic(Node):
self.get_logger().info('Active Bayesian Optimization request pending!') self.get_logger().info('Active Bayesian Optimization request pending!')
self.active_bo_pending = True self.active_bo_pending = True
self.bo_env = msg.env self.bo_env = msg.env
self.bo_metric = msg.metric
self.bo_fixed_seed = msg.fixed_seed self.bo_fixed_seed = msg.fixed_seed
self.bo_nr_weights = msg.nr_weights self.bo_nr_weights = msg.nr_weights
self.bo_steps = msg.max_steps self.bo_steps = msg.max_steps
@ -181,12 +189,24 @@ class ActiveBOTopic(Node):
self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}') self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}')
else: else:
if self.current_episode < self.bo_episodes: if self.current_episode < self.bo_episodes:
if np.random.uniform(0.0, 1.0, 1) < self.bo_metric_parameter: # metrics
if self.bo_metric == "RandomQuery":
user_query = RandomQuery(self.bo_metric_parameter)
else:
raise NotImplementedError
if user_query.query():
active_rl_request = ActiveRL() active_rl_request = ActiveRL()
old_policy, _, old_weights = self.BO.get_best_result() old_policy, _, old_weights = self.BO.get_best_result()
if self.seed is None:
seed = int(np.random.randint(1, 2147483647, 1)[0])
else:
seed = self.seed
active_rl_request.env = self.bo_env active_rl_request.env = self.bo_env
active_rl_request.seed = self.seed active_rl_request.seed = seed
active_rl_request.policy = old_policy.tolist() active_rl_request.policy = old_policy.tolist()
active_rl_request.weights = old_weights.tolist() active_rl_request.weights = old_weights.tolist()

View File

@ -35,7 +35,6 @@ class ActiveRLService(Node):
self.active_rl_callback, self.active_rl_callback,
1, callback_group=rl_callback_group) 1, callback_group=rl_callback_group)
self.active_rl_pending = False
self.rl_env = None self.rl_env = None
self.rl_seed = None self.rl_seed = None
self.rl_policy = None self.rl_policy = None
@ -66,7 +65,10 @@ class ActiveRLService(Node):
# RL Environments # RL Environments
self.env = None self.env = None
# State Machine Variables
self.best_pol_shown = False self.best_pol_shown = False
self.policy_sent = False
self.active_rl_pending = False
# Main loop timer object # Main loop timer object
self.mainloop_timer_period = 0.05 self.mainloop_timer_period = 0.05
@ -100,6 +102,8 @@ class ActiveRLService(Node):
self.get_logger().info('Active RL: Called!') self.get_logger().info('Active RL: Called!')
self.env.reset(seed=self.rl_seed) self.env.reset(seed=self.rl_seed)
self.active_rl_pending = True self.active_rl_pending = True
self.policy_sent = False
self.rl_step = 0
def reset_eval_request(self): def reset_eval_request(self):
self.eval_policy = None self.eval_policy = None
@ -148,9 +152,7 @@ class ActiveRLService(Node):
def mainloop_callback(self): def mainloop_callback(self):
if self.active_rl_pending: if self.active_rl_pending:
if not self.best_pol_shown: if not self.best_pol_shown:
done = self.next_image(self.rl_policy) if not self.policy_sent:
if done:
self.rl_step = 0 self.rl_step = 0
self.rl_reward = 0.0 self.rl_reward = 0.0
self.env.reset(seed=self.rl_seed) self.env.reset(seed=self.rl_seed)
@ -163,7 +165,13 @@ class ActiveRLService(Node):
self.get_logger().info('Active RL: Called!') self.get_logger().info('Active RL: Called!')
self.get_logger().info('Active RL: Waiting for Eval!') self.get_logger().info('Active RL: Waiting for Eval!')
self.policy_sent = True
done = self.next_image(self.rl_policy)
if done:
self.best_pol_shown = True self.best_pol_shown = True
self.rl_step = 0
elif self.best_pol_shown: elif self.best_pol_shown:
if not self.eval_response_received: if not self.eval_response_received:

View File

@ -11,7 +11,9 @@ setup(
package_name + '/PolicyModel', package_name + '/PolicyModel',
package_name + '/ReinforcementLearning', package_name + '/ReinforcementLearning',
package_name + '/AcquisitionFunctions', package_name + '/AcquisitionFunctions',
package_name + '/BayesianOptimization'], package_name + '/BayesianOptimization',
package_name + '/UserQuery',
],
data_files=[ data_files=[
('share/ament_index/resource_index/packages', ('share/ament_index/resource_index/packages',
['resource/' + package_name]), ['resource/' + package_name]),