67 million parameters comparable to the trillion-dollar monster GPT-4! Microsoft MIT and others have teamed up to crack the Transformer inference code

Deep Learning and NLP
2024/07/15 12:40

Source | New Zhiyuan ID | AI-era

"Causal reasoning" is definitely a niche field under the current GenAI craze, but it has a big guy - Yann LeCun.

One of his day-to-day operations on Twitter is to blast generative models like Sora and raise the flag for the field of causal reasoning that he believes in.

Even as early as 2019, in an interview with VentureBeat, he expressed this opinion: we need to introduce the causality of events into the deep learning model to enhance the generalization ability and reduce the use of training data.

Can we teach Transformer, the most popular model architecture at the moment, causal reasoning?

Recently, researchers from Microsoft MIT and other institutions have proposed a new paradigm for training large models - the axiomatic framework.

In the paper, the authors trained a 67-million-parameter model from scratch, using only a simple causal chain as the training data.

Surprisingly, when it comes to inferring causality in complex graphs, the 67M model outperforms billion-level parameter LLMs and is even comparable to GPT-4.

Paper address: https://arxiv.org/abs/2407.07612v1

The latest approach from teams such as Microsoft's MIT was inspired by Turing Award winner Judea Pearl.

Pearl has proposed the causal irrelevance axiom in structured causal rules, that is, to teach Transformer models to learn passive data directly through symbolic axiom examples.

This approach differs from traditional machine learning models and uses data derived from axioms.

As the results show, through axiom training, the study proves that the Transformer model can learn causality, thereby inferring causality, and identifying causality from correlation.

This implies that the training of large models such as GPT-4 can learn causality through noisy axiomatic examples in network data without intervening experiments.

Netizens praised, "The researchers' views are very intriguing, causal reasoning has always been the Achilles' heel of LLMs, and it is imperative to further develop this field."

"This type of research could be a path to semi-AGI."

Background:

Causal reasoning is a type of reasoning process that adheres to predefined axioms or rules that have specific causality.

Turing Award winner Judea Pearl has defined the possible types of causal reasoning through the following "ladder of causation."

Usually the axioms or rules used in causal reasoning are not directly introduced, and the model only learns the data. Axioms or rules are incorporated into the model as inductive biases, such as through regularization, model architecture, or variable selection.

What this paper wants to explore is whether models can learn axioms or rules directly from passive symbolic presentations. The authors refer to this method as "axiomatic training".

Hypothetical causal axioms can all be expressed in the following forms: < premise, hypothesis, and effect >, where the result has only two forms, "yes" and "no".

This is basically similar to Aristotle's "syllogism" format, for example, the "collider axiom" proposed in Judeal Pearl can be expressed as:

Precondition:???? ∐????, ???? ⟂̸⟂????, ???? ⟂̸⟂????

Hypothesis: Does A lead to C?

Conclusion: Yes

This is just a representation of a single axiom, so how do you express the combination of multiple axioms in a complex system? Even, can we express arbitrary causal models with a finite number of axioms?

Here, the paper cites a study published by Judea Pearl and David Galles in 1997, which demonstrated that for a given stable probabilistic causal model, there is a finite set of axioms that adequately characterize the corresponding directed causal graph.

The causal model M = (X, U, F) is defined as a set of internal variables X, external variables U, and a set of structural equations F, which describe the causal relationship between the variables X and U.

Another equivalent representation of model M is the directed graph G, which uses the directed edge Vi ⭢ Vj to represent the causal relationship between the two nodes Vi and Vj.

The so-called "stable probabilistic" causal model refers to the stability assumptions they make about the model, which refers to all the uncorrelations in M (X ↛ Y|). Z) are stable, writing:

Under the stability assumption, Galles and Pearl describe a total of 6 axioms, and this paper focuses on transitivity axioms. For a causal model of stable probability, given the variables X, Y, Z in the system, the axiom of transitivity can be written:

By simplifying the above expression further by negation, we can write a version with causal correlation:

The left side of the expression is the premise, and the right side is the hypothesis.

Such axioms can derive thousands of synthesized symbolic expressions that can be used to "teach" specific axioms to Transformer models.

Axiomized training

The above axioms with premises and assumptions can be mapped to "yes" or "no" labels, and a piece of training data can be expressed as a tuple form of {(P,H,L)}.

Given a true cause-and-effect graph, it is possible to construct a dataset D by applying the transitivity axiom (one or more times) to enumerate all possible N tuples {(P,H,L)}.

For example, the cause and effect diagram contains X1 ⭢ X2 ⭢ X3 ⭢... ⭢Xn, one possible premise is X1⭢X2∧ X2⭢X3, and the corresponding assumption X1⭢X3 is labeled "yes", while the other assumption X3 ⭢X1The label is "No".

