Quantcast
Channel: How to pick top q elements in each row and each column in a Tensorflow tensor? - Stack Overflow
Viewing all articles
Browse latest Browse all 2

How to pick top q elements in each row and each column in a Tensorflow tensor?

$
0
0

I am implementing a special loss function in Tensorflow. Here is the numpy-style code of a special function which picks top q elements and masks other elements in each row and each column. Note that A is a n*n matrix, and q is an integer less than n.

def thresh(A, q):
    A_ = A.copy()
    n = A_.shape[1]
    for i in range(n):
        A_[i, :][A_[i, :].argsort()[0:n - q]] = 0
        A_[:, i][A_[:, i].argsort()[0:n - q]] = 0
    return A_

Now the problem is that I have a Tensorflow tensor A whose shape is (n,n), and I would like to implement the same logic as numpy. However, I cannot use indices to assign values to the tensor A directly. Does anyont has some solutions about it?


Viewing all articles
Browse latest Browse all 2

Latest Images

Trending Articles





Latest Images