-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathSegmentTree.java
More file actions
81 lines (70 loc) · 3.06 KB
/
SegmentTree.java
File metadata and controls
81 lines (70 loc) · 3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
package org.psjava;
import org.psjava.ds.tree.BinaryTreeByArray;
import org.psjava.util.Assertion;
import java.util.List;
import java.util.function.BinaryOperator;
/**
* This class is for only simple replacement updating. This class has an advantage of BinaryTreeByArray's speed.
*/
public class SegmentTree<T> {
private final BinaryOperator<T> merger;
private final BinaryTreeByArray<T> tree;
private final int size;
public SegmentTree(final List<T> initialData, final BinaryOperator<T> merger) {
this.merger = merger;
size = initialData.size();
tree = new BinaryTreeByArray<>();
if (!initialData.isEmpty()) {
int root = tree.createRoot(initialData.get(0));
construct(root, initialData, 0, initialData.size());
}
}
private void construct(int node, List<T> initialData, int start, int end) {
if (end - start == 1) {
tree.setValue(node, initialData.get(start));
} else {
T any = initialData.get(0);
int mid = (start + end) / 2;
int left = tree.putChild(node, false, any);
int right = tree.putChild(node, true, any);
construct(left, initialData, start, mid);
construct(right, initialData, mid, end);
tree.setValue(node, merger.apply(tree.getValue(left), tree.getValue(right)));
}
}
public T query(int start, int end) {
Assertion.ensure(start < end && end <= size, () -> "invalid range start=" + start + ", end=" + end);
return queryRecursively(tree.getRootPointer(), 0, size, start, end);
}
private T queryRecursively(int node, int nodeStart, int nodeEnd, int start, int end) {
if (nodeStart == start && nodeEnd == end) {
return tree.getValue(node);
} else {
int mid = (nodeStart + nodeEnd) / 2;
if (end <= mid)
return queryRecursively(tree.getLeft(node), nodeStart, mid, start, end);
else if (mid <= start)
return queryRecursively(tree.getRight(node), mid, nodeEnd, start, end);
else
return merger.apply(queryRecursively(tree.getLeft(node), nodeStart, mid, start, mid), queryRecursively(tree.getRight(node), mid, nodeEnd, mid, end));
}
}
public void update(int position, T value) {
Assertion.ensure(position < size);
updateRecursively(tree.getRootPointer(), 0, size, position, value);
}
private void updateRecursively(int node, int nodeStart, int nodeEnd, int position, T value) {
if (nodeEnd - nodeStart == 1) {
tree.setValue(node, value);
} else {
int left = tree.getLeft(node);
int right = tree.getRight(node);
int mid = (nodeStart + nodeEnd) / 2;
if (position < mid)
updateRecursively(left, nodeStart, mid, position, value);
else
updateRecursively(right, mid, nodeEnd, position, value);
tree.setValue(node, merger.apply(tree.getValue(left), tree.getValue(right)));
}
}
}