It is worth noting that in the paper, mathematical language is used for clarity, but in fact, the dataset used for training contains only natural language.

For example, the premise in the above example should be expressed as "X1 causes X2, and X2 causes X3".

Previous studies have shown that increasing the variability and diversity of training data in the form of "perturbation" can help improve the generalization ability of the model.

Therefore, the authors introduce structured perturbations to the training data at different levels to maximize the diversity of the dataset distribution.

1) Node name: The name of each node on the transfer chain consists of 1~3 letters/numbers, and the length and specific characters used are randomly generated.

2) Causal diagram topology: There are two main types

- Sequential: All causal edges are oriented from back to front, forming a typical "transfer chain", such as X⭢Y⭢Z

- Random flipping: Given a chain of passes with a sequential structure, some of the edges are randomly flipped, introducing complexity. For example, X⭢Y⭢Z can be modified to X⭢Y⭠Z.

Random flipping can add bifurcated structures (X⭠Y⭢Z, fork) and collision structures (X⭢Y⭠Z, collider) to a chain in a single direction, which are the basic building blocks of any directed causal graph, helping to improve the model's ability to generalize across structures.

3) Chain length: Chains of different lengths are added to the training set, including 3~6 nodes.

Instead of using the next token commonly used to train Transformer models to predict the loss, the paper defines it according to the true label of each tuple in a given dataset, which is expressed as:

In addition to the training data and the loss function, another important factor is the choice of positional encoding.

Previous studies have shown that the position coding mechanism has a significant impact on the sequence length generalization ability of Transformers, but different studies seem to have drawn conflicting results.

Therefore, the authors tried different methods in their studies, including learnable position coding (LPE), sinusoidal position coding (SPE), and no location coding (NoPE).

The overall process of training and evaluation is shown in Figure 1, the Transformer model is trained on sequential chains and chains with random flips, with a length of 3~6 nodes.

After that, the trained model is evaluated on a more complex structure with > 6 nodes, where the average out-degree and in-degree of nodes are larger, the sequence is longer, and complex changes such as branching and reversal are introduced.

Implementation details: architecture, tokenizer, and training process

Specifically, the researchers trained a decoder model with 67 million parameters based on GPT-2's architecture.

The model has 12 attention layers, 8 attention heads, and 512 embedding dimensions.

It is worth mentioning that the 67M model was trained from scratch on various training datasets. To understand the impact of positional coding (PE), they considered three cases: sinusoidal positional encoding (SPE), learnable positional encoding (LPE), and no positional encoding (NoPE).

All models were trained using the AdamW optimizer with a learning rate of 1e-4 and 100 epochs trained.

Since the training dataset follows a specific structure, the researchers also developed a custom tokenizer.

Alphanumeric node names are tokenized at the character level, while special terms like "causes", "cause", "Does", "Yes", "No" are tokenized at the word level.

In a nutshell, character-level tokenization is used for alphanumeric node names, and word-level tokenization is used for special terms.

This approach avoids out-of-vocabulary (OOV) tokens during testing, as the alphanumeric node names in the test set may differ from those in the training set.

With this approach, the 67 million parameter Transformer model has a vocabulary size of 69.

Experimental results

The researchers first demonstrated how the Transformer model trained by axiomatization performed in generalizing to larger, more complex causal graphs and compared it to the pre-trained large model.

Sequence length generalization

Table 1 shows the accuracy of different models in evaluating longer causal chains that were not seen during training.

Among the baseline pre-trained language models, GPT-4 achieved the highest accuracy rate on both standard and randomly flipped causal chains.

Surprisingly, although the TS2 (NoPE) model has never seen longer sequences during training, its performance is comparable to that of the trillion-parameter GPT-4 model.

Although only the causal chain with a length of 3~6 nodes was used in training, when the sequence length was 7~13, TS2 (NoPE) obtained higher or equivalent accuracy than GPT-4 on the standard and randomly flipped chains.

For the case of sequence length of 14-15, the accuracy decreases (0.85 for the standard chain and 0.78 for the random flipped chain), but it is still significantly higher than that of the Gemini-Pro and Phi-3 models.

It should be noted that random prediction yields 50% accuracy, suggesting that TS2 (NoPE) models trained by axiomatic are able to generalize their inference capabilities to longer sequences.

Node name transformation

For the model trained on the TS2 dataset, the researchers also evaluated its ability to generalize to changes in variable names (Figure 3).

The results show that TS2 (NoPE) is robust to node name changes, and can still maintain a high accuracy rate when introducing new and longer names. It also maintains commonality for longer sequences of new node names, which behaves similarly to GPT-4.

Causal sequence order

