Skip to content

Commit 9a748a0

Browse files
LinZongibauersachs
andauthored
Fix DoH initial request using recommended nanoTime calculation
Co-authored-by: Ingo Bauersachs <[email protected]>
1 parent 8444680 commit 9a748a0

2 files changed

Lines changed: 92 additions & 10 deletions

File tree

src/main/java/org/xbill/DNS/DohResolver.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ public final class DohResolver implements Resolver {
174174
USE_HTTP_CLIENT = initSuccess;
175175
}
176176

177+
// package-visible for testing
178+
long getNanoTime() {
179+
return System.nanoTime();
180+
}
181+
177182
/**
178183
* Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s).
179184
*
@@ -315,7 +320,7 @@ public CompletionStage<Message> sendAsync(Message query, Executor executor) {
315320
private CompletionStage<Message> sendAsync8(final Message query, Executor executor) {
316321
byte[] queryBytes = prepareQuery(query).toWire();
317322
String url = getUrl(queryBytes);
318-
long startTime = System.nanoTime();
323+
long startTime = getNanoTime();
319324
return maxConcurrentRequests
320325
.acquire(timeout)
321326
.handleAsync(
@@ -363,7 +368,7 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
363368
((HttpsURLConnection) conn).setSSLSocketFactory(sslSocketFactory);
364369
}
365370

366-
Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
371+
Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
367372
conn.setConnectTimeout((int) remainingTimeout.toMillis());
368373
conn.setReadTimeout((int) remainingTimeout.toMillis());
369374
conn.setRequestMethod(usePost ? "POST" : "GET");
@@ -389,7 +394,7 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
389394
int offset = 0;
390395
while ((r = is.read(responseBytes, offset, responseBytes.length - offset)) > 0) {
391396
offset += r;
392-
remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
397+
remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
393398
if (remainingTimeout.isNegative()) {
394399
throw new SocketTimeoutException();
395400
}
@@ -403,7 +408,7 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
403408
byte[] buffer = new byte[4096];
404409
int r;
405410
while ((r = is.read(buffer, 0, buffer.length)) > 0) {
406-
remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
411+
remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
407412
if (remainingTimeout.isNegative()) {
408413
throw new SocketTimeoutException();
409414
}
@@ -432,7 +437,7 @@ private void discardStream(InputStream es) throws IOException {
432437
}
433438

434439
private CompletionStage<Message> sendAsync11(final Message query, Executor executor) {
435-
long startTime = System.nanoTime();
440+
long startTime = getNanoTime();
436441
byte[] queryBytes = prepareQuery(query).toWire();
437442
String url = getUrl(queryBytes);
438443

@@ -454,7 +459,7 @@ private CompletionStage<Message> sendAsync11(final Message query, Executor execu
454459
// check if this request needs to be done synchronously because of HttpClient's stupidity to
455460
// not use the connection pool for HTTP/2 until one connection is successfully established,
456461
// which could lead to hundreds of connections (and threads with the default executor)
457-
Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
462+
Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
458463
return initialRequestLock
459464
.acquire(remainingTimeout)
460465
.handle(
@@ -476,14 +481,13 @@ private CompletionStage<Message> sendAsync11WithInitialRequestPermit(
476481
Object requestBuilder,
477482
Permit initialRequestPermit) {
478483
long lastRequestTime = lastRequest.get();
479-
boolean isInitialRequest =
480-
(lastRequestTime < System.nanoTime() - idleConnectionTimeout.toNanos());
484+
boolean isInitialRequest = idleConnectionTimeout.toNanos() > getNanoTime() - lastRequestTime;
481485
if (!isInitialRequest) {
482486
initialRequestPermit.release();
483487
}
484488

485489
// check if we already exceeded the query timeout while checking the initial connection
486-
Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
490+
Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
487491
if (remainingTimeout.isNegative()) {
488492
if (isInitialRequest) {
489493
initialRequestPermit.release();
@@ -525,7 +529,7 @@ private CompletionStage<Message> sendAsync11WithConcurrentRequestPermit(
525529
boolean isInitialRequest,
526530
Permit maxConcurrentRequestPermit) {
527531
// check if the stream lock acquisition took too long
528-
Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
532+
Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
529533
if (remainingTimeout.isNegative()) {
530534
if (isInitialRequest) {
531535
initialRequestPermit.release();

src/test/java/org/xbill/DNS/DohResolverTest.java

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import static org.junit.jupiter.api.Assertions.assertEquals;
55
import static org.junit.jupiter.api.Assertions.assertTrue;
6+
import static org.mockito.Mockito.doAnswer;
7+
import static org.mockito.Mockito.spy;
68

79
import io.netty.handler.codec.http.HttpHeaderNames;
810
import io.vertx.core.Future;
@@ -20,13 +22,20 @@
2022
import java.time.Duration;
2123
import java.util.Base64;
2224
import java.util.Collections;
25+
import java.util.concurrent.CompletionStage;
26+
import java.util.concurrent.TimeUnit;
2327
import java.util.concurrent.TimeoutException;
28+
import java.util.concurrent.atomic.AtomicBoolean;
2429
import java.util.concurrent.atomic.AtomicInteger;
30+
import java.util.concurrent.atomic.AtomicLong;
2531
import org.junit.jupiter.api.BeforeEach;
2632
import org.junit.jupiter.api.Test;
33+
import org.junit.jupiter.api.condition.EnabledForJreRange;
34+
import org.junit.jupiter.api.condition.JRE;
2735
import org.junit.jupiter.api.extension.ExtendWith;
2836
import org.junit.jupiter.params.ParameterizedTest;
2937
import org.junit.jupiter.params.provider.ValueSource;
38+
import org.mockito.stubbing.Answer;
3039

3140
@ExtendWith(VertxExtension.class)
3241
class DohResolverTest {
@@ -202,6 +211,75 @@ private Future<HttpServer> setupResolverWithServer(
202211
.onSuccess(server -> resolver.setUriTemplate("http://localhost:" + server.actualPort()));
203212
}
204213

214+
@EnabledForJreRange(
215+
min = JRE.JAVA_9,
216+
disabledReason = "Java 8 implementation doesn't have the initial request guard")
217+
@Test
218+
void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime(
219+
Vertx vertx, VertxTestContext context) {
220+
AtomicLong startNanos = new AtomicLong(System.nanoTime());
221+
resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2)));
222+
resolver.setTimeout(Duration.ofSeconds(1));
223+
// Simulate a nanoTime value that is lower than the idle timeout
224+
doAnswer((Answer<Long>) invocationOnMock -> System.nanoTime() - startNanos.get())
225+
.when(resolver)
226+
.getNanoTime();
227+
228+
// Just add a 100ms delay before responding to the 1st call
229+
// to simulate a 'concurrent doh request' for the 2nd call,
230+
// then let the fake dns server respond to the 2nd call ASAP.
231+
allRequestsUseTimeout = false;
232+
233+
// idleConnectionTimeout = 2s, lastRequest = 0L
234+
// Ensure idleConnectionTimeout < System.nanoTime() - lastRequest (3s)
235+
236+
// Timeline:
237+
// |<-------- 100ms -------->|
238+
// ↑ ↑
239+
// 1st call sent response of 1st call
240+
// |20ms|<------ 80ms ------>|<------ few millis ------->|
241+
// ↑ wait until 1st call ↑ ↑
242+
// 2nd call begin 2nd call sent response of 2nd call
243+
244+
AtomicBoolean firstCallCompleted = new AtomicBoolean(false);
245+
246+
setupResolverWithServer(Duration.ofMillis(100L), 200, 2, vertx, context)
247+
.onSuccess(
248+
server -> {
249+
// First call
250+
CompletionStage<Message> firstCall = resolver.sendAsync(qm);
251+
// Ensure second call was made after first call and uses a different query
252+
startNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(20));
253+
CompletionStage<Message> secondCall = resolver.sendAsync(Message.newQuery(qr));
254+
255+
Future.fromCompletionStage(firstCall)
256+
.onComplete(
257+
context.succeeding(
258+
result ->
259+
context.verify(
260+
() -> {
261+
assertEquals(Rcode.NOERROR, result.getHeader().getRcode());
262+
assertEquals(0, result.getHeader().getID());
263+
assertEquals(queryName, result.getQuestion().getName());
264+
firstCallCompleted.set(true);
265+
})));
266+
267+
Future.fromCompletionStage(secondCall)
268+
.onComplete(
269+
context.succeeding(
270+
result ->
271+
context.verify(
272+
() -> {
273+
assertTrue(firstCallCompleted.get());
274+
assertEquals(Rcode.NOERROR, result.getHeader().getRcode());
275+
assertEquals(0, result.getHeader().getID());
276+
assertEquals(queryName, result.getQuestion().getName());
277+
// Complete context after the 2nd call was completed.
278+
context.completeNow();
279+
})));
280+
});
281+
}
282+
205283
private Future<HttpServer> setupServer(
206284
Message expectedDnsRequest,
207285
Message dnsResponse,

0 commit comments

Comments
 (0)