-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTreeArrayIllustration.java
More file actions
138 lines (122 loc) · 4.26 KB
/
TreeArrayIllustration.java
File metadata and controls
138 lines (122 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package bartMachine;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.HashMap;
import javax.imageio.ImageIO;
/**
* This class builds illustrations of many trees. Since it is only used
* as a debugging feature, it is undocumented
*
* @author Adam Kapelner
*/
public class TreeArrayIllustration {
private int sample_num;
private transient ArrayList<bartMachineTreeNode> trees;
private transient ArrayList<Double> likelihoods;
private String unique_name;
public static NumberFormat one_digit_format = NumberFormat.getInstance();
static {
one_digit_format.setMaximumFractionDigits(1);
}
public TreeArrayIllustration(int sample_num, String unique_name) {
this.sample_num = sample_num;
this.unique_name = unique_name;
trees = new ArrayList<bartMachineTreeNode>();
likelihoods = new ArrayList<Double>();
}
public void AddTree(bartMachineTreeNode tree) {
trees.add(tree);
}
public void addLikelihood(double lik) {
likelihoods.add(lik);
}
public synchronized void CreateIllustrationAndSaveImage() {
//first pull out all the tree images
int m = trees.size();
int w = 0;
int h = Integer.MIN_VALUE;
ArrayList<BufferedImage> canvases = new ArrayList<BufferedImage>(m);
for (int t = 0; t < m; t++){
bartMachineTreeNode tree = trees.get(t);
HashMap<String, String> info = new HashMap<String, String>();
info.put("tree_num", "" + (t + 1));
info.put("num_iteration", "" + sample_num);
// info.put("likelihood", "" + one_digit_format.format(likelihoods.get(t)));
BufferedImage canvas = new TreeIllustration(tree, info).getCanvas();
w += canvas.getWidth(); //aggregate the widths
if (canvas.getHeight() > h){ //get the maximum height
h = canvas.getHeight();
}
canvases.add(canvas);
}
BufferedImage master_canvas = new BufferedImage(w, h, BufferedImage.TYPE_BYTE_BINARY);
int sliding_width = 0;
for (int t = 0; t < m; t++){
BufferedImage canvas = canvases.get(t);
master_canvas.getGraphics().drawImage(canvas, sliding_width, 0, null);
sliding_width += canvas.getWidth();
}
saveImageFile(master_canvas);
}
private void saveImageFile(BufferedImage image) {
// System.out.println("w = " + image.getWidth() + " h = " + image.getHeight() + "sample_num: " + sample_num);
String title = "BART_" + unique_name + "_iter_" + LeadingZeroes(sample_num, 5);
try {
ImageIO.write(image, "PNG", new File(title + ".png"));
} catch (IOException e) {
System.err.println("can't save " + title);
}
}
private static final String ZEROES = "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000";
public static String LeadingZeroes(double num, int num_digits) {
if (num < 10 && num_digits >= 2){
return ZEROES.substring(0, num_digits - 1) + num;
}
else if (num < 100 && num_digits >= 3){
return ZEROES.substring(0, num_digits - 2) + num;
}
else if (num < 1000 && num_digits >= 4){
return ZEROES.substring(0, num_digits - 3) + num;
}
else if (num < 10000 && num_digits >= 5){
return ZEROES.substring(0, num_digits - 4) + num;
}
else if (num < 100000 && num_digits >= 6){
return ZEROES.substring(0, num_digits - 5) + num;
}
else if (num < 1000000 && num_digits >= 7){
return ZEROES.substring(0, num_digits - 6) + num;
}
else if (num < 10000000 && num_digits >= 8){
return ZEROES.substring(0, num_digits - 7) + num;
}
return String.valueOf(num);
}
public static String LeadingZeroes(int num, int num_digits) {
if (num < 10 && num_digits >= 2){
return ZEROES.substring(0, num_digits - 1) + num;
}
else if (num < 100 && num_digits >= 3){
return ZEROES.substring(0, num_digits - 2) + num;
}
else if (num < 1000 && num_digits >= 4){
return ZEROES.substring(0, num_digits - 3) + num;
}
else if (num < 10000 && num_digits >= 5){
return ZEROES.substring(0, num_digits - 4) + num;
}
else if (num < 100000 && num_digits >= 6){
return ZEROES.substring(0, num_digits - 5) + num;
}
else if (num < 1000000 && num_digits >= 7){
return ZEROES.substring(0, num_digits - 6) + num;
}
else if (num < 10000000 && num_digits >= 8){
return ZEROES.substring(0, num_digits - 7) + num;
}
return String.valueOf(num);
}
}