Skip to content

Commit 0d47f5b

Browse files
authored
xds: WRRPicker must not access unsynchronized data in ChildLbState
There was no point to using subchannels as keys to subchannelToReportListenerMap, as the listener is per-child. That meant the keys would be guaranteed to be known ahead-of-time and the unsynchronized getOrCreateOrcaListener() during picking was unnecessary. The picker still stores ChildLbStates to make sure that updating weights uses the correct children, but the picker itself no longer references ChildLbStates except in the constructor. That means weight calculation is moved into the LB policy, as child.getWeight() is unsynchronized, and the picker no longer needs a reference to helper.
1 parent 0d2ad89 commit 0d47f5b

File tree

2 files changed

+67
-67
lines changed

2 files changed

+67
-67
lines changed

xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java

Lines changed: 66 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@
4444
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
4545
import io.grpc.xds.orca.OrcaPerRequestUtil;
4646
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
47+
import java.util.ArrayList;
4748
import java.util.Collection;
48-
import java.util.HashMap;
4949
import java.util.HashSet;
5050
import java.util.List;
51-
import java.util.Map;
5251
import java.util.Random;
5352
import java.util.Set;
5453
import java.util.concurrent.ScheduledExecutorService;
@@ -233,9 +232,44 @@ protected void updateOverallBalancingState() {
233232
}
234233

235234
private SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {
236-
return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
237-
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper(),
238-
locality);
235+
WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
236+
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence);
237+
updateWeight(picker);
238+
return picker;
239+
}
240+
241+
private void updateWeight(WeightedRoundRobinPicker picker) {
242+
Helper helper = getHelper();
243+
float[] newWeights = new float[picker.children.size()];
244+
AtomicInteger staleEndpoints = new AtomicInteger();
245+
AtomicInteger notYetUsableEndpoints = new AtomicInteger();
246+
for (int i = 0; i < picker.children.size(); i++) {
247+
double newWeight = ((WeightedChildLbState) picker.children.get(i)).getWeight(staleEndpoints,
248+
notYetUsableEndpoints);
249+
helper.getMetricRecorder()
250+
.recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
251+
ImmutableList.of(helper.getChannelTarget()),
252+
ImmutableList.of(locality));
253+
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
254+
}
255+
256+
if (staleEndpoints.get() > 0) {
257+
helper.getMetricRecorder()
258+
.addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
259+
ImmutableList.of(helper.getChannelTarget()),
260+
ImmutableList.of(locality));
261+
}
262+
if (notYetUsableEndpoints.get() > 0) {
263+
helper.getMetricRecorder()
264+
.addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
265+
ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));
266+
}
267+
boolean weightsEffective = picker.updateWeight(newWeights);
268+
if (!weightsEffective) {
269+
helper.getMetricRecorder()
270+
.addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
271+
ImmutableList.of(locality));
272+
}
239273
}
240274

