integrating the metrics
This commit is contained in:
parent
08eac34c2a
commit
952ee10b67
@ -15,6 +15,11 @@ from active_bo_ros.ReinforcementLearning.CartPole import CartPoleEnv
|
|||||||
from active_bo_ros.ReinforcementLearning.Pendulum import PendulumEnv
|
from active_bo_ros.ReinforcementLearning.Pendulum import PendulumEnv
|
||||||
from active_bo_ros.ReinforcementLearning.Acrobot import AcrobotEnv
|
from active_bo_ros.ReinforcementLearning.Acrobot import AcrobotEnv
|
||||||
|
|
||||||
|
from active_bo_ros.UserQuery.random_query import RandomQuery
|
||||||
|
from active_bo_ros.UserQuery.regular_query import RegularQuery
|
||||||
|
from active_bo_ros.UserQuery.improvement_query import ImprovementQuery
|
||||||
|
from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -39,6 +44,7 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
self.active_bo_pending = False
|
self.active_bo_pending = False
|
||||||
self.bo_env = None
|
self.bo_env = None
|
||||||
|
self.bo_metric = None
|
||||||
self.bo_fixed_seed = False
|
self.bo_fixed_seed = False
|
||||||
self.bo_nr_weights = None
|
self.bo_nr_weights = None
|
||||||
self.bo_steps = None
|
self.bo_steps = None
|
||||||
@ -82,6 +88,7 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
def reset_bo_request(self):
|
def reset_bo_request(self):
|
||||||
self.bo_env = None
|
self.bo_env = None
|
||||||
|
self.bo_metric = None
|
||||||
self.bo_fixed_seed = False
|
self.bo_fixed_seed = False
|
||||||
self.bo_nr_weights = None
|
self.bo_nr_weights = None
|
||||||
self.bo_steps = None
|
self.bo_steps = None
|
||||||
@ -97,6 +104,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.get_logger().info('Active Bayesian Optimization request pending!')
|
self.get_logger().info('Active Bayesian Optimization request pending!')
|
||||||
self.active_bo_pending = True
|
self.active_bo_pending = True
|
||||||
self.bo_env = msg.env
|
self.bo_env = msg.env
|
||||||
|
self.bo_metric = msg.metric
|
||||||
self.bo_fixed_seed = msg.fixed_seed
|
self.bo_fixed_seed = msg.fixed_seed
|
||||||
self.bo_nr_weights = msg.nr_weights
|
self.bo_nr_weights = msg.nr_weights
|
||||||
self.bo_steps = msg.max_steps
|
self.bo_steps = msg.max_steps
|
||||||
@ -181,12 +189,24 @@ class ActiveBOTopic(Node):
|
|||||||
self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}')
|
self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}')
|
||||||
else:
|
else:
|
||||||
if self.current_episode < self.bo_episodes:
|
if self.current_episode < self.bo_episodes:
|
||||||
if np.random.uniform(0.0, 1.0, 1) < self.bo_metric_parameter:
|
# metrics
|
||||||
|
if self.bo_metric == "RandomQuery":
|
||||||
|
user_query = RandomQuery(self.bo_metric_parameter)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if user_query.query():
|
||||||
active_rl_request = ActiveRL()
|
active_rl_request = ActiveRL()
|
||||||
old_policy, _, old_weights = self.BO.get_best_result()
|
old_policy, _, old_weights = self.BO.get_best_result()
|
||||||
|
|
||||||
|
if self.seed is None:
|
||||||
|
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||||
|
else:
|
||||||
|
seed = self.seed
|
||||||
|
|
||||||
active_rl_request.env = self.bo_env
|
active_rl_request.env = self.bo_env
|
||||||
active_rl_request.seed = self.seed
|
active_rl_request.seed = seed
|
||||||
active_rl_request.policy = old_policy.tolist()
|
active_rl_request.policy = old_policy.tolist()
|
||||||
active_rl_request.weights = old_weights.tolist()
|
active_rl_request.weights = old_weights.tolist()
|
||||||
|
|
||||||
|
@ -35,7 +35,6 @@ class ActiveRLService(Node):
|
|||||||
self.active_rl_callback,
|
self.active_rl_callback,
|
||||||
1, callback_group=rl_callback_group)
|
1, callback_group=rl_callback_group)
|
||||||
|
|
||||||
self.active_rl_pending = False
|
|
||||||
self.rl_env = None
|
self.rl_env = None
|
||||||
self.rl_seed = None
|
self.rl_seed = None
|
||||||
self.rl_policy = None
|
self.rl_policy = None
|
||||||
@ -66,7 +65,10 @@ class ActiveRLService(Node):
|
|||||||
# RL Environments
|
# RL Environments
|
||||||
self.env = None
|
self.env = None
|
||||||
|
|
||||||
|
# State Machine Variables
|
||||||
self.best_pol_shown = False
|
self.best_pol_shown = False
|
||||||
|
self.policy_sent = False
|
||||||
|
self.active_rl_pending = False
|
||||||
|
|
||||||
# Main loop timer object
|
# Main loop timer object
|
||||||
self.mainloop_timer_period = 0.05
|
self.mainloop_timer_period = 0.05
|
||||||
@ -100,6 +102,8 @@ class ActiveRLService(Node):
|
|||||||
self.get_logger().info('Active RL: Called!')
|
self.get_logger().info('Active RL: Called!')
|
||||||
self.env.reset(seed=self.rl_seed)
|
self.env.reset(seed=self.rl_seed)
|
||||||
self.active_rl_pending = True
|
self.active_rl_pending = True
|
||||||
|
self.policy_sent = False
|
||||||
|
self.rl_step = 0
|
||||||
|
|
||||||
def reset_eval_request(self):
|
def reset_eval_request(self):
|
||||||
self.eval_policy = None
|
self.eval_policy = None
|
||||||
@ -148,9 +152,7 @@ class ActiveRLService(Node):
|
|||||||
def mainloop_callback(self):
|
def mainloop_callback(self):
|
||||||
if self.active_rl_pending:
|
if self.active_rl_pending:
|
||||||
if not self.best_pol_shown:
|
if not self.best_pol_shown:
|
||||||
done = self.next_image(self.rl_policy)
|
if not self.policy_sent:
|
||||||
|
|
||||||
if done:
|
|
||||||
self.rl_step = 0
|
self.rl_step = 0
|
||||||
self.rl_reward = 0.0
|
self.rl_reward = 0.0
|
||||||
self.env.reset(seed=self.rl_seed)
|
self.env.reset(seed=self.rl_seed)
|
||||||
@ -163,7 +165,13 @@ class ActiveRLService(Node):
|
|||||||
self.get_logger().info('Active RL: Called!')
|
self.get_logger().info('Active RL: Called!')
|
||||||
self.get_logger().info('Active RL: Waiting for Eval!')
|
self.get_logger().info('Active RL: Waiting for Eval!')
|
||||||
|
|
||||||
|
self.policy_sent = True
|
||||||
|
|
||||||
|
done = self.next_image(self.rl_policy)
|
||||||
|
|
||||||
|
if done:
|
||||||
self.best_pol_shown = True
|
self.best_pol_shown = True
|
||||||
|
self.rl_step = 0
|
||||||
|
|
||||||
elif self.best_pol_shown:
|
elif self.best_pol_shown:
|
||||||
if not self.eval_response_received:
|
if not self.eval_response_received:
|
||||||
|
@ -11,7 +11,9 @@ setup(
|
|||||||
package_name + '/PolicyModel',
|
package_name + '/PolicyModel',
|
||||||
package_name + '/ReinforcementLearning',
|
package_name + '/ReinforcementLearning',
|
||||||
package_name + '/AcquisitionFunctions',
|
package_name + '/AcquisitionFunctions',
|
||||||
package_name + '/BayesianOptimization'],
|
package_name + '/BayesianOptimization',
|
||||||
|
package_name + '/UserQuery',
|
||||||
|
],
|
||||||
data_files=[
|
data_files=[
|
||||||
('share/ament_index/resource_index/packages',
|
('share/ament_index/resource_index/packages',
|
||||||
['resource/' + package_name]),
|
['resource/' + package_name]),
|
||||||
|
Loading…
Reference in New Issue
Block a user