From 14db9425ad4f7841b2efd6c676df8ef4e058f77d Mon Sep 17 00:00:00 2001 From: ceramisch <carlos.ramisch@lis-lab.fr> Date: Mon, 18 Nov 2024 00:37:02 +0100 Subject: [PATCH] Update CM3 --- cm-code/examples-torch-cm3.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/cm-code/examples-torch-cm3.py b/cm-code/examples-torch-cm3.py index 0e5be39..bd4b30d 100755 --- a/cm-code/examples-torch-cm3.py +++ b/cm-code/examples-torch-cm3.py @@ -46,14 +46,34 @@ loss.backward() model.print_gradients() # ['a.grad=4.0', 'b.grad=3.0'] +################################################################################ +# torch.gather examples +import torch + +a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) +mask = torch.tensor([[1],[1]]) +print(torch.gather(a, 1, mask)) +# tensor([[2.], +# [5.]]) +print(torch.gather(a, 0, mask)) +# tensor([[4.], +# [4.]]) +mask = torch.tensor([[1,0,1],[0,1,0]]) +print(torch.gather(a, 0, mask)) +#tensor([[4., 5., 3.], +# [1., 5., 6.]]) +mask = torch.tensor([[2,1],[0,2]]) +print(torch.gather(a, 1, mask)) +#tensor([[3., 2.], +# [4., 6.]]) + ################################################################################ # Conv1D examples -import torch -mat = torch.rand(3,5) -conv = nn.Conv1d(in_channels=3, out_channels=2, kernel_size=3) -print(conv(mat).shape) +#mat = torch.rand(3,5) +#conv = nn.Conv1d(in_channels=3, out_channels=2, kernel_size=3) +#print(conv(mat).shape) #conv_res = self.char_conv[str(k_s)](char_embs[:,word_i].transpose(1,2)) #pool_res = nn.functional.max_pool1d(conv_res, conv_res.shape[-1]) -- GitLab