-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgmm_analysis.py
More file actions
83 lines (69 loc) · 2.76 KB
/
gmm_analysis.py
File metadata and controls
83 lines (69 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import Dict
from os import makedirs
import matplotlib.pyplot as plt
# from ogc import dimensionality_reduction as dr
from ogc.classifiers.gmm import GMM
from ogc import utilities
import numpy.typing as npt
import numpy as np
from project import TRAINING_DATA, ROOT_PATH
import logging
from project import ZNormalization as znorm_cached
from project import PCA as PCA_Cached
from pprint import pprint
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
OUTPUT_PATH = ROOT_PATH + "../images/gmm_analysis/"
TABLES_OUTPUT_PATH = ROOT_PATH + "../tables/gmm_analysis/"
makedirs(OUTPUT_PATH, exist_ok=True)
makedirs(TABLES_OUTPUT_PATH, exist_ok=True)
def GMM_callback(prior, dataset_type, mvg_param, dimred, components):
from ogc.utilities import Kfold
DTR, LTR = TRAINING_DATA()
if "tied" in mvg_param.keys():
model = GMM.GMMTiedCov(components)
elif "naive" in mvg_param.keys():
model = GMM.GMMDiag(components)
else:
model = GMM.GMM(components)
if dataset_type == "Z-Norm":
from ogc import utilities as utils
DTR = utils.ZNormalization(DTR)[0]
if dimred != None:
from ogc import dimensionality_reduction as dr
DTR = dr.PCA(DTR, dimred)[0]
kfold = Kfold(DTR, LTR, model, 5, prior=prior)
return kfold
def main():
fast_run = True
if fast_run:
priors = [("$\pi = 0.5$", 0.5), ("$\pi = 0.1$", 0.1),
("$\pi = 0.9$", 0.9)]
dataset_types = [("RAW", None), ("Z-Norm", "Z-Norm")]
mvg_params = [("Naive GMM", {
"naive": True}), ("Tied GMM", {"tied": True})]
dimred = [("No PCA", None), ("PCA $(m=5)$", 5)]
components = [("5", 5)]
else:
priors = [("$\pi = 0.5$", 0.5), ("$\pi = 0.1$", 0.1),
("$\pi = 0.9$", 0.9)]
dataset_types = [("RAW", None), ("Z-Norm", "Z-Norm")]
mvg_params = [("Standard GMM", {}), ("Naive GMM", {
"naive": True}), ("Tied GMM", {"tied": True})]
dimred = [("No PCA", None), ("PCA $(m=5)$", 5)]
components = [("1", 1), ("2", 2), ("3", 3), ("4", 4)]
use_csv = True
if use_csv:
table = utilities.load_from_csv(TABLES_OUTPUT_PATH + "gmm_results.csv")
table1 = utilities.load_from_csv(TABLES_OUTPUT_PATH + "gmm_results1.csv")
else:
_, table = utilities.grid_search(
GMM_callback, priors, dataset_types, mvg_params, dimred, components)
np.savetxt(TABLES_OUTPUT_PATH + "gmm_results.csv", table, delimiter=";", fmt="%s",
header=";".join(["Prior", "Dataset", "MVG", "PCA", "Components", "MinDCF"]))
if __name__ == "__main__":
import time
start = time.time()
main()
print(f"Time elapsed: {time.time() - start} seconds")