Policy Service complete
This commit is contained in:
parent
378cae14db
commit
52fb25844b
@ -17,18 +17,14 @@ endif()
|
|||||||
|
|
||||||
# find dependencies
|
# find dependencies
|
||||||
find_package(ament_cmake REQUIRED)
|
find_package(ament_cmake REQUIRED)
|
||||||
# uncomment the following section in order to fill in
|
find_package(rosidl_default_generators REQUIRED)
|
||||||
# further dependencies manually.
|
|
||||||
# find_package(<dependency> REQUIRED)
|
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||||
|
"srv/WeightToPolicy.srv"
|
||||||
|
)
|
||||||
|
|
||||||
if(BUILD_TESTING)
|
if(BUILD_TESTING)
|
||||||
find_package(ament_lint_auto REQUIRED)
|
find_package(ament_lint_auto REQUIRED)
|
||||||
# the following line skips the linter which checks for copyrights
|
|
||||||
# uncomment the line when a copyright and license is not present in all source files
|
|
||||||
#set(ament_cmake_copyright_FOUND TRUE)
|
|
||||||
# the following line skips cpplint (only works in a git repo)
|
|
||||||
# uncomment the line when this package is not in a git repo
|
|
||||||
#set(ament_cmake_cpplint_FOUND TRUE)
|
|
||||||
ament_lint_auto_find_test_dependencies()
|
ament_lint_auto_find_test_dependencies()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -9,6 +9,12 @@
|
|||||||
|
|
||||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||||
|
|
||||||
|
<build_depend>rosidl_default_generators</build_depend>
|
||||||
|
|
||||||
|
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||||
|
|
||||||
|
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||||
|
|
||||||
<test_depend>ament_lint_auto</test_depend>
|
<test_depend>ament_lint_auto</test_depend>
|
||||||
<test_depend>ament_lint_common</test_depend>
|
<test_depend>ament_lint_common</test_depend>
|
||||||
|
|
||||||
|
4
src/active_bo_msgs/srv/WeightToPolicy.srv
Normal file
4
src/active_bo_msgs/srv/WeightToPolicy.srv
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
float32[] weights
|
||||||
|
uint16 nr_steps
|
||||||
|
---
|
||||||
|
float32[] policy
|
@ -0,0 +1,13 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def ConfidenceBound(gp, X, nr_test, nr_weights, lam=1.2, seed=None, lower=-1.0, upper=1.0):
|
||||||
|
y_hat = gp.predict(X)
|
||||||
|
best_y = max(y_hat)
|
||||||
|
rng = np.random.default_rng(seed=seed)
|
||||||
|
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
|
||||||
|
mu, sigma = gp.predict(X_test, return_std=True)
|
||||||
|
cb = mu + lam * sigma
|
||||||
|
|
||||||
|
idx = np.argmax(cb)
|
||||||
|
X_next = X_test[idx, :]
|
||||||
|
return X_next
|
@ -0,0 +1,15 @@
|
|||||||
|
import numpy as np
|
||||||
|
from scipy.stats import norm
|
||||||
|
|
||||||
|
def ExpectedImprovement(gp, X, nr_test, nr_weights, kappa=2.576, seed=None, lower=-1.0, upper=1.0):
|
||||||
|
y_hat = gp.predict(X)
|
||||||
|
best_y = max(y_hat)
|
||||||
|
rng = np.random.default_rng(seed=seed)
|
||||||
|
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
|
||||||
|
mu, sigma = gp.predict(X_test, return_std=True)
|
||||||
|
z = (mu - best_y - kappa) / sigma
|
||||||
|
ei = (mu - best_y - kappa) * norm.cdf(z) + sigma * norm.pdf(z)
|
||||||
|
|
||||||
|
idx = np.argmax(ei)
|
||||||
|
X_next = X_test[idx, :]
|
||||||
|
return X_next
|
@ -0,0 +1,15 @@
|
|||||||
|
import numpy as np
|
||||||
|
from scipy.stats import norm
|
||||||
|
|
||||||
|
def ProbabilityOfImprovement(gp, X, nr_test, nr_weights, kappa=2.576, seed=None, lower=-1.0, upper=1.0):
|
||||||
|
y_hat = gp.predict(X)
|
||||||
|
best_y = max(y_hat)
|
||||||
|
rng = np.random.default_rng(seed=seed)
|
||||||
|
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
|
||||||
|
mu, sigma = gp.predict(X_test, return_std=True)
|
||||||
|
z = (mu - best_y - kappa) / sigma
|
||||||
|
pi = norm.cdf(z)
|
||||||
|
|
||||||
|
idx = np.argmax(pi)
|
||||||
|
X_next = X_test[idx, :]
|
||||||
|
return X_next
|
@ -0,0 +1,53 @@
|
|||||||
|
import numpy as np
|
||||||
|
class GaussianRBF:
|
||||||
|
def __init__(self, nr_weights, nr_steps, seed=None, lowerb=-1.0, upperb=1.0):
|
||||||
|
self.nr_weights = nr_weights
|
||||||
|
self.nr_steps = nr_steps
|
||||||
|
self.weights = None
|
||||||
|
self.policy = None
|
||||||
|
self.mean = np.linspace(0, self.nr_steps, self.nr_weights)
|
||||||
|
if nr_weights > 1:
|
||||||
|
self.std = self.mean[1] / (2 * np.sqrt(2 * np.log(2))) # Full width at half maximum
|
||||||
|
else:
|
||||||
|
self.std = self.nr_steps / 2
|
||||||
|
|
||||||
|
self.rng = np.random.default_rng(seed=seed)
|
||||||
|
self.low = lowerb
|
||||||
|
self.upper = upperb
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.weights = np.zeros((self.nr_weights, 1))
|
||||||
|
self.policy = np.zeros((self.nr_steps, 1))
|
||||||
|
|
||||||
|
def random_policy(self):
|
||||||
|
self.weights = self.rng.uniform(self.low, self.upper, self.nr_weights)
|
||||||
|
|
||||||
|
def policy_rollout(self):
|
||||||
|
self.policy = np.zeros((self.nr_steps, 1))
|
||||||
|
for i in range(self.nr_steps):
|
||||||
|
for j in range(self.nr_weights):
|
||||||
|
base_fun = np.exp(-0.5*(i - self.mean[j])**2 / self.std**2)
|
||||||
|
self.policy[i] += base_fun * self.weights[j]
|
||||||
|
|
||||||
|
return self.policy
|
||||||
|
|
||||||
|
# def plot_policy(self, finished=np.NAN):
|
||||||
|
# x = np.linspace(0, self.nr_steps, self.nr_steps)
|
||||||
|
# plt.plot(x, self.trajectory)
|
||||||
|
# if finished != np.NAN:
|
||||||
|
# plt.vlines(finished, -1, 1, colors='red')
|
||||||
|
# # for i in self.mean:
|
||||||
|
# # gaussian = np.exp(-0.5 * (x - i)**2 / self.std**2)
|
||||||
|
# # plt.plot(x, gaussian)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
policy = GaussianRBFModel(1, 50)
|
||||||
|
policy.random_policy()
|
||||||
|
policy.policy_rollout()
|
||||||
|
print(policy.weights)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
37
src/active_bo_ros/active_bo_ros/policy_service.py
Normal file
37
src/active_bo_ros/active_bo_ros/policy_service.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from active_bo_msgs.srv import WeightToPolicy
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
|
||||||
|
from active_bo_ros.PolicyModel.GaussianRBFModel import GaussianRBF
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class PolicyService(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('policy_service')
|
||||||
|
self.srv = self.create_service(WeightToPolicy, 'policy_srv', self.policy_callback)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def policy_callback(request, response):
|
||||||
|
weights = request.weights
|
||||||
|
weight_len = len(weights)
|
||||||
|
nr_steps = request.nr_steps
|
||||||
|
|
||||||
|
policy = GaussianRBF(weight_len, nr_steps)
|
||||||
|
policy.weights = weights
|
||||||
|
policy.policy_rollout()
|
||||||
|
|
||||||
|
response.policy = policy.policy.flatten().tolist()
|
||||||
|
return response
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
|
||||||
|
policy_service = PolicyService()
|
||||||
|
|
||||||
|
rclpy.spin(policy_service)
|
||||||
|
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
11
src/active_bo_ros/launch/policy_service.launch.py
Executable file
11
src/active_bo_ros/launch/policy_service.launch.py
Executable file
@ -0,0 +1,11 @@
|
|||||||
|
from launch import LaunchDescription
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
return LaunchDescription([
|
||||||
|
Node(
|
||||||
|
package='active_bo_ros',
|
||||||
|
executable='policy_srv',
|
||||||
|
name='policy_srv'
|
||||||
|
),
|
||||||
|
])
|
@ -7,6 +7,10 @@
|
|||||||
<maintainer email="nikolaus.feith@unileoben.ac.at">cpsfeith</maintainer>
|
<maintainer email="nikolaus.feith@unileoben.ac.at">cpsfeith</maintainer>
|
||||||
<license>TODO: License declaration</license>
|
<license>TODO: License declaration</license>
|
||||||
|
|
||||||
|
<exec_depend>example_interfaces</exec_depend>
|
||||||
|
<exec_depend>active_bo_msgs</exec_depend>
|
||||||
|
<exec_depend>rclpy</exec_depend>
|
||||||
|
|
||||||
<test_depend>ament_copyright</test_depend>
|
<test_depend>ament_copyright</test_depend>
|
||||||
<test_depend>ament_flake8</test_depend>
|
<test_depend>ament_flake8</test_depend>
|
||||||
<test_depend>ament_pep257</test_depend>
|
<test_depend>ament_pep257</test_depend>
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
package_name = 'active_bo_ros'
|
package_name = 'active_bo_ros'
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=package_name,
|
name=package_name,
|
||||||
version='0.0.0',
|
version='0.0.0',
|
||||||
packages=[package_name],
|
packages=[package_name, package_name+'/PolicyModel'],
|
||||||
data_files=[
|
data_files=[
|
||||||
('share/ament_index/resource_index/packages',
|
('share/ament_index/resource_index/packages',
|
||||||
['resource/' + package_name]),
|
['resource/' + package_name]),
|
||||||
('share/' + package_name, ['package.xml']),
|
('share/' + package_name, ['package.xml']),
|
||||||
|
(os.path.join('share', package_name), glob('launch/*.launch.py')),
|
||||||
],
|
],
|
||||||
install_requires=['setuptools'],
|
install_requires=['setuptools', 'numpy'],
|
||||||
zip_safe=True,
|
zip_safe=True,
|
||||||
maintainer='cpsfeith',
|
maintainer='cpsfeith',
|
||||||
maintainer_email='nikolaus.feith@unileoben.ac.at',
|
maintainer_email='nikolaus.feith@unileoben.ac.at',
|
||||||
@ -20,6 +23,7 @@ setup(
|
|||||||
tests_require=['pytest'],
|
tests_require=['pytest'],
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': [
|
'console_scripts': [
|
||||||
|
'policy_srv = active_bo_ros.policy_service:main'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user