ObjectDetector class created, BufferedImage reading needs work

This commit is contained in:
davpapp 2018-02-21 19:21:59 -05:00
parent b146c683da
commit f27f17b2c3
11 changed files with 213 additions and 719 deletions

View File

@ -1,185 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import javax.imageio.ImageIO;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
/**
* Java inference for the Object Detection API at:
* https://github.com/tensorflow/models/blob/master/research/object_detection/
*/
public class DetectObjects {
public static void main(String[] args) throws Exception {
/*if (args.length < 3) {
printUsage(System.err);
System.exit(1);
}*/
final String[] labels = loadLabels("don't care");
try (SavedModelBundle model = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_23826/saved_model/", "serve")) {
// printSignature(model);
final String filename = "/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg";
List<Tensor<?>> outputs = null;
try (Tensor<UInt8> input = makeImageTensor(filename)) {
System.out.println("Image was converted to tensor.");
long startTime = System.currentTimeMillis();
outputs =
model
.session()
.runner()
.feed("image_tensor", input)
.fetch("detection_scores")
.fetch("detection_classes")
.fetch("detection_boxes")
.run();
System.out.println("Object detection took " + (System.currentTimeMillis() - startTime));
}
try (Tensor<Float> scoresT = outputs.get(0).expect(Float.class);
Tensor<Float> classesT = outputs.get(1).expect(Float.class);
Tensor<Float> boxesT = outputs.get(2).expect(Float.class)) {
// All these tensors have:
// - 1 as the first dimension
// - maxObjects as the second dimension
// While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
// This can be verified by looking at scoresT.shape() etc.
int maxObjects = (int) scoresT.shape()[1];
float[] scores = scoresT.copyTo(new float[1][maxObjects])[0];
float[] classes = classesT.copyTo(new float[1][maxObjects])[0];
float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0];
// Print all objects whose score is at least 0.5.
System.out.printf("* %s\n", filename);
boolean foundSomething = false;
for (int i = 0; i < scores.length; ++i) {
if (scores[i] < 0.5) {
continue;
}
foundSomething = true;
System.out.printf("\tFound %-20s (score: %.4f)\n", "ironOre", 0.342); //labels[(int) classes[i]], scores[i]);
System.out.println("Location:");
System.out.println("X:" + boxes[i][0] + ", Y:" + boxes[i][1] + ", width:" + boxes[i][2] + ", height:" + boxes[i][3]);
}
if (!foundSomething) {
System.out.println("No objects detected with a high enough score.");
}
}
}
}
private static void printSignature(SavedModelBundle model) throws Exception {
/*MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
System.out.println("MODEL SIGNATURE");
System.out.println("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
}
int numOutputs = sig.getOutputsCount();
i = 1;
System.out.println("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
}*/
System.out.println("-----------------------------------------------");
}
private static String[] loadLabels(String filename) throws Exception {
/*String text = new String(Files.readAllBytes(Paths.get(filename)), StandardCharsets.UTF_8);
StringIntLabelMap.Builder builder = StringIntLabelMap.newBuilder();
TextFormat.merge(text, builder);
StringIntLabelMap proto = builder.build();
int maxId = 0;
for (StringIntLabelMapItem item : proto.getItemList()) {
if (item.getId() > maxId) {
maxId = item.getId();
}
}
String[] ret = new String[maxId + 1];
for (StringIntLabelMapItem item : proto.getItemList()) {
ret[item.getId()] = item.getDisplayName();
}*/
String[] label = {"ironOre"};
return label;
}
private static void bgr2rgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
private static Tensor<UInt8> makeImageTensor(String filename) throws IOException {
BufferedImage img = ImageIO.read(new File(filename));
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
throw new IOException(
String.format(
"Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust",
img.getType(), filename));
}
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
// ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
bgr2rgb(data);
final long BATCH_SIZE = 1;
final long CHANNELS = 3;
long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}
private static void printUsage(PrintStream s) {
s.println("USAGE: <model> <label_map> <image> [<image>] [<image>]");
s.println("");
s.println("Where");
s.println("<model> is the path to the SavedModel directory of the model to use.");
s.println(" For example, the saved_model directory in tarballs from ");
s.println(
" https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)");
s.println("");
s.println(
"<label_map> is the path to a file containing information about the labels detected by the model.");
s.println(" For example, one of the .pbtxt files from ");
s.println(
" https://github.com/tensorflow/models/tree/master/research/object_detection/data");
s.println("");
s.println("<image> is the path to an image file.");
s.println(" Sample images can be found from the COCO, Kitti, or Open Images dataset.");
s.println(
" See: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md");
}
}

