Refactored object detection, added new classifier class (ores)

This commit is contained in:
davpapp 2018-02-23 09:26:15 -05:00
parent 33761e183a
commit e3ee4e8878
15 changed files with 165 additions and 85 deletions

View File

@ -125,6 +125,12 @@ public class Cursor {
return randomizedGoalPoint; // Return the point we moved to in case we need precise movement afterwards
}
public Point moveCursorToCoordinatesWithRandomness(Point goalPoint, int xTolerance, int yTolerance) throws Exception {
Point randomizedGoalPoint = randomizePoint(goalPoint, xTolerance, yTolerance);
moveCursorToCoordinates(randomizedGoalPoint);
return randomizedGoalPoint; // Return the point we moved to in case we need precise movement afterwards
}
public void moveCursorToCoordinates(Point goalPoint) throws Exception {
Point startingPoint = getCurrentCursorPoint();
int distanceToMoveCursor = getDistanceBetweenPoints(startingPoint, goalPoint);

49
src/DetectedObject.java Normal file
View File

@ -0,0 +1,49 @@
import java.awt.Point;
import java.awt.Rectangle;
public class DetectedObject {
private Rectangle boundingBox;
private float detectionScore;
private String detectionClass;
public DetectedObject(float detectionScore, float detectionClass, float[] detectionBox) {
this.boundingBox = initializeBoundingBox(detectionBox);
this.detectionScore = detectionScore;
this.detectionClass = initializeLabel(detectionClass);
}
private Rectangle initializeBoundingBox(float[] detectionBox) {
int offset_x = (int) (detectionBox[1] * Constants.GAME_WINDOW_WIDTH);
int offset_y = (int) (detectionBox[0] * Constants.GAME_WINDOW_HEIGHT);
int width = (int) (detectionBox[3] * Constants.GAME_WINDOW_WIDTH) - offset_x;
int height = (int) (detectionBox[2] * Constants.GAME_WINDOW_HEIGHT) - offset_y;
//System.out.println(detectionBox[0] + ", " + detectionBox[1] + ", " + detectionBox[2] + ", " + detectionBox[3]);
return new Rectangle(offset_x, offset_y, width, height);
}
private String initializeLabel(float detectionClass) {
// TODO: actually load these from a file
String[] labels = {"NA", "ironOre", "ore"};
return labels[(int) detectionClass];
}
public String getDetectionClass() {
return detectionClass;
}
public Rectangle getBoundingBox() {
return boundingBox;
}
public Point getCenterForClicking() {
return new Point(boundingBox.x + boundingBox.width / 2 + Constants.GAME_WINDOW_OFFSET_X, boundingBox.y + boundingBox.height / 2 + Constants.GAME_WINDOW_OFFSET_Y);
}
public void display() {
System.out.println(detectionClass + " with score " + detectionScore + " at (" + getCenterForClicking().x + "," + getCenterForClicking().y + ")");
}
}

View File

@ -12,6 +12,7 @@ public class ImageCollector {
public String screenshotOutputDirectory;
public Rectangle gameWindowRectangle;
public Robot robot;
/*
* Methods needed:
@ -23,13 +24,14 @@ public class ImageCollector {
* detect last file name
*/
public ImageCollector(String screenshotOutputDirectory) {
public ImageCollector(String screenshotOutputDirectory) throws AWTException {
initializeGameWindowRectangle();
this.screenshotOutputDirectory = screenshotOutputDirectory;
this.robot = new Robot();
}
private void initializeGameWindowRectangle() {
this.gameWindowRectangle = new Rectangle(103, 85, 510, 330);
this.gameWindowRectangle = new Rectangle(Constants.GAME_WINDOW_OFFSET_X, Constants.GAME_WINDOW_OFFSET_Y, Constants.GAME_WINDOW_WIDTH, Constants.GAME_WINDOW_HEIGHT);
}
public void collectImages(String itemName) throws IOException, InterruptedException, AWTException {
@ -66,7 +68,6 @@ public class ImageCollector {
}
private void captureAndSaveGameWindow(String itemName, int fileCounter) throws IOException, InterruptedException, AWTException {
Robot robot = new Robot();
BufferedImage imageCaptured = robot.createScreenCapture(gameWindowRectangle);
String fileName = getFileName(itemName, fileCounter);
ImageIO.write(imageCaptured, "jpg", new File(fileName));
@ -85,8 +86,8 @@ public class ImageCollector {
public static void main(String[] args) throws Exception
{
ImageCollector imageCollector = new ImageCollector("/home/dpapp/Desktop/RunescapeAI/TensorFlow/IronOre/");
//imageCollector.collectImages("ironOre");
imageCollector.generateInventoryImages();
ImageCollector imageCollector = new ImageCollector("/home/dpapp/Desktop/RunescapeAI/TensorFlow/Ores/Images/");
imageCollector.collectImages("ore");
//imageCollector.generateInventoryImages();
}
}

View File

@ -9,11 +9,13 @@ import java.util.ArrayList;
import javax.imageio.ImageIO;
import org.opencv.core.Rect2d;
public class IronMiner {
public static final int IRON_ORE_MINING_TIME_MILLISECONDS = 2738;
public static final int MAXIMUM_DISTANCE_TO_WALK_TO_IRON_ORE = 400;
public static final Point GAME_WINDOW_CENTER = new Point(510 / 2, 330 / 2);
public static final Point GAME_WINDOW_CENTER = new Point(Constants.GAME_WINDOW_WIDTH / 2, Constants.GAME_WINDOW_HEIGHT / 2);
Cursor cursor;
CursorTask cursorTask;
@ -35,13 +37,15 @@ public class IronMiner {
public void run() throws Exception {
while (true) {
//Thread.sleep(250);
String filename = "/home/dpapp/Desktop/RunescapeAI/temp/screenshot.jpg";
BufferedImage image = captureScreenshotGameWindow();
ImageIO.write(image, "jpg", new File(filename));
mineClosestIronOre(filename);
objectDetector.update();
ArrayList<DetectedObject> ironOres = objectDetector.getRecognizedObjectsOfClassFromImage("ironOre");
ArrayList<DetectedObject> ores = objectDetector.getRecognizedObjectsOfClassFromImage("ore");
System.out.println(ironOres.size() + " ironOres, " + ores.size() + " ores.");
/*for (DetectedObject ironOre : ironOres) {
ironOre.display();
}*/
mineClosestIronOre(ironOres, ores);
dropInventoryIfFull();
}
}
@ -53,47 +57,56 @@ public class IronMiner {
}
}
private void mineClosestIronOre(String filename) throws Exception {
private void mineClosestIronOre(ArrayList<DetectedObject> ironOres, ArrayList<DetectedObject> ores) throws Exception {
DetectedObject closestIronOre = getClosestObjectToCharacter(ironOres);
if (closestIronOre != null) {
cursor.moveAndLeftClickAtCoordinatesWithRandomness(closestIronOre.getCenterForClicking(), 10, 10);
Thread.sleep(84, 219);
DetectedObject closestOre = getClosestObjectToCharacter(ores);
if (closestOre != null) {
cursor.moveCursorToCoordinatesWithRandomness(closestOre.getCenterForClicking(), 10, 10);
}
Thread.sleep(randomizer.nextGaussianWithinRange(IRON_ORE_MINING_TIME_MILLISECONDS - 250, IRON_ORE_MINING_TIME_MILLISECONDS + -50));
}
//Thread.sleep(randomizer.nextGaussianWithinRange(150, 350));
//cursor.moveCursorToCoordinates(goalPoint);
}
/*private void mineClosestIronOre(String filename) throws Exception {
Point ironOreLocation = getClosestIronOre(filename);
/*if (ironOreLocation == null) {
Thread.sleep(1000);
}*/
if (ironOreLocation != null) {
System.out.println("Mineable iron at (" + (ironOreLocation.x + 103) + "," + (ironOreLocation.y + 85) + ")");
Point actualIronOreLocation = new Point(ironOreLocation.x + 103, ironOreLocation.y + 85);
Rect2d trackerBoundingBox = new Rec2d();
//Rectangle trackerBoundingBox = new Rectangle(ironOreLocation.x - 10, ironOreLocation.x + 10, ironOreLocation.y - 10, ironOreLocation.y + 10);
//tracker.init(image, trackerBoundingBox);
cursor.moveAndLeftClickAtCoordinatesWithRandomness(actualIronOreLocation, 12, 12);
Thread.sleep(randomizer.nextGaussianWithinRange(IRON_ORE_MINING_TIME_MILLISECONDS - 350, IRON_ORE_MINING_TIME_MILLISECONDS + 150));
}
}
private Point getClosestIronOre(String filename) throws IOException {
ArrayList<Point> ironOreLocations = objectDetector.getIronOreLocationsFromImage(filename);
System.out.println(ironOreLocations.size());
int closestDistanceToIronOreFromCharacter = Integer.MAX_VALUE;
Point closestIronOreToCharacter = null;
for (Point ironOreLocation : ironOreLocations) {
int distanceToIronOreFromCharacter = getDistanceBetweenPoints(GAME_WINDOW_CENTER, ironOreLocation);
if (distanceToIronOreFromCharacter < closestDistanceToIronOreFromCharacter) {
closestDistanceToIronOreFromCharacter = distanceToIronOreFromCharacter;
closestIronOreToCharacter = new Point(ironOreLocation.x, ironOreLocation.y);
}*/
private DetectedObject getClosestObjectToCharacter(ArrayList<DetectedObject> detectedObjects) {
int closestDistanceToCharacter = Integer.MAX_VALUE;
DetectedObject closestObjectToCharacter = null;
for (DetectedObject detectedObject : detectedObjects) {
int objectDistanceToCharacter = getDistanceBetweenPoints(GAME_WINDOW_CENTER, detectedObject.getCenterForClicking());
if (objectDistanceToCharacter < closestDistanceToCharacter) {
closestDistanceToCharacter = objectDistanceToCharacter;
closestObjectToCharacter = detectedObject;
}
}
if (closestIronOreToCharacter != null && closestDistanceToIronOreFromCharacter < MAXIMUM_DISTANCE_TO_WALK_TO_IRON_ORE) {
return closestIronOreToCharacter;
if (closestObjectToCharacter != null && closestDistanceToCharacter < MAXIMUM_DISTANCE_TO_WALK_TO_IRON_ORE) {
return closestObjectToCharacter;
}
return null;
}
public int getDistanceBetweenPoints(Point startingPoint, Point goalPoint) {
return (int) (Math.hypot(goalPoint.x - startingPoint.x, goalPoint.y - startingPoint.y));
}
private BufferedImage captureScreenshotGameWindow() throws IOException {
Rectangle area = new Rectangle(103, 85, 510, 330);
return robot.createScreenCapture(area);
}
}

View File

@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import java.awt.AWTException;
import java.awt.Point;
import java.awt.Rectangle;
import java.awt.Robot;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.File;
@ -24,6 +27,7 @@ import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.imageio.ImageIO;
@ -38,16 +42,28 @@ import org.tensorflow.types.UInt8;
public class ObjectDetector {
SavedModelBundle model;
ArrayList<DetectedObject> detectedObjects;
Robot robot;
public ObjectDetector() {
model = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_23826/saved_model/", "serve");
public ObjectDetector() throws AWTException {
this.model = SavedModelBundle.load("/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_22948/saved_model/", "serve");
this.detectedObjects = new ArrayList<DetectedObject>();
this.robot = new Robot();
}
public ArrayList<Point> getIronOreLocationsFromImage(String filename) throws IOException {
public void update() throws Exception {
// TODO: eliminate IO and pass BufferedImage directly.
String fileName = "/home/dpapp/Desktop/RunescapeAI/temp/screenshot.jpg";
BufferedImage image = captureScreenshotGameWindow();
ImageIO.write(image, "jpg", new File(fileName));
this.detectedObjects = getRecognizedObjectsFromImage(fileName);
}
private ArrayList<DetectedObject> getRecognizedObjectsFromImage(String fileName) throws Exception {
List<Tensor<?>> outputs = null;
ArrayList<Point> ironOreLocations = new ArrayList<Point>();
try (Tensor<UInt8> input = makeImageTensor(filename)) {
ArrayList<DetectedObject> detectedObjectsInImage = new ArrayList<DetectedObject>();
try (Tensor<UInt8> input = makeImageTensor(fileName)) {
outputs =
model
.session()
@ -62,47 +78,30 @@ public class ObjectDetector {
try (Tensor<Float> scoresT = outputs.get(0).expect(Float.class);
Tensor<Float> classesT = outputs.get(1).expect(Float.class);
Tensor<Float> boxesT = outputs.get(2).expect(Float.class)) {
// All these tensors have:
// - 1 as the first dimension
// - maxObjects as the second dimension
// While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
// This can be verified by looking at scoresT.shape() etc.
int maxObjects = (int) scoresT.shape()[1];
float[] scores = scoresT.copyTo(new float[1][maxObjects])[0];
float[] classes = classesT.copyTo(new float[1][maxObjects])[0];
float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0];
// Print all objects whose score is at least 0.5.
boolean foundSomething = false;
for (int i = 0; i < scores.length; ++i) {
if (scores[i] < 0.75) {
continue;
if (scores[i] > 0.80) {
detectedObjectsInImage.add(new DetectedObject(scores[i], classes[i], boxes[i]));
}
foundSomething = true;
//System.out.printf("\tFound %-20s (score: %.4f)\n", "ironOre", scores[i]);
//System.out.println("X:" + 510 * boxes[i][1] + ", Y:" + 330 * boxes[i][0] + ", width:" + 510 * boxes[i][3] + ", height:" + 330 * boxes[i][2]);
ironOreLocations.add(getCenterPointFromBox(boxes[i]));
}
if (!foundSomething) {
System.out.println("No objects detected with a high enough score.");
}
}
return ironOreLocations;
return detectedObjectsInImage;
}
private Point getCenterPointFromBox(float[] box) {
int x = (int) (510 * (box[1] + box[3]) / 2);
int y = (int) (330 * (box[0] + box[2]) / 2);
return new Point(x, y);
public ArrayList<DetectedObject> getRecognizedObjectsOfClassFromImage(String objectClass) {
ArrayList<DetectedObject> detectedObjectsOfType = new ArrayList<DetectedObject>();
for (DetectedObject detectedObject : this.detectedObjects) {
if (detectedObject.getDetectionClass().equals(objectClass)) {
detectedObjectsOfType.add(detectedObject);
}
}
return detectedObjectsOfType;
}
private static void bgr2rgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
private static Tensor<UInt8> makeImageTensor(String filename) throws IOException {
BufferedImage img = ImageIO.read(new File(filename));
@ -111,8 +110,6 @@ public class ObjectDetector {
String.format(
"Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust"));
}
//System.out.println("Image is of type RGB? " + (img.getType() == BufferedImage.TYPE_INT_RGB));
//System.out.println("Image is of type RGB? " + (img.getType() == BufferedImage.TYPE_3BYTE_BGR));
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
// ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
@ -122,4 +119,17 @@ public class ObjectDetector {
long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}
}
private static void bgr2rgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
private BufferedImage captureScreenshotGameWindow() throws IOException {
Rectangle area = new Rectangle(103, 85, 510, 330);
return robot.createScreenCapture(area);
}
}

View File

@ -7,4 +7,6 @@ public class Paths {
public static final String CURSOR_COORDINATES_FILE_PATH = "/home/dpapp/GhostMouse/coordinates.txt";
public static final String TENSORFLOW_MODEL_DIRECTORY = "/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_23826/saved_model/";
public static final String TENSORFLOW_TRAINING_IMAGE_OUTPUT_DIRECTORY = "/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/results/checkpoint_23826/saved_model/";
public static final String TENSORFLOW_MODEL_LABELS_FILE_PATH = "/home/dpapp/tensorflow-1.5.0/models/raccoon_dataset/training/labels.pbtxt";
}

View File

@ -33,9 +33,8 @@ public class cascadeTrainingImageCollector {
public void captureWindowEveryNMilliseconds(int n) throws InterruptedException, IOException {
for (int i = 0; i < 1000; i++) {
captureScreenshotGameWindow(i);
System.out.println(i);
//System.out.println("Created image: " + getImageName(i));
Thread.sleep(n * 1000);
System.out.println("Created image: " + getImageName(i));
Thread.sleep(n);
}
}
@ -63,7 +62,7 @@ public class cascadeTrainingImageCollector {
return imageOutputDirectory + "screenshot" + counter + ".jpg";
}
private void resizeImagesInDirectory() throws IOException {
/*private void resizeImagesInDirectory() throws IOException {
File folder = new File("/home/dpapp/Desktop/RunescapeAIPics/CascadeTraining/Testing/");
File[] listOfFiles = folder.listFiles();
@ -73,21 +72,21 @@ public class cascadeTrainingImageCollector {
System.out.println("Cropped " + listOfFiles[i].getName());
}
}
}
}*/
private void resizeImage(File imageFile, int counter) throws IOException {
/*private void resizeImage(File imageFile, int counter) throws IOException {
BufferedImage screenshot = ImageIO.read(imageFile);
//Rectangle resizeRectangle = new Rectangle(103, 85, 510, 330);
BufferedImage resizedImage = screenshot.getSubimage(103, 85, 510, 330);
ImageIO.write(resizedImage, "jpg", new File(getImageName(counter)));
}
}*/
public static void main(String[] args) throws Exception
{
System.out.println("Starting image collection...");
cascadeTrainingImageCollector imageCollector = new cascadeTrainingImageCollector("/home/dpapp/Desktop/RunescapeAIPics/CascadeTraining/CoalNegative/");
cascadeTrainingImageCollector imageCollector = new cascadeTrainingImageCollector(Paths.TENSORFLOW_TRAINING_IMAGE_OUTPUT_DIRECTORY);
//imageCollector.resizeImagesInDirectory();
imageCollector.captureWindowEveryNMilliseconds(5);;
imageCollector.captureWindowEveryNMilliseconds(2000);
//cascadeTrainingImageCollector imageCollector = new cascadeTrainingImageCollector("/home/dpapp/Desktop/RunescapeAIPics/CascadeTraining/Testing/");
//imageCollector.captureWindowEveryNMilliseconds(1);
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.