Home >Technology peripherals >AI >Getting Started with TorchRL for Deep Reinforcement Learning
Reinforcement learning (RL) tackles complex problems, from autonomous vehicles to sophisticated language models. RL agents learn through reinforcement learning from human feedback (RLHF), adapting their responses based on human input. While Python frameworks like Keras and TensorFlow are established, PyTorch and PyTorch Lightning dominate new projects.
TorchRL, an open-source library, simplifies RL development with PyTorch. This tutorial demonstrates TorchRL setup, core components, and building a basic RL agent. We'll explore pre-built algorithms like Proximal Policy Optimization (PPO), and essential logging and monitoring techniques.
This section guides you through installing and using TorchRL.
Before installing TorchRL, ensure you have:
Install prerequisites:
!pip install torch tensordict gymnasium==0.29.1 pygame
Install TorchRL using pip. A Conda environment is recommended for personal computers or servers.
!pip install torchrl
Test your installation by importing torchrl
in a Python shell or notebook. Use check_env_specs()
to verify environment compatibility (e.g., CartPole):
import torchrl from torchrl.envs import GymEnv from torchrl.envs.utils import check_env_specs check_env_specs(GymEnv("CartPole-v1"))
A successful installation displays:
<code>[torchrl][INFO] check_env_specs succeeded!</code>
Before agent creation, let's examine TorchRL's core elements.
TorchRL provides a consistent API for various environments, wrapping environment-specific functions into standard wrappers. This simplifies interaction:
Create a Gymnasium environment using GymEnv
:
env = GymEnv("CartPole-v1")
Enhance environments with add-ons (e.g., step counters) using TransformedEnv
:
from torchrl.envs import GymEnv, StepCounter, TransformedEnv env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
Normalization is achieved with ObservationNorm
:
from torchrl.envs import Compose base_env = GymEnv('CartPole-v1', device=device) env = TransformedEnv( base_env, Compose( ObservationNorm(in_keys=["observation"]), StepCounter() ) )
Multiple transforms are combined using Compose
.
The agent uses a policy to select actions based on the environment's state, aiming to maximize cumulative rewards.
A simple random policy is created using RandomPolicy
:
!pip install torch tensordict gymnasium==0.29.1 pygame
This section demonstrates building a simple RL agent.
Import necessary packages:
!pip install torchrl
We'll use the CartPole environment:
import torchrl from torchrl.envs import GymEnv from torchrl.envs.utils import check_env_specs check_env_specs(GymEnv("CartPole-v1"))
Define hyperparameters:
<code>[torchrl][INFO] check_env_specs succeeded!</code>
Define a simple neural network policy:
env = GymEnv("CartPole-v1")
Create a data collector and replay buffer:
from torchrl.envs import GymEnv, StepCounter, TransformedEnv env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
Define training modules:
from torchrl.envs import Compose base_env = GymEnv('CartPole-v1', device=device) env = TransformedEnv( base_env, Compose( ObservationNorm(in_keys=["observation"]), StepCounter() ) )
Implement the training loop (simplified for brevity):
import torchrl import torch from tensordict import TensorDict from torchrl.data.tensor_specs import Bounded action_spec = Bounded(-torch.ones(1), torch.ones(1)) actor = torchrl.envs.utils.RandomPolicy(action_spec=action_spec) td = actor(TensorDict({}, batch_size=[])) print(td.get("action"))
Add evaluation and logging to the training loop (simplified):
import time import matplotlib.pyplot as plt from torchrl.envs import GymEnv, StepCounter, TransformedEnv from tensordict.nn import TensorDictModule as TensorDict, TensorDictSequential as Seq from torchrl.modules import EGreedyModule, MLP, QValueModule from torchrl.objectives import DQNLoss, SoftUpdate from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, ReplayBuffer from torch.optim import Adam from torchrl._utils import logger as torchrl_logger
Print training time and plot results:
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) torch.manual_seed(0) env.set_seed(0)
(The complete DQN implementation is available in the referenced DataLab workbook.)
TorchRL offers pre-built algorithms (DQN, DDPG, SAC, PPO, etc.). This section demonstrates using PPO.
Import necessary modules:
INIT_RAND_STEPS = 5000 FRAMES_PER_BATCH = 100 OPTIM_STEPS = 10 EPS_0 = 0.5 BUFFER_LEN = 100_000 ALPHA = 0.05 TARGET_UPDATE_EPS = 0.95 REPLAY_BUFFER_SAMPLE = 128 LOG_EVERY = 1000 MLP_SIZE = 64
Define hyperparameters:
value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[MLP_SIZE, MLP_SIZE]) value_net = TensorDict(value_mlp, in_keys=["observation"], out_keys=["action_value"]) policy = Seq(value_net, QValueModule(spec=env.action_spec)) exploration_module = EGreedyModule( env.action_spec, annealing_num_steps=BUFFER_LEN, eps_init=EPS_0 ) policy_explore = Seq(policy, exploration_module)
(The remaining PPO implementation, including network definitions, data collection, loss function, optimization, and training loop, follows a similar structure to the original response but is omitted here for brevity. Refer to the original response for the complete code.)
Monitor training progress using TensorBoard:
collector = SyncDataCollector( env, policy_explore, frames_per_batch=FRAMES_PER_BATCH, total_frames=-1, init_random_frames=INIT_RAND_STEPS, ) rb = ReplayBuffer(storage=LazyTensorStorage(BUFFER_LEN))
Visualize with: tensorboard --logdir="training_logs"
Debugging involves checking environment specifications:
loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True) optim = Adam(loss.parameters(), lr=ALPHA) updater = SoftUpdate(loss, eps=TARGET_UPDATE_EPS)
Sample observations and actions:
total_count = 0 total_episodes = 0 t0 = time.time() success_steps = [] for i, data in enumerate(collector): rb.extend(data) # ... (training steps, similar to the original response) ...
Visualize agent performance by rendering a video (requires torchvision
and av
):
# ... (training steps) ... if total_count > 0 and total_count % LOG_EVERY == 0: torchrl_logger.info(f"Successful steps: {max_length}, episodes: {total_episodes}") if max_length > 475: print("TRAINING COMPLETE") break
This tutorial provided a comprehensive introduction to TorchRL, showcasing its capabilities through DQN and PPO examples. Experiment with different environments and algorithms to further enhance your RL skills. The referenced resources provide additional learning opportunities.
The above is the detailed content of Getting Started with TorchRL for Deep Reinforcement Learning. For more information, please follow other related articles on the PHP Chinese website!