From 5e4f86a9786fa0c6a781dce976fbaa25843b4a27 Mon Sep 17 00:00:00 2001 From: Jay Morgan <jaymiles17@gmail.com> Date: Thu, 25 May 2023 15:09:10 +0100 Subject: [PATCH] Update readme to include predictions --- README.md | 8 ++++++++ src/cloudremoval/model.py | 6 +++--- tests/test_model.py | 10 ++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b46af04..70f08f0 100644 --- a/README.md +++ b/README.md @@ -64,4 +64,12 @@ model = CloudRemover(pretrained=True) # create a model using a different wavelength model = CloudRemover(wavelength="H-alpha", pretrained=True) + +# test making of predictions +dataset = SyntheticClouds(download=True, transform=CloudsTransform()) +model = CloudRemover(pretrained=True) +out = model(dataset[0].input[None,...])*dataset[0].mask[None,...] + +import matplotlib.pyplot as plt +plt.imshow(out[0,0].detach().cpu().numpy(), cmap="Greys_r") ``` diff --git a/src/cloudremoval/model.py b/src/cloudremoval/model.py index da5d6f9..8e1cc45 100644 --- a/src/cloudremoval/model.py +++ b/src/cloudremoval/model.py @@ -37,7 +37,7 @@ class CloudAddition(CloudIdentity): def forward(self, cloudy_input, cloud_pred): cloud_pred = self.activation(cloud_pred) if self.squash: - cloud_pred = cloud_pred * 0.5 + 0.5 + cloud_pred = cloud_pred * 0.5 - 0.5 return cloudy_input - cloud_pred @@ -57,7 +57,7 @@ class UNet(nn.Module): def __init__( self, n_blocks: int = 6, - cleaner=CloudAddition(activation=dfp.identity, squash=False), + cleaner=CloudAddition(), in_channels=1, out_channels=1, init_features=16, @@ -148,7 +148,7 @@ class UNet(nn.Module): dec1 = self.upconv1(dec2) dec1 = torch.cat((dec1, enc1), dim=1) dec1 = self.decoder1(dec1) - return self.conv(dec1) + return self.cleaner(x, self.conv(dec1)) @staticmethod def _block(in_channels, features, name): diff --git a/tests/test_model.py b/tests/test_model.py index 244d7f9..a51e22d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,4 +1,6 @@ +import dfp from cloudremoval.model import CloudRemover +from cloudremoval.dataset import SyntheticClouds, CloudsTransform # create a model model = CloudRemover() @@ -8,3 +10,11 @@ model = CloudRemover(pretrained=True) # create a model using a different wavelength model = CloudRemover(wavelength="H-alpha", pretrained=True) + +# test making of predictions +dataset = SyntheticClouds(download=True, transform=CloudsTransform()) +model = CloudRemover(pretrained=True) +out = model(dataset[0].input[None,...])*dataset[0].mask[None,...] + +import matplotlib.pyplot as plt +plt.imshow(out[0,0].detach().cpu().numpy(), cmap="Greys_r") -- GitLab