Introduction
As internet technology enters the Web 2.0 era, the data on the internet are exploding. While people enjoy the convenience brought by massive data, it is difficult to retrieve and acquire data because of the complexity and diversity of data. Text classification is an important method that can help users organize data efficiently and improve the speed and quality of data retrieval. Text classification is not only an essential branch of data mining, but also a classic problem in natural language processing [1]. It has a wide range of applications, such as language translation, text summarization, news recommendation, and spam filtering. Text classification uses natural language processing, data mining, and machine learning to effectively classify different text types and discover rules [2]. However, there are rarely tags in data in practice, and manually generating tagged data will consume time and cause low accuracy. Therefore, it is particularly urgent to conduct semi-supervised text classification research under the condition of a small amount of labeled data.
Text classification extracts semantic features from the original text corpus and predicts the topic categories of text data based on these features. In the past few decades, various models for text classification have emerged. Early text classification methods mainly rely on the classification rules manually discovered and formulated by domain experts. With the emergence of statistical learning and machine learning, text classification based on machine learning has made a breakthrough. Common machine learning-based classifiers include support vector machine [3], naive Bayes [4], K-nearest neighbor [5], decision tree [6], and logistic regression [7]. Compared with earlier rule-based methods, text classification methods based on statistical models have apparent advantages in terms of accuracy and stability. However, these methods still require laborious and time-consuming feature engineering. In addition, they typically do not consider natural sequential structures or contextual information. This makes it difficult for models to learn semantic information.
After the 2010s, text classification methods gradually changed from statistical to deep learning methods. The deep learning model uses an artificial neural network that simulates the way human brains to classify data. Compared with statistical model-based methods, deep learning methods avoid artificial design rules and features and can automatically mine a large number of rich semantic representations from the text.
Feedforward neural networks and Recursive Neural Networks (RNNs) were the first two deep learning methods for text classification tasks. Many researchers have improved the performance of text categorization for different tasks by improving Convolutional Neural Networks (CNNs), RNNs, attention mechanisms, model fusion and multitask methods. However, the text classification method based on CNN and RNN is implemented on the premise of independent texts. There are correlation relationships among the texts, such as citation networks, social relationship networks, and biomolecular structures. This correlation has an important role in the discrimination of text categories. Therefore, researchers pay more attention to designing more efficient text classification method by combining graph neural networks with convolutional neural networks and attention mechanisms. Kipf et al. proposed a semi-supervised text classification method [8] based on a graph convolutional neural network (GCN) for classifying the citation datasets. Yao et al. proposed the TextGCN [9] model for classifying text datasets without clear correlation. Wu et al. designed a multilayer convolutional neural network model named SGC [10] for text classification.
Compared with neural network models such as CNN and RNN, the GCN model significantly improves classification accuracy. However, it also has obvious shortcomings. First, the lack of flexibility of the GCN model leads to poor scalability. Second, GCN model training is full-batch, which is difficult to extend to large-scale networks and has slow convergence. In addition, in the GCN model, the weight of adjacent node features depend on the graph’s structure and has nothing to do with the features of the node itself, which is contrary to the goal of text classification. In contrast to this, graph attentional Neural Network (GAT) [11] takes the advantages of GCN into account. Meanwhile, the weight of adjacent node features in GAT entirely depends on the node features, independent of the graph’s structure. These advantages make GAT more suitable for solving text classification problems.
To overcome the problems of large-scale text data classification, this paper designs a semi-supervised multilabel text classification model based on a multi-layer DGAT. First, this model allows text features to be propagated along the correlation relationship and fused with the associated text features to improve the text feature expression ability. Second, we design a dual-level attentional model to distinguish the influence of associated nodes on target nodes. Furthermore, we apply network stacking and residual connections to improve the range of information transmission and ensures the learning ability of the model. Given a small amount of labeled text and text association graphs, the classification model can classify unlabeled text. The main contributions of this paper are as follows:
In this paper, we propose a feature fusion framework based on DGAT. In this framework, we can combine multi-layer DGAT and residual network to build a feature fusion framework for nodes connected across multiple hops. This framework can improve the expression ability of node features and reduce the risk of model overfitting.
Using the idea of information transmission, we aggregate the features of nodes in the local neighborhood to the target node and distinguish the weight from the two aspects of the node’s type and its eigenvalue. In addition, we also control the range of neighborhood nodes by changing the number of DGAT layers. It can effectively improve the model’s accuracy and the ability to deal with large-scale text data.
The proposed method is implemented and compared with many existing works on five benchmark datasets. Experimental results show it has clear advantages in terms of model correctness.
Problem Definition and Related Work
This section first formulates the problem and then summarizes the relevant work on text classification based on graph attention networks.
A. Problem Definition
Let
For any document
B. Deep Learning for Text Classification
In recent years, deep learning has made a breakthrough in text, images, and other fields. Compared with text classification models of traditional machine learning, deep learning methods can obtain better semantic representation and reduce incompleteness and complexity in the feature extraction process. The research on deep learning text classification can be divided into two categories. One is the word embedding model studied by Mikolov et al., which aggregates and embeds unsupervised words into documents and then embeds these documents into the classifier. The other is to use neural networks, such as CNNs [12]. CNN is a deep neural network with a convolutional structure. Convolutional design can reduce the memory requirements of deep text classification methods and can be roughly divided into four categories: reinforcement learning, ensemble learning, transfer learning, and deep learning. This section summarizes the characteristics of the method proposed in this article based on the analysis of related work.
In terms of applying convolutional neural networks to text classification tasks, Kim et al. proposed a model using CNN. This model trains a neural network model using labeled data to predict data categories. It includes a word vector layer, convolutional layer, top pooling layer, fully connected layer, and a softmax layer. There are many hyperparameters to choose from in the convolutional neural network. Gori et al. [13] conducted various comparative experiments under different hyperparameter settings for the model in the literature and gave suggestions for parameter tuning and location experience. Many researchers have studied and improved the model proposed by Kim et al. regarding structure, depth, and training speed. For example, Wang et al. [14] performed semantic feature extraction in the word vector matrix before inputting the convolutional layer and finally performing classification. Kalchbrenner et al. [15] proposed a network model called a DCNN (dynamic convolutional neural network), which uses a multilayer convolutional neural network and can handle variable-length inputs. However, these methods are all external neural networks and consider using deep neural networks to solve the text classification problem later. For example, the DPCNN (deep pyramid convolutional neural networks) [16] proposed by Tencent AI-lab in ACL2017 extracts the text’s long-distance dependence problem by deepening the number of layers of the neural network.
Although these methods are effective, they mainly focus on local word sequences and do not focus on the corpus’ global word co-occurrence information.
C. Semi-Supervised Text Classification
Due to the lack of annotated datasets and the inaccuracy of manual annotations, semi-supervised methods have been proposed. They can be categorized into two classes: (1) latent variable models and (2) embedding-based models [17]. The former extends the topic model through user-provided seed information and then infers the documents’ labels based on posterior category-topic assignment. The latter uses seed information to export the embedding of the documents and label names for text classification. For example, Yin et al. [18] used a semi-supervised learning method based on SVM to label unlabeled documents iteratively. GCNs have recently achieved good results in semi-supervised classification and have received extensive attention.
D. Graph Convolutional Networks
Scarselli et al. [19] were the first to try to extend the neural network to graph structures [13], [20], but it did not initially attract wide attention. CNNs have made extraordinary progress in computer vision and other fields, opening a new era of deep learning [21], [22]. The critical points of CNN are local connection, weight sharing, and the use of a deep network [23]. However, the pixel points in the image or video data processed by it are neatly arranged matrices and Euclidean structures. The corresponding data of a non-Euclidean structure, such as a social network, is the topology diagram in the abstract sense in graph theory, and the CNN finds it challenging to deal with.
In the last few years, there has been much interest in using convolution on graphs. These methods are based on neighborhood aggregation schemes and can be further divided into spectral methods and spatial methods. Spectral methods are based on spectral theory to define parametric filters. Bruna et al. [24] first defined a parametric filter in the Fourier domain. However, its large amount of computation limits its scalability [25]. To improve efficiency, Defferrard et al. proposed the Chebnet algorithm and approximated the K-polynomial filter through the Chebyshev expansion of the Laplacian operator. Kipf and Welling further simplified the Chebnet algorithm by simplifying Chebyshev polynomials.
The spatial method combines the neighborhood information of the vertex domain to generate node embedding. MoNet [26] and SplineCNN [27] integrate local signals by designing an operator. In the literature [8], Thomas N. Kipf and Max Welling proposed a scalable semi-supervised learning method based on a graph data structure GCN multilayer neural network that runs directly on the graph. The model scales linearly on the number of edges in the graph and can learn hidden layer representations, which encode both the local graph structure and node features. Later, Henaff et al. [28] explored graph neural networks for text classification. However, these models treat documents or sentences as nodes or rely on standardized literature citation relationships to build graphs, which have significant limitations. Deferred et al. [29] and Peng et al. [30]. have explored text classification algorithms based on graph convolutional networks. Recently, Yao et al. [9] used GCN for text classification. Their model is named TextGCN. The model embeds the entire corpus into a graph, which uses documents and words as nodes and links between documents and words and between words as edges. Given a predefined weight for each edge, the GCN model can be trained using the labeled data, and the unlabeled can be predicted based on the model. Gao et al. select a fixed number of neighborhood nodes for each feature and allow conventional convolution operations on Euclidean spaces.
Feature Fusion Based on Text Information Graph
This section presents text datasets and reference relationships using text infographics and discusses the correlation between texts and text feature fusion methods.
A. Text Information Graph
Given a text set
According to the reference relationship between texts, as shown in Figure 1, there may be the following possible relations between any two texts
B. Information Transmission and Feature Fusion
In a text infographic, the reference relation between texts not only represents the correlation relation between texts but also contains the direction of text feature fusion. If there are reference relations between text \begin{equation*} \hat {h}_{i}=\sum _{v_{j}\in N(v_{i})}^{}{f(v_{j},v_{i})\times h_{j}}\tag{1}\end{equation*}
According to Formula (1), we can fuse the features represented by vertices and adjacent vertices directly connected to them. However, in an actual application, the characteristics of \begin{equation*} \hat {h}_{i}^{l}=\sum _{v_{j}\in N(v_{i})}^{}{f(v_{j},v_{i})\times \hat {h}_{j}^{l-1}}\tag{2}\end{equation*}
C. Feature Fusion With Residual Connection
The previous section discussed the method of fusing the features expressed by vertices whose distances are less than
However, with the increase in
According to this idea, we revise Formula (2) as follows so that it superimposes with the original features of the vertex itself after fusing the vertex and surrounding vertex features.\begin{equation*} \hat {h}_{i}^{l}=h_{0}+\sum _{v_{j}\in N(v_{i})}^{}{f(v_{j},v_{i})\times \hat {h}_{j}^{l-1}}\tag{3}\end{equation*}
Classification Based on Multilayer DAGTs
The previous section discussed the method of information transmission and feature fusion in text infographics, which integrates the features of each vertex in the text infographic with those of associated vertices. For each vertex in a text infographic, the vertices associated with it may belong to different types and have diverse features. Therefore, these associated vertices have an inequitable influence on the labeled classification of the vertex. Next, we use the two-layer attention mechanism to subdivide the relationship of the vertices in the text infographic, distinguish the difference in the connection between related vertices, and label the categories of unlabeled vertices based on this basis.
A. Model Framework
According to the idea of information transmission and feature fusion based on text infographics, we design a framework model based on a multilayer graph neural network and attention mechanism for multilabel document classification, as shown in figure 2.
In the proposed model framework, we use a
B. Type-Level Attention
Given a vertex in a text infographic, we call it the target vertex. Many vertices might connect to the target vertex and may differ in classification. Generally, the features of vertices in the same category are more similar. Based on this, more attention should be given to the vertices belonging to the same category as the target vertex and less attention should be given to the vertices belonging to other categories when fusing the features of the target vertex and its related vertices. In this way, we can enhance the ability of the target vertex to represent its category.
Given a target vertex \begin{equation*} score_{c}=\sigma (\mu _{c}^{T}\cdot [W_{c}h_{t}\circleddash W_{c}h_{c}])\tag{4}\end{equation*}
Finally, the softmax function is used to normalize the attention scores of the target vertex on all types of the surrounding nodes and thus obtain the attention coefficient of the target vertex on every type.\begin{equation*} \alpha _{c}=\frac {exp(score_{c})}{\sum _{c'\in C_{N(v_{t})}}^{}{exp(score_{c'})}}\tag{5}\end{equation*}
C. Node-Level Attention
The previous section discussed how to calculate the attention coefficient of the target node to different types of nodes. This section will continue to discuss how to calculate the attention coefficient of the target node to each adjacent point. The node-attention mechanism differentiates the attentions of the target node to different nodes during feature fusion. In this way, it can capture the vital feature information of adjacent nodes, reduce the noise information brought by secondary nodes, and improve the ability to express its category.
Given a target node \begin{equation*} \beta _{ij}=\frac {exp\big (\mu _{v}^{T}\cdot [W_{v}h_{i}\circleddash W_{v}h_{j}]\big)} {\sum _{k\in N(v_{i})}^{}{exp\big (\mu _{v}^{T}\cdot [W_{v}h_{i}\circleddash W_{v}h_{k}]\big)}}\tag{6}\end{equation*}
D. Features Fusion With DGAT
Using two-level attention, we distinguish the attention of the target node to different nodes and node types, and fuse the features of the target node with those of the surrounding nodes. According to the relationship between nodes in the text information graph, we can use dual-layer attention mechanism diagram attention network (named DGAT for short) to learn the type attention matrix \begin{equation*} h_{i}=\sigma \left({\sum _{c\in C_{N(v_{i})}}^{}{\alpha _{c} W_{c}h_{c}}}\right)\oplus \sigma \left({\sum _{v_{j}\in N(v_{i})}^{}{\beta _{ij} W_{v}h_{j}}}\right)\tag{7}\end{equation*}
By applying one layer of DGATs, the target node can aggregate the features of the nodes that are
On the basis that the fusion feature \begin{equation*} H^{(l+1)}=\sigma \big (H^{l}\big)\tag{8}\end{equation*}
Therefore, the further iteration of feature fusion can be completed by superimposing more DGAT networks, which can realize the feature fusion between the target node and the surrounding nodes in a more extensive range.
E. Residual Network
In the previous subsection, we discussed applying the
When applying multilayer DGAT networks for text classification in our model, we use residual network to solve the problem of the degradation of node feature expression and vanishing gradient, which mainly includes two measures. Firstly, when fusing the fused features obtained from the DGAT network of \begin{equation*} H^{f}=ReLU(H^{l})+H^{0}\tag{9}\end{equation*}
F. Cost Function
After the final fusion feature matrix, \begin{equation*} Z=softmax(H^{f})\tag{10}\end{equation*}
Given the category set, denoted by \begin{equation*} Cost=-\frac {1}{N}\sum _{v\in V}^{}{\sum _{c\in C}^{}{v^{c}\cdot \log (Z_{v}^{c})}}\tag{11}\end{equation*}
Experiments
This section first introduces the datasets used in the experiments, the settings of the experimental environment parameters, and the related methods used for performance comparison. Finally, we report the experimental results.
A. Datasets
We used five reference datasets, including three citation datasets Cora, Citeseer, and PubMed, one protein correlation dataset PPI, and one knowledge graph dataset NELL.
1) Cora
The Cora dataset consists of 2708 papers in the field of machine learning. It has been a popular deep learning graph dataset in recent years. The papers were selected in such a way that in the final corpus every paper cited or was cited by at least one other paper and the papers were classified into seven classes. After stemming and removing stopwords, we were left with a vocabulary of 1433 unique words. All words with a document frequency less than 10 were removed.
2) Citeseer
This is a linked dataset built with permission from the Citeseer web database. Each row represents a scientific paper and each attribute represents an author. In a given row, there is a one for every author associated with that row (i.e., paper) and a zero for every author not associated with that row. The data file is stored in a sparse format and hence we do not expect a giant CSV matrix. The Citeseer dataset consists of 3312 scientific publications classified into one of six classes. The citation network consists of 4732 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence or presence of the corresponding word from the dictionary. The dictionary consists of 3703 unique words.
3) PubMed
The PubMed dataset consists of 19717 scientific publications from the PubMed database pertaining to diabetes classified into one of three classes. The citation network consists of 44338 links. Each publication in the dataset is described by a TF-IDF weighted word vector from a dictionary that consists of 500 unique words.
4) PPI
We make use of a protein-protein interaction (PPI) dataset that consists of graphs corresponding to different human tissues. The dataset contains 20 graphs for training, 2 for validation and 2 for testing. Critically, testing graphs remain completely unobserved during training. To construct the graphs, we used the preprocessed data provided by [32]. The average number of nodes per graph is 2372. Each node has 50 features that are composed of positional gene sets, motif gene sets and immunological signatures. There are 121 labels for each node set taken from Gene Ontology, collected from the Molecular Signatures Database [33]. A node can possess several labels simultaneously.
5) NELL
NELL is a dataset extracted from the knowledge graph introduced in [34]. A knowledge graph is a directed entity graph with a connection relation between nodes. We follow the preprocessing scheme in [35] and assign separate relational nodes R1 and R2 for each entity pair e(e1,r,e2), namely, (e1, r1) and (e2, r2). The entity nodes are represented by sparse feature vectors, and the number of features is expanded by assigning a unique thermal code to each node.
B. Parameter Settings
In the experiment, for each dataset, we randomly selected 20 texts from each type of document to form the training sample set together with 500 labeled documents. The remaining documents were used as the testing sample set. We used the three-layer DGAT structure to construct our model, performed more than 200 cycles of training on the training sample set, and tested the correctness of the model on the testing sample set. In model training, all word embeddings were set to 100 dimensions, and document features and type features were set to 512 and 8 dimensions, respectively. The
C. Baselines
For comparative analysis, we first compare the proposed method with some benchmark methods by running experiments on three citation network datasets and NELL datasets. Benchmark methods for comparison include the manifold regularization (ManiReg) method [37], label propagation (LP) [38], ICA [39], SemiEmb [40], DeepWalk [41], GCN [8], sMGC [42] and GAT-MH [11]. Methods ManiReg, LP, and ICA represent traditional machine learning methods among the above benchmark methods. Methods SemiEmb and DeepWalk represent the methods that combine traditional machine learning methods with deep learning methods. GCN and sMGC are deep learning methods based on graph convolutional neural networks. And GAT-MH is a deep learning method based on a graph attention network.
Moreover, we also compared against the GraphSAGE [32] method by conducting the experiments on the PPI dataset. The GraphSAGE method is presented by Hamilton et al. It is a general inductive framework that leverages node feature information such as text attributes to efficiently generate node embeddings for previously unseen data. In the comparative experiments, our variants of GraphSAGE that use the different aggregator functions are compared, including GraphSAGE-GCN, GraphSAGE-mean, GraphSAGE-LSTM and GraphSAGE-pool.
D. Results
Firstly, we conduct experiments on three citation network datasets and NELL datasets to evaluate the performance of our method. In the experiments, we construct our algorithm model with three-layer DGAT networks, and report the mean classification accuracy (with standard deviation) on the test nodes after 100 runs. As a comparison, the GAT-MH algorithm adopts a two-layer GAT network. The first layer consists of 8 attention heads computing 8 features each, and the second layer is used for classification. Results reported in [8] and [26] were reused for other baseline methods that participated in the comparison experiment.
As shown in Table 2, our method far outperforms traditional machine learning methods and those that combine machine learning with deep learning. The main reason is that traditional machine learning methods mainly use statistical methods to learn the external characteristics of data and cannot obtain the deep correlation relationship contained in the data. Therefore, the classification accuracy of this kind of method is generally low. For those methods that combine machine learning with deep learning, although the application of neural networks can improve the model’s generalization ability and computing power, the fundamental logic of the model is still machine learning. Therefore, the accuracy of these methods has been improved, but the extent of improvement is relatively limited.
The deep learning method based on GCN and GAT can not only learn the features of the data itself, but also aggregate the features of the neighbor nodes in the graph to the central node and then learn the internal relationship between the data. Therefore, the accuracy of these methods on four datasets is much higher than that of other methods. Comparatively speaking, the weight of the influence of neighbor nodes on the central node is determined by the degree of the node in the method based on GCN but by the correlation of features of neighboring nodes in GAT-based methods. Thus, the method based on GAT can achieve better performance because it can mine the deeper relationship between the data.
DGATs network structure, in which each layer of DGAT is composed of two GAT networks. DGAT network structure is equivalent to a graph attention network with two attention heads, which can simultaneously learn the attention coefficients of the data on categories and connection relationships. At the same time, multi-layer DGATs can achieve high-dimensional node feature aggregation, which can improve the feature expression ability of the model. Although GAT-MH uses a graph attention network with eight attention heads and can learn more attention coefficients on features, the feature dimensions that positively impact classification results are still limited in most cases. In addition, GAT-MH only aggregates the features of nodes directly connected to the central node, which may lose some vital feature information. The experimental results on the NELL dataset show that compared with the GAT-MH method, our method’s accuracy is improved by 2.48%, 4.84%, 0.9%, and 1.53%, respectively.
Secondly, we also conducted several experiments on the protein-protein interaction (PPI) dataset to compare the proposed model with GraphSAGE and other methods. In the experiment, the results of Random and four GraphSAGE models were obtained from the literature [35]. The structure of the GAT-MH model is a two-layer GAT network containing four attention heads, respectively.
The GraphSAGE model constructs the graph structure of the dataset based on the data’s geometric features, then performs feature aggregation for the local subgraphs of each node in the graph, and proposes four types of aggregation functions. The GraphSGAE method is similar to the GAT-based methods in aggregating node features using data correlation. Although the GraphSAGE model presents four different aggregation functions to aggregate data features, these four aggregation functions mainly aggregate data mathematical features and lack aggregation of data semantic features. As shown in Table 3, the accuracy of GAT-MH and our proposed model is far better than that of the GraphSAGE model. The GAT-MH model uses a two-layer neural network, which can well fuse the features of nodes with a distance of no more than two hops from the surrounding nodes. Compared with our model, GAT-MH does not solve the problem of gradient disappearance in multi-layer neural networks. Therefore, GAT-MH cannot use a deeper network model to learn deeper features. However, our model applies residual networks to improve the aggregation ability of multi-layer neural networks. Therefore our model achieves better performance than the GAT-MH method.
In addition, we also performed multiple experiments on four citation network datasets. In the experiments, we analyzed the influence of the layer number of DGAT networks on the model correctness by changing the layer number of DGAT networks from 1 to 7.
As shown in Table 4, the experimental results on the four citation datasets show that the model has the highest accuracy when the layer number of DGAT networks is set to 3. These results indicate that when the central node aggregates the features of nodes whose surrounding distance is no more than three hops, the fusion features of the central node have the most robust expression ability. With the increase of the number of layers in DGAT networks, although the central node integrates more characteristics of nodes, the correlation of nodes will be significantly reduced as the distance between two nodes increases. The central node aggregates the features of nodes far away from the central node, which may introduce obscure features and subsequently reduce the model’s accuracy.
Conclusion
According to the intrinsic correlation of text data, this paper constructs a dual-layer attention mechanism graph neural network (DGAT) to fuse the features of target and local neighborhood nodes. In addition, we can also use the residual network to superimpose multiple DGATs to control the range of local neighborhood nodes participating in feature fusion. This model can aggregate the features of the target node, the features of its neighborhood nodes, and their internal relations, which can effectively improve the ability of node feature expression and classification accuracy. Moreover, the model only fuses the features of nodes in the local domain, dramatically reducing the computational complexity. We implemented the model and conducted extensive experiments on five benchmark datasets. Experimental results show that the proposed model has obvious advantages over other models and methods.
In the future, we will continue this work in two ways. First, we will further modify the model to improve the model’s accuracy and generalization ability as well as reduce the computation complexity. Secondly, we will apply our model to solve specific application problems, verify and optimize our model through practice.