preparation for 2d policy

This commit is contained in:
Niko Feith 2023-08-23 17:06:18 +02:00
parent d92b84b53f
commit 4770ffb12d
3 changed files with 95 additions and 62 deletions

View File

@ -14,46 +14,52 @@ const rlstore = useRLStore();
let chartHandle; let chartHandle;
let verticalLinePlugin = { // let verticalLinePlugin = {
id: "verticalLinePlugin", // id: "verticalLinePlugin",
afterDraw: function (chart, args, options) { // afterDraw: function (chart, args, options) {
let y_Scale = chart.scales["y"]; // let y_Scale = chart.scales["y"];
let x_Scale = chart.scales["x"]; // let x_Scale = chart.scales["x"];
//
let ctx = chart.ctx; // let ctx = chart.ctx;
let xValue = options.xValue; // let xValue = options.xValue;
let x = x_Scale.getPixelForValue(xValue); // let x = x_Scale.getPixelForValue(xValue);
let top = y_Scale.top; // let top = y_Scale.top;
let bottom = y_Scale.bottom; // let bottom = y_Scale.bottom;
//
ctx.save(); // ctx.save();
ctx.beginPath(); // ctx.beginPath();
ctx.moveTo(x, top); // ctx.moveTo(x, top);
ctx.lineTo(x, bottom); // ctx.lineTo(x, bottom);
ctx.lineWidth = 1; // ctx.lineWidth = 1;
ctx.strokeStyle = "#00ff00"; // ctx.strokeStyle = "#00ff00";
ctx.stroke(); // ctx.stroke();
ctx.restore(); // ctx.restore();
}, // },
}; // };
Chart.register(verticalLinePlugin); // Chart.register(verticalLinePlugin);
function buildChart() { function buildChart() {
const policy = store.getPolicy; const policy_x = store.getPolicy;
const policy_labels = Array(policy.length) const policy_y = store.getPolicy_y;
.fill(0)
.map((_, i) => i); const policy_xy = [];
for (let i = 0; i < policy_x.length; i += 1) {
policy_xy[policy_xy.length] = { x: policy_x[i], y: policy_y[i] };
}
const RewardPlot = { const RewardPlot = {
type: "scatter",
data: { data: {
datasets: [ datasets: [
{ {
type: "line", type: "line",
xAxisID: "x", xAxisID: "x",
yAxisID: "y", yAxisID: "y",
data: policy, data: policy_xy,
borderColor: "#3DC47A", borderColor: "#3DC47A",
borderWidth: 2, borderWidth: 2,
pointRadius: 0, pointRadius: 0.1,
}, },
], ],
}, },
@ -67,9 +73,9 @@ function buildChart() {
legend: { legend: {
display: false, display: false,
}, },
verticalLinePlugin: { // verticalLinePlugin: {
xValue: rlstore.getCurrentTime, // xValue: rlstore.getCurrentTime,
}, // },
}, },
lineTension: 0, lineTension: 0,
tooltip: { tooltip: {
@ -77,19 +83,18 @@ function buildChart() {
}, },
scales: { scales: {
x: { x: {
display: true,
grid: { grid: {
display: false, display: false,
}, },
labels: policy_labels, min: -1,
ticks: { max: 1,
autoSkip: true,
},
}, },
y: { y: {
type: "linear", type: "linear",
ticks: { display: true,
beginAtZero: true, min: -1,
}, max: 1,
}, },
}, },
}, },
@ -106,13 +111,18 @@ onMounted(() => {
watch( watch(
() => rlstore.current_time, () => rlstore.current_time,
() => { () => {
const policy = store.getPolicy; const policy_x = store.getPolicy;
chartHandle.options.scales.x.labels = Array(policy.length) const policy_y = store.getPolicy_y;
.fill(0)
.map((_, i) => i); const policy_xy = [];
chartHandle.data.datasets[0].data = policy;
chartHandle.options.plugins.verticalLinePlugin.xValue = for (let i = 0; i < policy_x.length; i += 1) {
rlstore.current_time; policy_xy[policy_xy.length] = { x: policy_y[i], y: policy_x[i] };
}
chartHandle.data.datasets[0].data = policy_xy;
// chartHandle.options.plugins.verticalLinePlugin.xValue =
// rlstore.current_time;
chartHandle.update(); chartHandle.update();
} }
@ -120,15 +130,20 @@ watch(
watch( watch(
() => store.getTrigger, () => store.getTrigger,
() => { () => {
const policy = computeRbfCurve(store.getMaxSteps, store.getWeights); const policy_x = computeRbfCurve(store.getMaxSteps, store.getWeights);
store.setPolicy(policy); store.setPolicy(policy_x);
const policy_y = computeRbfCurve(store.getMaxSteps, store.getWeights_y);
store.setPolicy_y(policy_y);
chartHandle.options.scales.x.labels = Array(policy.length) const policy_xy = [];
.fill(0)
.map((_, i) => i); for (let i = 0; i < policy_x.length; i += 1) {
chartHandle.data.datasets[0].data = policy; policy_xy[policy_xy.length] = { x: policy_y[i], y: policy_x[i] };
chartHandle.options.plugins.verticalLinePlugin.xValue = }
rlstore.current_time;
chartHandle.data.datasets[0].data = policy_xy;
// chartHandle.options.plugins.verticalLinePlugin.xValue =
// rlstore.current_time;
chartHandle.update(); chartHandle.update();
} }

View File

@ -1,7 +1,7 @@
<template> <template>
<v-row no-gutters justify="center" class="v-weight-tuner"> <v-row no-gutters justify="center" class="v-weight-tuner">
<!-- eslint-disable-next-line --> <!-- eslint-disable-next-line -->
<v-col v-for="(_ , idx) in weights" :key="idx" <v-col v-for="(_ , idx) in weights_y" :key="idx"
class="v-column" class="v-column"
:style="{ height: `calc(100% / ${nrweights})` }" :style="{ height: `calc(100% / ${nrweights})` }"
> >
@ -10,7 +10,7 @@
<v-col class="v-column-content"> <v-col class="v-column-content">
<vue-slider <vue-slider
class="v-slider-margin-bottom" class="v-slider-margin-bottom"
v-model="weights[idx]" v-model="weights_y[idx]"
@change="updateWeight(idx, $event)" @change="updateWeight(idx, $event)"
direction="ltr" direction="ltr"
:min="-1" :min="-1"
@ -23,7 +23,7 @@
<v-checkbox <v-checkbox
density="compact" density="compact"
class="ma-0 v-checkbox-bottom" class="ma-0 v-checkbox-bottom"
v-model="store.weights_fixed[idx]" v-model="store.weights_fixed_y[idx]"
/> />
</v-col> </v-col>
</v-row> </v-row>
@ -43,15 +43,15 @@ const store = usePStore();
const nrweights = computed(() => store.nr_weights); const nrweights = computed(() => store.nr_weights);
const weights = computed({ const weights_y = computed({
get: () => store.weights, get: () => store.weights_y,
set: (value) => store.setWeights(value), set: (value) => store.setWeights_y(value),
}); });
const updateWeight = (index, newValue) => { const updateWeight = (index, newValue) => {
const newWeights = weights.value.slice(); const newWeights = weights_y.value.slice();
newWeights[index] = newValue; newWeights[index] = newValue;
store.setWeights(newWeights); store.setWeights_y(newWeights);
store.setTrigger(); store.setTrigger();
}; };
</script> </script>

View File

@ -4,18 +4,24 @@ export const usePStore = defineStore("Policy Store", {
state: () => { state: () => {
return { return {
policy: Array(10).fill(0), policy: Array(10).fill(0),
policy_y: Array(10).fill(0),
nr_weights: 5, nr_weights: 5,
weights: [0, 0, 0, 0, 0], weights: [0, 0, 0, 0, 0],
weights_y: [0, 0, 0, 0, 0],
weights_fixed: [false, false, false, false, false], weights_fixed: [false, false, false, false, false],
weights_fixed_y: [false, false, false, false, false],
max_steps: 100, max_steps: 100,
trigger: false, trigger: false,
}; };
}, },
getters: { getters: {
getPolicy: (state) => state.policy, getPolicy: (state) => state.policy,
getPolicy_y: (state) => state.policy_y,
getNrWeights: (state) => state.nr_weights, getNrWeights: (state) => state.nr_weights,
getWeights_Fixed: (state) => state.weights_fixed, getWeights_Fixed: (state) => state.weights_fixed,
getWeights: (state) => state.weights, getWeights: (state) => state.weights,
getWeights_Fixed_y: (state) => state.weights_fixed_y,
getWeights_y: (state) => state.weights_y,
getMaxSteps: (state) => state.max_steps, getMaxSteps: (state) => state.max_steps,
getTrigger: (state) => state.trigger, getTrigger: (state) => state.trigger,
}, },
@ -24,15 +30,24 @@ export const usePStore = defineStore("Policy Store", {
this.policy = null; this.policy = null;
this.policy = value; this.policy = value;
}, },
setPolicy_y(value) {
this.policy_y = null;
this.policy_y = value;
},
setNrWeights(value) { setNrWeights(value) {
this.nr_weights = value; this.nr_weights = value;
this.weights = Array(this.nr_weights).fill(0); this.weights = Array(this.nr_weights).fill(0);
this.weights_fixed = Array(this.nr_weights).fill(false); this.weights_fixed = Array(this.nr_weights).fill(false);
this.weights_y = Array(this.nr_weights).fill(0);
this.weights_fixed_y = Array(this.nr_weights).fill(false);
this.trigger = !this.trigger; this.trigger = !this.trigger;
}, },
setWeights(value) { setWeights(value) {
this.weights = value; this.weights = value;
}, },
setWeights_y(value) {
this.weights_y = value;
},
setMaxSteps(value) { setMaxSteps(value) {
this.max_steps = value; this.max_steps = value;
this.trigger = !this.trigger; this.trigger = !this.trigger;
@ -43,5 +58,8 @@ export const usePStore = defineStore("Policy Store", {
resetWeights_Fixed() { resetWeights_Fixed() {
this.weights_fixed = Array(this.nr_weights).fill(false); this.weights_fixed = Array(this.nr_weights).fill(false);
}, },
resetWeights_Fixed_y() {
this.weights_fixed_y = Array(this.nr_weights).fill(false);
},
}, },
}); });