diff --git a/src/java/org/apache/poi/poifs/nio/ByteArrayBackedDataSource.java b/src/java/org/apache/poi/poifs/nio/ByteArrayBackedDataSource.java index 8fbb3ce10..4e368994b 100644 --- a/src/java/org/apache/poi/poifs/nio/ByteArrayBackedDataSource.java +++ b/src/java/org/apache/poi/poifs/nio/ByteArrayBackedDataSource.java @@ -32,13 +32,15 @@ public class ByteArrayBackedDataSource extends DataSource { } public void read(ByteBuffer dst, long position) { - if(position + dst.capacity() > size) { + if(position >= size) { throw new IndexOutOfBoundsException( "Unable to read " + dst.capacity() + " bytes from " + position + " in stream of length " + size ); } - dst.put(buffer, (int)position, dst.capacity()); + + int toRead = (int)Math.min(dst.capacity(), size - position); + dst.put(buffer, (int)position, toRead); } public void write(ByteBuffer src, long position) { diff --git a/src/java/org/apache/poi/poifs/nio/FileBackedDataSource.java b/src/java/org/apache/poi/poifs/nio/FileBackedDataSource.java index 7f5e8e635..19c6a3030 100644 --- a/src/java/org/apache/poi/poifs/nio/FileBackedDataSource.java +++ b/src/java/org/apache/poi/poifs/nio/FileBackedDataSource.java @@ -17,32 +17,53 @@ package org.apache.poi.poifs.nio; +import java.io.File; +import java.io.FileNotFoundException; import java.io.IOException; +import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import org.apache.poi.util.IOUtils; + /** * A POIFS {@link DataSource} backed by a File */ public class FileBackedDataSource extends DataSource { - private FileChannel file; - public FileBackedDataSource(FileChannel file) { - this.file = file; + private FileChannel channel; + + public FileBackedDataSource(File file) throws FileNotFoundException { + if(!file.exists()) { + throw new FileNotFoundException(file.toString()); + } + this.channel = (new RandomAccessFile(file, "r")).getChannel(); + } + public FileBackedDataSource(FileChannel channel) { + this.channel = channel; } public void read(ByteBuffer dst, long position) throws IOException { - file.read(dst, position); + if(position >= size()) { + throw new IllegalArgumentException("Position " + position + " past the end of the file"); + } + + channel.position(position); + int worked = IOUtils.readFully(channel, dst); + + if(worked == -1) { + throw new IllegalArgumentException("Position " + position + " past the end of the file"); + } } public void write(ByteBuffer src, long position) throws IOException { - file.write(src, position); + channel.write(src, position); } public long size() throws IOException { - return file.size(); + return channel.size(); } public void close() throws IOException { - file.close(); + channel.close(); } } diff --git a/src/testcases/org/apache/poi/poifs/nio/TestDataSource.java b/src/testcases/org/apache/poi/poifs/nio/TestDataSource.java index df039ee1d..6eb0098cb 100644 --- a/src/testcases/org/apache/poi/poifs/nio/TestDataSource.java +++ b/src/testcases/org/apache/poi/poifs/nio/TestDataSource.java @@ -19,7 +19,10 @@ package org.apache.poi.poifs.nio; -import java.io.IOException; +import java.io.File; +import java.nio.ByteBuffer; + +import org.apache.poi.POIDataSamples; import junit.framework.TestCase; @@ -28,11 +31,141 @@ import junit.framework.TestCase; */ public class TestDataSource extends TestCase { - public void testFile() throws IOException { - // TODO + private static POIDataSamples data = POIDataSamples.getPOIFSInstance(); + + public void testFile() throws Exception { + File f = data.getFile("Notes.ole2"); + + FileBackedDataSource ds = new FileBackedDataSource(f); + assertEquals(8192, ds.size()); + + // Start of file + ByteBuffer bs = ByteBuffer.allocate(4); + ds.read(bs, 0); + assertEquals(4, bs.capacity()); + assertEquals(4, bs.position()); + assertEquals(0xd0-256, bs.get(0)); + assertEquals(0xcf-256, bs.get(1)); + assertEquals(0x11-000, bs.get(2)); + assertEquals(0xe0-256, bs.get(3)); + + // Mid way through + bs = ByteBuffer.allocate(8); + ds.read(bs, 0x400); + assertEquals(8, bs.capacity()); + assertEquals(8, bs.position()); + assertEquals((byte)'R', bs.get(0)); + assertEquals(0, bs.get(1)); + assertEquals((byte)'o', bs.get(2)); + assertEquals(0, bs.get(3)); + assertEquals((byte)'o', bs.get(4)); + assertEquals(0, bs.get(5)); + assertEquals((byte)'t', bs.get(6)); + assertEquals(0, bs.get(7)); + + // Can go to the end, but not past it + bs.clear(); + ds.read(bs, 8190); + assertEquals(2, bs.position()); + + // Can't go off the end + try { + bs.clear(); + ds.read(bs, 8192); + fail("Shouldn't be able to read off the end of the file"); + } catch(IllegalArgumentException e) {} } - public void testByteArray() throws IOException { - // TODO + public void testByteArray() throws Exception { + byte[] data = new byte[256]; + byte b; + for(int i=0; i