active bo implemented

not working due to watch in RosBar.vue line 150
This commit is contained in:
Niko Feith 2023-03-20 17:46:59 +01:00
parent 9e61c3244b
commit 1db897b9d5
7 changed files with 130 additions and 21 deletions

View File

@ -9,7 +9,7 @@
<v-slider <v-slider
class="control-slider" class="control-slider"
label="maximum steps" label="maximum steps"
v-model="pstore.max_steps" v-model="maxStepComp"
:step="1" :step="1"
:min="10" :min="10"
:max="200" :max="200"
@ -53,12 +53,12 @@
<v-slider <v-slider
class="control-slider" class="control-slider"
label="greedy" label="greedy"
v-model="cstore.greedy" v-model="cstore.epsilon"
:step="0.01" :step="0.01"
:min="0" :min="0"
:max="1" :max="1"
thumb-label thumb-label
:disabled="greedyRef" :disabled="epsilonRef"
/> />
<v-row class="d-flex justify-space-evenly mb-9"> <v-row class="d-flex justify-space-evenly mb-9">
<v-btn color="primary" @click="cstore.setSendWeights()" <v-btn color="primary" @click="cstore.setSendWeights()"
@ -84,7 +84,12 @@ const acq_funs = cstore.acq_funs;
const episodeRef = ref(true); const episodeRef = ref(true);
const runRef = ref(true); const runRef = ref(true);
const acqRef = ref(true); const acqRef = ref(true);
const greedyRef = ref(true); const epsilonRef = ref(true);
const maxStepComp = computed({
get: () => pstore.max_steps,
set: (value) => pstore.setMaxSteps(value),
});
const baseFuncsComp = computed({ const baseFuncsComp = computed({
get: () => pstore.nr_weights, get: () => pstore.nr_weights,
@ -110,19 +115,19 @@ watch(
episodeRef.value = true; episodeRef.value = true;
runRef.value = true; runRef.value = true;
acqRef.value = true; acqRef.value = true;
greedyRef.value = true; epsilonRef.value = true;
} }
if (usrMode === "BO") { if (usrMode === "BO") {
episodeRef.value = false; episodeRef.value = false;
runRef.value = false; runRef.value = false;
acqRef.value = false; acqRef.value = false;
greedyRef.value = true; epsilonRef.value = true;
} }
if (usrMode === "active BO") { if (usrMode === "active BO") {
episodeRef.value = false; episodeRef.value = false;
runRef.value = false; runRef.value = false;
acqRef.value = false; acqRef.value = false;
greedyRef.value = false; epsilonRef.value = false;
} }
} }
); );

View File

@ -6,6 +6,7 @@
import { onMounted, watch } from "vue"; import { onMounted, watch } from "vue";
import { usePStore } from "@/store/PolicyStore"; import { usePStore } from "@/store/PolicyStore";
import { Chart } from "chart.js/auto"; import { Chart } from "chart.js/auto";
import { computeRbfCurve } from "@/js_funs/online_rbf_policy";
const store = usePStore(); const store = usePStore();
@ -85,6 +86,20 @@ watch(
chartHandle.update(); chartHandle.update();
} }
); );
watch(
() => store.getTrigger,
() => {
const policy = computeRbfCurve(store.getMaxSteps, store.getWeights);
store.setPolicy(policy);
chartHandle.options.scales.x.labels = Array(policy.length)
.fill(0)
.map((_, i) => i);
chartHandle.data.datasets[0].data = policy;
chartHandle.update();
}
);
</script> </script>
<style scoped></style> <style scoped></style>

View File

@ -110,14 +110,18 @@ const policy_service = new ROSLIB.Service({
watch( watch(
() => cstore.getSendWeights, () => cstore.getSendWeights,
() => { () => {
const policy_request = new ROSLIB.ServiceRequest({ const usr_mode = cstore.getUserMode;
weights: pstore.weights,
nr_steps: pstore.max_steps,
});
policy_service.callService(policy_request, function (result) { if (usr_mode === "manually") {
pstore.setPolicy(result.policy); const policy_request = new ROSLIB.ServiceRequest({
}); weights: pstore.weights,
nr_steps: pstore.max_steps,
});
policy_service.callService(policy_request, function (result) {
pstore.setPolicy(result.policy);
});
}
} }
); );
@ -133,6 +137,27 @@ rl_feedback_subscriber.subscribe((msg) => {
mcstore.setRgbArrays(msg.red, msg.green, msg.blue); mcstore.setRgbArrays(msg.red, msg.green, msg.blue);
}); });
const active_rl_eval_service = new ROSLIB.Service({
ros: ros,
name: "/active_rl_eval_srv",
serviceType: "active_bo_msgs/srv/ActiveRLEval",
});
active_rl_eval_service.advertise(function (request, response) {
pstore.setPolicy(request["old_policy"]);
pstore.setWeights(request["old_weights"]);
watch(
() => cstore.getSendWeights,
() => {
response["new_policy"] = pstore.getPolicy;
response["new_weights"] = pstore.getWeights;
return true;
}
);
});
const rl_service = new ROSLIB.Service({ const rl_service = new ROSLIB.Service({
ros: ros, ros: ros,
name: "/rl_srv", name: "/rl_srv",
@ -145,6 +170,12 @@ const bo_service = new ROSLIB.Service({
serviceType: "active_bo_msgs/srv/BO", serviceType: "active_bo_msgs/srv/BO",
}); });
const active_bo_service = new ROSLIB.Service({
ros: ros,
name: "/active_bo_srv",
serviceType: "active_bo_msgs/srv/ActiveBO",
});
watch( watch(
() => cstore.getRunner, () => cstore.getRunner,
() => { () => {
@ -177,6 +208,30 @@ watch(
policy: pstore.policy, policy: pstore.policy,
}); });
rl_service.callService(rl_request, () => {});
} else if (usr_mode === "active BO") {
const active_bo_request = new ROSLIB.ServiceRequest({
nr_weights: pstore.nr_weights,
max_steps: pstore.max_steps,
nr_episodes: cstore.nr_episodes,
nr_runs: cstore.nr_runs,
acquisition_function: cstore.acq_fun,
epsilon: cstore.epsilon,
});
active_bo_service.callService(
active_bo_request,
function (active_bo_response) {
pstore.setPolicy(active_bo_response.best_policy);
pstore.setWeights(active_bo_response.best_weights);
rstore.setMean(active_bo_response.reward_mean);
rstore.setStd(active_bo_response.reward_std);
}
);
const rl_request = new ROSLIB.ServiceRequest({
policy: pstore.policy,
});
rl_service.callService(rl_request, () => {}); rl_service.callService(rl_request, () => {});
} }
} }

View File

@ -1,10 +1,11 @@
<template> <template>
<v-row no-gutters justify="center" class="weight-tuner"> <v-row no-gutters justify="center" class="weight-tuner">
<!-- eslint-disable-next-line --> <!-- eslint-disable-next-line -->
<v-col v-for="(_, idx) in weights"> <v-col v-for="(weight, idx) in weights">
<div class="weight-container"> <div class="weight-container">
<vue-slider <vue-slider
v-model="weights[idx]" :value="weight"
@change="updateWeight(idx, $event)"
direction="btt" direction="btt"
:height="100" :height="100"
:min="-1" :min="-1"
@ -26,6 +27,13 @@ import { computed } from "vue";
const store = usePStore(); const store = usePStore();
const weights = computed(() => store.weights); const weights = computed(() => store.weights);
const updateWeight = (index, newValue) => {
const newWeights = weights.value.slice();
newWeights[index] = newValue;
store.setWeights(newWeights);
store.setTrigger();
};
</script> </script>
<style scoped> <style scoped>

View File

@ -0,0 +1,19 @@
function rbf(x, centre, sigma) {
return Math.exp(-Math.pow(x - centre, 2) / (2 * Math.pow(sigma, 2)));
}
export function computeRbfCurve(nrPoints, weights) {
const centers = Array.from(
{ length: weights.length },
(_, i) => (i * nrPoints) / (weights.length - 1)
);
const sigma = centers[1] / (2 * Math.sqrt(2 * Math.log(2)));
let policy = Array(nrPoints).fill(0);
for (let x = 0; x < nrPoints; x++) {
for (let i = 0; i < centers.length; i++) {
policy[x] += weights[i] * rbf(x, centers[i], sigma);
}
}
return policy;
}

View File

@ -13,7 +13,7 @@ export const useCStore = defineStore("Control Store", {
], ],
nr_episodes: 50, nr_episodes: 50,
nr_runs: 10, nr_runs: 10,
greedy: 0, epsilon: 0,
sendWeights: false, sendWeights: false,
runner: false, runner: false,
}; };
@ -22,7 +22,7 @@ export const useCStore = defineStore("Control Store", {
getUserMode: (state) => state.mode, getUserMode: (state) => state.mode,
getNrEpisodes: (state) => state.nr_episodes, getNrEpisodes: (state) => state.nr_episodes,
getNrRuns: (state) => state.nr_runs, getNrRuns: (state) => state.nr_runs,
getGreedy: (state) => state.greedy, getEpsilon: (state) => state.epsilon,
getSendWeights: (state) => state.sendWeights, getSendWeights: (state) => state.sendWeights,
getRunner: (state) => state.runner, getRunner: (state) => state.runner,
getAcq: (state) => state.acq_fun, getAcq: (state) => state.acq_fun,
@ -37,8 +37,8 @@ export const useCStore = defineStore("Control Store", {
setNrRuns(value) { setNrRuns(value) {
this.nr_runs = value; this.nr_runs = value;
}, },
setGreedy(value) { setEpsilon(value) {
this.greedy = value; this.epsilon = value;
}, },
setSendWeights() { setSendWeights() {
this.sendWeights = !this.sendWeights; this.sendWeights = !this.sendWeights;

View File

@ -5,8 +5,9 @@ export const usePStore = defineStore("Policy Store", {
return { return {
policy: Array(10).fill(0), policy: Array(10).fill(0),
nr_weights: 5, nr_weights: 5,
weights: [-1, -1, 1, 0, 0], weights: [0, 0, 0, 0, 0],
max_steps: 100, max_steps: 100,
trigger: false,
}; };
}, },
getters: { getters: {
@ -14,6 +15,7 @@ export const usePStore = defineStore("Policy Store", {
getNrWeights: (state) => state.nr_weights, getNrWeights: (state) => state.nr_weights,
getWeights: (state) => state.weights, getWeights: (state) => state.weights,
getMaxSteps: (state) => state.max_steps, getMaxSteps: (state) => state.max_steps,
getTrigger: (state) => state.trigger,
}, },
actions: { actions: {
setPolicy(value) { setPolicy(value) {
@ -23,12 +25,17 @@ export const usePStore = defineStore("Policy Store", {
setNrWeights(value) { setNrWeights(value) {
this.nr_weights = value; this.nr_weights = value;
this.weights = Array(this.nr_weights).fill(0); this.weights = Array(this.nr_weights).fill(0);
this.trigger = !this.trigger;
}, },
setWeights(value) { setWeights(value) {
this.weights = value; this.weights = value;
}, },
setMaxSteps(value) { setMaxSteps(value) {
this.max_steps = value; this.max_steps = value;
this.trigger = !this.trigger;
},
setTrigger() {
this.trigger = !this.trigger;
}, },
}, },
}); });