Object tracking framework is working, but is too inaccurate. Also added some tests

This commit is contained in:
davpapp 2018-02-24 10:49:49 -05:00
parent 8dbc60a00f
commit f742cf3201
12 changed files with 201 additions and 45 deletions

View File

@ -7,7 +7,7 @@ import org.opencv.core.Rect2d;
public class DetectedObject {
private Rectangle boundingBox;
private Rect2d boundingBox;
private float detectionScore;
private String detectionClass;
@ -17,14 +17,13 @@ public class DetectedObject {
this.detectionClass = initializeLabel(detectionClass);
}
private Rectangle initializeBoundingBox(float[] detectionBox) {
// TODO: migrate this all to a Rect2d data type
private Rect2d initializeBoundingBox(float[] detectionBox) {
int offset_x = (int) (detectionBox[1] * Constants.GAME_WINDOW_WIDTH);
int offset_y = (int) (detectionBox[0] * Constants.GAME_WINDOW_HEIGHT);
int width = (int) (detectionBox[3] * Constants.GAME_WINDOW_WIDTH) - offset_x;
int height = (int) (detectionBox[2] * Constants.GAME_WINDOW_HEIGHT) - offset_y;
//System.out.println(detectionBox[0] + ", " + detectionBox[1] + ", " + detectionBox[2] + ", " + detectionBox[3]);
return new Rectangle(offset_x, offset_y, width, height);
return new Rect2d(offset_x, offset_y, width, height);
}
private String initializeLabel(float detectionClass) {
@ -37,16 +36,12 @@ public class DetectedObject {
return detectionClass;
}
public Rectangle getBoundingRectangle() {
return boundingBox;
}
public Rect2d getBoundingRect2d() {
return new Rect2d(boundingBox.x, boundingBox.y, boundingBox.x + boundingBox.width, boundingBox.y + boundingBox.height);
return new Rect2d(boundingBox.x, boundingBox.y, boundingBox.width, boundingBox.height);
}
public Point getCenterForClicking() {
return new Point(boundingBox.x + boundingBox.width / 2 + Constants.GAME_WINDOW_OFFSET_X, boundingBox.y + boundingBox.height / 2 + Constants.GAME_WINDOW_OFFSET_Y);
return new Point((int) (boundingBox.x + boundingBox.width / 2 + Constants.GAME_WINDOW_OFFSET_X), (int) (boundingBox.y + boundingBox.height / 2 + Constants.GAME_WINDOW_OFFSET_Y));
}
public void display() {

View File

@ -1,4 +1,5 @@
import java.awt.AWTException;
import java.awt.Graphics2D;
import java.awt.Point;
import java.awt.Rectangle;
import java.awt.Robot;
@ -10,6 +11,8 @@ import java.util.ArrayList;
import javax.imageio.ImageIO;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Rect2d;
import org.opencv.tracking.Tracker;
@ -30,29 +33,25 @@ public class IronMiner {
public IronMiner() throws AWTException, IOException
{
//cursor = new Cursor();
//cursorTask = new CursorTask();
//inventory = new Inventory();
cursor = new Cursor();
cursorTask = new CursorTask();
inventory = new Inventory();
objectDetector = new ObjectDetector();
robot = new Robot();
randomizer = new Randomizer();
}
public void run() throws Exception {
int count = 0;
long mineStartTime = System.currentTimeMillis();
while (System.currentTimeMillis() - 60000 < mineStartTime) {
count++;
while (true) {
BufferedImage screenCapture = objectDetector.captureScreenshotGameWindow();
ArrayList<DetectedObject> detectedObjects = objectDetector.getObjectsInImage(screenCapture);
//ArrayList<DetectedObject> ironOres = objectDetector.getObjectsOfClassInList(detectedObjects, "ironOre");
System.out.println("Count: " + count);
System.out.println(detectedObjects.size());
/*DetectedObject closestIronOre = getClosestObjectToCharacter(ironOres);
ArrayList<DetectedObject> ironOres = objectDetector.getObjectsOfClassInList(detectedObjects, "ironOre");
DetectedObject closestIronOre = getClosestObjectToCharacter(ironOres);
if (closestIronOre != null) {
//Tracker objectTracker = TrackerKCF.create();
//Rect2d boundingBox = closestIronOre.getBoundingRect2d();
System.out.println("Found iron ore! Starting tracking!");
Tracker objectTracker = TrackerKCF.create();
Rect2d boundingBox = closestIronOre.getBoundingRect2d();
objectTracker.init(getMatFromBufferedImage(screenCapture), boundingBox);
cursor.moveAndLeftClickAtCoordinatesWithRandomness(closestIronOre.getCenterForClicking(), 10, 10);
@ -61,17 +60,26 @@ public class IronMiner {
int maxTimeToMine = randomizer.nextGaussianWithinRange(3500, 5000);
// track until either we lose the object or too much time passes
while ((System.currentTimeMillis() - mineStartTime) < maxTimeToMine) {
int lostTrackCounter = 0;
while (((System.currentTimeMillis() - mineStartTime) < maxTimeToMine) && lostTrackCounter < 3) {
screenCapture = objectDetector.captureScreenshotGameWindow();
detectedObjects = objectDetector.getObjectsInImage(screenCapture);
boolean ok = objectTracker.update(getMatFromBufferedImage(screenCapture), boundingBox);
if (!ok || !objectDetector.isObjectPresentInBoundingBoxInImage(screenCapture, boundingBox, "ironOre")) {
System.out.println("Lost track! Finding new ore.");
break;
if (!ok || !objectDetector.isObjectPresentInBoundingBoxInImage(detectedObjects, boundingBox, "ironOre")) {
System.out.println("Lost track for + " + lostTrackCounter + "! Finding new ore soon.");
lostTrackCounter++;
}
else if (ok) {
lostTrackCounter = 0;
System.out.println("Tracking at " + boundingBox.x + ", " + boundingBox.y + ", " + boundingBox.width + ", " + boundingBox.height);
}
}
}
dropInventoryIfFull();*/
dropInventoryIfFull();
}
}
@ -108,13 +116,32 @@ public class IronMiner {
return null;
}
public Mat getMatFromBufferedImage(BufferedImage image) {
Mat matImage = new Mat();
byte[] pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();
private Mat getMatFromBufferedImage(BufferedImage image) {
BufferedImage formattedImage = convertBufferedImage(image, BufferedImage.TYPE_3BYTE_BGR);
byte[] data = ((DataBufferByte) formattedImage.getData().getDataBuffer()).getData();
bgr2rgb(data);
Mat matImage = new Mat(formattedImage.getWidth(), formattedImage.getHeight(), CvType.CV_8UC3);
byte[] pixels = ((DataBufferByte) formattedImage.getRaster().getDataBuffer()).getData();
matImage.put(0, 0, pixels);
return matImage;
}
private static BufferedImage convertBufferedImage(BufferedImage sourceImage, int bufferedImageType) {
BufferedImage image = new BufferedImage(sourceImage.getWidth(), sourceImage.getHeight(), bufferedImageType);
Graphics2D g2d = image.createGraphics();
g2d.drawImage(sourceImage, 0, 0, null);
g2d.dispose();
return image;
}
private static void bgr2rgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
public int getDistanceBetweenPoints(Point startingPoint, Point goalPoint) {
return (int) (Math.hypot(goalPoint.x - startingPoint.x, goalPoint.y - startingPoint.y));
}

View File

@ -33,6 +33,7 @@ import java.util.List;
import java.util.Map;
import javax.imageio.ImageIO;
import org.opencv.core.Core;
import org.opencv.core.Rect2d;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
@ -48,7 +49,7 @@ public class ObjectDetector {
Robot robot;
public ObjectDetector() throws AWTException {
this.model = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_22948/saved_model/", "serve");
this.model = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_56749/saved_model/", "serve");
this.robot = new Robot();
}
@ -94,10 +95,26 @@ public class ObjectDetector {
return detectedObjectsInImage;
}
public boolean isObjectPresentInBoundingBoxInImage(BufferedImage image, Rect2d boundingBox, String objectClass) throws Exception {
/*public boolean isObjectPresentInBoundingBoxInImage(BufferedImage image, Rect2d boundingBox, String objectClass) throws Exception {
BufferedImage subImage = image.getSubimage((int) boundingBox.x, (int) boundingBox.y, (int) boundingBox.width, (int) boundingBox.height);
ArrayList<DetectedObject> detectedObjectsInSubImage = getObjectsInImage(subImage);
return (getObjectsOfClassInList(detectedObjectsInSubImage, objectClass).size() != 0);
}*/
public boolean isObjectPresentInBoundingBoxInImage(ArrayList<DetectedObject> detectedObjects, Rect2d boundingBox, String objectClass) throws Exception {
for (DetectedObject detectedObject : detectedObjects) {
if (detectedObject.getDetectionClass().equals(objectClass)) {
//System.out.println(("Required bounding box: " + (int) boundingBox.x + ", " + (int) boundingBox.y + ", " + (int) boundingBox.width + ", " + (int) boundingBox.height));
//System.out.println(("Detected bounding box: " + (int) detectedObject.getBoundingRect2d().x + ", " + (int) detectedObject.getBoundingRect2d().y + ", " + (int) detectedObject.getBoundingRect2d().width + ", " + (int) detectedObject.getBoundingRect2d().height) + "\n");
if ((Math.abs(detectedObject.getBoundingRect2d().x - boundingBox.x) < 10) &&
(Math.abs(detectedObject.getBoundingRect2d().y - boundingBox.y) < 10) &&
(Math.abs(detectedObject.getBoundingRect2d().width - boundingBox.width) < 10) &&
(Math.abs(detectedObject.getBoundingRect2d().height - boundingBox.height) < 10)) {
return true;
}
}
}
return false;
}
public ArrayList<DetectedObject> getObjectsOfClassInList(ArrayList<DetectedObject> detectedObjects, String objectClass) {
@ -111,17 +128,11 @@ public class ObjectDetector {
}
private static Tensor<UInt8> makeImageTensor(BufferedImage image) throws IOException {
/*if (image.getType() != BufferedImage.TYPE_3BYTE_BGR) {
throw new IOException(
String.format(
"Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust"));
}*/
BufferedImage formattedImage = convertBufferedImage(image, BufferedImage.TYPE_3BYTE_BGR);
byte[] data = ((DataBufferByte) formattedImage.getData().getDataBuffer()).getData();
bgr2rgb(data);
// ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
// BufferedImage and ImageIO.read() seems to produce BGR-encoded images, but the model expects RGB.
final long BATCH_SIZE = 1;
final long CHANNELS = 3;
long[] shape = new long[] {BATCH_SIZE, formattedImage.getHeight(), formattedImage.getWidth(), CHANNELS};

View File

@ -1,15 +1,39 @@
import static org.junit.jupiter.api.Assertions.*;
import java.awt.AWTException;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import javax.imageio.ImageIO;
import org.junit.jupiter.api.Test;
import org.opencv.core.Rect2d;
class ObjectDetectorTest {
@Test
void testObjectDetection() throws AWTException {
void testObjectDetection() throws Exception {
ObjectDetector objectDetector = new ObjectDetector();
ArrayList<DetectedObject> detectedObjects1 = objectDetector.getObjectsInImage(loadimages here in bufferedimage format));
BufferedImage testImage1 = ImageIO.read(new File("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg"));
ArrayList<DetectedObject> detectedObjects1 = objectDetector.getObjectsInImage(testImage1);
ArrayList<DetectedObject> detectedIronOres1 = objectDetector.getObjectsOfClassInList(detectedObjects1, "ironOre");
ArrayList<DetectedObject> detectedOres1 = objectDetector.getObjectsOfClassInList(detectedObjects1, "ore");
assertEquals(3, detectedIronOres1.size());
assertEquals(2, detectedOres1.size());
//assertDetectedObjectsAreEqual();
}
void assertDetectedObjectsAreEqual(DetectedObject obj1, DetectedObject obj2) {
}va
}
void assertBoundingBoxesAreEqual(Rect2d bb1, Rect2d bb2) {
assertEquals(bb1.x, bb2.x, 3);
assertEquals(bb1.y, bb2.y, 3);
assertEquals(bb1.width, bb2.height, 3);
assertEquals(bb1.width, bb2.height, 3);
}
}

View File

@ -0,0 +1,96 @@
import static org.junit.jupiter.api.Assertions.*;
import java.awt.AWTException;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.util.ArrayList;
import javax.imageio.ImageIO;
import org.junit.jupiter.api.Test;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.MatOfByte;
import org.opencv.core.Rect2d;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.tracking.Tracker;
import org.opencv.tracking.TrackerKCF;
import org.opencv.videoio.VideoCapture;
class ObjectTrackerTest {
@Test
void testObjectTracking() throws Exception {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
VideoCapture video = new VideoCapture("/home/dpapp/Videos/gameplay-2018-02-24_10.01.00.mp4");
System.out.println("loaded video...");
ObjectDetector objectDetector = new ObjectDetector();
Mat frame = new Mat();
boolean frameReadSuccess = video.read(frame);
assertTrue(frameReadSuccess);
ArrayList<DetectedObject> detectedObjects = objectDetector.getObjectsInImage(Mat2BufferedImage(frame));
System.out.println("Tracking " + detectedObjects.size() + " objects.");
ArrayList<Tracker> objectTrackers = new ArrayList<Tracker>();
ArrayList<Rect2d> boundingBoxes = new ArrayList<Rect2d>();
for (int i = 0; i < 3; i++) {
boundingBoxes.add(detectedObjects.get(i).getBoundingRect2d());
objectTrackers.add(TrackerKCF.create());
objectTrackers.get(i).init(frame, boundingBoxes.get(i));
}
//System.out.println("bounding box: " + (int) boundingBoxes.get(0).x + ", " + (int) boundingBoxes.get(0).y + ", " + (int) boundingBoxes.get(0).width + ", " + (int) boundingBoxes.get(0).height);
int counter = 0;
while (video.read(frame)) {
for (int i = 0; i < 3; i++) {
objectTrackers.get(i).update(frame, boundingBoxes.get(i));
boolean trackingSuccess = objectTrackers.get(i).update(frame, boundingBoxes.get(i));
detectedObjects = objectDetector.getObjectsInImage(Mat2BufferedImage(frame));
//System.out.println(detectedObjects.size());
//System.out.println((int) boundingBoxes.get(i).x + ", " + (int) boundingBoxes.get(i).y + ", " + (int) boundingBoxes.get(i).width + ", " + (int) boundingBoxes.get(i).height);
//BufferedImage subImage = screencapture.getSubimage((int) boundingBoxes.get(i).x- 10, (int) boundingBoxes.get(i).y - 10, (int) boundingBoxes.get(i).width + 20, (int) boundingBoxes.get(i).height + 20);
boolean ironOreDetected = objectDetector.isObjectPresentInBoundingBoxInImage(detectedObjects, boundingBoxes.get(i), "ironOre");
boolean oreDetected = objectDetector.isObjectPresentInBoundingBoxInImage(detectedObjects, boundingBoxes.get(i), "ore");
//ImageIO.write(screencapture, "jpg", new File("/home/dpapp/Videos/frames/frame_" + counter + ".jpg"));
//ImageIO.write(subImage, "jpg", new File("/home/dpapp/Videos/sub_frames/frame_" + counter + "_sub.jpg"));
//System.out.println("wrote file...");
if (i == 2) {
System.out.println(trackingSuccess + ", ironOre: " + ironOreDetected + ", ore:" + oreDetected);
}
}
counter++;
}
/*Tracker objectTracker = TrackerKCF.create();
Rect2d boundingBox = new Rect2d(405, 177, 38, 38);
int counter = 0;
objectTracker.init(frame, boundingBox);
while (video.read(frame) && counter < 200) {
boolean trackingSuccess = objectTracker.update(frame, boundingBox);
BufferedImage screencapture = Mat2BufferedImage(frame);
boolean ironOreDetected = objectDetector.isObjectPresentInBoundingBoxInImage(screencapture, boundingBox, "ironOre", counter);
ImageIO.write(screencapture, "jpg", new File("/home/dpapp/Videos/frames/frame_" + counter + ".jpg"));
System.out.println(trackingSuccess + ", ironOre: " + ironOreDetected);
counter++;
}*/
}
private BufferedImage Mat2BufferedImage(Mat matrix)throws Exception {
MatOfByte mob=new MatOfByte();
Imgcodecs.imencode(".jpg", matrix, mob);
byte ba[]=mob.toArray();
BufferedImage bi=ImageIO.read(new ByteArrayInputStream(ba));
return bi;
}
}

View File

@ -3,11 +3,14 @@ import java.awt.Point;
import java.io.IOException;
import java.net.URL;
import org.opencv.core.Core;
public class main {
public static void main(String[] args) throws Exception {
System.out.println("Starting Iron Miner.");
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
IronMiner ironMiner = new IronMiner();
ironMiner.run();
/*Cursor cursor = new Cursor();

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.