From d27d07006bf0c72d35c081a0b3ff79c369817f69 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 24 Nov 2022 12:16:09 +0000 Subject: [PATCH] [Formatting] Upgrade formatting libs (#705) * init * amend * amend * amend * amend * lint * itertools.islice --- .pre-commit-config.yaml | 13 +++-- .../benchmark_sample_latency_over_rpc.py | 6 +-- build_tools/setup_helpers/__init__.py | 2 +- build_tools/setup_helpers/extension.py | 2 +- docs/source/conf.py | 2 +- examples/ddpg/ddpg.py | 16 ++---- examples/dqn/dqn.py | 16 ++---- examples/dreamer/dreamer.py | 16 ++---- examples/dreamer/dreamer_utils.py | 12 ++--- examples/ppo/ppo.py | 9 ++-- examples/redq/redq.py | 16 ++---- examples/sac/sac.py | 16 ++---- .../torchrl_features/memmap_td_distributed.py | 20 +++---- pyproject.toml | 3 +- setup.cfg | 1 + setup.py | 11 ++-- test/_utils_internal.py | 2 +- test/mocking_classes.py | 4 +- test/smoke_test.py | 4 +- test/test_actors.py | 2 +- test/test_collector.py | 12 ++--- test/test_cost.py | 12 ++--- test/test_env.py | 8 +-- test/test_helpers.py | 34 ++++++------ test/test_libs.py | 12 ++--- test/test_modules.py | 6 +-- test/test_tensor_spec.py | 10 ++-- test/test_tensordictmodules.py | 2 +- test/test_trainer.py | 10 ++-- test/test_transforms.py | 19 +++---- torchrl/_utils.py | 2 +- torchrl/collectors/__init__.py | 2 +- torchrl/collectors/collectors.py | 7 +-- torchrl/data/__init__.py | 30 +++++------ torchrl/data/postprocs/postprocs.py | 2 +- torchrl/data/replay_buffers/__init__.py | 6 +-- torchrl/data/replay_buffers/rb_prototype.py | 13 ++--- torchrl/data/replay_buffers/replay_buffers.py | 17 ++---- torchrl/data/replay_buffers/samplers.py | 13 ++--- torchrl/data/tensor_specs.py | 13 +++-- torchrl/envs/__init__.py | 42 +++++++-------- torchrl/envs/common.py | 7 +-- torchrl/envs/env_creator.py | 4 +- torchrl/envs/gym_like.py | 5 +- torchrl/envs/libs/dm_control.py | 11 ++-- torchrl/envs/libs/gym.py | 5 +- torchrl/envs/libs/jumanji.py | 10 ++-- torchrl/envs/model_based/common.py | 2 +- torchrl/envs/model_based/dreamer.py | 3 +- torchrl/envs/transforms/__init__.py | 38 +++++++------- torchrl/envs/transforms/r3m.py | 8 +-- torchrl/envs/transforms/transforms.py | 16 +++--- torchrl/envs/transforms/vip.py | 8 +-- torchrl/envs/vec_env.py | 14 ++--- torchrl/modules/__init__.py | 52 +++++++++---------- torchrl/modules/distributions/__init__.py | 11 ++-- torchrl/modules/distributions/continuous.py | 4 +- .../modules/distributions/truncated_normal.py | 6 +-- torchrl/modules/functional_modules.py | 16 +++--- torchrl/modules/models/__init__.py | 12 ++--- torchrl/modules/models/exploration.py | 2 +- torchrl/modules/models/models.py | 34 ++++++------ torchrl/modules/models/utils.py | 1 + torchrl/modules/tensordict_module/__init__.py | 10 ++-- torchrl/modules/tensordict_module/actors.py | 2 +- torchrl/modules/tensordict_module/common.py | 15 +----- .../tensordict_module/probabilistic.py | 11 ++-- torchrl/modules/tensordict_module/sequence.py | 24 ++++----- torchrl/modules/utils/__init__.py | 2 +- torchrl/objectives/__init__.py | 12 ++--- torchrl/objectives/common.py | 16 +++--- torchrl/objectives/ddpg.py | 9 ++-- torchrl/objectives/deprecated.py | 4 +- torchrl/objectives/dqn.py | 8 ++- torchrl/objectives/dreamer.py | 9 ++-- torchrl/objectives/ppo.py | 3 +- torchrl/objectives/redq.py | 4 +- torchrl/objectives/reinforce.py | 6 +-- torchrl/objectives/sac.py | 10 ++-- torchrl/objectives/utils.py | 14 ++--- torchrl/objectives/value/__init__.py | 2 +- torchrl/objectives/value/advantages.py | 7 +-- torchrl/objectives/value/functional.py | 2 +- torchrl/record/__init__.py | 2 +- torchrl/trainers/__init__.py | 4 +- torchrl/trainers/helpers/__init__.py | 24 ++++----- torchrl/trainers/helpers/collectors.py | 10 ++-- torchrl/trainers/helpers/envs.py | 9 ++-- torchrl/trainers/helpers/models.py | 7 +-- torchrl/trainers/helpers/trainers.py | 18 +++---- torchrl/trainers/trainers.py | 15 +++--- 91 files changed, 460 insertions(+), 513 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5f385f466b2..9d2f0de2a33 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,18 +11,23 @@ repos: - id: end-of-file-fixer - repo: https://github.com/omnilib/ufmt - rev: v1.3.2 + rev: v2.0.0b2 hooks: - id: ufmt additional_dependencies: - - black == 21.9b0 - - usort == 0.6.4 + - black == 22.3.0 + - usort == 1.0.3 + - libcst == 0.4.7 - repo: https://github.com/pycqa/flake8 - rev: 3.9.2 + rev: 4.0.1 hooks: - id: flake8 args: [--config=setup.cfg] + additional_dependencies: + - flake8-bugbear==22.10.27 + - flake8-comprehensions==3.10.1 + - repo: https://github.com/PyCQA/pydocstyle rev: 6.1.1 diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index ac88a557091..d922095de5f 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -44,9 +44,9 @@ } storage_arg_options = { - "LazyMemmapStorage": dict(scratch_dir="/tmp/", device=torch.device("cpu")), - "LazyTensorStorage": dict(), - "ListStorage": dict(), + "LazyMemmapStorage": {"scratch_dir": "/tmp/", "device": torch.device("cpu")}, + "LazyTensorStorage": {}, + "ListStorage": {}, } parser = argparse.ArgumentParser( description="RPC Replay Buffer Example", diff --git a/build_tools/setup_helpers/__init__.py b/build_tools/setup_helpers/__init__.py index 2a4da8d7b96..6c424ebba14 100644 --- a/build_tools/setup_helpers/__init__.py +++ b/build_tools/setup_helpers/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .extension import get_ext_modules, CMakeBuild # noqa +from .extension import CMakeBuild, get_ext_modules # noqa diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 314d61128ad..6e950caa237 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -8,7 +8,7 @@ import platform import subprocess from pathlib import Path -from subprocess import check_output, STDOUT, CalledProcessError +from subprocess import CalledProcessError, check_output, STDOUT import torch from setuptools import Extension diff --git a/docs/source/conf.py b/docs/source/conf.py index 63e0146f452..05c4c0b9b91 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -150,7 +150,7 @@ } -aafig_default_options = dict(scale=1.5, aspect=1.0, proportional=True) +aafig_default_options = {"scale": 1.5, "aspect": 1.0, "proportional": True} # -- Generate knowledge base references ----------------------------------- current_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index 5fb84a2b25f..69bb54fbbd6 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -12,7 +12,7 @@ import hydra import torch.cuda from hydra.core.config_store import ConfigStore -from torchrl.envs import ParallelEnv, EnvCreator +from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.envs.utils import set_exploration_mode from torchrl.modules import OrnsteinUhlenbeckProcessWrapper @@ -23,21 +23,15 @@ ) from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, + EnvConfig, get_stats_random_rollout, parallel_env_constructor, transformed_env_constructor, - EnvConfig, ) from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.losses import make_ddpg_loss, LossConfig -from torchrl.trainers.helpers.models import ( - make_ddpg_actor, - DDPGModelConfig, -) -from torchrl.trainers.helpers.replay_buffer import ( - make_replay_buffer, - ReplayArgsConfig, -) +from torchrl.trainers.helpers.losses import LossConfig, make_ddpg_loss +from torchrl.trainers.helpers.models import DDPGModelConfig, make_ddpg_actor +from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig config_fields = [ diff --git a/examples/dqn/dqn.py b/examples/dqn/dqn.py index fea51d0667c..96a81e533a4 100644 --- a/examples/dqn/dqn.py +++ b/examples/dqn/dqn.py @@ -12,7 +12,7 @@ import hydra import torch.cuda from hydra.core.config_store import ConfigStore -from torchrl.envs import ParallelEnv, EnvCreator +from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.modules import EGreedyWrapper from torchrl.record import VideoRecorder @@ -22,21 +22,15 @@ ) from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, + EnvConfig, get_stats_random_rollout, parallel_env_constructor, transformed_env_constructor, - EnvConfig, ) from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.losses import make_dqn_loss, LossConfig -from torchrl.trainers.helpers.models import ( - make_dqn_actor, - DiscreteModelConfig, -) -from torchrl.trainers.helpers.replay_buffer import ( - make_replay_buffer, - ReplayArgsConfig, -) +from torchrl.trainers.helpers.losses import LossConfig, make_dqn_loss +from torchrl.trainers.helpers.models import DiscreteModelConfig, make_dqn_actor +from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index f6926638182..140d7c83287 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -8,12 +8,12 @@ import torch.cuda import tqdm from dreamer_utils import ( - parallel_env_constructor, - transformed_env_constructor, call_record, + EnvConfig, grad_norm, make_recorder_env, - EnvConfig, + parallel_env_constructor, + transformed_env_constructor, ) from hydra.core.config_store import ConfigStore @@ -38,14 +38,8 @@ get_stats_random_rollout, ) from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.models import ( - make_dreamer, - DreamerConfig, -) -from torchrl.trainers.helpers.replay_buffer import ( - make_replay_buffer, - ReplayArgsConfig, -) +from torchrl.trainers.helpers.models import DreamerConfig, make_dreamer +from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig from torchrl.trainers.helpers.trainers import TrainerConfig from torchrl.trainers.trainers import Recorder, RewardNormalizer diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 341f97739cc..12b2e23842d 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -2,9 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from dataclasses import field as dataclass_field -from typing import Callable, Optional, Union, Any, Sequence +from dataclasses import dataclass, field as dataclass_field +from typing import Any, Callable, Optional, Sequence, Union from torchrl.data import NdUnboundedContinuousTensorSpec from torchrl.envs import ParallelEnv @@ -14,6 +13,7 @@ from torchrl.envs.libs.gym import GymEnv from torchrl.envs.transforms import ( CatFrames, + CenterCrop, DoubleToFloat, GrayScale, NoopResetEnv, @@ -22,12 +22,8 @@ RewardScaling, ToTensorImage, TransformedEnv, - CenterCrop, -) -from torchrl.envs.transforms.transforms import ( - FlattenObservation, - TensorDictPrimer, ) +from torchrl.envs.transforms.transforms import FlattenObservation, TensorDictPrimer from torchrl.record.recorder import VideoRecorder from torchrl.trainers.loggers import Logger diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py index a5b7ea1da96..c2e63abafe4 100644 --- a/examples/ppo/ppo.py +++ b/examples/ppo/ppo.py @@ -12,7 +12,7 @@ import hydra import torch.cuda from hydra.core.config_store import ConfigStore -from torchrl.envs import ParallelEnv, EnvCreator +from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.envs.utils import set_exploration_mode from torchrl.objectives.value import GAE @@ -23,17 +23,14 @@ ) from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, + EnvConfig, get_stats_random_rollout, parallel_env_constructor, transformed_env_constructor, - EnvConfig, ) from torchrl.trainers.helpers.logger import LoggerConfig from torchrl.trainers.helpers.losses import make_ppo_loss, PPOLossConfig -from torchrl.trainers.helpers.models import ( - make_ppo_model, - PPOModelConfig, -) +from torchrl.trainers.helpers.models import make_ppo_model, PPOModelConfig from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig config_fields = [ diff --git a/examples/redq/redq.py b/examples/redq/redq.py index 398d8610368..a470d41e8fe 100644 --- a/examples/redq/redq.py +++ b/examples/redq/redq.py @@ -12,7 +12,7 @@ import hydra import torch.cuda from hydra.core.config_store import ConfigStore -from torchrl.envs import ParallelEnv, EnvCreator +from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.envs.utils import set_exploration_mode from torchrl.modules import OrnsteinUhlenbeckProcessWrapper @@ -23,21 +23,15 @@ ) from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, + EnvConfig, get_stats_random_rollout, parallel_env_constructor, transformed_env_constructor, - EnvConfig, ) from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.losses import make_redq_loss, LossConfig -from torchrl.trainers.helpers.models import ( - make_redq_model, - REDQModelConfig, -) -from torchrl.trainers.helpers.replay_buffer import ( - make_replay_buffer, - ReplayArgsConfig, -) +from torchrl.trainers.helpers.losses import LossConfig, make_redq_loss +from torchrl.trainers.helpers.models import make_redq_model, REDQModelConfig +from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig config_fields = [ diff --git a/examples/sac/sac.py b/examples/sac/sac.py index e23f567dc36..9f477748293 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -12,7 +12,7 @@ import hydra import torch.cuda from hydra.core.config_store import ConfigStore -from torchrl.envs import ParallelEnv, EnvCreator +from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.envs.utils import set_exploration_mode from torchrl.modules import OrnsteinUhlenbeckProcessWrapper @@ -23,21 +23,15 @@ ) from torchrl.trainers.helpers.envs import ( correct_for_frame_skip, + EnvConfig, get_stats_random_rollout, parallel_env_constructor, transformed_env_constructor, - EnvConfig, ) from torchrl.trainers.helpers.logger import LoggerConfig -from torchrl.trainers.helpers.losses import make_sac_loss, LossConfig -from torchrl.trainers.helpers.models import ( - make_sac_model, - SACModelConfig, -) -from torchrl.trainers.helpers.replay_buffer import ( - make_replay_buffer, - ReplayArgsConfig, -) +from torchrl.trainers.helpers.losses import LossConfig, make_sac_loss +from torchrl.trainers.helpers.models import make_sac_model, SACModelConfig +from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig config_fields = [ diff --git a/examples/torchrl_features/memmap_td_distributed.py b/examples/torchrl_features/memmap_td_distributed.py index 54dcac5cd12..98ca235f16f 100644 --- a/examples/torchrl_features/memmap_td_distributed.py +++ b/examples/torchrl_features/memmap_td_distributed.py @@ -89,9 +89,9 @@ def tensordict_add_noreturn(): time.sleep(1) t0 = time.time() for w in range(1, args.world_size): - fut0 = rpc.rpc_async(f"worker{w}", get_tensordict, args=tuple()) + fut0 = rpc.rpc_async(f"worker{w}", get_tensordict, args=()) fut0.wait() - fut1 = rpc.rpc_async(f"worker{w}", tensordict_add, args=tuple()) + fut1 = rpc.rpc_async(f"worker{w}", tensordict_add, args=()) tensordict2 = fut1.wait() tensordict2.clone() print("time: ", time.time() - t0) @@ -99,7 +99,7 @@ def tensordict_add_noreturn(): time.sleep(1) t0 = time.time() waiters = [ - rpc.remote(f"worker{w}", get_tensordict, args=tuple()) + rpc.remote(f"worker{w}", get_tensordict, args=()) for w in range(1, args.world_size) ] td = torch.stack([waiter.to_here() for waiter in waiters], 0).contiguous() @@ -107,7 +107,7 @@ def tensordict_add_noreturn(): t0 = time.time() waiters = [ - rpc.remote(f"worker{w}", tensordict_add, args=tuple()) + rpc.remote(f"worker{w}", tensordict_add, args=()) for w in range(1, args.world_size) ] td = torch.stack([waiter.to_here() for waiter in waiters], 0).contiguous() @@ -118,9 +118,9 @@ def tensordict_add_noreturn(): elif args.task == 2: time.sleep(1) t0 = time.time() - # waiters = [rpc.rpc_async(f"worker{w}", get_tensordict, args=tuple()) for w in range(1, args.world_size)] + # waiters = [rpc.rpc_async(f"worker{w}", get_tensordict, args=()) for w in range(1, args.world_size)] waiters = [ - rpc.remote(f"worker{w}", get_tensordict, args=tuple()) + rpc.remote(f"worker{w}", get_tensordict, args=()) for w in range(1, args.world_size) ] # td = torch.stack([waiter.wait() for waiter in waiters], 0).clone() @@ -129,7 +129,7 @@ def tensordict_add_noreturn(): t0 = time.time() if args.memmap: waiters = [ - rpc.remote(f"worker{w}", tensordict_add_noreturn, args=tuple()) + rpc.remote(f"worker{w}", tensordict_add_noreturn, args=()) for w in range(1, args.world_size) ] print("temp t: ", time.time() - t0) @@ -139,7 +139,7 @@ def tensordict_add_noreturn(): print("temp t: ", time.time() - t0) else: waiters = [ - rpc.remote(f"worker{w}", tensordict_add, args=tuple()) + rpc.remote(f"worker{w}", tensordict_add, args=()) for w in range(1, args.world_size) ] print("temp t: ", time.time() - t0) @@ -153,14 +153,14 @@ def tensordict_add_noreturn(): time.sleep(1) t0 = time.time() waiters = [ - rpc.remote(f"worker{w}", get_tensordict, args=tuple()) + rpc.remote(f"worker{w}", get_tensordict, args=()) for w in range(1, args.world_size) ] td = torch.stack([waiter.to_here() for waiter in waiters], 0) print("time to receive objs: ", time.time() - t0) t0 = time.time() waiters = [ - rpc.remote(f"worker{w}", tensordict_add, args=tuple()) + rpc.remote(f"worker{w}", tensordict_add, args=()) for w in range(1, args.world_size) ] print("temp t: ", time.time() - t0) diff --git a/pyproject.toml b/pyproject.toml index 08fbca85d57..714cc1ca64a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,5 @@ -#[tool.usort] +[tool.usort] +first_party_detection = false [build-system] requires = ["setuptools", "wheel", "torch"] diff --git a/setup.cfg b/setup.cfg index adc31b72ec2..3f9ce9e3e4b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,7 @@ per-file-ignores = test/opengl_rendering.py: F401 exclude = venv +extend-select = B901, C401, C408, C409 [pydocstyle] ;select = D417 # Missing argument descriptions in the docstring diff --git a/setup.py b/setup.py index 341f012e247..11d8c1e6b8f 100644 --- a/setup.py +++ b/setup.py @@ -13,11 +13,8 @@ from pathlib import Path from typing import List -from setuptools import setup, find_packages -from torch.utils.cpp_extension import ( - CppExtension, - BuildExtension, -) +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CppExtension cwd = os.path.dirname(os.path.abspath(__file__)) try: @@ -144,10 +141,10 @@ def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "torchrl", "csrc") - extension_sources = set( + extension_sources = { os.path.join(extensions_dir, p) for p in glob.glob(os.path.join(extensions_dir, "*.cpp")) - ) + } sources = list(extension_sources) ext_modules = [ diff --git a/test/_utils_internal.py b/test/_utils_internal.py index c5e7bb6ea45..a11bbf4ab1d 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -11,7 +11,7 @@ # this returns relative path from current file. import pytest import torch.cuda -from torchrl._utils import seed_generator, implement_for +from torchrl._utils import implement_for, seed_generator from torchrl.envs import EnvBase from torchrl.envs.libs.gym import _has_gym diff --git a/test/mocking_classes.py b/test/mocking_classes.py index e03b857c9de..db26b1c687b 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -581,9 +581,11 @@ def __new__( input_spec=None, reward_spec=None, from_pixels=True, - pixel_shape=[1, 7, 7], + pixel_shape=None, **kwargs, ): + if pixel_shape is None: + pixel_shape = [1, 7, 7] if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( diff --git a/test/smoke_test.py b/test/smoke_test.py index e091b11a9eb..630171d4082 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -1,9 +1,9 @@ def test_imports(): - from torchrl.data import ( # noqa: F401 + from torchrl.data import ( PrioritizedReplayBuffer, ReplayBuffer, TensorSpec, - ) + ) # noqa: F401 from torchrl.envs import Transform, TransformedEnv # noqa: F401 from torchrl.envs.gym_like import GymLikeEnv # noqa: F401 from torchrl.modules import TensorDictModule # noqa: F401 diff --git a/test/test_actors.py b/test/test_actors.py index fef0d3d4953..9fdf8bc3882 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -1,8 +1,8 @@ import pytest import torch from torchrl.modules.tensordict_module.actors import ( - QValueHook, DistributionalQValueHook, + QValueHook, ) diff --git a/test/test_collector.py b/test/test_collector.py index f7f94035e0e..9d384665106 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -17,7 +17,7 @@ DiscreteActionVecPolicy, MockSerialEnv, ) -from tensordict.tensordict import TensorDict, assert_allclose_td +from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn from torchrl._utils import seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector @@ -308,10 +308,10 @@ def make_env(): collector = SyncDataCollector( env, total_frames=10000, frames_per_batch=10000, split_trajs=False ) - for data in collector: + for _data in collector: continue - steps = data["step_count"][..., 1:, :] - done = data["done"][..., :-1, :] + steps = _data["step_count"][..., 1:, :] + done = _data["done"][..., :-1, :] # we don't want just one done assert done.sum() > 3 # check that after a done, the next step count is always 1 @@ -321,8 +321,8 @@ def make_env(): # check that if step is 1, then the env was done before assert (steps == 1)[done].all() # check that split traj has a minimum total reward of -21 (for pong only) - data = split_trajectories(data) - assert data["reward"].sum(-2).min() == -21 + _data = split_trajectories(_data) + assert _data["reward"].sum(-2).min() == -21 @pytest.mark.parametrize("num_env", [1, 3]) diff --git a/test/test_cost.py b/test/test_cost.py index 0a5b1785128..233bcce5788 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -976,8 +976,8 @@ def test_sac(self, delay_value, delay_actor, delay_qvalue, num_qvalue, device): named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) - assert len(set(p for n, p in named_parameters)) == len(list(named_parameters)) - assert len(set(p for n, p in named_buffers)) == len(list(named_buffers)) + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) for name, p in named_parameters: assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" @@ -1287,8 +1287,8 @@ def test_redq(self, delay_qvalue, num_qvalue, device): named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) - assert len(set(p for n, p in named_parameters)) == len(list(named_parameters)) - assert len(set(p for n, p in named_buffers)) == len(list(named_buffers)) + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) for name, p in named_parameters: assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" @@ -1365,8 +1365,8 @@ def test_redq_shared(self, delay_qvalue, num_qvalue, device): named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) - assert len(set(p for n, p in named_parameters)) == len(list(named_parameters)) - assert len(set(p for n, p in named_buffers)) == len(list(named_buffers)) + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) for name, p in named_parameters: assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" diff --git a/test/test_env.py b/test/test_env.py index fa1607041ae..97b1cd5f8e8 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -12,11 +12,11 @@ import torch import yaml from _utils_internal import ( - get_available_devices, CARTPOLE_VERSIONED, + get_available_devices, + HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED, - HALFCHEETAH_VERSIONED, ) from mocking_classes import ( ActionObsMergeLinear, @@ -389,7 +389,9 @@ def env_make(): env_make = [lambda: DMControlEnv("humanoid", tasks[0])] * 3 else: single_task = False - env_make = [lambda: DMControlEnv("humanoid", task) for task in tasks] + env_make = [ + lambda task=task: DMControlEnv("humanoid", task) for task in tasks + ] if not share_individual_td and not single_task: with pytest.raises( diff --git a/test/test_helpers.py b/test/test_helpers.py index fd8b05982fd..77807effac6 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -88,9 +88,9 @@ def _assert_keys_match(td, expeceted_keys): @pytest.mark.skipif(not _has_tv, reason="No torchvision library found") @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize("noisy", [tuple(), ("noisy=True",)]) -@pytest.mark.parametrize("distributional", [tuple(), ("distributional=True",)]) -@pytest.mark.parametrize("from_pixels", [tuple(), ("from_pixels=True", "catframes=4")]) +@pytest.mark.parametrize("noisy", [(), ("noisy=True",)]) +@pytest.mark.parametrize("distributional", [(), ("distributional=True",)]) +@pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) @pytest.mark.parametrize( "categorical_action_encoding", [("categorical_action_encoding=True",), ("categorical_action_encoding=False",)], @@ -154,8 +154,8 @@ def test_dqn_maker( @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.skipif(not _has_gym, reason="No gym library found") @pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize("from_pixels", [("from_pixels=True", "catframes=4"), tuple()]) -@pytest.mark.parametrize("gsde", [tuple(), ("gSDE=True",)]) +@pytest.mark.parametrize("from_pixels", [("from_pixels=True", "catframes=4"), ()]) +@pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) def test_ddpg_maker(device, from_pixels, gsde, exploration): if not gsde and exploration != "random": @@ -237,9 +237,9 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration): @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.skipif(not _has_gym, reason="No gym library found") @pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize("from_pixels", [tuple(), ("from_pixels=True", "catframes=4")]) -@pytest.mark.parametrize("gsde", [tuple(), ("gSDE=True",)]) -@pytest.mark.parametrize("shared_mapping", [tuple(), ("shared_mapping=True",)]) +@pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) +@pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) +@pytest.mark.parametrize("shared_mapping", [(), ("shared_mapping=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): if not gsde and exploration != "random": @@ -361,9 +361,9 @@ def test_ppo_maker(device, from_pixels, shared_mapping, gsde, exploration): @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.skipif(not _has_gym, reason="No gym library found") @pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize("from_pixels", [tuple(), ("from_pixels=True", "catframes=4")]) -@pytest.mark.parametrize("gsde", [tuple(), ("gSDE=True",)]) -@pytest.mark.parametrize("shared_mapping", [tuple(), ("shared_mapping=True",)]) +@pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) +@pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) +@pytest.mark.parametrize("shared_mapping", [(), ("shared_mapping=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): A2CModelConfig.advantage_in_loss = False @@ -493,9 +493,9 @@ def test_a2c_maker(device, from_pixels, shared_mapping, gsde, exploration): @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.skipif(not _has_gym, reason="No gym library found") @pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize("gsde", [tuple(), ("gSDE=True",)]) -@pytest.mark.parametrize("from_pixels", [tuple()]) -@pytest.mark.parametrize("tanh_loc", [tuple(), ("tanh_loc=True",)]) +@pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) +@pytest.mark.parametrize("from_pixels", [()]) +@pytest.mark.parametrize("tanh_loc", [(), ("tanh_loc=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration): if not gsde and exploration != "random": @@ -624,8 +624,8 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration): @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.skipif(not _has_gym, reason="No gym library found") @pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize("from_pixels", [tuple(), ("from_pixels=True", "catframes=4")]) -@pytest.mark.parametrize("gsde", [tuple(), ("gSDE=True",)]) +@pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) +@pytest.mark.parametrize("gsde", [(), ("gSDE=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) def test_redq_make(device, from_pixels, gsde, exploration): if not gsde and exploration != "random": @@ -729,7 +729,7 @@ def test_redq_make(device, from_pixels, gsde, exploration): to see torch < 1.11 supported for dreamer, please submit an issue.""", ) @pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize("tanh_loc", [tuple(), ("tanh_loc=True",)]) +@pytest.mark.parametrize("tanh_loc", [(), ("tanh_loc=True",)]) @pytest.mark.parametrize("exploration", ["random", "mode"]) def test_dreamer_make(device, tanh_loc, exploration, dreamer_constructor_fixture): diff --git a/test/test_libs.py b/test/test_libs.py index bb853f9642d..06e09a0521b 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -12,8 +12,8 @@ _test_fake_tensordict, get_available_devices, HALFCHEETAH_VERSIONED, - PONG_VERSIONED, PENDULUM_VERSIONED, + PONG_VERSIONED, ) from packaging import version from tensordict.tensordict import assert_allclose_td @@ -21,12 +21,10 @@ from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import RandomPolicy from torchrl.envs import EnvCreator, ParallelEnv -from torchrl.envs.libs.dm_control import DMControlEnv, DMControlWrapper -from torchrl.envs.libs.dm_control import _has_dmc -from torchrl.envs.libs.gym import GymEnv, GymWrapper -from torchrl.envs.libs.gym import _has_gym, _is_from_pixels -from torchrl.envs.libs.habitat import HabitatEnv, _has_habitat -from torchrl.envs.libs.jumanji import JumanjiEnv, _has_jumanji +from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper +from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper +from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv +from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv if _has_gym: import gym diff --git a/test/test_modules.py b/test/test_modules.py index fa5305ac75b..59822843835 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -14,8 +14,8 @@ from torch import nn from torchrl.data.tensor_specs import ( DiscreteTensorSpec, - OneHotDiscreteTensorSpec, NdBoundedTensorSpec, + OneHotDiscreteTensorSpec, ) from torchrl.modules import ( ActorValueOperator, @@ -33,10 +33,10 @@ from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear from torchrl.modules.models.model_based import ( DreamerActor, - ObsEncoder, ObsDecoder, - RSSMPrior, + ObsEncoder, RSSMPosterior, + RSSMPrior, RSSMRollout, ) from torchrl.modules.models.utils import SquashDims diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 3ca99d8ac73..0b64c6df52d 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -11,6 +11,7 @@ from scipy.stats import chisquare from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( + _keys_to_empty_composite_spec, BinaryDiscreteTensorSpec, BoundedTensorSpec, CompositeSpec, @@ -20,7 +21,6 @@ NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, - _keys_to_empty_composite_spec, ) @@ -874,7 +874,7 @@ def test_one_hot_discrete_action_spec_rand(self): sample = torch.stack([action_spec.rand() for _ in range(10000)], 0) sample_list = sample.argmax(-1) - sample_list = list([sum(sample_list == i).item() for i in range(10)]) + sample_list = [sum(sample_list == i).item() for i in range(10)] assert chisquare(sample_list).pvalue > 0.1 sample = action_spec.to_numpy(sample) @@ -888,7 +888,7 @@ def test_categorical_action_spec_rand(self): sample = action_spec.rand((10000,)) sample_list = sample[:, 0] - sample_list = list([sum(sample_list == i).item() for i in range(10)]) + sample_list = [sum(sample_list == i).item() for i in range(10)] print(sample_list) assert chisquare(sample_list).pvalue > 0.1 @@ -917,11 +917,11 @@ def test_mult_discrete_action_spec_rand(self): assert sample.ndim == 2, f"found shape: {sample.shape}" sample0 = sample[:, 0] - sample_list = list([sum(sample0 == i) for i in range(ns[0])]) + sample_list = [sum(sample0 == i) for i in range(ns[0])] assert chisquare(sample_list).pvalue > 0.1 sample1 = sample[:, 1] - sample_list = list([sum(sample1 == i) for i in range(ns[1])]) + sample_list = [sum(sample1 == i) for i in range(ns[1])] assert chisquare(sample_list).pvalue > 0.1 def test_categorical_action_spec_encode(self): diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 6910ec2b497..5b80a7872cf 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -32,8 +32,8 @@ from torchrl.envs.utils import set_exploration_mode from torchrl.modules import NormalParamWrapper, TanhNormal, TensorDictModule from torchrl.modules.tensordict_module.common import ( - is_tensordict_compatible, ensure_tensordict_compatible, + is_tensordict_compatible, ) from torchrl.modules.tensordict_module.probabilistic import ( ProbabilisticTensorDictModule, diff --git a/test/test_trainer.py b/test/test_trainer.py index 1fd576dd58d..5fbedaa8137 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -24,17 +24,18 @@ from tensordict import TensorDict from torchrl.data import ( - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, - ListStorage, LazyMemmapStorage, LazyTensorStorage, + ListStorage, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, ) from torchrl.envs.libs.gym import _has_gym from torchrl.trainers import Recorder, Trainer from torchrl.trainers.helpers import transformed_env_constructor from torchrl.trainers.trainers import ( _has_tqdm, + _has_ts, BatchSubSampler, CountFramesLog, LogReward, @@ -43,7 +44,6 @@ RewardNormalizer, SelectKeys, UpdateWeights, - _has_ts, ) @@ -72,7 +72,7 @@ def shutdown(self): pass def state_dict(self): - return dict() + return {} def load_state_dict(self, state_dict): pass diff --git a/test/test_transforms.py b/test/test_transforms.py index d5ec4da981c..9442c09bfc8 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import itertools from copy import copy, deepcopy from functools import partial @@ -10,10 +11,10 @@ import pytest import torch from _utils_internal import ( # noqa - get_available_devices, - retry, dtype_fixture, + get_available_devices, PENDULUM_VERSIONED, + retry, ) from mocking_classes import ( ContinuousActionVecMockEnv, @@ -25,11 +26,11 @@ from torch import multiprocessing as mp, Tensor from torchrl._utils import prod from torchrl.data import ( + BoundedTensorSpec, CompositeSpec, NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, UnboundedContinuousTensorSpec, - BoundedTensorSpec, ) from torchrl.envs import ( BinarizeReward, @@ -55,15 +56,15 @@ from torchrl.envs.transforms import TransformedEnv, VecNorm from torchrl.envs.transforms.r3m import _R3MNet from torchrl.envs.transforms.transforms import ( - DiscreteActionProjection, _has_tv, CenterCrop, + DiscreteActionProjection, + gSDENoise, NoopResetEnv, PinMemoryTransform, SqueezeTransform, TensorDictPrimer, UnsqueezeTransform, - gSDENoise, ) from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform @@ -811,7 +812,7 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): def test_compose_inv(self, keys_inv_1, keys_inv_2, device): torch.manual_seed(0) keys_to_transform = set(keys_inv_1 + keys_inv_2) - keys_total = set(["action_1", "action_2", "dont_touch"]) + keys_total = {"action_1", "action_2", "dont_touch"} double2float_1 = DoubleToFloat(in_keys_inv=keys_inv_1) double2float_2 = DoubleToFloat(in_keys_inv=keys_inv_2) compose = Compose(double2float_1, double2float_2) @@ -1289,7 +1290,7 @@ def test_binarized_reward(self, device, batch): def test_reward_scaling(self, batch, scale, loc, keys, device, standard_normal): torch.manual_seed(0) if keys is None: - keys_total = set([]) + keys_total = set() else: keys_total = set(keys) reward_scaling = RewardScaling( @@ -1336,7 +1337,7 @@ def test_pin_mem(self, device): def test_append(self): env = ContinuousActionVecMockEnv() obs_spec = env.observation_spec - key = list(obs_spec.keys())[0] + (key,) = itertools.islice(obs_spec.keys(), 1) env = TransformedEnv(env) env.append_transform(CatFrames(N=4, cat_dim=-1, in_keys=[key])) @@ -1350,7 +1351,7 @@ def test_insert(self): env = ContinuousActionVecMockEnv() obs_spec = env.observation_spec - key = list(obs_spec.keys())[0] + (key,) = itertools.islice(obs_spec.keys(), 1) env = TransformedEnv(env) # we start by asking the spec. That will create the private attributes diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 044632b5c1f..bdaba813870 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -86,7 +86,7 @@ def seed_generator(seed): """ max_seed_val = ( - 2 ** 32 - 1 + 2**32 - 1 ) # https://discuss.pytorch.org/t/what-is-the-max-seed-you-can-set-up/145688 rng = np.random.default_rng(seed) seed = int.from_bytes(rng.bytes(8), "big") diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 7f72b75d990..07fc0542939 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -4,8 +4,8 @@ # LICENSE file in the root directory of this source tree. from .collectors import ( - SyncDataCollector, aSyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector, + SyncDataCollector, ) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 31d6d021554..72d63332ca9 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -12,17 +12,18 @@ from copy import deepcopy from multiprocessing import connection, queues from textwrap import indent -from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, Any, Dict +from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn as nn -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset from torchrl.envs.transforms import TransformedEnv from torchrl.envs.utils import set_exploration_mode, step_mdp + from .._utils import _check_for_faulty_process, prod from ..data import TensorSpec from ..data.utils import CloudpickleWrapper, DEVICE_TYPING @@ -142,7 +143,7 @@ def _get_policy_and_device( """ # if create_env_fn is not None: # if create_env_kwargs is None: - # create_env_kwargs = dict() + # create_env_kwargs = {} # self.create_env_fn = create_env_fn # if isinstance(create_env_fn, EnvBase): # env = create_env_fn diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 370df9f01fd..44822cbfa7e 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -5,27 +5,27 @@ from .postprocs import MultiStep from .replay_buffers import ( - ReplayBuffer, - PrioritizedReplayBuffer, - TensorDictReplayBuffer, - TensorDictPrioritizedReplayBuffer, - Storage, - ListStorage, LazyMemmapStorage, LazyTensorStorage, + ListStorage, + PrioritizedReplayBuffer, + ReplayBuffer, + Storage, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, ) from .tensor_specs import ( - TensorSpec, + BinaryDiscreteTensorSpec, BoundedTensorSpec, - OneHotDiscreteTensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, + CompositeSpec, + DEVICE_TYPING, + DiscreteTensorSpec, + MultOneHotDiscreteTensorSpec, NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, NdUnboundedDiscreteTensorSpec, - BinaryDiscreteTensorSpec, - MultOneHotDiscreteTensorSpec, - DiscreteTensorSpec, - CompositeSpec, - DEVICE_TYPING, + OneHotDiscreteTensorSpec, + TensorSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, ) diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 98d69dfb07b..6d6ed79d4b0 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -132,7 +132,7 @@ def __init__( self.register_buffer( "gammas", torch.tensor( - [gamma ** i for i in range(n_steps_max + 1)], + [gamma**i for i in range(n_steps_max + 1)], dtype=torch.float, ).reshape(1, 1, -1), ) diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 1822637ec5e..53e363855ef 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -4,9 +4,9 @@ # LICENSE file in the root directory of this source tree. from .replay_buffers import ( - ReplayBuffer, PrioritizedReplayBuffer, - TensorDictReplayBuffer, + ReplayBuffer, TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, ) -from .storages import Storage, ListStorage, LazyMemmapStorage, LazyTensorStorage +from .storages import LazyMemmapStorage, LazyTensorStorage, ListStorage, Storage diff --git a/torchrl/data/replay_buffers/rb_prototype.py b/torchrl/data/replay_buffers/rb_prototype.py index a80e1abdf28..8534bba46b1 100644 --- a/torchrl/data/replay_buffers/rb_prototype.py +++ b/torchrl/data/replay_buffers/rb_prototype.py @@ -1,17 +1,18 @@ import collections import threading from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Optional, Sequence, Union, Tuple, List +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import torch -from tensordict.tensordict import TensorDictBase, LazyStackedTensorDict +from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torchrl.envs.transforms.transforms import Compose, Transform + from .replay_buffers import pin_memory_output -from .samplers import Sampler, RandomSampler -from .storages import Storage, ListStorage, _get_default_collate -from .utils import INT_CLASSES, _to_numpy, accept_remote_rref_udf_invocation -from .writers import Writer, RoundRobinWriter +from .samplers import RandomSampler, Sampler +from .storages import _get_default_collate, ListStorage, Storage +from .utils import _to_numpy, accept_remote_rref_udf_invocation, INT_CLASSES +from .writers import RoundRobinWriter, Writer class ReplayBuffer: diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 8abd435738d..9dccba7ad07 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -7,14 +7,11 @@ import concurrent.futures import threading from copy import deepcopy -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Dict +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch -from tensordict.tensordict import ( - TensorDictBase, - LazyStackedTensorDict, -) +from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import Tensor from torchrl._torchrl import ( @@ -24,15 +21,11 @@ SumSegmentTreeFp64, ) from torchrl.data.replay_buffers.storages import ( - Storage, - ListStorage, _get_default_collate, + ListStorage, + Storage, ) -from torchrl.data.replay_buffers.utils import INT_CLASSES -from torchrl.data.replay_buffers.utils import ( - _to_numpy, - _to_torch, -) +from torchrl.data.replay_buffers.utils import _to_numpy, _to_torch, INT_CLASSES from torchrl.data.utils import DEVICE_TYPING diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index b45522b1e0e..2bb159d0b8d 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Union, Tuple +from typing import Any, Tuple, Union import numpy as np import torch @@ -10,8 +10,9 @@ SumSegmentTreeFp32, SumSegmentTreeFp64, ) + from .storages import Storage -from .utils import INT_CLASSES, _to_numpy +from .utils import _to_numpy, INT_CLASSES class Sampler(ABC): @@ -22,18 +23,18 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: raise NotImplementedError def add(self, index: int) -> None: - pass + return def extend(self, index: torch.Tensor) -> None: - pass + return def update_priority( self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] ) -> dict: - pass + return def mark_update(self, index: Union[int, torch.Tensor]) -> None: - pass + return @property def default_priority(self) -> float: diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index dd8df096854..ec347be62ed 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -12,19 +12,19 @@ from typing import ( Any, Dict, + ItemsView, + KeysView, List, Optional, Sequence, Tuple, Union, - ItemsView, - KeysView, ValuesView, ) import numpy as np import torch -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torchrl._utils import get_binary_env_var @@ -60,7 +60,7 @@ class invertible_dict(dict): def __init__(self, *args, inv_dict=None, **kwargs): if inv_dict is None: - inv_dict = dict() + inv_dict = {} super().__init__(*args, **kwargs) self.inv_dict = inv_dict @@ -1403,4 +1403,7 @@ def __iter__( yield key def __len__(self): - return len([k for k in self]) + i = 0 + for _ in self: + i += 1 + return i diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 85f7ed25035..04be8e7320a 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -3,35 +3,35 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .common import Specs, make_tensordict, EnvBase, EnvMetaData +from .common import EnvBase, EnvMetaData, make_tensordict, Specs from .env_creator import EnvCreator, get_env_metadata -from .gym_like import GymLikeEnv, default_info_dict_reader +from .gym_like import default_info_dict_reader, GymLikeEnv from .model_based import ModelBasedEnvBase from .transforms import ( - R3MTransform, - Transform, - TransformedEnv, - RewardClipping, - Resize, + BinarizeReward, + CatFrames, + CatTensors, CenterCrop, - GrayScale, Compose, - ToTensorImage, - ObservationNorm, - FlattenObservation, - UnsqueezeTransform, - RewardScaling, - ObservationTransform, - CatFrames, - FiniteTensorDictCheck, DoubleToFloat, - CatTensors, + FiniteTensorDictCheck, + FlattenObservation, + GrayScale, + gSDENoise, NoopResetEnv, - BinarizeReward, + ObservationNorm, + ObservationTransform, PinMemoryTransform, - VecNorm, - gSDENoise, + R3MTransform, + Resize, + RewardClipping, + RewardScaling, TensorDictPrimer, + ToTensorImage, + Transform, + TransformedEnv, + UnsqueezeTransform, + VecNorm, VIPTransform, ) -from .vec_env import SerialEnv, ParallelEnv +from .vec_env import ParallelEnv, SerialEnv diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index d5b3517dd76..b9bf02fcdc5 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -8,15 +8,16 @@ import abc from copy import deepcopy from numbers import Number -from typing import Any, Callable, Iterator, Optional, Union, Dict, Sequence +from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Union import numpy as np import torch import torch.nn as nn -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import CompositeSpec, TensorSpec -from .._utils import seed_generator, prod + +from .._utils import prod, seed_generator from ..data.utils import DEVICE_TYPING from .utils import get_available_libraries, step_mdp diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 8d1a1de5120..0fbb2b15943 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -84,7 +84,7 @@ def __init__( self.create_env_fn = create_env_fn self.create_env_kwargs = ( - create_env_kwargs if isinstance(create_env_kwargs, dict) else dict() + create_env_kwargs if isinstance(create_env_kwargs, dict) else {} ) self.initialized = False self._meta_data = None @@ -174,7 +174,7 @@ def get_env_metadata( ): # then env is a creator if kwargs is None: - kwargs = dict() + kwargs = {} env = env_or_creator(**kwargs) return EnvMetaData.build_metadata_from_env(env) elif isinstance(env_or_creator, EnvCreator): diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index a0c61b484e1..7f832a10b74 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -6,8 +6,9 @@ from __future__ import annotations import abc +import itertools import warnings -from typing import List, Optional, Sequence, Union, Tuple, Any, Dict +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -167,7 +168,7 @@ def read_obs( if isinstance(observations, dict): observations = {key: value for key, value in observations.items()} if not isinstance(observations, (TensorDict, dict)): - key = list(self.observation_spec.keys())[0] + (key,) = itertools.islice(self.observation_spec.keys(), 1) observations = {key: observations} observations = self.observation_spec.encode(observations) return observations diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 6fab27309ea..5ac931f7ee0 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -6,7 +6,7 @@ import collections import os -from typing import Optional, Tuple, Union, Dict, Any +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -15,10 +15,11 @@ CompositeSpec, NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, - TensorSpec, NdUnboundedDiscreteTensorSpec, + TensorSpec, ) -from ...data.utils import numpy_to_torch_dtype_dict, DEVICE_TYPING + +from ...data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict from ..gym_like import GymLikeEnv if torch.has_cuda and torch.cuda.device_count() > 1: @@ -80,10 +81,10 @@ def _dmcontrol_to_torchrl_spec_transform( def _get_envs(to_dict: bool = True) -> Dict[str, Any]: if not _has_dmc: - return dict() + return {} if not to_dict: return tuple(suite.BENCHMARKING) + tuple(suite.EXTRA) - d = dict() + d = {} for tup in suite.BENCHMARKING: env_name = tup[0] d.setdefault(env_name, []).append(tup[1]) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index cf284f4e7db..1c4d0a2680c 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import warnings from types import ModuleType -from typing import List, Dict +from typing import Dict, List from warnings import warn import torch @@ -19,9 +19,10 @@ TensorSpec, UnboundedContinuousTensorSpec, ) + from ..._utils import implement_for from ...data.utils import numpy_to_torch_dtype_dict -from ..gym_like import GymLikeEnv, default_info_dict_reader +from ..gym_like import default_info_dict_reader, GymLikeEnv from ..utils import _classproperty try: diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 1b6f691cef1..ef4bf02b8d4 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -1,19 +1,19 @@ import dataclasses -from typing import Optional, Dict, Union +from typing import Dict, Optional, Union import numpy as np import torch -from tensordict.tensordict import TensorDict, TensorDictBase, make_tensordict +from tensordict.tensordict import make_tensordict, TensorDict, TensorDictBase from torchrl.data import ( - DEVICE_TYPING, - TensorSpec, CompositeSpec, + DEVICE_TYPING, DiscreteTensorSpec, - OneHotDiscreteTensorSpec, NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, NdUnboundedDiscreteTensorSpec, + OneHotDiscreteTensorSpec, + TensorSpec, ) from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs import GymLikeEnv diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 189ecb4bdcf..128408dd6f2 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -5,7 +5,7 @@ import abc from copy import deepcopy -from typing import Optional, Union, List +from typing import List, Optional, Union import numpy as np import torch diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 9e968eee6f8..432682812b2 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -3,8 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Union -from typing import Tuple +from typing import Optional, Tuple, Union import numpy as np import torch diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index cdb4d35f4ce..098c7545812 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -5,29 +5,29 @@ from .r3m import R3MTransform from .transforms import ( - Transform, - TransformedEnv, - RewardClipping, - Resize, + BinarizeReward, + CatFrames, + CatTensors, CenterCrop, - GrayScale, Compose, - ToTensorImage, - ObservationNorm, - FlattenObservation, - UnsqueezeTransform, - RewardScaling, - ObservationTransform, - CatFrames, - FiniteTensorDictCheck, DoubleToFloat, - CatTensors, + FiniteTensorDictCheck, + FlattenObservation, + GrayScale, + gSDENoise, NoopResetEnv, - BinarizeReward, + ObservationNorm, + ObservationTransform, PinMemoryTransform, - VecNorm, - gSDENoise, - TensorDictPrimer, + Resize, + RewardClipping, + RewardScaling, SqueezeTransform, + TensorDictPrimer, + ToTensorImage, + Transform, + TransformedEnv, + UnsqueezeTransform, + VecNorm, ) -from .vip import VIPTransform, VIPRewardTransform +from .vip import VIPRewardTransform, VIPTransform diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index 2d7a7ef8df8..43b3d61c069 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -11,19 +11,19 @@ from torch.nn import Identity from torchrl.data.tensor_specs import ( - TensorSpec, CompositeSpec, NdUnboundedContinuousTensorSpec, + TensorSpec, ) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms.transforms import ( - ToTensorImage, + CatTensors, Compose, + FlattenObservation, ObservationNorm, Resize, + ToTensorImage, Transform, - CatTensors, - FlattenObservation, UnsqueezeTransform, ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 410cb639fb3..30fe1148a68 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -7,23 +7,23 @@ import collections import multiprocessing as mp -from copy import deepcopy, copy +from copy import copy, deepcopy from textwrap import indent from typing import Any, List, Optional, OrderedDict, Sequence, Tuple, Union import torch -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, Tensor from torchrl.data.tensor_specs import ( + BinaryDiscreteTensorSpec, BoundedTensorSpec, CompositeSpec, ContinuousBox, + DEVICE_TYPING, NdUnboundedContinuousTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, - BinaryDiscreteTensorSpec, - DEVICE_TYPING, ) from torchrl.envs.common import EnvBase, make_tensordict from torchrl.envs.transforms import functional as F @@ -1723,7 +1723,7 @@ def __init__( "Lazy call to CatTensors is only supported when `dim=-1`." ) else: - in_keys = sorted(list(in_keys)) + in_keys = sorted(in_keys) if type(out_key) != str: raise Exception("CatTensors requires out_key to be of type string") # super().__init__(in_keys=in_keys) @@ -1767,8 +1767,8 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: else: raise Exception( f"CatTensor failed, as it expected input keys =" - f" {sorted(list(self.in_keys))} but got a TensorDict with keys" - f" {sorted(list(tensordict.keys(include_nested=True)))}" + f" {sorted(self.in_keys)} but got a TensorDict with keys" + f" {sorted(tensordict.keys(include_nested=True))}" ) return tensordict @@ -2340,7 +2340,7 @@ def build_td_for_shared_vecnorm( if keys is None: keys = ["next", "reward"] td = make_tensordict(env) - keys = set(key for key in td.keys() if key in keys) + keys = {key for key in td.keys() if key in keys} td_select = td.select(*keys) td_select = td_select.flatten_keys(sep) if td.batch_dims: diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 8167593f6ef..ecb571656a3 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -11,19 +11,19 @@ from torch.hub import load_state_dict_from_url from torchrl.data.tensor_specs import ( - TensorSpec, CompositeSpec, NdUnboundedContinuousTensorSpec, + TensorSpec, ) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms import ( - ToTensorImage, + CatTensors, Compose, + FlattenObservation, ObservationNorm, Resize, + ToTensorImage, Transform, - CatTensors, - FlattenObservation, UnsqueezeTransform, ) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index b8cdece106b..4654e3b9236 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -11,16 +11,16 @@ from multiprocessing import connection from multiprocessing.synchronize import Lock as MpLock from time import sleep -from typing import Callable, Optional, Sequence, Union, Any, List, Dict +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from warnings import warn import torch from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase, LazyStackedTensorDict +from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp from torchrl._utils import _check_for_faulty_process -from torchrl.data import TensorSpec, CompositeSpec +from torchrl.data import CompositeSpec, TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import get_env_metadata @@ -172,7 +172,7 @@ def __init__( "share_individual_td must be set to None or True when using multi-task batched environments" ) share_individual_td = True - create_env_kwargs = dict() if create_env_kwargs is None else create_env_kwargs + create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs if isinstance(create_env_kwargs, dict): create_env_kwargs = [ deepcopy(create_env_kwargs) for _ in range(num_workers) @@ -407,17 +407,17 @@ def _create_td(self) -> None: ) else: if self._single_task: - self.env_input_keys = sorted(list(self.input_spec.keys())) + self.env_input_keys = sorted(self.input_spec.keys()) else: env_input_keys = set() for meta_data in self.meta_data: env_input_keys = env_input_keys.union( meta_data.specs["input_spec"].keys() ) - self.env_input_keys = sorted(list(env_input_keys)) + self.env_input_keys = sorted(env_input_keys) if not len(self.env_input_keys): raise RuntimeError( - f"found 0 action keys in {sorted(list(self.selected_keys))}" + f"found 0 action keys in {sorted(self.selected_keys)}" ) if self._single_task: shared_tensordict_parent = shared_tensordict_parent.select( diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 8d1cdd8203e..8460e40160b 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -4,58 +4,58 @@ # LICENSE file in the root directory of this source tree. from .distributions import ( - NormalParamWrapper, - TanhNormal, Delta, - TanhDelta, - TruncatedNormal, + distributions_maps, IndependentNormal, + NormalParamWrapper, OneHotCategorical, - distributions_maps, + TanhDelta, + TanhNormal, + TruncatedNormal, ) from .functional_modules import ( + extract_buffers, + extract_weights, FunctionalModule, FunctionalModuleWithBuffers, - extract_weights, - extract_buffers, ) from .models import ( - NoisyLinear, - NoisyLazyLinear, - reset_noise, - DreamerActor, - ObsEncoder, - ObsDecoder, - RSSMPrior, - RSSMPosterior, - MLP, ConvNet, - DuelingCnnDQNet, - DistributionalDQNnet, DdpgCnnActor, DdpgCnnQNet, DdpgMlpActor, DdpgMlpQNet, + DistributionalDQNnet, + DreamerActor, + DuelingCnnDQNet, LSTMNet, - SqueezeLayer, + MLP, + NoisyLazyLinear, + NoisyLinear, + ObsDecoder, + ObsEncoder, + reset_noise, + RSSMPosterior, + RSSMPrior, Squeeze2dLayer, + SqueezeLayer, ) from .tensordict_module import ( Actor, - ActorValueOperator, - ValueOperator, - ProbabilisticActor, - QValueActor, ActorCriticOperator, ActorCriticWrapper, + ActorValueOperator, + AdditiveGaussianWrapper, DistributionalQValueActor, - TensorDictModule, - TensorDictModuleWrapper, EGreedyWrapper, - AdditiveGaussianWrapper, OrnsteinUhlenbeckProcessWrapper, + ProbabilisticActor, ProbabilisticTensorDictModule, + QValueActor, + TensorDictModule, + TensorDictModuleWrapper, TensorDictSequential, + ValueOperator, WorldModelWrapper, ) from .planners import CEMPlanner, MPCPlannerBase # usort:skip diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index 02ee42f0298..05069e19b52 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -4,16 +4,15 @@ # LICENSE file in the root directory of this source tree. from .continuous import ( - NormalParamWrapper, - TanhNormal, + __all__ as _all_continuous, Delta, + IndependentNormal, + NormalParamWrapper, TanhDelta, + TanhNormal, TruncatedNormal, - IndependentNormal, ) -from .continuous import __all__ as _all_continuous -from .discrete import OneHotCategorical -from .discrete import __all__ as _all_discrete +from .discrete import __all__ as _all_discrete, OneHotCategorical distributions_maps = { distribution_class.lower(): eval(distribution_class) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 090fc6af941..da136f67208 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from numbers import Number -from typing import Dict, Sequence, Union, Optional, Tuple +from typing import Dict, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -144,7 +144,7 @@ def __init__( def forward(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor]: net_output = self.operator(*tensors) - others = tuple() + others = () if not isinstance(net_output, torch.Tensor): net_output, *others = net_output loc, scale = net_output.chunk(2, -1) diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index 0cccf120bcf..f733dcac5f7 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -98,7 +98,7 @@ def auc(self): @staticmethod def _little_phi(x): - return (-(x ** 2) * 0.5).exp() * CONST_INV_SQRT_2PI + return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI def _big_phi(self, x): phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) @@ -121,7 +121,7 @@ def icdf(self, value): def log_prob(self, value): if self._validate_args: self._validate_sample(value) - return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value ** 2) * 0.5 + return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5 def rsample(self, sample_shape=None): if sample_shape is None: @@ -151,7 +151,7 @@ def __init__(self, loc, scale, a, b, validate_args=None): super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args) self._log_scale = self.scale.log() self._mean = self._mean * self.scale + self.loc - self._variance = self._variance * self.scale ** 2 + self._variance = self._variance * self.scale**2 self._entropy += self._log_scale def _to_std_rv(self, value): diff --git a/torchrl/modules/functional_modules.py b/torchrl/modules/functional_modules.py index 7e7160e0fef..b40a6a22242 100644 --- a/torchrl/modules/functional_modules.py +++ b/torchrl/modules/functional_modules.py @@ -21,14 +21,14 @@ # Monky-patch functorch, mainly for cases where a "isinstance(obj, Tensor) is invoked if _has_functorch: from functorch._src.vmap import ( - _get_name, - tree_flatten, + _add_batch_dim, _broadcast_to_and_flatten, - Tensor, + _get_name, + _remove_batch_dim, _validate_and_get_batch_size, - _add_batch_dim, + Tensor, + tree_flatten, tree_unflatten, - _remove_batch_dim, ) # Monkey-patches @@ -100,7 +100,7 @@ def _create_batched_inputs(flat_in_dims, flat_args, vmap_level: int, args_spec): arg if in_dim is None else arg.apply( - lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level), + lambda _arg, in_dim=in_dim: _add_batch_dim(_arg, in_dim, vmap_level), batch_size=[b for i, b in enumerate(arg.batch_size) if i != in_dim], ) if isinstance(arg, TensorDictBase) @@ -155,7 +155,9 @@ def incompatible_error(): out = _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) else: out = batched_output.apply( - lambda x: _remove_batch_dim(x, vmap_level, batch_size, out_dim), + lambda x, out_dim=out_dim: _remove_batch_dim( + x, vmap_level, batch_size, out_dim + ), batch_size=[batch_size, *batched_output.batch_size], ) flat_outputs.append(out) diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 096a6560f89..8654d338c18 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -4,17 +4,17 @@ # LICENSE file in the root directory of this source tree. -from .exploration import NoisyLinear, NoisyLazyLinear, reset_noise -from .model_based import DreamerActor, ObsEncoder, ObsDecoder, RSSMPrior, RSSMPosterior +from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise +from .model_based import DreamerActor, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior from .models import ( - MLP, ConvNet, - DuelingCnnDQNet, - DistributionalDQNnet, DdpgCnnActor, DdpgCnnQNet, DdpgMlpActor, DdpgMlpQNet, + DistributionalDQNnet, + DuelingCnnDQNet, LSTMNet, + MLP, ) -from .utils import SqueezeLayer, Squeeze2dLayer +from .utils import Squeeze2dLayer, SqueezeLayer diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 551eaf6a6bb..a18454fbbdf 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -7,7 +7,7 @@ from typing import Optional, Sequence, Union import torch -from torch import nn, distributions as d +from torch import distributions as d, nn from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import UninitializedBuffer, UninitializedParameter diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 9e57a7c6bc3..8c17c614826 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -174,14 +174,14 @@ def __init__( self._out_features_num = _out_features_num self.activation_class = activation_class self.activation_kwargs = ( - activation_kwargs if activation_kwargs is not None else dict() + activation_kwargs if activation_kwargs is not None else {} ) self.norm_class = norm_class - self.norm_kwargs = norm_kwargs if norm_kwargs is not None else dict() + self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} self.bias_last_layer = bias_last_layer self.single_bias_last_layer = single_bias_last_layer self.layer_class = layer_class - self.layer_kwargs = layer_kwargs if layer_kwargs is not None else dict() + self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {} self.activate_last_layer = activate_last_layer if single_bias_last_layer: raise NotImplementedError @@ -363,10 +363,10 @@ def __init__( self.in_features = in_features self.activation_class = activation_class self.activation_kwargs = ( - activation_kwargs if activation_kwargs is not None else dict() + activation_kwargs if activation_kwargs is not None else {} ) self.norm_class = norm_class - self.norm_kwargs = norm_kwargs if norm_kwargs is not None else dict() + self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} self.bias_last_layer = bias_last_layer self.aggregator_class = aggregator_class self.aggregator_kwargs = ( @@ -516,7 +516,7 @@ def __init__( super().__init__() mlp_kwargs_feature = ( - mlp_kwargs_feature if mlp_kwargs_feature is not None else dict() + mlp_kwargs_feature if mlp_kwargs_feature is not None else {} ) _mlp_kwargs_feature = { "num_cells": [256, 256], @@ -533,9 +533,7 @@ def __init__( "num_cells": 512, "bias_last_layer": True, } - mlp_kwargs_output = ( - mlp_kwargs_output if mlp_kwargs_output is not None else dict() - ) + mlp_kwargs_output = mlp_kwargs_output if mlp_kwargs_output is not None else {} _mlp_kwargs_output.update(mlp_kwargs_output) self.out_features = out_features self.out_features_value = out_features_value @@ -598,7 +596,7 @@ def __init__( ): super().__init__() - cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else dict() + cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {} _cnn_kwargs = { "num_cells": [32, 64, 64], "strides": [4, 2, 1], @@ -613,7 +611,7 @@ def __init__( "num_cells": 512, "bias_last_layer": True, } - mlp_kwargs = mlp_kwargs if mlp_kwargs is not None else dict() + mlp_kwargs = mlp_kwargs if mlp_kwargs is not None else {} _mlp_kwargs.update(mlp_kwargs) self.out_features = out_features self.out_features_value = out_features_value @@ -748,7 +746,7 @@ def __init__( else {"output_size": (1, 1)}, "squeeze_output": use_avg_pooling, } - conv_net_kwargs = conv_net_kwargs if conv_net_kwargs is not None else dict() + conv_net_kwargs = conv_net_kwargs if conv_net_kwargs is not None else {} conv_net_default_kwargs.update(conv_net_kwargs) mlp_net_default_kwargs = { "in_features": None, @@ -758,7 +756,7 @@ def __init__( "activation_class": nn.ELU, "bias_last_layer": True, } - mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else dict() + mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {} mlp_net_default_kwargs.update(mlp_net_kwargs) self.convnet = ConvNet(device=device, **conv_net_default_kwargs) self.mlp = MLP(device=device, **mlp_net_default_kwargs) @@ -808,7 +806,7 @@ def __init__( "activation_class": nn.ELU, "bias_last_layer": True, } - mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else dict() + mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {} mlp_net_default_kwargs.update(mlp_net_kwargs) self.mlp = MLP(device=device, **mlp_net_default_kwargs) ddpg_init_last_layer(self.mlp[-1], 6e-3, device=device) @@ -878,7 +876,7 @@ def __init__( else {"output_size": (1, 1)}, "squeeze_output": use_avg_pooling, } - conv_net_kwargs = conv_net_kwargs if conv_net_kwargs is not None else dict() + conv_net_kwargs = conv_net_kwargs if conv_net_kwargs is not None else {} conv_net_default_kwargs.update(conv_net_kwargs) mlp_net_default_kwargs = { "in_features": None, @@ -888,7 +886,7 @@ def __init__( "activation_class": nn.ELU, "bias_last_layer": True, } - mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else dict() + mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {} mlp_net_default_kwargs.update(mlp_net_kwargs) self.convnet = ConvNet(device=device, **conv_net_default_kwargs) self.mlp = MLP(device=device, **mlp_net_default_kwargs) @@ -949,7 +947,7 @@ def __init__( "activate_last_layer": True, } mlp_net_kwargs_net1: Dict = ( - mlp_net_kwargs_net1 if mlp_net_kwargs_net1 is not None else dict() + mlp_net_kwargs_net1 if mlp_net_kwargs_net1 is not None else {} ) mlp1_net_default_kwargs.update(mlp_net_kwargs_net1) self.mlp1 = MLP(device=device, **mlp1_net_default_kwargs) @@ -964,7 +962,7 @@ def __init__( "bias_last_layer": True, } mlp_net_kwargs_net2 = ( - mlp_net_kwargs_net2 if mlp_net_kwargs_net2 is not None else dict() + mlp_net_kwargs_net2 if mlp_net_kwargs_net2 is not None else {} ) mlp2_net_default_kwargs.update(mlp_net_kwargs_net2) self.mlp2 = MLP(device=device, **mlp2_net_default_kwargs) diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py index 07b84072335..3e8515ab7aa 100644 --- a/torchrl/modules/models/utils.py +++ b/torchrl/modules/models/utils.py @@ -10,6 +10,7 @@ from torch import nn from torchrl.data.utils import DEVICE_TYPING + from .exploration import NoisyLazyLinear, NoisyLinear LazyMapping = { diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 47558bcdbc5..06cc365104f 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -5,18 +5,18 @@ from .actors import ( Actor, - ActorValueOperator, - ValueOperator, - ProbabilisticActor, - QValueActor, ActorCriticOperator, ActorCriticWrapper, + ActorValueOperator, DistributionalQValueActor, + ProbabilisticActor, + QValueActor, + ValueOperator, ) from .common import TensorDictModule, TensorDictModuleWrapper from .exploration import ( - EGreedyWrapper, AdditiveGaussianWrapper, + EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ) from .probabilistic import ProbabilisticTensorDictModule diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 22384bdce53..1092064eab9 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -8,7 +8,7 @@ import torch from torch import nn -from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, TensorSpec +from torchrl.data import CompositeSpec, TensorSpec, UnboundedContinuousTensorSpec from torchrl.modules.models.models import DistributionalDQNnet from torchrl.modules.tensordict_module.common import ( TensorDictModule, diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 81ca1c33f1b..82294617a67 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -9,15 +9,7 @@ import warnings from copy import deepcopy from textwrap import indent -from typing import ( - Any, - Iterable, - List, - Optional, - Sequence, - Type, - Union, -) +from typing import Any, Iterable, List, Optional, Sequence, Type, Union import torch @@ -45,10 +37,7 @@ from tensordict.tensordict import TensorDictBase from torch import nn, Tensor -from torchrl.data import ( - TensorSpec, - CompositeSpec, -) +from torchrl.data import CompositeSpec, TensorSpec from torchrl.modules.functional_modules import ( FunctionalModule as rlFunctionalModule, FunctionalModuleWithBuffers as rlFunctionalModuleWithBuffers, diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 0e33039f89d..316f9536d25 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -6,16 +6,15 @@ import re from copy import deepcopy from textwrap import indent -from typing import List, Sequence, Union, Type, Optional, Tuple +from typing import List, Optional, Sequence, Tuple, Type, Union from tensordict.tensordict import TensorDictBase -from torch import Tensor -from torch import distributions as d +from torch import distributions as d, Tensor from torchrl.data import TensorSpec from torchrl.envs.utils import exploration_mode, set_exploration_mode -from torchrl.modules.distributions import distributions_maps, Delta -from torchrl.modules.tensordict_module.common import TensorDictModule, _check_all_str +from torchrl.modules.distributions import Delta, distributions_maps +from torchrl.modules.tensordict_module.common import _check_all_str, TensorDictModule class ProbabilisticTensorDictModule(TensorDictModule): @@ -185,7 +184,7 @@ def __init__( distribution_class = distributions_maps.get(distribution_class.lower()) self.distribution_class = distribution_class self.distribution_kwargs = ( - distribution_kwargs if distribution_kwargs is not None else dict() + distribution_kwargs if distribution_kwargs is not None else {} ) self.n_empirical_estimate = n_empirical_estimate self._dist = None diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index a1f3b96f8f2..057ad1e9afd 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -22,12 +22,8 @@ FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." import torch -from tensordict.tensordict import ( - LazyStackedTensorDict, - TensorDict, - TensorDictBase, -) -from torch import Tensor, nn +from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from torch import nn, Tensor from torchrl.data import CompositeSpec from torchrl.modules.tensordict_module.common import TensorDictModule @@ -243,7 +239,7 @@ def select_subsequence( in_keys = deepcopy(self.in_keys) if out_keys is None: out_keys = deepcopy(self.out_keys) - id_to_keep = {i for i in range(len(self.module))} + id_to_keep = set(range(len(self.module))) for i, module in enumerate(self.module): if all(key in in_keys for key in module.in_keys): in_keys.extend(module.out_keys) @@ -255,7 +251,7 @@ def select_subsequence( out_keys.extend(module.in_keys) else: id_to_keep.remove(i) - id_to_keep = sorted(list(id_to_keep)) + id_to_keep = sorted(id_to_keep) modules = [self.module[i] for i in id_to_keep] @@ -306,8 +302,8 @@ def forward( if params is not None and buffers is not None: if isinstance(params, TensorDictBase): # TODO: implement sorted values and items - param_splits = list(zip(*sorted(list(params.items()))))[1] - buffer_splits = list(zip(*sorted(list(buffers.items()))))[1] + param_splits = list(zip(*sorted(params.items())))[1] + buffer_splits = list(zip(*sorted(buffers.items())))[1] else: param_splits = self._split_param(params, "params") buffer_splits = self._split_param(buffers, "buffers") @@ -330,7 +326,7 @@ def forward( elif params is not None: if isinstance(params, TensorDictBase): # TODO: implement sorted values and items - param_splits = list(zip(*sorted(list(params.items()))))[1] + param_splits = list(zip(*sorted(params.items())))[1] else: param_splits = self._split_param(params, "params") for i, (module, param) in enumerate(zip(self.module, param_splits)): @@ -450,8 +446,8 @@ def get_dist( params = kwargs["params"] buffers = kwargs["buffers"] if isinstance(params, TensorDictBase): - param_splits = list(zip(*sorted(list(params.items()))))[1] - buffer_splits = list(zip(*sorted(list(buffers.items()))))[1] + param_splits = list(zip(*sorted(params.items())))[1] + buffer_splits = list(zip(*sorted(buffers.items())))[1] else: param_splits = self._split_param(kwargs["params"], "params") buffer_splits = self._split_param(kwargs["buffers"], "buffers") @@ -478,7 +474,7 @@ def get_dist( elif "params" in kwargs: params = kwargs["params"] if isinstance(params, TensorDictBase): - param_splits = list(zip(*sorted(list(params.items()))))[1] + param_splits = list(zip(*sorted(params.items())))[1] else: param_splits = self._split_param(kwargs["params"], "params") kwargs_pruned = { diff --git a/torchrl/modules/utils/__init__.py b/torchrl/modules/utils/__init__.py index 0bbf5182e08..4af16165f7c 100644 --- a/torchrl/modules/utils/__init__.py +++ b/torchrl/modules/utils/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .mappings import mappings, inv_softplus, biased_softplus +from .mappings import biased_softplus, inv_softplus, mappings diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 5b502e25dfd..2cf6a1e8eef 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -6,18 +6,18 @@ from .a2c import A2CLoss from .common import LossModule from .ddpg import DDPGLoss -from .dqn import DQNLoss, DistributionalDQNLoss -from .dreamer import DreamerValueLoss, DreamerActorLoss, DreamerModelLoss -from .ppo import PPOLoss, ClipPPOLoss, KLPENPPOLoss +from .dqn import DistributionalDQNLoss, DQNLoss +from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss +from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss from .redq import REDQLoss from .sac import SACLoss from .utils import ( - SoftUpdate, - HardUpdate, distance_loss, + HardUpdate, + hold_out_net, hold_out_params, next_state_value, - hold_out_net, + SoftUpdate, ) # from .value import bellman_max, c_val, dv_val, vtrace, GAE, TDLambdaEstimate, TDEstimate diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 7db260108da..5be98bd3215 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Iterator, Optional, Tuple, List, Union +from typing import Iterator, List, Optional, Tuple, Union import torch @@ -25,7 +25,7 @@ ) FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, Tensor from torch.nn import Parameter @@ -44,7 +44,7 @@ class LossModule(nn.Module): def __init__(self): super().__init__() - self._param_maps = dict() + self._param_maps = {} def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*". @@ -284,8 +284,8 @@ def _convert_to_functional_native( # module.module.param or such names. We assume that there is a constant prefix # and that, when sorted, all keys will match. We could check that the values # do match too. - keys1 = sorted(list(params.flatten_keys(".").keys())) - keys2 = sorted(list(params_vals.keys())) + keys1 = sorted(params.flatten_keys(".").keys()) + keys2 = sorted(params_vals.keys()) for key1, key2 in zip(keys1, keys2): params_vals.rename_key(key2, key1) params = params_vals.unflatten_keys(".") @@ -341,7 +341,7 @@ def _convert_to_functional_native( name_buffers_target = "_target_" + buffer_name if create_target_params: target_params = getattr(self, param_name).detach().clone() - target_params_items = sorted(list(target_params.flatten_keys(".").items())) + target_params_items = sorted(target_params.flatten_keys(".").items()) target_params_list = [] for i, (key, val) in enumerate(target_params_items): name = "_".join([name_params_target, str(i)]) @@ -363,9 +363,7 @@ def _convert_to_functional_native( ) target_buffers = getattr(self, buffer_name).detach().clone() - target_buffers_items = sorted( - list(target_buffers.flatten_keys(".").items()) - ) + target_buffers_items = sorted(target_buffers.flatten_keys(".").items()) target_buffers_list = [] for i, (key, val) in enumerate(target_buffers_items): name = "_".join([name_buffers_target, str(i)]) diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 93b42ddb8c1..d2a54ee81fe 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -8,15 +8,12 @@ from typing import Tuple import torch -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.modules import TensorDictModule from torchrl.modules.tensordict_module.actors import ActorCriticWrapper -from torchrl.objectives.utils import ( - distance_loss, - hold_out_params, - next_state_value, -) +from torchrl.objectives.utils import distance_loss, hold_out_params, next_state_value + from ..envs.utils import set_exploration_mode from .common import LossModule diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 21a0bdd0620..9005112e7d7 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -1,6 +1,6 @@ import math from numbers import Number -from typing import Union, Tuple +from typing import Tuple, Union import numpy as np import torch @@ -11,9 +11,9 @@ from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.modules import TensorDictModule from torchrl.objectives import ( + distance_loss, hold_out_params, next_state_value as get_next_state_value, - distance_loss, ) from torchrl.objectives.common import LossModule diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 29586e9fcdc..99b82f1404a 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -11,11 +11,9 @@ from torch import nn from torchrl.envs.utils import step_mdp -from torchrl.modules import ( - DistributionalQValueActor, - QValueActor, -) +from torchrl.modules import DistributionalQValueActor, QValueActor from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible + from .common import LossModule from .utils import distance_loss, next_state_value @@ -214,7 +212,7 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: done = tensordict.get("done") steps_to_next_obs = tensordict.get("steps_to_next_obs", 1) - discount = self.gamma ** steps_to_next_obs + discount = self.gamma**steps_to_next_obs # Calculate current state probabilities (online network noise already # sampled) diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 863b34a2c56..c9f35b64649 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -8,11 +8,10 @@ from tensordict import TensorDict from torchrl.envs.model_based.dreamer import DreamerEnv -from torchrl.envs.utils import set_exploration_mode -from torchrl.envs.utils import step_mdp +from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.modules import TensorDictModule from torchrl.objectives.common import LossModule -from torchrl.objectives.utils import hold_out_net, distance_loss +from torchrl.objectives.utils import distance_loss, hold_out_net from torchrl.objectives.value.functional import vec_td_lambda_return_estimate @@ -113,8 +112,8 @@ def kl_loss( ) -> torch.Tensor: kl = ( torch.log(prior_std / posterior_std) - + (posterior_std ** 2 + (prior_mean - posterior_mean) ** 2) - / (2 * prior_std ** 2) + + (posterior_std**2 + (prior_mean - posterior_mean) ** 2) + / (2 * prior_std**2) - 0.5 ) if not self.global_average: diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 510a4ca2a79..b711519492e 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -7,11 +7,12 @@ from typing import Callable, Optional, Tuple import torch -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torch import distributions as d from torchrl.modules import TensorDictModule from torchrl.objectives.utils import distance_loss + from ..modules.tensordict_module import ProbabilisticTensorDictModule from .common import LossModule diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 2f3c010f7ee..beb7c31c51a 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -9,12 +9,12 @@ import numpy as np import torch -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torch import Tensor from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.modules import TensorDictModule -from torchrl.objectives.common import LossModule, _has_functorch +from torchrl.objectives.common import _has_functorch, LossModule from torchrl.objectives.utils import ( distance_loss, hold_out_params, diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index caca4c1722f..b68b7b981a0 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -1,10 +1,10 @@ -from typing import Optional, Callable +from typing import Callable, Optional import torch -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.envs.utils import step_mdp -from torchrl.modules import TensorDictModule, ProbabilisticTensorDictModule +from torchrl.modules import ProbabilisticTensorDictModule, TensorDictModule from torchrl.objectives import distance_loss from torchrl.objectives.common import LossModule diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 7dec8f525eb..9b0685e2178 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -9,15 +9,13 @@ import numpy as np import torch -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torch import Tensor -from torchrl.modules import ProbabilisticActor -from torchrl.modules import TensorDictModule -from torchrl.modules.tensordict_module.actors import ( - ActorCriticWrapper, -) +from torchrl.modules import ProbabilisticActor, TensorDictModule +from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.utils import distance_loss, next_state_value + from ..envs.utils import set_exploration_mode from .common import LossModule diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 3812a153a50..467c0cb7c7f 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -154,8 +154,8 @@ def init_(self) -> None: for source, target in zip(self._sources.values(), self._targets.values()): if isinstance(source, TensorDictBase) and not source.is_empty(): # native functional modules - source = list(zip(*sorted(list(source.items()))))[1] - target = list(zip(*sorted(list(target.items()))))[1] + source = list(zip(*sorted(source.items())))[1] + target = list(zip(*sorted(target.items())))[1] elif isinstance(source, TensorDictBase) and source.is_empty(): continue for p_source, p_target in zip(source, target): @@ -174,8 +174,8 @@ def step(self) -> None: for source, target in zip(self._sources.values(), self._targets.values()): if isinstance(source, TensorDictBase) and not source.is_empty(): # native functional modules - source = list(zip(*sorted(list(source.items()))))[1] - target = list(zip(*sorted(list(target.items()))))[1] + source = list(zip(*sorted(source.items())))[1] + target = list(zip(*sorted(target.items())))[1] elif isinstance(source, TensorDictBase) and source.is_empty(): continue for p_source, p_target in zip(source, target): @@ -191,8 +191,8 @@ def _step(self, p_source: Tensor, p_target: Tensor) -> None: def __repr__(self) -> str: string = ( - f"{self.__class__.__name__}(sources={[name for name in self._sources]}, targets=" - f"{[name for name in self._targets]})" + f"{self.__class__.__name__}(sources={list(self._sources)}, targets=" + f"{list(self._targets)})" ) return string @@ -341,5 +341,5 @@ def next_state_value( done = done.to(torch.float) target_value = (1 - done) * pred_next_val_detach rewards = rewards.to(torch.float) - target_value = rewards + (gamma ** steps_to_next_obs) * target_value + target_value = rewards + (gamma**steps_to_next_obs) * target_value return target_value diff --git a/torchrl/objectives/value/__init__.py b/torchrl/objectives/value/__init__.py index d226058da97..11e8f316f0b 100644 --- a/torchrl/objectives/value/__init__.py +++ b/torchrl/objectives/value/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .advantages import GAE, TDLambdaEstimate, TDEstimate +from .advantages import GAE, TDEstimate, TDLambdaEstimate diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 279339d9e6a..6996252847e 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -3,19 +3,20 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Union, Optional, List +from typing import List, Optional, Union import torch from tensordict.tensordict import TensorDictBase -from torch import Tensor, nn +from torch import nn, Tensor from torchrl.envs.utils import step_mdp from torchrl.modules import TensorDictModule from torchrl.objectives.value.functional import ( - vec_generalized_advantage_estimate, td_lambda_advantage_estimate, + vec_generalized_advantage_estimate, vec_td_lambda_advantage_estimate, ) + from ..utils import hold_out_net from .functional import td_advantage_estimate diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 7906b24a0d9..20941f8499e 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple, Optional +from typing import Optional, Tuple import torch diff --git a/torchrl/record/__init__.py b/torchrl/record/__init__.py index 6094190bf90..be720e7687c 100644 --- a/torchrl/record/__init__.py +++ b/torchrl/record/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .recorder import VideoRecorder, TensorDictRecorder +from .recorder import TensorDictRecorder, VideoRecorder diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index f325db0cb0a..62ea14a923b 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -4,16 +4,16 @@ # LICENSE file in the root directory of this source tree. from .trainers import ( - Trainer, BatchSubSampler, + ClearCudaCache, CountFramesLog, LogReward, Recorder, ReplayBuffer, RewardNormalizer, SelectKeys, + Trainer, UpdateWeights, - ClearCudaCache, ) # from .loggers import * diff --git a/torchrl/trainers/helpers/__init__.py b/torchrl/trainers/helpers/__init__.py index 13a6a554694..67466668164 100644 --- a/torchrl/trainers/helpers/__init__.py +++ b/torchrl/trainers/helpers/__init__.py @@ -4,35 +4,35 @@ # LICENSE file in the root directory of this source tree. from .collectors import ( - sync_sync_collector, - sync_async_collector, make_collector_offpolicy, make_collector_onpolicy, + sync_async_collector, + sync_sync_collector, ) from .envs import ( correct_for_frame_skip, - transformed_env_constructor, - parallel_env_constructor, get_stats_random_rollout, + parallel_env_constructor, + transformed_env_constructor, ) from .logger import LoggerConfig from .losses import ( - make_sac_loss, - make_dqn_loss, - make_ddpg_loss, - make_target_updater, make_a2c_loss, + make_ddpg_loss, + make_dqn_loss, make_ppo_loss, make_redq_loss, + make_sac_loss, + make_target_updater, ) from .models import ( - make_dqn_actor, - make_ddpg_actor, make_a2c_model, + make_ddpg_actor, + make_dqn_actor, + make_dreamer, make_ppo_model, - make_sac_model, make_redq_model, - make_dreamer, + make_sac_model, ) from .replay_buffer import make_replay_buffer from .trainers import make_trainer diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index 1d72aab2643..9e38a00f4b0 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -4,20 +4,20 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from typing import Callable, List, Optional, Type, Union, Dict, Any +from typing import Any, Callable, Dict, List, Optional, Type, Union from tensordict.tensordict import TensorDictBase from torchrl.collectors.collectors import ( _DataCollector, - SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector, + SyncDataCollector, ) from torchrl.data import MultiStep from torchrl.envs import ParallelEnv from torchrl.envs.common import EnvBase -from torchrl.modules import TensorDictModuleWrapper, ProbabilisticTensorDictModule +from torchrl.modules import ProbabilisticTensorDictModule, TensorDictModuleWrapper def sync_async_collector( @@ -176,7 +176,7 @@ def _make_collector( **kwargs, ) -> _DataCollector: if env_kwargs is None: - env_kwargs = dict() + env_kwargs = {} if isinstance(env_fns, list): num_env = len(env_fns) if num_env_per_collector is None: @@ -219,7 +219,7 @@ def _make_collector( env_kwargs = [_env_kwargs[0] for _env_kwargs in env_kwargs_split] else: env_fns = [ - lambda: ParallelEnv( + lambda _env_fn=_env_fn, _env_kwargs=_env_kwargs: ParallelEnv( num_workers=len(_env_fn), create_env_fn=_env_fn, create_env_kwargs=_env_kwargs, diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 0046315aea6..09311e35bd8 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -3,9 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from dataclasses import field as dataclass_field -from typing import Callable, Optional, Union, Any, Sequence +from dataclasses import dataclass, field as dataclass_field +from typing import Any, Callable, Optional, Sequence, Union import torch @@ -17,6 +16,7 @@ from torchrl.envs.transforms import ( CatFrames, CatTensors, + CenterCrop, DoubleToFloat, FiniteTensorDictCheck, GrayScale, @@ -27,9 +27,8 @@ ToTensorImage, TransformedEnv, VecNorm, - CenterCrop, ) -from torchrl.envs.transforms.transforms import gSDENoise, FlattenObservation +from torchrl.envs.transforms.transforms import FlattenObservation, gSDENoise from torchrl.record.recorder import VideoRecorder from torchrl.trainers.loggers import Logger diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index f0d21e13997..f24771b0288 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import itertools from dataclasses import dataclass from typing import Optional, Sequence @@ -165,7 +166,7 @@ def make_dqn_actor( "mlp_kwargs_output": {"num_cells": 512, "layer_class": linear_layer_class}, } # automatically infer in key - in_key = list(env_specs["observation_spec"])[0] + (in_key,) = itertools.islice(env_specs["observation_spec"], 1) out_features = action_spec.shape[0] actor_class = QValueActor @@ -285,8 +286,8 @@ def make_ddpg_actor( from_pixels = cfg.from_pixels noisy = cfg.noisy - actor_net_kwargs = actor_net_kwargs if actor_net_kwargs is not None else dict() - value_net_kwargs = value_net_kwargs if value_net_kwargs is not None else dict() + actor_net_kwargs = actor_net_kwargs if actor_net_kwargs is not None else {} + value_net_kwargs = value_net_kwargs if value_net_kwargs is not None else {} linear_layer_class = torch.nn.Linear if not noisy else NoisyLinear diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 10333ad82f3..0a1016c2f70 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional, Union, List +from typing import List, Optional, Union from warnings import warn import torch @@ -14,21 +14,21 @@ from torchrl.collectors.collectors import _DataCollector from torchrl.data import ReplayBuffer from torchrl.envs.common import EnvBase -from torchrl.modules import TensorDictModule, TensorDictModuleWrapper, reset_noise +from torchrl.modules import reset_noise, TensorDictModule, TensorDictModuleWrapper from torchrl.objectives.common import LossModule from torchrl.objectives.utils import TargetNetUpdater from torchrl.trainers.loggers import Logger from torchrl.trainers.trainers import ( - Trainer, - SelectKeys, - ReplayBufferTrainer, + BatchSubSampler, + ClearCudaCache, + CountFramesLog, LogReward, + Recorder, + ReplayBufferTrainer, RewardNormalizer, - BatchSubSampler, + SelectKeys, + Trainer, UpdateWeights, - Recorder, - CountFramesLog, - ClearCudaCache, ) OPTIMIZERS = { diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 7c647a35bdd..1909df1370a 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -8,19 +8,18 @@ import abc import pathlib import warnings -from collections import OrderedDict, defaultdict +from collections import defaultdict, OrderedDict from copy import deepcopy from textwrap import indent -from typing import Callable, Dict, Optional, Union, Sequence, Tuple, Type, List, Any +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import numpy as np import torch.nn -from tensordict.tensordict import TensorDictBase, pad +from tensordict.tensordict import pad, TensorDictBase from tensordict.utils import expand_right from torch import nn, optim -from torchrl._utils import KeyDependentDefaultDict -from torchrl._utils import _CKPT_BACKEND +from torchrl._utils import _CKPT_BACKEND, KeyDependentDefaultDict from torchrl.collectors.collectors import _DataCollector from torchrl.data import ( ReplayBuffer, @@ -42,7 +41,7 @@ _has_tqdm = False try: - from torchsnapshot import StateDict, Snapshot + from torchsnapshot import Snapshot, StateDict _has_ts = True except ImportError: @@ -406,7 +405,7 @@ def _post_steps_log_hook(self, batch: TensorDictBase) -> None: def train(self): if self.progress_bar: self._pbar = tqdm(total=self.total_frames) - self._pbar_str = dict() + self._pbar_str = {} for batch in self.collector: batch = self._process_batch_hook(batch) @@ -1082,7 +1081,7 @@ def __call__(self, batch: TensorDictBase) -> Dict: self.recorder.train() self.recorder.transform.dump(suffix=self.suffix) - out = dict() + out = {} for key in self.log_keys: value = td_record.get(key).float() if key == "reward":