@@ -104,6 +104,7 @@ abstract class RetriableStream<ReqT> implements ClientStream {
104104 @ GuardedBy ("lock" )
105105 private FutureCanceller scheduledHedging ;
106106 private long nextBackoffIntervalNanos ;
107+ private Status cancellationStatus ;
107108
108109 RetriableStream (
109110 MethodDescriptor <ReqT , ?> method , Metadata headers ,
@@ -244,14 +245,16 @@ private void drain(Substream substream) {
244245 int index = 0 ;
245246 int chunk = 0x80 ;
246247 List <BufferEntry > list = null ;
248+ boolean streamStarted = false ;
247249
248250 while (true ) {
249251 State savedState ;
250252
251253 synchronized (lock ) {
252254 savedState = state ;
253- if (savedState .winningSubstream != null && savedState .winningSubstream != substream ) {
254- // committed but not me
255+ if (savedState .winningSubstream != null && savedState .winningSubstream != substream
256+ && streamStarted ) {
257+ // committed but not me, to be cancelled
255258 break ;
256259 }
257260 if (index == savedState .buffer .size ()) { // I'm drained
@@ -275,17 +278,22 @@ private void drain(Substream substream) {
275278
276279 for (BufferEntry bufferEntry : list ) {
277280 savedState = state ;
278- if (savedState .winningSubstream != null && savedState .winningSubstream != substream ) {
279- // committed but not me
281+ if (savedState .winningSubstream != null && savedState .winningSubstream != substream
282+ && streamStarted ) {
283+ // committed but not me, to be cancelled
280284 break ;
281285 }
282- if (savedState .cancelled ) {
286+ if (savedState .cancelled && streamStarted ) {
283287 checkState (
284288 savedState .winningSubstream == substream ,
285289 "substream should be CANCELLED_BECAUSE_COMMITTED already" );
290+ substream .stream .cancel (cancellationStatus );
286291 return ;
287292 }
288293 bufferEntry .runWith (substream );
294+ if (bufferEntry instanceof RetriableStream .StartEntry ) {
295+ streamStarted = true ;
296+ }
289297 }
290298 }
291299
@@ -299,6 +307,13 @@ private void drain(Substream substream) {
299307 @ Nullable
300308 abstract Status prestart ();
301309
310+ class StartEntry implements BufferEntry {
311+ @ Override
312+ public void runWith (Substream substream ) {
313+ substream .stream .start (new Sublistener (substream ));
314+ }
315+ }
316+
302317 /** Starts the first PRC attempt. */
303318 @ Override
304319 public final void start (ClientStreamListener listener ) {
@@ -311,13 +326,6 @@ public final void start(ClientStreamListener listener) {
311326 return ;
312327 }
313328
314- class StartEntry implements BufferEntry {
315- @ Override
316- public void runWith (Substream substream ) {
317- substream .stream .start (new Sublistener (substream ));
318- }
319- }
320-
321329 synchronized (lock ) {
322330 state .buffer .add (new StartEntry ());
323331 }
@@ -450,11 +458,18 @@ public final void cancel(Status reason) {
450458 return ;
451459 }
452460
453- state . winningSubstream . stream . cancel ( reason ) ;
461+ Substream winningSubstreamToCancel = null ;
454462 synchronized (lock ) {
455- // This is not required, but causes a short-circuit in the draining process.
463+ if (state .drainedSubstreams .contains (state .winningSubstream )) {
464+ winningSubstreamToCancel = state .winningSubstream ;
465+ } else { // the winningSubstream will be cancelled while draining
466+ cancellationStatus = reason ;
467+ }
456468 state = state .cancelled ();
457469 }
470+ if (winningSubstreamToCancel != null ) {
471+ winningSubstreamToCancel .stream .cancel (reason );
472+ }
458473 }
459474
460475 private void delayOrExecute (BufferEntry bufferEntry ) {
0 commit comments