简介
利用CNN网络实现对简单文本的情感分类。
要求
numpy<2
torch与torchtext版本兼容
代码
main.py
import osimport argparseimport datetimeimport torchimport torchtextfrom torchtext.vocab import build_vocab_from_iteratorfrom torchtext.data.utils import get_tokenizerfrom torch.nn.utils.rnn import pad_sequenceimport modelimport trainimport mydatasets# 禁用 torchtext 弃用警告torchtext.disable_torchtext_deprecation_warning()# 参数解析parser = argparse.ArgumentParser(description='CNN text classifier')# 学习参数parser.add_argument('-lr', type=float, default=0.001, help='学习率 [默认: 0.001]')parser.add_argument('-epochs', type=int, default=256, help='训练轮数 [默认: 256]')parser.add_argument('-batch-size', type=int, default=64, help='批次大小 [默认: 64]')parser.add_argument('-log-interval', type=int, default=1, help='日志间隔 [默认: 1]')parser.add_argument('-test-interval', type=int, default=100, help='测试间隔 [默认: 100]')parser.add_argument('-save-interval', type=int, default=500, help='模型保存间隔步数 [默认: 500]')parser.add_argument('-save-dir', type=str, default='snapshot', help='模型保存路径')parser.add_argument('-early-stop', type=int, default=1000, help='早停步数')parser.add_argument('-save-best', type=bool, default=True, help='保存最佳模型')# 数据参数parser.add_argument('-shuffle', action='store_true', default=False, help='打乱数据')# 模型参数parser.add_argument('-dropout', type=float, default=0.5, help='Dropout概率 [默认: 0.5]')parser.add_argument('-embed-dim', type=int, default=128, help='词向量维度 [默认: 128]')parser.add_argument('-kernel-num', type=int, default=100, help='卷积核数量')parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='卷积核尺寸')parser.add_argument('-static', action='store_true', default=False, help='固定词向量')# 设备参数parser.add_argument('-device', type=int, default=-1, help='设备编号 (-1=CPU)')parser.add_argument('-no-cuda', action='store_true', default=False, help='禁用GPU')# 选项参数parser.add_argument('-snapshot', type=str, default=None, help='模型快照路径')parser.add_argument('-predict', type=str, default=None, help='预测文本')parser.add_argument('-test', action='store_true', default=False, help='测试模式')def collate_batch(batch): """处理批次数据:填充文本并转换为张量""" texts, labels = zip(*batch) # 直接使用已有张量(避免重复包装) padded_texts = pad_sequence(texts, batch_first=True, padding_value=0) labels = torch.stack(labels) return padded_texts, labelsif __name__ == '__main__': args = parser.parse_args() # 构建词汇表 print("\nBuilding vocabulary...") tokenizer = get_tokenizer(lambda x: x.split()) def yield_tokens(data_iter): for example in data_iter.examples: text, _ = example yield text # 临时转换函数 temp_text_transform = lambda x: x temp_label_transform = lambda x: x # 加载原始数据集构建词汇表 train_data, _ = mydatasets.MR.splits(temp_text_transform, temp_label_transform, root='.') vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=["<unk>"]) vocab.set_default_index(vocab["<unk>"]) # 定义正式转换函数 text_transform = lambda x: vocab(x) label_transform = lambda x: 1 if x == 'positive' else 0 # 加载带转换的数据集 print("\nLoading datasets...") train_data, dev_data = mydatasets.MR.splits(text_transform, label_transform, root='.', dev_ratio=0.1) # 创建DataLoader train_iter = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, collate_fn=collate_batch, shuffle=args.shuffle ) dev_iter = torch.utils.data.DataLoader( dev_data, batch_size=args.batch_size, collate_fn=collate_batch ) # 更新模型参数 args.embed_num = len(vocab) args.class_num = 2 args.cuda = (not args.no_cuda) and torch.cuda.is_available() args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) print("\nModel Parameters:") for attr, value in sorted(args.__dict__.items()): print(f"\t{attr.upper():<15} = {value}") # 初始化模型 cnn = model.CNN_Text(args) if args.snapshot: print(f"\nLoading model from {args.snapshot}...") cnn.load_state_dict(torch.load(args.snapshot)) if args.cuda: torch.cuda.set_device(args.device) cnn = cnn.cuda() # 运行模式 if args.predict: label = train.predict(args.predict, cnn, text_transform, label_transform, args.cuda) print(f'\n[Text]\t{args.predict}\n[Label]\t{label}\n') elif args.test: try: train.eval(dev_iter, cnn, args) except Exception as e: print(f"\nTest error: {str(e)}") else: try: train.train(train_iter, dev_iter, cnn, args) except KeyboardInterrupt: print('\n' + '-'*50) print('Training stopped by user')
train.py
import osimport sysimport torchimport torch.autograd as autogradimport torch.nn.functional as Fdef train(train_iter, dev_iter, model, args): if args.cuda: model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) steps = 0 best_acc = 0 last_step = 0 for epoch in range(1, args.epochs + 1): for batch in train_iter: model.train() feature, target = batch if args.cuda: feature, target = feature.cuda(), target.cuda() optimizer.zero_grad() logit = model(feature) loss = F.cross_entropy(logit, target) loss.backward() optimizer.step() steps += 1 if steps % args.log_interval == 0: corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum() accuracy = 100.0 * corrects / feature.size(0) sys.stdout.write( '\rBatch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(steps, loss.item(), accuracy.item(), corrects.item(), feature.size(0))) if steps % args.test_interval == 0: dev_acc = eval(dev_iter, model, args) if dev_acc > best_acc: best_acc = dev_acc last_step = steps if args.save_best: save(model, args.save_dir, 'best', steps) else: if steps - last_step >= args.early_stop: print('early stop by {} steps.'.format(args.early_stop)) elif steps % args.save_interval == 0: save(model, args.save_dir, 'snapshot', steps)def eval(data_iter, model, args): model.eval() corrects, avg_loss = 0, 0 with torch.no_grad(): for batch in data_iter: feature, target = batch if args.cuda: feature, target = feature.cuda(), target.cuda() logit = model(feature) loss = F.cross_entropy(logit, target, reduction='sum') avg_loss += loss.item() corrects += (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum() size = len(data_iter.dataset) avg_loss /= size accuracy = 100.0 * corrects / size print('\nEvaluation - loss: {:.6f} acc: {:.4f}%({}/{}) \n'.format(avg_loss, accuracy, corrects, size)) return accuracydef predict(text, model, text_transform, label_transform, cuda_flag): assert isinstance(text, str) model.eval() text = text_transform(text) x = torch.tensor([text]) x = autograd.Variable(x) if cuda_flag: x = x.cuda() print(x) output = model(x) _, predicted = torch.max(output, 1) label_mapping = {0: 'negative', 1: 'positive'} return label_mapping[predicted.item()]def save(model, save_dir, save_prefix, steps): if not os.path.isdir(save_dir): os.makedirs(save_dir) save_prefix = os.path.join(save_dir, save_prefix) save_path = '{}_steps_{}.pt'.format(save_prefix, steps) torch.save(model.state_dict(), save_path)
mydatasets.py
# mydatasets.pyimport reimport osimport randomimport tarfileimport urllib.requestimport torchfrom torch.utils.data import Datasetfrom torchtext.data.utils import get_tokenizerclass TarDataset(Dataset): @classmethod def download_or_unzip(cls, root): path = os.path.join(root, cls.dirname) if not os.path.isdir(path): tpath = os.path.join(root, cls.filename) if not os.path.isfile(tpath): print('Downloading dataset...') urllib.request.urlretrieve(cls.url, tpath) with tarfile.open(tpath, 'r') as tfile: print('Extracting...') def is_within_directory(directory, target): abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) prefix = os.path.commonprefix([abs_directory, abs_target]) return prefix == abs_directory def safe_extract(tar, path=".", members=None, numeric_owner=False): for member in tar.getmembers(): member_path = os.path.join(path, member.name) if not is_within_directory(path, member_path): raise Exception("Attempted Path Traversal in Tar File") tar.extractall(path, members, numeric_owner=numeric_owner) safe_extract(tfile, root) return os.path.join(path, '')class MR(TarDataset): url = 'https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz' filename = 'rt-polaritydata.tar.gz' dirname = 'rt-polaritydata' def __init__(self, text_transform, label_transform, path=None, examples=None): self.text_transform = text_transform self.label_transform = label_transform # 修正分词器:清洗后按空格切分 self.tokenizer = get_tokenizer(lambda x: self.clean_str(x).split()) if examples is None: path = self.download_or_unzip(os.getcwd() if path is None else path) self.examples = [] # 加载负面评价(确保存储分词列表) with open(os.path.join(path, 'rt-polarity.neg'), 'r', encoding='latin-1') as f: for line in f: tokenized_text = self.tokenizer(line.strip()) self.examples.append((tokenized_text, 'negative')) # 加载正面评价 with open(os.path.join(path, 'rt-polarity.pos'), 'r', encoding='latin-1') as f: for line in f: tokenized_text = self.tokenizer(line.strip()) self.examples.append((tokenized_text, 'positive')) else: self.examples = examples @staticmethod def clean_str(string): string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) string = re.sub(r"\'s", " \'s", string) string = re.sub(r"\'ve", " \'ve", string) string = re.sub(r"n\'t", " n\'t", string) string = re.sub(r"\'re", " \'re", string) string = re.sub(r"\'d", " \'d", string) string = re.sub(r"\'ll", " \'ll", string) string = re.sub(r",", " , ", string) string = re.sub(r"!", " ! ", string) string = re.sub(r"\(", " ( ", string) string = re.sub(r"\)", " ) ", string) string = re.sub(r"\?", " ? ", string) return re.sub(r"\s{2,}", " ", string).strip() def __len__(self): return len(self.examples) def __getitem__(self, idx): raw_text, raw_label = self.examples[idx] # 确保输入是分好词的列表 processed_text = self.text_transform(raw_text) # 这里接收的应是列表 processed_label = self.label_transform(raw_label) return torch.tensor(processed_text, dtype=torch.long), torch.tensor(processed_label, dtype=torch.long) @classmethod def splits(cls, text_transform, label_transform, dev_ratio=0.1, shuffle=True, root='.', **kwargs): path = cls.download_or_unzip(root) full_dataset = cls(text_transform, label_transform, path=path, **kwargs) examples = full_dataset.examples if shuffle: random.shuffle(examples) split_idx = int(len(examples) * (1 - dev_ratio)) return (cls(text_transform, label_transform, examples=examples[:split_idx]), cls(text_transform, label_transform, examples=examples[split_idx:]))
model.py
import torchimport torch.nn as nnimport torch.nn.functional as Fclass CNN_Text(nn.Module): def __init__(self, args): super(CNN_Text, self).__init__() self.args = args # 词向量层 self.embed = nn.Embedding(args.embed_num, args.embed_dim) # 卷积层组 self.convs = nn.ModuleList([ nn.Conv2d( in_channels=1, out_channels=args.kernel_num, kernel_size=(k, args.embed_dim) # 卷积核尺寸 (高度, 宽度) ) for k in args.kernel_sizes ]) # 分类层 self.dropout = nn.Dropout(args.dropout) self.fc = nn.Linear(len(args.kernel_sizes) * args.kernel_num, args.class_num) # 冻结词向量 if args.static: self.embed.weight.requires_grad = False def forward(self, x): # 输入形状: (batch_size, seq_len) x = self.embed(x) # (batch_size, seq_len, embed_dim) # 添加通道维度: (batch_size, 1, seq_len, embed_dim) x = x.unsqueeze(1) # 卷积处理 conv_outputs = [] for conv in self.convs: # 卷积: (batch, Co, seq_len-k+1, 1) conv_out = F.relu(conv(x)) # 去除最后维度: (batch, Co, seq_len-k+1) conv_out = conv_out.squeeze(3) # 最大池化: (batch, Co) pooled = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2) conv_outputs.append(pooled) # 拼接特征 x = torch.cat(conv_outputs, 1) x = self.dropout(x) logits = self.fc(x) return logits
训练
运行命令python main.py -epochs 30 (训练时间过长,将epoch改为30,缩短训练时间)
训练过程
训练后将会得到类似的目录结构,快照选项表示模型从何处加载。如果不指定该选项,模型将从头开始。
测试
测试命令如下,让模型判断"Hello my dear , I love you so much ."这句话的情感。
python main.py -predict="Hello my dear , I love you so much ." -snapshot="./snapshot/2025-05-01_14-36-47/best_steps_3800.pt"
结果
结果为positive