Skip to content

Commit 15241d4

Browse files
committed
wip: added sliced dataset to pytorch functionality
1 parent 4ece9b6 commit 15241d4

File tree

5 files changed

+44
-45
lines changed

5 files changed

+44
-45
lines changed

sdk/diffgram/core/diffgram_dataset_iterator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from PIL import Image, ImageDraw
22
from imageio import imread
3-
3+
import numpy as np
44

55
class DiffgramDatasetIterator:
66

@@ -42,7 +42,7 @@ def get_image_data(self, diffgram_file):
4242
raise Exception('Pytorch datasets only support images. Please provide only file_ids from images')
4343

4444
def get_file_instances(self, diffgram_file):
45-
if diffgram_file['type'] not in ['image', 'frame']:
45+
if diffgram_file.type not in ['image', 'frame']:
4646
raise NotImplementedError('File type "{}" is not supported yet'.format(diffgram_file['type']))
4747

4848
image = self.get_image_data(diffgram_file)

sdk/diffgram/core/directory.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,19 @@ def all_file_ids(self):
9696
page_num = 1
9797
result = []
9898
while page_num is not None:
99-
diffgram_files = self.list_files(limit = 1000, page_num = page_num, file_view_mode = 'ids_only')
99+
diffgram_ids = self.list_files(limit = 1000, page_num = page_num, file_view_mode = 'ids_only')
100100
page_num = self.file_list_metadata['next_page']
101-
result = result + diffgram_files
101+
result = result + diffgram_ids
102102
return result
103103

