As far as I know, PyTorch does not inherently have masked tensor operations (such as those available in numpy.ma
).
The other day, I needed to do some aggregation operations on a tensor while ignoring the masked elements in the operations. Specifically, I needed to do a mean()
along a specific dimension, but ignore the masked elements. Fortunately, it’s easy enough to implement these operations manually. Let’s implement the mean()
operation.
Let’s say you have a matrix a
, and a bool mask m
(with the same shape as a
) and you want to compute a.mean(dim=1)
but only on elements that are not masked. Here’s a small function that does this for you:
def masked_mean(tensor, mask, dim): masked = torch.mul(tensor, mask) # Apply the mask using an element-wise multiply return masked.sum(dim=dim) / mask.sum(dim=dim) # Find the average!
We can implement a similar function for finding (say) max()
along a specific dimension:
def masked_max(tensor, mask, dim): masked = torch.mul(tensor, mask) neg_inf = torch.zeros_like(tensor) neg_inf[~mask] = -math.inf # Place the smallest values possible in masked positions return (masked + neg_inf).max(dim=dim)[0]
Simple, no? 🙂
Recent Comments