下面是一个完整的基于BERT的对话意图识别系统代码,用于检测"线下交易"嫌疑。包含数据生成、模型训练和评估全流程:
import torchimport pandas as pdimport numpy as npfrom sklearn.model_selection import train_test_splitfrom transformers import BertTokenizer, BertForSequenceClassification, AdamWfrom transformers import get_linear_schedule_with_warmupfrom torch.utils.data import Dataset, DataLoaderimport matplotlib.pyplot as pltfrom sklearn.metrics import classification_report, confusion_matriximport seaborn as sns# 设置随机种子确保可复现RANDOM_SEED = 42np.random.seed(RANDOM_SEED)torch.manual_seed(RANDOM_SEED)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 1. 生成模拟对话数据def generate_dialogue_data(num_samples=500): """生成包含正常对话和线下交易嫌疑的对话数据""" data = [] labels = [] # 正常对话模板 normal_templates = [ ("这个商品还有货吗?", "有货的,可以直接下单"), ("质量怎么样?", "我们是正品保证,支持7天无理由退货"), ("什么时候发货?", "24小时内发货"), ("能便宜点吗?", "可以给您包邮"), ("有优惠券吗?", "关注店铺领取5元优惠券"), ("怎么支付?", "平台支持支付宝、微信支付"), ("买多了有折扣吗?", "满100减10"), ("退换货流程是怎样的?", "申请售后我们会及时处理"), ("是正品吗?", "官方授权,假一赔十"), ("快递发什么?", "默认发顺丰快递") ] # 线下交易嫌疑对话模板 suspicious_templates = [ ("平台手续费太高了", "加我微信聊:XXX"), ("可以线下交易吗?", "现金交易更划算"), ("我们私下交易吧", "转我支付宝避免平台抽成"), ("加个微信详细聊?", "我的微信是123456"), ("平台监管太严", "我们直接银行转账"), ("绕过平台交易", "给你我的手机号138XXXX"), ("不走平台可以优惠", "现金支付给你9折"), ("加QQ发你位置", "QQ号:987654"), ("当面交易怎么样?", "明天下午3点地铁站见"), ("平台记录不方便", "我们短信联系") ] # 生成正常对话 for _ in range(num_samples // 2): template = normal_templates[np.random.choice(len(normal_templates))] dialogue = f"用户A: {template[0]}\n用户B: {template[1]}" data.append(dialogue) labels.append(0) # 正常对话标签为0 # 生成可疑对话 for _ in range(num_samples // 2): template = suspicious_templates[np.random.choice(len(suspicious_templates))] dialogue = f"用户A: {template[0]}\n用户B: {template[1]}" data.append(dialogue) labels.append(1) # 可疑对话标签为1 return pd.DataFrame({"dialogue": data, "label": labels})# 生成500条对话数据df = generate_dialogue_data(500)print("数据分布:\n", df["label"].value_counts())print("\n示例数据:")print(df.head())# 2. 数据预处理# 使用BERT的分词器tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')# 数据集类class DialogueDataset(Dataset): def __init__(self, dialogues, labels, tokenizer, max_len): self.dialogues = dialogues self.labels = labels self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.dialogues) def __getitem__(self, idx): dialogue = str(self.dialogues[idx]) label = self.labels[idx] encoding = tokenizer.encode_plus( dialogue, add_special_tokens=True, max_length=self.max_len, return_token_type_ids=False, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt', ) return { 'dialogue_text': dialogue, 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long) }# 数据分割train_df, test_df = train_test_split(df, test_size=0.2, random_state=RANDOM_SEED)train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=RANDOM_SEED)# 参数设置MAX_LEN = 64BATCH_SIZE = 16# 创建数据加载器def create_data_loader(df, tokenizer, max_len, batch_size): ds = DialogueDataset( dialogues=df.dialogue.to_numpy(), labels=df.label.to_numpy(), tokenizer=tokenizer, max_len=max_len ) return DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)train_data_loader = create_data_loader(train_df, tokenizer, MAX_LEN, BATCH_SIZE)val_data_loader = create_data_loader(val_df, tokenizer, MAX_LEN, BATCH_SIZE)test_data_loader = create_data_loader(test_df, tokenizer, MAX_LEN, BATCH_SIZE)# 3. 模型构建model = BertForSequenceClassification.from_pretrained( 'bert-base-chinese', num_labels=2)model = model.to(device)# 4. 训练设置EPOCHS = 5optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)total_steps = len(train_data_loader) * EPOCHSscheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=total_steps)loss_fn = torch.nn.CrossEntropyLoss().to(device)# 5. 训练函数def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler): model = model.train() losses = [] correct_predictions = 0 for batch in data_loader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) loss = outputs.loss losses.append(loss.item()) _, preds = torch.max(outputs.logits, dim=1) correct_predictions += torch.sum(preds == labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() scheduler.step() optimizer.zero_grad() accuracy = correct_predictions.double() / len(data_loader.dataset) avg_loss = np.mean(losses) return avg_loss, accuracy# 6. 评估函数def eval_model(model, data_loader, device): model = model.eval() losses = [] correct_predictions = 0 all_preds = [] all_labels = [] with torch.no_grad(): for batch in data_loader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) loss = outputs.loss losses.append(loss.item()) _, preds = torch.max(outputs.logits, dim=1) correct_predictions += torch.sum(preds == labels) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) accuracy = correct_predictions.double() / len(data_loader.dataset) avg_loss = np.mean(losses) return avg_loss, accuracy, all_preds, all_labels# 7. 训练循环history = { 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}for epoch in range(EPOCHS): print(f'Epoch {epoch + 1}/{EPOCHS}') print('-' * 10) train_loss, train_acc = train_epoch( model, train_data_loader, loss_fn, optimizer, device, scheduler ) print(f'Train loss: {train_loss:.4f}, accuracy: {train_acc:.4f}') val_loss, val_acc, _, _ = eval_model(model, val_data_loader, device) print(f'Val loss: {val_loss:.4f}, accuracy: {val_acc:.4f}') print() history['train_loss'].append(train_loss) history['train_acc'].append(train_acc.cpu()) history['val_loss'].append(val_loss) history['val_acc'].append(val_acc.cpu())# 8. 评估模型test_loss, test_acc, y_pred, y_true = eval_model(model, test_data_loader, device)print(f'Test accuracy: {test_acc:.4f}')print(classification_report(y_true, y_pred, target_names=['正常', '可疑']))# 混淆矩阵def plot_confusion_matrix(y_true, y_pred): cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(8,6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['正常', '可疑'], yticklabels=['正常', '可疑']) plt.xlabel('预测标签') plt.ylabel('真实标签') plt.title('混淆矩阵') plt.show()plot_confusion_matrix(y_true, y_pred)# 9. 预测函数def predict_intent(dialogue, model, tokenizer, device, max_len=64): model.eval() encoding = tokenizer.encode_plus( dialogue, add_special_tokens=True, max_length=max_len, return_token_type_ids=False, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt', ) input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, prediction = torch.max(outputs.logits, dim=1) probability = torch.nn.functional.softmax(outputs.logits, dim=1) return { "dialogue": dialogue, "prediction": "可疑对话" if prediction.item() == 1 else "正常对话", "confidence": probability[0][prediction.item()].item(), "suspicious_score": probability[0][1].item() }# 10. 测试示例test_dialogues = [ "用户A: 可以微信转账吗?\n用户B: 加我微信123456,转完发货", "用户A: 质量怎么样?\n用户B: 都是正品,支持平台验货", "用户A: 平台手续费太高了\n用户B: 我们私下交易,给你优惠价", "用户A: 什么时候能发货?\n用户B: 付款后24小时内发货"]for dialogue in test_dialogues: result = predict_intent(dialogue, model, tokenizer, device) print(f"对话: {result['dialogue']}") print(f"预测: {result['prediction']} (置信度: {result['confidence']:.4f})") print(f"可疑指数: {result['suspicious_score']:.4f}") print("-" * 50)
代码说明:
数据生成:
- 创建500个对话样本(50%正常,50%可疑)正常对话:普通购物咨询可疑对话:包含线下交易、第三方联系方式等关键词
数据处理:
- 使用
bert-base-chinese
分词器创建PyTorch数据集和数据加载器设置最大序列长度64模型构建:
- 基于预训练的BERT模型添加分类层使用交叉熵损失函数设置AdamW优化器
训练流程:
- 训练5个epoch记录训练/验证损失和准确率使用线性学习率调度
评估指标:
- 测试集准确率分类报告(精确率、召回率、F1值)混淆矩阵可视化
预测功能:
- 输入新对话,输出预测结果显示预测标签和置信度提供可疑指数(0-1之间的概率值)
关键特征:
- 中文BERT适配:使用
bert-base-chinese
预训练模型完整流程:数据生成→预处理→训练→评估→预测可视化分析:混淆矩阵展示模型性能实用预测:输出可疑指数和置信度预期输出示例:
对话: 用户A: 可以微信转账吗?\n用户B: 加我微信123456,转完发货预测: 可疑对话 (置信度: 0.9987)可疑指数: 0.9987--------------------------------------------------对话: 用户A: 质量怎么样?\n用户B: 都是正品,支持平台验货预测: 正常对话 (置信度: 0.9921)可疑指数: 0.0079--------------------------------------------------
扩展建议:
- 增加更多对话模式和关键词尝试不同预训练模型(如RoBERTa)添加实体识别强化关键词检测集成规则引擎提高准确率
此代码完整实现了从数据生成到模型部署的全流程,可以直接运行(需安装transformers、torch等库)。