diff --git a/src/DetectedObject.java b/src/DetectedObject.java index fc334c3..6797d0f 100644 --- a/src/DetectedObject.java +++ b/src/DetectedObject.java @@ -7,7 +7,7 @@ import org.opencv.core.Rect2d; public class DetectedObject { - private Rectangle boundingBox; + private Rect2d boundingBox; private float detectionScore; private String detectionClass; @@ -17,14 +17,13 @@ public class DetectedObject { this.detectionClass = initializeLabel(detectionClass); } - private Rectangle initializeBoundingBox(float[] detectionBox) { + // TODO: migrate this all to a Rect2d data type + private Rect2d initializeBoundingBox(float[] detectionBox) { int offset_x = (int) (detectionBox[1] * Constants.GAME_WINDOW_WIDTH); int offset_y = (int) (detectionBox[0] * Constants.GAME_WINDOW_HEIGHT); int width = (int) (detectionBox[3] * Constants.GAME_WINDOW_WIDTH) - offset_x; int height = (int) (detectionBox[2] * Constants.GAME_WINDOW_HEIGHT) - offset_y; - - //System.out.println(detectionBox[0] + ", " + detectionBox[1] + ", " + detectionBox[2] + ", " + detectionBox[3]); - return new Rectangle(offset_x, offset_y, width, height); + return new Rect2d(offset_x, offset_y, width, height); } private String initializeLabel(float detectionClass) { @@ -37,16 +36,12 @@ public class DetectedObject { return detectionClass; } - public Rectangle getBoundingRectangle() { - return boundingBox; - } - public Rect2d getBoundingRect2d() { - return new Rect2d(boundingBox.x, boundingBox.y, boundingBox.x + boundingBox.width, boundingBox.y + boundingBox.height); + return new Rect2d(boundingBox.x, boundingBox.y, boundingBox.width, boundingBox.height); } public Point getCenterForClicking() { - return new Point(boundingBox.x + boundingBox.width / 2 + Constants.GAME_WINDOW_OFFSET_X, boundingBox.y + boundingBox.height / 2 + Constants.GAME_WINDOW_OFFSET_Y); + return new Point((int) (boundingBox.x + boundingBox.width / 2 + Constants.GAME_WINDOW_OFFSET_X), (int) (boundingBox.y + boundingBox.height / 2 + Constants.GAME_WINDOW_OFFSET_Y)); } public void display() { diff --git a/src/IronMiner.java b/src/IronMiner.java index f9cd964..f21b98a 100644 --- a/src/IronMiner.java +++ b/src/IronMiner.java @@ -1,4 +1,5 @@ import java.awt.AWTException; +import java.awt.Graphics2D; import java.awt.Point; import java.awt.Rectangle; import java.awt.Robot; @@ -10,6 +11,8 @@ import java.util.ArrayList; import javax.imageio.ImageIO; +import org.opencv.core.Core; +import org.opencv.core.CvType; import org.opencv.core.Mat; import org.opencv.core.Rect2d; import org.opencv.tracking.Tracker; @@ -30,29 +33,25 @@ public class IronMiner { public IronMiner() throws AWTException, IOException { - //cursor = new Cursor(); - //cursorTask = new CursorTask(); - //inventory = new Inventory(); + cursor = new Cursor(); + cursorTask = new CursorTask(); + inventory = new Inventory(); objectDetector = new ObjectDetector(); robot = new Robot(); randomizer = new Randomizer(); } public void run() throws Exception { - int count = 0; - long mineStartTime = System.currentTimeMillis(); - - while (System.currentTimeMillis() - 60000 < mineStartTime) { - count++; + while (true) { BufferedImage screenCapture = objectDetector.captureScreenshotGameWindow(); ArrayList detectedObjects = objectDetector.getObjectsInImage(screenCapture); - //ArrayList ironOres = objectDetector.getObjectsOfClassInList(detectedObjects, "ironOre"); - System.out.println("Count: " + count); - System.out.println(detectedObjects.size()); - /*DetectedObject closestIronOre = getClosestObjectToCharacter(ironOres); + ArrayList ironOres = objectDetector.getObjectsOfClassInList(detectedObjects, "ironOre"); + + DetectedObject closestIronOre = getClosestObjectToCharacter(ironOres); if (closestIronOre != null) { - //Tracker objectTracker = TrackerKCF.create(); - //Rect2d boundingBox = closestIronOre.getBoundingRect2d(); + System.out.println("Found iron ore! Starting tracking!"); + Tracker objectTracker = TrackerKCF.create(); + Rect2d boundingBox = closestIronOre.getBoundingRect2d(); objectTracker.init(getMatFromBufferedImage(screenCapture), boundingBox); cursor.moveAndLeftClickAtCoordinatesWithRandomness(closestIronOre.getCenterForClicking(), 10, 10); @@ -61,17 +60,26 @@ public class IronMiner { int maxTimeToMine = randomizer.nextGaussianWithinRange(3500, 5000); // track until either we lose the object or too much time passes - while ((System.currentTimeMillis() - mineStartTime) < maxTimeToMine) { + int lostTrackCounter = 0; + while (((System.currentTimeMillis() - mineStartTime) < maxTimeToMine) && lostTrackCounter < 3) { + screenCapture = objectDetector.captureScreenshotGameWindow(); + detectedObjects = objectDetector.getObjectsInImage(screenCapture); + boolean ok = objectTracker.update(getMatFromBufferedImage(screenCapture), boundingBox); - if (!ok || !objectDetector.isObjectPresentInBoundingBoxInImage(screenCapture, boundingBox, "ironOre")) { - System.out.println("Lost track! Finding new ore."); - break; + if (!ok || !objectDetector.isObjectPresentInBoundingBoxInImage(detectedObjects, boundingBox, "ironOre")) { + System.out.println("Lost track for + " + lostTrackCounter + "! Finding new ore soon."); + lostTrackCounter++; } + else if (ok) { + lostTrackCounter = 0; + System.out.println("Tracking at " + boundingBox.x + ", " + boundingBox.y + ", " + boundingBox.width + ", " + boundingBox.height); + } + } } - dropInventoryIfFull();*/ + dropInventoryIfFull(); } } @@ -108,13 +116,32 @@ public class IronMiner { return null; } - public Mat getMatFromBufferedImage(BufferedImage image) { - Mat matImage = new Mat(); - byte[] pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData(); + private Mat getMatFromBufferedImage(BufferedImage image) { + BufferedImage formattedImage = convertBufferedImage(image, BufferedImage.TYPE_3BYTE_BGR); + byte[] data = ((DataBufferByte) formattedImage.getData().getDataBuffer()).getData(); + bgr2rgb(data); + Mat matImage = new Mat(formattedImage.getWidth(), formattedImage.getHeight(), CvType.CV_8UC3); + byte[] pixels = ((DataBufferByte) formattedImage.getRaster().getDataBuffer()).getData(); matImage.put(0, 0, pixels); return matImage; } + private static BufferedImage convertBufferedImage(BufferedImage sourceImage, int bufferedImageType) { + BufferedImage image = new BufferedImage(sourceImage.getWidth(), sourceImage.getHeight(), bufferedImageType); + Graphics2D g2d = image.createGraphics(); + g2d.drawImage(sourceImage, 0, 0, null); + g2d.dispose(); + return image; + } + + private static void bgr2rgb(byte[] data) { + for (int i = 0; i < data.length; i += 3) { + byte tmp = data[i]; + data[i] = data[i + 2]; + data[i + 2] = tmp; + } + } + public int getDistanceBetweenPoints(Point startingPoint, Point goalPoint) { return (int) (Math.hypot(goalPoint.x - startingPoint.x, goalPoint.y - startingPoint.y)); } diff --git a/src/ObjectDetector.java b/src/ObjectDetector.java index dc02825..eb70c94 100644 --- a/src/ObjectDetector.java +++ b/src/ObjectDetector.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.Map; import javax.imageio.ImageIO; +import org.opencv.core.Core; import org.opencv.core.Rect2d; import org.tensorflow.SavedModelBundle; import org.tensorflow.Tensor; @@ -48,7 +49,7 @@ public class ObjectDetector { Robot robot; public ObjectDetector() throws AWTException { - this.model = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_22948/saved_model/", "serve"); + this.model = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_56749/saved_model/", "serve"); this.robot = new Robot(); } @@ -94,10 +95,26 @@ public class ObjectDetector { return detectedObjectsInImage; } - public boolean isObjectPresentInBoundingBoxInImage(BufferedImage image, Rect2d boundingBox, String objectClass) throws Exception { + /*public boolean isObjectPresentInBoundingBoxInImage(BufferedImage image, Rect2d boundingBox, String objectClass) throws Exception { BufferedImage subImage = image.getSubimage((int) boundingBox.x, (int) boundingBox.y, (int) boundingBox.width, (int) boundingBox.height); ArrayList detectedObjectsInSubImage = getObjectsInImage(subImage); return (getObjectsOfClassInList(detectedObjectsInSubImage, objectClass).size() != 0); + }*/ + + public boolean isObjectPresentInBoundingBoxInImage(ArrayList detectedObjects, Rect2d boundingBox, String objectClass) throws Exception { + for (DetectedObject detectedObject : detectedObjects) { + if (detectedObject.getDetectionClass().equals(objectClass)) { + //System.out.println(("Required bounding box: " + (int) boundingBox.x + ", " + (int) boundingBox.y + ", " + (int) boundingBox.width + ", " + (int) boundingBox.height)); + //System.out.println(("Detected bounding box: " + (int) detectedObject.getBoundingRect2d().x + ", " + (int) detectedObject.getBoundingRect2d().y + ", " + (int) detectedObject.getBoundingRect2d().width + ", " + (int) detectedObject.getBoundingRect2d().height) + "\n"); + if ((Math.abs(detectedObject.getBoundingRect2d().x - boundingBox.x) < 10) && + (Math.abs(detectedObject.getBoundingRect2d().y - boundingBox.y) < 10) && + (Math.abs(detectedObject.getBoundingRect2d().width - boundingBox.width) < 10) && + (Math.abs(detectedObject.getBoundingRect2d().height - boundingBox.height) < 10)) { + return true; + } + } + } + return false; } public ArrayList getObjectsOfClassInList(ArrayList detectedObjects, String objectClass) { @@ -111,17 +128,11 @@ public class ObjectDetector { } private static Tensor makeImageTensor(BufferedImage image) throws IOException { - /*if (image.getType() != BufferedImage.TYPE_3BYTE_BGR) { - throw new IOException( - String.format( - "Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust")); - }*/ - BufferedImage formattedImage = convertBufferedImage(image, BufferedImage.TYPE_3BYTE_BGR); byte[] data = ((DataBufferByte) formattedImage.getData().getDataBuffer()).getData(); bgr2rgb(data); - // ImageIO.read seems to produce BGR-encoded images, but the model expects RGB. + // BufferedImage and ImageIO.read() seems to produce BGR-encoded images, but the model expects RGB. final long BATCH_SIZE = 1; final long CHANNELS = 3; long[] shape = new long[] {BATCH_SIZE, formattedImage.getHeight(), formattedImage.getWidth(), CHANNELS}; diff --git a/src/ObjectDetectorTest.java b/src/ObjectDetectorTest.java index fb2e595..c03493a 100644 --- a/src/ObjectDetectorTest.java +++ b/src/ObjectDetectorTest.java @@ -1,15 +1,39 @@ import static org.junit.jupiter.api.Assertions.*; import java.awt.AWTException; +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; + +import javax.imageio.ImageIO; import org.junit.jupiter.api.Test; +import org.opencv.core.Rect2d; class ObjectDetectorTest { @Test - void testObjectDetection() throws AWTException { + void testObjectDetection() throws Exception { ObjectDetector objectDetector = new ObjectDetector(); - ArrayList detectedObjects1 = objectDetector.getObjectsInImage(loadimages here in bufferedimage format)); + BufferedImage testImage1 = ImageIO.read(new File("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/test_images/ironOre_test_9.jpg")); + ArrayList detectedObjects1 = objectDetector.getObjectsInImage(testImage1); + ArrayList detectedIronOres1 = objectDetector.getObjectsOfClassInList(detectedObjects1, "ironOre"); + ArrayList detectedOres1 = objectDetector.getObjectsOfClassInList(detectedObjects1, "ore"); + + assertEquals(3, detectedIronOres1.size()); + assertEquals(2, detectedOres1.size()); + //assertDetectedObjectsAreEqual(); + } + + void assertDetectedObjectsAreEqual(DetectedObject obj1, DetectedObject obj2) { - }va + } + + void assertBoundingBoxesAreEqual(Rect2d bb1, Rect2d bb2) { + assertEquals(bb1.x, bb2.x, 3); + assertEquals(bb1.y, bb2.y, 3); + assertEquals(bb1.width, bb2.height, 3); + assertEquals(bb1.width, bb2.height, 3); + } } diff --git a/src/ObjectTrackerTest.java b/src/ObjectTrackerTest.java new file mode 100644 index 0000000..4ecd897 --- /dev/null +++ b/src/ObjectTrackerTest.java @@ -0,0 +1,96 @@ +import static org.junit.jupiter.api.Assertions.*; + +import java.awt.AWTException; +import java.awt.image.BufferedImage; +import java.io.ByteArrayInputStream; +import java.io.File; +import java.util.ArrayList; + +import javax.imageio.ImageIO; + +import org.junit.jupiter.api.Test; +import org.opencv.core.Core; +import org.opencv.core.Mat; +import org.opencv.core.MatOfByte; +import org.opencv.core.Rect2d; +import org.opencv.imgcodecs.Imgcodecs; +import org.opencv.tracking.Tracker; +import org.opencv.tracking.TrackerKCF; +import org.opencv.videoio.VideoCapture; + +class ObjectTrackerTest { + + @Test + void testObjectTracking() throws Exception { + System.loadLibrary(Core.NATIVE_LIBRARY_NAME); + VideoCapture video = new VideoCapture("/home/dpapp/Videos/gameplay-2018-02-24_10.01.00.mp4"); + System.out.println("loaded video..."); + + ObjectDetector objectDetector = new ObjectDetector(); + + Mat frame = new Mat(); + boolean frameReadSuccess = video.read(frame); + assertTrue(frameReadSuccess); + + ArrayList detectedObjects = objectDetector.getObjectsInImage(Mat2BufferedImage(frame)); + System.out.println("Tracking " + detectedObjects.size() + " objects."); + ArrayList objectTrackers = new ArrayList(); + ArrayList boundingBoxes = new ArrayList(); + for (int i = 0; i < 3; i++) { + boundingBoxes.add(detectedObjects.get(i).getBoundingRect2d()); + objectTrackers.add(TrackerKCF.create()); + objectTrackers.get(i).init(frame, boundingBoxes.get(i)); + } + //System.out.println("bounding box: " + (int) boundingBoxes.get(0).x + ", " + (int) boundingBoxes.get(0).y + ", " + (int) boundingBoxes.get(0).width + ", " + (int) boundingBoxes.get(0).height); + + int counter = 0; + while (video.read(frame)) { + for (int i = 0; i < 3; i++) { + objectTrackers.get(i).update(frame, boundingBoxes.get(i)); + boolean trackingSuccess = objectTrackers.get(i).update(frame, boundingBoxes.get(i)); + + detectedObjects = objectDetector.getObjectsInImage(Mat2BufferedImage(frame)); + //System.out.println(detectedObjects.size()); + //System.out.println((int) boundingBoxes.get(i).x + ", " + (int) boundingBoxes.get(i).y + ", " + (int) boundingBoxes.get(i).width + ", " + (int) boundingBoxes.get(i).height); + //BufferedImage subImage = screencapture.getSubimage((int) boundingBoxes.get(i).x- 10, (int) boundingBoxes.get(i).y - 10, (int) boundingBoxes.get(i).width + 20, (int) boundingBoxes.get(i).height + 20); + + boolean ironOreDetected = objectDetector.isObjectPresentInBoundingBoxInImage(detectedObjects, boundingBoxes.get(i), "ironOre"); + boolean oreDetected = objectDetector.isObjectPresentInBoundingBoxInImage(detectedObjects, boundingBoxes.get(i), "ore"); + + //ImageIO.write(screencapture, "jpg", new File("/home/dpapp/Videos/frames/frame_" + counter + ".jpg")); + //ImageIO.write(subImage, "jpg", new File("/home/dpapp/Videos/sub_frames/frame_" + counter + "_sub.jpg")); + //System.out.println("wrote file..."); + if (i == 2) { + System.out.println(trackingSuccess + ", ironOre: " + ironOreDetected + ", ore:" + oreDetected); + } + } + counter++; + } + + /*Tracker objectTracker = TrackerKCF.create(); + Rect2d boundingBox = new Rect2d(405, 177, 38, 38); + + int counter = 0; + objectTracker.init(frame, boundingBox); + while (video.read(frame) && counter < 200) { + boolean trackingSuccess = objectTracker.update(frame, boundingBox); + BufferedImage screencapture = Mat2BufferedImage(frame); + boolean ironOreDetected = objectDetector.isObjectPresentInBoundingBoxInImage(screencapture, boundingBox, "ironOre", counter); + + ImageIO.write(screencapture, "jpg", new File("/home/dpapp/Videos/frames/frame_" + counter + ".jpg")); + + System.out.println(trackingSuccess + ", ironOre: " + ironOreDetected); + counter++; + }*/ + } + + private BufferedImage Mat2BufferedImage(Mat matrix)throws Exception { + MatOfByte mob=new MatOfByte(); + Imgcodecs.imencode(".jpg", matrix, mob); + byte ba[]=mob.toArray(); + + BufferedImage bi=ImageIO.read(new ByteArrayInputStream(ba)); + return bi; + } + +} diff --git a/src/main.java b/src/main.java index b90ccbc..e3d1fc2 100644 --- a/src/main.java +++ b/src/main.java @@ -3,11 +3,14 @@ import java.awt.Point; import java.io.IOException; import java.net.URL; +import org.opencv.core.Core; + public class main { public static void main(String[] args) throws Exception { System.out.println("Starting Iron Miner."); + System.loadLibrary(Core.NATIVE_LIBRARY_NAME); IronMiner ironMiner = new IronMiner(); ironMiner.run(); /*Cursor cursor = new Cursor(); diff --git a/target/classes/DetectedObject.class b/target/classes/DetectedObject.class index b4dce75..036eb10 100644 Binary files a/target/classes/DetectedObject.class and b/target/classes/DetectedObject.class differ diff --git a/target/classes/IronMiner.class b/target/classes/IronMiner.class index d626dc6..722072e 100644 Binary files a/target/classes/IronMiner.class and b/target/classes/IronMiner.class differ diff --git a/target/classes/ObjectDetector.class b/target/classes/ObjectDetector.class index 49cbb2c..92603e3 100644 Binary files a/target/classes/ObjectDetector.class and b/target/classes/ObjectDetector.class differ diff --git a/target/classes/ObjectDetectorTest.class b/target/classes/ObjectDetectorTest.class index c7ee397..91e94c2 100644 Binary files a/target/classes/ObjectDetectorTest.class and b/target/classes/ObjectDetectorTest.class differ diff --git a/target/classes/ObjectTrackerTest.class b/target/classes/ObjectTrackerTest.class new file mode 100644 index 0000000..759d443 Binary files /dev/null and b/target/classes/ObjectTrackerTest.class differ diff --git a/target/classes/main.class b/target/classes/main.class index c0eaf81..8388c35 100644 Binary files a/target/classes/main.class and b/target/classes/main.class differ