Commit 5bfb70ac authored by timo's avatar timo
Browse files

update

parent 198f1e95
import inspect
import json
import os
from collections import Counter
from logging import warning
......@@ -1094,15 +1095,12 @@ class ADEAffSynth12(Interleaved):
super().__init__(subset, max_samples, dataset1, dataset2, model_config=model_config)
OREGON_AFFORDANCE_PROPERTIES = ['walkable', 'sittable', 'lyable', 'reachable', 'movable']
class OregonAffordances(_DatasetBySample):
def __init__(self, subset, max_samples, image_size=(352, 352), augmentation=False, cache=False, ade_format=False,
split='nyu'):
def __init__(self, subset, max_samples, augmentation=False, split='nyu'):
super().__init__(('image', 'affordance', 'mask'),
visualization_modes=('Image',
('Slices', {'maps_dim': 0, 'slice_labels': OREGON_AFFORDANCE_PROPERTIES}), None),
......@@ -1112,14 +1110,7 @@ class OregonAffordances(_DatasetBySample):
self.target_size = 352
self.intermediate_size_range = (370, 550)
if ade_format:
# ade_original = ['obstruct', 'break', 'sit', 'grasp', 'pinch_pull', 'hook_pull', 'tip_push',
# 'warmth', 'illumination', 'read/watch', 'support', 'place_on', 'clean_dry', 'roll',
# 'walk']
self.properties = [None, None, 'sittable', 'reachable', None, None, None, None, None, None, None,
'movable', None, None, None, 'walkable']
else:
self.properties = OREGON_AFFORDANCE_PROPERTIES
self.properties = OREGON_AFFORDANCE_PROPERTIES
if split == 'nyu': # ORIGINAL NYUv2 splits
splits = loadmat(os.path.join(PATHS['AVA_DATA'], 'nyu_splits.mat'))
......@@ -1134,7 +1125,8 @@ class OregonAffordances(_DatasetBySample):
elif split == 'own':
# MORE-TRAINING-DATA SPLIT
sample_ids = os.listdir(os.path.join(self.root, 'Annotations'))
sample_ids = json.load(open(os.path.join(PATHS['AVA_DATA'], 'sample_id_oregon.json')))
seed(12345)
shuffle(sample_ids)
......@@ -1150,9 +1142,6 @@ class OregonAffordances(_DatasetBySample):
self.sample_ids = tuple(self.sample_ids)
if cache:
self.load_img = MemoryCachedFunction(self.load_img)
self.model_config.update({'with_mask': False, 'out_channels': len(self.properties), 'binary': True})
self.parameter_variables = [split, self.target_size]
......@@ -1161,6 +1150,7 @@ class OregonAffordances(_DatasetBySample):
return img
def __getitem__(self, index):
index = self.sample_ids[index]
sample_path = os.path.join(self.root, 'Annotations')
......
from functools import partial
import time
import torch
from torch import nn
import yaml
import time
import numpy as np
from torch.optim import SGD
from torch.utils.data import DataLoader
from ade import ADEAff12, AffHQ12, OregonAffordances
from ade import ADEAff12, AffHQ12, OregonAffordances, AffSynth12
from ava.core.logging import *
from ava.models.dense.dense_resnet import ResNet50Dense
from ava.models.dense.pspnet import PSPNet
......@@ -64,12 +62,16 @@ DATASET_INTERFACE = {
}
}
def get_dataset_type(dataset_name, additional_args):
all_datasets = {'ADEAff12': ADEAff12, 'AffHQ12': AffHQ12, 'Oregon': OregonAffordances}
def get_dataset_type(dataset_name):
all_datasets = {'ADEAff12': ADEAff12, 'AffHQ12': AffHQ12, 'Oregon': OregonAffordances,
'Oregon-OWN': OregonAffordances, 'AffSynth12': AffSynth12}
add_args = {'Oregon-OWN': {'split': 'own'}, 'Oregon': {'split': 'nyu'}}
if dataset_name in all_datasets:
dataset_type = all_datasets[dataset_name]
additional_args = add_args[dataset_name] if dataset_name in add_args else {}
else:
raise FileNotFoundError('Dataset not found: {}'.format(dataset_name))
......@@ -165,6 +167,7 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None,
log_important('\nload pre-trained model', model_name)
model_file = torch.load(model_name)
print('model created', model_file['creation_time'])
loaded_model_name = model_file['model']
try:
......@@ -198,6 +201,9 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None,
msg = "These arguments required by {} were not provided: {}".format(model_name, ', '.join(non_match_args))
assert len(non_match_args) == 0, msg
if 'transfer_mode' in model_args:
model_args['transfer_mode'] = False
# model = model_type(**model_args)
model = model_type(**model_args)
model_name = model.name()
......@@ -207,15 +213,15 @@ def load_pretrained_model(model_name: str, cli_args=None, model_config=None,
if not no_weights:
log_important('Set pre-loaded weights for ', model_name)
weights = model_file['state_dict']
log_info('Included submodules', set(['.'.join(k.split('.')[:2]) for k in weights]))
#log_info('Included submodules', set(['.'.join(k.split('.')[:2]) for k in weights]))
# for some reason weight keys often start with "module.". This is fixed here:
if all([k.startswith('module') for k in weights]):
weights = {k[7:]: v for k, v in weights.items()}
#if all([k.startswith('module') for k in weights]):
# weights = {k[7:]: v for k, v in weights.items()}
model.load_state_dict(weights, strict=True)
model = monkey_patch_model(model, MODEL_INTERFACE, sync_bn, multi_gpu)
#model = monkey_patch_model(model, MODEL_INTERFACE, sync_bn, multi_gpu)
# compatibility mode
if 'dataset_arguments' not in model_file: model_file['dataset_arguments'] = [{}]
......
......@@ -3,4 +3,4 @@ HQ_AFF: data/Expert
OREGON_AFFORDANCES: data/Oregon_Affordance
AVA_DATA: data
CACHE: cache
SYNTH_AFF: /home/timo/datasets/affordances_simulated
\ No newline at end of file
SYNTH_AFF: data/affordances_simulated
......@@ -7,37 +7,54 @@ from torch.utils.data import DataLoader
from common import get_dataset_type, initialize_model, load_pretrained_model, collate, CONFIG
def score(model_name, dataset_name, batch_size, additional_arguments):
def score(model_name, dataset_name, batch_size):
dataset_type, dataset_args, other_args = get_dataset_type(dataset_name, additional_arguments)
dataset_type, dataset_args, other_args = get_dataset_type(dataset_name)
d_train = dataset_type('test', None, **dataset_args)
print('Initialize dataset {} with arguments\n{}'.format(dataset_name, '\n'.join(' {}: {}'.format(k, v) for k,v in dataset_args.items())))
d_test = dataset_type('test', None, **dataset_args)
if os.path.isfile(model_name):
model = load_pretrained_model(model_name)
import torch
model_file = torch.load(model_name)
weights = model_file['state_dict']
model_name = model_file['model']
model_args = model_file['arguments'][-1]
model_args['transfer_mode'] = False
model_args['pretrained'] = False
model = initialize_model(model_name, cli_args={}, model_config=model_args)
model.load_state_dict(weights, strict=True)
# model = load_pretrained_model(model_name)
else:
model = initialize_model(model_name, cli_args=other_args, model_config=d_train.model_config)
model = initialize_model(model_name, cli_args=other_args, model_config=d_test.model_config)
model.eval()
loader = DataLoader(d_train, batch_size=batch_size, num_workers=4, collate_fn=collate)
loader = DataLoader(d_test, batch_size=batch_size, num_workers=2, collate_fn=collate)
print('Score with batch size {} on {} samples'.format(batch_size, len(d_train)))
print('Score with batch size {} on {} samples'.format(batch_size, len(d_test)))
time_start = time.time()
metrics = [m() for m in model.metrics()]
metrics = [m() for m in model.metrics()][:2]
print(metrics)
for i, (vars_x, vars_y) in enumerate(loader):
vars_x = [v.cuda() for v in vars_x]
vars_y = [v.cuda() for v in vars_y]
if i % 30 == 29:
for j in range(len(metrics)):
score = metrics[j].value()
for name, s in zip(metrics[j].names(), score):
print(name, s)
pred = model(*vars_x)
vars_x = [v.cuda() if v is not None else None for v in vars_x]
vars_y = [v.cuda() if v is not None else None for v in vars_y]
for i in range(len(metrics)):
metrics[i].add(pred, vars_y)
pred = model(*vars_x)
if i % 10 == 9:
print('.')
for j in range(len(metrics)):
metrics[j].add(pred, vars_y)
for i in range(len(metrics)):
score = metrics[i].value()
......@@ -57,5 +74,4 @@ if __name__ == '__main__':
args, unknown_args = parser.parse_known_args()
score(args.model_name, args.dataset_name, batch_size=1, additional_arguments=None)
# additional_arguments = parse_additional_arguments(unknown_args)
score(args.model_name, args.dataset_name, batch_size=1)
import argparse
import time
import time
import torch
from torch.optim import SGD, RMSprop
from torch.utils.data import DataLoader
......@@ -10,9 +10,10 @@ import numpy as np
from common import get_dataset_type, initialize_model, collate, CONFIG
def train(model_name, dataset_name, batch_size=16, epochs=25, additional_arguments=None):
def train(model_name, dataset_name, batch_size=16, epochs=25, decoder_shape='m'):
dataset_type, dataset_args, other_args = get_dataset_type(dataset_name, additional_arguments)
dataset_type, dataset_args, other_args = get_dataset_type(dataset_name)
other_args = {'decoder_shape': decoder_shape}
d_train = dataset_type('train', None, **dataset_args)
d_val = dataset_type('val', None, **dataset_args)
......@@ -58,7 +59,7 @@ def train(model_name, dataset_name, batch_size=16, epochs=25, additional_argumen
val_loss = np.mean(val_losses)
if val_loss < min_val_loss:
print('new best val loss {:.4f}! Saving model parameters'.format(val_loss))
torch.save(model.state_dict(), 'model-checkpoint.th')
torch.save({'model': model_name, 'args': args, 'state_dict': model.state_dict()}, 'model-checkpoint.th')
min_val_loss = val_loss
print(time.time() - time_start)
......@@ -71,8 +72,7 @@ if __name__ == '__main__':
parser.add_argument('--n', type=int, default=10, help='Number of episodes')
parser.add_argument('--batch-size', type=int, default=10, help='Batch size')
parser.add_argument('--precache', default=False, action='store_true')
parser.add_argument('--decoder-shape', default='m')
args, unknown_args = parser.parse_known_args()
train(args.model_name, args.dataset_name, batch_size=args.batch_size)
# additional_arguments = parse_additional_arguments(unknown_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment