preparation for 2d policy

This commit is contained in:
Niko Feith 2023-08-28 16:37:20 +02:00
parent 4770ffb12d
commit 9a70229171
3 changed files with 46 additions and 17 deletions

View File

@ -144,8 +144,19 @@ const active_rl_eval_sub = new ROS.Topic({
const pendingRequest = ref(false);
active_rl_eval_sub.subscribe((msg) => {
pstore.setPolicy(msg.policy);
pstore.setWeights(msg.weights);
const nr_steps = msg.nr_steps;
const nr_weights = msg.nr_weights;
const pol_x = msg.policy.slice(0, nr_steps);
const pol_y = msg.policy.slice(nr_steps);
const weights_x = msg.weights.slice(0, nr_weights);
const weights_y = msg.weights.slice(nr_weights);
pstore.setPolicy(pol_x);
pstore.setPolicy_y(pol_y);
pstore.setWeights(weights_x);
pstore.setWeights(weights_y);
pendingRequest.value = true;
});
@ -162,10 +173,13 @@ watch(
return;
}
console.log("Button pressed!");
const policy = pstore.policy.concat(pstore.policy_y);
const weights = pstore.weights.concat(pstore.weights_y);
const weights_fixed = pstore.weights_fixed.concat(pstore.weights_fixed_y);
const active_eval_response = new ROS.Message({
policy: pstore.policy,
weights: pstore.weights,
overwrite_weight: pstore.weights_fixed,
policy: policy,
weights: weights,
overwrite_weight: weights_fixed,
});
console.log(active_eval_response);
@ -178,12 +192,6 @@ watch(
const active_bo_pending = ref(false);
const active_bo_request = new ROS.Topic({
ros: ros,
name: "/active_bo_request",
messageType: "active_bo_msgs/msg/ActiveBORequest",
});
const active_bo_response = new ROS.Topic({
ros: ros,
name: "/active_bo_response",
@ -191,13 +199,30 @@ const active_bo_response = new ROS.Topic({
});
active_bo_response.subscribe((msg) => {
pstore.setPolicy(msg.best_policy);
pstore.setWeights(msg.best_weights);
const nr_steps = msg.nr_steps;
const nr_weights = msg.nr_weights;
const pol_x = msg.best_policy.slice(0, nr_steps);
const pol_y = msg.best_policy.slice(nr_steps);
const weights_x = msg.best_weights.slice(0, nr_weights);
const weights_y = msg.best_weights.slice(nr_weights);
pstore.setPolicy(pol_x);
pstore.setPolicy_y(pol_y);
pstore.setWeights(weights_x);
pstore.setWeights(weights_y);
rstore.setMean(msg.reward_mean);
rstore.setStd(msg.reward_std);
active_bo_pending.value = false;
});
const active_bo_request = new ROS.Topic({
ros: ros,
name: "/active_bo_request",
messageType: "active_bo_msgs/msg/ActiveBORequest",
});
watch(
() => cstore.getRunner,
() => {
@ -207,6 +232,7 @@ watch(
fixed_seed: cstore.fixed_seed,
nr_weights: pstore.nr_weights,
max_steps: pstore.max_steps,
nr_dims: cstore.env_dim[cstore.env],
nr_episodes: cstore.nr_episodes,
nr_runs: cstore.nr_runs,
acquisition_function: cstore.acq_fun,

View File

@ -1,10 +1,14 @@
import { defineStore } from "pinia";
import {state} from "vue-tsc/out/shared";
export const useCStore = defineStore("Control Store", {
state: () => {
return {
env: "Mountain Car",
envs: ["Mountain Car", "Cartpole", "Acrobot", "Pendulum"],
env: "Reacher",
envs: ["Reacher"],
env_dim: {
Reacher: 2,
},
metric: "random",
metrics: ["random", "regular", "improvement", "max_acquisition"],
metrics_label: {
@ -39,6 +43,7 @@ export const useCStore = defineStore("Control Store", {
},
getters: {
getEnv: (state) => state.env,
getEnvdims: (state) => state.env_dims,
getMetric: (state) => state.metric,
getNrEpisodes: (state) => state.nr_episodes,
getNrRuns: (state) => state.nr_runs,

View File

@ -57,8 +57,6 @@ export const usePStore = defineStore("Policy Store", {
},
resetWeights_Fixed() {
this.weights_fixed = Array(this.nr_weights).fill(false);
},
resetWeights_Fixed_y() {
this.weights_fixed_y = Array(this.nr_weights).fill(false);
},
},