2025年3月28日 星期五 甲辰(龙)年 月廿七 设为首页 加入收藏
rss
您当前的位置:首页 > 计算机 > 编程开发 > 人工智能

利用深度学习实现验证码识别-2-使用Python导出ONNX模型并在Java中调用实现验证码识别

时间:09-24来源:作者:点击数:47
城东书院 www.cdsy.xyz
在这里插入图片描述
1. Python部分:导出ONNX模型

首先,我们需要在Python中定义并导出一个已经训练好的验证码识别模型。以下是完整的Python代码:

  • import string
  • import torch
  • import torch.nn as nn
  • import torch.nn.functional as F
  • CHAR_SET = string.digits
  • # 优化后的模型设计
  • class CaptchaModel(nn.Module):
  • def __init__(self):
  • super(CaptchaModel, self).__init__()
  • self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
  • self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  • self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
  • self.fc1 = nn.Linear(128 * 5 * 12, 256) # 调整为实际展平维度
  • self.fc2 = nn.Linear(256, 4 * len(CHAR_SET))
  • self.dropout = nn.Dropout(0.5)
  • def forward(self, x):
  • x = F.relu(F.max_pool2d(self.conv1(x), 2))
  • x = F.relu(F.max_pool2d(self.conv2(x), 2))
  • x = F.relu(F.max_pool2d(self.conv3(x), 2))
  • x = x.view(x.size(0), -1)
  • x = F.relu(self.fc1(x))
  • x = self.dropout(x)
  • x = self.fc2(x)
  • return x.view(-1, 4, len(CHAR_SET))
  • # 使用CUDA,如果可用的话
  • device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  • print(f"Using device: {device}")
  • # 假设你的模型已经训练好并保存在 'best_model.pth'
  • model = CaptchaModel().to(device)
  • model.load_state_dict(torch.load('best_model.pth'))
  • # 生成一个测试输入 (示例输入的形状应与模型输入形状一致)
  • dummy_input = torch.randn(1, 1, 40, 100).to(device)
  • # 导出模型为 ONNX 格式
  • torch.onnx.export(model, dummy_input, "captcha_model.onnx",
  • input_names=["input"], output_names=["output"],
  • dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})
  • print("Model exported to captcha_model.onnx")

这段代码定义了一个验证码识别模型,并将其导出为ONNX格式,以便在Java中使用。

2. Java部分:调用ONNX模型进行验证码识别

接下来,我们使用Java调用导出的ONNX模型进行验证码识别。以下是完整的Java代码:

  • 引用onnxruntime-1.19.0.jar
  • package com.tushuoit;
  • import ai.onnxruntime.*;
  • import javax.imageio.ImageIO;
  • import java.awt.*;
  • import java.awt.image.BufferedImage;
  • import java.io.File;
  • import java.nio.FloatBuffer;
  • import java.util.ArrayList;
  • import java.util.Collections;
  • import java.util.Random;
  • import java.util.List;
  • public class CaptchaInference {
  • private static final String CHAR_SET = "0123456789";
  • private static final int INPUT_WIDTH = 100;
  • private static final int INPUT_HEIGHT = 40;
  • private static final Random random = new Random();
  • public static void main(String[] args) throws Exception {
  • // 随机生成4个字符的验证码文本
  • String captchaText = generateRandomText(4);
  • System.out.println("Generated Captcha Text: " + captchaText);
  • // 生成包含文本的Bitmap (BufferedImage)
  • BufferedImage captchaImage = generateCaptcha(captchaText, 36, INPUT_WIDTH, INPUT_HEIGHT);
  • // 将Bitmap保存为文件(仅用于查看生成的图像,实际使用中可以省略)
  • ImageIO.write(captchaImage, "png", new File("generated_captcha.png"));
  • // 将图像转换为浮点数数组,并进行归一化处理
  • float[] inputData = imageToFloatArray(captchaImage);
  • // 创建ONNX Runtime环境
  • OrtEnvironment env = OrtEnvironment.getEnvironment();
  • OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
  • // 加载ONNX模型
  • OrtSession session = env.createSession("captcha_model.onnx", opts);
  • // 创建输入张量
  • FloatBuffer inputBuffer = FloatBuffer.wrap(inputData);
  • OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputBuffer,
  • new long[] { 1, 1, INPUT_HEIGHT, INPUT_WIDTH });
  • // 进行推理
  • OrtSession.Result result = session.run(Collections.singletonMap("input", inputTensor));
  • // Extract output tensor and decode it
  • float[][][] outputData = (float[][][]) result.get(0).getValue();
  • List<String> decodedTexts = decodeOutput(outputData);
  • // Print the decoded captcha text
  • for (String text : decodedTexts) {
  • System.out.println("Predicted Captcha Text: " + text);
  • }
  • System.out.println("Inference completed.");
  • // 释放资源
  • session.close();
  • env.close();
  • }
  • // 随机生成指定长度的验证码文本
  • private static String generateRandomText(int length) {
  • StringBuilder text = new StringBuilder(length);
  • for (int i = 0; i < length; i++) {
  • text.append(CHAR_SET.charAt(random.nextInt(CHAR_SET.length())));
  • }
  • return text.toString();
  • }
  • // 生成包含文本的BufferedImage
  • private static BufferedImage generateCaptcha(String text, int fontSize, int width, int height) {
  • BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
  • Graphics2D g2d = image.createGraphics();
  • // 设置背景颜色为白色
  • g2d.setColor(Color.WHITE);
  • g2d.fillRect(0, 0, width, height);
  • // 设置字体和颜色
  • g2d.setFont(new Font("DroidSansMono", Font.PLAIN, fontSize));
  • g2d.setColor(Color.BLACK);
  • // 绘制文本
  • FontMetrics fm = g2d.getFontMetrics();
  • int x = 5; // 文字开始的X坐标
  • int y = fm.getAscent() + 5; // 文字开始的Y坐标
  • g2d.drawString(text, x, y);
  • g2d.dispose();
  • return image;
  • }
  • // 将BufferedImage转换为float数组,并进行归一化处理
  • private static float[] imageToFloatArray(BufferedImage image) {
  • int width = image.getWidth();
  • int height = image.getHeight();
  • float[] floatArray = new float[width * height];
  • for (int y = 0; y < height; y++) {
  • for (int x = 0; x < width; x++) {
  • int rgb = image.getRGB(x, y);
  • int gray = (rgb >> 16) & 0xFF; // 因为是灰度图,只需获取一个通道的值
  • floatArray[y * width + x] = (gray / 255.0f - 0.5f) * 2.0f; // 归一化到[-1, 1]
  • }
  • }
  • return floatArray;
  • }
  • private static List<String> decodeOutput(float[][][] outputData) {
  • List<String> decodedTexts = new ArrayList<>();
  • for (float[][] singleOutput : outputData) {
  • StringBuilder decodedText = new StringBuilder();
  • for (float[] charProbabilities : singleOutput) {
  • int maxIndex = getMaxIndex(charProbabilities);
  • decodedText.append(CHAR_SET.charAt(maxIndex));
  • }
  • decodedTexts.add(decodedText.toString());
  • }
  • return decodedTexts;
  • }
  • private static int getMaxIndex(float[] probabilities) {
  • int maxIndex = 0;
  • float maxProb = probabilities[0];
  • for (int i = 1; i < probabilities.length; i++) {
  • if (probabilities[i] > maxProb) {
  • maxProb = probabilities[i];
  • maxIndex = i;
  • }
  • }
  • return maxIndex;
  • }
  • }

这段Java代码首先生成一个随机的验证码图像,然后将其转换为模型输入格式,并通过ONNX Runtime调用导出的模型进行推理,最后解码模型的输出以获取识别的验证码文本。

在这里插入图片描述
总结

通过上述步骤,我们成功地在Python中导出了一个验证码识别模型,并在Java中调用该模型进行验证码识别。这种方法充分利用了Python在深度学习模型训练和导出方面的优势,以及Java在实际应用部署和性能方面的优势,实现了高效的验证码识别系统。

城东书院 www.cdsy.xyz
方便获取更多学习、工作、生活信息请关注本站微信公众号城东书院 微信服务号城东书院 微信订阅号
推荐内容
相关内容
栏目更新
栏目热门
本栏推荐