Franka added

This commit is contained in:
Niko Feith 2023-09-19 14:09:14 +02:00
parent d856697d06
commit 943faa48f0
6 changed files with 82 additions and 38 deletions

View File

@ -5,45 +5,25 @@
<script setup> <script setup>
import { onMounted, watch } from "vue"; import { onMounted, watch } from "vue";
import { usePStore } from "@/store/PolicyStore"; import { usePStore } from "@/store/PolicyStore";
import { useCStore } from "@/store/ControlStore";
import { useRLStore } from "@/store/RLStore";
import { Chart } from "chart.js/auto"; import { Chart } from "chart.js/auto";
import { computeRbfCurve } from "@/js_funs/online_rbf_policy"; import { computeRbfCurve } from "@/js_funs/online_rbf_policy";
import { useRLStore } from "@/store/RLStore";
const store = usePStore(); const store = usePStore();
const rlstore = useRLStore(); const rlstore = useRLStore();
const cstore = useCStore();
let chartHandle; let chartHandle;
// let verticalLinePlugin = {
// id: "verticalLinePlugin",
// afterDraw: function (chart, args, options) {
// let y_Scale = chart.scales["y"];
// let x_Scale = chart.scales["x"];
//
// let ctx = chart.ctx;
// let xValue = options.xValue;
// let x = x_Scale.getPixelForValue(xValue);
// let top = y_Scale.top;
// let bottom = y_Scale.bottom;
//
// ctx.save();
// ctx.beginPath();
// ctx.moveTo(x, top);
// ctx.lineTo(x, bottom);
// ctx.lineWidth = 1;
// ctx.strokeStyle = "#00ff00";
// ctx.stroke();
// ctx.restore();
// },
// };
// Chart.register(verticalLinePlugin);
function buildChart() { function buildChart() {
const policy_x = store.getPolicy; const policy_x = store.getPolicy;
const policy_y = store.getPolicy_y; const policy_y = store.getPolicy_y;
const current_x = policy_x[rlstore.current_time]; const current_x =
const current_y = policy_y[rlstore.current_time]; (cstore.bounds[cstore.env][2] + cstore.bounds[cstore.env][3]) / 2;
const current_y =
(cstore.bounds[cstore.env][3] + cstore.bounds[cstore.env][1]) / 2;
const policy_xy = []; const policy_xy = [];
@ -74,6 +54,34 @@ function buildChart() {
pointRadius: 3, pointRadius: 3,
order: 0, order: 0,
}, },
{
type: "scatter",
xAxisID: "x",
yAxisID: "y",
data: [{ x: 0.5, y: 0.3 }],
backgroundColor: "green",
pointRadius: 5,
order: 1,
},
{
type: "scatter",
xAxisID: "x",
yAxisID: "y",
data: [{ x: 0.5, y: -0.3 }],
backgroundColor: "blue",
pointRadius: 5,
order: 1,
},
{
type: "scatter",
xAxisID: "x",
yAxisID: "y",
data: [{ x: 0.5, y: 0 }],
backgroundColor: "red",
pointRadius: 75,
pointStyle: 'rect',
order: 0,
},
], ],
}, },
options: { options: {
@ -100,14 +108,14 @@ function buildChart() {
grid: { grid: {
display: false, display: false,
}, },
min: -1.1, min: cstore.bounds[cstore.env][0] - 0.1,
max: 1.1, max: cstore.bounds[cstore.env][2] + 0.1,
}, },
y: { y: {
type: "linear", type: "linear",
display: true, display: true,
min: -1.1, min: cstore.bounds[cstore.env][1] - 0.1,
max: 1.1, max: cstore.bounds[cstore.env][3] + 0.1,
}, },
}, },
}, },
@ -162,6 +170,27 @@ watch(
chartHandle.update(); chartHandle.update();
} }
); );
watch(
() => cstore.env,
() => {
const x_axis_lim = {
min: cstore.bounds[cstore.env][0] - 0.1,
max: cstore.bounds[cstore.env][2] + 0.1,
};
const y_axis_lim = {
min: cstore.bounds[cstore.env][1] - 0.1,
max: cstore.bounds[cstore.env][3] + 0.1,
};
chartHandle.options.scales.x.min = x_axis_lim.min;
chartHandle.options.scales.x.max = x_axis_lim.max;
chartHandle.options.scales.y.min = y_axis_lim.min;
chartHandle.options.scales.y.max = y_axis_lim.max;
chartHandle.update();
}
);
</script> </script>
<style scoped></style> <style scoped></style>

View File

