-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbartMachine_a_base.java
More file actions
91 lines (75 loc) · 3.2 KB
/
bartMachine_a_base.java
File metadata and controls
91 lines (75 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
package bartMachine;
import java.io.Serializable;
/**
* The base class for any BART implementation. Contains
* mostly instance variables and settings for the algorithm
*
* @author Adam Kapelner and Justin Bleich
*/
@SuppressWarnings("serial")
public abstract class bartMachine_a_base extends Classifier implements Serializable {
/** all Gibbs samples for burn-in and post burn-in where each entry is a vector of pointers to the <code>num_trees</code> trees in the sum-of-trees model */
protected bartMachineTreeNode[][] gibbs_samples_of_bart_trees;
/** Gibbs samples post burn-in where each entry is a vector of pointers to the <code>num_trees</code> trees in the sum-of-trees model */
protected bartMachineTreeNode[][] gibbs_samples_of_bart_trees_after_burn_in;
/** Gibbs samples for burn-in and post burn-in of the variances */
protected double[] gibbs_samples_of_sigsq;
/** Gibbs samples for post burn-in of the variances */
protected double[] gibbs_samples_of_sigsq_after_burn_in;
/** a record of whether each Gibbs sample accepted or rejected the MH step within each of the <code>num_trees</code> trees */
protected boolean[][] accept_reject_mh;
/** a record of the proposal of each Gibbs sample within each of the <code>m</code> trees: G, P or C for "grow", "prune", "change" */
protected char[][] accept_reject_mh_steps;
/** the number of trees in our sum-of-trees model */
protected int num_trees;
/** how many Gibbs samples we burn-in for */
protected int num_gibbs_burn_in;
/** how many total Gibbs samples in a BART model creation */
protected int num_gibbs_total_iterations;
/** the current thread being used to run this Gibbs sampler */
protected int threadNum;
/** how many CPU cores to use during model creation */
protected int num_cores;
/**
* whether or not we use the memory cache feature
*
* @see Section 3.1 of Kapelner, A and Bleich, J. bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
*/
protected boolean mem_cache_for_speed;
/** saves indices in nodes (useful for computing weights) */
protected boolean flush_indices_to_save_ram;
/** should we print stuff out to screen? */
protected boolean verbose = true;
/** Remove unnecessary data from the Gibbs chain to conserve RAM */
protected void FlushData() {
for (bartMachineTreeNode[] bart_trees : gibbs_samples_of_bart_trees){
FlushDataForSample(bart_trees);
}
}
/** Remove unnecessary data from an individual Gibbs sample */
protected void FlushDataForSample(bartMachineTreeNode[] bart_trees) {
for (bartMachineTreeNode tree : bart_trees){
tree.flushNodeData();
}
}
/** Must be implemented, but does nothing */
public void StopBuilding(){}
public void setThreadNum(int threadNum) {
this.threadNum = threadNum;
}
public void setVerbose(boolean verbose){
this.verbose = verbose;
}
public void setTotalNumThreads(int num_cores) {
this.num_cores = num_cores;
}
public void setMemCacheForSpeed(boolean mem_cache_for_speed){
this.mem_cache_for_speed = mem_cache_for_speed;
}
public void setFlushIndicesToSaveRAM(boolean flush_indices_to_save_ram) {
this.flush_indices_to_save_ram = flush_indices_to_save_ram;
}
public void setNumTrees(int m){
this.num_trees = m;
}
}