I am implementing a special loss function in Tensorflow. Here is the numpy-style code of a special function which picks top q elements and masks other elements in each row and each column. Note that A
is a n*n
matrix, and q
is an integer less than n
.
def thresh(A, q):
A_ = A.copy()
n = A_.shape[1]
for i in range(n):
A_[i, :][A_[i, :].argsort()[0:n - q]] = 0
A_[:, i][A_[:, i].argsort()[0:n - q]] = 0
return A_
Now the problem is that I have a Tensorflow tensor A
whose shape is (n,n)
, and I would like to implement the same logic as numpy. However, I cannot use indices to assign values to the tensor A
directly. Does anyont has some solutions about it?