Last summer we started looking for deep learning solutions for tree shaped data. Tree-LSTM* seemed to be the simple and proven model we could use as the core of our patent search. There exists lots of working implementations, and though they all are for slightly different problems, it was no big trouble to convert one to work for us. In August we had a model that clearly was learning something. All was good, except it wasn’t. The training speed was horribly slow, ~20 iterations per second, and our data was in millions of training samples. The model managed to utilise mostly just a single CPU core, and training with GPU was even slower than that. Further research gave a grim picture. The naive implementation of the model handles only a single sample at time, when usually you want to be able to process 10-1000 samples in parallel. To make matters worse, the model also processed all the tree nodes separately.
The main source of speed in machine learning comes from parallelism, which is achieved by computing similar operations in big batches. This has three benefits. First, it allows utilising GPUs which can have over 1000 cores and with neural networks can give 50x speed benefits over CPU. Second, the matrix multiplication that is the heart of most models, is actually cheaper when you can have one big operation instead of many smaller ones. Third, the code needs to be parallel so that we can scale the speed with money. The vanilla Tree-LSTM is basically unusable even with a supercomputer, as it wouldn’t be able to utilise much more than a single CPU core in the training phase. Making the model work in a parallel fashion was needed, but everyone seemed to view it as a very difficult challenge:
“This type of model, i.e. a recursive neural network, will have a different structure for each input sample, and it is extremely difficult (almost impossible) to find samples which might result in almost similar structures, so that we can batch them together.” Github issue about Tree-LSTM performance
“Neural models operating directly on parse trees are usually difficult to parallelise and thus computationally inefficient, because aligning trees for efficient batch training is usually nontrivial.” Zhang et al., 2018
We could have tried different models, but Tree-LSTM seemed to work, and there was something that didn’t make sense. We know what we need to calculate and we can do that in parallel - what exactly is stopping us from doing just that? First we can calculate the nodes that only have edges to the leafs, after that only the nodes that have all the data available, and so on. For sure, there are few reasons why this is not trivial. The index calculations are mind bending by nature, and the simplest implementations seemed to mess the backward step performance where the learning of the neural net happens. Luckily we were living in 2018 and had chosen PyTorch, which provides some wonderful indexing functions. When you manage to use these functions for the heavy parts, the code is pretty much automatically efficient.
So, is this the fastest Tree-LSTM in the world? Not necessarily, as we recently found out that someone brilliant had done something similar with the old TensorFlow. According to the readme, nicolaspi’s implementation is 70x faster than vanilla Tree-LSTM, which is similar performance gain as what we have measured. Someone could have improved over that as well. Nevertheless, we have room for improvements, e.g. we can calculate the indexes in CPU at the same time the batch is processed by the GPU. The end result will be a Tree-LSTM model that is fast and scalable enough to be used with any kind of data, and can utilise GPUs fully. The similar optimisation strategies have also proven handy with other models, which is another story.
*) If you are not familiar with Tree-LSTM or recursive neural networks in general yet, Stanford lecture in Youtube is a good watch.