Think Before You Act: Decision Transformers with Working Memory

1McGill University 2Mila 3Microsoft Research
ICML 2024

Abstract

Decision Transformer-based decision-making agents have shown the ability to generalize across multiple tasks. However, their performance relies on massive data and computation. We argue that this inefficiency stems from the forgetting phenomenon, in which a model memorizes its behaviors in parameters throughout training. As a result, training on a new task may deteriorate the model's performance on previous tasks. In contrast to LLMs' implicit memory mechanism, the human brain utilizes distributed memory storage, which helps manage and organize multiple skills efficiently, mitigating the forgetting phenomenon. Inspired by this, we propose a working memory module to store, blend, and retrieve information for different downstream tasks. Evaluation results show that the proposed method improves training efficiency and generalization in Atari games and Meta-World object manipulation tasks. Moreover, we demonstrate that memory fine-tuning further enhances the adaptability of the proposed architecture.

Motivation of Explicit Memory

Our motivation comes from how humans think before they act: they can reason on past experiences to generate appropriate behavior in new situations. We want to equip our robots with similar abilities. Imagine training a robot to play four different Atari games: Asteroids, Asteroids Deluxe, Space Invaders, and Space Invaders II (As shown in below). Asteroids Deluxe is a sequel to Asteroids that introduces new boss fights and enemies, similarly, Space Invaders II is a sequel to Space Invaders. For a robot to play these four games, it must actively store what it has learned in memory and choose the appropriate strategy for each game. Throughout training, the robot's memory module continuously processes and updates relevant game information, allowing it to make informed decisions and adapt its strategies.

intro
Illustrating how a robot can use its memory to guide its playing strategy.

Decision Transformer with Working Memory

Memory Update

To store incoming information and blend it with existing memory, we calculate an erasing vector, \(\epsilon^e\), and an adding vector, \(\epsilon^a\). The erasing vector erases the current memory, while the adding vector controls information flow to the memory. We use the attention mechanism for this. First, we map memory and input information to query, key, and value vectors: \(\hat{Q}=M\hat{W}^q\), \(\hat{K}=E\hat{W}^k\), and \(\hat{V}=E\hat{W}^v\), where \(\hat{W}^q\), \(\hat{W}^k\), and \(\hat{W}^v\) are parameters. Next, we calculate the writing strength, \(\beta = \text{softmax}\Big(\frac{\hat{Q}\hat{K}^T}{\sqrt{d}}\Big)\).

The erasing vector \(\epsilon^e = w \odot (1 - \beta)\), where \(\odot\) indicates element-wise multiplication, selectively erases information from the memory matrix. The adding vector \(\epsilon^a = (w \odot \beta) \hat{W}^v x\) selectively adds information to the memory matrix. Finally, the memory is updated as \(M_t = M_{t-1} \odot (1 - \epsilon^e) + \epsilon^a\). New information is stored if the selected memory slot is empty or erased, otherwise, it blends with the existing memory contents.

Memory Retrieval

We retrieve information from the updated memory slots to utilize memory for decision-making. Reading from the memory matrix is done by computing a read position vector. This vector can be computed using the above content-based addressing mechanism that compares the query vector with the contents of the memory matrix. Note that in other retrieval-based methods, the nearest neighbor is the common way to retrieve related information. However, in our case, the working memory is considerably smaller than typical external memory, which makes attention-based retrieval feasible. Since the query information is the same as the input information, we use the same content address to retrieve the memory: \({E}_{\text{out}} = {w}\odot{M}_t\).

algo
An overview of the proposed DT-Mem architecture. The input of the encoder is a fixed-length sequence of trajectories. The encoder with positional encoder module embeds the inputs and persists the temporal correlations between states and actions. The primary role of the attention module is to capture dependencies and relationships between states, actions, and returns in a sequence. Note that there are multiple attention modules stack together. Our design deconstructs this module and manages the memory flows between the attention module within each block. The output from attention blocks flows to the action decoder, which decodes back to the real actions.

Evaluation Results

Generalization Result

math
Evaluation results on 4 held-out games after pre-training on other Atari Games. Each value represents the DQN-normalized score, computed with a 95\% confidence interval.

Scaling Result

math
Scaling of IQM scores

Fine-tuning Result

math
Fine-tuning performance on 10\% of dataset in unseen Atari games. For better visualization, the y-axis is the logarithm of DQN-normalized score.

Pre-training Result

math
The performance improvement for the training dataset.

Poster

BibTeX

@inproceedings{kang2023think,
  title={Think Before You Act: Decision Transformers with Working Memory},
  author={Kang, Jikun and Laroche, Romain and Yuan, Xingdi and Trischler, Adam and Liu, Xue and Fu, Jie},
  booktitle={Forty-first International Conference on Machine Learning},
  year={2023}
}