Policy Service complete
This commit is contained in:
parent
378cae14db
commit
52fb25844b
@ -17,18 +17,14 @@ endif()
|
||||
|
||||
# find dependencies
|
||||
find_package(ament_cmake REQUIRED)
|
||||
# uncomment the following section in order to fill in
|
||||
# further dependencies manually.
|
||||
# find_package(<dependency> REQUIRED)
|
||||
find_package(rosidl_default_generators REQUIRED)
|
||||
|
||||
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
"srv/WeightToPolicy.srv"
|
||||
)
|
||||
|
||||
if(BUILD_TESTING)
|
||||
find_package(ament_lint_auto REQUIRED)
|
||||
# the following line skips the linter which checks for copyrights
|
||||
# uncomment the line when a copyright and license is not present in all source files
|
||||
#set(ament_cmake_copyright_FOUND TRUE)
|
||||
# the following line skips cpplint (only works in a git repo)
|
||||
# uncomment the line when this package is not in a git repo
|
||||
#set(ament_cmake_cpplint_FOUND TRUE)
|
||||
ament_lint_auto_find_test_dependencies()
|
||||
endif()
|
||||
|
||||
|
@ -9,6 +9,12 @@
|
||||
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
|
||||
<build_depend>rosidl_default_generators</build_depend>
|
||||
|
||||
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||
|
||||
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||
|
||||
<test_depend>ament_lint_auto</test_depend>
|
||||
<test_depend>ament_lint_common</test_depend>
|
||||
|
||||
|
4
src/active_bo_msgs/srv/WeightToPolicy.srv
Normal file
4
src/active_bo_msgs/srv/WeightToPolicy.srv
Normal file
@ -0,0 +1,4 @@
|
||||
float32[] weights
|
||||
uint16 nr_steps
|
||||
---
|
||||
float32[] policy
|
@ -0,0 +1,13 @@
|
||||
import numpy as np
|
||||
|
||||
def ConfidenceBound(gp, X, nr_test, nr_weights, lam=1.2, seed=None, lower=-1.0, upper=1.0):
|
||||
y_hat = gp.predict(X)
|
||||
best_y = max(y_hat)
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
|
||||
mu, sigma = gp.predict(X_test, return_std=True)
|
||||
cb = mu + lam * sigma
|
||||
|
||||
idx = np.argmax(cb)
|
||||
X_next = X_test[idx, :]
|
||||
return X_next
|
@ -0,0 +1,15 @@
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
|
||||
def ExpectedImprovement(gp, X, nr_test, nr_weights, kappa=2.576, seed=None, lower=-1.0, upper=1.0):
|
||||
y_hat = gp.predict(X)
|
||||
best_y = max(y_hat)
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
|
||||
mu, sigma = gp.predict(X_test, return_std=True)
|
||||
z = (mu - best_y - kappa) / sigma
|
||||
ei = (mu - best_y - kappa) * norm.cdf(z) + sigma * norm.pdf(z)
|
||||
|
||||
idx = np.argmax(ei)
|
||||
X_next = X_test[idx, :]
|
||||
return X_next
|
@ -0,0 +1,15 @@
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
|
||||
def ProbabilityOfImprovement(gp, X, nr_test, nr_weights, kappa=2.576, seed=None, lower=-1.0, upper=1.0):
|
||||
y_hat = gp.predict(X)
|
||||
best_y = max(y_hat)
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
|
||||
mu, sigma = gp.predict(X_test, return_std=True)
|
||||
z = (mu - best_y - kappa) / sigma
|
||||
pi = norm.cdf(z)
|
||||
|
||||
idx = np.argmax(pi)
|
||||
X_next = X_test[idx, :]
|
||||
return X_next
|
@ -0,0 +1,53 @@
|
||||
import numpy as np
|
||||
class GaussianRBF:
|
||||
def __init__(self, nr_weights, nr_steps, seed=None, lowerb=-1.0, upperb=1.0):
|
||||
self.nr_weights = nr_weights
|
||||
self.nr_steps = nr_steps
|
||||
self.weights = None
|
||||
self.policy = None
|
||||
self.mean = np.linspace(0, self.nr_steps, self.nr_weights)
|
||||
if nr_weights > 1:
|
||||
self.std = self.mean[1] / (2 * np.sqrt(2 * np.log(2))) # Full width at half maximum
|
||||
else:
|
||||
self.std = self.nr_steps / 2
|
||||
|
||||
self.rng = np.random.default_rng(seed=seed)
|
||||
self.low = lowerb
|
||||
self.upper = upperb
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.weights = np.zeros((self.nr_weights, 1))
|
||||
self.policy = np.zeros((self.nr_steps, 1))
|
||||
|
||||
def random_policy(self):
|
||||
self.weights = self.rng.uniform(self.low, self.upper, self.nr_weights)
|
||||
|
||||
def policy_rollout(self):
|
||||
self.policy = np.zeros((self.nr_steps, 1))
|
||||
for i in range(self.nr_steps):
|
||||
for j in range(self.nr_weights):
|
||||
base_fun = np.exp(-0.5*(i - self.mean[j])**2 / self.std**2)
|
||||
self.policy[i] += base_fun * self.weights[j]
|
||||
|
||||
return self.policy
|
||||
|
||||
# def plot_policy(self, finished=np.NAN):
|
||||
# x = np.linspace(0, self.nr_steps, self.nr_steps)
|
||||
# plt.plot(x, self.trajectory)
|
||||
# if finished != np.NAN:
|
||||
# plt.vlines(finished, -1, 1, colors='red')
|
||||
# # for i in self.mean:
|
||||
# # gaussian = np.exp(-0.5 * (x - i)**2 / self.std**2)
|
||||
# # plt.plot(x, gaussian)
|
||||
|
||||
|
||||
def main():
|
||||
policy = GaussianRBFModel(1, 50)
|
||||
policy.random_policy()
|
||||
policy.policy_rollout()
|
||||
print(policy.weights)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
37
src/active_bo_ros/active_bo_ros/policy_service.py
Normal file
37
src/active_bo_ros/active_bo_ros/policy_service.py
Normal file
@ -0,0 +1,37 @@
|
||||
from active_bo_msgs.srv import WeightToPolicy
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from active_bo_ros.PolicyModel.GaussianRBFModel import GaussianRBF
|
||||
import numpy as np
|
||||
|
||||
class PolicyService(Node):
|
||||
def __init__(self):
|
||||
super().__init__('policy_service')
|
||||
self.srv = self.create_service(WeightToPolicy, 'policy_srv', self.policy_callback)
|
||||
|
||||
@staticmethod
|
||||
def policy_callback(request, response):
|
||||
weights = request.weights
|
||||
weight_len = len(weights)
|
||||
nr_steps = request.nr_steps
|
||||
|
||||
policy = GaussianRBF(weight_len, nr_steps)
|
||||
policy.weights = weights
|
||||
policy.policy_rollout()
|
||||
|
||||
response.policy = policy.policy.flatten().tolist()
|
||||
return response
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
|
||||
policy_service = PolicyService()
|
||||
|
||||
rclpy.spin(policy_service)
|
||||
|
||||
rclpy.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
11
src/active_bo_ros/launch/policy_service.launch.py
Executable file
11
src/active_bo_ros/launch/policy_service.launch.py
Executable file
@ -0,0 +1,11 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
|
||||
def generate_launch_description():
|
||||
return LaunchDescription([
|
||||
Node(
|
||||
package='active_bo_ros',
|
||||
executable='policy_srv',
|
||||
name='policy_srv'
|
||||
),
|
||||
])
|
@ -7,6 +7,10 @@
|
||||
<maintainer email="nikolaus.feith@unileoben.ac.at">cpsfeith</maintainer>
|
||||
<license>TODO: License declaration</license>
|
||||
|
||||
<exec_depend>example_interfaces</exec_depend>
|
||||
<exec_depend>active_bo_msgs</exec_depend>
|
||||
<exec_depend>rclpy</exec_depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
|
@ -1,17 +1,20 @@
|
||||
from setuptools import setup
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
package_name = 'active_bo_ros'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.0.0',
|
||||
packages=[package_name],
|
||||
packages=[package_name, package_name+'/PolicyModel'],
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name), glob('launch/*.launch.py')),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
install_requires=['setuptools', 'numpy'],
|
||||
zip_safe=True,
|
||||
maintainer='cpsfeith',
|
||||
maintainer_email='nikolaus.feith@unileoben.ac.at',
|
||||
@ -20,6 +23,7 @@ setup(
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'policy_srv = active_bo_ros.policy_service:main'
|
||||
],
|
||||
},
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user