Gitlab Community Edition Instance

Skip to content
Snippets Groups Projects
Commit b2321744 authored by Dorothea Sommer's avatar Dorothea Sommer
Browse files

replace bmm with matmul due to driver issues

parent 1a6e0271
Branches
No related tags found
No related merge requests found
......@@ -73,13 +73,13 @@ class PointNet(nn.Module):
def forward(self, input):
matrix3x3 = self.input_transform(input)
# batch matrix multiplication
xb = torch.bmm(torch.transpose(input, 1, 2),
xb = torch.matmul(torch.transpose(input, 1, 2),
matrix3x3).transpose(1, 2)
xb = F.relu(self.bn1(self.conv1(xb)))
matrix64x64 = self.feature_transform(xb)
xb = torch.bmm(torch.transpose(xb, 1, 2),
xb = torch.matmul(torch.transpose(xb, 1, 2),
matrix64x64).transpose(1, 2)
xb = F.relu(self.bn2(self.conv2(xb)))
......@@ -104,8 +104,8 @@ def pointnetloss(outputs, labels, m3x3, m64x64, alpha=0.001, device=None):
# Calculate difference to identity matrix for regularization.
id3x3 = torch.eye(3, requires_grad=True, device=device).repeat(bs, 1, 1)
id64x64 = torch.eye(64, requires_grad=True, device=device).repeat(bs, 1, 1)
diff3x3 = id3x3 - torch.bmm(m3x3, m3x3.transpose(1, 2))
diff64x64 = id64x64 - torch.bmm(m64x64, m64x64.transpose(1, 2))
diff3x3 = id3x3 - torch.matmul(m3x3, m3x3.transpose(1, 2))
diff64x64 = id64x64 - torch.matmul(m64x64, m64x64.transpose(1, 2))
# Negative log likelihood criterion is already adapted to batch size.
return criterion(outputs, labels) + alpha * (torch.norm(diff3x3) + torch.norm(diff64x64)) / float(bs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment