Implicit Chain of Thought Reasoning via Knowledge Distillation
- Yuntian Deng ,
- Kiran Prasad ,
- Roland Fernandez ,
- Paul Smolensky ,
- Vishrav Chaudhary ,
- Stuart Shieber
ArXiv | , Vol abs/2311.01460
To augment language models with the ability to reason, researchers usually prompt or finetune them to produce chain of thought reasoning steps before producing the final answer. However, although people use natural language to reason effectively, it may be that LMs could reason more effectively with some intermediate computation that is not in natural language. In this work, we explore an alternative reasoning approach: instead of explicitly producing the chain of thought reasoning steps, we use the language model’s internal hidden states to perform implicit reasoning. The implicit reasoning steps are distilled from a teacher model trained on explicit chain-of-thought reasoning, and instead of doing reasoning”horizontally”by producing intermediate words one-by-one, we distill it such that the reasoning happens”vertically”among the hidden states in different layers. We conduct experiments on a multi-digit multiplication task and a grade school math problem dataset and find that this approach enables solving tasks previously not solvable without explicit chain-of-thought, at a speed comparable to no chain-of-thought.