From 14db9425ad4f7841b2efd6c676df8ef4e058f77d Mon Sep 17 00:00:00 2001
From: ceramisch <carlos.ramisch@lis-lab.fr>
Date: Mon, 18 Nov 2024 00:37:02 +0100
Subject: [PATCH] Update CM3

---
 cm-code/examples-torch-cm3.py | 28 ++++++++++++++++++++++++----
 1 file changed, 24 insertions(+), 4 deletions(-)

diff --git a/cm-code/examples-torch-cm3.py b/cm-code/examples-torch-cm3.py
index 0e5be39..bd4b30d 100755
--- a/cm-code/examples-torch-cm3.py
+++ b/cm-code/examples-torch-cm3.py
@@ -46,14 +46,34 @@ loss.backward()
 model.print_gradients()
 # ['a.grad=4.0', 'b.grad=3.0']
 
+################################################################################
+# torch.gather examples
+import torch
+
+a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
+mask = torch.tensor([[1],[1]])
+print(torch.gather(a, 1, mask))
+# tensor([[2.], 
+#         [5.]])
+print(torch.gather(a, 0, mask))
+# tensor([[4.], 
+#         [4.]])
+mask = torch.tensor([[1,0,1],[0,1,0]])
+print(torch.gather(a, 0, mask))
+#tensor([[4., 5., 3.],
+#        [1., 5., 6.]])
+mask = torch.tensor([[2,1],[0,2]])
+print(torch.gather(a, 1, mask))
+#tensor([[3., 2.],
+#        [4., 6.]])
+
 ################################################################################
 # Conv1D examples
 
-import torch
 
-mat = torch.rand(3,5)
-conv = nn.Conv1d(in_channels=3, out_channels=2, kernel_size=3)
-print(conv(mat).shape)
+#mat = torch.rand(3,5)
+#conv = nn.Conv1d(in_channels=3, out_channels=2, kernel_size=3)
+#print(conv(mat).shape)
 
 #conv_res = self.char_conv[str(k_s)](char_embs[:,word_i].transpose(1,2))
 #pool_res = nn.functional.max_pool1d(conv_res, conv_res.shape[-1])
-- 
GitLab