Skip to content

Commit b00855b

Browse files
greyson-signalcody-signal
authored andcommitted
Add support for more methods in TruncatingInputStream.
1 parent 929942d commit b00855b

4 files changed

Lines changed: 176 additions & 3 deletions

File tree

core-util-jvm/src/main/java/org/signal/core/util/InputStreamExtensions.kt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
package org.signal.core.util
77

8+
import java.io.ByteArrayOutputStream
89
import java.io.IOException
910
import java.io.InputStream
11+
import kotlin.math.min
1012

1113
/**
1214
* Reads a 32-bit variable-length integer from the stream.
@@ -68,6 +70,29 @@ fun InputStream.readNBytesOrThrow(length: Int): ByteArray {
6870
return buffer
6971
}
7072

73+
/**
74+
* Read at most [byteLimit] bytes from the stream.
75+
*/
76+
fun InputStream.readAtMostNBytes(byteLimit: Int): ByteArray {
77+
val buffer = ByteArrayOutputStream()
78+
val readBuffer = ByteArray(4096)
79+
80+
var remaining = byteLimit
81+
while (remaining > 0) {
82+
val bytesToRead = min(remaining, readBuffer.size)
83+
val read = this.read(readBuffer, 0, bytesToRead)
84+
85+
if (read == -1) {
86+
break
87+
}
88+
89+
buffer.write(readBuffer, 0, read)
90+
remaining -= read
91+
}
92+
93+
return buffer.toByteArray()
94+
}
95+
7196
@Throws(IOException::class)
7297
fun InputStream.readLength(): Long {
7398
val buffer = ByteArray(4096)

core-util-jvm/src/main/java/org/signal/core/util/stream/TruncatingInputStream.kt

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,20 @@
55

66
package org.signal.core.util.stream
77

8+
import org.signal.core.util.readAtMostNBytes
9+
import org.signal.core.util.readFully
810
import java.io.FilterInputStream
911
import java.io.InputStream
1012
import java.lang.UnsupportedOperationException
13+
import kotlin.math.min
1114

1215
/**
1316
* An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes.
1417
*/
1518
class TruncatingInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) {
1619

1720
private var bytesRead: Long = 0
21+
private var lastMark = -1L
1822

1923
override fun read(): Int {
2024
if (bytesRead >= maxBytes) {
@@ -48,11 +52,58 @@ class TruncatingInputStream(private val wrapped: InputStream, private val maxByt
4852
return bytesRead
4953
}
5054

51-
override fun skip(n: Long): Long {
52-
throw UnsupportedOperationException()
55+
override fun skip(requestedSkipCount: Long): Long {
56+
val bytesRemaining: Long = maxBytes - bytesRead
57+
val bytesToSkip: Long = min(bytesRemaining, requestedSkipCount)
58+
59+
return super.skip(bytesToSkip).also { bytesSkipped ->
60+
if (bytesSkipped > 0) {
61+
this.bytesRead += bytesSkipped
62+
}
63+
}
64+
}
65+
66+
override fun available(): Int {
67+
val bytesRemaining = Math.toIntExact(maxBytes - bytesRead)
68+
return min(bytesRemaining, wrapped.available())
69+
}
70+
71+
override fun markSupported(): Boolean {
72+
return wrapped.markSupported()
73+
}
74+
75+
override fun mark(readlimit: Int) {
76+
if (!markSupported()) {
77+
throw UnsupportedOperationException("Mark not supported")
78+
}
79+
80+
wrapped.mark(readlimit)
81+
lastMark = bytesRead
5382
}
5483

5584
override fun reset() {
56-
throw UnsupportedOperationException()
85+
if (!markSupported()) {
86+
throw UnsupportedOperationException("Mark not supported")
87+
}
88+
89+
if (lastMark == -1L) {
90+
throw UnsupportedOperationException("Mark not set")
91+
}
92+
93+
wrapped.reset()
94+
bytesRead = lastMark
95+
}
96+
97+
/**
98+
* If the stream has been fully read, this will return all bytes that were truncated from the stream.
99+
*
100+
* @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit.
101+
*/
102+
fun readTruncatedBytes(byteLimit: Int = -1): ByteArray {
103+
return if (byteLimit < 0) {
104+
wrapped.readFully()
105+
} else {
106+
wrapped.readAtMostNBytes(byteLimit)
107+
}
57108
}
58109
}

core-util-jvm/src/test/java/org/signal/core/util/stream/TruncatingInputStreamTest.kt

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package org.signal.core.util.stream
88
import org.junit.Assert.assertEquals
99
import org.junit.Test
1010
import org.signal.core.util.readFully
11+
import org.signal.core.util.readNBytesOrThrow
1112

1213
class TruncatingInputStreamTest {
1314

@@ -32,4 +33,84 @@ class TruncatingInputStreamTest {
3233

3334
assertEquals(75, count)
3435
}
36+
37+
@Test
38+
fun `when I skip past the maxBytes, I should get -1`() {
39+
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
40+
41+
val skipCount = inputStream.skip(100)
42+
val read = inputStream.read()
43+
44+
assertEquals(75, skipCount)
45+
assertEquals(-1, read)
46+
}
47+
48+
@Test
49+
fun `when I skip, I should still truncate correctly afterwards`() {
50+
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
51+
52+
val skipCount = inputStream.skip(50)
53+
val data = inputStream.readFully()
54+
55+
assertEquals(50, skipCount)
56+
assertEquals(25, data.size)
57+
}
58+
59+
@Test
60+
fun `when I skip more than maxBytes, I only skip maxBytes`() {
61+
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
62+
63+
val skipCount = inputStream.skip(100)
64+
65+
assertEquals(75, skipCount)
66+
}
67+
68+
@Test
69+
fun `when I finish reading the stream, getTruncatedBytes gives me the rest`() {
70+
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
71+
inputStream.readFully()
72+
73+
val truncatedBytes = inputStream.readTruncatedBytes()
74+
assertEquals(25, truncatedBytes.size)
75+
}
76+
77+
@Test
78+
fun `when I finish reading the stream, getTruncatedBytes gives me the rest, respecting the byte limit`() {
79+
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
80+
inputStream.readFully()
81+
82+
val truncatedBytes = inputStream.readTruncatedBytes(byteLimit = 10)
83+
assertEquals(10, truncatedBytes.size)
84+
}
85+
86+
@Test
87+
fun `when I call available, it should respect the maxBytes`() {
88+
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
89+
val available = inputStream.available()
90+
91+
assertEquals(75, available)
92+
}
93+
94+
@Test
95+
fun `when I call available after reading some bytes, it should respect the maxBytes`() {
96+
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
97+
inputStream.readNBytesOrThrow(50)
98+
99+
val available = inputStream.available()
100+
101+
assertEquals(25, available)
102+
}
103+
104+
@Test
105+
fun `when I mark and reset, it should jump back to the correct position`() {
106+
val inputStream = TruncatingInputStream(ByteArray(100).inputStream(), maxBytes = 75)
107+
108+
inputStream.mark(100)
109+
inputStream.readNBytesOrThrow(10)
110+
inputStream.reset()
111+
112+
val data = inputStream.readFully()
113+
114+
assertEquals(75, data.size)
115+
}
35116
}

core-util/src/test/java/org/signal/core/util/InputStreamExtensionTests.kt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,20 @@ class InputStreamExtensionTests {
1919
assertEquals(bytes.size.toLong(), length)
2020
}
2121
}
22+
23+
@Test
24+
fun `when I call readAtMostNBytes, I only read that many bytes`() {
25+
val bytes = ByteArray(100)
26+
val inputStream = bytes.inputStream()
27+
val readBytes = inputStream.readAtMostNBytes(50)
28+
assertEquals(50, readBytes.size)
29+
}
30+
31+
@Test
32+
fun `when I call readAtMostNBytes, it will return at most the length of the stream`() {
33+
val bytes = ByteArray(100)
34+
val inputStream = bytes.inputStream()
35+
val readBytes = inputStream.readAtMostNBytes(200)
36+
assertEquals(100, readBytes.size)
37+
}
2238
}

0 commit comments

Comments
 (0)