原创 江浩 2025-06-23 18:04 上海
CLIP+Milvus,如何用于以文搜图
安装方法2:下载Chinese-CLIP,从源代码安装。pip install cn_clip
cd Chinese-CLIP
!pip install -e .
# 创建模式
from pymilvus import MilvusClient, DataType
import torch
import time
milvus_client = MilvusClient(uri="http://localhost:19530")
def create_schema():
schema = milvus_client.create_schema(
auto_id=True,
enable_dynamic_field=True,
description=""
)
schema.add_field(field_name="id", datatype=DataType.INT64, descrition='ids', is_primary=True)
schema.add_field(field_name="vectors", datatype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=512)
schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, descrition='file path', max_length=200)
return schema
schema = create_schema()
# 定义创建集合的函数
import time
def create_collection(collection_name, schema, timeout = 3):
# 创建集合
try:
milvus_client.create_collection(
collection_name=collection_name,
schema=schema,
shards_num=2
)
print(f"开始创建集合:{collection_name}")
except Exception as e:
print(f"创建集合的过程中出现了错误: {e}")
return False
# 检查集合是否创建成功
start_time = time.time()
while True:
if milvus_client.has_collection(collection_name):
print(f"集合 {collection_name} 创建成功")
return True
elif time.time() - start_time > timeout:
print(f"创建集合 {collection_name} 超时")
return False
time.sleep(1)
# 定义检查并且删除同名集合的函数
class CollectionDeletionError(Exception):
"""删除集合失败"""
def check_and_drop_collection(collection_name):
if milvus_client.has_collection(collection_name):
print(f"集合 {collection_name} 已经存在")
try:
milvus_client.drop_collection(collection_name)
print(f"删除集合:{collection_name}")
return True
except Exception as e:
print(f"删除集合时出现错误: {e}")
return False
return True
collection_name = "multimodal_chinese_clip"
uri="http://localhost:19530"
milvus_client = MilvusClient(uri=uri)
# 如果无法删除集合,抛出异常
if not check_and_drop_collection(collection_name):
raise CollectionDeletionError('删除集合失败')
else:
# 创建集合的模式
schema = create_schema()
# 创建集合并等待成功
create_collection(collection_name, schema)
import cn_clip.clip as clip
# 导入可用模型的函数
from cn_clip.clip import available_models
import torch
# 用于图片处理
from PIL import Image
# 查看 chinese-clip 中可用模型列表
print("Available models:", available_models())
Chinese-CLIP的embedding模型分成ViT(Vision Transformer)架构和RN(ResNet)架构两种。先介绍ViT系列模型,它的命名规律是,ViT-{参数规模}-{patch大小}-{输入图片分辨率}。第1个参数“ViT”表示模型的架构,第2个参数表示模型的参数规模,分成B(Base,中等规模)、L(Large,大规模)和H(Huge,超大规模),让我想起咖啡的中杯、大杯和超大杯。第3个参数指的是图片被分割成的patch的大小,14表示patch的尺寸是14 * 14像素。embedding模型在处理图片时,会先把图片分割成多个patch,类似于处理文本时,先对文本分块(详见[[03-鲁迅到底说没说?RAG之分块]])。输入图片的分辨率默认为224 * 224像素,否则会通过第4个参数指定。举个例子,“ViT-L-14-336”表示该embedding模型是ViT架构,参数规模为大规模,patch的尺寸是14 * 14,输入图片的分辨率是336 * 336。相比于ViT系列模型,RN系列模型的命名规律简单些:RN+层数。第1个参数“RN”同样表示模型架构,第2个参数表示层数。比如,RN50表示该模型基于50层的ResNet架构。为了方便演示,我们使用较小的“ViT-B-16”模型。通过clip.load_from_name函数下载、加载模型和预处理函数。Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']
# 确定使用的设备:如果可用则使用GPU,否则使用CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# 指定模型名称
model_name = "ViT-B-16"
# 加载chinese-clip模型和对应的预处理函数
# model: 包含图片编码器(encode_image)和文本编码器(encode_text)
# preprocess: 图片预处理函数(包括归一化、缩放等操作)
# download_root: 设置模型下载后保存的位置
model, preprocess = clip.load_from_name(model_name, device=device, download_root='./chinese_clip_model')
# 将模型设置为评估模式,关闭dropout等训练特性
model.eval()
print("-"*50)
print(f"Model Loaded: {model_name}")
def encode_image(image_path):
# 关闭梯度计算,减少内存消耗,提高计算效率
with torch.no_grad():
# 打开图片文件
# 如果图片不是RGB格式,使用convert转换格式
raw_image = Image.open(image_path).convert('RGB')
processed_image = preprocess(raw_image).unsqueeze(0).to(device)
# 生成图片的向量
image_features = model.encode_image(processed_image)
# 特征归一化
image_features /= image_features.norm(dim=-1, keepdim=True)
# 以列表形式返回向量
return image_features.squeeze().tolist()
def encode_text(text_list):
# 关闭梯度计算,减少内存消耗,提高计算效率
with torch.no_grad():
# 文本分词和特殊符号处理
text_tokens = clip.tokenize(text_list).to(device)
# 生成文本的向量
text_features = model.encode_text(text_tokens)
# 特征归一化
text_features /= text_features.norm(dim=-1, keepdim=True)
# 以列表形式返回向量
return [f.squeeze().tolist() for f in text_features]
# 定义插入数据的函数
import os
from glob import glob
from tqdm import tqdm
import time
# 进度条显示一个变化的进度条,而不是多个不同进度的进度条
def process_images_and_insert(input_dir_path, ext_list, batch_size=100):
# 获取所有JPEG文件路径(递归图片检索)
image_paths = []
for ext in ext_list:
image_paths.extend(glob(os.path.join(input_dir_path, f"**/{ext}"), recursive=True))
total_images = len(image_paths)
print(f"总计需要处理 {total_images} 张图片")
# 初始化总计时器
total_start_time = time.time()
# 初始化进度条
with tqdm(total=total_images, desc="处理图片并插入数据") as progress_bar:
# 分批处理图片
for batch_start in range(0, total_images, batch_size):
batch_data = []
batch_paths = image_paths[batch_start: batch_start + batch_size]
batch_start_time = time.time()
# 当前批次的向量化处理
for image_path in batch_paths:
try:
image_embedding = encode_image(image_path)
batch_data.append({
"vectors": image_embedding,
"filepath": image_path
})
except Exception as e:
print(f"处理图片 {image_path} 时出错: {str(e)}")
continue
# 批量插入当前批次到Milvus
if batch_data:
try:
res = milvus_client.insert(
collection_name=collection_name,
data=batch_data
)
# 计算批次耗时
batch_duration = time.time() - batch_start_time
# 更新进度条:每次成功插入的图片数量
progress_bar.update(len(batch_data))
# 显示批次处理时间
progress_bar.set_postfix({
"批次耗时": batch_duration,
})
except Exception as e:
print(f"插入批次 {batch_start} 时失败: {str(e)}")
# 计算总耗时
total_duration = time.time() - total_start_time
print(f"\n所有图片处理完成!总耗时: total_duration)")
print(f"平均处理速度: {total_images/total_duration:.1f}张/秒")
# 插入数据
input_dir_path = "lhq_1024_jpg_5000"
# 每批处理数量
batch_size = 300
ext_list = ['*.JPEG', '*.jpg', '*.png']
process_images_and_insert(input_dir_path, ext_list, batch_size)
# 定义创建索引的函数
def create_index(collection_name):
# 准备索引参数
index_params = milvus_client.prepare_index_params()
index_params.add_index(
index_name="IVF_FLAT",
# 指定创建索引的字段
field_name="vectors",
index_type="IVF_FLAT",
metric_type="COSINE",
params={"nlist":512}
)
# 创建索引
milvus_client.create_index(
collection_name=collection_name,
index_params=index_params
)
create_index(collection_name)
集合加载成功了吗?验证下看看。# 加载集合
print(f"正在加载集合 {collection_name}")
milvus_client.load_collection(collection_name=collection_name)
print(f"集合 {collection_name} 加载完成")
# 验证加载状态
state = str(milvus_client.get_load_state(collection_name=collection_name)['state'])
if state == 'Loaded':
print("集合加载完成")
else:
print("集合加载失败")
print(milvus_client.query(
collection_name=collection_name,
output_fields=["count(*)"]
)
)
03 结果展示使用Chinese-CLIP可以实现以文搜图以及以图搜图,其实本质都是相同的,都是根据查询(文字或者图片)生成查询向量,再从Milvus中检索与查询向量最接近的图片的向量,最后返回该图片。先来试试以文搜图吧。以文搜图首先定义图片检索函数。输入查询向量(vector)、图片检索的字段(field_name)、返回结果的数量(limit)以及输出的字段(output_fields),返回图片检索结果。data: ["{'count(*)': 5000}"]
# 定义图片检索函数
def vector_search(vector, field_name, limit, output_fields):
# 执行向量图片检索
res = milvus_client.search(
collection_name=collection_name,
data=vector,
anns_field=field_name,
limit=limit,
output_fields=output_fields
)
return res
# 以文搜图
query_text = ["枯藤老树昏鸦"]
query_embedding = encode_text(query_text)[0]
field_name = "vectors"
limit = 10
output_fields = ["filepath"]
res = vector_search([query_embedding], field_name, limit, output_fields)
from IPython.display import display
from PIL import Image
# 定义显示图片检索结果的函数
def create_concatenated_image(res, images_per_row=2, images_per_column=2, image_size=(400, 400)):
# 设置拼接后的大图尺寸:
width = image_size[0] * images_per_row
height = image_size[1] * images_per_column
# 创建一个空白的大画布(RGB模式,白色背景)
concatenated_image = Image.new("RGB", (width, height))
# 存储所有结果图片的列表
result_images = []
# 遍历图片检索结果的每个hit对象(res是包含多个batch的列表)
for result in res: # 通常res是单batch列表
for hit in result:
# 从hit对象中获取图片文件路径
filename = hit["entity"]["filepath"]
# 打开图片文件并调整大小为指定尺寸
img = Image.open(filename)
# 保持宽高比的缩略图
img = img.resize(image_size)
# 将处理后的图片添加到列表
result_images.append(img)
# 将缩略图拼接到大画布上
for idx, img in enumerate(result_images):
# 计算当前图片应放置的网格位置:
# 列索引(每行显示images_per_row张图)
x = idx % images_per_row
# 行索引(整数除法)
y = idx // images_per_row
# 将图片粘贴到计算好的位置
concatenated_image.paste(img, (x * image_size[0], y * image_size[1]))
return concatenated_image
# 查询文本
print(f"查询文本: {query_text}")
# 图片检索结果
print(f"检索结果:")
display(create_concatenated_image(res, 2, 2, (400, 400)))
# 显示查询图片
def show_single_image(image_path, image_size=(300, 300)):
# 打开图片
img = Image.open(image_path)
# 保持宽高比的前提下缩小图片,图片缩小后的最大值不超过指定值
img.thumbnail(image_size)
# 缩放图片到指定尺寸
# img = img.resize(image_size)
# 显示图片
display(img)
# 定义查询图片
query_image = 'query_image.jpg'
query_embedding = encode_image(query_image)
field_name = "vectors"
limit = 10
output_fields = ["filepath"]
res = vector_search([query_embedding], field_name, limit, output_fields)
# 查询图片
print(f"查询图片")
show_single_image(query_image)
# 图片检索结果
print(f"图片检索结果:")
concatenated_image = create_concatenated_image(res)
display(concatenated_image)
作者介绍
Zilliz 黄金写手:江浩