Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/*
dataset/*
File renamed without changes.
83 changes: 83 additions & 0 deletions lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random

#------------------------------------------------#
# Deep Learning Model #
#------------------------------------------------#
# we require having forward, fit, predict, and predict_proba methods to interface with the
# EMGClassifier class. Everything else is extra.
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class LSTM(nn.Module):
def __init__(self, n_output, n_features, hidden_layers=32):
super().__init__()
self.lstm = nn.LSTM(input_size=n_features, hidden_size=hidden_layers, num_layers=2, batch_first=True)
self.output_layer = nn.Linear(hidden_layers, n_output)
self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x, _ = self.lstm(x)
x = self.output_layer(x)
return self.softmax(x)

def fit(self, dataloader_dictionary, learning_rate=1e-3, num_epochs=100, verbose=True):
# what device should we use (GPU if available)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# get the optimizer and loss function ready
optimizer = optim.Adam(self.parameters(), lr=learning_rate)
loss_function = nn.CrossEntropyLoss()
self.log = {"training_loss":[],
"validation_loss": [],
"training_accuracy": [],
"validation_accuracy": []}

for epoch in range(num_epochs):
#training set
self.train()
for data, labels in dataloader_dictionary["training_dataloader"]:
optimizer.zero_grad()
data = data.to(device)
labels = labels.to(device)
output = self.forward(data)
loss = loss_function(output, labels)
loss.backward()
optimizer.step()
acc = sum(torch.argmax(output,1) == labels)/labels.shape[0]
# log it
self.log["training_loss"] += [(epoch, loss.item())]
self.log["training_accuracy"] += [(epoch, acc)]
# validation set
self.eval()
for data, labels in dataloader_dictionary["validation_dataloader"]:
data = data.to(device)
labels = labels.to(device)
output = self.forward(data)
loss = loss_function(output, labels)
acc = sum(torch.argmax(output,1) == labels)/labels.shape[0]
# log it
self.log["validation_loss"] += [(epoch, loss.item())]
self.log["validation_accuracy"] += [(epoch, acc)]
if verbose:
epoch_trloss = np.mean([i[1] for i in self.log['training_loss'] if i[0]==epoch])
epoch_tracc = np.mean([i[1] for i in self.log['training_accuracy'] if i[0]==epoch])
epoch_valoss = np.mean([i[1] for i in self.log['validation_loss'] if i[0]==epoch])
epoch_vaacc = np.mean([i[1] for i in self.log['validation_accuracy'] if i[0]==epoch])
print(f"{epoch}: trloss:{epoch_trloss:.2f} tracc:{epoch_tracc:.2f} valoss:{epoch_valoss:.2f} vaacc:{epoch_vaacc:.2f}")
self.eval()

def predict(self, x):
if type(x) != torch.Tensor:
x = torch.tensor(x, dtype=torch.float32)
y = self.forward(x)
predictions = torch.argmax(y, dim=1)
return predictions.cpu().detach().numpy()

def predict_proba(self, x):
if type(x) != torch.Tensor:
x = torch.tensor(x, dtype=torch.float32)
y = self.forward(x)
return y.cpu().detach().numpy()
77 changes: 65 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
from libemg.datasets import OneSubjectMyoDataset
from libemg.data_handler import OfflineDataHandler
from libemg.filtering import Filter
from libemg.feature_extractor import FeatureExtractor
from libemg.emg_classifier import EMGClassifier
from libemg.offline_metrics import OfflineMetrics
from deeplearningspecificcode import fix_random_seed, make_data_loader, CNN
from cnn import fix_random_seed, make_data_loader, CNN
from lstm import LSTM

import numpy as np

def main():
# make our results repeatable
fix_random_seed(seed_value=0, use_cuda=True)
# download the dataset from the internet
dataset = OneSubjectMyoDataset(save_dir='dataset/',
redownload=False)
dataset = OneSubjectMyoDataset(save_dir='dataset/')
odh = dataset.prepare_data(format=OfflineDataHandler)

# split the dataset into a train, validation, and test set
Expand Down Expand Up @@ -44,44 +45,94 @@ def main():
valid_windows, valid_metadata = valid_data.parse_windows(window_size, window_increment)
test_windows, test_metadata = test_data.parse_windows( window_size, window_increment)


# ------------------------------------ #
# Setup for the CNN #
# -------------------------------------#
#

# we can even make a dictionary of parameters that get passed into the training
# process of the deep learning model
dl_dictionary = {"learning_rate": 1e-4,
"num_epochs": 20,
"verbose": True}

#--------------------------------------#
# Now we need to interface custom code #
#--------------------------------------#
# libemg supports deep learning, but we need to prepare the dataloaders
train_dataloader = make_data_loader(train_windows, train_metadata["classes"])
valid_dataloader = make_data_loader(valid_windows, valid_metadata["classes"])



# let's make the dictionary of dataloaders
dataloader_dictionary = {"training_dataloader": train_dataloader,
cnn_dataloader_dictionary = {"training_dataloader": train_dataloader,
"validation_dataloader": valid_dataloader}
# We need to tell the libEMG EMGClassifier that we are using a custom model
model = CNN(n_output = np.unique(np.vstack(odh.classes[:])).shape[0],
n_channels = train_windows.shape[1],
n_samples = train_windows.shape[2],
n_filters = 64)
# we can even make a dictionary of parameters that get passed into the training
# process of the deep learning model
dl_dictionary = {"learning_rate": 1e-4,
"num_epochs": 50,
"verbose": True}

#--------------------------------------#
# Back to library code #
#--------------------------------------#
# Now that we've made the custom classifier object, libEMG knows how to
# interpret it when passed in the dataloader_dictionary. Everything happens behind the scenes.
classifier = EMGClassifier()
classifier.fit(model, dataloader_dictionary=dataloader_dictionary, parameters=dl_dictionary)
classifier.fit(model, dataloader_dictionary=cnn_dataloader_dictionary, parameters=dl_dictionary)
# get the classifier's predictions on the test set
preds = classifier.run(test_windows)
om = OfflineMetrics()
metrics = ['CA','AER','INS','REJ_RATE','CONF_MAT','RECALL','PREC','F1']
results = om.extract_offline_metrics(metrics, test_metadata['classes'], preds[0], null_label=2)
print('\n------------------ CNN Results ---------------')
for key in results:
print(f"{key}: {results[key]}")
print('-------------------------------------------------\n')

# and conviniently, you can access everything from the training process here
# model.log -> has training loss, accuracy, validation loss, accuracy for every batch

# ------------------------------------ #
# Setup for the LSTM #
# -------------------------------------#
fe = FeatureExtractor()

train_features = EMGClassifier()._format_data(fe.extract_feature_group('HTD', train_windows))
val_features = EMGClassifier()._format_data(fe.extract_feature_group('HTD', valid_windows))
train_dataloader = make_data_loader(train_features, train_metadata["classes"])
valid_dataloader = make_data_loader(val_features, valid_metadata["classes"])
lstm_dataloader_dictionary = {"training_dataloader": train_dataloader,
"validation_dataloader": valid_dataloader}

# We need to tell the libEMG EMGClassifier that we are using a custom model
model = LSTM(n_output = np.unique(np.vstack(odh.classes[:])).shape[0],
n_features = train_features.shape[1],
hidden_layers = 128)


#--------------------------------------#
# Back to library code #
#--------------------------------------#
# Now that we've made the custom classifier object, libEMG knows how to
# interpret it when passed in the dataloader_dictionary. Everything happens behind the scenes.
classifier = EMGClassifier()
classifier.fit(model, dataloader_dictionary=lstm_dataloader_dictionary, parameters=dl_dictionary)
# get the classifier's predictions on the test set
preds = classifier.run(EMGClassifier()._format_data(fe.extract_feature_group('HTD',test_windows)))
om = OfflineMetrics()
metrics = ['CA','AER','INS','REJ_RATE','CONF_MAT','RECALL','PREC','F1']
results = om.extract_offline_metrics(metrics, test_metadata['classes'], preds[0], null_label=2)
print('\n------------------ LSTM Results ---------------')
for key in results:
print(f"{key}: {results[key]}")
print('-------------------------------------------------\n')

# and conveniently, you can access everything from the training process here
# model.log -> has training loss, accuracy, validation loss, accuracy for every batch



# We could also train a model with bells and whistles (rejection, velocity control, majority vote):
# We just need to pass the training windows and training labels to the fit function or velocity control
Expand All @@ -100,14 +151,16 @@ def main():
n_channels = train_windows.shape[1],
n_samples = train_windows.shape[2],
n_filters = 64)
classifier.fit(model, feature_dictionary=feature_dictionary, dataloader_dictionary=dataloader_dictionary, parameters=dl_dictionary)
classifier.fit(model, feature_dictionary=feature_dictionary, dataloader_dictionary=cnn_dataloader_dictionary, parameters=dl_dictionary)
# get the classifier's predictions on the test set
preds = classifier.run(test_windows)
om = OfflineMetrics()
metrics = ['CA','AER','INS','REJ_RATE','CONF_MAT','RECALL','PREC','F1']
results = om.extract_offline_metrics(metrics, test_metadata['classes'], preds[0], null_label=2)
print('\n------------------ CNN w/ Rejection Results ---------------')
for key in results:
print(f"{key}: {results[key]}")
print('-------------------------------------------------------------\n')



Expand Down