Home >Technology peripherals >AI >Research shows that similarity-based weighted interleaved learning can effectively deal with the 'amnesia' problem in deep learning
Unlike humans, artificial neural networks quickly forget previously learned information when learning new things and must be retrained by interleaving old and new information; however, interleaving all old information is very time-consuming and may not be necessary. It may be sufficient to interleave only old information that is substantially similar to the new information.
Recently, the Proceedings of the National Academy of Sciences (PNAS) published a paper, "Learning in deep neural networks and brains with similarity-weighted interleaved learning", written by a fellow of the Royal Society of Canada and Published by the team of renowned neuroscientist Bruce McNaughton. Their work found that by training old information in similarity-weighted interleaving with new information, deep networks can quickly learn new things, not only reducing forgetting rates, but also using significantly less data.
The authors also hypothesize that similarity weighting can be achieved in the brain by tracking the ongoing excitability trajectories of recently active neurons and neurodynamic attractor dynamics. staggered. These findings may lead to further advances in neuroscience and machine learning.
Understanding how the brain learns throughout life remains a long-term challenge.
In artificial neural networks (ANN), integrating new information too quickly can produce catastrophic interference, where previously acquired knowledge is suddenly lost. Complementary Learning Systems Theory (CLST) suggests that new memories can be gradually integrated into the neocortex by interleaving them with existing knowledge.
CLST states that the brain relies on complementary learning systems: the hippocampus (HC) for rapid acquisition of new memories, and the neocortex (NC) for the gradual integration of new data into context. irrelevant structured knowledge. During "offline periods," such as during sleep and quiet waking rest, the HC triggers replay of recent experiences in the NC, while the NC spontaneously retrieves and interleaves representations of existing categories. Interleaved replay allows for incremental adjustment of NC synaptic weights in a gradient descent manner to create context-independent category representations that elegantly integrate new memories and overcome catastrophic interference. Many studies have successfully used interleaved replay to achieve lifelong learning of neural networks.
However, when applying CLST in practice, there are two important issues that need to be resolved. First, how can a comprehensive interleaving of information occur when the brain cannot access all old data? One possible solution is “pseudo-rehearsal,” in which random inputs can elicit generative playback of internal representations without explicit access to previously learned examples. Attractor-like dynamics may enable the brain to complete "pseudo-rehearsal", but the content of "pseudo-rehearsal" has not yet been clarified. Therefore, the second question is whether after each new learning activity, the brain has sufficient time to interweave all the previously learned information.
The Similarity-Weighted Interleaved Learning (SWIL) algorithm is considered a solution to the second problem, which shows that only interleaving has substantial representational similarity with new information old information may be sufficient. Empirical behavioral studies show that highly consistent new information can be quickly integrated into NC structured knowledge with little interference. This suggests that the speed at which new information is integrated depends on its consistency with prior knowledge. Inspired by this behavioral result, and by re-examining previously obtained distributions of catastrophic interference between categories, McClelland et al. demonstrated that SWIL can be used in contexts with two hypernym categories (e.g., “fruit” is “apple” and “banana” ""), each epoch uses less than 2.5 times the amount of data to learn new information, achieving the same performance as training the network on all the data. However, the researchers did not find similar effects when using more complex data sets, raising concerns about the scalability of the algorithm.
Experiments show that deep nonlinear artificial neural networks can learn new information by interleaving only subsets of old information that share a large amount of representational similarity with the new information. By using the SWIL algorithm, the ANN is able to quickly learn new information with a similar level of accuracy and minimal interference, while using a very small amount of old information presented with each epoch, which means high data utilization and rapid learning.
At the same time, SWIL can also be applied to the sequence learning framework. In addition, learning a new category can greatly improve data utilization. If the old information has very little similarity to previously learned categories, then the amount of old information presented will be much smaller, which is likely to be the actual case of human learning.
Finally, the authors propose a theoretical model of how SWIL is implemented in the brain, with excitability bias proportional to the overlap of new information.
Experiments by McClelland et al. showed that in depth linearity with one hidden layer In the network, SWIL can learn a new category, similar to Fully Interleaved Learning (FIL), which interleaves the entire old category with the new category, but the amount of data used is reduced by 40%.
However, the network was trained on a very simple dataset with only two hypernym categories, which raises questions about the scalability of the algorithm.
First explore how different categories of learning evolve in a deep linear neural network with one hidden layer for more complex datasets (such as Fashion-MNIST). After removing the "boot" and "bag" categories, the model achieved a test accuracy of 87% on the remaining eight categories. The author team then retrained the model to learn the (new) "boot" class under two different conditions, each repeated 10 times:
The network was tested on a total of 9000 never-seen images, where the test dataset consisted of 1000 images per class, excluding the "bag" class . Training stops when the network's performance reaches asymptote.
As expected, FoL caused interference to the old category, while FIL overcame this (Figure 1, column 2). As mentioned above, FoL's interference with old data varies by category, which was part of the original inspiration for SWIL and suggests a hierarchical similarity relationship between the new "boot" category and the old category. For example, the recall of "sneaker" ("sneakers") and "sandals" ("sandals") decreases faster than that of "trouser" ("pants") (Figure 1, column 2), possibly due to the integration of new The "boot" class selectively changes the synaptic weights representing the "sneaker" and "sandals" classes, causing more interference.
Figure 1: Comparative analysis of the performance of the pre-trained network in learning new "boot" classes in two situations: FoL (top) and FIL (bottom). From left to right, recall for predicting new "boot" classes (olive), recall for existing classes (plotted in different colors), overall accuracy (high score means low error), and cross-entropy loss (overall A measure of error) curve is a function of the number of epochs on the retained test data set.
When FoL learns new categories, the classification performance on similar old categories will drop significantly.
The relationship between multi-category attribute similarity and learning has been explored previously, and it was shown that deep linear networks can quickly acquire known consistent attributes. In contrast, adding new branches of inconsistent attributes to an existing category hierarchy requires slow, incremental, staggered learning.
In the current work, the author team uses the proposed method to calculate the similarity at the feature level. Briefly, the cosine similarity between the average per-class activation vectors of existing and new classes in the target hidden layer (usually the penultimate layer) is calculated. Figure 2A shows the similarity matrix calculated by the author team based on the penultimate layer activation function of the pre-trained network for the new “boot” category and the old category based on the Fashion MNIST dataset.
The similarity between categories is consistent with our visual perception of objects. For example, in the hierarchical clustering diagram (Figure 2B), we can observe the relationship between the "boot" class and the "sneaker" and "sandal" classes, as well as between "shirt" ("shirt") and "t-shirt" (" T-shirt") have high similarity between categories. The similarity matrix (Figure 2A) corresponds exactly to the confusion matrix (Figure 2C). The higher the similarity, the easier it is to confuse, for example, the category “shirt” is easily confused with the category “T-shirt”, “jumper” and “jacket”, which shows that the similarity measure predicts the learning dynamics of the neural network.
In the FoL result graph in the previous section (Figure 1), there is a similar class similarity curve in the recall curve of the old category. Compared with different old categories ("trouser", etc.), FoL quickly forgets similar old categories ("sneaker" and "sandal") when learning the new "boot" category.
Figure 2: (A) The author’s team calculated the current value based on the penultimate layer activation function of the pre-trained network. Similarity matrix with classes and the new "boot" class, where diagonal values (similarities for the same class are plotted in white) are removed. (B) Hierarchical clustering of the similarity matrix in A. (C) The confusion matrix generated by the FIL algorithm after training to learn the "boot" class. Diagonal values removed for scaling clarity.
Next, 3 new conditions are added based on the first two conditions. , new classification learning dynamics were studied, where each condition was repeated 10 times:
The author team used the same test data set mentioned above (a total of n=9000 images). Training is stopped when the performance of the neural network reaches asymptote in each condition. Although less training data is used per epoch, the accuracy of predicting the new "boot" class takes longer to reach asymptote, and PIL has a lower recall compared to FIL (H=7.27, P
For SWIL, similarity calculation is used to determine the proportion of existing old category images to be interleaved. Based on this, the author team randomly draws input images with weighted probabilities from each old category. Compared to other categories, the “sneaker” and “sandal” categories are most similar, resulting in a higher proportion of being interleaved (Figure 3A).
According to the dendrogram (Figure 2B), the author team refers to the “sneaker” and “sandal” classes as similar old classes, and the rest as different old classes. The new class recall (column 1 of Figure 3B and the “New class” column of Table 1), total precision and loss of PIL (H=5.44, P0.05) are comparable to FIL. The learning of the new "boot" class in EqWIL (H=10.99, P
The author team uses the following two methods to compare SWIL and FIL:
SWIL can learn new content with reduced data requirements, memory ratio = 154.3x (54000/350), and is faster, acceleration ratio = 77.1x ( 54000/(350×2)). Even if the number of images related to new content is smaller, the model can achieve the same performance by using SWIL, which leverages the hierarchical structure of the model's prior knowledge. SWIL provides an intermediate buffer between PIL and EqWIL, allowing the integration of a new category with minimal disruption to existing categories.
Figure 3 (A) The author team pre-trained the neural network to learn new " boot" category (olive green) until the performance is stable: 1) FoL (total n=6000 images/epoch); 2) FIL (total n=54000 images/epoch); 3) PIL (total n=350 images/epoch) epoch); 4) SWIL (total n=350 images/epoch) and 5) EqWIL (total n=350 images/epoch). (B) FoL (black), FIL (blue), PIL (brown), SWIL (magenta) and EqWIL (gold) predict new categories, similar old categories (“sneaker” and “sandals”) and different old categories Recall rate, the total accuracy of predicting all categories, and the cross-entropy loss on the test data set, where the abscissa is the number of epochs.
Next, in order to test whether SWIL can work in more complex environments, The author team trained a 6-layer nonlinear CNN (Figure 4A) with a fully connected output layer to recognize images of the remaining 8 different categories (except “cat” and “car”) in the CIFAR10 dataset. They also retrained the model to learn the "cat" class under 5 different training conditions defined previously (FoL, FIL, PIL, SWIL and EqWIL). Figure 4C shows the distribution of images in each category under 5 conditions. The total number of images per epoch was 2400 for the SWIL, PIL and EqWIL conditions, while the total number of images per epoch was 45000 and 5000 for FIL and FoL respectively. The author's team trained the network separately for each situation until the performance stabilized.
They tested the model on a total of 9,000 previously unseen images (1,000 images/class, excluding the "car" class). Figure 4B is the similarity matrix calculated by the author's team based on the CIFAR10 data set. The “cat” class is more similar to the “dog” class, while other animal classes belong to the same branch (Figure 4B left).
According to the tree diagram (Figure 4B), the categories "truck" ("truck"), "ship" ("ship") and "plane" ("aircraft") are called different old categories, except for "cat" "The remaining categories of animals outside the category are called similar old categories. For FoL, the model learns the new "cat" class but forgets the old class. Similar to the results of the Fashion-MNIST data set, there are interference gradients in both the "dog" class (the most similar to the "cat" class) and the "truck" class (the least similar to the "cat" class), among which the "dog" class is forgotten has the highest rate, while the “truck” category has the lowest forgetting rate.
As shown in Figure 4D, the FIL algorithm overcomes catastrophic interference when learning the new "cat" class. For the PIL algorithm, the model uses 18.75 times the amount of data in each epoch to learn the new "cat" class, but the recall rate of the "cat" class is higher than that of FIL (H=5.72, P0.05; see Table 2 and Figure 4D ). SWIL has a higher recall rate for the new "cat" class than PIL (H=7.89, P
FIL, PIL, SWIL and EqWIL have comparable performance in predicting different old categories (H=0.6, P>0.05). SWI incorporates the new "cat" class better than PIL and helps overcome observation interference in EqWIL. Compared with FIL, using SWIL to learn new categories is faster, acceleration ratio = 31.25x (45000×10/(2400×6)), while using less data (memory ratio = 18.75x). These results demonstrate that SWIL can effectively learn new categories of things even on nonlinear CNNs and more realistic datasets.
Figure 4: (A) The author team uses a 6-layer nonlinear CNN with a fully connected output layer to learn CIFAR10 data 8 categories of things to focus on. (B) The similarity matrix (right) is calculated by the author team based on the activation function of the last convolutional layer after presenting the new "cat" class. Applying hierarchical clustering to the similarity matrix (left) shows the grouping of the two hypernym categories Animals (olive green) and Vehicles (blue) in a dendrogram. (C) The author team pre-trained CNN to learn the new "cat" class (olive green) under 5 different conditions until the performance stabilized: 1) FoL (total n=5000 images/epoch); 2) FIL (total n =45000 images/epoch); 3) PIL (total n=2400 images/epoch); 4) SWIL (total n=2400 images/epoch); 5) EqWIL (total n=2400 images/epoch). Each condition was repeated 10 times. (D) FoL (black), FIL (blue), PIL (brown), SWIL (magenta) and EqWIL (gold) predict new classes, similar old classes (other animal classes in the CIFAR10 dataset) and different old classes ( "plane", "ship" and "truck"), the total prediction accuracy of all categories, and the cross-entropy loss on the test data set, where the abscissa is the number of epochs.
If a new content can be added to the previous In the learned categories without requiring major changes to the network, the two are said to be consistent. Based on this framework, learning new categories that interfere with fewer existing categories (high consistency) can be more easily integrated into the network than learning new categories that interfere with multiple existing categories (low consistency).
To test the above inference, the author team used the pre-trained CNN from the previous section to learn a new "car" category under all 5 learning conditions described previously . Figure 5A shows the similarity matrix of the "car" category. Compared with other existing categories, "car" and "truck", "ship" and "plane" are under the same level node, indicating that they are more similar. For further confirmation, the author team performed t-SNE dimensionality reduction visualization analysis on the activation layer used for similarity calculation (Figure 5B). The study found that the "car" class overlapped significantly with other vehicle classes ("truck," "ship," and "plane"), while the "cat" class overlapped with other animal classes ("dog," "frog" , "horse" ("horse"), "bird" ("bird") and "deer" ("deer")) overlap.
In line with the author's team's expectations, FoL will produce catastrophic interference when learning the "car" category, and is more interfering with similar old categories. This is overcome by using FIL (Figure 5D). For PIL, SWIL and EqWIL, there are a total of n = 2000 images per epoch (Figure 5C). Using the SWIL algorithm, the model can learn new "car" categories with an accuracy similar to FIL (H=0.79, P>0.05), with minimal interference to existing categories (including similar and different categories). As shown in column 2 of Figure 5D, using EqWIL, the model learns the new “car” class in the same way as SWIL, but with a higher degree of interference with other similar categories (e.g., “truck”) (H=53.81, P
Compared with FIL, SWIL can learn new content faster, acceleration ratio = 48.75x (45000×12/(2000×6)), and memory requirements are reduced, memory ratio = 22.5x. Compared to "cat" (48.75x vs. 31.25x), "car" can learn faster by interleaving fewer classes (such as "truck", "ship" and "plane"), while "cat" has more Many categories (such as "dog", "frog", "horse", "frog" and "deer") overlap. These simulation experiments show that the amount of old category data required for cross-over and accelerated learning of new categories depends on the consistency of the new information with prior knowledge.
Figure 5: (A) The author team calculated the similarity matrix based on the penultimate layer activation function (left) , and the result of hierarchical clustering of the similarity matrix after presenting the new "car" category (right). (B) The model learns new "car" categories and "cat" categories respectively. After the last convolutional layer passes the activation function, the author team performs t-SNE dimensionality reduction visualization results. (C) The author team pre-trained CNN to learn the new "car" class (olive green) under 5 different conditions until the performance stabilized: 1) FoL (total n=5000 images/epoch); 2) FIL (total n =45000 images/epoch); 3) PIL (total n=2000 images/epoch); 4) SWIL (total n=2000 images/epoch); 5) EqWIL (total n=2000 images/epoch). (D) FoL (black), FIL (blue), PIL (brown), SWIL (magenta) and EqWIL (gold) predict new categories, similar old categories (“plane”, “ship” and “truck”) and The recall rate of different old categories (other animal categories in the CIFAR10 dataset), the total prediction accuracy of all categories, and the cross-entropy loss on the test dataset, where the abscissa is the number of epochs. Each graph shows the average of 10 replicates, and the shaded area is ±1 SEM.
Next, the author team tested whether SWIL could be used to learn new content presented in a serialized form (sequence learning frame). To this end, they adopted the trained CNN model in Figure 4 to learn the "cat" class (task 1) in the CIFAR10 dataset under FIL and SWIL conditions, trained only on the remaining 9 categories of CIFAR10, and then trained on each condition Next train the model to learn the new "car" class (Task 2). The first column of Figure 6 shows the distribution of the number of images in other categories when learning the "car" category under SWIL conditions (a total of n=2500 images/epoch). Note that predicting the "cat" class also cross-learns the new "car" class. Since the model performance is best under FIL conditions, SWIL only compared results with FIL.
As shown in Figure 6, SWIL’s ability to predict new and old categories is equivalent to FIL (H=14.3, P>0.05). The model uses the SWIL algorithm to learn new "car" categories faster, with an acceleration ratio of 45x (50000×20/(2500×8)), and the memory footprint of each epoch is 20 times less than FIL. When the model learns the "cat" and "car" categories, the number of images used per epoch under the SWIL condition (memory ratio and speedup ratio are 18.75x and 20x respectively) is less than the entire data used per epoch under the FIL condition set (memory ratio and speedup ratio of 31.25x and 45x respectively) and still learn new categories quickly. Extending this idea, as the number of learned categories continues to increase, the author team expects the model's learning time and data storage to be reduced exponentially, allowing it to learn new categories more efficiently, perhaps reflecting what happens when the human brain actually learns.
Experimental results show that SWIL can integrate multiple new classes in the sequence learning framework, allowing the neural network to continue learning without interference.
Figure 6: The author team trained a 6-layer CNN to learn the new "cat" class (Task 1), and then learned "car" class (Task 2) until performance stabilizes in the following two cases: 1) FIL: Contains all old classes (drawn in different colors) and new classes presented with the same probability ("cat"/"car ") image; 2) SWIL: Weighted by similarity to new category ("cat"/"car") and using old category examples in proportion. Also include the “cat” class learned in Task 1 and weight it based on the similarity of the “car” class learned in Task 2. The first subfigure shows the distribution of the number of images used in each epoch. The remaining subfigures respectively indicate the recall rates of FIL (blue) and SWIL (magenta) in predicting new categories, similar old categories and different old categories, and predicting the recall rates of all categories. The overall accuracy, and the cross-entropy loss on the test data set, where the abscissa is the number of epochs.
The author team finally tested the generalization of the SWIL algorithm. Verify whether it can learn data sets that include more categories and whether it is suitable for more complex network architectures.
They trained a complex CNN model-VGG19 (a total of 19 layers) on the CIFAR100 data set (training set 500 images/category, test set 100 images/category) and learned 90 categories. The network is then retrained to learn new categories. Figure 7A shows the similarity matrix calculated by the author's team based on the activation function of the penultimate layer based on the CIFAR100 data set. As shown in Figure 7B, the new "train" class is compatible with many existing transportation classes such as "bus" ("bus"), "streetcar" ("tram"), and "tractor" ("tractor") etc.) are very similar.
Compared with FIL, SWIL can learn new things faster (speedup = 95.45x (45500×6/(1430×2))) and uses less data (memory Ratio=31.8x) was significantly reduced, while the performance was basically the same (H=8.21, P>0.05). As shown in Figure 7C, under the conditions of PIL (H=10.34, P
At the same time, in order to explore whether the large distance between representations of different categories constitutes a basic condition for accelerating model learning, the author team trained two additional neural network models:
As shown in Figure 7B, for the above two network models, the overlap between the new "train" class and the transportation class is higher, but with the VGG19 model In comparison, the separation of each category is low. Compared with FIL, the speed of SWIL learning new things is roughly linear with the increase in the number of layers (slope = 0.84). This result shows that increasing the representational distance between categories can accelerate learning and reduce memory load.
Figure 7: (A) After VGG19 learns the new "train" class, the author team based on the penultimate layer The similarity matrix calculated by the activation function. The five categories “truck”, “streetcar”, “bus”, “house” and “tractor” are most similar to “train”. Exclude diagonal elements from the similarity matrix (similarity=1). (B, left) The author team’s t-SNE dimensionality reduction visualization results for 6-layer CNN, VGG11 and VGG19 networks after passing the penultimate layer of activation function. (B, right) The vertical axis represents the speedup ratio (FIL/SWIL), and the horizontal axis represents the ratio of the number of layers of 3 different networks relative to the 6-layer CNN. The black dotted line, red dotted line and blue solid line represent the standard line with slope =1, the best fitting line and the simulation results respectively. (C) Learning situation of VGG19 model: FoL (black), FIL (blue), PIL (brown), SWIL (magenta) and EqWIL (gold) predict new "train" class, similar old class (transportation class) And the recall rate of different old categories (except the transportation category), the total prediction accuracy of all categories, and the cross-entropy loss on the test data set, where the abscissa is the epoch number. Each graph shows the average of 10 replicates, and the shaded area is ±1 SEM. (D) From left to right, the model predicts the recall of the Fashion-MNIST “boot” class (Fig. 3), CIFAR10 “cat” class (Fig. 4), CIFAR10 “car” class (Fig. 5) and CIFAR100 “train” class. Rate as a function of the total number of images (log scale) used by SWIL (magenta) and FIL (blue). “N” represents the total number of images used in each epoch under each learning condition (including new and old categories).
Will the speed be further improved if the network is trained on more non-overlapping classes, and the distance between representations is larger?
To this end, the author team adopted a deep linear network (used for the Fashion-MNIST example in Figure 1-3) and trained it to learn 8 Fashion - A combined dataset of MNIST categories (excluding "bags" and "boot" categories) and 10 Digit-MNIST categories, and then train the network to learn new "boot" categories.
In line with the expectations of the author team, "boot" is more similar to the old categories "sandals" and "sneaker", followed by the rest of the Fashion-MNIST categories (mainly including clothing categories images), and finally the Digit-MNIST class (mainly including digital images).
Based on this, the author team first interleaved more similar old category samples, and then interleaved Fashion-MNIST and Digit-MNIST category samples (a total of n=350 images/epoch). Experimental results show that, similar to FIL, SWIL can quickly learn new category content without interference, but uses a much smaller data subset, with a memory ratio of 325.7x (114000/350) and an acceleration ratio of 162.85x (228000/228000/350). 1400). The author team observed a speedup of 2.1x (162.85/77.1) in the current results, with a 2.25x increase in the number of categories (18/8) compared to the Fashion-MNIST dataset.
The experimental results in this section help to determine that SWIL can be applied to more complex data sets (CIFAR100) and neural network models (VGG19), proving the generalization of the algorithm. We also demonstrated that widening the internal distance between categories or increasing the number of non-overlapping categories may further increase learning speed and reduce memory load.
Artificial neural networks face significant challenges in continuous learning, often exhibiting catastrophic interference. To overcome this problem, many studies have used fully interleaved learning (FIL), where new and old content are cross-learned to jointly train the network. FIL requires interweaving all existing information every time it learns new information, making it a biologically implausible and time-consuming process. Recently, some research has shown that FIL may not be necessary, and only interleaving old content that has substantial representational similarity with the new content, that is, using the similarity-weighted interleaved learning (SWIL) method can achieve the same learning effect. However, concerns have been expressed about SWIL's scalability.
This paper extends the SWIL algorithm and tests it based on different data sets (Fashion-MNIST, CIFAR10 and CIFAR100) and neural network models (deep linear networks and CNN). Across all conditions, similarity-weighted interleaved learning (SWIL) and equal-weighted interleaved learning (EqWIL) performed better in learning new categories compared to partially interleaved learning (PIL). This is in line with the expectations of the author team, because SWIL and EqWIL increase the relative frequency of new categories compared with old categories.
This paper also demonstrates that careful selection and interweaving of similar content reduces catastrophic interference with similar old categories compared to equally subsampling existing categories (i.e., the EqWIL method). SWIL performs similarly to FIL in predicting new and existing categories, but significantly speeds up learning new content (Figure 7D) while greatly reducing the training data required. SWIL can learn new categories in a sequence learning framework, further demonstrating its generalization capabilities.
Finally, the integration time can be shortened if it has less overlap (larger distance) with previously learned categories than a new category that shares similarities with many old categories, and More data efficiency. Overall, the experimental results provide a possible insight that the brain actually overcomes one of the major weaknesses of the original CLST model by reducing unrealistic training time.
The above is the detailed content of Research shows that similarity-based weighted interleaved learning can effectively deal with the 'amnesia' problem in deep learning. For more information, please follow other related articles on the PHP Chinese website!