View File

@ -1,233 +0,0 @@
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import javax.imageio.ImageIO;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.types.UInt8;
public class ObjectDetectionViaSavedModelBundle {
/*public Tensor<UInt8> getImage() {
byte[] imageBytes = readAllBytesOrExit(Paths.get("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg)");
try (Tensor<UInt8> image = constructAndExecuteGraphToNormalizeImage(imageBytes) {
/*float[] labelProbabilities = executeInceptionGraph(graphDef, image);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(
String.format("BEST MATCH: %s (%.2f%% likely)",
labels.get(bestLabelIdx),
labelProbabilities[bestLabelIdx] * 100f));
}
}*/
public static void main( String[] args ) throws Exception {
/*System.out.println("Reading model from TensorFlow...");
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
ObjectDetector objectDetector = new ObjectDetector();
objectDetector.testGetLayerTypes();
objectDetector.testGetLayer();
objectDetector.testImage();*/
final int IMG_SIZE = 128;
final String value = "Hello from " + TensorFlow.version();
System.out.println(value);
//byte[] imageBytes = readAllBytesOrExit(Paths.get("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg"));
//Tensor<UInt8> image = constructAndExecuteGraphToNormalizeImage(imageBytes);
//final long[] shape = {330, 510, 3};
//Tensor image = Tensor.create(DataType.UINT8, shape, ByteBuffer.wrap(imageBytes));
SavedModelBundle load = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_23826/saved_model/", "serve");
try (Graph g = load.graph()) {
try (Session s = load.session();
Tensor result = s.runner()
.feed("image_tensor:0", image)
.fetch("detection_boxes:0").run().get(0))
{
System.out.println(result.floatValue());
}
}
load.close();
System.out.println("Done...");
}
private static Tensor<UInt8> makeImageTensor(String filename) throws IOException {
BufferedImage img = ImageIO.read(new File(filename));
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
throw new IOException(
String.format(
"Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust",
img.getType(), filename));
}
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
// ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
bgr2rgb(data);
final long BATCH_SIZE = 1;
final long CHANNELS = 3;
long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}
private static void bgr2rgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
}
/*
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
private static Tensor<UInt8> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
try (Graph g = new Graph()) {
GraphBuilder b = new GraphBuilder(g);
// Some constants specific to the pre-trained model at:
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
//
// - The model was trained with images scaled to 224x224 pixels.
// - The colors, represented as R, G, B in 1-byte each were converted to
// float using (value - Mean)/Scale.
final int H = 224;
final int W = 224;
final float mean = 117f;
final float scale = 1f;
// Since the graph is being constructed once per execution here, we can use a constant for the
// input image. If the graph were to be re-used for multiple input images, a placeholder would
// have been more appropriate.
final Output<String> input = b.constant("input", imageBytes);
final Output<Float> output =
b.div(
b.sub(
b.resizeBilinear(
b.expandDims(
b.cast(b.decodeJpeg(input, 3), UInt8.class),
b.constant("make_batch", 0)),
b.constant("size", new int[] {H, W})),
b.constant("mean", mean)),
b.constant("scale", scale));
try (Session s = new Session(g)) {
return s.runner().fetch(output.op().name()).run().get(0).expect(UInt8.class);
}
}
}
static class GraphBuilder {
GraphBuilder(Graph g) {
this.g = g;
}
Output<Float> div(Output<Float> x, Output<Float> y) {
return binaryOp("Div", x, y);
}
<T> Output<T> sub(Output<T> x, Output<T> y) {
return binaryOp("Sub", x, y);
}
<T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
return binaryOp3("ResizeBilinear", images, size);
}
<T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
return binaryOp3("ExpandDims", input, dim);
}
<T, U> Output<U> cast(Output<T> value, Class<U> type) {
DataType dtype = DataType.fromClass(type);
return g.opBuilder("Cast", "Cast")
.addInput(value)
.setAttr("DstT", dtype)
.build()
.<U>output(0);
}
Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
.addInput(contents)
.setAttr("channels", channels)
.build()
.<UInt8>output(0);
}
<T> Output<T> constant(String name, Object value, Class<T> type) {
try (Tensor<T> t = Tensor.<T>create(value, type)) {
return g.opBuilder("Const", name)
.setAttr("dtype", DataType.fromClass(type))
.setAttr("value", t)
.build()
.<T>output(0);
}
}
Output<String> constant(String name, byte[] value) {
return this.constant(name, value, String.class);
}
Output<Integer> constant(String name, int value) {
return this.constant(name, value, Integer.class);
}
Output<Integer> constant(String name, int[] value) {
return this.constant(name, value, Integer.class);
}
Output<Float> constant(String name, float value) {
return this.constant(name, value, Float.class);
}
private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
}
private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
}
private Graph g;
}
*/

View File

@ -1,308 +1,209 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
import java.awt.List;
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.dnn.DictValue;
import org.opencv.dnn.Dnn;
import org.opencv.dnn.Layer;
import org.opencv.dnn.Net;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import java.util.List;
import java.util.Map;
import javax.imageio.ImageIO;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.types.UInt8;
/**
* Java inference for the Object Detection API at:
* https://github.com/tensorflow/models/blob/master/research/object_detection/
*/
public class ObjectDetector {
SavedModelBundle model;
String inputImagePath;
String inputModelPath;
String inputModelArgumentsPath;
Net net;
public ObjectDetector() throws Exception {
this.inputImagePath = "/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg";
this.inputModelPath = "/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_23826/frozen_graph_inference.pb";
this.inputModelArgumentsPath = "/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/training/graph.pbtxt";
File f = new File(inputImagePath);
if(!f.exists()) throw new Exception("Test image is missing: " + inputImagePath);
File f1 = new File(inputModelPath);
if(!f1.exists()) throw new Exception("Test image is missing: " + inputModelPath);
File f2 = new File(inputModelArgumentsPath);
if(!f2.exists()) throw new Exception("Test image is missing: " + inputModelArgumentsPath);
net = Dnn.readNetFromTensorflow(inputModelPath, inputModelArgumentsPath);
public ObjectDetector() {
model = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_23826/saved_model/", "serve");
}
public void testGetLayerTypes() {
ArrayList<String> layertypes = new ArrayList();
net.getLayerTypes(layertypes);
assertFalse("No layer types returned!", layertypes.isEmpty());
}
public void testGetLayer() {
ArrayList<String> layernames = (ArrayList<String>) net.getLayerNames();
assertFalse("Test net returned no layers!", layernames.isEmpty());
String testLayerName = layernames.get(0);
DictValue layerId = new DictValue(testLayerName);
assertEquals("DictValue did not return the string, which was used in constructor!", testLayerName, layerId.getStringValue());
Layer layer = net.getLayer(layerId);
assertEquals("Layer name does not match the expected value!", testLayerName, layer.get_name());
}
public Mat testImage() throws Exception {
final int IN_WIDTH = 300;
final int IN_HEIGHT = 300;
final float WH_RATIO = (float)IN_WIDTH / IN_HEIGHT;
final double IN_SCALE_FACTOR = 0.007843;
final double MEAN_VAL = 127.5;
final double THRESHOLD = 0.2;
Mat frame = Imgcodecs.imread(inputImagePath);
Imgproc.cvtColor(frame, frame, Imgproc.COLOR_RGBA2RGB);
assertNotNull("Loading image from file failed!", frame);
Mat blob = Dnn.blobFromImage(frame, IN_SCALE_FACTOR,
new Size(IN_WIDTH, IN_HEIGHT),
new Scalar(MEAN_VAL, MEAN_VAL, MEAN_VAL), false, false);
net.setInput(blob);
Mat detections = net.forward();
int cols = frame.cols();
int rows = frame.rows();
Size cropSize;
if ((float)cols / rows > WH_RATIO) {
cropSize = new Size(rows * WH_RATIO, rows);
} else {
cropSize = new Size(cols, cols / WH_RATIO);
public void getIronOreLocationsFromImage(BufferedImage image) throws IOException {
List<Tensor<?>> outputs = null;
try (Tensor<UInt8> input = makeImageTensor(image)) {
outputs =
model
.session()
.runner()
.feed("image_tensor", input)
.fetch("detection_scores")
.fetch("detection_classes")
.fetch("detection_boxes")
.run();
}
int y1 = (int)(rows - cropSize.height) / 2;
int y2 = (int)(y1 + cropSize.height);
int x1 = (int)(cols - cropSize.width) / 2;
int x2 = (int)(x1 + cropSize.width);
Mat subFrame = frame.submat(y1, y2, x1, x2);
cols = subFrame.cols();
rows = subFrame.rows();
detections = detections.reshape(1, (int)detections.total() / 7);
for (int i = 0; i < detections.rows(); ++i) {
double confidence = detections.get(i, 2)[0];
if (confidence > THRESHOLD) {
int classId = (int)detections.get(i, 1)[0];
int xLeftBottom = (int)(detections.get(i, 3)[0] * cols);
int yLeftBottom = (int)(detections.get(i, 4)[0] * rows);
int xRightTop = (int)(detections.get(i, 5)[0] * cols);
int yRightTop = (int)(detections.get(i, 6)[0] * rows);
// Draw rectangle around detected object.
Imgproc.rectangle(subFrame, new Point(xLeftBottom, yLeftBottom),
new Point(xRightTop, yRightTop),
new Scalar(0, 255, 0));
String label = "ironOre" + ": " + confidence;
int[] baseLine = new int[1];
Size labelSize = Imgproc.getTextSize(label, Core.FONT_HERSHEY_SIMPLEX, 0.5, 1, baseLine);
// Draw background for label.
Imgproc.rectangle(subFrame, new Point(xLeftBottom, yLeftBottom - labelSize.height),
new Point(xLeftBottom + labelSize.width, yLeftBottom + baseLine[0]),
new Scalar(255, 255, 255), Core.FILLED);
// Write class name and confidence.
Imgproc.putText(subFrame, label, new Point(xLeftBottom, yLeftBottom),
Core.FONT_HERSHEY_SIMPLEX, 0.5, new Scalar(0, 0, 0));
try (Tensor<Float> scoresT = outputs.get(0).expect(Float.class);
Tensor<Float> classesT = outputs.get(1).expect(Float.class);
Tensor<Float> boxesT = outputs.get(2).expect(Float.class)) {
// All these tensors have:
// - 1 as the first dimension
// - maxObjects as the second dimension
// While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
// This can be verified by looking at scoresT.shape() etc.
int maxObjects = (int) scoresT.shape()[1];
float[] scores = scoresT.copyTo(new float[1][maxObjects])[0];
float[] classes = classesT.copyTo(new float[1][maxObjects])[0];
float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0];
// Print all objects whose score is at least 0.5.
boolean foundSomething = false;
for (int i = 0; i < scores.length; ++i) {
if (scores[i] < 0.5) {
continue;
}
foundSomething = true;
System.out.printf("\tFound %-20s (score: %.4f)\n", "ironOre", 1.0000); //labels[(int) classes[i]], scores[i]);
System.out.println("Location:");
System.out.println("X:" + 510 * boxes[i][1] + ", Y:" + 330 * boxes[i][0] + ", width:" + 510 * boxes[i][3] + ", height:" + 330 * boxes[i][2]);
}
if (!foundSomething) {
System.out.println("No objects detected with a high enough score.");
}
}
return frame;
}
public static void main( String[] args ) throws Exception {
System.out.println("Reading model from TensorFlow...");
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
ObjectDetector objectDetector = new ObjectDetector();
objectDetector.testGetLayerTypes();
objectDetector.testGetLayer();
objectDetector.testImage();
/*
final int IMG_SIZE = 128;
final String value = "Hello from " + TensorFlow.version();
byte[] imageBytes = readAllBytesOrExit(Paths.get("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg"));
Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes);
SavedModelBundle load = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/SavedModel/saved_model.pb");
long[] sitio2;
try (Graph g = load.graph()) {
try (Session s = load.session();
Tensor result = s.runner()
.feed("image_tensor", image)
.fetch("detection_boxes").run().get(0))
{
sitio2 = (long[]) result.copyTo(new long[1]);
System.out.print(sitio2[0]+"\n");
}
}
/*public void test() {
try (SavedModelBundle model = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_23826/saved_model/", "serve")) {
// printSignature(model);
final String filename = "/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg";
List<Tensor<?>> outputs = null;
try (Tensor<UInt8> input = makeImageTensor(filename)) {
System.out.println("Image was converted to tensor.");
long startTime = System.currentTimeMillis();
outputs =
model
.session()
.runner()
.feed("image_tensor", input)
.fetch("detection_scores")
.fetch("detection_classes")
.fetch("detection_boxes")
.run();
System.out.println("Object detection took " + (System.currentTimeMillis() - startTime));
}
try (Tensor<Float> scoresT = outputs.get(0).expect(Float.class);
Tensor<Float> classesT = outputs.get(1).expect(Float.class);
Tensor<Float> boxesT = outputs.get(2).expect(Float.class)) {
// All these tensors have:
// - 1 as the first dimension
// - maxObjects as the second dimension
// While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
// This can be verified by looking at scoresT.shape() etc.
int maxObjects = (int) scoresT.shape()[1];
float[] scores = scoresT.copyTo(new float[1][maxObjects])[0];
float[] classes = classesT.copyTo(new float[1][maxObjects])[0];
float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0];
// Print all objects whose score is at least 0.5.
System.out.printf("* %s\n", filename);
boolean foundSomething = false;
for (int i = 0; i < scores.length; ++i) {
if (scores[i] < 0.5) {
continue;
}
foundSomething = true;
System.out.printf("\tFound %-20s (score: %.4f)\n", "ironOre", 1.0000); //labels[(int) classes[i]], scores[i]);
System.out.println("Location:");
System.out.println("X:" + 510 * boxes[i][1] + ", Y:" + 330 * boxes[i][0] + ", width:" + 510 * boxes[i][3] + ", height:" + 330 * boxes[i][2]);
}
if (!foundSomething) {
System.out.println("No objects detected with a high enough score.");
}
}
}
}*/
private static void printSignature(SavedModelBundle model) throws Exception {
/*MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
System.out.println("MODEL SIGNATURE");
System.out.println("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
}
int numOutputs = sig.getOutputsCount();
i = 1;
System.out.println("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
}*/
System.out.println("-----------------------------------------------");
}
/*private static String[] loadLabels(String filename) throws Exception {
String text = new String(Files.readAllBytes(Paths.get(filename)), StandardCharsets.UTF_8);
StringIntLabelMap.Builder builder = StringIntLabelMap.newBuilder();
TextFormat.merge(text, builder);
StringIntLabelMap proto = builder.build();
int maxId = 0;
for (StringIntLabelMapItem item : proto.getItemList()) {
if (item.getId() > maxId) {
maxId = item.getId();
}
load.close();
*/
System.out.println("Done...");
}
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
try (Graph g = new Graph()) {
GraphBuilder b = new GraphBuilder(g);
// Some constants specific to the pre-trained model at:
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
//
// - The model was trained with images scaled to 224x224 pixels.
// - The colors, represented as R, G, B in 1-byte each were converted to
// float using (value - Mean)/Scale.
final int H = 224;
final int W = 224;
final float mean = 117f;
final float scale = 1f;
}
String[] ret = new String[maxId + 1];
for (StringIntLabelMapItem item : proto.getItemList()) {
ret[item.getId()] = item.getDisplayName();
}
String[] label = {"ironOre"};
return label;
}*/
// Since the graph is being constructed once per execution here, we can use a constant for the
// input image. If the graph were to be re-used for multiple input images, a placeholder would
// have been more appropriate.
final Output<String> input = b.constant("input", imageBytes);
final Output<Float> output =
b.div(
b.sub(
b.resizeBilinear(
b.expandDims(
b.cast(b.decodeJpeg(input, 3), Float.class),
b.constant("make_batch", 0)),
b.constant("size", new int[] {H, W})),
b.constant("mean", mean)),
b.constant("scale", scale));
try (Session s = new Session(g)) {
return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
}
}
}
static class GraphBuilder {
GraphBuilder(Graph g) {
this.g = g;
}
private static void bgr2rgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
Output<Float> div(Output<Float> x, Output<Float> y) {
return binaryOp("Div", x, y);
}
<T> Output<T> sub(Output<T> x, Output<T> y) {
return binaryOp("Sub", x, y);
}
<T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
return binaryOp3("ResizeBilinear", images, size);
}
<T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
return binaryOp3("ExpandDims", input, dim);
}
<T, U> Output<U> cast(Output<T> value, Class<U> type) {
DataType dtype = DataType.fromClass(type);
return g.opBuilder("Cast", "Cast")
.addInput(value)
.setAttr("DstT", dtype)
.build()
.<U>output(0);
}
Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
.addInput(contents)
.setAttr("channels", channels)
.build()
.<UInt8>output(0);
}
<T> Output<T> constant(String name, Object value, Class<T> type) {
try (Tensor<T> t = Tensor.<T>create(value, type)) {
return g.opBuilder("Const", name)
.setAttr("dtype", DataType.fromClass(type))
.setAttr("value", t)
.build()
.<T>output(0);
}
}
Output<String> constant(String name, byte[] value) {
return this.constant(name, value, String.class);
}
Output<Integer> constant(String name, int value) {
return this.constant(name, value, Integer.class);
}
Output<Integer> constant(String name, int[] value) {
return this.constant(name, value, Integer.class);
}
Output<Float> constant(String name, float value) {
return this.constant(name, value, Float.class);
}
private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
}
private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
}
private Graph g;
}
}
private static Tensor<UInt8> makeImageTensor(BufferedImage img) throws IOException {
//BufferedImage img = ImageIO.read(new File(filename));
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
throw new IOException(
String.format(
"Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust"));
}
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
// ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
bgr2rgb(data);
final long BATCH_SIZE = 1;
final long CHANNELS = 3;
long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}
}

View File

@ -1,8 +0,0 @@
public class StringIntLabelOuterClass {
public StringIntLabelOuterClass() {
// TODO Auto-generated constructor stub
}
}

View File

@ -1,23 +1,35 @@
import java.awt.AWTException;
import java.awt.Rectangle;
import java.awt.Robot;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;
public class WillowChopper {
Cursor cursor;
CursorTask cursorTask;
Inventory inventory;
ObjectDetector objectDetector;
Robot robot;
public WillowChopper() throws AWTException, IOException
{
cursor = new Cursor();
cursorTask = new CursorTask();
inventory = new Inventory();
objectDetector = new ObjectDetector();
robot = new Robot();
}
public void run() throws Exception {
System.out.println("Starting ironMiner...");
while (true) {
Thread.sleep(250);
BufferedImage image = captureScreenshotGameWindow();
objectDetector.getIronOreLocationsFromImage(image);
System.out.println("--------------------------------\n\n");
/*
if (character.isCharacterEngaged()) {
// DO NOTHING
@ -28,15 +40,22 @@ public class WillowChopper {
chop willow tree
}
*/
inventory.update();
/*inventory.update();
if (inventory.isInventoryFull()) {
long startTime = System.currentTimeMillis();
System.out.println("Inventory is full! Dropping...");
cursorTask.optimizedDropAllItemsInInventory(cursor, inventory);
System.out.println("Dropping took " + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds.");
//cursorTask.dropAllItemsInInventory(cursor, inventory);
}
}*/
}
}
private BufferedImage captureScreenshotGameWindow() throws IOException {
Rectangle area = new Rectangle(103, 85, 510, 330);
return robot.createScreenCapture(area);
}
}

Binary file not shown.

Binary file not shown.

Binary file not shown.