241275
private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) {
@@ -345,7 +379,7 @@ private final class UpdateWeightTask implements Runnable {
345379
@Override
346380
public void run() {
347381
if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
348-
((WeightedRoundRobinPicker) currentPicker).updateWeight();
382+
updateWeight((WeightedRoundRobinPicker) currentPicker);
349383
}
350384
weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
351385
TimeUnit.NANOSECONDS, timeService);
@@ -415,110 +449,76 @@ public void shutdown() {
415449

416450
@VisibleForTesting
417451
static final class WeightedRoundRobinPicker extends SubchannelPicker {
418-
private final List<ChildLbState> children;
419-
private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap =
420-
new HashMap<>();
452+
// Parallel lists (column-based storage instead of normal row-based storage of List<Struct>).
453+
// The ith element of children corresponds to the ith element of pickers, listeners, and even
454+
// updateWeight(float[]).
455+
private final List<ChildLbState> children; // May only be accessed from sync context
456+
private final List<SubchannelPicker> pickers;
457+
private final List<OrcaPerRequestReportListener> reportListeners;
421458
private final boolean enableOobLoadReport;
422459
private final float errorUtilizationPenalty;
423460
private final AtomicInteger sequence;
424461
private final int hashCode;
425-
private final LoadBalancer.Helper helper;
426-
private final String locality;
427462
private volatile StaticStrideScheduler scheduler;
428463

429464
WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
430-
float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper,
431-
String locality) {
465+
float errorUtilizationPenalty, AtomicInteger sequence) {
432466
checkNotNull(children, "children");
433467
Preconditions.checkArgument(!children.isEmpty(), "empty child list");
434468
this.children = children;
469+
List<SubchannelPicker> pickers = new ArrayList<>(children.size());
470+
List<OrcaPerRequestReportListener> reportListeners = new ArrayList<>(children.size());
435471
for (ChildLbState child : children) {
436472
WeightedChildLbState wChild = (WeightedChildLbState) child;
437-
for (WrrSubchannel subchannel : wChild.subchannels) {
438-
this.subchannelToReportListenerMap
439-
.put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
440-
}
473+
pickers.add(wChild.getCurrentPicker());
474+
reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
441475
}
476+
this.pickers = pickers;
477+
this.reportListeners = reportListeners;
442478
this.enableOobLoadReport = enableOobLoadReport;
443479
this.errorUtilizationPenalty = errorUtilizationPenalty;
444480
this.sequence = checkNotNull(sequence, "sequence");
445-
this.helper = helper;
446-
this.locality = checkNotNull(locality, "locality");
447481

448-
// For equality we treat children as a set; use hash code as defined by Set
482+
// For equality we treat pickers as a set; use hash code as defined by Set
449483
int sum = 0;
450-
for (ChildLbState child : children) {
451-
sum += child.hashCode();
484+
for (SubchannelPicker picker : pickers) {
485+
sum += picker.hashCode();
452486
}
453487
this.hashCode = sum
454488
^ Boolean.hashCode(enableOobLoadReport)
455489
^ Float.hashCode(errorUtilizationPenalty);
456-
457-
updateWeight();
458490
}
459491

460492
@Override
461493
public PickResult pickSubchannel(PickSubchannelArgs args) {
462-
ChildLbState childLbState = children.get(scheduler.pick());
463-
WeightedChildLbState wChild = (WeightedChildLbState) childLbState;
464-
PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args);
494+
int pick = scheduler.pick();
495+
PickResult pickResult = pickers.get(pick).pickSubchannel(args);
465496
Subchannel subchannel = pickResult.getSubchannel();
466497
if (subchannel == null) {
467498
return pickResult;
468499
}
469500
if (!enableOobLoadReport) {
470501
return PickResult.withSubchannel(subchannel,
471502
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
472-
subchannelToReportListenerMap.getOrDefault(subchannel,
473-
wChild.getOrCreateOrcaListener(errorUtilizationPenalty))));
503+
reportListeners.get(pick)));
474504
} else {
475505
return PickResult.withSubchannel(subchannel);
476506
}
477507
}
478508

479-
private void updateWeight() {
480-
float[] newWeights = new float[children.size()];
481-
AtomicInteger staleEndpoints = new AtomicInteger();
482-
AtomicInteger notYetUsableEndpoints = new AtomicInteger();
483-
for (int i = 0; i < children.size(); i++) {
484-
double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints,
485-
notYetUsableEndpoints);
486-
// TODO: add locality label once available
487-
helper.getMetricRecorder()
488-
.recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
489-
ImmutableList.of(helper.getChannelTarget()),
490-
ImmutableList.of(locality));
491-
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
492-
}
493-
if (staleEndpoints.get() > 0) {
494-
// TODO: add locality label once available
495-
helper.getMetricRecorder()
496-
.addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
497-
ImmutableList.of(helper.getChannelTarget()),
498-
ImmutableList.of(locality));
499-
}
500-
if (notYetUsableEndpoints.get() > 0) {
501-
// TODO: add locality label once available
502-
helper.getMetricRecorder()
503-
.addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
504-
ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));
505-
}
506-
509+
/** Returns {@code true} if weights are different than round_robin. */
510+
private boolean updateWeight(float[] newWeights) {
507511
this.scheduler = new StaticStrideScheduler(newWeights, sequence);
508-
if (this.scheduler.usesRoundRobin()) {
509-
// TODO: locality label once available
510-
helper.getMetricRecorder()
511-
.addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
512-
ImmutableList.of(locality));
513-
}
512+
return !this.scheduler.usesRoundRobin();
514513
}
515514

516515
@Override
517516
public String toString() {
518517
return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
519518
.add("enableOobLoadReport", enableOobLoadReport)
520519
.add("errorUtilizationPenalty", errorUtilizationPenalty)
521-
.add("list", children).toString();
520+
.add("pickers", pickers)
521+
.toString();
522522
}
523523

524524
@VisibleForTesting
@@ -545,8 +545,8 @@ public boolean equals(Object o) {
545545
&& sequence == other.sequence
546546
&& enableOobLoadReport == other.enableOobLoadReport
547547
&& Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
548-
&& children.size() == other.children.size()
549-
&& new HashSet<>(children).containsAll(other.children);
548+
&& pickers.size() == other.pickers.size()
549+
&& new HashSet<>(pickers).containsAll(other.pickers);
550550
}
551551
}
552552

xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ public void wrrLifeCycle() {
244244
String weightedPickerStr = weightedPicker.toString();
245245
assertThat(weightedPickerStr).contains("enableOobLoadReport=false");
246246
assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0");
247-
assertThat(weightedPickerStr).contains("list=");
247+
assertThat(weightedPickerStr).contains("pickers=");
248248

249249
WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
250250
WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);

0 commit comments

Comments
 (0)