@@ -749,6 +749,91 @@ public void request(int numMessages) {
749749 inOrder .verify (mockStream2 , never ()).writeMessage (any (InputStream .class ));
750750 }
751751
752+ @ Test
753+ public void cancelWhileDraining () {
754+ ArgumentCaptor <ClientStreamListener > sublistenerCaptor1 =
755+ ArgumentCaptor .forClass (ClientStreamListener .class );
756+ ClientStream mockStream1 = mock (ClientStream .class );
757+ ClientStream mockStream2 =
758+ mock (
759+ ClientStream .class ,
760+ delegatesTo (
761+ new NoopClientStream () {
762+ @ Override
763+ public void request (int numMessages ) {
764+ retriableStream .cancel (
765+ Status .CANCELLED .withDescription ("cancelled while requesting" ));
766+ }
767+ }));
768+
769+ InOrder inOrder = inOrder (retriableStreamRecorder , mockStream1 , mockStream2 );
770+ doReturn (mockStream1 ).when (retriableStreamRecorder ).newSubstream (0 );
771+ retriableStream .start (masterListener );
772+ inOrder .verify (mockStream1 ).start (sublistenerCaptor1 .capture ());
773+ retriableStream .request (3 );
774+ inOrder .verify (mockStream1 ).request (3 );
775+
776+ // retry
777+ doReturn (mockStream2 ).when (retriableStreamRecorder ).newSubstream (1 );
778+ sublistenerCaptor1 .getValue ().closed (
779+ Status .fromCode (RETRIABLE_STATUS_CODE_1 ), PROCESSED , new Metadata ());
780+ fakeClock .forwardTime ((long ) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM ), TimeUnit .SECONDS );
781+
782+ inOrder .verify (mockStream2 ).start (any (ClientStreamListener .class ));
783+ inOrder .verify (mockStream2 ).request (3 );
784+ inOrder .verify (retriableStreamRecorder ).postCommit ();
785+ ArgumentCaptor <Status > statusCaptor = ArgumentCaptor .forClass (Status .class );
786+ inOrder .verify (mockStream2 ).cancel (statusCaptor .capture ());
787+ assertThat (statusCaptor .getValue ().getCode ()).isEqualTo (Code .CANCELLED );
788+ assertThat (statusCaptor .getValue ().getDescription ())
789+ .isEqualTo ("Stream thrown away because RetriableStream committed" );
790+ verify (masterListener ).closed (
791+ statusCaptor .capture (), any (RpcProgress .class ), any (Metadata .class ));
792+ assertThat (statusCaptor .getValue ().getCode ()).isEqualTo (Code .CANCELLED );
793+ assertThat (statusCaptor .getValue ().getDescription ()).isEqualTo ("cancelled while requesting" );
794+ }
795+
796+ @ Test
797+ public void cancelWhileRetryStart () {
798+ ArgumentCaptor <ClientStreamListener > sublistenerCaptor1 =
799+ ArgumentCaptor .forClass (ClientStreamListener .class );
800+ ClientStream mockStream1 = mock (ClientStream .class );
801+ ClientStream mockStream2 =
802+ mock (
803+ ClientStream .class ,
804+ delegatesTo (
805+ new NoopClientStream () {
806+ @ Override
807+ public void start (ClientStreamListener listener ) {
808+ retriableStream .cancel (
809+ Status .CANCELLED .withDescription ("cancelled while retry start" ));
810+ }
811+ }));
812+
813+ InOrder inOrder = inOrder (retriableStreamRecorder , mockStream1 , mockStream2 );
814+ doReturn (mockStream1 ).when (retriableStreamRecorder ).newSubstream (0 );
815+ retriableStream .start (masterListener );
816+ inOrder .verify (mockStream1 ).start (sublistenerCaptor1 .capture ());
817+
818+ // retry
819+ doReturn (mockStream2 ).when (retriableStreamRecorder ).newSubstream (1 );
820+ sublistenerCaptor1 .getValue ().closed (
821+ Status .fromCode (RETRIABLE_STATUS_CODE_1 ), PROCESSED , new Metadata ());
822+ fakeClock .forwardTime ((long ) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM ), TimeUnit .SECONDS );
823+
824+ inOrder .verify (mockStream2 ).start (any (ClientStreamListener .class ));
825+ inOrder .verify (retriableStreamRecorder ).postCommit ();
826+ ArgumentCaptor <Status > statusCaptor = ArgumentCaptor .forClass (Status .class );
827+ inOrder .verify (mockStream2 ).cancel (statusCaptor .capture ());
828+ assertThat (statusCaptor .getValue ().getCode ()).isEqualTo (Code .CANCELLED );
829+ assertThat (statusCaptor .getValue ().getDescription ())
830+ .isEqualTo ("Stream thrown away because RetriableStream committed" );
831+ verify (masterListener ).closed (
832+ statusCaptor .capture (), any (RpcProgress .class ), any (Metadata .class ));
833+ assertThat (statusCaptor .getValue ().getCode ()).isEqualTo (Code .CANCELLED );
834+ assertThat (statusCaptor .getValue ().getDescription ()).isEqualTo ("cancelled while retry start" );
835+ }
836+
752837 @ Test
753838 public void operationsAfterImmediateCommit () {
754839 ArgumentCaptor <ClientStreamListener > sublistenerCaptor1 =
@@ -916,6 +1001,47 @@ public void start(ClientStreamListener listener) {
9161001 verify (mockStream3 ).request (1 );
9171002 }
9181003
1004+ @ Test
1005+ public void commitAndCancelWhileDraining () {
1006+ ClientStream mockStream1 = mock (ClientStream .class );
1007+ ClientStream mockStream2 =
1008+ mock (
1009+ ClientStream .class ,
1010+ delegatesTo (
1011+ new NoopClientStream () {
1012+ @ Override
1013+ public void start (ClientStreamListener listener ) {
1014+ // commit while draining
1015+ listener .headersRead (new Metadata ());
1016+ // cancel while draining
1017+ retriableStream .cancel (
1018+ Status .CANCELLED .withDescription ("cancelled while drained" ));
1019+ }
1020+ }));
1021+
1022+ when (retriableStreamRecorder .newSubstream (anyInt ()))
1023+ .thenReturn (mockStream1 , mockStream2 );
1024+
1025+ retriableStream .start (masterListener );
1026+
1027+ ArgumentCaptor <ClientStreamListener > sublistenerCaptor1 =
1028+ ArgumentCaptor .forClass (ClientStreamListener .class );
1029+ verify (mockStream1 ).start (sublistenerCaptor1 .capture ());
1030+
1031+ ClientStreamListener listener1 = sublistenerCaptor1 .getValue ();
1032+
1033+ // retry
1034+ listener1 .closed (
1035+ Status .fromCode (RETRIABLE_STATUS_CODE_1 ), PROCESSED , new Metadata ());
1036+ fakeClock .forwardTime ((long ) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM ), TimeUnit .SECONDS );
1037+
1038+ verify (mockStream2 ).start (any (ClientStreamListener .class ));
1039+ verify (retriableStreamRecorder ).postCommit ();
1040+ ArgumentCaptor <Status > statusCaptor = ArgumentCaptor .forClass (Status .class );
1041+ verify (mockStream2 ).cancel (statusCaptor .capture ());
1042+ assertThat (statusCaptor .getValue ().getCode ()).isEqualTo (Code .CANCELLED );
1043+ assertThat (statusCaptor .getValue ().getDescription ()).isEqualTo ("cancelled while drained" );
1044+ }
9191045
9201046 @ Test
9211047 public void perRpcBufferLimitExceeded () {
0 commit comments