[D] How to efficiently implement local attention?
I’d like to implement a simple dot-product attention mechanism such that the output at each timestep is computed by attending to the preceding L elements. This is similar to the standard setup for autoregressive attention, but differing in that only a fixed window is attended to at each timestep.
Suppose we are training on sequences of length N and want to compute attention over windows of L elements. The options that I can think of are:
- Compute all N2 elements of the attention matrix and apply a mask so that only the N*L elements of interest are used. This is inefficient for L<<N and often impractical for large N due to memory constraints.
- Manually window the inputs into overlapping sequences of length L, then apply attention to each window. This only requires N*L dot products, but involves tiling/repeating the inputs (attention keys/values) L times which is impractical for large L.
- Manually loop over N and L and individually compute each of the N*L dot products. This is efficient in an algorithmic sense but practically will be terrible if implemented using a high-level DL library.
My question is whether or not this operation can be efficiently computed with high-level DL libraries.