Commit 44389fce authored by Timo Lueddecke's avatar Timo Lueddecke
Browse files

bugfix in train.py

parent 8568f479
# Source code of Affordance Segmentation from Single Images # Affordance Segmentation from Single Images
This repository contains code to reproduce the experiments from the paper. This is a simplified version of This repository contains code to reproduce the experiments from the paper. This is a simplified version of
our actual codebase, which is contains only the the necessary modules and should be more comprehensible. We will our actual codebase, which is contains only the the necessary modules and should be more comprehensible. We will
...@@ -70,4 +70,32 @@ and a dataset name. ...@@ -70,4 +70,32 @@ and a dataset name.
`[Model]` can be either `ResNet50Dense` or `PSPNet` `[Model]` can be either `ResNet50Dense` or `PSPNet`
`[Dataset]` can be either `ADEAff12`, `AffHQ12`, `Oregon`, `Oregon-OWN` or `AffSynth12`: `[Dataset]` can be either `ADEAff12`, `AffHQ12`, `Oregon`, `Oregon-OWN` or `AffSynth12`:
## References
```
@article{luddecke2019context,
title={Context-based affordance segmentation from 2D images for robot actions},
author={L{\"u}ddecke, Timo and Kulvicius, Tomas and W{\"o}rg{\"o}tter, Florentin},
journal={Robotics and Autonomous Systems},
volume={119},
pages={92--107},
year={2019},
publisher={Elsevier}
}
```
```
@inproceedings{luddecke2017learning,
title={Learning to segment affordances},
author={Luddecke, Timo and Worgotter, Florentin},
booktitle={Proceedings of the IEEE International Conference on Computer Vision Workshops},
pages={769--776},
year={2017}
}
```
Have Fun! Have Fun!
...@@ -3,13 +3,7 @@ from torch import nn ...@@ -3,13 +3,7 @@ from torch import nn
from ..blocks.blocks import _DecoderBlock from ..blocks.blocks import _DecoderBlock
from ..dense.base import _DenseBase from ..dense.base import _DenseBase
from drn.drn import drn_d_38, drn_d_54, drn_d_105
from torch.nn import functional as nnf, Parameter from torch.nn import functional as nnf, Parameter
# from third_party.sync_bn.sync_batchnorm import SynchronizedBatchNorm2d
# from ava.models.dense.pnas_wrapper import PNAS_Wrapper
# from ..feature_extraction.pnas_wrapper import PNASFeatures
BatchNorm2d = nn.BatchNorm2d BatchNorm2d = nn.BatchNorm2d
...@@ -56,6 +50,7 @@ class PSPNet(_DenseBase): ...@@ -56,6 +50,7 @@ class PSPNet(_DenseBase):
self.base = base self.base = base
self.use_act = use_act self.use_act = use_act
from drn.drn import drn_d_38, drn_d_54, drn_d_105
if base == 'drn105': if base == 'drn105':
self.base_model = drn_d_105(pretrained=pretrained, out_map=True, out_middle=True) self.base_model = drn_d_105(pretrained=pretrained, out_map=True, out_middle=True)
......
...@@ -28,8 +28,10 @@ def train(model_name, dataset_name, batch_size=16, epochs=25, decoder_shape='m') ...@@ -28,8 +28,10 @@ def train(model_name, dataset_name, batch_size=16, epochs=25, decoder_shape='m')
time_start = time.time() time_start = time.time()
min_val_loss = 99999 min_val_loss = 99999
for i_epoch in range(epochs): for i_epoch in range(epochs):
model.train()
print('start epoch {}'.format(i_epoch)) print('start epoch {}'.format(i_epoch))
for i, (vars_x, vars_y) in enumerate(loader): for i, (vars_x, vars_y) in enumerate(loader):
...@@ -46,13 +48,13 @@ def train(model_name, dataset_name, batch_size=16, epochs=25, decoder_shape='m') ...@@ -46,13 +48,13 @@ def train(model_name, dataset_name, batch_size=16, epochs=25, decoder_shape='m')
print('epoch {} / loss {}'.format(i_epoch, float(loss))) print('epoch {} / loss {}'.format(i_epoch, float(loss)))
print('Epoch done. Start validation...') print('Epoch done. Start validation...')
model.eval()
# validation time # validation time
val_losses = [] val_losses = []
for i, (vars_x, vars_y) in enumerate(loader_val): for i, (vars_x, vars_y) in enumerate(loader_val):
vars_x = [v.cuda() if v is not None else None for v in vars_x] vars_x = [v.cuda() if v is not None else None for v in vars_x]
vars_y = [v.cuda() if v is not None else None for v in vars_y] vars_y = [v.cuda() if v is not None else None for v in vars_y]
opt.zero_grad()
pred = model(*vars_x) pred = model(*vars_x)
val_losses.append(float(model.loss(pred, vars_y))) val_losses.append(float(model.loss(pred, vars_y)))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment