Skip to content

Commit

Permalink
update scripts for metadrive envs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ja4822 committed Aug 27, 2023
1 parent 6ede2c2 commit f504dd4
Show file tree
Hide file tree
Showing 12 changed files with 24 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/eval/eval_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def eval(args: EvalConfig):
if args.device == "cpu":
torch.set_num_threads(args.threads)

if "Metadrive" in cfg["task"]:
import gym
env = gym.make(cfg["task"])
env.set_target_cost(cfg["cost_limit"])

Expand Down
2 changes: 2 additions & 0 deletions examples/eval/eval_bcql.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def eval(args: EvalConfig):
if args.device == "cpu":
torch.set_num_threads(args.threads)

if "Metadrive" in cfg["task"]:
import gym
env = wrap_env(
env=gym.make(cfg["task"]),
reward_scale=cfg["reward_scale"],
Expand Down
2 changes: 2 additions & 0 deletions examples/eval/eval_bearl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def eval(args: EvalConfig):
if args.device == "cpu":
torch.set_num_threads(args.threads)

if "Metadrive" in cfg["task"]:
import gym
env = wrap_env(
env=gym.make(cfg["task"]),
reward_scale=cfg["reward_scale"],
Expand Down
2 changes: 2 additions & 0 deletions examples/eval/eval_cdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def eval(args: EvalConfig):
if args.device == "cpu":
torch.set_num_threads(args.threads)

if "Metadrive" in cfg["task"]:
import gym
env = wrap_env(
env=gym.make(cfg["task"]),
reward_scale=cfg["reward_scale"],
Expand Down
2 changes: 2 additions & 0 deletions examples/eval/eval_coptidice.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def eval(args: EvalConfig):
if args.device == "cpu":
torch.set_num_threads(args.threads)

if "Metadrive" in cfg["task"]:
import gym
env = wrap_env(
env=gym.make(cfg["task"]),
reward_scale=cfg["reward_scale"],
Expand Down
2 changes: 2 additions & 0 deletions examples/eval/eval_cpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def eval(args: EvalConfig):
if args.device == "cpu":
torch.set_num_threads(args.threads)

if "Metadrive" in cfg["task"]:
import gym
env = wrap_env(
env=gym.make(cfg["task"]),
reward_scale=cfg["reward_scale"],
Expand Down
2 changes: 2 additions & 0 deletions examples/train/train_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def train(args: BCTrainConfig):
torch.set_num_threads(args.threads)

# the cost scale is down in trainer rollout
if "Metadrive" in args.task:
import gym
env = gym.make(args.task)
data = env.get_dataset()
env.set_target_cost(args.cost_limit)
Expand Down
2 changes: 2 additions & 0 deletions examples/train/train_bcql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def train(args: BCQLTrainConfig):
torch.set_num_threads(args.threads)

# initialize environment
if "Metadrive" in args.task:
import gym
env = gym.make(args.task)

# pre-process offline dataset
Expand Down
2 changes: 2 additions & 0 deletions examples/train/train_bearl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def train(args: BEARLTrainConfig):
torch.set_num_threads(args.threads)

# initialize environment
if "Metadrive" in args.task:
import gym
env = gym.make(args.task)

# pre-process offline dataset
Expand Down
2 changes: 2 additions & 0 deletions examples/train/train_cdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def train(args: CDTTrainConfig):
torch.set_num_threads(args.threads)

# initialize environment
if "Metadrive" in args.task:
import gym
env = gym.make(args.task)

# pre-process offline dataset
Expand Down
2 changes: 2 additions & 0 deletions examples/train/train_coptidice.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def train(args: COptiDICETrainConfig):
torch.set_num_threads(args.threads)

# initialize environment
if "Metadrive" in args.task:
import gym
env = gym.make(args.task)

# pre-process offline dataset
Expand Down
2 changes: 2 additions & 0 deletions examples/train/train_cpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def train(args: CPQTrainConfig):
torch.set_num_threads(args.threads)

# initialize environment
if "Metadrive" in args.task:
import gym
env = gym.make(args.task)

# pre-process offline dataset
Expand Down

0 comments on commit f504dd4

Please sign in to comment.