Commit 810a851d authored by timo's avatar timo
Browse files

added missing file

parent 699c00cd
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as nnf
class _PredictionOutput(nn.Module):
def __init__(self, input_channels, n_classes, with_softmax=False, upsample_mode='bilinear', dropout=None):
super().__init__()
self.upsample_mode = upsample_mode
self.with_softmax = with_softmax
self.n_classes = n_classes
self.input_channels = input_channels
self.dropout = nn.Dropout2d()
n_features = int(0.5 * input_channels + 0.5 * n_classes)
seq = [nn.Conv2d(input_channels, n_features, kernel_size=3, padding=1)]
if dropout is not None:
seq += [nn.Dropout2d(p=dropout)]
seq += [nn.Conv2d(n_features, n_classes, kernel_size=3, padding=1)]
self.convs = nn.Sequential(*seq)
# self.convs = nn.Sequential(
# nn.Conv2d(input_channels, n_features, kernel_size=3, padding=1),
# nn.Conv2d(n_features, n_classes, kernel_size=3, padding=1),
# )
def forward(self, x, target_size):
x_up = nnf.upsample(x, target_size, mode=self.upsample_mode)
pred = self.convs(x_up)
if self.with_softmax:
pred = nnf.softmax(pred, dim=1)
return pred
class _DecoderBlock(nn.Module):
def __init__(self, input_channels, output_channels, dropout=None):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
self.dropout = nn.Dropout2d(p=dropout) if dropout is not None else None
def forward(self, x, skip_x):
x_up = nnf.upsample(x, size=(skip_x.size(2), skip_x.size(3)), mode='bilinear', align_corners=False)
x = torch.cat([x_up, skip_x], dim=1)
x = self.conv(x)
if self.dropout is not None:
x = self.dropout(x)
x = nnf.relu(x)
return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment