diff --git a/InteractiveBO/src/components/RosBar.vue b/InteractiveBO/src/components/RosBar.vue index 8cb2b19..0811e5a 100644 --- a/InteractiveBO/src/components/RosBar.vue +++ b/InteractiveBO/src/components/RosBar.vue @@ -144,8 +144,19 @@ const active_rl_eval_sub = new ROS.Topic({ const pendingRequest = ref(false); active_rl_eval_sub.subscribe((msg) => { - pstore.setPolicy(msg.policy); - pstore.setWeights(msg.weights); + const nr_steps = msg.nr_steps; + const nr_weights = msg.nr_weights; + + const pol_x = msg.policy.slice(0, nr_steps); + const pol_y = msg.policy.slice(nr_steps); + + const weights_x = msg.weights.slice(0, nr_weights); + const weights_y = msg.weights.slice(nr_weights); + + pstore.setPolicy(pol_x); + pstore.setPolicy_y(pol_y); + pstore.setWeights(weights_x); + pstore.setWeights(weights_y); pendingRequest.value = true; }); @@ -162,10 +173,13 @@ watch( return; } console.log("Button pressed!"); + const policy = pstore.policy.concat(pstore.policy_y); + const weights = pstore.weights.concat(pstore.weights_y); + const weights_fixed = pstore.weights_fixed.concat(pstore.weights_fixed_y); const active_eval_response = new ROS.Message({ - policy: pstore.policy, - weights: pstore.weights, - overwrite_weight: pstore.weights_fixed, + policy: policy, + weights: weights, + overwrite_weight: weights_fixed, }); console.log(active_eval_response); @@ -178,12 +192,6 @@ watch( const active_bo_pending = ref(false); -const active_bo_request = new ROS.Topic({ - ros: ros, - name: "/active_bo_request", - messageType: "active_bo_msgs/msg/ActiveBORequest", -}); - const active_bo_response = new ROS.Topic({ ros: ros, name: "/active_bo_response", @@ -191,13 +199,30 @@ const active_bo_response = new ROS.Topic({ }); active_bo_response.subscribe((msg) => { - pstore.setPolicy(msg.best_policy); - pstore.setWeights(msg.best_weights); + const nr_steps = msg.nr_steps; + const nr_weights = msg.nr_weights; + + const pol_x = msg.best_policy.slice(0, nr_steps); + const pol_y = msg.best_policy.slice(nr_steps); + + const weights_x = msg.best_weights.slice(0, nr_weights); + const weights_y = msg.best_weights.slice(nr_weights); + + pstore.setPolicy(pol_x); + pstore.setPolicy_y(pol_y); + pstore.setWeights(weights_x); + pstore.setWeights(weights_y); rstore.setMean(msg.reward_mean); rstore.setStd(msg.reward_std); active_bo_pending.value = false; }); +const active_bo_request = new ROS.Topic({ + ros: ros, + name: "/active_bo_request", + messageType: "active_bo_msgs/msg/ActiveBORequest", +}); + watch( () => cstore.getRunner, () => { @@ -207,6 +232,7 @@ watch( fixed_seed: cstore.fixed_seed, nr_weights: pstore.nr_weights, max_steps: pstore.max_steps, + nr_dims: cstore.env_dim[cstore.env], nr_episodes: cstore.nr_episodes, nr_runs: cstore.nr_runs, acquisition_function: cstore.acq_fun, diff --git a/InteractiveBO/src/store/ControlStore.js b/InteractiveBO/src/store/ControlStore.js index cf5c091..fc4410f 100644 --- a/InteractiveBO/src/store/ControlStore.js +++ b/InteractiveBO/src/store/ControlStore.js @@ -1,10 +1,14 @@ import { defineStore } from "pinia"; +import {state} from "vue-tsc/out/shared"; export const useCStore = defineStore("Control Store", { state: () => { return { - env: "Mountain Car", - envs: ["Mountain Car", "Cartpole", "Acrobot", "Pendulum"], + env: "Reacher", + envs: ["Reacher"], + env_dim: { + Reacher: 2, + }, metric: "random", metrics: ["random", "regular", "improvement", "max_acquisition"], metrics_label: { @@ -39,6 +43,7 @@ export const useCStore = defineStore("Control Store", { }, getters: { getEnv: (state) => state.env, + getEnvdims: (state) => state.env_dims, getMetric: (state) => state.metric, getNrEpisodes: (state) => state.nr_episodes, getNrRuns: (state) => state.nr_runs, diff --git a/InteractiveBO/src/store/PolicyStore.js b/InteractiveBO/src/store/PolicyStore.js index 8dec8d9..2595ca3 100644 --- a/InteractiveBO/src/store/PolicyStore.js +++ b/InteractiveBO/src/store/PolicyStore.js @@ -57,8 +57,6 @@ export const usePStore = defineStore("Policy Store", { }, resetWeights_Fixed() { this.weights_fixed = Array(this.nr_weights).fill(false); - }, - resetWeights_Fixed_y() { this.weights_fixed_y = Array(this.nr_weights).fill(false); }, },