首先,我们需要在Python中定义并导出一个已经训练好的验证码识别模型。以下是完整的Python代码:
- import string
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- CHAR_SET = string.digits
-
- # 优化后的模型设计
- class CaptchaModel(nn.Module):
- def __init__(self):
- super(CaptchaModel, self).__init__()
- self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
- self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
- self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
- self.fc1 = nn.Linear(128 * 5 * 12, 256) # 调整为实际展平维度
- self.fc2 = nn.Linear(256, 4 * len(CHAR_SET))
- self.dropout = nn.Dropout(0.5)
-
- def forward(self, x):
- x = F.relu(F.max_pool2d(self.conv1(x), 2))
- x = F.relu(F.max_pool2d(self.conv2(x), 2))
- x = F.relu(F.max_pool2d(self.conv3(x), 2))
- x = x.view(x.size(0), -1)
- x = F.relu(self.fc1(x))
- x = self.dropout(x)
- x = self.fc2(x)
- return x.view(-1, 4, len(CHAR_SET))
-
- # 使用CUDA,如果可用的话
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- print(f"Using device: {device}")
- # 假设你的模型已经训练好并保存在 'best_model.pth'
- model = CaptchaModel().to(device)
- model.load_state_dict(torch.load('best_model.pth'))
-
- # 生成一个测试输入 (示例输入的形状应与模型输入形状一致)
- dummy_input = torch.randn(1, 1, 40, 100).to(device)
-
- # 导出模型为 ONNX 格式
- torch.onnx.export(model, dummy_input, "captcha_model.onnx",
- input_names=["input"], output_names=["output"],
- dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})
-
- print("Model exported to captcha_model.onnx")
-
这段代码定义了一个验证码识别模型,并将其导出为ONNX格式,以便在Java中使用。
接下来,我们使用Java调用导出的ONNX模型进行验证码识别。以下是完整的Java代码:
- package com.tushuoit;
-
- import ai.onnxruntime.*;
- import javax.imageio.ImageIO;
- import java.awt.*;
- import java.awt.image.BufferedImage;
- import java.io.File;
- import java.nio.FloatBuffer;
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.Random;
- import java.util.List;
-
- public class CaptchaInference {
- private static final String CHAR_SET = "0123456789";
- private static final int INPUT_WIDTH = 100;
- private static final int INPUT_HEIGHT = 40;
- private static final Random random = new Random();
-
- public static void main(String[] args) throws Exception {
- // 随机生成4个字符的验证码文本
- String captchaText = generateRandomText(4);
- System.out.println("Generated Captcha Text: " + captchaText);
-
- // 生成包含文本的Bitmap (BufferedImage)
- BufferedImage captchaImage = generateCaptcha(captchaText, 36, INPUT_WIDTH, INPUT_HEIGHT);
-
- // 将Bitmap保存为文件(仅用于查看生成的图像,实际使用中可以省略)
- ImageIO.write(captchaImage, "png", new File("generated_captcha.png"));
-
- // 将图像转换为浮点数数组,并进行归一化处理
- float[] inputData = imageToFloatArray(captchaImage);
-
- // 创建ONNX Runtime环境
- OrtEnvironment env = OrtEnvironment.getEnvironment();
- OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
-
- // 加载ONNX模型
- OrtSession session = env.createSession("captcha_model.onnx", opts);
-
- // 创建输入张量
- FloatBuffer inputBuffer = FloatBuffer.wrap(inputData);
- OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputBuffer,
- new long[] { 1, 1, INPUT_HEIGHT, INPUT_WIDTH });
-
- // 进行推理
- OrtSession.Result result = session.run(Collections.singletonMap("input", inputTensor));
-
- // Extract output tensor and decode it
- float[][][] outputData = (float[][][]) result.get(0).getValue();
- List<String> decodedTexts = decodeOutput(outputData);
-
- // Print the decoded captcha text
- for (String text : decodedTexts) {
- System.out.println("Predicted Captcha Text: " + text);
- }
-
- System.out.println("Inference completed.");
- // 释放资源
- session.close();
- env.close();
- }
-
- // 随机生成指定长度的验证码文本
- private static String generateRandomText(int length) {
- StringBuilder text = new StringBuilder(length);
- for (int i = 0; i < length; i++) {
- text.append(CHAR_SET.charAt(random.nextInt(CHAR_SET.length())));
- }
- return text.toString();
- }
-
- // 生成包含文本的BufferedImage
- private static BufferedImage generateCaptcha(String text, int fontSize, int width, int height) {
- BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
- Graphics2D g2d = image.createGraphics();
-
- // 设置背景颜色为白色
- g2d.setColor(Color.WHITE);
- g2d.fillRect(0, 0, width, height);
-
- // 设置字体和颜色
- g2d.setFont(new Font("DroidSansMono", Font.PLAIN, fontSize));
- g2d.setColor(Color.BLACK);
-
- // 绘制文本
- FontMetrics fm = g2d.getFontMetrics();
- int x = 5; // 文字开始的X坐标
- int y = fm.getAscent() + 5; // 文字开始的Y坐标
- g2d.drawString(text, x, y);
-
- g2d.dispose();
- return image;
- }
-
- // 将BufferedImage转换为float数组,并进行归一化处理
- private static float[] imageToFloatArray(BufferedImage image) {
- int width = image.getWidth();
- int height = image.getHeight();
- float[] floatArray = new float[width * height];
-
- for (int y = 0; y < height; y++) {
- for (int x = 0; x < width; x++) {
- int rgb = image.getRGB(x, y);
- int gray = (rgb >> 16) & 0xFF; // 因为是灰度图,只需获取一个通道的值
- floatArray[y * width + x] = (gray / 255.0f - 0.5f) * 2.0f; // 归一化到[-1, 1]
- }
- }
-
- return floatArray;
- }
-
- private static List<String> decodeOutput(float[][][] outputData) {
- List<String> decodedTexts = new ArrayList<>();
- for (float[][] singleOutput : outputData) {
- StringBuilder decodedText = new StringBuilder();
- for (float[] charProbabilities : singleOutput) {
- int maxIndex = getMaxIndex(charProbabilities);
- decodedText.append(CHAR_SET.charAt(maxIndex));
- }
- decodedTexts.add(decodedText.toString());
- }
- return decodedTexts;
- }
-
- private static int getMaxIndex(float[] probabilities) {
- int maxIndex = 0;
- float maxProb = probabilities[0];
- for (int i = 1; i < probabilities.length; i++) {
- if (probabilities[i] > maxProb) {
- maxProb = probabilities[i];
- maxIndex = i;
- }
- }
- return maxIndex;
- }
- }
-
这段Java代码首先生成一个随机的验证码图像,然后将其转换为模型输入格式,并通过ONNX Runtime调用导出的模型进行推理,最后解码模型的输出以获取识别的验证码文本。
通过上述步骤,我们成功地在Python中导出了一个验证码识别模型,并在Java中调用该模型进行验证码识别。这种方法充分利用了Python在深度学习模型训练和导出方面的优势,以及Java在实际应用部署和性能方面的优势,实现了高效的验证码识别系统。