-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.h
More file actions
105 lines (88 loc) · 2.76 KB
/
model.h
File metadata and controls
105 lines (88 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#pragma once
#include "common.h"
#include "corpus.h"
enum action_t {COPY, INIT, FREE};
class model
{
public:
model(corpus* corp) : corp(corp)
{
nUsers = corp->nUsers;
nItems = corp->nItems;
nVotes = corp->nVotes;
//leave out two for each user
test_per_user = new pair<int, long long>[nUsers];
val_per_user = new pair<int, long long>[nUsers];
for (int u = 0; u < nUsers; u++)
{
test_per_user[u] = make_pair(-1, -1); // -1 denotes empty
val_per_user[u] = make_pair(-1, -1);
}
//leave out two for each item
test_per_item = new pair<int, long long>[nItems];
val_per_item = new pair<int, long long>[nItems];
for (int i = 0; i < nItems; i++)
{
test_per_item[i] = make_pair(-1, -1); // -1 denotes empty
val_per_item[i] = make_pair(-1, -1);
}
// split into training set and valid set and test set
pos_per_user = new map<int, long long>[nUsers];
pos_per_item = new map<int, long long>[nItems];
for (int x = 0; x < nVotes; x++)
{
vote* V = corp->V.at(x);
int user = V->user;
int item = V->item;
long long voteTime = V->voteTime;
if (test_per_user[user].first == -1)
test_per_user[user] = make_pair(item, voteTime);
else if (val_per_user[user].first == -1)
val_per_user[user] = make_pair(item, voteTime);
else
pos_per_user[user][item] = voteTime;
if (test_per_item[item].first == -1)
test_per_item[item] = make_pair(user, voteTime);
else if (val_per_item[item].first == -1)
val_per_item[item] = make_pair(user, voteTime);
else
pos_per_item[item][user] = voteTime;
}
num_pos_events = 0;
for (int u = 0; u < nUsers; u++)
{
num_pos_events += pos_per_user[u].size();
}
}
~model()
{
delete[] pos_per_user;
delete[] pos_per_item;
delete[] test_per_user;
delete[] val_per_user;
}
/* Model parameters */
int NW; // Total number of parameters
double* W; // Contiguous version of all parameters
double* bestW;
/* Corpus related */
corpus* corp;
int nUsers; // Number of users
int nItems; // Number of items
int nVotes; // Number of ratings
map<int, long long>* pos_per_user;
map<int, long long>* pos_per_item;
pair<int, long long>* val_per_user;
pair<int, long long>* test_per_user;
pair<int, long long>* val_per_item;
pair<int, long long>* test_per_item;
int num_pos_events;
virtual void AUC_IR(double* AUC_val, double* AUC_test, double* std); // AUC for item ranking
virtual void AUC_AR(double* AUC_val, double* AUC_test, double* std); // AUC for audience retrieval
virtual void copyBestModel();
virtual void saveModel(const char* path);
virtual void loadModel(const char* path);
virtual string tostring();
private:
virtual double prediction(int user, int item) = 0;
};