usability improvements and saving function
testing for bo and manual usage completed
This commit is contained in:
parent
c5f00a4069
commit
55b21d667a
@ -30,6 +30,7 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
|||||||
"msg/ActiveRLResponse.msg"
|
"msg/ActiveRLResponse.msg"
|
||||||
"msg/ActiveRL.msg"
|
"msg/ActiveRL.msg"
|
||||||
"msg/ImageFeedback.msg"
|
"msg/ImageFeedback.msg"
|
||||||
|
"msg/ActiveBOState.msg"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -6,4 +6,5 @@ uint16 max_steps
|
|||||||
uint16 nr_episodes
|
uint16 nr_episodes
|
||||||
uint16 nr_runs
|
uint16 nr_runs
|
||||||
string acquisition_function
|
string acquisition_function
|
||||||
float32 metric_parameter
|
float32 metric_parameter
|
||||||
|
bool save_result
|
4
src/active_bo_msgs/msg/ActiveBOState.msg
Normal file
4
src/active_bo_msgs/msg/ActiveBOState.msg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
uint16 current_run
|
||||||
|
uint16 current_episode
|
||||||
|
float32 best_reward
|
||||||
|
float32 last_user_reward
|
@ -2,6 +2,7 @@ from active_bo_msgs.msg import ActiveBORequest
|
|||||||
from active_bo_msgs.msg import ActiveBOResponse
|
from active_bo_msgs.msg import ActiveBOResponse
|
||||||
from active_bo_msgs.msg import ActiveRL
|
from active_bo_msgs.msg import ActiveRL
|
||||||
from active_bo_msgs.msg import ActiveRLResponse
|
from active_bo_msgs.msg import ActiveRLResponse
|
||||||
|
from active_bo_msgs.msg import ActiveBOState
|
||||||
|
|
||||||
import rclpy
|
import rclpy
|
||||||
from rclpy.node import Node
|
from rclpy.node import Node
|
||||||
@ -22,6 +23,7 @@ from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class ActiveBOTopic(Node):
|
class ActiveBOTopic(Node):
|
||||||
@ -47,14 +49,15 @@ class ActiveBOTopic(Node):
|
|||||||
self.bo_metric = None
|
self.bo_metric = None
|
||||||
self.bo_fixed_seed = False
|
self.bo_fixed_seed = False
|
||||||
self.bo_nr_weights = None
|
self.bo_nr_weights = None
|
||||||
self.bo_steps = None
|
self.bo_steps = 0
|
||||||
self.bo_episodes = None
|
self.bo_episodes = 0
|
||||||
self.bo_runs = None
|
self.bo_runs = 0
|
||||||
self.bo_acq_fcn = None
|
self.bo_acq_fcn = None
|
||||||
self.bo_metric_parameter = None
|
self.bo_metric_parameter = None
|
||||||
self.current_run = 0
|
self.current_run = 0
|
||||||
self.current_episode = 0
|
self.current_episode = 0
|
||||||
self.seed = None
|
self.seed = None
|
||||||
|
self.save_result = False
|
||||||
|
|
||||||
# Active Reinforcement Learning Publisher, Subscriber and Message attributes
|
# Active Reinforcement Learning Publisher, Subscriber and Message attributes
|
||||||
self.active_rl_pub = self.create_publisher(ActiveRL,
|
self.active_rl_pub = self.create_publisher(ActiveRL,
|
||||||
@ -68,7 +71,10 @@ class ActiveBOTopic(Node):
|
|||||||
self.active_rl_pending = False
|
self.active_rl_pending = False
|
||||||
self.rl_weights = None
|
self.rl_weights = None
|
||||||
self.rl_final_step = None
|
self.rl_final_step = None
|
||||||
self.rl_reward = None
|
self.rl_reward = 0.0
|
||||||
|
|
||||||
|
# State Publisher
|
||||||
|
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
|
||||||
|
|
||||||
# RL Environments and BO
|
# RL Environments and BO
|
||||||
self.env = None
|
self.env = None
|
||||||
@ -76,6 +82,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.BO = None
|
self.BO = None
|
||||||
self.nr_init = 3
|
self.nr_init = 3
|
||||||
self.reward = None
|
self.reward = None
|
||||||
|
self.best_reward = 0.0
|
||||||
self.best_pol_reward = None
|
self.best_pol_reward = None
|
||||||
self.best_policy = None
|
self.best_policy = None
|
||||||
self.best_weights = None
|
self.best_weights = None
|
||||||
@ -91,13 +98,14 @@ class ActiveBOTopic(Node):
|
|||||||
self.bo_metric = None
|
self.bo_metric = None
|
||||||
self.bo_fixed_seed = False
|
self.bo_fixed_seed = False
|
||||||
self.bo_nr_weights = None
|
self.bo_nr_weights = None
|
||||||
self.bo_steps = None
|
self.bo_steps = 0
|
||||||
self.bo_episodes = None
|
self.bo_episodes = 0
|
||||||
self.bo_runs = None
|
self.bo_runs = 0
|
||||||
self.bo_acq_fcn = None
|
self.bo_acq_fcn = None
|
||||||
self.bo_metric_parameter = None
|
self.bo_metric_parameter = None
|
||||||
self.current_run = 0
|
self.current_run = 0
|
||||||
self.current_episode = 0
|
self.current_episode = 0
|
||||||
|
self.save_result = False
|
||||||
|
|
||||||
def active_bo_callback(self, msg):
|
def active_bo_callback(self, msg):
|
||||||
if not self.active_bo_pending:
|
if not self.active_bo_pending:
|
||||||
@ -112,6 +120,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.bo_runs = msg.nr_runs
|
self.bo_runs = msg.nr_runs
|
||||||
self.bo_acq_fcn = msg.acquisition_function
|
self.bo_acq_fcn = msg.acquisition_function
|
||||||
self.bo_metric_parameter = msg.metric_parameter
|
self.bo_metric_parameter = msg.metric_parameter
|
||||||
|
self.save_result = msg.save_result
|
||||||
|
|
||||||
# initialize
|
# initialize
|
||||||
self.reward = np.zeros((self.bo_episodes, self.bo_runs))
|
self.reward = np.zeros((self.bo_episodes, self.bo_runs))
|
||||||
@ -128,7 +137,6 @@ class ActiveBOTopic(Node):
|
|||||||
def reset_rl_response(self):
|
def reset_rl_response(self):
|
||||||
self.rl_weights = None
|
self.rl_weights = None
|
||||||
self.rl_final_step = None
|
self.rl_final_step = None
|
||||||
self.rl_reward = None
|
|
||||||
|
|
||||||
def active_rl_callback(self, msg):
|
def active_rl_callback(self, msg):
|
||||||
if self.active_rl_pending:
|
if self.active_rl_pending:
|
||||||
@ -180,6 +188,18 @@ class ActiveBOTopic(Node):
|
|||||||
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
|
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
|
||||||
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
||||||
|
|
||||||
|
if self.save_result:
|
||||||
|
home_dir = os.path.expanduser('~')
|
||||||
|
file_path = os.path.join(home_dir, 'Documents/IntRLResults')
|
||||||
|
filename = self.bo_metric + '-'\
|
||||||
|
+ str(self.bo_metric_parameter) + '-' \
|
||||||
|
+ str(self.bo_nr_weights) + '-' + str(time.time())
|
||||||
|
filename = filename.replace('.', '_') + '.csv'
|
||||||
|
path = os.path.join(file_path, filename)
|
||||||
|
self.get_logger().info(path)
|
||||||
|
|
||||||
|
np.savetxt(path, self.reward, delimiter=',')
|
||||||
|
|
||||||
self.get_logger().info('Responding: Active BO')
|
self.get_logger().info('Responding: Active BO')
|
||||||
self.active_bo_pub.publish(bo_response)
|
self.active_bo_pub.publish(bo_response)
|
||||||
self.reset_bo_request()
|
self.reset_bo_request()
|
||||||
@ -249,6 +269,18 @@ class ActiveBOTopic(Node):
|
|||||||
self.current_run += 1
|
self.current_run += 1
|
||||||
self.get_logger().info(f'Current Run: {self.current_run}')
|
self.get_logger().info(f'Current Run: {self.current_run}')
|
||||||
|
|
||||||
|
# send the current states
|
||||||
|
|
||||||
|
if self.BO is not None and self.BO.Y is not None:
|
||||||
|
self.best_reward = np.max(self.BO.Y)
|
||||||
|
|
||||||
|
state_msg = ActiveBOState()
|
||||||
|
state_msg.current_run = self.current_run + 1 if self.current_run < self.bo_runs else self.bo_runs
|
||||||
|
state_msg.current_episode = self.current_episode + 1\
|
||||||
|
if self.current_episode < self.bo_episodes else self.bo_episodes
|
||||||
|
state_msg.best_reward = self.best_reward
|
||||||
|
state_msg.last_user_reward = self.rl_reward
|
||||||
|
self.state_pub.publish(state_msg)
|
||||||
|
|
||||||
def main(args=None):
|
def main(args=None):
|
||||||
rclpy.init(args=args)
|
rclpy.init(args=args)
|
||||||
|
@ -172,6 +172,7 @@ class ActiveRLService(Node):
|
|||||||
if done:
|
if done:
|
||||||
self.best_pol_shown = True
|
self.best_pol_shown = True
|
||||||
self.rl_step = 0
|
self.rl_step = 0
|
||||||
|
self.rl_reward = 0.0
|
||||||
|
|
||||||
elif self.best_pol_shown:
|
elif self.best_pol_shown:
|
||||||
if not self.eval_response_received:
|
if not self.eval_response_received:
|
||||||
|
Loading…
Reference in New Issue
Block a user