Skip to content

Commit

Permalink
re-organize example
Browse files Browse the repository at this point in the history
  • Loading branch information
Ja4822 committed Apr 18, 2023
1 parent 039f775 commit 3cff82b
Show file tree
Hide file tree
Showing 21 changed files with 29 additions and 18 deletions.
Empty file added examples/__init__.py
Empty file.
6 changes: 0 additions & 6 deletions examples/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
from .bc_configs import BCTrainConfig, BC_DEFAULT_CONFIG
from .bcql_configs import BCQLTrainConfig, BCQL_DEFAULT_CONFIG
from .bearl_configs import BEARLTrainConfig, BEARL_DEFAULT_CONFIG
from .cdt_configs import CDTTrainConfig, CDT_DEFAULT_CONFIG
from .coptidice_configs import COptiDICETrainConfig, COptiDICE_DEFAULT_CONFIG
from .cpq_configs import CPQTrainConfig, CPQ_DEFAULT_CONFIG
2 changes: 1 addition & 1 deletion examples/configs/bc_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BCTrainConfig:
dataset: str = None
seed: int = 0
device: str = "cuda:0"
thread: int = 4
threads: int = 4
actor_lr: float = 0.001
cost_limit: int = 10
episode_len: int = 300
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/bcql_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BCQLTrainConfig:
dataset: str = None
seed: int = 0
device: str = "cuda:0"
thread: int = 4
threads: int = 4
reward_scale: float = 0.1
cost_scale: float = 1
actor_lr: float = 0.001
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/bearl_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BEARLTrainConfig:
dataset: str = None
seed: int = 0
device: str = "cuda:0"
thread: int = 4
threads: int = 4
reward_scale: float = 0.1
cost_scale: float = 1
actor_lr: float = 0.001
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/cdt_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CDTTrainConfig:
# general params
seed: int = 11
device: str = "cuda:0"
thread: int = 6
threads: int = 6
# augmentation param
deg: int = 4
pf_sample: bool = False
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/coptidice_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class COptiDICETrainConfig:
dataset: str = None
seed: int = 0
device: str = "cuda:0"
thread: int = 4
threads: int = 4
reward_scale: float = 0.1
cost_scale: float = 1
actor_lr: float = 0.0001
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/cpq_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class CPQTrainConfig:
dataset: str = None
seed: int = 0
device: str = "cuda:0"
thread: int = 4
threads: int = 4
reward_scale: float = 0.1
cost_scale: float = 1
actor_lr: float = 0.0001
Expand Down
1 change: 1 addition & 0 deletions examples/eval_bc.py → examples/eval/eval_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import asdict, dataclass

import gym # noqa
import dsrl
import numpy as np
import pyrallis
from pyrallis import field
Expand Down
1 change: 1 addition & 0 deletions examples/eval_bcql.py → examples/eval/eval_bcql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import asdict, dataclass

import gym # noqa
import dsrl
import numpy as np
import pyrallis
from pyrallis import field
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions examples/eval_cdt.py → examples/eval/eval_cdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import asdict, dataclass

import gym # noqa
import dsrl
import numpy as np
import pyrallis
from pyrallis import field
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import asdict, dataclass

import gym # noqa
import dsrl
import numpy as np
import pyrallis
from pyrallis import field
Expand Down
1 change: 1 addition & 0 deletions examples/eval_cpq.py → examples/eval/eval_cpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import asdict, dataclass

import gym # noqa
import dsrl
import numpy as np
import pyrallis
from pyrallis import field
Expand Down
Empty file added examples/train/__init__.py
Empty file.
4 changes: 3 additions & 1 deletion examples/train_bc.py → examples/train/train_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from osrl.common.dataset import process_bc_dataset
from osrl.algorithms import BC, BCTrainer
from saferl.utils.exp_util import auto_name, seed_all
from configs.bc_configs import BCTrainConfig, BC_DEFAULT_CONFIG
from examples.configs.bc_configs import BCTrainConfig, BC_DEFAULT_CONFIG


@pyrallis.wrap()
def train(args: BCTrainConfig):
seed_all(args.seed)
if args.device == "cpu":
torch.set_num_threads(args.threads)

# setup logger
cfg = asdict(args)
Expand Down
4 changes: 3 additions & 1 deletion examples/train_bcql.py → examples/train/train_bcql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from osrl.common import TransitionDataset
from osrl.algorithms import BCQL, BCQLTrainer
from saferl.utils.exp_util import auto_name, seed_all
from configs.bcql_configs import BCQLTrainConfig, BCQL_DEFAULT_CONFIG
from examples.configs.bcql_configs import BCQLTrainConfig, BCQL_DEFAULT_CONFIG


@pyrallis.wrap()
def train(args: BCQLTrainConfig):
seed_all(args.seed)
if args.device == "cpu":
torch.set_num_threads(args.threads)

# setup logger
cfg = asdict(args)
Expand Down
4 changes: 3 additions & 1 deletion examples/train_bearl.py → examples/train/train_bearl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from osrl.common import TransitionDataset
from osrl.algorithms import BEARL, BEARLTrainer
from saferl.utils.exp_util import auto_name, seed_all
from configs.bearl_configs import BEARLTrainConfig, BEARL_DEFAULT_CONFIG
from examples.configs.bearl_configs import BEARLTrainConfig, BEARL_DEFAULT_CONFIG


@pyrallis.wrap()
def train(args: BEARLTrainConfig):
seed_all(args.seed)
if args.device == "cpu":
torch.set_num_threads(args.threads)

# setup logger
cfg = asdict(args)
Expand Down
4 changes: 3 additions & 1 deletion examples/train_cdt.py → examples/train/train_cdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from osrl.common import SequenceDataset
from osrl.algorithms import CDT, CDTTrainer
from saferl.utils.exp_util import auto_name, seed_all
from configs.cdt_configs import CDTTrainConfig, CDT_DEFAULT_CONFIG
from examples.configs.cdt_configs import CDTTrainConfig, CDT_DEFAULT_CONFIG


@pyrallis.wrap()
def train(args: CDTTrainConfig):
seed_all(args.seed)
if args.device == "cpu":
torch.set_num_threads(args.threads)

# setup logger
cfg = asdict(args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from osrl.common import TransitionDataset
from osrl.algorithms import COptiDICE, COptiDICETrainer
from saferl.utils.exp_util import auto_name, seed_all
from configs.coptidice_configs import COptiDICETrainConfig, COptiDICE_DEFAULT_CONFIG
from examples.configs.coptidice_configs import COptiDICETrainConfig, COptiDICE_DEFAULT_CONFIG


@pyrallis.wrap()
def train(args: COptiDICETrainConfig):
seed_all(args.seed)
if args.device == "cpu":
torch.set_num_threads(args.threads)

# setup logger
cfg = asdict(args)
Expand Down
4 changes: 3 additions & 1 deletion examples/train_cpq.py → examples/train/train_cpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from osrl.common import TransitionDataset
from osrl.algorithms import CPQ, CPQTrainer
from saferl.utils.exp_util import auto_name, seed_all
from configs.cpq_configs import CPQTrainConfig, CPQ_DEFAULT_CONFIG
from examples.configs.cpq_configs import CPQTrainConfig, CPQ_DEFAULT_CONFIG


@pyrallis.wrap()
def train(args: CPQTrainConfig):
seed_all(args.seed)
if args.device == "cpu":
torch.set_num_threads(args.threads)

# setup logger
cfg = asdict(args)
Expand Down

0 comments on commit 3cff82b

Please sign in to comment.