Monday, May 1, 2023

How to Implement Image classification using TensorFlow maven and Java

Here is an example of using TensorFlow with Java and Maven to perform image classification: 

 1.Create a new Maven project in your favorite IDE. 

 2. Add the TensorFlow Java dependency to your project by adding the following to your pom.xml file:

    
      <dependencies>
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow</artifactId>
        <version>2.7.0</version>
    </dependency>
</dependencies>
    

3. Create a new class, for example ImageClassifier.java, and add the following code:
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;

public class ImageClassifier {
    private static byte[] loadImage(String path) throws IOException {
        BufferedImage img = ImageIO.read(new File(path));
        int height = img.getHeight();
        int width = img.getWidth();
        int channels = 3;
        byte[] data = new byte[height * width * channels];
        int pixel = 0;
        for (int i = 0; i < height; i++) {
            for (int j = 0; j < width; j++) {
                int rgb = img.getRGB(j, i);
                data[pixel++] = (byte) ((rgb >> 16) & 0xFF);
                data[pixel++] = (byte) ((rgb >> 8) & 0xFF);
                data[pixel++] = (byte) (rgb & 0xFF);
            }
        }
        return data;
    }

    public static void main(String[] args) throws Exception {
        // Load the TensorFlow library
        try (Graph g = new Graph()) {
           byte[] graphBytes = TensorFlowModelLoader.load("path/to/model.pb");
            g.importGraphDef(graphBytes);

            // Create a new session to run the graph
            try (Session s = new Session(g)) {
                // Load the image data
                String imagePath = "path/to/image.jpg";
                byte[] imageBytes = loadImage(imagePath);

                // Create a tensor from the image data
                Tensor inputTensor = Tensor.create(new long[]
                                   {1, imageBytes.length}, ByteBuffer.wrap(imageBytes));

                // Run the graph on the input tensor
                Tensor outputTensor = s.runner()
                        .feed("input", inputTensor)
                        .fetch("output")
                        .run()
                        .get(0);

                // Print the predicted label
                DataType outputDataType = outputTensor.dataType();
                long[] outputShape = outputTensor.shape();
                Object[] output = new Object[outputTensor.numElements()];
                outputTensor.copyTo(output);
                System.out.println("Prediction: " + output[0]);
            }
        }
    }
}
4. Replace the path/to/model.pb and path/to/image.jpg with the actual paths to your model and image files. 

 5. Run the ImageClassifier class, and it should print out the predicted label for the input image.

No comments:

Post a Comment

AddToAny

Contact Form

Name

Email *

Message *