started adding user query interactions
This commit is contained in:
parent
6cdb7f8711
commit
2cea3b3c53
@ -1,13 +1,11 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def ConfidenceBound(gp, X, nr_test, nr_weights, lam=1.2, seed=None, lower=-1.0, upper=1.0):
|
||||
y_hat = gp.predict(X)
|
||||
best_y = max(y_hat)
|
||||
def ConfidenceBound(gp, nr_test, nr_weights, beta=1.2, seed=None, lower=-1.0, upper=1.0):
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
|
||||
mu, sigma = gp.predict(X_test, return_std=True)
|
||||
cb = mu + lam * sigma
|
||||
cb = mu + beta * sigma
|
||||
|
||||
idx = np.argmax(cb)
|
||||
X_next = X_test[idx, :]
|
||||
|
@ -109,10 +109,9 @@ class BayesianOptimization:
|
||||
|
||||
elif self.acq == "Upper Confidence Bound":
|
||||
x_next = ConfidenceBound(self.GP,
|
||||
self.X,
|
||||
self.eval_X,
|
||||
self.nr_policy_weights,
|
||||
lam=2.576,
|
||||
beta=2.576,
|
||||
seed=self.policy_seed,
|
||||
lower=self.lower_bound,
|
||||
upper=self.upper_bound)
|
||||
|
@ -0,0 +1,19 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ImprovementQuery:
|
||||
def __init__(self, threshold, period):
|
||||
self.threshold = threshold
|
||||
self.period = period
|
||||
|
||||
def query(self, reward_array):
|
||||
if reward_array.shape < self.period:
|
||||
return False
|
||||
|
||||
else:
|
||||
first = reward_array[-self.period]
|
||||
last = reward_array[-1]
|
||||
|
||||
slope = (last - first) / self.period
|
||||
|
||||
return slope < self.threshold
|
55
src/active_bo_ros/active_bo_ros/UserQuery/max_acq_query.py
Normal file
55
src/active_bo_ros/active_bo_ros/UserQuery/max_acq_query.py
Normal file
@ -0,0 +1,55 @@
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
|
||||
|
||||
class MaxAcqQuery:
|
||||
def __init__(self, threshold, gp,
|
||||
nr_test, nr_weights,
|
||||
lower=-1.0, upper=1.0,
|
||||
acq="Expected Improvement",
|
||||
**kwargs):
|
||||
self.threshold = threshold
|
||||
self.gp = gp
|
||||
self.nr_test = nr_test
|
||||
self.nr_weights = nr_weights
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.acq = acq
|
||||
|
||||
self.seed = kwargs.get('seed', None)
|
||||
self.kappa = kwargs.get('kappa', 2.576)
|
||||
self.beta = kwargs.get('beta', 1.2)
|
||||
self.X = kwargs.get('X', None)
|
||||
|
||||
self.rng = np.random.default_rng(self.seed)
|
||||
|
||||
def query(self):
|
||||
X_test = self.rng.uniform(self.lower, self.upper, (self.nr_test, self.nr_weights))
|
||||
max_acq = 0
|
||||
|
||||
if self.acq == "Expected Improvement":
|
||||
if self.X is None:
|
||||
raise ValueError
|
||||
y_hat = self.gp.predict(self.X)
|
||||
best_y = max(y_hat)
|
||||
mu, sigma = self.gp.predict(X_test, return_std=True)
|
||||
z = (mu - best_y - self.kappa) / sigma
|
||||
ei = (mu - best_y - self.kappa) * norm.cdf(z) + sigma * norm.pdf(z)
|
||||
max_acq = np.max(ei)
|
||||
|
||||
if self.acq == "Probability of Improvement":
|
||||
if self.X is None:
|
||||
raise ValueError
|
||||
y_hat = self.gp.predict(self.X)
|
||||
best_y = max(y_hat)
|
||||
mu, sigma = self.gp.predict(X_test, return_std=True)
|
||||
z = (mu - best_y - self.kappa) / sigma
|
||||
pi = norm.cdf(z)
|
||||
max_acq = np.max(pi)
|
||||
|
||||
if self.acq == "Upper Confidence Bound":
|
||||
mu, sigma = self.gp.predict(X_test, return_std=True)
|
||||
cb = mu + self.beta * sigma
|
||||
max_acq = np.max(cb)
|
||||
|
||||
return max_acq > self.threshold
|
10
src/active_bo_ros/active_bo_ros/UserQuery/random_query.py
Normal file
10
src/active_bo_ros/active_bo_ros/UserQuery/random_query.py
Normal file
@ -0,0 +1,10 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RandomQuery:
|
||||
def __init__(self, threshold):
|
||||
self.threshold = threshold
|
||||
self.random = np.random.uniform(0.0, 1.0, 1)
|
||||
|
||||
def query(self):
|
||||
return self.random > self.threshold
|
13
src/active_bo_ros/active_bo_ros/UserQuery/regular_query.py
Normal file
13
src/active_bo_ros/active_bo_ros/UserQuery/regular_query.py
Normal file
@ -0,0 +1,13 @@
|
||||
class RegularQuery:
|
||||
def __init__(self, regular):
|
||||
self.regular = regular
|
||||
self.counter = 0
|
||||
|
||||
def query(self):
|
||||
if self.counter < self.regular:
|
||||
self.counter += 1
|
||||
return False
|
||||
|
||||
else:
|
||||
self.counter = 0
|
||||
return True
|
@ -1,123 +0,0 @@
|
||||
from active_bo_msgs.srv import ActiveBO
|
||||
from active_bo_msgs.srv import ActiveRL
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||
|
||||
from active_bo_ros.BayesianOptimization.BayesianOptimization import BayesianOptimization
|
||||
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
class ActiveBOService(Node):
|
||||
def __init__(self):
|
||||
super().__init__('active_bo_service')
|
||||
|
||||
bo_callback_group = ReentrantCallbackGroup()
|
||||
rl_callback_group = ReentrantCallbackGroup()
|
||||
|
||||
self.srv = self.create_service(ActiveBO,
|
||||
'active_bo_srv',
|
||||
self.active_bo_callback,
|
||||
callback_group=bo_callback_group)
|
||||
|
||||
self.active_rl_client = self.create_client(ActiveRL,
|
||||
'active_rl_srv',
|
||||
callback_group=rl_callback_group)
|
||||
|
||||
self.env = Continuous_MountainCarEnv()
|
||||
self.distance_penalty = 0
|
||||
|
||||
self.nr_init = 3
|
||||
|
||||
def active_bo_callback(self, request, response):
|
||||
self.get_logger().info('Active Bayesian Optimization Service started!')
|
||||
nr_weights = request.nr_weights
|
||||
max_steps = request.max_steps
|
||||
nr_episodes = request.nr_episodes
|
||||
nr_runs = request.nr_runs
|
||||
acq = request.acquisition_function
|
||||
epsilon = request.epsilon
|
||||
|
||||
reward = np.zeros((nr_episodes, nr_runs))
|
||||
best_pol_reward = np.zeros((1, nr_runs))
|
||||
best_policy = np.zeros((max_steps, nr_runs))
|
||||
best_weights = np.zeros((nr_weights, nr_runs))
|
||||
|
||||
BO = BayesianOptimization(self.env,
|
||||
max_steps,
|
||||
nr_init=self.nr_init,
|
||||
acq=acq,
|
||||
nr_weights=nr_weights)
|
||||
|
||||
arl_request = ActiveRL.Request()
|
||||
for i in range(nr_runs):
|
||||
BO.initialize()
|
||||
|
||||
for j in range(nr_episodes):
|
||||
# active part
|
||||
if (j > 0) and (np.random.uniform(0.0, 1.0, 1) < epsilon):
|
||||
self.get_logger().info('Active User Input')
|
||||
old_policy, _, old_weights = BO.get_best_result()
|
||||
|
||||
arl_request.old_policy = old_policy.tolist()
|
||||
arl_request.old_weights = old_weights.tolist()
|
||||
self.get_logger().info('Calling: Active RL')
|
||||
future_rl = self.active_rl_client.call_async(arl_request)
|
||||
self.get_logger().info(str(future_rl))
|
||||
|
||||
timeout = 10
|
||||
start_time = time.time()
|
||||
while not future_rl.done():
|
||||
rclpy.spin_once(self, timeout_sec=0.1)
|
||||
# self.get_logger().info(f'{future_rl.result()}')
|
||||
if time.time() - start_time > timeout:
|
||||
self.get_logger().error('Service call timed out.')
|
||||
break
|
||||
# self.executor.spin_until_future_complete(future_rl)
|
||||
|
||||
# arl_response = self.active_rl_client.call(arl_request)
|
||||
self.get_logger().info('Received: Active RL')
|
||||
|
||||
try:
|
||||
arl_response = future_rl.result()
|
||||
BO.add_new_observation(arl_response.reward, arl_response.new_weights)
|
||||
except Exception as e:
|
||||
self.get_logger().error('active RL Service failed %r' % (e,))
|
||||
|
||||
# BO part
|
||||
else:
|
||||
x_next = BO.next_observation()
|
||||
BO.eval_new_observation(x_next)
|
||||
|
||||
self.get_logger().info(str(j))
|
||||
|
||||
best_policy[:, i], best_pol_reward[:, i], best_weights[:, i] = BO.get_best_result()
|
||||
|
||||
reward[:, i] = BO.best_reward.T
|
||||
|
||||
response.reward_mean = np.mean(reward, axis=1).tolist()
|
||||
response.reward_std = np.std(reward, axis=1).tolist()
|
||||
|
||||
best_policy_idx = np.argmax(best_pol_reward)
|
||||
response.best_weights = best_weights[:, best_policy_idx].tolist()
|
||||
response.best_policy = best_policy[:, best_policy_idx].tolist()
|
||||
return response
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
|
||||
active_bo_service = ActiveBOService()
|
||||
|
||||
rclpy.spin(active_bo_service)
|
||||
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,163 +0,0 @@
|
||||
from active_bo_msgs.srv import ActiveRL
|
||||
from active_bo_msgs.msg import ImageFeedback
|
||||
from active_bo_msgs.msg import ActiveRL as ActiveRLEval
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||
|
||||
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import copy
|
||||
|
||||
|
||||
class ActiveRLService(Node):
|
||||
def __init__(self):
|
||||
super().__init__('active_rl_service')
|
||||
srv_callback_group = ReentrantCallbackGroup()
|
||||
topic_callback_group = ReentrantCallbackGroup()
|
||||
|
||||
self.srv = self.create_service(ActiveRL,
|
||||
'active_rl_srv',
|
||||
self.active_rl_callback,
|
||||
callback_group=srv_callback_group)
|
||||
|
||||
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1, callback_group=topic_callback_group)
|
||||
|
||||
self.eval_pub = self.create_publisher(ActiveRLEval,
|
||||
'active_rl_eval_request',
|
||||
1,
|
||||
callback_group=topic_callback_group)
|
||||
self.eval_sub = self.create_subscription(ActiveRLEval,
|
||||
'active_rl_eval_response',
|
||||
self.active_rl_eval_callback,
|
||||
1,
|
||||
callback_group=topic_callback_group)
|
||||
# active_rl_eval_response
|
||||
self.eval_response_received = False
|
||||
self.eval_response = None
|
||||
|
||||
self.env = Continuous_MountainCarEnv(render_mode='rgb_array')
|
||||
self.distance_penalty = 0
|
||||
|
||||
def active_rl_eval_callback(self, response):
|
||||
self.eval_response = response
|
||||
self.eval_response_received = True
|
||||
|
||||
def active_rl_callback(self, request, response):
|
||||
|
||||
self.get_logger().info('Active RL: Called')
|
||||
|
||||
feedback_msg = ImageFeedback()
|
||||
|
||||
reward = 0
|
||||
step_count = 0
|
||||
old_policy = request.old_policy
|
||||
old_weights = request.old_weights
|
||||
|
||||
eval_request = ActiveRLEval()
|
||||
eval_request.policy = old_policy
|
||||
eval_request.weights = old_weights
|
||||
|
||||
self.env.reset()
|
||||
|
||||
self.get_logger().info('Best policy so far!')
|
||||
|
||||
for i in range(len(old_policy)):
|
||||
action = old_policy[i]
|
||||
output = self.env.step(action)
|
||||
|
||||
done = output[2]
|
||||
|
||||
rgb_array = self.env.render()
|
||||
rgb_shape = rgb_array.shape
|
||||
|
||||
red = rgb_array[:, :, 0].flatten().tolist()
|
||||
green = rgb_array[:, :, 1].flatten().tolist()
|
||||
blue = rgb_array[:, :, 2].flatten().tolist()
|
||||
|
||||
feedback_msg.height = rgb_shape[0]
|
||||
feedback_msg.width = rgb_shape[1]
|
||||
feedback_msg.red = red
|
||||
feedback_msg.green = green
|
||||
feedback_msg.blue = blue
|
||||
|
||||
self.publisher.publish(feedback_msg)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
self.get_logger().info('Enter new solution!')
|
||||
self.eval_pub.publish(eval_request)
|
||||
|
||||
while rclpy.ok():
|
||||
rclpy.spin_once(self, timeout_sec=0.1)
|
||||
if self.eval_response_received:
|
||||
break
|
||||
|
||||
self.get_logger().info('Topic responded!')
|
||||
new_policy = copy.deepcopy(self.eval_response.policy)
|
||||
new_weights = copy.deepcopy(self.eval_response.weights)
|
||||
self.eval_response_received = False
|
||||
self.eval_response = None
|
||||
|
||||
reward = 0
|
||||
step_count = 0
|
||||
done = False
|
||||
self.env.reset()
|
||||
|
||||
for i in range(len(new_policy)):
|
||||
action = new_policy[i]
|
||||
output = self.env.step(action)
|
||||
|
||||
reward += output[1]
|
||||
done = output[2]
|
||||
step_count += 1
|
||||
|
||||
rgb_array = self.env.render()
|
||||
rgb_shape = rgb_array.shape
|
||||
|
||||
red = rgb_array[:, :, 0].flatten().tolist()
|
||||
green = rgb_array[:, :, 1].flatten().tolist()
|
||||
blue = rgb_array[:, :, 2].flatten().tolist()
|
||||
|
||||
feedback_msg.height = rgb_shape[0]
|
||||
feedback_msg.width = rgb_shape[1]
|
||||
feedback_msg.red = red
|
||||
feedback_msg.green = green
|
||||
feedback_msg.blue = blue
|
||||
|
||||
self.publisher.publish(feedback_msg)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
if not done and i == len(new_policy):
|
||||
distance = -(self.env.goal_position - output[0][0])
|
||||
reward += distance * self.distance_penalty
|
||||
|
||||
self.get_logger().info(str(reward))
|
||||
response.new_weights = new_weights
|
||||
response.reward = reward
|
||||
response.final_step = step_count
|
||||
|
||||
self.get_logger().info(f'{response}')
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
|
||||
active_rl_service = ActiveRLService()
|
||||
|
||||
rclpy.spin(active_rl_service)
|
||||
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,38 +0,0 @@
|
||||
from active_bo_msgs.srv import WeightToPolicy
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from active_bo_ros.PolicyModel.GaussianRBFModel import GaussianRBF
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PolicyService(Node):
|
||||
def __init__(self):
|
||||
super().__init__('policy_service')
|
||||
self.srv = self.create_service(WeightToPolicy, 'policy_srv', self.policy_callback)
|
||||
|
||||
@staticmethod
|
||||
def policy_callback(request, response):
|
||||
weights = request.weights
|
||||
weight_len = len(weights)
|
||||
nr_steps = request.nr_steps
|
||||
|
||||
policy = GaussianRBF(weight_len, nr_steps)
|
||||
policy.weights = weights
|
||||
policy.rollout()
|
||||
|
||||
response.policy = policy.policy.flatten().tolist()
|
||||
return response
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
|
||||
policy_service = PolicyService()
|
||||
|
||||
rclpy.spin(policy_service)
|
||||
|
||||
rclpy.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -15,14 +15,7 @@ def generate_launch_description():
|
||||
)
|
||||
)
|
||||
)
|
||||
policy_launch = IncludeLaunchDescription(
|
||||
PythonLaunchDescriptionSource(
|
||||
os.path.join(
|
||||
get_package_share_directory('active_bo_ros'),
|
||||
'policy_service.launch.py'
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
rl_launch = IncludeLaunchDescription(
|
||||
PythonLaunchDescriptionSource(
|
||||
os.path.join(
|
||||
@ -31,6 +24,7 @@ def generate_launch_description():
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
bo_launch = IncludeLaunchDescription(
|
||||
PythonLaunchDescriptionSource(
|
||||
os.path.join(
|
||||
@ -39,9 +33,9 @@ def generate_launch_description():
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return LaunchDescription([
|
||||
websocket_launch,
|
||||
policy_launch,
|
||||
rl_launch,
|
||||
bo_launch
|
||||
])
|
||||
|
@ -1,12 +0,0 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
return LaunchDescription([
|
||||
Node(
|
||||
package='active_bo_ros',
|
||||
executable='policy_srv',
|
||||
name='policy_srv'
|
||||
),
|
||||
])
|
@ -27,12 +27,9 @@ setup(
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'policy_srv = active_bo_ros.policy_service:main',
|
||||
'rl_srv = active_bo_ros.rl_service:main',
|
||||
'bo_srv = active_bo_ros.bo_service:main',
|
||||
'bo_torch_srv = active_bo_ros.bo_torch_service:main',
|
||||
'active_bo_srv = active_bo_ros.active_bo_service:main',
|
||||
'active_rl_srv = active_bo_ros.active_rl_service:main',
|
||||
'active_bo_topic = active_bo_ros.active_bo_topic:main',
|
||||
'active_rl_topic = active_bo_ros.active_rl_topic:main',
|
||||
],
|
||||
|
Loading…
Reference in New Issue
Block a user