Commit 5bfb70ac authored by timo's avatar timo
Browse files

update

parent 198f1e95
import inspect import inspect
import json
import os import os
from collections import Counter from collections import Counter
from logging import warning from logging import warning
...@@ -1094,15 +1095,12 @@ class ADEAffSynth12(Interleaved): ...@@ -1094,15 +1095,12 @@ class ADEAffSynth12(Interleaved):
super().__init__(subset, max_samples, dataset1, dataset2, model_config=model_config) super().__init__(subset, max_samples, dataset1, dataset2, model_config=model_config)
OREGON_AFFORDANCE_PROPERTIES = ['walkable', 'sittable', 'lyable', 'reachable', 'movable'] OREGON_AFFORDANCE_PROPERTIES = ['walkable', 'sittable', 'lyable', 'reachable', 'movable']
class OregonAffordances(_DatasetBySample): class OregonAffordances(_DatasetBySample):
def __init__(self, subset, max_samples, image_size=(352, 352), augmentation=False, cache=False, ade_format=False, def __init__(self, subset, max_samples, augmentation=False, split='nyu'):
split='nyu'):
super().__init__(('image', 'affordance', 'mask'), super().__init__(('image', 'affordance', 'mask'),
visualization_modes=('Image', visualization_modes=('Image',
('Slices', {'maps_dim': 0, 'slice_labels': OREGON_AFFORDANCE_PROPERTIES}), None), ('Slices', {'maps_dim': 0, 'slice_labels': OREGON_AFFORDANCE_PROPERTIES}), None),
...@@ -1112,14 +1110,7 @@ class OregonAffordances(_DatasetBySample): ...@@ -1112,14 +1110,7 @@ class OregonAffordances(_DatasetBySample):
self.target_size = 352 self.target_size = 352
self.intermediate_size_range = (370, 550) self.intermediate_size_range = (370, 550)
if ade_format: self.properties = OREGON_AFFORDANCE_PROPERTIES
# ade_original = ['obstruct', 'break', 'sit', 'grasp', 'pinch_pull', 'hook_pull', 'tip_push',
# 'warmth', 'illumination', 'read/watch', 'support', 'place_on', 'clean_dry', 'roll',
# 'walk']
self.properties = [None, None, 'sittable', 'reachable', None, None, None, None, None, None, None,
'movable', None, None, None, 'walkable']
else:
self.properties = OREGON_AFFORDANCE_PROPERTIES
if split == 'nyu': # ORIGINAL NYUv2 splits if split == 'nyu': # ORIGINAL NYUv2 splits
splits = loadmat(os.path.join(PATHS['AVA_DATA'], 'nyu_splits.mat')) splits = loadmat(os.path.join(PATHS['AVA_DATA'], 'nyu_splits.mat'))
...@@ -1134,7 +1125,8 @@ class OregonAffordances(_DatasetBySample): ...@@ -1134,7 +1125,8 @@ class OregonAffordances(_DatasetBySample):
elif split == 'own': elif split == 'own':
# MORE-TRAINING-DATA SPLIT # MORE-TRAINING-DATA SPLIT
sample_ids = os.listdir(os.path.join(self.root, 'Annotations'))
sample_ids = json.load(open(os.path.join(PATHS['AVA_DATA'], 'sample_id_oregon.json')))
seed(12345) seed(12345)
shuffle(sample_ids) shuffle(sample_ids)
...@@ -1150,9 +1142,6 @@ class OregonAffordances(_DatasetBySample): ...@@ -1150,9 +1142,6 @@ class OregonAffordances(_DatasetBySample):
self.sample_ids = tuple(self.sample_ids) self.sample_ids = tuple(self.sample_ids)
if cache:
self.load_img = MemoryCachedFunction(self.load_img)
self.model_config.update({'with_mask': False, 'out_channels': len(self.properties), 'binary': True}) self.model_config.update({'with_mask': False, 'out_channels': len(self.properties), 'binary': True})
self.parameter_variables = [split, self.target_size] self.parameter_variables = [split, self.target_size]
...@@ -1161,6 +1150,7 @@ class OregonAffordances(_DatasetBySample): ...@@ -1161,6 +1150,7 @@ class OregonAffordances(_DatasetBySample):
return img return img
def __getitem__(self, index): def __getitem__(self, index):
index = self.sample_ids[index] index = self.sample_ids[index]
sample_path = os.path.join(self.root, 'Annotations') sample_path = os.path.join(self.root, 'Annotations')
......
from functools import partial from functools import partial
import time
import torch import torch
from torch import nn from torch import nn
import yaml import yaml
import time
import numpy as np import numpy as np
from torch.optim import SGD from ade import ADEAff12, AffHQ12, OregonAffordances, AffSynth12
from torch.utils.data import DataLoader
from ade import ADEAff12, AffHQ12, OregonAffordances
from ava.core.logging import * from ava.core.logging import *
from ava.models.dense.dense_resnet import ResNet50Dense from ava.models.dense.dense_resnet import ResNet50Dense
from ava.models.dense.pspnet import PSPNet from ava.models.dense.pspnet import PSPNet
...@@ -64,12 +62,16 @@ DATASET_INTERFACE = { ...@@ -64,12 +62,16 @@ DATASET_INTERFACE = {
} }
} }
def get_dataset_type(dataset_name, additional_args):
all_datasets = {'ADEAff12': ADEAff12, 'AffHQ12': AffHQ12, 'Oregon': OregonAffordances} def get_dataset_type(dataset_name):
all_datasets = {'ADEAff12': ADEAff12, 'AffHQ12': AffHQ12, 'Oregon': OregonAffordances,
'Oregon-OWN': OregonAffordances, 'AffSynth12': AffSynth12}
add_args = {'Oregon-OWN': {'split': 'own'}, 'Oregon': {'split': 'nyu'}}
if dataset_name in all_datasets: if dataset_name in all_datasets:
dataset_type = all_datasets[dataset_name] dataset_type = all_datasets[dataset_name]
additional_args = add_args[dataset_name] if dataset_name in add_args else {}
else: else:
raise FileNotFoundError('Dataset not found: {}'.format(dataset_name)) raise FileNotFoundError('Dataset not found: {}'.format(dataset_name))
...@@ -165,6 +167,7 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None, ...@@ -165,6 +167,7 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None,
log_important('\nload pre-trained model', model_name) log_important('\nload pre-trained model', model_name)
model_file = torch.load(model_name) model_file = torch.load(model_name)
print('model created', model_file['creation_time'])
loaded_model_name = model_file['model'] loaded_model_name = model_file['model']
try: try:
...@@ -198,6 +201,9 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None, ...@@ -198,6 +201,9 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None,
msg = "These arguments required by {} were not provided: {}".format(model_name, ', '.join(non_match_args)) msg = "These arguments required by {} were not provided: {}".format(model_name, ', '.join(non_match_args))
assert len(non_match_args) == 0, msg assert len(non_match_args) == 0, msg
if 'transfer_mode' in model_args:
model_args['transfer_mode'] = False
# model = model_type(**model_args) # model = model_type(**model_args)
model = model_type(**model_args) model = model_type(**model_args)
model_name = model.name() model_name = model.name()
...@@ -207,15 +213,15 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None, ...@@ -207,15 +213,15 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None,
if not no_weights: if not no_weights:
log_important('Set pre-loaded weights for ', model_name) log_important('Set pre-loaded weights for ', model_name)
weights = model_file['state_dict'] weights = model_file['state_dict']
log_info('Included submodules', set(['.'.join(k.split('.')[:2]) for k in weights])) #log_info('Included submodules', set(['.'.join(k.split('.')[:2]) for k in weights]))
# for some reason weight keys often start with "module.". This is fixed here: # for some reason weight keys often start with "module.". This is fixed here:
if all([k.startswith('module') for k in weights]): #if all([k.startswith('module') for k in weights]):
weights = {k[7:]: v for k, v in weights.items()} # weights = {k[7:]: v for k, v in weights.items()}
model.load_state_dict(weights, strict=True) model.load_state_dict(weights, strict=True)
model = monkey_patch_model(model, MODEL_INTERFACE, sync_bn, multi_gpu) #model = monkey_patch_model(model, MODEL_INTERFACE, sync_bn, multi_gpu)
# compatibility mode # compatibility mode
if 'dataset_arguments' not in model_file: model_file['dataset_arguments'] = [{}] if 'dataset_arguments' not in model_file: model_file['dataset_arguments'] = [{}]
......
...@@ -3,4 +3,4 @@ HQ_AFF: data/Expert ...@@ -3,4 +3,4 @@ HQ_AFF: data/Expert
OREGON_AFFORDANCES: data/Oregon_Affordance OREGON_AFFORDANCES: data/Oregon_Affordance
AVA_DATA: data AVA_DATA: data
CACHE: cache CACHE: cache
SYNTH_AFF: /home/timo/datasets/affordances_simulated SYNTH_AFF: data/affordances_simulated
\ No newline at end of file
...@@ -7,37 +7,54 @@ from torch.utils.data import DataLoader ...@@ -7,37 +7,54 @@ from torch.utils.data import DataLoader
from common import get_dataset_type, initialize_model, load_pretrained_model, collate, CONFIG from common import get_dataset_type, initialize_model, load_pretrained_model, collate, CONFIG
def score(model_name, dataset_name, batch_size, additional_arguments): def score(model_name, dataset_name, batch_size):
dataset_type, dataset_args, other_args = get_dataset_type(dataset_name, additional_arguments) dataset_type, dataset_args, other_args = get_dataset_type(dataset_name)
d_train = dataset_type('test', None, **dataset_args) print('Initialize dataset {} with arguments\n{}'.format(dataset_name, '\n'.join(' {}: {}'.format(k, v) for k,v in dataset_args.items())))
d_test = dataset_type('test', None, **dataset_args)
if os.path.isfile(model_name): if os.path.isfile(model_name):
model = load_pretrained_model(model_name)
import torch
model_file = torch.load(model_name)
weights = model_file['state_dict']
model_name = model_file['model']
model_args = model_file['arguments'][-1]
model_args['transfer_mode'] = False
model_args['pretrained'] = False
model = initialize_model(model_name, cli_args={}, model_config=model_args)
model.load_state_dict(weights, strict=True)
# model = load_pretrained_model(model_name)
else: else:
model = initialize_model(model_name, cli_args=other_args, model_config=d_train.model_config) model = initialize_model(model_name, cli_args=other_args, model_config=d_test.model_config)
model.eval()
loader = DataLoader(d_train, batch_size=batch_size, num_workers=4, collate_fn=collate) loader = DataLoader(d_test, batch_size=batch_size, num_workers=2, collate_fn=collate)
print('Score with batch size {} on {} samples'.format(batch_size, len(d_train))) print('Score with batch size {} on {} samples'.format(batch_size, len(d_test)))
time_start = time.time() time_start = time.time()
metrics = [m() for m in model.metrics()] metrics = [m() for m in model.metrics()][:2]
print(metrics)
for i, (vars_x, vars_y) in enumerate(loader): for i, (vars_x, vars_y) in enumerate(loader):
vars_x = [v.cuda() for v in vars_x] if i % 30 == 29:
vars_y = [v.cuda() for v in vars_y] for j in range(len(metrics)):
score = metrics[j].value()
for name, s in zip(metrics[j].names(), score):
print(name, s)
pred = model(*vars_x) vars_x = [v.cuda() if v is not None else None for v in vars_x]
vars_y = [v.cuda() if v is not None else None for v in vars_y]
for i in range(len(metrics)): pred = model(*vars_x)
metrics[i].add(pred, vars_y)
if i % 10 == 9: for j in range(len(metrics)):
print('.') metrics[j].add(pred, vars_y)
for i in range(len(metrics)): for i in range(len(metrics)):
score = metrics[i].value() score = metrics[i].value()
...@@ -57,5 +74,4 @@ if __name__ == '__main__': ...@@ -57,5 +74,4 @@ if __name__ == '__main__':
args, unknown_args = parser.parse_known_args() args, unknown_args = parser.parse_known_args()
score(args.model_name, args.dataset_name, batch_size=1, additional_arguments=None) score(args.model_name, args.dataset_name, batch_size=1)
# additional_arguments = parse_additional_arguments(unknown_args)
import argparse import argparse
import time
import time
import torch import torch
from torch.optim import SGD, RMSprop from torch.optim import SGD, RMSprop
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -10,9 +10,10 @@ import numpy as np ...@@ -10,9 +10,10 @@ import numpy as np
from common import get_dataset_type, initialize_model, collate, CONFIG from common import get_dataset_type, initialize_model, collate, CONFIG
def train(model_name, dataset_name, batch_size=16, epochs=25, additional_arguments=None): def train(model_name, dataset_name, batch_size=16, epochs=25, decoder_shape='m'):
dataset_type, dataset_args, other_args = get_dataset_type(dataset_name, additional_arguments) dataset_type, dataset_args, other_args = get_dataset_type(dataset_name)
other_args = {'decoder_shape': decoder_shape}
d_train = dataset_type('train', None, **dataset_args) d_train = dataset_type('train', None, **dataset_args)
d_val = dataset_type('val', None, **dataset_args) d_val = dataset_type('val', None, **dataset_args)
...@@ -58,7 +59,7 @@ def train(model_name, dataset_name, batch_size=16, epochs=25, additional_argumen ...@@ -58,7 +59,7 @@ def train(model_name, dataset_name, batch_size=16, epochs=25, additional_argumen
val_loss = np.mean(val_losses) val_loss = np.mean(val_losses)
if val_loss < min_val_loss: if val_loss < min_val_loss:
print('new best val loss {:.4f}! Saving model parameters'.format(val_loss)) print('new best val loss {:.4f}! Saving model parameters'.format(val_loss))
torch.save(model.state_dict(), 'model-checkpoint.th') torch.save({'model': model_name, 'args': args, 'state_dict': model.state_dict()}, 'model-checkpoint.th')
min_val_loss = val_loss min_val_loss = val_loss
print(time.time() - time_start) print(time.time() - time_start)
...@@ -71,8 +72,7 @@ if __name__ == '__main__': ...@@ -71,8 +72,7 @@ if __name__ == '__main__':
parser.add_argument('--n', type=int, default=10, help='Number of episodes') parser.add_argument('--n', type=int, default=10, help='Number of episodes')
parser.add_argument('--batch-size', type=int, default=10, help='Batch size') parser.add_argument('--batch-size', type=int, default=10, help='Batch size')
parser.add_argument('--precache', default=False, action='store_true') parser.add_argument('--precache', default=False, action='store_true')
parser.add_argument('--decoder-shape', default='m')
args, unknown_args = parser.parse_known_args() args, unknown_args = parser.parse_known_args()
train(args.model_name, args.dataset_name, batch_size=args.batch_size) train(args.model_name, args.dataset_name, batch_size=args.batch_size)
# additional_arguments = parse_additional_arguments(unknown_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment