Nanonets 2024年11月26日
What is Test Time Training
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

Test-Time Training是一种相对新颖的技术,可在推理阶段对模型进行微调。它在ARC数据集上取得了新的准确率水平,适用于多种任务和场景,本文将探讨其原理、应用、优缺点等。

Test-Time Training是在推理阶段的学习,类似人类的任务特定学习阶段,涉及克隆模型并在类似测试数据的点上微调。

在线性回归、LLM等领域有应用,如更新线性回归模型斜率以更准确预测,LLM中可使模型适应新环境。

执行TTT需考虑多种因素,如参数高效微调、迁移学习、嵌入复用、测试时增强等。

TTT有超专业化、数据高效等优点,但也存在计算成本高、不适合实时应用等缺点。

Introduction

Back-propagation has been the engine driving the deep learning revolution. We've come a long way with advancements such as:

However, the Abstraction and Reasoning Corpus (ARC) dataset, created over five years ago, has withstood the test of numerous architectures but never budged. It has remained one of the toughest datasets where even the best models could not beat human level accuracies. This was an indication that true AGI is still far from our grasp.

Last week, a new paper "The Surprising Effectiveness of Test-Time Training for Abstract Reasoning" pushed a relatively novel technique forward, reaching a new state of the art level of accuracy on the ARC dataset that has excited the deep learning community akin to how AlexNet did 12 years ago.

TTT was invented five years ago, where training occurs on very few samples—usually one or two—similar to the testing data point. The model is allowed to update its parameters based on these examples, hyper-adapting it to only those data points.

TTT is analogous to transforming a general physician into a surgeon who is now super specialized in only heart valve replacements.

In this post, we'll learn what TTT is, how we can apply it in various tasks, and discuss the advantages, disadvantages, and implications of using TTT in real-world scenarios.

What is Test Time Training?

Humans are highly adaptable. They follow two learning phases for any task—a general learning phase that starts from birth, and a task-specific learning phase, often known as task orientation. Similarly, TTT complements pre-training and fine-tuning as a second phase of learning that occurs during inference.

Simply put, Test Time Training involves cloning a trained model during testing phase and fine-tuning it on data points similar to the datum on which you want to make an inference. To break down the process into steps, during inference, given a new test data point to infer, we perform the following actions -

    clone the (general purpose) model,gather data points from training set that are closest to the test point, either via some prior knowledge or embedding similarity,build a smaller training dataset with inputs and targets using the data from above step,decide on a loss function and train the cloned model on this small dataset,use the updated clone model to predict on the said test data point.
TTT in linear regression

For a simple example, one can take a trained linear regression model, and update the slope for a set of points in the vicinity of the test point and use it make more accurate predictions.

K-Nearest Neighbors is an extreme example of TTT process where the only training that happens is during test time.

In the domain of LLMs, TTT is especially useful, when tasks are complex and outside what an LLM has seen before.

In-Context Learning, few-shot prompting, Chain of Thought reasoning, and Retrieval Augmented Generation have been standards for enhancing LLMs during inference. These techniques enrich context before arriving at a final answer but fail in one aspect—the model is not adapting to the new environment at test time. With TTT, we can make the model learn new concepts that would otherwise needlessly capturing a vast amount of data.

Neural Network/LLM hyper-specialises during TTT

The ARC dataset is an ideal fit for this paradigm, as each data sample is a collection of few-shot examples followed by a question that can only be solved using the given examples—similar to how SAT exams require you to find the next diagram in a sequence.

Example of a data point in ARC

As shown in the image above, one can use the first three examples for training during the test time and predict on the fourth image.

How to Perform TTT

The brilliance of TTT lies in its simplicity; it extends learning into the test phase. Thus, any standard training techniques are applicable here, but there are practical aspects to consider.

Since training is computationally expensive, TTT adds more overhead since, in theory, you need to train for every inference. To mitigate this cost, consider:

def test_time_train(llm, test_input, nearest_examples, loss_fn, OptimizerClass):    lora_adapters = initialize_lora(llm)    optimizer = OptimizerClass(lora_adapters, learning_rate)    new_model = merge(llm, lora_adapters)    for nearest_example_input, nearest_example_target in nearest_examples:        nearest_example_prediction = new_model(nearest_example_input)        loss = loss_fn(nearest_example_prediction, nearest_example_target)        optimizer.zero_grad()        loss.backward()        optimizer.step()    predictions = new_model(test_input)    return predictions

Psuedo-code for test time training with LLMs

def test_time_train(base_model, test_input, nearest_examples, loss_fn, OptimizerClass):    new_head = clone(base_model.head)    optimizer = OptimizerClass(new_head, learning_rate)    for nearest_example_input, nearest_example_target in nearest_examples:        nearest_example_feature = base_model.backbone(nearest_example_input)        nearest_example_prediction = new_head(nearest_example_feature)        loss = loss_fn(nearest_example_prediction, nearest_example_target)        optimizer.zero_grad()        loss.backward()        optimizer.step()    test_features = base_model.backbone(test_input)    predictions = new_head(test_features)    return predictions

Psuedo-code for test time training with conventional transfer learning

Real-World Uses

Advantages

Disadvantages

TTT is a promising tool, but with significant overhead and risks. When used wisely, it can push model performance in challenging scenarios beyond what conventional methods can achieve.

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Test-Time Training 深度学习 模型微调 应用场景 优缺点
相关文章