104104
def slice(self, query):
105105
from diffgram.core.sliced_directory import SlicedDirectory
106-
result = self.list_files(
106+
# Get the first page to validate syntax.
107+
self.list_files(
107108
limit = 25,
108109
page_num = 1,
109-
file_view_mode = 'ids_only'
110+
file_view_mode = 'ids_only',
111+
query = query,
110112
)
111113
sliced_dataset = SlicedDirectory(
112114
client = self.client,
@@ -120,7 +122,6 @@ def to_pytorch(self, transform = None):
120122
Transforms the file list inside the dataset into a pytorch dataset.
121123
:return:
122124
"""
123-
from diffgram.core.sliced_directory import SlicedDirectory
124125
file_id_list = self.all_file_ids()
125126
pytorch_dataset = DiffgramPytorchDataset(
126127
project = self.client,
@@ -211,7 +212,6 @@ def list_files(
211212
else:
212213
logging.info("Using Default Dataset ID " + str(self.client.directory_id))
213214
directory_id = self.client.directory_id
214-
#print("directory_id", directory_id)
215215

216216
metadata = {'metadata' :
217217
{
@@ -222,7 +222,8 @@ def list_files(
222222
'media_type': "All",
223223
'page': page_num,
224224
'file_view_mode': file_view_mode,
225-
'search_term': search_term
225+
'search_term': search_term,
226+
'query': query
226227
}
227228
}
228229

@@ -245,14 +246,17 @@ def list_files(
245246
self.file_list_metadata = data.get('metadata')
246247
# TODO would like this to perhaps be a seperate function
247248
# ie part of File_Constructor perhaps
248-
file_list = []
249-
for file_json in file_list_json:
250-
file = File.new(
251-
client = self.client,
252-
file_json = file_json)
253-
file_list.append(file)
254-
255-
return file_list
249+
if file_view_mode == 'ids_only':
250+
return file_list_json
251+
else:
252+
file_list = []
253+
for file_json in file_list_json:
254+
file = File.new(
255+
client = self.client,
256+
file_json = file_json)
257+
file_list.append(file)
258+
259+
return file_list
256260

257261

258262
def get(self,

sdk/diffgram/core/sliced_directory.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
from diffgram.core.directory import Directory
22
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
33

4+
45
class SlicedDirectory(Directory):
56

67
def __init__(self, client, original_directory: Directory, query: str):
78
self.original_directory = original_directory
89
self.query = query
910
self.client = client
11+
# Share the same ID from the original directory as this is just an in-memory construct for better semantics.
12+
self.id = original_directory.id
1013

1114
def all_file_ids(self):
1215
page_num = 1
1316
result = []
1417
while page_num is not None:
18+
print('slcied query', self.query)
1519
diffgram_files = self.list_files(limit = 1000,
1620
page_num = page_num,
1721
file_view_mode = 'ids_only',
@@ -20,7 +24,6 @@ def all_file_ids(self):
2024
result = result + diffgram_files
2125
return result
2226

23-
2427
def to_pytorch(self, transform = None):
2528
"""
2629
Transforms the file list inside the dataset into a pytorch dataset.
@@ -34,4 +37,3 @@ def to_pytorch(self, transform = None):
3437

3538
)
3639
return pytorch_dataset
37-

sdk/diffgram/pytorch_diffgram/diffgram_pytorch_dataset.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import os
2-
3-
import numpy as np
4-
import scipy as sp
5-
1+
from torch.utils.data import Dataset, DataLoader
2+
import torch as torch # type: ignore
63
from diffgram.core.diffgram_dataset_iterator import DiffgramDatasetIterator
74

85

@@ -15,20 +12,12 @@ def __init__(self, project, diffgram_file_id_list = None, transform = None):
1512
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
1613
:param transform (callable, optional): Optional transforms to be applied on a sample
1714
"""
18-
super(DiffgramDatasetIterator, self).__init__(project, diffgram_file_id_list)
19-
global torch, Dataset, DataLoader
20-
try:
21-
import torch as torch # type: ignore
22-
from torch.utils.data import Dataset, DataLoader
23-
except ModuleNotFoundError:
24-
raise ModuleNotFoundError(
25-
"'torch' module should be installed to convert the Dataset into pytorch format"
26-
)
15+
super(DiffgramPytorchDataset, self).__init__(project, diffgram_file_id_list)
16+
2717
self.diffgram_file_id_list = diffgram_file_id_list
2818

2919
self.project = project
3020
self.transform = transform
31-
self.__validate_file_ids()
3221

3322
def __len__(self):
3423
return len(self.diffgram_file_id_list)

sdk/diffgram/tensorflow_diffgram/pytorch_test.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,26 @@
1818

1919

2020
# Draw
21-
import matplotlib.pyplot as plt
22-
from PIL import Image, ImageDraw
23-
img = Image.new("L", [diffgram_dataset[0]['diffgram_file'].image['width'], diffgram_dataset[0]['diffgram_file'].image['height']], 0)
24-
mask1 = diffgram_dataset[0]['polygon_mask_list'][0]
25-
mask2 = diffgram_dataset[0]['polygon_mask_list'][1]
26-
plt.figure()
27-
plt.subplot(1,2,1)
28-
# plt.imshow(img, 'gray', interpolation='none')
29-
plt.imshow(mask1, 'jet', interpolation='none', alpha=0.7)
30-
plt.imshow(mask2, 'Oranges', interpolation='none', alpha=0.7)
31-
plt.show()
21+
def display_masks():
22+
import matplotlib.pyplot as plt
23+
from PIL import Image, ImageDraw
24+
img = Image.new("L", [diffgram_dataset[0]['diffgram_file'].image['width'],
25+
diffgram_dataset[0]['diffgram_file'].image['height']], 0)
26+
mask1 = diffgram_dataset[0]['polygon_mask_list'][0]
27+
mask2 = diffgram_dataset[0]['polygon_mask_list'][1]
28+
plt.figure()
29+
plt.subplot(1, 2, 1)
30+
# plt.imshow(img, 'gray', interpolation='none')
31+
plt.imshow(mask1, 'jet', interpolation = 'none', alpha = 0.7)
32+
plt.imshow(mask2, 'Oranges', interpolation = 'none', alpha = 0.7)
33+
plt.show()
3234

3335

3436
# Dataset Example
3537

3638
dataset = project.directory.get('Default')
3739

40+
pytorch_dataset = dataset.to_pytorch()
41+
3842
sliced_dataset = dataset.slice(query = 'labels.sheep > 0 or labels.sofa > 0')
3943

0 commit comments

Comments
 (0)