@ -153,6 +153,7 @@ active_rl_eval_sub.subscribe((msg) => {
const weights_x = msg.weights.slice(0, nr_weights); const weights_x = msg.weights.slice(0, nr_weights);
const weights_y = msg.weights.slice(nr_weights); const weights_y = msg.weights.slice(nr_weights);
pstore.setTrigger();
pstore.setPolicy(pol_x); pstore.setPolicy(pol_x);
pstore.setPolicy_y(pol_y); pstore.setPolicy_y(pol_y);
pstore.setWeights(weights_x); pstore.setWeights(weights_x);
@ -210,6 +211,7 @@ active_bo_response.subscribe((msg) => {
pstore.setPolicy(pol_x); pstore.setPolicy(pol_x);
pstore.setPolicy_y(pol_y); pstore.setPolicy_y(pol_y);
pstore.setTrigger();
pstore.setWeights(weights_x); pstore.setWeights(weights_x);
pstore.setWeights_y(weights_y); pstore.setWeights_y(weights_y);
rstore.setMean(msg.reward_mean); rstore.setMean(msg.reward_mean);

View File

@ -9,9 +9,10 @@
@change="updateWeight(idx, $event)" @change="updateWeight(idx, $event)"
direction="btt" direction="btt"
:height="100" :height="100"
:min="-1" :min="cstore.bounds[cstore.env][0]"
:max="1" :max="cstore.bounds[cstore.env][2]"
:interval="0.01" :interval="0.01"
:disabled="idx === 0"
/> />
<v-checkbox <v-checkbox
class="ma-0 h-checkbox-bottom" class="ma-0 h-checkbox-bottom"
@ -24,12 +25,14 @@
<script setup> <script setup>
import { usePStore } from "@/store/PolicyStore"; import { usePStore } from "@/store/PolicyStore";
import { useCStore } from "@/store/ControlStore";
import VueSlider from "vue-slider-component"; import VueSlider from "vue-slider-component";
import "vue-slider-component/theme/default.css"; import "vue-slider-component/theme/default.css";
import { computed } from "vue"; import { computed } from "vue";
const store = usePStore(); const store = usePStore();
const cstore = useCStore();
const weights = computed({ const weights = computed({
get: () => store.weights, get: () => store.weights,

View File

@ -13,10 +13,11 @@
v-model="weights_y[idx]" v-model="weights_y[idx]"
@change="updateWeight(idx, $event)" @change="updateWeight(idx, $event)"
direction="ltr" direction="ltr"
:min="-1" :min="cstore.bounds[cstore.env][1]"
:max="1" :max="cstore.bounds[cstore.env][3]"
:interval="0.01" :interval="0.01"
:style="{}" :style="{}"
:disabled="idx === 0"
/> />
</v-col> </v-col>
<v-col class="v-column-content" :style="{ flexGrow: 0 }"> <v-col class="v-column-content" :style="{ flexGrow: 0 }">
@ -34,12 +35,14 @@
<script setup> <script setup>
import { usePStore } from "@/store/PolicyStore"; import { usePStore } from "@/store/PolicyStore";
import { useCStore } from "@/store/ControlStore";
import VueSlider from "vue-slider-component"; import VueSlider from "vue-slider-component";
import "vue-slider-component/theme/default.css"; import "vue-slider-component/theme/default.css";
import { computed } from "vue"; import { computed } from "vue";
const store = usePStore(); const store = usePStore();
const cstore = useCStore();
const nrweights = computed(() => store.nr_weights); const nrweights = computed(() => store.nr_weights);

View File

@ -3,11 +3,17 @@ import { defineStore } from "pinia";
export const useCStore = defineStore("Control Store", { export const useCStore = defineStore("Control Store", {
state: () => { state: () => {
return { return {
env: "Reacher", env: "Franka",
envs: ["Reacher", "Finger"], envs: ["Reacher", "Finger", "Franka"],
env_dim: { env_dim: {
Reacher: 2, Reacher: 2,
Finger: 2, Finger: 2,
Franka: 2,
},
bounds: {
Reacher: [-1, -1, 1, 1],
Finger: [-1, -1, 1, 1],
Franka: [0.3, -0.35, 0.7, 0.35],
}, },
metric: "random", metric: "random",
metrics: ["random", "regular", "improvement", "max_acquisition"], metrics: ["random", "regular", "improvement", "max_acquisition"],
@ -23,7 +29,7 @@ export const useCStore = defineStore("Control Store", {
improvement: [0.0, 1.0, 0.01], improvement: [0.0, 1.0, 0.01],
max_acquisition: [1, 200, 1], max_acquisition: [1, 200, 1],
}, },
acq_fun: "Expected Improvement", acq_fun: "Preference Expected Improvement",
acq_funs: [ acq_funs: [
"Expected Improvement", "Expected Improvement",
"Probability of Improvement", "Probability of Improvement",

View File

@ -53,6 +53,7 @@ export const usePStore = defineStore("Policy Store", {
this.trigger = !this.trigger; this.trigger = !this.trigger;
}, },
setTrigger() { setTrigger() {
console.log('Toggling trigger');
this.trigger = !this.trigger; this.trigger = !this.trigger;
}, },
resetWeights_Fixed() { resetWeights_Fixed() {