integration of all but acrobot envs and msg + srvs adaption

This commit is contained in:
Niko Feith 2023-05-24 17:46:18 +02:00
parent 75e5b034df
commit 6d09535113
6 changed files with 83 additions and 62 deletions

View File

@ -1,11 +1,18 @@
<template> <template>
<v-row> <v-row>
<v-col> <v-col>
<v-select <v-row class="d-flex justify-space-evenly">
:items="options" <v-select
label="User Mode" :items="envs"
v-model="modeSelector" label="Environments"
></v-select> v-model="envSelector"
></v-select>
<v-select
:items="metrics"
label="Metric"
v-model="metricSelector"
></v-select>
</v-row>
<v-slider <v-slider
class="control-slider" class="control-slider"
label="maximum steps" label="maximum steps"
@ -52,8 +59,8 @@
></v-select> ></v-select>
<v-slider <v-slider
class="control-slider" class="control-slider"
label="greedy" label="Metric Parameter"
v-model="cstore.epsilon" v-model="cstore.metric_parameter"
:step="0.01" :step="0.01"
:min="0" :min="0"
:max="1" :max="1"
@ -78,7 +85,8 @@ import { computed, watch, ref } from "vue";
const pstore = usePStore(); const pstore = usePStore();
const cstore = useCStore(); const cstore = useCStore();
const options = cstore.user_modes; const envs = cstore.envs;
const metrics = cstore.metircs;
const acq_funs = cstore.acq_funs; const acq_funs = cstore.acq_funs;
const episodeRef = ref(true); const episodeRef = ref(true);
@ -96,9 +104,14 @@ const baseFuncsComp = computed({
set: (value) => pstore.setNrWeights(value), set: (value) => pstore.setNrWeights(value),
}); });
const modeSelector = computed({ const envSelector = computed({
get: () => cstore.mode, get: () => cstore.env,
set: (value) => cstore.setUserMode(value), set: (value) => cstore.setEnv(value),
});
const metricSelector = computed({
get: () => cstore.metric,
set: (value) => cstore.setMetric(value),
}); });
const acqSelector = computed({ const acqSelector = computed({
@ -107,9 +120,9 @@ const acqSelector = computed({
}); });
watch( watch(
() => cstore.getUserMode, () => cstore.getMetric,
() => { () => {
const usrMode = cstore.getUserMode; const usrMode = cstore.getMetric;
if (usrMode === "manually") { if (usrMode === "manually") {
episodeRef.value = true; episodeRef.value = true;

View File

@ -4,16 +4,16 @@
<script setup> <script setup>
import { onMounted, ref, watch } from "vue"; import { onMounted, ref, watch } from "vue";
import { useMCStore } from "@/store/MountainCarStore"; import { useRLStore } from "@/store/RLStore";
const mcstore = useMCStore(); const rlstore = useRLStore();
const imageData = ref(null); const imageData = ref(null);
const renderImage = (width, height) => { const renderImage = (width, height) => {
const red = mcstore.getRed; const red = rlstore.getRed;
const green = mcstore.getGreen; const green = rlstore.getGreen;
const blue = mcstore.getBlue; const blue = rlstore.getBlue;
const rgbData = new Uint8ClampedArray(height * width * 4); const rgbData = new Uint8ClampedArray(height * width * 4);
for (let i = 0; i < height * width; i++) { for (let i = 0; i < height * width; i++) {
@ -38,17 +38,17 @@ const drawImage = (width, height) => {
}; };
onMounted(() => { onMounted(() => {
const height = mcstore.getHeight; const height = rlstore.getHeight;
const width = mcstore.getWidth; const width = rlstore.getWidth;
renderImage(width, height); renderImage(width, height);
drawImage(width, height); drawImage(width, height);
}); });
watch( watch(
() => mcstore.trigger, () => rlstore.trigger,
() => { () => {
const height = mcstore.getHeight; const height = rlstore.getHeight;
const width = mcstore.getWidth; const width = rlstore.getWidth;
renderImage(width, height); renderImage(width, height);
drawImage(width, height); drawImage(width, height);
} }

View File

@ -30,18 +30,16 @@ import * as ROS from "roslib";
import { useBWStore } from "@/store/BaseWebsiteStore"; import { useBWStore } from "@/store/BaseWebsiteStore";
import { usePStore } from "@/store/PolicyStore"; import { usePStore } from "@/store/PolicyStore";
import { useCStore } from "@/store/ControlStore"; import { useCStore } from "@/store/ControlStore";
import { useRStore } from "@/store/RewardStore";
import { useRLStore } from "@/store/RLStore";
import { reactive, ref, watch } from "vue"; import { reactive, ref, watch } from "vue";
import * as ROSLIB from "roslib";
import { useRStore } from "@/store/RewardStore";
import { useMCStore } from "@/store/MountainCarStore";
const store = useBWStore(); const store = useBWStore();
const pstore = usePStore(); const pstore = usePStore();
const cstore = useCStore(); const cstore = useCStore();
const rstore = useRStore(); const rstore = useRStore();
const mcstore = useMCStore(); const rlstore = useRLStore();
const formState = reactive({ const formState = reactive({
ipaddress: "localhost", ipaddress: "localhost",
@ -108,25 +106,25 @@ watch(stateCounter.value, (newValue) => {
}); });
// Policy Service // Policy Service
const policy_service = new ROSLIB.Service({ const policy_service = new ROS.Service({
ros: ros, ros: ros,
name: "/policy_srv", name: "/policy_srv",
serviceType: "active_bo_msgs/srv/WeightToPolicy", serviceType: "active_bo_msgs/srv/WeightToPolicy",
}); });
// RL Service + Feedback Suscriber // RL Service + Feedback Subscriber
const rl_feedback_subscriber = new ROSLIB.Topic({ const rl_feedback_subscriber = new ROS.Topic({
ros: ros, ros: ros,
name: "/rl_feedback", name: "/rl_feedback",
messageType: "active_bo_msgs/msg/ImageFeedback", messageType: "active_bo_msgs/msg/ImageFeedback",
}); });
rl_feedback_subscriber.subscribe((msg) => { rl_feedback_subscriber.subscribe((msg) => {
mcstore.setDim(msg.height, msg.width); rlstore.setDim(msg.height, msg.width);
mcstore.setRgbArrays(msg.red, msg.green, msg.blue); rlstore.setRgbArrays(msg.red, msg.green, msg.blue);
}); });
const active_rl_eval_sub = new ROSLIB.Topic({ const active_rl_eval_sub = new ROS.Topic({
ros: ros_eval, ros: ros_eval,
name: "/active_rl_eval_request", name: "/active_rl_eval_request",
messageType: "active_bo_msgs/msg/ActiveRL", messageType: "active_bo_msgs/msg/ActiveRL",
@ -139,7 +137,7 @@ active_rl_eval_sub.subscribe((msg) => {
pendingRequest.value = true; pendingRequest.value = true;
}); });
const active_rl_eval_pub = new ROSLIB.Topic({ const active_rl_eval_pub = new ROS.Topic({
ros: ros_eval, ros: ros_eval,
name: "/active_rl_eval_response", name: "/active_rl_eval_response",
messageType: "active_bo_msgs/msg/ActiveRL", messageType: "active_bo_msgs/msg/ActiveRL",
@ -159,10 +157,10 @@ const active_rl_eval_pub = new ROSLIB.Topic({
watch( watch(
() => cstore.getSendWeights, () => cstore.getSendWeights,
() => { () => {
const usr_mode = cstore.getUserMode; const metric = cstore.getMetric;
if (usr_mode === "manually") { if (metric === "manually") {
const policy_request = new ROSLIB.ServiceRequest({ const policy_request = new ROS.ServiceRequest({
weights: pstore.policy, weights: pstore.policy,
nr_steps: pstore.weights, nr_steps: pstore.weights,
}); });
@ -170,12 +168,12 @@ watch(
policy_service.callService(policy_request, function (result) { policy_service.callService(policy_request, function (result) {
pstore.setPolicy(result.policy); pstore.setPolicy(result.policy);
}); });
} else if (usr_mode === "active BO") { } else if (metric === "active BO") {
if (!pendingRequest.value) { if (!pendingRequest.value) {
return; return;
} }
console.log("Button pressed!"); console.log("Button pressed!");
const active_eval_response = new ROSLIB.Message({ const active_eval_response = new ROS.Message({
policy: pstore.policy, policy: pstore.policy,
weights: pstore.weights, weights: pstore.weights,
}); });
@ -190,13 +188,13 @@ watch(
} }
); );
const rl_service = new ROSLIB.Service({ const rl_service = new ROS.Service({
ros: ros, ros: ros,
name: "/rl_srv", name: "/rl_srv",
serviceType: "active_bo_msgs/srv/RLRollOut", serviceType: "active_bo_msgs/srv/RLRollOut",
}); });
const bo_service = new ROSLIB.Service({ const bo_service = new ROS.Service({
ros: ros, ros: ros,
name: "/bo_srv", name: "/bo_srv",
serviceType: "active_bo_msgs/srv/BO", serviceType: "active_bo_msgs/srv/BO",
@ -209,13 +207,13 @@ const bo_service = new ROSLIB.Service({
// }); // });
const active_bo_pending = ref(false); const active_bo_pending = ref(false);
const active_bo_request = new ROSLIB.Topic({ const active_bo_request = new ROS.Topic({
ros: ros, ros: ros,
name: "/active_bo_request", name: "/active_bo_request",
messageType: "active_bo_msgs/msg/ActiveBORequest", messageType: "active_bo_msgs/msg/ActiveBORequest",
}); });
const active_bo_response = new ROSLIB.Topic({ const active_bo_response = new ROS.Topic({
ros: ros, ros: ros,
name: "/active_bo_response", name: "/active_bo_response",
messageType: "active_bo_msgs/msg/ActiveBOResponse", messageType: "active_bo_msgs/msg/ActiveBOResponse",
@ -232,10 +230,11 @@ active_bo_response.subscribe((msg) => {
watch( watch(
() => cstore.getRunner, () => cstore.getRunner,
() => { () => {
const usr_mode = cstore.getUserMode; const usr_mode = cstore.getMetric;
if (usr_mode === "manually") { if (usr_mode === "manually") {
const rl_request = new ROSLIB.ServiceRequest({ const rl_request = new ROS.ServiceRequest({
env: cstore.env,
policy: pstore.policy, policy: pstore.policy,
}); });
@ -243,7 +242,7 @@ watch(
rstore.addMeanManually(rl_response.reward); rstore.addMeanManually(rl_response.reward);
}); });
} else if (usr_mode === "BO") { } else if (usr_mode === "BO") {
const bo_request = new ROSLIB.ServiceRequest({ const bo_request = new ROS.ServiceRequest({
nr_weights: pstore.nr_weights, nr_weights: pstore.nr_weights,
max_steps: pstore.max_steps, max_steps: pstore.max_steps,
nr_episodes: cstore.nr_episodes, nr_episodes: cstore.nr_episodes,
@ -257,23 +256,26 @@ watch(
rstore.setMean(bo_response.reward_mean); rstore.setMean(bo_response.reward_mean);
rstore.setStd(bo_response.reward_std); rstore.setStd(bo_response.reward_std);
}); });
const rl_request = new ROSLIB.ServiceRequest({ const rl_request = new ROS.ServiceRequest({
env: cstore.env,
policy: pstore.policy, policy: pstore.policy,
}); });
rl_service.callService(rl_request, () => {}); rl_service.callService(rl_request, () => {});
} else if (usr_mode === "active BO" && !active_bo_pending.value) { } else if (usr_mode === "active BO" && !active_bo_pending.value) {
const request_msg = new ROSLIB.Message({ const request_msg = new ROS.Message({
env: cstore.env,
nr_weights: pstore.nr_weights, nr_weights: pstore.nr_weights,
max_steps: pstore.max_steps, max_steps: pstore.max_steps,
nr_episodes: cstore.nr_episodes, nr_episodes: cstore.nr_episodes,
nr_runs: cstore.nr_runs, nr_runs: cstore.nr_runs,
acquisition_function: cstore.acq_fun, acquisition_function: cstore.acq_fun,
epsilon: cstore.epsilon, metric_parameter: cstore.metric_parameter,
}); });
active_bo_request.publish(request_msg); active_bo_request.publish(request_msg);
const rl_request = new ROSLIB.ServiceRequest({ const rl_request = new ROS.ServiceRequest({
env: cstore.env,
policy: pstore.policy, policy: pstore.policy,
}); });

View File

@ -58,7 +58,7 @@ import WeightTuner from "@/components/WeightTuner.vue";
import RewardPlot from "@/components/RewardPlot.vue"; import RewardPlot from "@/components/RewardPlot.vue";
import PolicyPlot from "@/components/PolicyPlot.vue"; import PolicyPlot from "@/components/PolicyPlot.vue";
import ControlPanel from "@/components/ControlPanel.vue"; import ControlPanel from "@/components/ControlPanel.vue";
import MountainCarCanvas from "@/components/MountainCarCanvas.vue"; import MountainCarCanvas from "@/components/RLCanvas.vue";
</script> </script>
<style scoped> <style scoped>

View File

@ -3,8 +3,10 @@ import { defineStore } from "pinia";
export const useCStore = defineStore("Control Store", { export const useCStore = defineStore("Control Store", {
state: () => { state: () => {
return { return {
mode: "manually", env: "Mountain Car",
user_modes: ["manually", "BO", "active BO"], envs: ["Mountain Car", "Cartpole", "Acrobot", "Pendulum"],
metric: "manually",
metircs: ["manually", "BO", "active BO"],
acq_fun: "Expected Improvement", acq_fun: "Expected Improvement",
acq_funs: [ acq_funs: [
"Expected Improvement", "Expected Improvement",
@ -13,23 +15,27 @@ export const useCStore = defineStore("Control Store", {
], ],
nr_episodes: 50, nr_episodes: 50,
nr_runs: 10, nr_runs: 10,
epsilon: 0, metric_parameter: 0,
sendWeights: false, sendWeights: false,
runner: false, runner: false,
}; };
}, },
getters: { getters: {
getUserMode: (state) => state.mode, getEnv: (state) => state.env,
getMetric: (state) => state.metric,
getNrEpisodes: (state) => state.nr_episodes, getNrEpisodes: (state) => state.nr_episodes,
getNrRuns: (state) => state.nr_runs, getNrRuns: (state) => state.nr_runs,
getEpsilon: (state) => state.epsilon, getMetricParameter: (state) => state.metric_parameter,
getSendWeights: (state) => state.sendWeights, getSendWeights: (state) => state.sendWeights,
getRunner: (state) => state.runner, getRunner: (state) => state.runner,
getAcq: (state) => state.acq_fun, getAcq: (state) => state.acq_fun,
}, },
actions: { actions: {
setUserMode(value) { setEnv(value) {
this.mode = value; this.env = value;
},
setMetric(value) {
this.metric = value;
}, },
setNrEpisodes(value) { setNrEpisodes(value) {
this.nr_episodes = value; this.nr_episodes = value;
@ -37,8 +43,8 @@ export const useCStore = defineStore("Control Store", {
setNrRuns(value) { setNrRuns(value) {
this.nr_runs = value; this.nr_runs = value;
}, },
setEpsilon(value) { setMetricParameter(value) {
this.epsilon = value; this.metric_parameter = value;
}, },
setSendWeights() { setSendWeights() {
this.sendWeights = !this.sendWeights; this.sendWeights = !this.sendWeights;

View File

@ -1,6 +1,6 @@
import { defineStore } from "pinia"; import { defineStore } from "pinia";
export const useMCStore = defineStore("Mountain Car Store", { export const useRLStore = defineStore("RL Store", {
state: () => { state: () => {
return { return {
red: Array(153600).fill(120), red: Array(153600).fill(120),