Skip to content
Snippets Groups Projects
Unverified Commit 5e4f86a9 authored by Jay Morgan's avatar Jay Morgan
Browse files

Update readme to include predictions

parent dd9b317b
No related branches found
No related tags found
No related merge requests found
......@@ -64,4 +64,12 @@ model = CloudRemover(pretrained=True)
# create a model using a different wavelength
model = CloudRemover(wavelength="H-alpha", pretrained=True)
# test making of predictions
dataset = SyntheticClouds(download=True, transform=CloudsTransform())
model = CloudRemover(pretrained=True)
out = model(dataset[0].input[None,...])*dataset[0].mask[None,...]
import matplotlib.pyplot as plt
plt.imshow(out[0,0].detach().cpu().numpy(), cmap="Greys_r")
```
......@@ -37,7 +37,7 @@ class CloudAddition(CloudIdentity):
def forward(self, cloudy_input, cloud_pred):
cloud_pred = self.activation(cloud_pred)
if self.squash:
cloud_pred = cloud_pred * 0.5 + 0.5
cloud_pred = cloud_pred * 0.5 - 0.5
return cloudy_input - cloud_pred
......@@ -57,7 +57,7 @@ class UNet(nn.Module):
def __init__(
self,
n_blocks: int = 6,
cleaner=CloudAddition(activation=dfp.identity, squash=False),
cleaner=CloudAddition(),
in_channels=1,
out_channels=1,
init_features=16,
......@@ -148,7 +148,7 @@ class UNet(nn.Module):
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return self.conv(dec1)
return self.cleaner(x, self.conv(dec1))
@staticmethod
def _block(in_channels, features, name):
......
import dfp
from cloudremoval.model import CloudRemover
from cloudremoval.dataset import SyntheticClouds, CloudsTransform
# create a model
model = CloudRemover()
......@@ -8,3 +10,11 @@ model = CloudRemover(pretrained=True)
# create a model using a different wavelength
model = CloudRemover(wavelength="H-alpha", pretrained=True)
# test making of predictions
dataset = SyntheticClouds(download=True, transform=CloudsTransform())
model = CloudRemover(pretrained=True)
out = model(dataset[0].input[None,...])*dataset[0].mask[None,...]
import matplotlib.pyplot as plt
plt.imshow(out[0,0].detach().cpu().numpy(), cmap="Greys_r")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment