Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
cns-group-public
aff-seg
Commits
5bfb70ac
Commit
5bfb70ac
authored
Sep 22, 2018
by
timo
Browse files
update
parent
198f1e95
Changes
5
Hide whitespace changes
Inline
Side-by-side
ade.py
View file @
5bfb70ac
import
inspect
import
json
import
os
from
collections
import
Counter
from
logging
import
warning
...
...
@@ -1094,15 +1095,12 @@ class ADEAffSynth12(Interleaved):
super
().
__init__
(
subset
,
max_samples
,
dataset1
,
dataset2
,
model_config
=
model_config
)
OREGON_AFFORDANCE_PROPERTIES
=
[
'walkable'
,
'sittable'
,
'lyable'
,
'reachable'
,
'movable'
]
class
OregonAffordances
(
_DatasetBySample
):
def
__init__
(
self
,
subset
,
max_samples
,
image_size
=
(
352
,
352
),
augmentation
=
False
,
cache
=
False
,
ade_format
=
False
,
split
=
'nyu'
):
def
__init__
(
self
,
subset
,
max_samples
,
augmentation
=
False
,
split
=
'nyu'
):
super
().
__init__
((
'image'
,
'affordance'
,
'mask'
),
visualization_modes
=
(
'Image'
,
(
'Slices'
,
{
'maps_dim'
:
0
,
'slice_labels'
:
OREGON_AFFORDANCE_PROPERTIES
}),
None
),
...
...
@@ -1112,14 +1110,7 @@ class OregonAffordances(_DatasetBySample):
self
.
target_size
=
352
self
.
intermediate_size_range
=
(
370
,
550
)
if
ade_format
:
# ade_original = ['obstruct', 'break', 'sit', 'grasp', 'pinch_pull', 'hook_pull', 'tip_push',
# 'warmth', 'illumination', 'read/watch', 'support', 'place_on', 'clean_dry', 'roll',
# 'walk']
self
.
properties
=
[
None
,
None
,
'sittable'
,
'reachable'
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
'movable'
,
None
,
None
,
None
,
'walkable'
]
else
:
self
.
properties
=
OREGON_AFFORDANCE_PROPERTIES
self
.
properties
=
OREGON_AFFORDANCE_PROPERTIES
if
split
==
'nyu'
:
# ORIGINAL NYUv2 splits
splits
=
loadmat
(
os
.
path
.
join
(
PATHS
[
'AVA_DATA'
],
'nyu_splits.mat'
))
...
...
@@ -1134,7 +1125,8 @@ class OregonAffordances(_DatasetBySample):
elif
split
==
'own'
:
# MORE-TRAINING-DATA SPLIT
sample_ids
=
os
.
listdir
(
os
.
path
.
join
(
self
.
root
,
'Annotations'
))
sample_ids
=
json
.
load
(
open
(
os
.
path
.
join
(
PATHS
[
'AVA_DATA'
],
'sample_id_oregon.json'
)))
seed
(
12345
)
shuffle
(
sample_ids
)
...
...
@@ -1150,9 +1142,6 @@ class OregonAffordances(_DatasetBySample):
self
.
sample_ids
=
tuple
(
self
.
sample_ids
)
if
cache
:
self
.
load_img
=
MemoryCachedFunction
(
self
.
load_img
)
self
.
model_config
.
update
({
'with_mask'
:
False
,
'out_channels'
:
len
(
self
.
properties
),
'binary'
:
True
})
self
.
parameter_variables
=
[
split
,
self
.
target_size
]
...
...
@@ -1161,6 +1150,7 @@ class OregonAffordances(_DatasetBySample):
return
img
def
__getitem__
(
self
,
index
):
index
=
self
.
sample_ids
[
index
]
sample_path
=
os
.
path
.
join
(
self
.
root
,
'Annotations'
)
...
...
common.py
View file @
5bfb70ac
from
functools
import
partial
import
time
import
torch
from
torch
import
nn
import
yaml
import
time
import
numpy
as
np
from
torch.optim
import
SGD
from
torch.utils.data
import
DataLoader
from
ade
import
ADEAff12
,
AffHQ12
,
OregonAffordances
from
ade
import
ADEAff12
,
AffHQ12
,
OregonAffordances
,
AffSynth12
from
ava.core.logging
import
*
from
ava.models.dense.dense_resnet
import
ResNet50Dense
from
ava.models.dense.pspnet
import
PSPNet
...
...
@@ -64,12 +62,16 @@ DATASET_INTERFACE = {
}
}
def
get_dataset_type
(
dataset_name
,
additional_args
):
all_datasets
=
{
'ADEAff12'
:
ADEAff12
,
'AffHQ12'
:
AffHQ12
,
'Oregon'
:
OregonAffordances
}
def
get_dataset_type
(
dataset_name
):
all_datasets
=
{
'ADEAff12'
:
ADEAff12
,
'AffHQ12'
:
AffHQ12
,
'Oregon'
:
OregonAffordances
,
'Oregon-OWN'
:
OregonAffordances
,
'AffSynth12'
:
AffSynth12
}
add_args
=
{
'Oregon-OWN'
:
{
'split'
:
'own'
},
'Oregon'
:
{
'split'
:
'nyu'
}}
if
dataset_name
in
all_datasets
:
dataset_type
=
all_datasets
[
dataset_name
]
additional_args
=
add_args
[
dataset_name
]
if
dataset_name
in
add_args
else
{}
else
:
raise
FileNotFoundError
(
'Dataset not found: {}'
.
format
(
dataset_name
))
...
...
@@ -165,6 +167,7 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None,
log_important
(
'
\n
load pre-trained model'
,
model_name
)
model_file
=
torch
.
load
(
model_name
)
print
(
'model created'
,
model_file
[
'creation_time'
])
loaded_model_name
=
model_file
[
'model'
]
try
:
...
...
@@ -198,6 +201,9 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None,
msg
=
"These arguments required by {} were not provided: {}"
.
format
(
model_name
,
', '
.
join
(
non_match_args
))
assert
len
(
non_match_args
)
==
0
,
msg
if
'transfer_mode'
in
model_args
:
model_args
[
'transfer_mode'
]
=
False
# model = model_type(**model_args)
model
=
model_type
(
**
model_args
)
model_name
=
model
.
name
()
...
...
@@ -207,15 +213,15 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None,
if
not
no_weights
:
log_important
(
'Set pre-loaded weights for '
,
model_name
)
weights
=
model_file
[
'state_dict'
]
log_info
(
'Included submodules'
,
set
([
'.'
.
join
(
k
.
split
(
'.'
)[:
2
])
for
k
in
weights
]))
#
log_info('Included submodules', set(['.'.join(k.split('.')[:2]) for k in weights]))
# for some reason weight keys often start with "module.". This is fixed here:
if
all
([
k
.
startswith
(
'module'
)
for
k
in
weights
]):
weights
=
{
k
[
7
:]:
v
for
k
,
v
in
weights
.
items
()}
#
if all([k.startswith('module') for k in weights]):
#
weights = {k[7:]: v for k, v in weights.items()}
model
.
load_state_dict
(
weights
,
strict
=
True
)
model
=
monkey_patch_model
(
model
,
MODEL_INTERFACE
,
sync_bn
,
multi_gpu
)
#
model = monkey_patch_model(model, MODEL_INTERFACE, sync_bn, multi_gpu)
# compatibility mode
if
'dataset_arguments'
not
in
model_file
:
model_file
[
'dataset_arguments'
]
=
[{}]
...
...
paths.yaml
View file @
5bfb70ac
...
...
@@ -3,4 +3,4 @@ HQ_AFF: data/Expert
OREGON_AFFORDANCES
:
data/Oregon_Affordance
AVA_DATA
:
data
CACHE
:
cache
SYNTH_AFF
:
/home/timo/datasets/affordances_simulated
\ No newline at end of file
SYNTH_AFF
:
data/affordances_simulated
score.py
View file @
5bfb70ac
...
...
@@ -7,37 +7,54 @@ from torch.utils.data import DataLoader
from
common
import
get_dataset_type
,
initialize_model
,
load_pretrained_model
,
collate
,
CONFIG
def
score
(
model_name
,
dataset_name
,
batch_size
,
additional_arguments
):
def
score
(
model_name
,
dataset_name
,
batch_size
):
dataset_type
,
dataset_args
,
other_args
=
get_dataset_type
(
dataset_name
,
additional_arguments
)
dataset_type
,
dataset_args
,
other_args
=
get_dataset_type
(
dataset_name
)
d_train
=
dataset_type
(
'test'
,
None
,
**
dataset_args
)
print
(
'Initialize dataset {} with arguments
\n
{}'
.
format
(
dataset_name
,
'
\n
'
.
join
(
' {}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
dataset_args
.
items
())))
d_test
=
dataset_type
(
'test'
,
None
,
**
dataset_args
)
if
os
.
path
.
isfile
(
model_name
):
model
=
load_pretrained_model
(
model_name
)
import
torch
model_file
=
torch
.
load
(
model_name
)
weights
=
model_file
[
'state_dict'
]
model_name
=
model_file
[
'model'
]
model_args
=
model_file
[
'arguments'
][
-
1
]
model_args
[
'transfer_mode'
]
=
False
model_args
[
'pretrained'
]
=
False
model
=
initialize_model
(
model_name
,
cli_args
=
{},
model_config
=
model_args
)
model
.
load_state_dict
(
weights
,
strict
=
True
)
# model = load_pretrained_model(model_name)
else
:
model
=
initialize_model
(
model_name
,
cli_args
=
other_args
,
model_config
=
d_train
.
model_config
)
model
=
initialize_model
(
model_name
,
cli_args
=
other_args
,
model_config
=
d_test
.
model_config
)
model
.
eval
()
loader
=
DataLoader
(
d_t
rain
,
batch_size
=
batch_size
,
num_workers
=
4
,
collate_fn
=
collate
)
loader
=
DataLoader
(
d_t
est
,
batch_size
=
batch_size
,
num_workers
=
2
,
collate_fn
=
collate
)
print
(
'Score with batch size {} on {} samples'
.
format
(
batch_size
,
len
(
d_t
rain
)))
print
(
'Score with batch size {} on {} samples'
.
format
(
batch_size
,
len
(
d_t
est
)))
time_start
=
time
.
time
()
metrics
=
[
m
()
for
m
in
model
.
metrics
()]
metrics
=
[
m
()
for
m
in
model
.
metrics
()][:
2
]
print
(
metrics
)
for
i
,
(
vars_x
,
vars_y
)
in
enumerate
(
loader
):
vars_x
=
[
v
.
cuda
()
for
v
in
vars_x
]
vars_y
=
[
v
.
cuda
()
for
v
in
vars_y
]
if
i
%
30
==
29
:
for
j
in
range
(
len
(
metrics
)):
score
=
metrics
[
j
].
value
()
for
name
,
s
in
zip
(
metrics
[
j
].
names
(),
score
):
print
(
name
,
s
)
pred
=
model
(
*
vars_x
)
vars_x
=
[
v
.
cuda
()
if
v
is
not
None
else
None
for
v
in
vars_x
]
vars_y
=
[
v
.
cuda
()
if
v
is
not
None
else
None
for
v
in
vars_y
]
for
i
in
range
(
len
(
metrics
)):
metrics
[
i
].
add
(
pred
,
vars_y
)
pred
=
model
(
*
vars_x
)
if
i
%
10
==
9
:
print
(
'.'
)
for
j
in
range
(
len
(
metrics
))
:
metrics
[
j
].
add
(
pred
,
vars_y
)
for
i
in
range
(
len
(
metrics
)):
score
=
metrics
[
i
].
value
()
...
...
@@ -57,5 +74,4 @@ if __name__ == '__main__':
args
,
unknown_args
=
parser
.
parse_known_args
()
score
(
args
.
model_name
,
args
.
dataset_name
,
batch_size
=
1
,
additional_arguments
=
None
)
# additional_arguments = parse_additional_arguments(unknown_args)
score
(
args
.
model_name
,
args
.
dataset_name
,
batch_size
=
1
)
train.py
View file @
5bfb70ac
import
argparse
import
time
import
time
import
torch
from
torch.optim
import
SGD
,
RMSprop
from
torch.utils.data
import
DataLoader
...
...
@@ -10,9 +10,10 @@ import numpy as np
from
common
import
get_dataset_type
,
initialize_model
,
collate
,
CONFIG
def
train
(
model_name
,
dataset_name
,
batch_size
=
16
,
epochs
=
25
,
additional_arguments
=
None
):
def
train
(
model_name
,
dataset_name
,
batch_size
=
16
,
epochs
=
25
,
decoder_shape
=
'm'
):
dataset_type
,
dataset_args
,
other_args
=
get_dataset_type
(
dataset_name
,
additional_arguments
)
dataset_type
,
dataset_args
,
other_args
=
get_dataset_type
(
dataset_name
)
other_args
=
{
'decoder_shape'
:
decoder_shape
}
d_train
=
dataset_type
(
'train'
,
None
,
**
dataset_args
)
d_val
=
dataset_type
(
'val'
,
None
,
**
dataset_args
)
...
...
@@ -58,7 +59,7 @@ def train(model_name, dataset_name, batch_size=16, epochs=25, additional_argumen
val_loss
=
np
.
mean
(
val_losses
)
if
val_loss
<
min_val_loss
:
print
(
'new best val loss {:.4f}! Saving model parameters'
.
format
(
val_loss
))
torch
.
save
(
model
.
state_dict
(),
'model-checkpoint.th'
)
torch
.
save
(
{
'model'
:
model_name
,
'args'
:
args
,
'state_dict'
:
model
.
state_dict
()
}
,
'model-checkpoint.th'
)
min_val_loss
=
val_loss
print
(
time
.
time
()
-
time_start
)
...
...
@@ -71,8 +72,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--n'
,
type
=
int
,
default
=
10
,
help
=
'Number of episodes'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
10
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--precache'
,
default
=
False
,
action
=
'store_true'
)
parser
.
add_argument
(
'--decoder-shape'
,
default
=
'm'
)
args
,
unknown_args
=
parser
.
parse_known_args
()
train
(
args
.
model_name
,
args
.
dataset_name
,
batch_size
=
args
.
batch_size
)
# additional_arguments = parse_additional_arguments(unknown_args)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment