Monday, May 1, 2023

How to Implement Image classification using TensorFlow maven and Java

Here is an example of using TensorFlow with Java and Maven to perform image classification: 

 1.Create a new Maven project in your favorite IDE. 

 2. Add the TensorFlow Java dependency to your project by adding the following to your pom.xml file:

    
      <dependencies>
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow</artifactId>
        <version>2.7.0</version>
    </dependency>
</dependencies>
    

3. Create a new class, for example ImageClassifier.java, and add the following code:
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;

public class ImageClassifier {
    private static byte[] loadImage(String path) throws IOException {
        BufferedImage img = ImageIO.read(new File(path));
        int height = img.getHeight();
        int width = img.getWidth();
        int channels = 3;
        byte[] data = new byte[height * width * channels];
        int pixel = 0;
        for (int i = 0; i < height; i++) {
            for (int j = 0; j < width; j++) {
                int rgb = img.getRGB(j, i);
                data[pixel++] = (byte) ((rgb >> 16) & 0xFF);
                data[pixel++] = (byte) ((rgb >> 8) & 0xFF);
                data[pixel++] = (byte) (rgb & 0xFF);
            }
        }
        return data;
    }

    public static void main(String[] args) throws Exception {
        // Load the TensorFlow library
        try (Graph g = new Graph()) {
           byte[] graphBytes = TensorFlowModelLoader.load("path/to/model.pb");
            g.importGraphDef(graphBytes);

            // Create a new session to run the graph
            try (Session s = new Session(g)) {
                // Load the image data
                String imagePath = "path/to/image.jpg";
                byte[] imageBytes = loadImage(imagePath);

                // Create a tensor from the image data
                Tensor inputTensor = Tensor.create(new long[]
                                   {1, imageBytes.length}, ByteBuffer.wrap(imageBytes));

                // Run the graph on the input tensor
                Tensor outputTensor = s.runner()
                        .feed("input", inputTensor)
                        .fetch("output")
                        .run()
                        .get(0);

                // Print the predicted label
                DataType outputDataType = outputTensor.dataType();
                long[] outputShape = outputTensor.shape();
                Object[] output = new Object[outputTensor.numElements()];
                outputTensor.copyTo(output);
                System.out.println("Prediction: " + output[0]);
            }
        }
    }
}
4. Replace the path/to/model.pb and path/to/image.jpg with the actual paths to your model and image files. 

 5. Run the ImageClassifier class, and it should print out the predicted label for the input image.

AddToAny

Contact Form

Name

Email *

Message *