Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added __pycache__/test.cpython-311.pyc
Binary file not shown.
Binary file added __pycache__/utils.cpython-311.pyc
Binary file not shown.
Binary file added data/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file added data/__pycache__/dataset.cpython-311.pyc
Binary file not shown.
Binary file added data/__pycache__/musdb.cpython-311.pyc
Binary file not shown.
Binary file added data/__pycache__/utils.cpython-311.pyc
Binary file not shown.
Binary file added model/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file added model/__pycache__/conv.cpython-311.pyc
Binary file not shown.
Binary file added model/__pycache__/crop.cpython-311.pyc
Binary file not shown.
Binary file added model/__pycache__/resample.cpython-311.pyc
Binary file not shown.
Binary file added model/__pycache__/utils.cpython-311.pyc
Binary file not shown.
Binary file added model/__pycache__/waveunet.cpython-311.pyc
Binary file not shown.
15 changes: 10 additions & 5 deletions model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ def save_model(model, optimizer, state, path):
}, path)


def load_model(model, optimizer, path, cuda):

def load_model(model, optimizer, path, device):
if isinstance(model, torch.nn.DataParallel):
model = model.module # load state dict of wrapped module
if cuda:
checkpoint = torch.load(path)
else:
checkpoint = torch.load(path, map_location='cpu')

checkpoint = torch.load(path, map_location=device)

try:
model.load_state_dict(checkpoint['model_state_dict'])
except:
Expand All @@ -32,8 +32,13 @@ def load_model(model, optimizer, path, cuda):
k = k[len(prefix):]
model_state_dict_fixed[k] = v
model.load_state_dict(model_state_dict_fixed)

# Ensure the model is moved to the correct device
model.to(device)

if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

if 'state' in checkpoint:
state = checkpoint['state']
else:
Expand Down
8 changes: 5 additions & 3 deletions model/waveunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def get_input_size(self, output_size):
return curr_size

class Waveunet(nn.Module):
def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2):
def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2, device=None):
super(Waveunet, self).__init__()

self.device = device if device is not None else torch.device('cpu')
self.num_levels = len(num_channels)
self.strides = strides
self.kernel_size = kernel_size
Expand All @@ -111,7 +112,7 @@ def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_si
self.depth = depth
self.instruments = instruments
self.separate = separate

self.to(self.device)
# Only odd filter kernels allowed
assert(kernel_size % 2 == 1)

Expand Down Expand Up @@ -195,9 +196,10 @@ def forward_module(self, x, module):
:param module: Network module to be used for prediction
:return: Source estimates
'''
x = x.to(self.device)
shortcuts = []
out = x

#print(x.shape)
# DOWNSAMPLING BLOCKS
for block in module.downsampling_blocks:
out, short = block(out)
Expand Down
28 changes: 15 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@

def main(args):
#torch.backends.cudnn.benchmark=True # This makes dilated conv much faster for CuDNN 7.5

device = "cuda" if args.cuda else "cpu"
if args.mps:
device = torch.device('mps')
# MODEL
num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
[args.features*2**i for i in range(0, args.levels)]
target_outputs = int(args.output_size * args.sr)
model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size,
target_output_size=target_outputs, depth=args.depth, strides=args.strides,
conv_type=args.conv_type, res=args.res, separate=args.separate)
target_output_size=target_outputs, depth=args.depth, strides=args.strides,
conv_type=args.conv_type, res=args.res, separate=args.separate, device=device).to(device)

if args.cuda:
if args.cuda or args.mps:
model = model_utils.DataParallel(model)
print("move model to gpu")
model.cuda()
print(f"move model to {device}")
model.to(device)

print('model: ', model)
print('parameter count: ', str(sum(p.numel() for p in model.parameters())))
Expand Down Expand Up @@ -75,7 +77,7 @@ def main(args):
# LOAD MODEL CHECKPOINT IF DESIRED
if args.load_model is not None:
print("Continuing training full model from checkpoint " + str(args.load_model))
state = model_utils.load_model(model, optimizer, args.load_model, args.cuda)
state = model_utils.load_model(model, optimizer, args.load_model, device)

print('TRAINING START')
while state["worse_epochs"] < args.patience:
Expand All @@ -85,10 +87,9 @@ def main(args):
with tqdm(total=len(train_data) // args.batch_size) as pbar:
np.random.seed()
for example_num, (x, targets) in enumerate(dataloader):
if args.cuda:
x = x.cuda()
for k in list(targets.keys()):
targets[k] = targets[k].cuda()
x = x.to(device)
for k in list(targets.keys()):
targets[k] = targets[k].to(device)

t = time.time()

Expand Down Expand Up @@ -139,13 +140,12 @@ def main(args):
print("Saving model...")
model_utils.save_model(model, optimizer, state, checkpoint_path)


#### TESTING ####
# Test loss
print("TESTING")

# Load best model based on validation loss
state = model_utils.load_model(model, None, state["best_checkpoint"], args.cuda)
state = model_utils.load_model(model, None, state["best_checkpoint"], device)
test_loss = validate(args, model, criterion, test_data)
print("TEST FINISHED: LOSS: " + str(test_loss))
writer.add_scalar("test_loss", test_loss, state["step"])
Expand Down Expand Up @@ -176,6 +176,8 @@ def main(args):
help="List of instruments to separate (default: \"bass drums other vocals\")")
parser.add_argument('--cuda', action='store_true',
help='Use CUDA (default: False)')
parser.add_argument('--mps', action='store_true',
help='Use MPS on M1 Mac (default: False)')
parser.add_argument('--num_workers', type=int, default=1,
help='Number of data loader worker threads (default: 1)')
parser.add_argument('--features', type=int, default=32,
Expand Down