Skip to content

Commit 630edea

Browse files
committed
add keras.dataset.cifar10.
1 parent 4359d35 commit 630edea

File tree

9 files changed

+415
-15
lines changed

9 files changed

+415
-15
lines changed

README.md

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
![logo](docs/assets/tf.net.logo.png)
22

3-
**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
3+
**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. TensorFlow.NET has built-in Keras high-level interface and is released as an independent package [TensorFlow.Keras](https://www.nuget.org/packages/TensorFlow.Keras/).
44

55
[![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community)
66
[![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net)
@@ -14,28 +14,31 @@
1414

1515
![tensors_flowing](docs/assets/tensors_flowing.gif)
1616

17-
### Why TensorFlow.NET ?
17+
### Why TensorFlow in .NET/ C# ?
1818

1919
`SciSharp STACK`'s mission is to bring popular data science technology into the .NET world and to provide .NET developers with a powerful Machine Learning tool set without reinventing the wheel. Since the APIs are kept as similar as possible you can immediately adapt any existing Tensorflow code in C# with a zero learning curve. Take a look at a comparison picture and see how comfortably a Tensorflow/Python script translates into a C# program with TensorFlow.NET.
2020

2121
![pythn vs csharp](docs/assets/syntax-comparision.png)
2222

2323
SciSharp's philosophy allows a large number of machine learning code written in Python to be quickly migrated to .NET, enabling .NET developers to use cutting edge machine learning models and access a vast number of Tensorflow resources which would not be possible without this project.
2424

25-
In comparison to other projects, like for instance TensorFlowSharp which only provide Tensorflow's low-level C++ API and can only run models that were built using Python, Tensorflow.NET also implements Tensorflow's high level API where all the magic happens. This computation graph building layer is still under active development. Once it is completely implemented you can build new Machine Learning models in C#.
25+
In comparison to other projects, like for instance [TensorFlowSharp](https://www.nuget.org/packages/TensorFlowSharp/) which only provide Tensorflow's low-level C++ API and can only run models that were built using Python, Tensorflow.NET also implements Tensorflow's high level API where all the magic happens. This computation graph building layer is still under active development. Once it is completely implemented you can build new Machine Learning models in C#.
2626

2727
### How to use
2828

29-
| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 |
30-
| ----------- | ------------- | -------------- | ------------- |
31-
| tf.net 0.20 | | x | x |
32-
| tf.net 0.15 | x | x | |
33-
| tf.net 0.14 | x | | |
29+
| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 |
30+
| ------------------------- | ------------- | -------------- | ------------- |
31+
| tf.net 0.30, tf.keras 0.1 | | | x |
32+
| tf.net 0.20 | | x | x |
33+
| tf.net 0.15 | x | x | |
34+
| tf.net 0.14 | x | | |
3435

3536
Install TF.NET and TensorFlow binary through NuGet.
3637
```sh
3738
### install tensorflow C# binding
3839
PM> Install-Package TensorFlow.NET
40+
### install keras for tensorflow
41+
PM> Install-Package TensorFlow.Keras
3942

4043
### Install tensorflow binary
4144
### For CPU version
@@ -45,13 +48,14 @@ PM> Install-Package SciSharp.TensorFlow.Redist
4548
PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
4649
```
4750

48-
Import TF.NET in your project.
51+
Import TF.NET and Keras API in your project.
4952

5053
```cs
5154
using static Tensorflow.Binding;
55+
using static Tensorflow.KerasApi;
5256
```
5357

54-
Linear Regression:
58+
Linear Regression in `Eager` mode:
5559

5660
```c#
5761
// Parameters
@@ -92,6 +96,52 @@ foreach (var step in range(1, training_steps + 1))
9296

9397
Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube).
9498

99+
Toy version of `ResNet` in `Keras` functional API:
100+
101+
```csharp
102+
// input layer
103+
var inputs = keras.Input(shape: (32, 32, 3), name: "img");
104+
105+
// convolutional layer
106+
var x = layers.Conv2D(32, 3, activation: "relu").Apply(inputs);
107+
x = layers.Conv2D(64, 3, activation: "relu").Apply(x);
108+
var block_1_output = layers.MaxPooling2D(3).Apply(x);
109+
110+
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_1_output);
111+
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x);
112+
var block_2_output = layers.add(x, block_1_output);
113+
114+
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_2_output);
115+
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x);
116+
var block_3_output = layers.add(x, block_2_output);
117+
118+
x = layers.Conv2D(64, 3, activation: "relu").Apply(block_3_output);
119+
x = layers.GlobalAveragePooling2D().Apply(x);
120+
x = layers.Dense(256, activation: "relu").Apply(x);
121+
x = layers.Dropout(0.5f).Apply(x);
122+
123+
// output layer
124+
var outputs = layers.Dense(10).Apply(x);
125+
126+
// build keras model
127+
model = keras.Model(inputs, outputs, name: "toy_resnet");
128+
model.summary();
129+
130+
// compile keras model in tensorflow static graph
131+
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
132+
loss: keras.losses.CategoricalCrossentropy(from_logits: true),
133+
metrics: new[] { "acc" });
134+
135+
// prepare dataset
136+
var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
137+
138+
// training
139+
model.fit(x_train[new Slice(0, 1000)], y_train[new Slice(0, 1000)],
140+
batch_size: 64,
141+
epochs: 10,
142+
validation_split: 0.2f);
143+
```
144+
95145
Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflownet.readthedocs.io/en/latest/FrontCover.html).
96146

97147
There are many examples reside at [TensorFlow.NET Examples](https://github.com/SciSharp/TensorFlow.NET-Examples).

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,5 +167,21 @@ public NDArray eval_in_eager_or_function(Tensor outputs)
167167

168168
public class _DummyEagerGraph
169169
{ }
170+
171+
/// <summary>
172+
/// Categorical crossentropy between an output tensor and a target tensor.
173+
/// </summary>
174+
/// <param name="target"></param>
175+
/// <param name="output"></param>
176+
/// <param name="from_logits"></param>
177+
/// <param name="axis"></param>
178+
/// <returns></returns>
179+
public Tensor categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1)
180+
{
181+
if (from_logits)
182+
return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: output, axis: axis);
183+
184+
throw new NotImplementedException("");
185+
}
170186
}
171187
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
using NumSharp;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.IO;
5+
using System.Text;
6+
using static Tensorflow.Binding;
7+
using Tensorflow.Keras.Utils;
8+
9+
namespace Tensorflow.Keras.Datasets
10+
{
11+
public class Cifar10
12+
{
13+
string origin_folder = "https://www.cs.toronto.edu/~kriz/";
14+
string file_name = "cifar-10-python.tar.gz";
15+
string dest_folder = "cifar-10-batches";
16+
17+
/// <summary>
18+
/// Loads [CIFAR10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html).
19+
/// </summary>
20+
/// <returns></returns>
21+
public DatasetPass load_data()
22+
{
23+
var dst = Download();
24+
25+
var data_list = new List<Tensor>();
26+
var label_list = new List<Tensor>();
27+
28+
foreach (var i in range(1, 6))
29+
{
30+
var fpath = Path.Combine(dst, $"data_batch_{i}");
31+
var (data, labels) = load_batch(fpath);
32+
data_list.Add(data);
33+
label_list.Add(labels);
34+
}
35+
36+
var x_train_tensor = tf.concat(data_list, 0);
37+
var y_train_tensor = tf.concat(label_list, 0);
38+
var y_train = np.array(y_train_tensor.BufferToArray()).reshape(y_train_tensor.shape);
39+
40+
// test data
41+
var fpath_test = Path.Combine(dst, "test_batch");
42+
var (x_test, y_test) = load_batch(fpath_test);
43+
44+
// channels_last
45+
x_train_tensor = tf.transpose(x_train_tensor, new[] { 0, 2, 3, 1 });
46+
var x_train = np.array(x_train_tensor.BufferToArray()).reshape(x_train_tensor.shape);
47+
48+
var x_test_tensor = tf.transpose(x_test, new[] { 0, 2, 3, 1 });
49+
x_test = np.array(x_test_tensor.BufferToArray()).reshape(x_test_tensor.shape);
50+
51+
return new DatasetPass
52+
{
53+
Train = (x_train, y_train),
54+
Test = (x_test, y_test)
55+
};
56+
}
57+
58+
(NDArray, NDArray) load_batch(string fpath, string label_key = "labels")
59+
{
60+
var pickle = File.ReadAllBytes(fpath);
61+
// read description
62+
var start_pos = 7;
63+
var desc = read_description(ref start_pos, pickle);
64+
var labels = read_labels(ref start_pos, pickle);
65+
var data = read_data(ref start_pos, pickle);
66+
67+
return (data.Item2, labels.Item2);
68+
}
69+
70+
(string, string) read_description(ref int start_pos, byte[] pickle)
71+
{
72+
var key_length = pickle[start_pos];
73+
start_pos++;
74+
var span = new Span<byte>(pickle, start_pos, key_length);
75+
var key = Encoding.ASCII.GetString(span.ToArray());
76+
start_pos += key_length + 3;
77+
78+
var value_length = pickle[start_pos];
79+
start_pos++;
80+
var value = Encoding.ASCII.GetString(new Span<byte>(pickle, start_pos, value_length).ToArray());
81+
start_pos += value_length;
82+
start_pos += 3;
83+
84+
return (key, value);
85+
}
86+
87+
(string, NDArray) read_labels(ref int start_pos, byte[] pickle)
88+
{
89+
byte[] value = new byte[10000];
90+
91+
var key_length = pickle[start_pos];
92+
start_pos++;
93+
var span = new Span<byte>(pickle, start_pos, key_length);
94+
var key = Encoding.ASCII.GetString(span.ToArray());
95+
start_pos += key_length + 6;
96+
97+
var value_length = 10000;
98+
for (int i = 0; i < value_length; i++)
99+
{
100+
if (i > 0 && i % 1000 == 0)
101+
start_pos += 2;
102+
value[i] = pickle[start_pos + 1];
103+
start_pos += 2;
104+
}
105+
start_pos += 2;
106+
107+
return (key, np.array(value));
108+
}
109+
110+
(string, NDArray) read_data(ref int start_pos, byte[] pickle)
111+
{
112+
var key_length = pickle[start_pos];
113+
start_pos++;
114+
var span = new Span<byte>(pickle, start_pos, key_length);
115+
var key = Encoding.ASCII.GetString(span.ToArray());
116+
start_pos += key_length + 133;
117+
var value_length = 3072 * 10000;
118+
var value = new Span<byte>(pickle, start_pos, value_length).ToArray();
119+
start_pos += value_length;
120+
121+
return (key, np.array(value).reshape(10000, 3, 32, 32));
122+
}
123+
124+
string Download()
125+
{
126+
var dst = Path.Combine(Path.GetTempPath(), dest_folder);
127+
Directory.CreateDirectory(dst);
128+
129+
Web.Download(origin_folder + file_name, dst, file_name);
130+
Compress.ExtractTGZ(Path.Combine(Path.GetTempPath(), file_name), dst);
131+
132+
return Path.Combine(dst, "cifar-10-batches-py");
133+
}
134+
}
135+
}

src/TensorFlowNET.Keras/Datasets/KerasDataset.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ namespace Tensorflow.Keras.Datasets
1919
public class KerasDataset
2020
{
2121
public Mnist mnist { get; } = new Mnist();
22+
public Cifar10 cifar10 { get; } = new Cifar10();
2223
}
2324
}

src/TensorFlowNET.Keras/Datasets/MNIST.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License.
1717
using NumSharp;
1818
using System;
1919
using System.IO;
20-
using System.Net;
20+
using Tensorflow.Keras.Utils;
2121

2222
namespace Tensorflow.Keras.Datasets
2323
{
@@ -65,8 +65,7 @@ string Download()
6565
return fileSaveTo;
6666
}
6767

68-
using var wc = new WebClient();
69-
wc.DownloadFileTaskAsync(origin_folder + file_name, fileSaveTo).Wait();
68+
Web.Download(origin_folder + file_name, Path.GetTempPath(), file_name);
7069

7170
return fileSaveTo;
7271
}

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@
66
<LangVersion>8.0</LangVersion>
77
<RootNamespace>Tensorflow.Keras</RootNamespace>
88
<Platforms>AnyCPU;x64</Platforms>
9-
<Version>0.1.0</Version>
9+
<Version>0.2.0</Version>
1010
<Authors>Haiping Chen</Authors>
1111
<Product>Keras for .NET</Product>
1212
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright>
1313
<PackageId>TensorFlow.Keras</PackageId>
1414
<PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl>
1515
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
1616
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl>
17-
<PackageReleaseNotes>Keras for .NET is a C# version of Keras ported from the python version.</PackageReleaseNotes>
17+
<PackageReleaseNotes>Keras for .NET is a C# version of Keras ported from the python version.
18+
19+
* Support CIFAR-10 dataset in keras.datasets.
20+
* Support Conv2D functional API.</PackageReleaseNotes>
1821
<Description>Keras for .NET
1922

2023
Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent &amp; simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear &amp; actionable error messages.</Description>
@@ -27,11 +30,17 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
2730

2831
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
2932
<DefineConstants>DEBUG;TRACE</DefineConstants>
33+
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
34+
</PropertyGroup>
35+
36+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
37+
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
3038
</PropertyGroup>
3139

3240
<ItemGroup>
3341
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
3442
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
43+
<PackageReference Include="SharpZipLib" Version="1.3.1" />
3544
</ItemGroup>
3645

3746
<ItemGroup>

0 commit comments

Comments
 (0)