[R] How can I improve my material segmentations?
I am trying to perform material segmentation (essentially semantic segmentation with respect to materials) on street-view imagery. My datasets only has ground truth for select regions, so not all pixels have a label, and I calculate loss and metrics only within these ground truth regions. I use Semantic FPN (with the ResNet-50 backbone pre-trained on ImageNet), a learning rate of 0.001, momentum of 0.8, and learning rate is divided by 4 if there is no validations loss improvement after three epochs. My loss function is a per-pixel multiclass cross-entropy loss.
My dataset is extremely limited. Not only are not all pixels classified, I also only have 700 images and a severe class imbalance. I tried tackling this imbalance through loss class weighting (based on the number of ground truth pixels for each respective class, i.e. their area sizes), but it barely helps. I also possess, for every image, a depth map, which I (can) supply as a fourth channel to the input layer.
Visualizations of images trained only on RGB
Visualizations of images trained on RGBD
Visualizations of images trained only on RGB, but with class loss weighting
Visualizations of images trained only RGBD, and with class loss weighting
Performance is pretty crappy. What’s more, there is very little difference between results of my four experiments. Why is this? I would expect that the addition of depth information (which encodes surface normals and perhaps texture information; pretty discriminitive information). Besides the overall metrics being rather low, the predictions are very messy, and the networks rarely, if ever, predicts “small” classes (in terms of area size), e.g. plastic or gravel. This is to be expected with such a small amount of data, but I was wondering if there are any “performance hacks” that can boost my network, or if I am missing any obvious stuff? Or is data likely the only bottleneck here? Any suggestions are greatly appreciated!
PS. I also tried a simple ResNet-50 FCN (I simply upsample ResNet’s output until I have the same resolution; there aren’t even skip connections), and the results are worse, but at least they are smooth. Why are these more smooth?
submitted by /u/EmielBoss