Policy Service complete

This commit is contained in:
Niko Feith 2023-02-24 16:32:55 +01:00
parent 378cae14db
commit 52fb25844b
12 changed files with 169 additions and 11 deletions

View File

@ -17,18 +17,14 @@ endif()
# find dependencies
find_package(ament_cmake REQUIRED)
# uncomment the following section in order to fill in
# further dependencies manually.
# find_package(<dependency> REQUIRED)
find_package(rosidl_default_generators REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"srv/WeightToPolicy.srv"
)
if(BUILD_TESTING)
find_package(ament_lint_auto REQUIRED)
# the following line skips the linter which checks for copyrights
# uncomment the line when a copyright and license is not present in all source files
#set(ament_cmake_copyright_FOUND TRUE)
# the following line skips cpplint (only works in a git repo)
# uncomment the line when this package is not in a git repo
#set(ament_cmake_cpplint_FOUND TRUE)
ament_lint_auto_find_test_dependencies()
endif()

View File

@ -9,6 +9,12 @@
<buildtool_depend>ament_cmake</buildtool_depend>
<build_depend>rosidl_default_generators</build_depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<test_depend>ament_lint_auto</test_depend>
<test_depend>ament_lint_common</test_depend>

View File

@ -0,0 +1,4 @@
float32[] weights
uint16 nr_steps
---
float32[] policy

View File

@ -0,0 +1,13 @@
import numpy as np
def ConfidenceBound(gp, X, nr_test, nr_weights, lam=1.2, seed=None, lower=-1.0, upper=1.0):
y_hat = gp.predict(X)
best_y = max(y_hat)
rng = np.random.default_rng(seed=seed)
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
mu, sigma = gp.predict(X_test, return_std=True)
cb = mu + lam * sigma
idx = np.argmax(cb)
X_next = X_test[idx, :]
return X_next

View File

@ -0,0 +1,15 @@
import numpy as np
from scipy.stats import norm
def ExpectedImprovement(gp, X, nr_test, nr_weights, kappa=2.576, seed=None, lower=-1.0, upper=1.0):
y_hat = gp.predict(X)
best_y = max(y_hat)
rng = np.random.default_rng(seed=seed)
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
mu, sigma = gp.predict(X_test, return_std=True)
z = (mu - best_y - kappa) / sigma
ei = (mu - best_y - kappa) * norm.cdf(z) + sigma * norm.pdf(z)
idx = np.argmax(ei)
X_next = X_test[idx, :]
return X_next

View File

@ -0,0 +1,15 @@
import numpy as np
from scipy.stats import norm
def ProbabilityOfImprovement(gp, X, nr_test, nr_weights, kappa=2.576, seed=None, lower=-1.0, upper=1.0):
y_hat = gp.predict(X)
best_y = max(y_hat)
rng = np.random.default_rng(seed=seed)
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
mu, sigma = gp.predict(X_test, return_std=True)
z = (mu - best_y - kappa) / sigma
pi = norm.cdf(z)
idx = np.argmax(pi)
X_next = X_test[idx, :]
return X_next

View File

@ -0,0 +1,53 @@
import numpy as np
class GaussianRBF:
def __init__(self, nr_weights, nr_steps, seed=None, lowerb=-1.0, upperb=1.0):
self.nr_weights = nr_weights
self.nr_steps = nr_steps
self.weights = None
self.policy = None
self.mean = np.linspace(0, self.nr_steps, self.nr_weights)
if nr_weights > 1:
self.std = self.mean[1] / (2 * np.sqrt(2 * np.log(2))) # Full width at half maximum
else:
self.std = self.nr_steps / 2
self.rng = np.random.default_rng(seed=seed)
self.low = lowerb
self.upper = upperb
self.reset()
def reset(self):
self.weights = np.zeros((self.nr_weights, 1))
self.policy = np.zeros((self.nr_steps, 1))
def random_policy(self):
self.weights = self.rng.uniform(self.low, self.upper, self.nr_weights)
def policy_rollout(self):
self.policy = np.zeros((self.nr_steps, 1))
for i in range(self.nr_steps):
for j in range(self.nr_weights):
base_fun = np.exp(-0.5*(i - self.mean[j])**2 / self.std**2)
self.policy[i] += base_fun * self.weights[j]
return self.policy
# def plot_policy(self, finished=np.NAN):
# x = np.linspace(0, self.nr_steps, self.nr_steps)
# plt.plot(x, self.trajectory)
# if finished != np.NAN:
# plt.vlines(finished, -1, 1, colors='red')
# # for i in self.mean:
# # gaussian = np.exp(-0.5 * (x - i)**2 / self.std**2)
# # plt.plot(x, gaussian)
def main():
policy = GaussianRBFModel(1, 50)
policy.random_policy()
policy.policy_rollout()
print(policy.weights)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,37 @@
from active_bo_msgs.srv import WeightToPolicy
import rclpy
from rclpy.node import Node
from active_bo_ros.PolicyModel.GaussianRBFModel import GaussianRBF
import numpy as np
class PolicyService(Node):
def __init__(self):
super().__init__('policy_service')
self.srv = self.create_service(WeightToPolicy, 'policy_srv', self.policy_callback)
@staticmethod
def policy_callback(request, response):
weights = request.weights
weight_len = len(weights)
nr_steps = request.nr_steps
policy = GaussianRBF(weight_len, nr_steps)
policy.weights = weights
policy.policy_rollout()
response.policy = policy.policy.flatten().tolist()
return response
def main(args=None):
rclpy.init(args=args)
policy_service = PolicyService()
rclpy.spin(policy_service)
rclpy.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,11 @@
from launch import LaunchDescription
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
Node(
package='active_bo_ros',
executable='policy_srv',
name='policy_srv'
),
])

View File

@ -7,6 +7,10 @@
<maintainer email="nikolaus.feith@unileoben.ac.at">cpsfeith</maintainer>
<license>TODO: License declaration</license>
<exec_depend>example_interfaces</exec_depend>
<exec_depend>active_bo_msgs</exec_depend>
<exec_depend>rclpy</exec_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>

View File

@ -1,17 +1,20 @@
from setuptools import setup
import os
from glob import glob
package_name = 'active_bo_ros'
setup(
name=package_name,
version='0.0.0',
packages=[package_name],
packages=[package_name, package_name+'/PolicyModel'],
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name), glob('launch/*.launch.py')),
],
install_requires=['setuptools'],
install_requires=['setuptools', 'numpy'],
zip_safe=True,
maintainer='cpsfeith',
maintainer_email='nikolaus.feith@unileoben.ac.at',
@ -20,6 +23,7 @@ setup(
tests_require=['pytest'],
entry_points={
'console_scripts': [
'policy_srv = active_bo_ros.policy_service:main'
],
},
)