掘金 人工智能 05月22日 09:58
Hugging Face + Spark:打造高效的 NLP 大数据处理引擎(一)
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文探讨了如何将Hugging Face和Spark相结合,以实现高效的自然语言处理(NLP)大数据处理。文章介绍了两种方法:一是升级Spark版本至3.4及以上,利用Spark原生支持加载模型进行预测;二是基于Spark 3.3.1,通过手动封装接口实现Hugging Face和Spark的结合,包括封装分布式模型缓存和处理逻辑,以满足不同Spark版本用户的需求,并提供了详细的代码实现和关键说明。

✨Spark 3.4及以上版本原生支持加载Hugging Face模型进行预测,方便快捷。

🔑对于Spark 3.3.1版本,可以通过手动封装接口实现结合,增强灵活性。

📦手动封装方法中,关键在于封装分布式的模型缓存,高效管理模型加载与缓存。

💡提供了详细的代码实现,包括模型加载策略、文件夹管理,以及批量预测的Pandas UDF定义。

🛠️通过批量预测的Pandas UDF,可以处理大规模文本数据,实现高效的文本嵌入向量提取。

在自然语言处理(NLP)领域,Hugging Face 是不可或缺的处理库,而 Spark 则是大数据处理的必备工具。将两者的优势结合起来,可以实现高效的 NLP 大数据处理。以下是结合 Hugging Face 和 Spark 的两种方法,基于 Spark & PySpark 3.3.1 版本进行探索。

方法一:升级 Spark 版本至 3.4 及以上

如果你愿意升级 Spark 版本到 3.4 或更高版本,那么结合 Hugging Face 和 Spark 将变得非常方便。Spark 3.4 及以上版本天然支持加载模型进行预测。

关键步骤说明:

    模型加载策略:需要为每个 Worker 单独加载模型,确保模型在分布式环境中的可用性。文件夹管理:在加载 Hugging Face 预训练模型之前,务必删除之前的模型文件夹,防止加载失败。

注:如果图片无法显示,请检查链接合法性或稍后重试。

方法二:基于 Spark 3.3.1 的手动封装接口

如果你希望保持当前的 Spark 3.3.1 版本,那么可以通过手动封装接口来实现 Hugging Face 和 Spark 的结合。以下是详细的代码实现和关键说明。

封装分布式的模型缓存

为了高效管理模型加载和缓存,我们从spark3.4的源代码中抽取了一个分布式的模型缓存机制:

from collections import OrderedDictfrom threading import Lockfrom typing import Callable, Optionalfrom uuid import UUIDclass ModelCache:    """Cache for model prediction functions on executors.    This requires the `spark.python.worker.reuse` configuration to be set to `true`, otherwise a    new python worker (with an empty cache) will be started for every task.    If a python worker is idle for more than one minute (per the IDLE_WORKER_TIMEOUT_NS setting in    PythonWorkerFactory.scala), it will be killed, effectively clearing the cache until a new python    worker is started.    Caching large models can lead to out-of-memory conditions, which may require adjusting spark    memory configurations, e.g. `spark.executor.memoryOverhead`.    """    _models: OrderedDict = OrderedDict()    _capacity: int = 3  # "reasonable" default size for now, make configurable later, if needed    _lock: Lock = Lock()    @staticmethod    def add(uuid: UUID, predict_fn: Callable) -> None:        with ModelCache._lock:            ModelCache._models[uuid] = predict_fn            ModelCache._models.move_to_end(uuid)            if len(ModelCache._models) > ModelCache._capacity:                ModelCache._models.popitem(last=False)    @staticmethod    def get(uuid: UUID) -> Optional[Callable]:        with ModelCache._lock:            predict_fn = ModelCache._models.get(uuid)            if predict_fn:                ModelCache._models.move_to_end(uuid)            return predict_fn

封装处理逻辑

from __future__ import annotationsimport osimport argparseimport randomimport loggingimport pandas as pdfrom pyspark.sql import SparkSessionfrom pyspark.sql.functions import udf, column, encodefrom pyspark.sql.types import *from datetime import datetime, timedeltaimport requests as reqfrom io import BytesIOimport numpy as npimport uuidimport inspectfrom pyspark.sql.functions import pandas_udffrom pyspark.sql.types import (    ArrayType,    ByteType,    DataType,    DoubleType,    FloatType,    IntegerType,    LongType,    ShortType,    StringType,    StructType,)from typing import Any, Callable, Iterator, List, Mapping, TYPE_CHECKING, Tuple, Union, Optionalsupported_scalar_types = (    ByteType,    ShortType,    IntegerType,    LongType,    FloatType,    DoubleType,    StringType,)hadoop = os.path.join(os.environ['HADOOP_COMMON_HOME'], 'bin/hadoop')def init_spark():    """初始化 SparkSession 配置"""    spark = SparkSession.builder \        .config("spark.sql.caseSensitive", "false") \        .config("spark.shuffle.spill", "true") \        .config("spark.shuffle.spill.compress", "true") \        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \        .config("metastore.catalog.default", "hive") \        .config("spark.sql.hive.convertMetastoreOrc", "true") \        .config("spark.kryoserializer.buffer.max", "1024m") \        .config("spark.kryoserializer.buffer", "64m") \        .config("spark.driver.maxResultSize","4g") \        .config("spark.sql.broadcastTimeout", "36000") \        .enableHiveSupport() \        .getOrCreate()    return sparkdef system_command(command):    """执行系统命令"""    code = os.system(command)    if code != 0:        logging.error(f"Command: ({command}) execute failed.")    else:        logging.info(f"Command: ({command}) execute succeed.")def parse_args():    """解析命令行参数"""    parser = argparse.ArgumentParser(usage="it's usage tip.",                                     description="user tags prefer")    parser.add_argument("--db", default="", help="hive表")    parser.add_argument("--date", default="", help="日期")    parser.add_argument("--output_path", default="", help="输出路径")    parser.add_argument("--batch_size", default=16, help="输出路径")    return parser.parse_args()def _batched(    data: Union[pd.Series, pd.DataFrame, Tuple[pd.Series]], batch_size: int) -> Iterator[pd.DataFrame]:    """将 pandas dataframe/series 分批处理"""    if isinstance(data, pd.DataFrame):        df = data    elif isinstance(data, pd.Series):        df = pd.concat((data,), axis=1)    else:  # isinstance(data, Tuple[pd.Series])        df = pd.concat(data, axis=1)    index = 0    data_size = len(df)    while index < data_size:        yield df.iloc[index : index + batch_size]        index += batch_sizedef _is_tensor_col(data: Union[pd.Series, pd.DataFrame]) -> bool:    """检查数据是否为张量列"""    if isinstance(data, pd.Series):        return data.dtype == np.object_ and isinstance(data.iloc[0], (np.ndarray, list))    elif isinstance(data, pd.DataFrame):        return any(data.dtypes == np.object_) and any(            [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]        )    else:        raise ValueError(            "Unexpected data type: {}, expected pd.Series or pd.DataFrame.".format(type(data))        )def _has_tensor_cols(data: Union[pd.Series, pd.DataFrame, Tuple[pd.Series]]) -> bool:    """检查输入是否包含张量值列"""    if isinstance(data, (pd.Series, pd.DataFrame)):        return _is_tensor_col(data)    else:  # isinstance(data, Tuple)        return any(_is_tensor_col(elem) for elem in data)def _validate_and_transform_multiple_inputs(    batch: pd.DataFrame, input_shapes: List[Optional[List[int]]], num_input_cols: int) -> List[np.ndarray]:    """验证并转换多个输入"""    multi_inputs = [batch[col].to_numpy() for col in batch.columns]    if input_shapes:        if len(input_shapes) == num_input_cols:            multi_inputs = [                np.vstack(v).reshape([-1] + input_shapes[i])  # type: ignore                if input_shapes[i]                else v                for i, v in enumerate(multi_inputs)            ]            if not all([len(x) == len(batch) for x in multi_inputs]):                raise ValueError("Input data does not match expected shape.")        else:            raise ValueError("input_tensor_shapes must match columns")    return multi_inputsdef _validate_and_transform_single_input(    batch: pd.DataFrame,    input_shapes: List[Optional[List[int]]],    has_tensors: bool,    has_tuple: bool,) -> np.ndarray:    """验证并转换单个输入"""    # 处理逻辑省略(与原文一致)    return single_inputdef _validate_and_transform_prediction_result(    preds: Union[np.ndarray, Mapping[str, np.ndarray], List[Mapping[str, Any]]],    num_input_rows: int,    return_type: DataType,) -> Union[pd.DataFrame, pd.Series]:    """验证并转换预测结果"""    # 处理逻辑省略(与原文一致)    return pd.DataFrame(preds)def predict_batch_udf(    make_predict_fn: Callable[        [],        PredictBatchFunction,    ],    *,    return_type: DataType,    batch_size: int,    input_tensor_shapes: Optional[Union[List[Optional[List[int]]], Mapping[int, List[int]]]] = None,):    """定义批量预测的 Pandas UDF"""    model_uuid = uuid.uuid4()    def predict(data: Iterator[Union[pd.Series, pd.DataFrame]]) -> Iterator[pd.DataFrame]:        from model_cache import ModelCache        predict_fn = ModelCache.get(model_uuid)        if not predict_fn:            predict_fn = make_predict_fn()            ModelCache.add(model_uuid, predict_fn)        signature = inspect.signature(predict_fn)        num_expected_cols = len(signature.parameters)        input_shapes: List[Optional[List[int]]]        if isinstance(input_tensor_shapes, Mapping):            input_shapes = [None] * num_expected_cols            for index, shape in input_tensor_shapes.items():                input_shapes[index] = shape        else:            input_shapes = input_tensor_shapes  # type: ignore        for pandas_batch in data:            has_tuple = isinstance(pandas_batch, Tuple)  # type: ignore            has_tensors = _has_tensor_cols(pandas_batch)            if has_tensors and not input_shapes:                raise ValueError("Tensor columns require input_tensor_shapes")            for batch in _batched(pandas_batch, batch_size):                num_input_rows = len(batch)                num_input_cols = len(batch.columns)                if num_input_cols == num_expected_cols and num_expected_cols > 1:                    multi_inputs = _validate_and_transform_multiple_inputs(                        batch, input_shapes, num_input_cols                    )                    preds = predict_fn(*multi_inputs)                elif num_expected_cols == 1:                    single_input = _validate_and_transform_single_input(                        batch, input_shapes, has_tensors, has_tuple                    )                    preds = predict_fn(single_input)                else:                    msg = "Model expected {} inputs, but received {} columns"                    raise ValueError(msg.format(num_expected_cols, num_input_cols))                yield _validate_and_transform_prediction_result(                    preds, num_input_rows, return_type                )  # type: ignore    return pandas_udf(predict, return_type)  # type: ignore[call-overload]def extract_text_embedding(model, tokenizer, sentence):    """提取文本嵌入向量"""    inputs = tokenizer(sentence, return_tensors='pt', max_length=32, padding=True, truncation=True)    embeddings = model(**inputs)    embeddings = embeddings.pooler_output    embeddings = embeddings.tolist()    for i in range(len(embeddings)):        embeddings[i] = [round(c,4) for c in embeddings[i]]    return np.array(embeddings, dtype=np.float32)if __name__ == "__main__":    args = parse_args()    spark = init_spark()      ### 读取数据  df = spark.sql(f"""         select article_id, title         from xxx     """)    def predict_embedding():        system_command(f"""rm -rf ./bert-base-chinese""")        system_command(f"""{hadoop} fs -get /path/to/bert-base-chinese""")        from transformers import BertTokenizer, BertModel        tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese')        text_model = BertModel.from_pretrained('./bert-base-chinese')        def predict(inputs):            sentence = inputs.tolist()            embeddings = extract_text_embedding(text_model, tokenizer, sentence)            return embeddings        return predict    predict_embedding_udf = predict_batch_udf(predict_embedding,                                  return_type=ArrayType(StringType()),                                  batch_size=100)    df.withColumn("title_embedding", predict_embedding_udf("title")).show(5)    spark.stop()    del spark

关键点说明:

    模型加载与缓存:通过 predict_batch_udf 函数封装预测逻辑,利用模型缓存避免重复加载,提高效率。批量处理:使用 _batched 函数将数据分批处理,避免内存溢出,适合大数据场景。类型转换与验证:通过 _validate_and_transform 系列函数确保输入输出类型匹配,提高代码健壮性。Hugging Face 模型集成:在 predict_embedding 函数中加载 Hugging Face 的 BERT 模型,并定义预测逻辑。typing语法修改:python3.10以的typing是不支持|语法的,需要改成Union进行类型的或推断

方法比较

对比维度方法一:升级 Spark 至 3.4+方法二:基于 Spark 3.3.1 手动封装接口
实现难度较低,依托新版本特性较高,需手动实现缓存及接口封装
模型加载方式每个 Worker 单独加载模型每个 Worker 单独加载模型,并通过分布式缓存机制复用
文件管理要求需提前删除旧模型文件夹防止加载失败加载前删除旧模型文件夹
代码复用性可直接使用新版本 API,代码简洁需手动封装,代码量较大,但更具灵活性
性能优化新版本可能自带优化可通过调整缓存策略、批量处理逻辑等进行精细优化
适用场景适合可升级环境,追求快速开发适合无法升级环境,或对性能和资源管理有更高要求的场景
可维护性依赖新版本稳定性,升级后需充分测试自定义逻辑较多,需额外维护封装的接口和缓存机制
扩展性依赖 Spark 新版本的更新节奏可根据项目需求灵活扩展自定义功能
社区支持可直接参考官方文档和社区对新版本的案例需结合旧版本社区经验,同时参考自定义实现的维护文档
资源消耗新版本可能对硬件有新要求可通过优化缓存和批处理逻辑,更精细地控制资源使用

通过以上两种方法,可以在不同 Spark 版本环境下实现 Hugging Face 和 Spark 的结合,充分发挥两者在 NLP 和大数据处理中的优势,推荐第二种,更加可控一些。

本文由博客一文多发平台 OpenWrite 发布!

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Hugging Face Spark NLP 大数据处理 PySpark
相关文章