Wayve - spatial pooling attention for feature selection

2 October 2020
This summer I have been working at Wayve - a London based start-up that is trying to solve autonomous driving in possibly the hardest, but hopefully the most scalable way - with a single neural network that turns camera images into driving control. The technology is quite early stage now, so most research models use only a single front camera image, which also allows faster iteration.

In the second half of my internship, I got my own research project to use attention to increase interpretability, evaluate task usefulness, and improve performance by dynamically choosing how features are sampled. I did this by adding a layer at the end of the convolutional stack, where it would use attention and pooling to select individual features (instead of downsampling and flattening). It was inspired by Towards Interpretable Reinforcement Learning Using Attention Augmented Agents. The project isn't completed as I only had 6 weeks and was working on it alone, but it has showed good results and with a few minor improvements I believe will be added into the baseline.

Architecture overview

The standard model currently used at Wayve has a perception module, a sensor fusion module and a control module. The perception module outputs depth and segmentation for each image inputed; it is trained just on these tasks and fine-tuning it for control doesn't seem very useful. The sensor fusion module takes outputs from the perception module(s), information about the route and optionally other metadata, and outputs it as a single vector. From that it goes into the control module, which outputs speed and steering.
Overview of new submodules inside sensor fusion.
I did all my changes inside the sensor fusion module. I added SF towers, which enable the model to take different matrices for each image (previously it would get one composite one, this way it can choose to additionaly also take raw image (useful for seeing break lights, traffic light colours etc.) or separate the composite image to be able to do task attention on it).

The spatial pooling attention used a query network that would generate a number of queries, which would then be turned into maps and softmaxed (so that they specifically pull out 1 feature and don't do any further operations), each is broadcasted across the features and spatially summed resulting with a single feature vector for each query.

If task attention is included, that means dynamically creating a softmaxed kernel for a 1x1 convolution and having features normalized going into the attention layer. In the convolution the number of channels is reduced hence we see how much information it is taking from each of the tasks.

Spatial information

There is also spatial basis (set of usually 64 channels of 2D wave functions) that can be appended to both keys (the channels convolved with queries to generate maps) and values (the channels being passed on). When it is omitted, the model will select single pixel features, as since all spatial information is lost here, having features that haven't been average over multiple pixels is likely to carry more information (since the features also have a 40x40 area of view (on the depth and segmentation images), they can encode a lot of locally spatial information that they wouldn't want blurred out). When it is added to both keys and values, they don't move much as any query will have a very good match with the wave functions somewhere (more than the other channels likely) and also as the gradients in the query net aren't very good, having more than 8 channels gives worse results. Third option was to add spatial basis only to values. This lead to the best results out of the three, and the maps would always be focused on a whole semantic area (e.g. a specific lane marking or a bumper of one car).

But even though it was sampling meaningful information, focusing on important things and ignoring unimportant things, it was still slightly worse than the models that are sampling information uniformly. Why is that? I believe it will be because implicit spatial information is easier to reason from than explicit spatial information. Because instead of having a number that encodes the presence of a car in front of you, you have a lot of numbers that encode the presence of a car somewhere and you have to check which one is the one right in front of you.

So I see two ways of resloving this: 1) after pooling attention, add a few layers that are able to deal with it or 2) add implicit spatial information into the features.

1 - this could be done with some self attention + 1x1 conv layers, as self attention doesn't care about implicit spatial information. But it should be able to deal well with explicit spatial information. So after a few of these layers, all the relative spatial information might get exchanged well enough that further fc layers might work better.

2 - this can be done by including spatial basis with keys. But how can we do it so that we don't encounter the same problems as before? I believe that if we run spatial basis through a few conv layers, we can reduce the number of layers and it might also turn into more semantically meaningful structures (gates looking for lane markings, cars that are too close etc.). Visualizing what this static input looks like would also provide some further insight into whether it's learning meaningful things and what it puts focus on.

Results

Ultimately it hasn't achieved a performance improvement on the road yet, but has better results with some tasks, where you need to focus on small objects, such as traffic light intersections. It has also by far outperformed all other models in our simulation tasks. Even if the simulator was perfectly realistic, that would still not be representative of real driving performance, as the tests in the simulator test for specific tasks which might occur at frequncies not represented in this result. It is visible in the images that it is able to find semantically meaningful areas, not just areas with high frequency signal (which often happened when the architecture was too weak or underfitted).
Part of the visualization, showing attention being paid to the lit up traffic light bulb, the bumper of the van in front and the position of the lane to aim towards. (Each colour represents 4 attention heads, overall there are 32, 8 of which are ommitted in this image.)
Attention being paid to the road, with a greater focus towards the edges and the horizon. This image also shows that the system is able to deal well with many enviromental conditions. Since the attention is at the end of the convolutional stack, the pixels slightly above the horizon might still well contain some information about what is at the horizon. An interesting thing is that it learns to pay attention to vehicles and pedestrians only where they interface with the road.