掘金 人工智能 06月23日 23:38
Pytorch风格迁移的Android实现
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文介绍了如何在 Android 平台上部署 PyTorch 风格迁移模型,实现图像风格的转换。通过使用 PyTorch Lite,开发者能够将预训练的风格迁移模型集成到 Android 应用中,从而将内容图像转化为具有特定艺术风格的图像。文章详细阐述了模型部署的流程,包括图像预处理、模型输入、推理以及结果后处理,并提供了代码示例和性能分析。 最终实现了将照片转换为梵高、毕加索等艺术风格的效果。

🎨 **风格迁移技术核心**:风格迁移是一种基于深度学习的图像处理技术,它结合了内容图像的结构和风格图像的艺术特征,生成具有两者特点的新图像。其核心在于利用卷积神经网络提取图像的高层语义特征,并融合风格图像的纹理和色彩。

📱 **PyTorch Lite 在 Android 上的应用**:文章重点介绍了如何在 Android 平台上使用 PyTorch Lite 实现风格迁移。PyTorch Lite 使得在移动设备上部署模型成为可能,降低了对算力和存储的需求,从而实现了在手机上运行风格迁移模型。

🖼️ **图像预处理与后处理**:为了使图像适应 PyTorch 模型的输入,需要进行预处理,包括将图像转换为 Tensor 格式。在模型推理完成后,还需要进行后处理,将输出的 Tensor 转换为 Android 平台可用的 Bitmap 图像,以呈现风格转换后的结果。

⚙️ **模型部署流程**:文章详细介绍了风格迁移模型在 Android 上的部署流程,包括模型加载、图像预处理、模型推理、结果后处理等步骤。通过具体的代码示例,展示了如何将预训练的风格迁移模型集成到 Android 应用中。

⏱️ **性能评估**:文章还对风格迁移的处理时间进行了评估。通过对图像进行缩放,可以在保证效果的同时,缩短风格迁移的耗时,在测试手机上,耗时仅2秒左右。

前言

使用神经网络训练后的模型可以实现图像风格的迁移,比如前一阵非常火爆的吉卜力风格。本文尝试在 Android 实现 PyTorch 风格迁移模型的部署, 实现基于固定风格的迁移。

什么是图像风格迁移

图像风格迁移(Neural Style Transfer, NST) 是一种基于深度学习的图像处理技术,通过结合内容图像的结构和风格图像的艺术特征(如笔触、色彩),生成兼具两者特点的新图像。其核心思想是利用卷积神经网络(如VGG19)提取图像的高层语义特征:

典型应用包括将照片转化为梵高、毕加索等艺术风格,或生成吉卜力动画风格的画面。技术变体包括实时风格迁移(如AdaIN)和基于扩散模型的生成方法。

基于 PyTorchLite 的风格迁移

在上一篇 PyTorch对抗生成网络模型及Android端的实现 中,我们通过使用 PyTorch Android 端的 SDK 实现了 GAN 在手机端的部署。但是由于算力和存储的限制,在手机上部署图像生成模型压力还是比较大的,存储和推理速度都是瓶颈。然而, 风格迁移的实现相对来说就比较简单了,算力和存储占用大幅降低的情况之下,也能收获相对来说不错的效果。

可以看到,以上两个示例的效果还是不错的。

PyTorch Lite 处理输入输出

我们先回顾一下 PyTorch Lite 的用法,上一节中对于基于 GAN 框架的生成模型来说,按照训练期间模型的定义,我们的输入是 1x100 的随机数,输出是 1x64x64x3 (64 像素大小彩色图片)。

    private fun genImage(): Bitmap {        val zDim = intArrayOf(1, 100)        val outDims = intArrayOf(64, 64, 3)            val z = FloatArray(zDim[0] * zDim[1])              val rand = Random()        // 生成高斯随机数        for (c in 0 until zDim[0] * zDim[1]) {            z[c] = rand.nextGaussian().toFloat()        }            val tensor = Tensor.fromBlob(z, longArrayOf(1, 100))                val resultArray = module.forward(IValue.from(tensor)).toTensor().dataAsFloatArray        val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }        var index = 0        // 根据输出的一维数组,解析生成的卡通图像        ....        val bitmap = Utils.getBitmap(resultImg, outDims)        return bitmap    }

因此,我们做的工作就是构建适用于模型输入结构的数据 inputTensor ,同时根据模型返回的数据 resultTensor,将数据转换为我们需要的格式,比如对于生成模型来说,转化为对应平台可以渲染的图像数据即可,对于Android 来说就是常用的 Bitmap。

而对于风格迁移模型来说,一般情况下输入有两个,内容图片和风格图片。这里简单起见,风格类型由模型固化,我们只处理内容图片的输入,生成固定风格的图片。

模型初始化

这里我们直接使用 PyTorch 官方示例中 fast_neural_style 风格迁移所用到的模型。具体模型可以从 这里 下载,转换为 PyTorch Lite 可用的模型即可,我们以风格比较鲜明的 mosaic 为例。

module = LiteModuleLoader.load(AndroidAssetsFileUtil.assetFilePath(this, "mosaic.pt"))

对 Bitmap 进行风格迁移

    fun transferStyleAsync(        contentImage: Bitmap, scale: Float = 1.0f, cb: ((Bitmap) -> Unit)? = null    ): Bitmap {                // 1. Preprocess the content image (simple ToTensor + multiply by 255)        val (transformedImage, width, height) = preprocessImage(contentImage, scale)        // 2. Create input tensor        val inputTensor = Tensor.fromBlob(            transformedImage, longArrayOf(1, 3, height.toLong(), width.toLong())        )                // 3. Run the model        val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()        // 4. Postprocessor the output        return postprocessingImage(outputTensor, width, height, cb)    }

HWC 和 CHW

这里需要注意的是,Android 中标准的 Bitmap 其数据是 RGB 的格式进行存储的,而 PyTorch 中为了方便内存优化进行计算,是按照 CHW 的格式排列数据的,因此将传统的 Bitmap 传入 PyTorch 进行处理之前需要进行数据转换。同时,由于数据数据的宽高涉及到后续推理模型的计算,因此还需要返回内容图片的宽高值。

返回结果的处理

有了原始 Bitmap 数据和宽高,我们就可以调用模型进行推理了。

        val inputTensor = Tensor.fromBlob(            transformedImage, longArrayOf(1, 3, height.toLong(), width.toLong())        )        val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()

返回结果依然是 Tensor 类型,我们需要从中获取所需要的 Bitmap 图像数据

    private fun postprocessingImage(        outputTensor: Tensor, width: Int, height: Int): Bitmap {                val outputData = outputTensor.dataAsFloatArray        // Create output bitmap        val outputBitmap = createBitmap(width, height)        val pixels = IntArray(width * height)        // Convert from CHW to ARGB format        val channelSize = width * height               for (i in 0 until channelSize) {            // Get RGB values (scaled back from 0-255)            val r = outputData[i].toInt().coerceIn(0, 255)            val g = outputData[i + channelSize].toInt().coerceIn(0, 255)            val b = outputData[i + 2 * channelSize].toInt().coerceIn(0, 255)            // Combine into ARGB pixel            pixels[i] = 0xFF shl 24 or (r shl 16) or (g shl 8) or b        }        // Set pixels to bitmap        outputBitmap.setPixels(pixels, 0, width, 0, 0, width, height)        return outputBitmap    }

由于返回的 dataAsFloatArray 依然是 CHW 格式的数据,我们需要再执行一次逆向操作,将 CHW 格式的数据转换为 ARGB 格式的 Bitmap 数据。我们可以看一下效果(示例中最后一个就是 mosaic 风格)。

可以顺便看一眼耗时

16:49:39.471 StyleTransferProcessor   D  transferStyle() called with: contentImage = android.graphics.Bitmap@3e8f24, scale = 0.516:49:39.471 StyleTransferProcessor   D  preprocessImage() called with: bitmap = android.graphics.Bitmap@3e8f24, scale = 0.516:49:39.477 StyleTransferProcessor   I  scale 468,83216:49:39.511 StyleTransferProcessor   D  transferStyle() step2 done16:49:41.332 StyleTransferProcessor   D  transferStyle() step3 done16:49:41.332 StyleTransferProcessor   D  postprocessingImage() called with: outputTensor = Tensor([1, 3, 832, 468], dtype=torch.float32), width = 468, height = 83216:49:41.354 StyleTransferProcessor   D  channelSize = 38937616:49:41.466 StyleTransferProcessor   D  bitmap is ok16:50:21.199 StyleTransferProcessor   D  transferStyle() called with: contentImage = android.graphics.Bitmap@3e8f24, scale = 0.516:50:21.199 StyleTransferProcessor   D  preprocessImage() called with: bitmap = android.graphics.Bitmap@3e8f24, scale = 0.516:50:21.202 StyleTransferProcessor   I  scale 468,83216:50:21.228 StyleTransferProcessor   D  transferStyle() step2 done16:50:23.159 StyleTransferProcessor   D  transferStyle() step3 done16:50:23.160 StyleTransferProcessor   D  postprocessingImage() called with: outputTensor = Tensor([1, 3, 832, 468], dtype=torch.float32), width = 468, height = 83216:50:23.185 StyleTransferProcessor   D  channelSize = 38937616:50:23.307 StyleTransferProcessor   D  bitmap is ok16:50:31.999 StyleTransferProcessor   D  transferStyle() called with: contentImage = android.graphics.Bitmap@3e8f24, scale = 0.516:50:31.999 StyleTransferProcessor   D  preprocessImage() called with: bitmap = android.graphics.Bitmap@3e8f24, scale = 0.516:50:32.002 StyleTransferProcessor   I  scale 468,83216:50:32.028 StyleTransferProcessor   D  transferStyle() step2 done16:50:33.716 StyleTransferProcessor   D  transferStyle() step3 done16:50:33.716 StyleTransferProcessor   D  postprocessingImage() called with: outputTensor = Tensor([1, 3, 832, 468], dtype=torch.float32), width = 468, height = 83216:50:33.743 StyleTransferProcessor   D  channelSize = 38937616:50:33.856 StyleTransferProcessor   D  bitmap is ok

可以看到,对原始图片宽高按照 50% 的比例压缩之后,转换时间还是挺快的,2 秒基本上就可以完成一副图片的转换(测试手机为一加 8, 骁龙 865 ,8GB RAM)

总结

有兴趣的话,可以运行官方示例 fast_neural_style,对比一下各自的效果。还可以基于官方示例,训练自己的风格模型,尝试实现不同的风格。

在手机端除了由于内存的限制,以及转换后模型精度的变化,多多少少还是有一些损失。甚至对于一些像素比较高的图片,不进行压缩的话会出现内存不足无法运行的情况。当然,风格迁移的实现方式有很多种,现在很多手机自带的相册和相机都可以实时进行固定风格的迁移,技术总是在不断变化的进步,相信随着这一波生成式人工智能的发展,风格迁移模型可以变得更加强大。

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

PyTorch 风格迁移 Android 图像处理 深度学习
相关文章