Commit f59a93cb authored by timo's avatar timo
Browse files

added batch output option for sample.py

parent 02a33b84
......@@ -646,6 +646,7 @@ class ADE(_DatasetBySample):
img_path = os.path.join(self.root_path, path, 'ADE_' + self.subset_path + '_' + sample_num + '.jpg')
img = cv2.imread(img_path)
assert img is not None, 'image not found: {}'.format(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = tensor_resize(img, (self.intermediate_size_range[1],) * 2, interpret_as_min_bound=True)
return img
......
......@@ -5,29 +5,60 @@ import torch
from common import get_dataset_type, initialize_model, load_pretrained_model, collate, CONFIG
from ava.core.visualize.color_overlay import color_overlay
from ava.core.transformations.resize import tensor_resize
# defines the affordance indices to output (see affordances12.csv for indices)
AFF_INDICES = [11, 6, 3]
INTENSITIES = [5, 5, 5]
# obstruct;break;sit;grasp;pull;tip_push;illumination;read/watch;support;place_on;roll;walk
affordances12 = open('data/ade/affordances12.csv').read().split('\n')[0].split(';')[2:]
AFFORDANCES = ['place_on', 'grasp', 'pull']
AFF_INDICES = [affordances12.index(aff) for aff in AFFORDANCES]
INTENSITIES = [5, 5, 5]
SIZES = [350, 400]
def sample(model_name, image):
def sample_image(model_name, image_file_or_folder, crop=0.0):
if os.path.isfile(model_name):
model = load_pretrained_model(model_name)
model = model.eval()
else:
raise ValueError('{} is not a valid file'.format(model_name))
img = cv2.imread(image)
# img = tensor_resize(img, (300, 300), interpret_as_max_bound=True)
image_filenames = []
if os.path.isdir(image_file_or_folder):
for filename in os.listdir(image_file_or_folder):
if not filename.startswith('_') and filename[-4:] in {'.jpg', '.png', '.JPG'}:
image_filenames += [os.path.join(image_file_or_folder, filename)]
elif os.path.isfile(image_file_or_folder):
image_filenames += [image_file_or_folder]
else:
raise ValueError('Invalid image file or path')
for size in SIZES:
for image_filename in image_filenames:
print('Next image: {}'.format(image_filename))
img = cv2.imread(image_filename)
img = tensor_resize(img, (size, 999), interpret_as_max_bound=True)
img = img[:, int(crop*size):-int(crop*size)]
img_t = torch.from_numpy(img.transpose([2, 0, 1]).astype('float32')).cuda()
img_t = img_t.unsqueeze(0)
out = model(img_t)[0].detach().cpu().numpy()
colored = color_overlay(img, out[0].transpose([1, 2, 0]), 0, AFF_INDICES, False,
intensities=INTENSITIES)
dirname = os.path.dirname(image_filename)
img_t = torch.from_numpy(img.transpose([2, 0, 1]).astype('float32')).cuda()
img_t = img_t.unsqueeze(0)
out = model(img_t)[0].detach().cpu().numpy()
if not os.path.isdir(dirname):
os.makedirs(dirname)
colored = color_overlay(img, out[0].transpose([1, 2, 0]), 0, AFF_INDICES, False,
intensities=INTENSITIES)
cv2.imwrite('output.png', colored)
basename = os.path.basename(image_filename)
basename = basename[:basename.index('.')]
cv2.imwrite(os.path.join(dirname, 'out', '{}_{}_orig.jpg'.format(basename, size)), img)
cv2.imwrite(os.path.join(dirname, 'out', '{}_{}_out.jpg'.format(basename, size)), colored)
if __name__ == '__main__':
......@@ -36,9 +67,10 @@ if __name__ == '__main__':
parser.add_argument('image', help='dataset name')
parser.add_argument('--n', type=int, default=10, help='Number of episodes')
parser.add_argument('--batch-size', type=int, default=10, help='Batch size')
parser.add_argument('--crop', type=float, default=0.0, help='Percentage of crop on each side (horizontally)')
parser.add_argument('--precache', default=False, action='store_true')
args, unknown_args = parser.parse_known_args()
sample(args.model_name, args.image)
sample_image(args.model_name, args.image, args.crop)
# additional_arguments = parse_additional_arguments(unknown_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment