掘金 人工智能 06月25日 14:37
[nlp]transformer+bert+意图识别(实战)
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文介绍了一个基于BERT的中文对话意图识别系统,用于检测“线下交易”嫌疑。系统包括数据生成、模型训练、评估和预测全流程,使用了BERT-base-chinese预训练模型,并提供了数据预处理、模型构建、训练、评估、预测等环节的完整代码实现。通过混淆矩阵和分类报告,展示了模型的性能,并提供了实用预测功能,输出预测标签、置信度和可疑指数。

💬系统首先通过`generate_dialogue_data`函数生成包含正常对话和线下交易嫌疑的对话数据,其中正常对话模拟常见购物咨询,可疑对话则包含引导线下交易的关键词。

⚙️数据预处理环节,使用`bert-base-chinese`分词器对对话进行分词,并构建`DialogueDataset`数据集类,将对话文本转换为模型可接受的输入格式,包括input_ids、attention_mask和labels。

🚀模型构建部分,基于预训练的BERT模型添加分类层,并使用交叉熵损失函数和AdamW优化器进行训练。训练过程中,系统记录训练集和验证集的损失和准确率,并使用线性学习率调度器。

📊评估阶段,系统在测试集上评估模型性能,输出测试集准确率、分类报告(包括精确率、召回率和F1值)以及混淆矩阵,用于全面了解模型的表现。

💡最后,系统提供`predict_intent`函数,用于输入新对话并预测其是否为可疑对话,同时输出预测标签、置信度和可疑指数,方便用户进行判断。

下面是一个完整的基于BERT的对话意图识别系统代码,用于检测"线下交易"嫌疑。包含数据生成、模型训练和评估全流程:

import torchimport pandas as pdimport numpy as npfrom sklearn.model_selection import train_test_splitfrom transformers import BertTokenizer, BertForSequenceClassification, AdamWfrom transformers import get_linear_schedule_with_warmupfrom torch.utils.data import Dataset, DataLoaderimport matplotlib.pyplot as pltfrom sklearn.metrics import classification_report, confusion_matriximport seaborn as sns# 设置随机种子确保可复现RANDOM_SEED = 42np.random.seed(RANDOM_SEED)torch.manual_seed(RANDOM_SEED)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 1. 生成模拟对话数据def generate_dialogue_data(num_samples=500):    """生成包含正常对话和线下交易嫌疑的对话数据"""    data = []    labels = []        # 正常对话模板    normal_templates = [        ("这个商品还有货吗?", "有货的,可以直接下单"),        ("质量怎么样?", "我们是正品保证,支持7天无理由退货"),        ("什么时候发货?", "24小时内发货"),        ("能便宜点吗?", "可以给您包邮"),        ("有优惠券吗?", "关注店铺领取5元优惠券"),        ("怎么支付?", "平台支持支付宝、微信支付"),        ("买多了有折扣吗?", "满100减10"),        ("退换货流程是怎样的?", "申请售后我们会及时处理"),        ("是正品吗?", "官方授权,假一赔十"),        ("快递发什么?", "默认发顺丰快递")    ]        # 线下交易嫌疑对话模板    suspicious_templates = [        ("平台手续费太高了", "加我微信聊:XXX"),        ("可以线下交易吗?", "现金交易更划算"),        ("我们私下交易吧", "转我支付宝避免平台抽成"),        ("加个微信详细聊?", "我的微信是123456"),        ("平台监管太严", "我们直接银行转账"),        ("绕过平台交易", "给你我的手机号138XXXX"),        ("不走平台可以优惠", "现金支付给你9折"),        ("加QQ发你位置", "QQ号:987654"),        ("当面交易怎么样?", "明天下午3点地铁站见"),        ("平台记录不方便", "我们短信联系")    ]        # 生成正常对话    for _ in range(num_samples // 2):        template = normal_templates[np.random.choice(len(normal_templates))]        dialogue = f"用户A: {template[0]}\n用户B: {template[1]}"        data.append(dialogue)        labels.append(0)  # 正常对话标签为0        # 生成可疑对话    for _ in range(num_samples // 2):        template = suspicious_templates[np.random.choice(len(suspicious_templates))]        dialogue = f"用户A: {template[0]}\n用户B: {template[1]}"        data.append(dialogue)        labels.append(1)  # 可疑对话标签为1        return pd.DataFrame({"dialogue": data, "label": labels})# 生成500条对话数据df = generate_dialogue_data(500)print("数据分布:\n", df["label"].value_counts())print("\n示例数据:")print(df.head())# 2. 数据预处理# 使用BERT的分词器tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')# 数据集类class DialogueDataset(Dataset):    def __init__(self, dialogues, labels, tokenizer, max_len):        self.dialogues = dialogues        self.labels = labels        self.tokenizer = tokenizer        self.max_len = max_len            def __len__(self):        return len(self.dialogues)        def __getitem__(self, idx):        dialogue = str(self.dialogues[idx])        label = self.labels[idx]                encoding = tokenizer.encode_plus(            dialogue,            add_special_tokens=True,            max_length=self.max_len,            return_token_type_ids=False,            padding='max_length',            truncation=True,            return_attention_mask=True,            return_tensors='pt',        )                return {            'dialogue_text': dialogue,            'input_ids': encoding['input_ids'].flatten(),            'attention_mask': encoding['attention_mask'].flatten(),            'labels': torch.tensor(label, dtype=torch.long)        }# 数据分割train_df, test_df = train_test_split(df, test_size=0.2, random_state=RANDOM_SEED)train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=RANDOM_SEED)# 参数设置MAX_LEN = 64BATCH_SIZE = 16# 创建数据加载器def create_data_loader(df, tokenizer, max_len, batch_size):    ds = DialogueDataset(        dialogues=df.dialogue.to_numpy(),        labels=df.label.to_numpy(),        tokenizer=tokenizer,        max_len=max_len    )    return DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)train_data_loader = create_data_loader(train_df, tokenizer, MAX_LEN, BATCH_SIZE)val_data_loader = create_data_loader(val_df, tokenizer, MAX_LEN, BATCH_SIZE)test_data_loader = create_data_loader(test_df, tokenizer, MAX_LEN, BATCH_SIZE)# 3. 模型构建model = BertForSequenceClassification.from_pretrained(    'bert-base-chinese',     num_labels=2)model = model.to(device)# 4. 训练设置EPOCHS = 5optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)total_steps = len(train_data_loader) * EPOCHSscheduler = get_linear_schedule_with_warmup(    optimizer,    num_warmup_steps=0,    num_training_steps=total_steps)loss_fn = torch.nn.CrossEntropyLoss().to(device)# 5. 训练函数def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler):    model = model.train()    losses = []    correct_predictions = 0        for batch in data_loader:        input_ids = batch["input_ids"].to(device)        attention_mask = batch["attention_mask"].to(device)        labels = batch["labels"].to(device)                outputs = model(            input_ids=input_ids,            attention_mask=attention_mask,            labels=labels        )                loss = outputs.loss        losses.append(loss.item())                _, preds = torch.max(outputs.logits, dim=1)        correct_predictions += torch.sum(preds == labels)                loss.backward()        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)        optimizer.step()        scheduler.step()        optimizer.zero_grad()        accuracy = correct_predictions.double() / len(data_loader.dataset)    avg_loss = np.mean(losses)    return avg_loss, accuracy# 6. 评估函数def eval_model(model, data_loader, device):    model = model.eval()    losses = []    correct_predictions = 0    all_preds = []    all_labels = []        with torch.no_grad():        for batch in data_loader:            input_ids = batch["input_ids"].to(device)            attention_mask = batch["attention_mask"].to(device)            labels = batch["labels"].to(device)                        outputs = model(                input_ids=input_ids,                attention_mask=attention_mask,                labels=labels            )                        loss = outputs.loss            losses.append(loss.item())                        _, preds = torch.max(outputs.logits, dim=1)            correct_predictions += torch.sum(preds == labels)            all_preds.extend(preds.cpu().numpy())            all_labels.extend(labels.cpu().numpy())        accuracy = correct_predictions.double() / len(data_loader.dataset)    avg_loss = np.mean(losses)    return avg_loss, accuracy, all_preds, all_labels# 7. 训练循环history = {    'train_loss': [], 'train_acc': [],    'val_loss': [], 'val_acc': []}for epoch in range(EPOCHS):    print(f'Epoch {epoch + 1}/{EPOCHS}')    print('-' * 10)        train_loss, train_acc = train_epoch(        model, train_data_loader, loss_fn, optimizer, device, scheduler    )    print(f'Train loss: {train_loss:.4f}, accuracy: {train_acc:.4f}')        val_loss, val_acc, _, _ = eval_model(model, val_data_loader, device)    print(f'Val loss: {val_loss:.4f}, accuracy: {val_acc:.4f}')    print()        history['train_loss'].append(train_loss)    history['train_acc'].append(train_acc.cpu())    history['val_loss'].append(val_loss)    history['val_acc'].append(val_acc.cpu())# 8. 评估模型test_loss, test_acc, y_pred, y_true = eval_model(model, test_data_loader, device)print(f'Test accuracy: {test_acc:.4f}')print(classification_report(y_true, y_pred, target_names=['正常', '可疑']))# 混淆矩阵def plot_confusion_matrix(y_true, y_pred):    cm = confusion_matrix(y_true, y_pred)    plt.figure(figsize=(8,6))    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',                 xticklabels=['正常', '可疑'],                 yticklabels=['正常', '可疑'])    plt.xlabel('预测标签')    plt.ylabel('真实标签')    plt.title('混淆矩阵')    plt.show()plot_confusion_matrix(y_true, y_pred)# 9. 预测函数def predict_intent(dialogue, model, tokenizer, device, max_len=64):    model.eval()    encoding = tokenizer.encode_plus(        dialogue,        add_special_tokens=True,        max_length=max_len,        return_token_type_ids=False,        padding='max_length',        truncation=True,        return_attention_mask=True,        return_tensors='pt',    )        input_ids = encoding['input_ids'].to(device)    attention_mask = encoding['attention_mask'].to(device)        with torch.no_grad():        outputs = model(            input_ids=input_ids,            attention_mask=attention_mask        )        _, prediction = torch.max(outputs.logits, dim=1)    probability = torch.nn.functional.softmax(outputs.logits, dim=1)        return {        "dialogue": dialogue,        "prediction": "可疑对话" if prediction.item() == 1 else "正常对话",        "confidence": probability[0][prediction.item()].item(),        "suspicious_score": probability[0][1].item()    }# 10. 测试示例test_dialogues = [    "用户A: 可以微信转账吗?\n用户B: 加我微信123456,转完发货",    "用户A: 质量怎么样?\n用户B: 都是正品,支持平台验货",    "用户A: 平台手续费太高了\n用户B: 我们私下交易,给你优惠价",    "用户A: 什么时候能发货?\n用户B: 付款后24小时内发货"]for dialogue in test_dialogues:    result = predict_intent(dialogue, model, tokenizer, device)    print(f"对话: {result['dialogue']}")    print(f"预测: {result['prediction']} (置信度: {result['confidence']:.4f})")    print(f"可疑指数: {result['suspicious_score']:.4f}")    print("-" * 50)

代码说明:

    数据生成

      创建500个对话样本(50%正常,50%可疑)正常对话:普通购物咨询可疑对话:包含线下交易、第三方联系方式等关键词

    数据处理

      使用bert-base-chinese分词器创建PyTorch数据集和数据加载器设置最大序列长度64

    模型构建

      基于预训练的BERT模型添加分类层使用交叉熵损失函数设置AdamW优化器

    训练流程

      训练5个epoch记录训练/验证损失和准确率使用线性学习率调度

    评估指标

      测试集准确率分类报告(精确率、召回率、F1值)混淆矩阵可视化

    预测功能

      输入新对话,输出预测结果显示预测标签和置信度提供可疑指数(0-1之间的概率值)

关键特征:

    中文BERT适配:使用bert-base-chinese预训练模型完整流程:数据生成→预处理→训练→评估→预测可视化分析:混淆矩阵展示模型性能实用预测:输出可疑指数和置信度

预期输出示例:

对话: 用户A: 可以微信转账吗?\n用户B: 加我微信123456,转完发货预测: 可疑对话 (置信度: 0.9987)可疑指数: 0.9987--------------------------------------------------对话: 用户A: 质量怎么样?\n用户B: 都是正品,支持平台验货预测: 正常对话 (置信度: 0.9921)可疑指数: 0.0079--------------------------------------------------

扩展建议:

    增加更多对话模式和关键词尝试不同预训练模型(如RoBERTa)添加实体识别强化关键词检测集成规则引擎提高准确率

此代码完整实现了从数据生成到模型部署的全流程,可以直接运行(需安装transformers、torch等库)。

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

BERT 对话意图识别 线下交易检测 自然语言处理
相关文章