Skip to content
Snippets Groups Projects
Commit 86f91af7 authored by Florian Lübbe's avatar Florian Lübbe
Browse files

Update README.md

parent c6d4f30c
Branches master
No related tags found
No related merge requests found
......@@ -37,9 +37,120 @@ RL4CO is built upon:
## Getting started
<a href="https://colab.research.google.com/github/kaist-silab/rl4co/blob/main/notebooks/1-quickstart.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
RL4CO is now available for installation on `pip`!
What to do now `pip`!
```bash
pip install rl4co
!git clone https://gitlab.gwdg.de/srp_tim/rl4codynamicknn40.git
!pip install tensordict==0.1.2
!pip install einops
!pip install torchrl==0.1.1
!pip install lightning
!pip install wandb
!pip install rich
!pip install omegaconf
!pip install pyrootutils
!pip install scipy
!pip install torch_geometric
!pip install hydra-core
!pip install hydra-colorlog
!pip install robust-downloader
%cd /content/rl4codynamicknn40
!wandb login 9815f3b0e8ce08b8878d8ca9a84bca3dd70f3978
import torch
import torch.nn as nn
from rl4co.envs import TSPEnv
from rl4co.utils.trainer import RL4COTrainer
from rl4co.models.zoo.am.model import AttentionModel
#from rl4co.models.zoo.symnco.model import SymNCO
torch.manual_seed(1234)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Not done in rl4codynamicknn40: 8vny5ukd ehfbbfef ijlwua0q j8f4ript xfxmda8p
# dynamic symco not work yynmngby
import wandb
wandb.init(project="gen")
from lightning.pytorch.loggers import WandbLogger
logger = WandbLogger(project="gen", name="attn20_gcn_knn20")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelss = AttentionModel.load_from_checkpoint("rl4co/spjdvjx7/checkpoints/epoch=99-step=19600.ckpt", strict=False)
env = TSPEnv(num_loc=20)
# Generate data (100) and set as test dataset
new_dataset = env.dataset(10000)
dataloader = modelss._dataloader(new_dataset, batch_size=128)
modelss = modelss.to(device)
init_states = next(iter(dataloader))
td_init_generalization = env.reset(init_states).to(device)
out = modelss(td_init_generalization.clone(), phase="test", decode_type="greedy", return_actions=False)
print(out['reward'].mean())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelss = AttentionModel.load_from_checkpoint("rl4co/spjdvjx7/checkpoints/epoch=99-step=19600.ckpt", strict=False)
env = TSPEnv(num_loc=50)
# Generate data (100) and set as test dataset
new_dataset = env.dataset(10000)
dataloader = modelss._dataloader(new_dataset, batch_size=128)
modelss = modelss.to(device)
init_states = next(iter(dataloader))
td_init_generalization = env.reset(init_states).to(device)
out = modelss(td_init_generalization.clone(), phase="test", decode_type="greedy", return_actions=False)
print(out['reward'].mean())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelss = AttentionModel.load_from_checkpoint("rl4co/spjdvjx7/checkpoints/epoch=99-step=19600.ckpt", strict=False)
env = TSPEnv(num_loc=100)
# Generate data (100) and set as test dataset
new_dataset = env.dataset(10000)
dataloader = modelss._dataloader(new_dataset, batch_size=64)
modelss = modelss.to(device)
init_states = next(iter(dataloader))
td_init_generalization = env.reset(init_states).to(device)
out = modelss(td_init_generalization.clone(), phase="test", decode_type="greedy", return_actions=False)
print(out['reward'].mean())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelss = AttentionModel.load_from_checkpoint("rl4co/spjdvjx7/checkpoints/epoch=99-step=19600.ckpt", strict=False)
env = TSPEnv(num_loc=500)
# Generate data (100) and set as test dataset
new_dataset = env.dataset(10000)
dataloader = modelss._dataloader(new_dataset, batch_size=4)
modelss = modelss.to(device)
init_states = next(iter(dataloader))
td_init_generalization = env.reset(init_states).to(device)
out = modelss(td_init_generalization.clone(), phase="test", decode_type="greedy", return_actions=False)
print(out['reward'].mean())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelss = AttentionModel.load_from_checkpoint("rl4co/spjdvjx7/checkpoints/epoch=99-step=19600.ckpt", strict=False)
env = TSPEnv(num_loc=1000)
# Generate data (100) and set as test dataset
new_dataset = env.dataset(10000)
dataloader = modelss._dataloader(new_dataset, batch_size=2)
modelss = modelss.to(device)
init_states = next(iter(dataloader))
td_init_generalization = env.reset(init_states).to(device)
out = modelss(td_init_generalization.clone(), phase="test", decode_type="greedy", return_actions=False)
print(out['reward'].mean())
wandb.finish()
wandb.finish()
```
To get started, we recommend checking out our [quickstart notebook](notebooks/1-quickstart.ipynb) or the [minimalistic example](#minimalistic-example) below.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment