-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbartMachine_c_debug.java
More file actions
84 lines (73 loc) · 3 KB
/
bartMachine_c_debug.java
File metadata and controls
84 lines (73 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package bartMachine;
import java.io.Serializable;
/**
* This portion of the code used to have many debug functions. These have
* been removed during the tidy up for release.
*
* @author Adam Kapelner and Justin Bleich
*/
@SuppressWarnings("serial")
public abstract class bartMachine_c_debug extends bartMachine_b_hyperparams implements Serializable{
/** should we create illustrations of the trees and save the images to the debug directory? */
protected boolean tree_illust = false;
/** the hook that gets called to save the tree illustrations when the Gibbs sampler begins */
protected void InitTreeIllustrations() {
bartMachineTreeNode[] initial_trees = gibbs_samples_of_bart_trees[0];
TreeArrayIllustration tree_array_illustration = new TreeArrayIllustration(0, unique_name);
for (bartMachineTreeNode tree : initial_trees){
tree_array_illustration.AddTree(tree);
tree_array_illustration.addLikelihood(0);
}
tree_array_illustration.CreateIllustrationAndSaveImage();
}
/** the hook that gets called to save the tree illustrations for each Gibbs sample */
protected void illustrate(TreeArrayIllustration tree_array_illustration) {
if (tree_illust){ //
tree_array_illustration.CreateIllustrationAndSaveImage();
}
}
/**
* Get the untransformed samples of the sigsqs from the Gibbs chaing
*
* @return The vector of untransformed variances over all the Gibbs samples
*/
public double[] getGibbsSamplesSigsqs(){
double[] sigsqs_to_export = new double[gibbs_samples_of_sigsq.length];
for (int n_g = 0; n_g < gibbs_samples_of_sigsq.length; n_g++){
sigsqs_to_export[n_g] = un_transform_sigsq(gibbs_samples_of_sigsq[n_g]);
}
return sigsqs_to_export;
}
/**
* Queries the depths of the <code>num_trees</code> trees between a range of Gibbs samples
*
* @param n_g_i The Gibbs sample number to start querying
* @param n_g_f The Gibbs sample number (+1) to stop querying
* @return The depths of all <code>num_trees</code> trees for each Gibbs sample specified
*/
public int[][] getDepthsForTrees(int n_g_i, int n_g_f){
int[][] all_depths = new int[n_g_f - n_g_i][num_trees];
for (int g = n_g_i; g < n_g_f; g++){
for (int t = 0; t < num_trees; t++){
all_depths[g - n_g_i][t] = gibbs_samples_of_bart_trees[g][t].deepestNode();
}
}
return all_depths;
}
/**
* Queries the number of nodes (terminal and non-terminal) in the <code>num_trees</code> trees between a range of Gibbs samples
*
* @param n_g_i The Gibbs sample number to start querying
* @param n_g_f The Gibbs sample number (+1) to stop querying
* @return The number of nodes of all <code>num_trees</code> trees for each Gibbs sample specified
*/
public int[][] getNumNodesAndLeavesForTrees(int n_g_i, int n_g_f){
int[][] all_new_nodes = new int[n_g_f - n_g_i][num_trees];
for (int g = n_g_i; g < n_g_f; g++){
for (int t = 0; t < num_trees; t++){
all_new_nodes[g - n_g_i][t] = gibbs_samples_of_bart_trees[g][t].numNodesAndLeaves();
}
}
return all_new_nodes;
}
}