Skip to content

CUDA OOM leads to unhandled thrust::system exception #357

@evelkey

Description

@evelkey

Describe the bug

ME raises a C++ thrust::system::system_error exception which cannot be handled from Python and crashes the program. This issue is raised non-deterministically during training (especially in long running trainings after a few days) and cannot be caught from Python leading to a failing training pipeline.

As parallel_for is not used directly in the repo, most likely one of the functions in MinkowskiConvolution use a thrust builtin function which utilizes it. This function call should be wrapped with THRUST_CHECK like CUDA_CHECK to create an exception which can be interpreted in Python.


To Reproduce

The problem is GPU dependent, the below code is deterministically producing the error on a 16 GB Tesla V100 GPU. To reproduce on other GPUs (mostly dependent on VRAM size), one needs to find the optimal point_count in the below code.

import MinkowskiEngine as ME
import torch
import torch.nn as nn
from MinkowskiEngine import SparseTensor


class TestNet(ME.MinkowskiNetwork):
    def __init__(self, in_feat, out_feat, D, layers=80):
        super(TestNet, self).__init__(D)
        convs = [out_feat for _ in range(layers)]
        self.convs = []
        prev = in_feat
        for outchannels in convs:
            layer = nn.Sequential(
                ME.MinkowskiConvolution(
                    in_channels=prev,
                    out_channels=outchannels,
                    kernel_size=3,
                    stride=2,
                    dilation=1,
                    bias=True,
                    dimension=D,
                ),
                ME.MinkowskiReLU(),
            )
            self.convs.append(layer)
            prev = outchannels
        self.relu = ME.MinkowskiReLU()

    def forward(self, x):
        temp = x
        for convlayer in self.convs:
            temp = convlayer(temp)
        return temp

    def cuda(self):
        super(TestNet, self).cuda()
        self.convs = [c.cuda() for c in self.convs]
        return self


point_count = 6000000
in_channels, out_channels, D = 2, 3, 3
coords, feats = (
    torch.randint(low=-1000, high=1000, size=(point_count, D + 1)).int().cuda(),
    torch.rand(size=(point_count, in_channels)).cuda(),
)
coords[:, 0] = 0

testnetwork = TestNet(in_channels, 32, 3).cuda()


for i in range(5):
    print(f"starting {i}")
    xt = SparseTensor(feats, coordinates=coords, device="cuda")
    torch.cuda.synchronize()
    print("run forward")
    res = testnetwork(xt)
    loss = res.F.sum()
    torch.cuda.synchronize()
    print("run backward")
    loss.backward()

Expected behavior

A thrust::system::system_error exception should be converted to a Python RuntimeError or MemoryError so that it can be caught with a try .. except block in Python.


Server (running inside Nvidia Docker):

==========System==========
Linux-5.4.0-1047-aws-x86_64-with-glibc2.10
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.5 LTS"
3.8.5 (default, Sep 4 2020, 07:30:14)
[GCC 7.3.0]
==========Pytorch==========
1.7.1
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 460.73.01
CUDA Version 11.2
VBIOS Version 88.00.4F.00.09
Image Version G503.0201.00.03
==========NVCC==========
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
==========CC==========
/usr/bin/c++
c++ (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0

==========MinkowskiEngine==========
0.5.4 (master of 05/26/2021)
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 10020
CUDART version MinkowskiEngine is compiled: 10020

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions