2828import com .google .common .annotations .VisibleForTesting ;
2929import com .google .common .base .MoreObjects ;
3030import com .google .common .base .Preconditions ;
31+ import com .google .common .util .concurrent .Futures ;
3132import com .google .common .util .concurrent .ListenableFuture ;
3233import com .google .common .util .concurrent .SettableFuture ;
3334import io .grpc .Attributes ;
4647import io .grpc .InternalServerInterceptors ;
4748import io .grpc .Metadata ;
4849import io .grpc .ServerCall ;
50+ import io .grpc .ServerCallExecutorSupplier ;
4951import io .grpc .ServerCallHandler ;
5052import io .grpc .ServerInterceptor ;
5153import io .grpc .ServerMethodDefinition ;
@@ -125,6 +127,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
125127 private final InternalChannelz channelz ;
126128 private final CallTracer serverCallTracer ;
127129 private final Deadline .Ticker ticker ;
130+ private final ServerCallExecutorSupplier executorSupplier ;
128131
129132 /**
130133 * Construct a server.
@@ -159,6 +162,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
159162 this .serverCallTracer = builder .callTracerFactory .create ();
160163 this .ticker = checkNotNull (builder .ticker , "ticker" );
161164 channelz .addServer (this );
165+ this .executorSupplier = builder .executorSupplier ;
162166 }
163167
164168 /**
@@ -469,11 +473,11 @@ private void streamCreatedInternal(
469473 final Executor wrappedExecutor ;
470474 // This is a performance optimization that avoids the synchronization and queuing overhead
471475 // that comes with SerializingExecutor.
472- if (executor == directExecutor ()) {
476+ if (executorSupplier != null || executor != directExecutor ()) {
477+ wrappedExecutor = new SerializingExecutor (executor );
478+ } else {
473479 wrappedExecutor = new SerializeReentrantCallsDirectExecutor ();
474480 stream .optimizeForDirectExecutor ();
475- } else {
476- wrappedExecutor = new SerializingExecutor (executor );
477481 }
478482
479483 if (headers .containsKey (MESSAGE_ENCODING_KEY )) {
@@ -499,52 +503,120 @@ private void streamCreatedInternal(
499503
500504 final JumpToApplicationThreadServerStreamListener jumpListener
501505 = new JumpToApplicationThreadServerStreamListener (
502- wrappedExecutor , executor , stream , context , tag );
506+ wrappedExecutor , executor , stream , context , tag );
503507 stream .setListener (jumpListener );
504- // Run in wrappedExecutor so jumpListener.setListener() is called before any callbacks
505- // are delivered, including any errors. Callbacks can still be triggered, but they will be
506- // queued.
507-
508- final class StreamCreated extends ContextRunnable {
509- StreamCreated () {
508+ final SettableFuture <ServerCallParameters <?,?>> future = SettableFuture .create ();
509+ // Run in serializing executor so jumpListener.setListener() is called before any callbacks
510+ // are delivered, including any errors. MethodLookup() and HandleServerCall() are proactively
511+ // queued before any callbacks are queued at serializing executor.
512+ // MethodLookup() runs on the default executor.
513+ // When executorSupplier is enabled, MethodLookup() may set/change the executor in the
514+ // SerializingExecutor before it finishes running.
515+ // Then HandleServerCall() and callbacks would switch to the executorSupplier executor.
516+ // Otherwise, they all run on the default executor.
517+
518+ final class MethodLookup extends ContextRunnable {
519+ MethodLookup () {
510520 super (context );
511521 }
512522
513523 @ Override
514524 public void runInContext () {
515- PerfMark .startTask ("ServerTransportListener$StreamCreated .startCall" , tag );
525+ PerfMark .startTask ("ServerTransportListener$MethodLookup .startCall" , tag );
516526 PerfMark .linkIn (link );
517527 try {
518528 runInternal ();
519529 } finally {
520- PerfMark .stopTask ("ServerTransportListener$StreamCreated .startCall" , tag );
530+ PerfMark .stopTask ("ServerTransportListener$MethodLookup .startCall" , tag );
521531 }
522532 }
523533
524534 private void runInternal () {
525- ServerStreamListener listener = NOOP_LISTENER ;
535+ ServerMethodDefinition <?, ?> wrapMethod ;
536+ ServerCallParameters <?, ?> callParams ;
526537 try {
527538 ServerMethodDefinition <?, ?> method = registry .lookupMethod (methodName );
528539 if (method == null ) {
529540 method = fallbackRegistry .lookupMethod (methodName , stream .getAuthority ());
530541 }
531542 if (method == null ) {
532543 Status status = Status .UNIMPLEMENTED .withDescription (
533- "Method not found: " + methodName );
544+ "Method not found: " + methodName );
534545 // TODO(zhangkun83): this error may be recorded by the tracer, and if it's kept in
535546 // memory as a map whose key is the method name, this would allow a misbehaving
536547 // client to blow up the server in-memory stats storage by sending large number of
537548 // distinct unimplemented method
538549 // names. (https://github.com/grpc/grpc-java/issues/2285)
550+ jumpListener .setListener (NOOP_LISTENER );
539551 stream .close (status , new Metadata ());
540552 context .cancel (null );
553+ future .cancel (false );
541554 return ;
542555 }
543- listener = startCall (stream , methodName , method , headers , context , statsTraceCtx , tag );
556+ wrapMethod = wrapMethod (stream , method , statsTraceCtx );
557+ callParams = maySwitchExecutor (wrapMethod , stream , headers , context , tag );
558+ future .set (callParams );
544559 } catch (Throwable t ) {
560+ jumpListener .setListener (NOOP_LISTENER );
545561 stream .close (Status .fromThrowable (t ), new Metadata ());
546562 context .cancel (null );
563+ future .cancel (false );
547564 throw t ;
565+ }
566+ }
567+
568+ private <ReqT , RespT > ServerCallParameters <ReqT , RespT > maySwitchExecutor (
569+ final ServerMethodDefinition <ReqT , RespT > methodDef ,
570+ final ServerStream stream ,
571+ final Metadata headers ,
572+ final Context .CancellableContext context ,
573+ final Tag tag ) {
574+ final ServerCallImpl <ReqT , RespT > call = new ServerCallImpl <>(
575+ stream ,
576+ methodDef .getMethodDescriptor (),
577+ headers ,
578+ context ,
579+ decompressorRegistry ,
580+ compressorRegistry ,
581+ serverCallTracer ,
582+ tag );
583+ if (executorSupplier != null ) {
584+ Executor switchingExecutor = executorSupplier .getExecutor (call , headers );
585+ if (switchingExecutor != null ) {
586+ ((SerializingExecutor )wrappedExecutor ).setExecutor (switchingExecutor );
587+ }
588+ }
589+ return new ServerCallParameters <>(call , methodDef .getServerCallHandler ());
590+ }
591+ }
592+
593+ final class HandleServerCall extends ContextRunnable {
594+ HandleServerCall () {
595+ super (context );
596+ }
597+
598+ @ Override
599+ public void runInContext () {
600+ PerfMark .startTask ("ServerTransportListener$HandleServerCall.startCall" , tag );
601+ PerfMark .linkIn (link );
602+ try {
603+ runInternal ();
604+ } finally {
605+ PerfMark .stopTask ("ServerTransportListener$HandleServerCall.startCall" , tag );
606+ }
607+ }
608+
609+ private void runInternal () {
610+ ServerStreamListener listener = NOOP_LISTENER ;
611+ if (future .isCancelled ()) {
612+ return ;
613+ }
614+ try {
615+ listener = startWrappedCall (methodName , Futures .getDone (future ), headers );
616+ } catch (Throwable ex ) {
617+ stream .close (Status .fromThrowable (ex ), new Metadata ());
618+ context .cancel (null );
619+ throw new IllegalStateException (ex );
548620 } finally {
549621 jumpListener .setListener (listener );
550622 }
@@ -568,7 +640,8 @@ public void cancelled(Context context) {
568640 }
569641 }
570642
571- wrappedExecutor .execute (new StreamCreated ());
643+ wrappedExecutor .execute (new MethodLookup ());
644+ wrappedExecutor .execute (new HandleServerCall ());
572645 }
573646
574647 private Context .CancellableContext createContext (
@@ -593,9 +666,8 @@ private Context.CancellableContext createContext(
593666 }
594667
595668 /** Never returns {@code null}. */
596- private <ReqT , RespT > ServerStreamListener startCall (ServerStream stream , String fullMethodName ,
597- ServerMethodDefinition <ReqT , RespT > methodDef , Metadata headers ,
598- Context .CancellableContext context , StatsTraceContext statsTraceCtx , Tag tag ) {
669+ private <ReqT , RespT > ServerMethodDefinition <?,?> wrapMethod (ServerStream stream ,
670+ ServerMethodDefinition <ReqT , RespT > methodDef , StatsTraceContext statsTraceCtx ) {
599671 // TODO(ejona86): should we update fullMethodName to have the canonical path of the method?
600672 statsTraceCtx .serverCallStarted (
601673 new ServerCallInfoImpl <>(
@@ -609,34 +681,31 @@ private <ReqT, RespT> ServerStreamListener startCall(ServerStream stream, String
609681 ServerMethodDefinition <ReqT , RespT > interceptedDef = methodDef .withServerCallHandler (handler );
610682 ServerMethodDefinition <?, ?> wMethodDef = binlog == null
611683 ? interceptedDef : binlog .wrapMethodDefinition (interceptedDef );
612- return startWrappedCall (fullMethodName , wMethodDef , stream , headers , context , tag );
684+ return wMethodDef ;
685+ }
686+
687+ private final class ServerCallParameters <ReqT , RespT > {
688+ ServerCallImpl <ReqT , RespT > call ;
689+ ServerCallHandler <ReqT , RespT > callHandler ;
690+
691+ public ServerCallParameters (ServerCallImpl <ReqT , RespT > call ,
692+ ServerCallHandler <ReqT , RespT > callHandler ) {
693+ this .call = call ;
694+ this .callHandler = callHandler ;
695+ }
613696 }
614697
615698 private <WReqT , WRespT > ServerStreamListener startWrappedCall (
616699 String fullMethodName ,
617- ServerMethodDefinition <WReqT , WRespT > methodDef ,
618- ServerStream stream ,
619- Metadata headers ,
620- Context .CancellableContext context ,
621- Tag tag ) {
622-
623- ServerCallImpl <WReqT , WRespT > call = new ServerCallImpl <>(
624- stream ,
625- methodDef .getMethodDescriptor (),
626- headers ,
627- context ,
628- decompressorRegistry ,
629- compressorRegistry ,
630- serverCallTracer ,
631- tag );
632-
633- ServerCall .Listener <WReqT > listener =
634- methodDef .getServerCallHandler ().startCall (call , headers );
635- if (listener == null ) {
700+ ServerCallParameters <WReqT , WRespT > params ,
701+ Metadata headers ) {
702+ ServerCall .Listener <WReqT > callListener =
703+ params .callHandler .startCall (params .call , headers );
704+ if (callListener == null ) {
636705 throw new NullPointerException (
637- "startCall() returned a null listener for method " + fullMethodName );
706+ "startCall() returned a null listener for method " + fullMethodName );
638707 }
639- return call .newServerStreamListener (listener );
708+ return params . call .newServerStreamListener (callListener );
640709 }
641710 }
642711
0 commit comments