preparation for 2d policy

This commit is contained in:
Niko Feith 2023-08-23 17:06:18 +02:00
parent d92b84b53f
commit 4770ffb12d
3 changed files with 95 additions and 62 deletions

View File

@ -14,46 +14,52 @@ const rlstore = useRLStore();
let chartHandle;
let verticalLinePlugin = {
id: "verticalLinePlugin",
afterDraw: function (chart, args, options) {
let y_Scale = chart.scales["y"];
let x_Scale = chart.scales["x"];
let ctx = chart.ctx;
let xValue = options.xValue;
let x = x_Scale.getPixelForValue(xValue);
let top = y_Scale.top;
let bottom = y_Scale.bottom;
ctx.save();
ctx.beginPath();
ctx.moveTo(x, top);
ctx.lineTo(x, bottom);
ctx.lineWidth = 1;
ctx.strokeStyle = "#00ff00";
ctx.stroke();
ctx.restore();
},
};
Chart.register(verticalLinePlugin);
// let verticalLinePlugin = {
// id: "verticalLinePlugin",
// afterDraw: function (chart, args, options) {
// let y_Scale = chart.scales["y"];
// let x_Scale = chart.scales["x"];
//
// let ctx = chart.ctx;
// let xValue = options.xValue;
// let x = x_Scale.getPixelForValue(xValue);
// let top = y_Scale.top;
// let bottom = y_Scale.bottom;
//
// ctx.save();
// ctx.beginPath();
// ctx.moveTo(x, top);
// ctx.lineTo(x, bottom);
// ctx.lineWidth = 1;
// ctx.strokeStyle = "#00ff00";
// ctx.stroke();
// ctx.restore();
// },
// };
// Chart.register(verticalLinePlugin);
function buildChart() {
const policy = store.getPolicy;
const policy_labels = Array(policy.length)
.fill(0)
.map((_, i) => i);
const policy_x = store.getPolicy;
const policy_y = store.getPolicy_y;
const policy_xy = [];
for (let i = 0; i < policy_x.length; i += 1) {
policy_xy[policy_xy.length] = { x: policy_x[i], y: policy_y[i] };
}
const RewardPlot = {
type: "scatter",
data: {
datasets: [
{
type: "line",
xAxisID: "x",
yAxisID: "y",
data: policy,
data: policy_xy,
borderColor: "#3DC47A",
borderWidth: 2,
pointRadius: 0,
pointRadius: 0.1,
},
],
},
@ -67,9 +73,9 @@ function buildChart() {
legend: {
display: false,
},
verticalLinePlugin: {
xValue: rlstore.getCurrentTime,
},
// verticalLinePlugin: {
// xValue: rlstore.getCurrentTime,
// },
},
lineTension: 0,
tooltip: {
@ -77,19 +83,18 @@ function buildChart() {
},
scales: {
x: {
display: true,
grid: {
display: false,
},
labels: policy_labels,
ticks: {
autoSkip: true,
},
min: -1,
max: 1,
},
y: {
type: "linear",
ticks: {
beginAtZero: true,
},
display: true,
min: -1,
max: 1,
},
},
},
@ -106,13 +111,18 @@ onMounted(() => {
watch(
() => rlstore.current_time,
() => {
const policy = store.getPolicy;
chartHandle.options.scales.x.labels = Array(policy.length)
.fill(0)
.map((_, i) => i);
chartHandle.data.datasets[0].data = policy;
chartHandle.options.plugins.verticalLinePlugin.xValue =
rlstore.current_time;
const policy_x = store.getPolicy;
const policy_y = store.getPolicy_y;
const policy_xy = [];
for (let i = 0; i < policy_x.length; i += 1) {
policy_xy[policy_xy.length] = { x: policy_y[i], y: policy_x[i] };
}
chartHandle.data.datasets[0].data = policy_xy;
// chartHandle.options.plugins.verticalLinePlugin.xValue =
// rlstore.current_time;
chartHandle.update();
}
@ -120,15 +130,20 @@ watch(
watch(
() => store.getTrigger,
() => {
const policy = computeRbfCurve(store.getMaxSteps, store.getWeights);
store.setPolicy(policy);
const policy_x = computeRbfCurve(store.getMaxSteps, store.getWeights);
store.setPolicy(policy_x);
const policy_y = computeRbfCurve(store.getMaxSteps, store.getWeights_y);
store.setPolicy_y(policy_y);
chartHandle.options.scales.x.labels = Array(policy.length)
.fill(0)
.map((_, i) => i);
chartHandle.data.datasets[0].data = policy;
chartHandle.options.plugins.verticalLinePlugin.xValue =
rlstore.current_time;
const policy_xy = [];
for (let i = 0; i < policy_x.length; i += 1) {
policy_xy[policy_xy.length] = { x: policy_y[i], y: policy_x[i] };
}
chartHandle.data.datasets[0].data = policy_xy;
// chartHandle.options.plugins.verticalLinePlugin.xValue =
// rlstore.current_time;
chartHandle.update();
}

View File

@ -1,7 +1,7 @@
<template>
<v-row no-gutters justify="center" class="v-weight-tuner">
<!-- eslint-disable-next-line -->
<v-col v-for="(_ , idx) in weights" :key="idx"
<v-col v-for="(_ , idx) in weights_y" :key="idx"
class="v-column"
:style="{ height: `calc(100% / ${nrweights})` }"
>
@ -10,7 +10,7 @@
<v-col class="v-column-content">
<vue-slider
class="v-slider-margin-bottom"
v-model="weights[idx]"
v-model="weights_y[idx]"
@change="updateWeight(idx, $event)"
direction="ltr"
:min="-1"
@ -23,7 +23,7 @@
<v-checkbox
density="compact"
class="ma-0 v-checkbox-bottom"
v-model="store.weights_fixed[idx]"
v-model="store.weights_fixed_y[idx]"
/>
</v-col>
</v-row>
@ -43,15 +43,15 @@ const store = usePStore();
const nrweights = computed(() => store.nr_weights);
const weights = computed({
get: () => store.weights,
set: (value) => store.setWeights(value),
const weights_y = computed({
get: () => store.weights_y,
set: (value) => store.setWeights_y(value),
});
const updateWeight = (index, newValue) => {
const newWeights = weights.value.slice();
const newWeights = weights_y.value.slice();
newWeights[index] = newValue;
store.setWeights(newWeights);
store.setWeights_y(newWeights);
store.setTrigger();
};
</script>

View File

@ -4,18 +4,24 @@ export const usePStore = defineStore("Policy Store", {
state: () => {
return {
policy: Array(10).fill(0),
policy_y: Array(10).fill(0),
nr_weights: 5,
weights: [0, 0, 0, 0, 0],
weights_y: [0, 0, 0, 0, 0],
weights_fixed: [false, false, false, false, false],
weights_fixed_y: [false, false, false, false, false],
max_steps: 100,
trigger: false,
};
},
getters: {
getPolicy: (state) => state.policy,
getPolicy_y: (state) => state.policy_y,
getNrWeights: (state) => state.nr_weights,
getWeights_Fixed: (state) => state.weights_fixed,
getWeights: (state) => state.weights,
getWeights_Fixed_y: (state) => state.weights_fixed_y,
getWeights_y: (state) => state.weights_y,
getMaxSteps: (state) => state.max_steps,
getTrigger: (state) => state.trigger,
},
@ -24,15 +30,24 @@ export const usePStore = defineStore("Policy Store", {
this.policy = null;
this.policy = value;
},
setPolicy_y(value) {
this.policy_y = null;
this.policy_y = value;
},
setNrWeights(value) {
this.nr_weights = value;
this.weights = Array(this.nr_weights).fill(0);
this.weights_fixed = Array(this.nr_weights).fill(false);
this.weights_y = Array(this.nr_weights).fill(0);
this.weights_fixed_y = Array(this.nr_weights).fill(false);
this.trigger = !this.trigger;
},
setWeights(value) {
this.weights = value;
},
setWeights_y(value) {
this.weights_y = value;
},
setMaxSteps(value) {
this.max_steps = value;
this.trigger = !this.trigger;
@ -43,5 +58,8 @@ export const usePStore = defineStore("Policy Store", {
resetWeights_Fixed() {
this.weights_fixed = Array(this.nr_weights).fill(false);
},
resetWeights_Fixed_y() {
this.weights_fixed_y = Array(this.nr_weights).fill(false);
},
},
});