usability improvements and saving function

testing for bo and manual usage completed
This commit is contained in:
Niko Feith 2023-05-31 19:07:34 +02:00
parent c5f00a4069
commit 55b21d667a
5 changed files with 48 additions and 9 deletions

View File

@ -30,6 +30,7 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"msg/ActiveRLResponse.msg" "msg/ActiveRLResponse.msg"
"msg/ActiveRL.msg" "msg/ActiveRL.msg"
"msg/ImageFeedback.msg" "msg/ImageFeedback.msg"
"msg/ActiveBOState.msg"
) )

View File

@ -7,3 +7,4 @@ uint16 nr_episodes
uint16 nr_runs uint16 nr_runs
string acquisition_function string acquisition_function
float32 metric_parameter float32 metric_parameter
bool save_result

View File

@ -0,0 +1,4 @@
uint16 current_run
uint16 current_episode
float32 best_reward
float32 last_user_reward

View File

@ -2,6 +2,7 @@ from active_bo_msgs.msg import ActiveBORequest
from active_bo_msgs.msg import ActiveBOResponse from active_bo_msgs.msg import ActiveBOResponse
from active_bo_msgs.msg import ActiveRL from active_bo_msgs.msg import ActiveRL
from active_bo_msgs.msg import ActiveRLResponse from active_bo_msgs.msg import ActiveRLResponse
from active_bo_msgs.msg import ActiveBOState
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
@ -22,6 +23,7 @@ from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery
import numpy as np import numpy as np
import time import time
import os
class ActiveBOTopic(Node): class ActiveBOTopic(Node):
@ -47,14 +49,15 @@ class ActiveBOTopic(Node):
self.bo_metric = None self.bo_metric = None
self.bo_fixed_seed = False self.bo_fixed_seed = False
self.bo_nr_weights = None self.bo_nr_weights = None
self.bo_steps = None self.bo_steps = 0
self.bo_episodes = None self.bo_episodes = 0
self.bo_runs = None self.bo_runs = 0
self.bo_acq_fcn = None self.bo_acq_fcn = None
self.bo_metric_parameter = None self.bo_metric_parameter = None
self.current_run = 0 self.current_run = 0
self.current_episode = 0 self.current_episode = 0
self.seed = None self.seed = None
self.save_result = False
# Active Reinforcement Learning Publisher, Subscriber and Message attributes # Active Reinforcement Learning Publisher, Subscriber and Message attributes
self.active_rl_pub = self.create_publisher(ActiveRL, self.active_rl_pub = self.create_publisher(ActiveRL,
@ -68,7 +71,10 @@ class ActiveBOTopic(Node):
self.active_rl_pending = False self.active_rl_pending = False
self.rl_weights = None self.rl_weights = None
self.rl_final_step = None self.rl_final_step = None
self.rl_reward = None self.rl_reward = 0.0
# State Publisher
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
# RL Environments and BO # RL Environments and BO
self.env = None self.env = None
@ -76,6 +82,7 @@ class ActiveBOTopic(Node):
self.BO = None self.BO = None
self.nr_init = 3 self.nr_init = 3
self.reward = None self.reward = None
self.best_reward = 0.0
self.best_pol_reward = None self.best_pol_reward = None
self.best_policy = None self.best_policy = None
self.best_weights = None self.best_weights = None
@ -91,13 +98,14 @@ class ActiveBOTopic(Node):
self.bo_metric = None self.bo_metric = None
self.bo_fixed_seed = False self.bo_fixed_seed = False
self.bo_nr_weights = None self.bo_nr_weights = None
self.bo_steps = None self.bo_steps = 0
self.bo_episodes = None self.bo_episodes = 0
self.bo_runs = None self.bo_runs = 0
self.bo_acq_fcn = None self.bo_acq_fcn = None
self.bo_metric_parameter = None self.bo_metric_parameter = None
self.current_run = 0 self.current_run = 0
self.current_episode = 0 self.current_episode = 0
self.save_result = False
def active_bo_callback(self, msg): def active_bo_callback(self, msg):
if not self.active_bo_pending: if not self.active_bo_pending:
@ -112,6 +120,7 @@ class ActiveBOTopic(Node):
self.bo_runs = msg.nr_runs self.bo_runs = msg.nr_runs
self.bo_acq_fcn = msg.acquisition_function self.bo_acq_fcn = msg.acquisition_function
self.bo_metric_parameter = msg.metric_parameter self.bo_metric_parameter = msg.metric_parameter
self.save_result = msg.save_result
# initialize # initialize
self.reward = np.zeros((self.bo_episodes, self.bo_runs)) self.reward = np.zeros((self.bo_episodes, self.bo_runs))
@ -128,7 +137,6 @@ class ActiveBOTopic(Node):
def reset_rl_response(self): def reset_rl_response(self):
self.rl_weights = None self.rl_weights = None
self.rl_final_step = None self.rl_final_step = None
self.rl_reward = None
def active_rl_callback(self, msg): def active_rl_callback(self, msg):
if self.active_rl_pending: if self.active_rl_pending:
@ -180,6 +188,18 @@ class ActiveBOTopic(Node):
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist() bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
bo_response.reward_std = np.std(self.reward, axis=1).tolist() bo_response.reward_std = np.std(self.reward, axis=1).tolist()
if self.save_result:
home_dir = os.path.expanduser('~')
file_path = os.path.join(home_dir, 'Documents/IntRLResults')
filename = self.bo_metric + '-'\
+ str(self.bo_metric_parameter) + '-' \
+ str(self.bo_nr_weights) + '-' + str(time.time())
filename = filename.replace('.', '_') + '.csv'
path = os.path.join(file_path, filename)
self.get_logger().info(path)
np.savetxt(path, self.reward, delimiter=',')
self.get_logger().info('Responding: Active BO') self.get_logger().info('Responding: Active BO')
self.active_bo_pub.publish(bo_response) self.active_bo_pub.publish(bo_response)
self.reset_bo_request() self.reset_bo_request()
@ -249,6 +269,18 @@ class ActiveBOTopic(Node):
self.current_run += 1 self.current_run += 1
self.get_logger().info(f'Current Run: {self.current_run}') self.get_logger().info(f'Current Run: {self.current_run}')
# send the current states
if self.BO is not None and self.BO.Y is not None:
self.best_reward = np.max(self.BO.Y)
state_msg = ActiveBOState()
state_msg.current_run = self.current_run + 1 if self.current_run < self.bo_runs else self.bo_runs
state_msg.current_episode = self.current_episode + 1\
if self.current_episode < self.bo_episodes else self.bo_episodes
state_msg.best_reward = self.best_reward
state_msg.last_user_reward = self.rl_reward
self.state_pub.publish(state_msg)
def main(args=None): def main(args=None):
rclpy.init(args=args) rclpy.init(args=args)

View File

@ -172,6 +172,7 @@ class ActiveRLService(Node):
if done: if done:
self.best_pol_shown = True self.best_pol_shown = True
self.rl_step = 0 self.rl_step = 0
self.rl_reward = 0.0
elif self.best_pol_shown: elif self.best_pol_shown:
if not self.eval_response_received: if not self.eval_response_received: