active bo implemented
not working due to watch in RosBar.vue line 150
This commit is contained in:
parent
9e61c3244b
commit
1db897b9d5
@ -9,7 +9,7 @@
|
||||
<v-slider
|
||||
class="control-slider"
|
||||
label="maximum steps"
|
||||
v-model="pstore.max_steps"
|
||||
v-model="maxStepComp"
|
||||
:step="1"
|
||||
:min="10"
|
||||
:max="200"
|
||||
@ -53,12 +53,12 @@
|
||||
<v-slider
|
||||
class="control-slider"
|
||||
label="greedy"
|
||||
v-model="cstore.greedy"
|
||||
v-model="cstore.epsilon"
|
||||
:step="0.01"
|
||||
:min="0"
|
||||
:max="1"
|
||||
thumb-label
|
||||
:disabled="greedyRef"
|
||||
:disabled="epsilonRef"
|
||||
/>
|
||||
<v-row class="d-flex justify-space-evenly mb-9">
|
||||
<v-btn color="primary" @click="cstore.setSendWeights()"
|
||||
@ -84,7 +84,12 @@ const acq_funs = cstore.acq_funs;
|
||||
const episodeRef = ref(true);
|
||||
const runRef = ref(true);
|
||||
const acqRef = ref(true);
|
||||
const greedyRef = ref(true);
|
||||
const epsilonRef = ref(true);
|
||||
|
||||
const maxStepComp = computed({
|
||||
get: () => pstore.max_steps,
|
||||
set: (value) => pstore.setMaxSteps(value),
|
||||
});
|
||||
|
||||
const baseFuncsComp = computed({
|
||||
get: () => pstore.nr_weights,
|
||||
@ -110,19 +115,19 @@ watch(
|
||||
episodeRef.value = true;
|
||||
runRef.value = true;
|
||||
acqRef.value = true;
|
||||
greedyRef.value = true;
|
||||
epsilonRef.value = true;
|
||||
}
|
||||
if (usrMode === "BO") {
|
||||
episodeRef.value = false;
|
||||
runRef.value = false;
|
||||
acqRef.value = false;
|
||||
greedyRef.value = true;
|
||||
epsilonRef.value = true;
|
||||
}
|
||||
if (usrMode === "active BO") {
|
||||
episodeRef.value = false;
|
||||
runRef.value = false;
|
||||
acqRef.value = false;
|
||||
greedyRef.value = false;
|
||||
epsilonRef.value = false;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
@ -6,6 +6,7 @@
|
||||
import { onMounted, watch } from "vue";
|
||||
import { usePStore } from "@/store/PolicyStore";
|
||||
import { Chart } from "chart.js/auto";
|
||||
import { computeRbfCurve } from "@/js_funs/online_rbf_policy";
|
||||
|
||||
const store = usePStore();
|
||||
|
||||
@ -85,6 +86,20 @@ watch(
|
||||
chartHandle.update();
|
||||
}
|
||||
);
|
||||
watch(
|
||||
() => store.getTrigger,
|
||||
() => {
|
||||
const policy = computeRbfCurve(store.getMaxSteps, store.getWeights);
|
||||
store.setPolicy(policy);
|
||||
|
||||
chartHandle.options.scales.x.labels = Array(policy.length)
|
||||
.fill(0)
|
||||
.map((_, i) => i);
|
||||
chartHandle.data.datasets[0].data = policy;
|
||||
|
||||
chartHandle.update();
|
||||
}
|
||||
);
|
||||
</script>
|
||||
|
||||
<style scoped></style>
|
||||
|
@ -110,6 +110,9 @@ const policy_service = new ROSLIB.Service({
|
||||
watch(
|
||||
() => cstore.getSendWeights,
|
||||
() => {
|
||||
const usr_mode = cstore.getUserMode;
|
||||
|
||||
if (usr_mode === "manually") {
|
||||
const policy_request = new ROSLIB.ServiceRequest({
|
||||
weights: pstore.weights,
|
||||
nr_steps: pstore.max_steps,
|
||||
@ -119,6 +122,7 @@ watch(
|
||||
pstore.setPolicy(result.policy);
|
||||
});
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// RL Service + Feedback Suscriber
|
||||
@ -133,6 +137,27 @@ rl_feedback_subscriber.subscribe((msg) => {
|
||||
mcstore.setRgbArrays(msg.red, msg.green, msg.blue);
|
||||
});
|
||||
|
||||
const active_rl_eval_service = new ROSLIB.Service({
|
||||
ros: ros,
|
||||
name: "/active_rl_eval_srv",
|
||||
serviceType: "active_bo_msgs/srv/ActiveRLEval",
|
||||
});
|
||||
|
||||
active_rl_eval_service.advertise(function (request, response) {
|
||||
pstore.setPolicy(request["old_policy"]);
|
||||
pstore.setWeights(request["old_weights"]);
|
||||
|
||||
watch(
|
||||
() => cstore.getSendWeights,
|
||||
() => {
|
||||
response["new_policy"] = pstore.getPolicy;
|
||||
response["new_weights"] = pstore.getWeights;
|
||||
|
||||
return true;
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
const rl_service = new ROSLIB.Service({
|
||||
ros: ros,
|
||||
name: "/rl_srv",
|
||||
@ -145,6 +170,12 @@ const bo_service = new ROSLIB.Service({
|
||||
serviceType: "active_bo_msgs/srv/BO",
|
||||
});
|
||||
|
||||
const active_bo_service = new ROSLIB.Service({
|
||||
ros: ros,
|
||||
name: "/active_bo_srv",
|
||||
serviceType: "active_bo_msgs/srv/ActiveBO",
|
||||
});
|
||||
|
||||
watch(
|
||||
() => cstore.getRunner,
|
||||
() => {
|
||||
@ -177,6 +208,30 @@ watch(
|
||||
policy: pstore.policy,
|
||||
});
|
||||
|
||||
rl_service.callService(rl_request, () => {});
|
||||
} else if (usr_mode === "active BO") {
|
||||
const active_bo_request = new ROSLIB.ServiceRequest({
|
||||
nr_weights: pstore.nr_weights,
|
||||
max_steps: pstore.max_steps,
|
||||
nr_episodes: cstore.nr_episodes,
|
||||
nr_runs: cstore.nr_runs,
|
||||
acquisition_function: cstore.acq_fun,
|
||||
epsilon: cstore.epsilon,
|
||||
});
|
||||
|
||||
active_bo_service.callService(
|
||||
active_bo_request,
|
||||
function (active_bo_response) {
|
||||
pstore.setPolicy(active_bo_response.best_policy);
|
||||
pstore.setWeights(active_bo_response.best_weights);
|
||||
rstore.setMean(active_bo_response.reward_mean);
|
||||
rstore.setStd(active_bo_response.reward_std);
|
||||
}
|
||||
);
|
||||
const rl_request = new ROSLIB.ServiceRequest({
|
||||
policy: pstore.policy,
|
||||
});
|
||||
|
||||
rl_service.callService(rl_request, () => {});
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,11 @@
|
||||
<template>
|
||||
<v-row no-gutters justify="center" class="weight-tuner">
|
||||
<!-- eslint-disable-next-line -->
|
||||
<v-col v-for="(_, idx) in weights">
|
||||
<v-col v-for="(weight, idx) in weights">
|
||||
<div class="weight-container">
|
||||
<vue-slider
|
||||
v-model="weights[idx]"
|
||||
:value="weight"
|
||||
@change="updateWeight(idx, $event)"
|
||||
direction="btt"
|
||||
:height="100"
|
||||
:min="-1"
|
||||
@ -26,6 +27,13 @@ import { computed } from "vue";
|
||||
const store = usePStore();
|
||||
|
||||
const weights = computed(() => store.weights);
|
||||
|
||||
const updateWeight = (index, newValue) => {
|
||||
const newWeights = weights.value.slice();
|
||||
newWeights[index] = newValue;
|
||||
store.setWeights(newWeights);
|
||||
store.setTrigger();
|
||||
};
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
19
ActiveBOWeb/src/js_funs/online_rbf_policy.js
Normal file
19
ActiveBOWeb/src/js_funs/online_rbf_policy.js
Normal file
@ -0,0 +1,19 @@
|
||||
function rbf(x, centre, sigma) {
|
||||
return Math.exp(-Math.pow(x - centre, 2) / (2 * Math.pow(sigma, 2)));
|
||||
}
|
||||
|
||||
export function computeRbfCurve(nrPoints, weights) {
|
||||
const centers = Array.from(
|
||||
{ length: weights.length },
|
||||
(_, i) => (i * nrPoints) / (weights.length - 1)
|
||||
);
|
||||
const sigma = centers[1] / (2 * Math.sqrt(2 * Math.log(2)));
|
||||
|
||||
let policy = Array(nrPoints).fill(0);
|
||||
for (let x = 0; x < nrPoints; x++) {
|
||||
for (let i = 0; i < centers.length; i++) {
|
||||
policy[x] += weights[i] * rbf(x, centers[i], sigma);
|
||||
}
|
||||
}
|
||||
return policy;
|
||||
}
|
@ -13,7 +13,7 @@ export const useCStore = defineStore("Control Store", {
|
||||
],
|
||||
nr_episodes: 50,
|
||||
nr_runs: 10,
|
||||
greedy: 0,
|
||||
epsilon: 0,
|
||||
sendWeights: false,
|
||||
runner: false,
|
||||
};
|
||||
@ -22,7 +22,7 @@ export const useCStore = defineStore("Control Store", {
|
||||
getUserMode: (state) => state.mode,
|
||||
getNrEpisodes: (state) => state.nr_episodes,
|
||||
getNrRuns: (state) => state.nr_runs,
|
||||
getGreedy: (state) => state.greedy,
|
||||
getEpsilon: (state) => state.epsilon,
|
||||
getSendWeights: (state) => state.sendWeights,
|
||||
getRunner: (state) => state.runner,
|
||||
getAcq: (state) => state.acq_fun,
|
||||
@ -37,8 +37,8 @@ export const useCStore = defineStore("Control Store", {
|
||||
setNrRuns(value) {
|
||||
this.nr_runs = value;
|
||||
},
|
||||
setGreedy(value) {
|
||||
this.greedy = value;
|
||||
setEpsilon(value) {
|
||||
this.epsilon = value;
|
||||
},
|
||||
setSendWeights() {
|
||||
this.sendWeights = !this.sendWeights;
|
||||
|
@ -5,8 +5,9 @@ export const usePStore = defineStore("Policy Store", {
|
||||
return {
|
||||
policy: Array(10).fill(0),
|
||||
nr_weights: 5,
|
||||
weights: [-1, -1, 1, 0, 0],
|
||||
weights: [0, 0, 0, 0, 0],
|
||||
max_steps: 100,
|
||||
trigger: false,
|
||||
};
|
||||
},
|
||||
getters: {
|
||||
@ -14,6 +15,7 @@ export const usePStore = defineStore("Policy Store", {
|
||||
getNrWeights: (state) => state.nr_weights,
|
||||
getWeights: (state) => state.weights,
|
||||
getMaxSteps: (state) => state.max_steps,
|
||||
getTrigger: (state) => state.trigger,
|
||||
},
|
||||
actions: {
|
||||
setPolicy(value) {
|
||||
@ -23,12 +25,17 @@ export const usePStore = defineStore("Policy Store", {
|
||||
setNrWeights(value) {
|
||||
this.nr_weights = value;
|
||||
this.weights = Array(this.nr_weights).fill(0);
|
||||
this.trigger = !this.trigger;
|
||||
},
|
||||
setWeights(value) {
|
||||
this.weights = value;
|
||||
},
|
||||
setMaxSteps(value) {
|
||||
this.max_steps = value;
|
||||
this.trigger = !this.trigger;
|
||||
},
|
||||
setTrigger() {
|
||||
this.trigger = !this.trigger;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user