Unlike changes in length and node names, reversal and branching operations change the causal structure, so it is better to assess whether the model has learned an accurate representation of the structure.

In Table 2b, TS2 (NoPE) has a higher accuracy than Gemini Pro and Phi-3 on a causal chain with a length of no more than 8. At a length of 9, the TS2 (NoPE) has an accuracy of 0.73, which is comparable to that of the Gemini Pro (0.74).

In Table 2a, the investigators also observed a similar pattern for the evaluation of complete inversion sequences.

In this task, the axiom training model TS2 (NoPE) outperformed GPT-4 when the limit chain length was 3-6. In particular, its accuracy (0.94 for a chain of length 6) is significantly higher than that of Gemini Pro and Phi-3 (0.62 and 0.69, respectively).

Branching

Branching can be the most challenging task because it introduces new structures that were not seen during training.

While GPT-4 achieved the best accuracy with increasing graph sizes, the TS2 (NoPE) model achieved higher accuracy than Gemini Pro in all graph sizes except for one node.

Even when evaluated on a graph with 12 nodes and 1.4 branching factors, the TS2 (NoPE) model achieves 70% accuracy, which is significantly better than the stochastic model (50%).

summary

In all evaluation settings, the axiomatic training model TS2 (NoPE) performed significantly better than the stochastic baseline, even if the length of the causal chain exceeded its training data.

In particular, the model was not trained on a fully inverted chain, and it performed on par with the larger GPT-4 model (Figure 2).

In other tasks, its accuracy is often better than or comparable to billion-parameter scale models such as Gemini Pro and Phi-3.

These results suggest that models trained on axioms can learn to reason about more complex causal structures from the demonstration of simple causal sequences. This shows the potential of axiom training in causal graph reasoning.

The role of positional encoding

Comparing model performance with different positional coding choices, the researchers found that models without positional coding generalized well on both longer sequences (chains up to 15 nodes) and complex, unseen graph structures, although they were only trained on chains of 3-6 nodes.

Models using sinusoidal position encoding (SPE) and learnable position encoding (LPE) also perform well on longer chains, but perform worse when the node name length increases, even on chains with fewer nodes (Figure 3).

This failure of generalization using SPE and LPE highlights the inability of the model to handle minor perturbations of the sequences in the training set.

In addition, SPE performs poorly in different structural dimensions (such as branching) as well as order-based setups (shuffling and inversion).

Learnable position encoding performed well on linear chains up to 9 in length, but then dropped dramatically.

Overall, the findings extend earlier research on the effectiveness of not using position coding (NoPE) to apply to the task of understanding causal sequences and generalize to longer lengths and complex structures when tested.

The importance of data perturbations

In addition to positional coding, the diversity of sequences in the training data also plays an important role.

On the causal chain alone, the trained model can generalize to longer chains (Table 1), but not to other DAG structures (see Flip in Figure 4, Invert in Figure 2, Branching in Table 3).

Models trained on TS1 or TS1 are common in all cases, including random flipping, sequential arrangement, and branching; Thus, the impact of incorporating variability at the edge level by random flipping is highlighted.

However, in different tasks, the study found that TS2 had a higher accuracy rate than TS1, even though TS1 had more variation due to random flipping.

This suggests that while perturbations contribute to structural generalization, excessive perturbations may hinder structural generalization.

Use axiom training to infer causality from correlation

Next, the authors investigate whether this ability can be transferred to other causal tasks.

To do this, the researchers applied axiomatic training to a task that inferred causality from correlation statements in observational data.

As shown in Figure 5, each data instance includes the correlation of 3 to 6 node graphs described in natural language; The goal is to infer the truth of a hypothetical and determine whether there is a direct or indirect relationship between any given node, as well as possible collision nodes and confounders.

This task is much more difficult than applying the axioms of transitivity.

Due to the complexity of the task, it was found that pre-trained models like Gemini Pro, Phi-3 performed similarly to random guesses (52% accuracy).

While GPT-4 performs slightly better, its performance is still lower (58% accuracy).

Notably, the researchers' small Transformer model outperformed all baseline models, with an accuracy rate of 64%, which was 6% higher than GPT-4.

By further exploring different training settings, the axiomatically trained Transformer model may be further optimized for this kind of causal inference task.

Overall, the researchers believe that axiomatic training is a promising way to teach Transformer models to learn causality.

Inspired by Judea Pearl's vision, this work represents a potential new frontier in science – at the intersection of causality research and language models.

Resources:

https://arxiv.org/abs/2407.07612v1

https://x.com/AniketVashisht8/status/1811752011399877014

This article is from Xinzhi self-media and does not represent the views and positions of Business Xinzhi.If there is any suspicion of infringement, please contact the administrator of the Business News Platform.Contact: system@shangyexinzhi.com