Object detection from SavedModel API is working!

This commit is contained in:
davpapp 2018-02-21 18:01:35 -05:00
parent 8ccafa8c63
commit b146c683da
7 changed files with 263 additions and 8 deletions

185
src/DetectObjects.java Normal file
View File

@ -0,0 +1,185 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import javax.imageio.ImageIO;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
/**
* Java inference for the Object Detection API at:
* https://github.com/tensorflow/models/blob/master/research/object_detection/
*/
public class DetectObjects {
public static void main(String[] args) throws Exception {
/*if (args.length < 3) {
printUsage(System.err);
System.exit(1);
}*/
final String[] labels = loadLabels("don't care");
try (SavedModelBundle model = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_23826/saved_model/", "serve")) {
// printSignature(model);
final String filename = "/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg";
List<Tensor<?>> outputs = null;
try (Tensor<UInt8> input = makeImageTensor(filename)) {
System.out.println("Image was converted to tensor.");
long startTime = System.currentTimeMillis();
outputs =
model
.session()
.runner()
.feed("image_tensor", input)
.fetch("detection_scores")
.fetch("detection_classes")
.fetch("detection_boxes")
.run();
System.out.println("Object detection took " + (System.currentTimeMillis() - startTime));
}
try (Tensor<Float> scoresT = outputs.get(0).expect(Float.class);
Tensor<Float> classesT = outputs.get(1).expect(Float.class);
Tensor<Float> boxesT = outputs.get(2).expect(Float.class)) {
// All these tensors have:
// - 1 as the first dimension
// - maxObjects as the second dimension
// While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
// This can be verified by looking at scoresT.shape() etc.
int maxObjects = (int) scoresT.shape()[1];
float[] scores = scoresT.copyTo(new float[1][maxObjects])[0];
float[] classes = classesT.copyTo(new float[1][maxObjects])[0];
float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0];
// Print all objects whose score is at least 0.5.
System.out.printf("* %s\n", filename);
boolean foundSomething = false;
for (int i = 0; i < scores.length; ++i) {
if (scores[i] < 0.5) {
continue;
}
foundSomething = true;
System.out.printf("\tFound %-20s (score: %.4f)\n", "ironOre", 0.342); //labels[(int) classes[i]], scores[i]);
System.out.println("Location:");
System.out.println("X:" + boxes[i][0] + ", Y:" + boxes[i][1] + ", width:" + boxes[i][2] + ", height:" + boxes[i][3]);
}
if (!foundSomething) {
System.out.println("No objects detected with a high enough score.");
}
}
}
}
private static void printSignature(SavedModelBundle model) throws Exception {
/*MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
System.out.println("MODEL SIGNATURE");
System.out.println("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
}
int numOutputs = sig.getOutputsCount();
i = 1;
System.out.println("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
}*/
System.out.println("-----------------------------------------------");
}
private static String[] loadLabels(String filename) throws Exception {
/*String text = new String(Files.readAllBytes(Paths.get(filename)), StandardCharsets.UTF_8);
StringIntLabelMap.Builder builder = StringIntLabelMap.newBuilder();
TextFormat.merge(text, builder);
StringIntLabelMap proto = builder.build();
int maxId = 0;
for (StringIntLabelMapItem item : proto.getItemList()) {
if (item.getId() > maxId) {
maxId = item.getId();
}
}
String[] ret = new String[maxId + 1];
for (StringIntLabelMapItem item : proto.getItemList()) {
ret[item.getId()] = item.getDisplayName();
}*/
String[] label = {"ironOre"};
return label;
}
private static void bgr2rgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
private static Tensor<UInt8> makeImageTensor(String filename) throws IOException {
BufferedImage img = ImageIO.read(new File(filename));
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
throw new IOException(
String.format(
"Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust",
img.getType(), filename));
}
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
// ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
bgr2rgb(data);
final long BATCH_SIZE = 1;
final long CHANNELS = 3;
long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}
private static void printUsage(PrintStream s) {
s.println("USAGE: <model> <label_map> <image> [<image>] [<image>]");
s.println("");
s.println("Where");
s.println("<model> is the path to the SavedModel directory of the model to use.");
s.println(" For example, the saved_model directory in tarballs from ");
s.println(
" https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)");
s.println("");
s.println(
"<label_map> is the path to a file containing information about the labels detected by the model.");
s.println(" For example, one of the .pbtxt files from ");
s.println(
" https://github.com/tensorflow/models/tree/master/research/object_detection/data");
s.println("");
s.println("<image> is the path to an image file.");
s.println(" Sample images can be found from the COCO, Kitti, or Open Images dataset.");
s.println(
" See: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md");
}
}

