diff --git a/src/main/java/com/github/dockerjava/netty/InvocationBuilder.java b/src/main/java/com/github/dockerjava/netty/InvocationBuilder.java index d76918bcb..26b950e4e 100644 --- a/src/main/java/com/github/dockerjava/netty/InvocationBuilder.java +++ b/src/main/java/com/github/dockerjava/netty/InvocationBuilder.java @@ -27,6 +27,7 @@ import java.io.InputStream; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.CountDownLatch; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; @@ -75,6 +76,57 @@ public void onNext(Void object) { } } + /** + * Implementation of {@link ResultCallback} with the single result event expected. + */ + public static class AsyncResultCallback + extends ResultCallbackTemplate, A_RES_T> { + + private A_RES_T result = null; + + private final CountDownLatch resultReady = new CountDownLatch(1); + + @Override + public void onNext(A_RES_T object) { + onResult(object); + } + + private void onResult(A_RES_T object) { + if (resultReady.getCount() == 0) { + throw new IllegalStateException("Result has already been set"); + } + + try { + result = object; + } finally { + resultReady.countDown(); + } + } + + @Override + public void close() throws IOException { + try { + super.close(); + } finally { + resultReady.countDown(); + } + } + + /** + * Blocks until {@link ResultCallback#onNext(Object)} was called for the first time + */ + @SuppressWarnings("unchecked") + public A_RES_T awaitResult() { + try { + resultReady.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + getFirstError(); + return result; + } + } + private ChannelProvider channelProvider; private String resource; @@ -203,7 +255,7 @@ public InputStream post(final Object entity) { Channel channel = getChannel(); - ResponseCallback callback = new ResponseCallback(); + AsyncResultCallback callback = new AsyncResultCallback<>(); HttpResponseHandler responseHandler = new HttpResponseHandler(requestProvider, callback); HttpResponseStreamHandler streamHandler = new HttpResponseStreamHandler(callback); @@ -454,7 +506,7 @@ public InputStream get() { Channel channel = getChannel(); - ResponseCallback resultCallback = new ResponseCallback(); + AsyncResultCallback resultCallback = new AsyncResultCallback<>(); HttpResponseHandler responseHandler = new HttpResponseHandler(requestProvider, resultCallback); diff --git a/src/main/java/com/github/dockerjava/netty/handler/HttpResponseStreamHandler.java b/src/main/java/com/github/dockerjava/netty/handler/HttpResponseStreamHandler.java index 706228d04..596334640 100644 --- a/src/main/java/com/github/dockerjava/netty/handler/HttpResponseStreamHandler.java +++ b/src/main/java/com/github/dockerjava/netty/handler/HttpResponseStreamHandler.java @@ -6,9 +6,6 @@ import java.io.IOException; import java.io.InputStream; -import java.util.concurrent.LinkedTransferQueue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import com.github.dockerjava.api.async.ResultCallback; @@ -19,45 +16,87 @@ */ public class HttpResponseStreamHandler extends SimpleChannelInboundHandler { - private HttpResponseInputStream stream = new HttpResponseInputStream(); + private ResultCallback resultCallback; + + private final HttpResponseInputStream stream = new HttpResponseInputStream(); public HttpResponseStreamHandler(ResultCallback resultCallback) { - resultCallback.onNext(stream); + this.resultCallback = resultCallback; } @Override protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + invokeCallbackOnFirstRead(); + stream.write(msg.copy()); } + private void invokeCallbackOnFirstRead() { + if (resultCallback != null) { + resultCallback.onNext(stream); + resultCallback = null; + } + } + @Override - public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { - stream.close(); - super.channelReadComplete(ctx); + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + stream.writeComplete(); + + super.channelInactive(ctx); } public static class HttpResponseInputStream extends InputStream { - private AtomicBoolean closed = new AtomicBoolean(false); + private boolean writeCompleted = false; - private LinkedTransferQueue queue = new LinkedTransferQueue(); + private boolean closed = false; private ByteBuf current = null; - public void write(ByteBuf byteBuf) { - queue.put(byteBuf); + private final Object lock = new Object(); + + public void write(ByteBuf byteBuf) throws InterruptedException { + synchronized (lock) { + if (closed) { + return; + } + while (current != null) { + lock.wait(); + + if (closed) { + return; + } + } + current = byteBuf; + + lock.notifyAll(); + } + } + + public void writeComplete() { + synchronized (lock) { + writeCompleted = true; + + lock.notifyAll(); + } } @Override public void close() throws IOException { - closed.set(true); - super.close(); + synchronized (lock) { + closed = true; + releaseCurrent(); + + lock.notifyAll(); + } } @Override public int available() throws IOException { - poll(); - return readableBytes(); + synchronized (lock) { + poll(0); + return readableBytes(); + } } private int readableBytes() { @@ -66,34 +105,72 @@ private int readableBytes() { } else { return 0; } - } @Override public int read() throws IOException { + byte[] b = new byte[1]; + int n = read(b, 0, 1); + return n != -1 ? b[0] : -1; + } - poll(); + @Override + public int read(byte[] b, int off, int len) throws IOException { + synchronized (lock) { + off = poll(off); - if (readableBytes() == 0) { - if (closed.get()) { + if (current == null) { return -1; + } else { + int availableBytes = Math.min(len, current.readableBytes() - off); + current.readBytes(b, off, availableBytes); + return availableBytes; } } + } - if (current != null && current.readableBytes() > 0) { - return current.readByte() & 0xff; - } else { - return read(); + private int poll(int off) throws IOException { + synchronized (lock) { + while (readableBytes() <= off) { + try { + if (closed) { + throw new IOException("Stream closed"); + } + + off -= releaseCurrent(); + if (writeCompleted) { + return off; + } + while (current == null) { + lock.wait(); + + if (closed) { + throw new IOException("Stream closed"); + } + if (writeCompleted && current == null) { + return off; + } + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + return off; } } - private void poll() { - if (readableBytes() == 0) { - try { - current = queue.poll(50, TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - throw new RuntimeException(e); + private int releaseCurrent() { + synchronized (lock) { + if (current != null) { + int n = current.readableBytes(); + current.release(); + current = null; + + lock.notifyAll(); + + return n; } + return 0; } } } diff --git a/src/test/java/com/github/dockerjava/netty/exec/SaveImageCmdExecTest.java b/src/test/java/com/github/dockerjava/netty/exec/SaveImageCmdExecTest.java index 0527b793f..a2fab38a7 100644 --- a/src/test/java/com/github/dockerjava/netty/exec/SaveImageCmdExecTest.java +++ b/src/test/java/com/github/dockerjava/netty/exec/SaveImageCmdExecTest.java @@ -47,8 +47,9 @@ public void afterMethod(ITestResult result) { @Test public void saveImage() throws Exception { - InputStream image = IOUtils.toBufferedInputStream(dockerClient.saveImageCmd("busybox").exec()); - assertThat(image.available(), greaterThan(0)); + try (InputStream image = dockerClient.saveImageCmd("busybox").exec()) { + assertThat(image.available(), greaterThan(0)); + } } diff --git a/src/test/java/com/github/dockerjava/netty/handler/HttpResponseStreamHandlerTest.java b/src/test/java/com/github/dockerjava/netty/handler/HttpResponseStreamHandlerTest.java index 6652f3eba..eea9ddc0f 100644 --- a/src/test/java/com/github/dockerjava/netty/handler/HttpResponseStreamHandlerTest.java +++ b/src/test/java/com/github/dockerjava/netty/handler/HttpResponseStreamHandlerTest.java @@ -1,13 +1,21 @@ package com.github.dockerjava.netty.handler; +import static com.github.dockerjava.netty.handler.HttpResponseStreamHandler.HttpResponseInputStream; +import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; -import java.io.InputStream; - import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; + +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + import org.apache.commons.io.IOUtils; import org.mockito.Mockito; import org.testng.annotations.Test; @@ -25,9 +33,81 @@ public void testNoBytesSkipped() throws Exception { ChannelHandlerContext ctx = Mockito.mock(ChannelHandlerContext.class); ByteBuf buffer = generateByteBuf(); streamHandler.channelRead0(ctx, buffer); - streamHandler.channelReadComplete(ctx); + streamHandler.channelInactive(ctx); + + try (InputStream inputStream = callback.getInputStream()) { + assertTrue(IOUtils.contentEquals(inputStream, new ByteBufInputStream(buffer))); + } + } + + @Test + public void testReadByteByByte() throws Exception { + ResultCallbackTest callback = new ResultCallbackTest(); + HttpResponseStreamHandler streamHandler = new HttpResponseStreamHandler(callback); + ChannelHandlerContext ctx = Mockito.mock(ChannelHandlerContext.class); + ByteBuf buffer = generateByteBuf(); + streamHandler.channelRead0(ctx, buffer); + streamHandler.channelInactive(ctx); + + try (InputStream inputStream = callback.getInputStream()) { + for (int i = 0; i < buffer.readableBytes(); i++) { + int b = inputStream.read(); + assertEquals(b, buffer.getByte(i)); + } + assertTrue(inputStream.read() == -1); + } + } + + @Test + public void testCloseResponseStreamBeforeWrite() throws Exception { + HttpResponseInputStream inputStream = new HttpResponseInputStream(); + ByteBuf buffer = generateByteBuf(); + + inputStream.write(buffer); + inputStream.close(); + inputStream.write(buffer); + } + + @Test + public void testCloseResponseStreamOnWrite() throws Exception { + final HttpResponseInputStream inputStream = new HttpResponseInputStream(); + + final ByteBuf buffer = generateByteBuf(); + + final CountDownLatch firstWrite = new CountDownLatch(1); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future submit = executor.submit(new Runnable() { + @Override + public void run() { + try { + inputStream.write(buffer); + firstWrite.countDown(); + inputStream.write(buffer); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + }); + + firstWrite.await(); + assertTrue(inputStream.available() > 0); + + // second write should have started + Thread.sleep(500L); + inputStream.close(); + + submit.get(); + } + + @Test(expectedExceptions = IOException.class) + public void testReadClosedResponseStream() throws Exception { + HttpResponseInputStream inputStream = new HttpResponseInputStream(); + ByteBuf buffer = generateByteBuf(); - assertTrue(IOUtils.contentEquals(callback.getInputStream(), new ByteBufInputStream(buffer))); + inputStream.write(buffer); + inputStream.close(); + inputStream.read(); } private ByteBuf generateByteBuf() { @@ -46,7 +126,7 @@ public void onNext(InputStream stream) { this.stream = stream; } - public InputStream getInputStream() { + private InputStream getInputStream() { return stream; } }