diff --git a/README.md b/README.md index b46af0496c97d0125e6bbfe0a38a4d9c2b94e01c..70f08f0e560a28a28685757e6cb0533199fb1449 100644 --- a/README.md +++ b/README.md @@ -64,4 +64,12 @@ model = CloudRemover(pretrained=True) # create a model using a different wavelength model = CloudRemover(wavelength="H-alpha", pretrained=True) + +# test making of predictions +dataset = SyntheticClouds(download=True, transform=CloudsTransform()) +model = CloudRemover(pretrained=True) +out = model(dataset[0].input[None,...])*dataset[0].mask[None,...] + +import matplotlib.pyplot as plt +plt.imshow(out[0,0].detach().cpu().numpy(), cmap="Greys_r") ``` diff --git a/src/cloudremoval/model.py b/src/cloudremoval/model.py index da5d6f98c0cba5ae1638a2260f34c8f897a9ade8..8e1cc45e54eb0c8561bcb9066628e8fec334a96b 100644 --- a/src/cloudremoval/model.py +++ b/src/cloudremoval/model.py @@ -37,7 +37,7 @@ class CloudAddition(CloudIdentity): def forward(self, cloudy_input, cloud_pred): cloud_pred = self.activation(cloud_pred) if self.squash: - cloud_pred = cloud_pred * 0.5 + 0.5 + cloud_pred = cloud_pred * 0.5 - 0.5 return cloudy_input - cloud_pred @@ -57,7 +57,7 @@ class UNet(nn.Module): def __init__( self, n_blocks: int = 6, - cleaner=CloudAddition(activation=dfp.identity, squash=False), + cleaner=CloudAddition(), in_channels=1, out_channels=1, init_features=16, @@ -148,7 +148,7 @@ class UNet(nn.Module): dec1 = self.upconv1(dec2) dec1 = torch.cat((dec1, enc1), dim=1) dec1 = self.decoder1(dec1) - return self.conv(dec1) + return self.cleaner(x, self.conv(dec1)) @staticmethod def _block(in_channels, features, name): diff --git a/tests/test_model.py b/tests/test_model.py index 244d7f9243389a8f23fd3f60f854b0369418a5a3..a51e22daf9507862f96edfdf5479be25789563a6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,4 +1,6 @@ +import dfp from cloudremoval.model import CloudRemover +from cloudremoval.dataset import SyntheticClouds, CloudsTransform # create a model model = CloudRemover() @@ -8,3 +10,11 @@ model = CloudRemover(pretrained=True) # create a model using a different wavelength model = CloudRemover(wavelength="H-alpha", pretrained=True) + +# test making of predictions +dataset = SyntheticClouds(download=True, transform=CloudsTransform()) +model = CloudRemover(pretrained=True) +out = model(dataset[0].input[None,...])*dataset[0].mask[None,...] + +import matplotlib.pyplot as plt +plt.imshow(out[0,0].detach().cpu().numpy(), cmap="Greys_r")