View File

@ -1,9 +1,13 @@
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.opencv.core.Core;
import javax.imageio.ImageIO;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Output;
@ -16,6 +20,20 @@ import org.tensorflow.types.UInt8;
public class ObjectDetectionViaSavedModelBundle {
/*public Tensor<UInt8> getImage() {
byte[] imageBytes = readAllBytesOrExit(Paths.get("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg)");
try (Tensor<UInt8> image = constructAndExecuteGraphToNormalizeImage(imageBytes) {
/*float[] labelProbabilities = executeInceptionGraph(graphDef, image);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(
String.format("BEST MATCH: %s (%.2f%% likely)",
labels.get(bestLabelIdx),
labelProbabilities[bestLabelIdx] * 100f));
}
}*/
public static void main( String[] args ) throws Exception {
/*System.out.println("Reading model from TensorFlow...");
@ -33,11 +51,13 @@ public class ObjectDetectionViaSavedModelBundle {
final String value = "Hello from " + TensorFlow.version();
System.out.println(value);
byte[] imageBytes = readAllBytesOrExit(Paths.get("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg"));
Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes);
//byte[] imageBytes = readAllBytesOrExit(Paths.get("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg"));
//Tensor<UInt8> image = constructAndExecuteGraphToNormalizeImage(imageBytes);
//final long[] shape = {330, 510, 3};
//Tensor image = Tensor.create(DataType.UINT8, shape, ByteBuffer.wrap(imageBytes));
SavedModelBundle load = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_23826/saved_model/", "serve");
try (Graph g = load.graph()) {
try (Session s = load.session();
Tensor result = s.runner()
@ -52,6 +72,47 @@ public class ObjectDetectionViaSavedModelBundle {
System.out.println("Done...");
}
private static Tensor<UInt8> makeImageTensor(String filename) throws IOException {
BufferedImage img = ImageIO.read(new File(filename));
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
throw new IOException(
String.format(
"Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust",
img.getType(), filename));
}
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
// ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
bgr2rgb(data);
final long BATCH_SIZE = 1;
final long CHANNELS = 3;
long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}
private static void bgr2rgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
}
/*
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
@ -62,7 +123,7 @@ public class ObjectDetectionViaSavedModelBundle {
return null;
}
private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
private static Tensor<UInt8> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
try (Graph g = new Graph()) {
GraphBuilder b = new GraphBuilder(g);
// Some constants specific to the pre-trained model at:
@ -85,13 +146,13 @@ public class ObjectDetectionViaSavedModelBundle {
b.sub(
b.resizeBilinear(
b.expandDims(
b.cast(b.decodeJpeg(input, 3), Float.class),
b.cast(b.decodeJpeg(input, 3), UInt8.class),
b.constant("make_batch", 0)),
b.constant("size", new int[] {H, W})),
b.constant("mean", mean)),
b.constant("scale", scale));
try (Session s = new Session(g)) {
return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
return s.runner().fetch(output.op().name()).run().get(0).expect(UInt8.class);
}
}
}
@ -168,4 +229,5 @@ public class ObjectDetectionViaSavedModelBundle {
}
private Graph g;
}
}
*/

View File

@ -0,0 +1,8 @@
public class StringIntLabelOuterClass {
public StringIntLabelOuterClass() {
// TODO Auto-generated constructor stub
}
}

Binary file not shown.

Binary file not shown.