-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMoE.py
More file actions
63 lines (45 loc) · 2.43 KB
/
MoE.py
File metadata and controls
63 lines (45 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# -*- coding: utf-8 -*-
"""
Created on Thu May 25 19:54:23 2023
@author: oowoyele
"""
import torch
class MoE(): # fully connected neural network class
def __init__(self, fcn_list, kappa = 5):
# fcn_list = list of fully connected neural network objects, created using the MLP() class.
# kappa is a parameter that controls how strongly we want to separate the experts
#self.num_experts = num_experts
self.num_experts = len(fcn_list)
self.kappa = kappa
self.fcn_list = fcn_list
self.alpha = None
def compute_SE(self, fcn):
return (fcn.output - fcn.y)**2
def compute_MSE(self, fcn):
return torch.mean((fcn.output - fcn.y)**2)
def compute_weighted_MSE(self, fcn, alpha, update_y = True):
# computes and returns the weighted (each sample is weighted using alpha)
# if update_y is true, it updates the model predictions using latest weights before computing the MSE
if update_y:
fcn.pred()
return torch.mean(self.alpha*(fcn.output - fcn.y)**2)
def get_num_winning_points(self):
# computes and returns the number of points "won" by each model
num_wp = [len(torch.where(torch.argmax(self.alpha, axis=1) == iexp)[0].detach().numpy()) for iexp in torch.arange(self.num_experts)]
return num_wp
def get_winning_points_inds(self):
# computes and returns the indices of points "won" by each model
inds_exp = [torch.where(torch.argmax(self.alpha, axis=1) == iexp)[0] for iexp in torch.arange(self.num_experts)]
return inds_exp
def compute_alpha(self):
# computes the weights for the MSE (stored as alpha)
with torch.no_grad():
errors = [self.compute_SE(fcn) for fcn in self.fcn_list]
errors_mat = torch.concatenate(errors, axis = 1)
errors_mat_norm = errors_mat/torch.amax(errors_mat, axis = 1)[:,None]
self.alpha = torch.exp(-self.kappa*errors_mat_norm)/torch.sum(torch.exp(-self.kappa*errors_mat_norm), axis = 1)[:,None]
return self.alpha
def compute_weighted_mse(self, update_y = True):
self.compute_alpha()
self.wmse = [self.compute_weighted_MSE(fcn, self.alpha[:,iexp:iexp+1], update_y) for iexp, fcn in enumerate(self.fcn_list)]
return self.wmse