usability improvements and saving function
testing for bo and manual usage completed
This commit is contained in:
parent
c5f00a4069
commit
55b21d667a
@ -30,6 +30,7 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
"msg/ActiveRLResponse.msg"
|
||||
"msg/ActiveRL.msg"
|
||||
"msg/ImageFeedback.msg"
|
||||
"msg/ActiveBOState.msg"
|
||||
|
||||
)
|
||||
|
||||
|
@ -7,3 +7,4 @@ uint16 nr_episodes
|
||||
uint16 nr_runs
|
||||
string acquisition_function
|
||||
float32 metric_parameter
|
||||
bool save_result
|
4
src/active_bo_msgs/msg/ActiveBOState.msg
Normal file
4
src/active_bo_msgs/msg/ActiveBOState.msg
Normal file
@ -0,0 +1,4 @@
|
||||
uint16 current_run
|
||||
uint16 current_episode
|
||||
float32 best_reward
|
||||
float32 last_user_reward
|
@ -2,6 +2,7 @@ from active_bo_msgs.msg import ActiveBORequest
|
||||
from active_bo_msgs.msg import ActiveBOResponse
|
||||
from active_bo_msgs.msg import ActiveRL
|
||||
from active_bo_msgs.msg import ActiveRLResponse
|
||||
from active_bo_msgs.msg import ActiveBOState
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
@ -22,6 +23,7 @@ from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
class ActiveBOTopic(Node):
|
||||
@ -47,14 +49,15 @@ class ActiveBOTopic(Node):
|
||||
self.bo_metric = None
|
||||
self.bo_fixed_seed = False
|
||||
self.bo_nr_weights = None
|
||||
self.bo_steps = None
|
||||
self.bo_episodes = None
|
||||
self.bo_runs = None
|
||||
self.bo_steps = 0
|
||||
self.bo_episodes = 0
|
||||
self.bo_runs = 0
|
||||
self.bo_acq_fcn = None
|
||||
self.bo_metric_parameter = None
|
||||
self.current_run = 0
|
||||
self.current_episode = 0
|
||||
self.seed = None
|
||||
self.save_result = False
|
||||
|
||||
# Active Reinforcement Learning Publisher, Subscriber and Message attributes
|
||||
self.active_rl_pub = self.create_publisher(ActiveRL,
|
||||
@ -68,7 +71,10 @@ class ActiveBOTopic(Node):
|
||||
self.active_rl_pending = False
|
||||
self.rl_weights = None
|
||||
self.rl_final_step = None
|
||||
self.rl_reward = None
|
||||
self.rl_reward = 0.0
|
||||
|
||||
# State Publisher
|
||||
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
|
||||
|
||||
# RL Environments and BO
|
||||
self.env = None
|
||||
@ -76,6 +82,7 @@ class ActiveBOTopic(Node):
|
||||
self.BO = None
|
||||
self.nr_init = 3
|
||||
self.reward = None
|
||||
self.best_reward = 0.0
|
||||
self.best_pol_reward = None
|
||||
self.best_policy = None
|
||||
self.best_weights = None
|
||||
@ -91,13 +98,14 @@ class ActiveBOTopic(Node):
|
||||
self.bo_metric = None
|
||||
self.bo_fixed_seed = False
|
||||
self.bo_nr_weights = None
|
||||
self.bo_steps = None
|
||||
self.bo_episodes = None
|
||||
self.bo_runs = None
|
||||
self.bo_steps = 0
|
||||
self.bo_episodes = 0
|
||||
self.bo_runs = 0
|
||||
self.bo_acq_fcn = None
|
||||
self.bo_metric_parameter = None
|
||||
self.current_run = 0
|
||||
self.current_episode = 0
|
||||
self.save_result = False
|
||||
|
||||
def active_bo_callback(self, msg):
|
||||
if not self.active_bo_pending:
|
||||
@ -112,6 +120,7 @@ class ActiveBOTopic(Node):
|
||||
self.bo_runs = msg.nr_runs
|
||||
self.bo_acq_fcn = msg.acquisition_function
|
||||
self.bo_metric_parameter = msg.metric_parameter
|
||||
self.save_result = msg.save_result
|
||||
|
||||
# initialize
|
||||
self.reward = np.zeros((self.bo_episodes, self.bo_runs))
|
||||
@ -128,7 +137,6 @@ class ActiveBOTopic(Node):
|
||||
def reset_rl_response(self):
|
||||
self.rl_weights = None
|
||||
self.rl_final_step = None
|
||||
self.rl_reward = None
|
||||
|
||||
def active_rl_callback(self, msg):
|
||||
if self.active_rl_pending:
|
||||
@ -180,6 +188,18 @@ class ActiveBOTopic(Node):
|
||||
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
|
||||
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
||||
|
||||
if self.save_result:
|
||||
home_dir = os.path.expanduser('~')
|
||||
file_path = os.path.join(home_dir, 'Documents/IntRLResults')
|
||||
filename = self.bo_metric + '-'\
|
||||
+ str(self.bo_metric_parameter) + '-' \
|
||||
+ str(self.bo_nr_weights) + '-' + str(time.time())
|
||||
filename = filename.replace('.', '_') + '.csv'
|
||||
path = os.path.join(file_path, filename)
|
||||
self.get_logger().info(path)
|
||||
|
||||
np.savetxt(path, self.reward, delimiter=',')
|
||||
|
||||
self.get_logger().info('Responding: Active BO')
|
||||
self.active_bo_pub.publish(bo_response)
|
||||
self.reset_bo_request()
|
||||
@ -249,6 +269,18 @@ class ActiveBOTopic(Node):
|
||||
self.current_run += 1
|
||||
self.get_logger().info(f'Current Run: {self.current_run}')
|
||||
|
||||
# send the current states
|
||||
|
||||
if self.BO is not None and self.BO.Y is not None:
|
||||
self.best_reward = np.max(self.BO.Y)
|
||||
|
||||
state_msg = ActiveBOState()
|
||||
state_msg.current_run = self.current_run + 1 if self.current_run < self.bo_runs else self.bo_runs
|
||||
state_msg.current_episode = self.current_episode + 1\
|
||||
if self.current_episode < self.bo_episodes else self.bo_episodes
|
||||
state_msg.best_reward = self.best_reward
|
||||
state_msg.last_user_reward = self.rl_reward
|
||||
self.state_pub.publish(state_msg)
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
|
@ -172,6 +172,7 @@ class ActiveRLService(Node):
|
||||
if done:
|
||||
self.best_pol_shown = True
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
|
||||
elif self.best_pol_shown:
|
||||
if not self.eval_response_received:
|
||||
|
Loading…
Reference in New Issue
Block a user