-
Notifications
You must be signed in to change notification settings - Fork 11
/
evaluate_half_tensor.py
251 lines (198 loc) · 10.1 KB
/
evaluate_half_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
""" Unified home for training and evaluation. Imports model and dataloader."""
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
# To unpack ADNI data
import pickle
import random
# Import network
import sys
sys.path.insert(1, './model')
from network import Network
from data_loader import MRIData
import argparse
parser = argparse.ArgumentParser(description='Train and validate network.')
parser.add_argument('--disable-cuda', action='store_true', default=False,
help='Disable CUDA')
args = parser.parse_args()
args.device = None
print(args.disable_cuda)
if torch.cuda.is_available():
print("Using CUDA. : )")
# torch.set_default_tensor_type('torch.cuda.FloatTensor')
args.device = torch.device('cuda')
else:
print("We aren't using CUDA.")
args.device = torch.device('cpu')
# torch.manual_seed(314159265368979323846264338327950288419716939937510) # for reproducibility for testing purposes. Delete during actual training.
# NOTE: don't change the seed numbers as we debug, or we might introduce user bias into the model!
torch.manual_seed(1)
random.seed(1)
# ============= Hyperparameters ===================
BATCH_SIZE = 10 # FIXME: used to be 64
# Dimensionality of the data outputted by the LSTM,
# forwarded to the final dense layer. THIS IS A GUESS CURRENTLY.
LSTM_output_size = 16
input_size = 1 # FIXME: used to be 3 # Size of the processed MRI scans fed into the CNN.
output_dimension = 2 # the number of predictions the model will make
# 2 used for binary prediction for each image.
# update the splicing used in train()
learning_rate = 0.1
training_epochs = 5
# The size of images passed, as a tuple
data_shape = (200,200,150)
# Other hyperparameters unlisted: the depth of the model, the kernel size, the padding, the channel restriction.
# ========== TODO: Import Data ==============
# expected format:
# training_data stores batches of MRI's and classifications like this: [batch,batch,batch] : )
# each batch should be in form
# [Bunch of MRIs, Bunch of Classifications]
# and each 'bunch' in the batch should be grouped by patient
# Bunch of MRIs = [Patient 1 MRIs, Patient 2 MRIs,...]
# Bunch of Classifications = [Patient 1 classifications, Patient 2 Classifications...]
# the Classifications should be binary 0,1 probabilities in output_dimension dimensions. Perhaps something like this:
# [chance_of_normality: 0 , chance of MCI: 0, chance of AD: 1]
MRI_images_list = pickle.load(open("./Data/Combined_MRI_List.pkl", "rb"))
random.shuffle(MRI_images_list)
# NOTE: simply for testing out the data loader, take the first three images from the list
# print(MRI_images_list)
# NOTE: For testing on Farnam cluster
# MRI_images_list = MRI_images_list[:4]
# >>>>>>> 4da380244e90113833ad6d03e0483fe38046367c
# How much of the data will be reserved for testing?
train_size = int(0.7 * len(MRI_images_list))
# Split list
training_list = MRI_images_list[:train_size]
test_list = MRI_images_list[train_size:]
# print(MRI_images_list)
DATA_ROOT_DIR = './'
train_dataset = MRIData(DATA_ROOT_DIR, training_list)
test_dataset = MRIData(DATA_ROOT_DIR, test_list)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
training_data = train_loader
test_data = test_loader
# ================== Define Model =========================================
model = Network(input_size, data_shape, output_dimension).to(args.device)
loss_function = nn.CrossEntropyLoss()
# Perhaps use ADAM, if SGD doesn't give good results
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# training function
def train(model,training_data,optimizer,criterion):
""" takes (model, training data, optimizer, loss function)"""
# activate training mode
model.train()
# initialize the per epoch loss
epoch_loss = 0
epoch_length = len(training_data)
for i, patient_data in enumerate(training_data):
if i % (math.floor(epoch_length / 5) + 1) == 0: print(f"\t\tTraining Progress:{i / len(training_data) * 100}%")
# Clear gradients
model.zero_grad()
torch.cuda.empty_cache() # clear cuda memory
batch_loss=torch.tensor(0.0).to(args.device,dtype=torch.half)
# clear the LSTM hidden state after each patient
# print("Well, the model.hidden is",model.hidden)
model.hidden = model.init_hidden()
# print("Patient data is ",patient_data, "with shape",patient_data['images'].shape)
#get the MRI's and classifications for the current patient
patient_markers = patient_data['num_images']
patient_MRIs = patient_data["images"].to(args.device,dtype=torch.half)
# patient_MRI = patient_MRI.to(device=args.device)
# print(patient_MRI.shape)
patient_classifications = patient_data["label"]
print("Patient batch classes ", patient_classifications)
for x in range(len(patient_MRIs)):
try:
# clear hidden states to give each patient a clean slate
model.hidden = model.init_hidden()
single_patient_MRIs = patient_MRIs[x][:patient_markers[x]].view(-1,1,data_shape[0],data_shape[1],data_shape[2])
# print("Single patient MRIs are ",single_patient_MRIs,"with shape",single_patient_MRIs.shape)
patient_diagnosis = patient_classifications[x]
# print("patient diagnosis is ",patient_diagnosis)
# print("single_patient_MRI size 0 gives ",single_patient_MRIs.size(0))
patient_endstate = torch.ones(single_patient_MRIs.size(0)) * patient_diagnosis
patient_endstate = patient_endstate.long().to(args.device)
out = model(single_patient_MRIs)
if len(out.shape)==1:
out = out[None,...]# in the case of a single input, we need padding
print("model predictions are ",out)
print("patient endstate is ",patient_endstate)
model_predictions = out
# print("model predictions are ",model_predictions)
loss = criterion(model_predictions, patient_endstate)
batch_loss += loss
except Exception as e:
print("EXCEPTION CAUGHT:",e)
batch_loss.backward()
print("batch loss is",batch_loss)
optimizer.step()
epoch_loss += batch_loss
if epoch_length == 0: epoch_length = 0.000001
return epoch_loss / epoch_length
def test(model, test_data, criterion):
"""takes (model, test_data, loss function) and returns the epoch loss."""
model.eval()
epoch_loss = torch.tensor(0.0)
epoch_length = len(test_data)
for i, patient_data in enumerate(test_data):
if i % (math.floor(epoch_length / 5) + 1) == 0: print(f"\t\tTesting Progress:{i / len(test_data) * 100}%")
# Clear gradients
model.zero_grad()
torch.cuda.empty_cache() # clear cuda memory
# clear the LSTM hidden state after each patient
# print("Well, the model.hidden is",model.hidden)
model.hidden = model.init_hidden()
# print("Patient data is ", patient_data, "with shape", patient_data['images'].shape)
# get the MRI's and classifications for the current patient
patient_markers = patient_data['num_images']
patient_MRIs = patient_data["images"].to(args.device,dtype=torch.half)
# patient_MRI = patient_MRI.to(device=args.device)
# print(patient_MRI.shape)
patient_classifications = patient_data["label"]
print("Patient batch classes ", patient_classifications)
for x in range(len(patient_MRIs)):
try:
# clear hidden states to give each patient a clean slate
model.hidden = model.init_hidden()
single_patient_MRIs = patient_MRIs[x][:patient_markers[x]].view(-1, 1, data_shape[0], data_shape[1],
data_shape[2])
# print("Single patient MRIs are ", single_patient_MRIs, "with shape", single_patient_MRIs.shape)
patient_diagnosis = patient_classifications[x]
patient_endstate = torch.ones(single_patient_MRIs.size(0)) * patient_diagnosis
patient_endstate = patient_endstate.long().to(args.device)
out = model(single_patient_MRIs)
if len(out.shape)==1:
out = out[None,...]# in the case of a single input, we need padding
model_predictions = out
# print("model predictions are ",model_predictions)
loss = criterion(model_predictions, patient_endstate)
epoch_loss += loss
print("Current test loss ",loss)
except Exception as e:
epoch_length -= 1
print("EXCEPTION CAUGHT:", e)
if epoch_length == 0: epoch_length = 0.000001
return epoch_loss / epoch_length
# perform training and measure test accuracy. Save best performing model.
best_test_accuracy = float('inf')
# this evaluation workflow was adapted from Ben Trevett's design on https://github.com/bentrevett/pytorch-seq2seq/blob/master/1%20-%20Sequence%20to%20Sequence%20Learning%20with%20Neural%20Networks.ipynb
for epoch in range(training_epochs):
start_time = time.time()
train_loss = train(model, training_data, optimizer, loss_function)
test_loss = test(model, test_data, loss_function)
end_time = time.time()
epoch_mins = math.floor((end_time-start_time)/60)
epoch_secs = math.floor((end_time-start_time)%60)
print(f"Hurrah! Epoch {epoch + 1}/{training_epochs} concludes. | Time: {epoch_mins}m {epoch_secs}s")
print(f"\tTrain Loss: {train_loss:.3f}| Train Perplexity: {math.exp(train_loss):7.3f}")
print(f"\tTest Loss: {test_loss:.3f}| Test Perplexity: {math.exp(test_loss):7.3f}")
if test_loss<best_test_accuracy:
print("...that was our best test accuracy yet!")
best_test_accuracy=test_loss
torch.save(model.state_dict(),'ad-model.pt')