1

I have a paired dataset of binary images A and B: A1 paired with B1, A2-B2, etc., with simple shapes (rectangles, squares).

The external software receives both images A and B and it returns a number that represents the error.

I need a model that, given images A and B, can modify A into A' by adding or removing squares, so that the error from the software is minimized. I don't have access to the source code of the software so I don't know how it works.

I tried to make a NN that copies the functionality of the software, and a generative NN to generate the modified image A' but I haven't got good results.

The software can only receive binary images, so I cannot use a loss function because my last layer of the generator being a softmax, if I apply a threshold, I will lose the track of the gradients, so I cannot apply gradient descent.

Someone told me that when you cannot calculate the gradient of the loss with respect to the weights, reinforcement learning with policy gradients is a good solution.

I'm new to this field, so I want to be sure I'm going in the right direction.

This really doesn't sound like an RL problem to me. If the only issue you have is backpropagating through your thresholded softmax layer, I recommend you look into Gumbel Softmax. Maybe that can solve your problem? – harwiltz – 2020-06-23T14:07:01.190

Thank you for your reply harwiltz! I checked the Gumbel Softmax paper you mentioned, it is very interesting, I took a look into tensorflow implementation, tfp.RelaxedOneHotCategorical, the problem is that when you sample from the Gumbel distribution you get just the approximation of the one hot encoded vector, and my external software cannot receive any approximation, i'm forced to either Cast to int or modify the variables in such a way that I cannot do backward propagation – Marco Garcia Macias – 2020-06-26T07:27:11.783

Ah, that's unfortunate. In PyTorch the implementation gives you the option to get the one hot representation. However I think you can actually get away with casting, I believe this is called "straight through sampling" or something. – harwiltz – 2020-06-26T14:12:45.343