Submitted by Scared_Employer6992 t3_11dd59q in MachineLearning
I have to train an UNet-like architecture for semantic segmentation with 200 outcome classes. When outcoming a final map of 4x200x500x500, batch size of 4 and 200 channels (no. of semantic classes). It blows up my GPU memory (40GB).
My first thought is only to create a broad category to reduce the number of classes. Does someone have a suggestion or tricks to accomplish this semantic segmentation task in a savvier way?
badabummbadabing t1_ja7yb9y wrote
The problem might be the number of output channels at high resolution. Instead of computing the final layer's activations and gradients in parallel for each channel, you should be able to sequentially compute each channel's loss and add their gradients in the end. This is easy, because the loss decomposes as a sum over the channels (and thus, also the channels' gradients).
In pytorch, this whole thing should then be as simple as running the forward and backward passes for the channels of the final layer sequentially (before calling optimizer.step() and optimizer.zero_grad() once). You will probably also need to retain_graph=True on every backward call, otherwise the activations in the preceding layers will be deleted before you get to the next channel.