@@ -98,7 +98,7 @@ namespace
9898
9999__global__ void otsu_sums (uint *histogram, uint *threshold_sums, unsigned long long *sums)
100100{
101- const uint32_t n_bins = 256 ;
101+ const uint n_bins = 256 ;
102102
103103 __shared__ uint shared_memory_ts[n_bins];
104104 __shared__ unsigned long long shared_memory_s[n_bins];
@@ -109,7 +109,7 @@ __global__ void otsu_sums(uint *histogram, uint *threshold_sums, unsigned long l
109109 uint threshold_sum_above = 0 ;
110110 unsigned long long sum_above = 0 ;
111111
112- if (bin_idx >= threshold)
112+ if (bin_idx > threshold)
113113 {
114114 uint value = histogram[bin_idx];
115115 threshold_sum_above = value;
@@ -129,7 +129,7 @@ __global__ void otsu_sums(uint *histogram, uint *threshold_sums, unsigned long l
129129__global__ void
130130otsu_variance (float2 *variance, uint *histogram, uint *threshold_sums, unsigned long long *sums)
131131{
132- const uint32_t n_bins = 256 ;
132+ const uint n_bins = 256 ;
133133
134134 __shared__ signed long long shared_memory_a[n_bins];
135135 __shared__ signed long long shared_memory_b[n_bins];
@@ -147,7 +147,7 @@ otsu_variance(float2 *variance, uint *histogram, uint *threshold_sums, unsigned
147147
148148 float threshold_variance_above_f32 = 0 ;
149149 float threshold_variance_below_f32 = 0 ;
150- if (bin_idx >= threshold)
150+ if (bin_idx > threshold)
151151 {
152152 float mean = (float ) sum_above / n_samples_above;
153153 float sigma = bin_idx - mean;
@@ -172,11 +172,35 @@ otsu_variance(float2 *variance, uint *histogram, uint *threshold_sums, unsigned
172172 }
173173}
174174
175+ template <uint n_thresholds>
176+ __device__ bool has_lowest_score (
177+ uint threshold, float original_score, float score, uint *shared_memory
178+ ) {
179+ // It may happen that multiple threads have the same minimum score. In that case, we want to find the thread with
180+ // the lowest threshold. This is done by calling '__syncthreads_count' to count how many threads have a score
181+ // that matches to the minimum score found. Since this is rare, we will optimize towards the common case where only
182+ // one thread has the minimum score. If multiple threads have the same minimum score, we will find the minimum
183+ // threshold that satifies the condition
184+ bool has_match = original_score == score;
185+ uint matches = __syncthreads_count (has_match);
186+
187+ if (matches > 1 ) {
188+ // If this thread has a match, we use it; otherwise we give it a value that is larger than the maximum
189+ // threshold, so it will never get picked
190+ uint min_threshold = has_match ? threshold : n_thresholds;
191+
192+ blockReduce<n_thresholds>(shared_memory, min_threshold, threshold, minimum<uint>());
193+
194+ return min_threshold == threshold;
195+ } else {
196+ return has_match;
197+ }
198+ }
175199
176200__global__ void
177201otsu_score (uint *otsu_threshold, uint *threshold_sums, float2 *variance)
178202{
179- const uint32_t n_thresholds = 256 ;
203+ const uint n_thresholds = 256 ;
180204
181205 __shared__ float shared_memory[n_thresholds];
182206
@@ -190,8 +214,8 @@ otsu_score(uint *otsu_threshold, uint *threshold_sums, float2 *variance)
190214 float threshold_mean_below = (float )n_samples_below / n_samples;
191215
192216 float2 variances = variance[threshold];
193- float variance_above = variances.x / n_samples_above;
194- float variance_below = variances.y / n_samples_below;
217+ float variance_above = n_samples_above > 0 ? variances.x / n_samples_above : 0 . 0f ;
218+ float variance_below = n_samples_below > 0 ? variances.y / n_samples_below : 0 . 0f ;
195219
196220 float above = threshold_mean_above * variance_above;
197221 float below = threshold_mean_below * variance_below;
@@ -209,11 +233,11 @@ otsu_score(uint *otsu_threshold, uint *threshold_sums, float2 *variance)
209233
210234 score = shared_memory[0 ];
211235
212- // We found the minimum score, but we need to find the threshold. If we find the thread with the minimum score, we
213- // know which threshold it is
214- if (original_score == score)
236+ // We found the minimum score, but in some cases multiple threads can have the same score, so we need to find the
237+ // lowest threshold
238+ if (has_lowest_score<n_thresholds>(threshold, original_score, score, (uint *) shared_memory) )
215239 {
216- *otsu_threshold = threshold - 1 ;
240+ *otsu_threshold = threshold;
217241 }
218242}
219243
0 commit comments