-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
67 lines (61 loc) · 2.59 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn.functional as F
from torch_geometric.nn import HANConv
import xgboost as xgb
from sklearn.metrics import accuracy_score, classification_report
# HAN Model Implementation
class HAN(torch.nn.Module):
def __init__(self, in_channels, out_channels, metadata, heads1, dropout1, heads2, dropout2):
super(HAN, self).__init__()
self.conv1 = HANConv(in_channels, out_channels, metadata=metadata, heads=heads1, dropout=dropout1)
self.conv2 = HANConv(out_channels, out_channels, metadata=metadata, heads=heads2, dropout=dropout2)
def forward(self, x_dict, edge_index_dict):
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: F.elu(x) for key, x in x_dict.items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return x_dict
def train_xgboost_model(train_embeddings, train_labels, test_embeddings, test_labels, n_estimators, learning_rate, max_depth, subsample, colsample_bytree, early_stopping_rounds):
xgb_model = xgb.XGBClassifier(
objective='binary:logistic',
use_label_encoder=False,
n_estimators=n_estimators,
learning_rate=learning_rate,
max_depth=max_depth,
subsample=subsample,
colsample_bytree=colsample_bytree,
early_stopping_rounds=early_stopping_rounds
)
xgb_model.fit(
train_embeddings, train_labels,
eval_set=[(train_embeddings, train_labels), (test_embeddings, test_labels)],
verbose=True
)
test_preds = xgb_model.predict(test_embeddings)
accuracy = accuracy_score(test_labels, test_preds)
report = classification_report(test_labels, test_preds, output_dict=True)
return accuracy, report, xgb_model
def load_xgboost_model():
"""Initialize the XGBoost model with the desired parameters."""
xgb_model = xgb.XGBClassifier(
n_estimators=50000,
learning_rate=0.1,
max_depth=6,
subsample=1,
colsample_bytree=1,
use_label_encoder=False,
eval_metric='logloss',
early_stopping_rounds=150
)
return xgb_model
def run_xgboost_inference(xgb_model, test_embeddings, test_labels):
"""Perform inference using a pre-trained XGBoost model."""
test_preds = xgb_model.predict(test_embeddings)
accuracy = accuracy_score(test_labels, test_preds)
report = classification_report(test_labels, test_preds)
return accuracy, report
def load_model(model, model_path, device):
"""Load the saved HAN model from a file."""
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
return model