Must admit, I found the circuit diagram harder to interpret than the textual description of what the circuit doing.
It's an interesting approach. I can see it being really useful for networks that are inherently smaller than an LLM, maybe recommendation systems, fraud detection models etc. For LLMs I guess the most important followup line of research would be to ask whether a network trained in this special manner can then be distilled or densified in some way that retains the underlying decision making of the interpretable network with a more efficient runtime representation. Or alternatively, whether super sparse networks can be made efficient to inference.
There's also a question of expected outcomes. Mechanistic interpretability seems hard not only because of the density and superposition but also because a lot of the deep concepts being represented are just inherently difficult to express in words. There are going to be a lot of groups of neurons encoding fuzzy intuitions that might take an entire essay to crudely put into words, at best.
Starting from product goals and working backwards definitely seems like the best way to keep this stuff focused but the product goal is going to depend heavily on the network being analyzed. Like, the goal of interpretability for a recommender is going to look very different to the interpretability goal for a general chat focused LLM.
In theory, multiplying a matrix and a highly sparse vector should be much faster than the dense equivalent, because you only need to read the columns of the matrix that correspond to nonzero elements of the vector. But in this paper, the vectors are much less sparse than the matrices: "Our sparsest models have approximately 1 in 1000 nonzero weights. We also enforce mild activation sparsity at all node locations, with 1 in 4 nonzero activations. Note that this does not directly enforce sparsity of the residual stream, only of residual reads and writes." In addition, they're comparing to highly optimized dense matrix multiplication kernels on GPUs, which have dedicated hardware support (Tensor Cores) that isn't useful for sparse matmul.
Right. It's super interesting to me because some years ago I got dinner with a director of AI research at Google and he told me the LLMs at that time were super sparse. Not sure if something got lost in translation or stuff just changed, but it doesn't seem to be true anymore.
In theory NVIDIA and others could optimize for sparse matrices, right? If the operands are that sparse I wonder if whole tiles could be trivially zeroed without ever executing a matmul at all. The problem feels more like RAM and how you efficiently encode such a sparse entity without wasting lots of memory and bandwidth transferring zeros around. You can use RLE but if you have to unpack into memory to use the hardware anyway maybe it's not a win in the end.