Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 171 additions & 7 deletions Common/MathUtils/include/MathUtils/fit.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <algorithm>
#include <vector>
#include <array>
#include <thread>

#include "Rtypes.h"
#include "TLinearFitter.h"
Expand Down Expand Up @@ -69,9 +70,9 @@ TFitResultPtr fit(const size_t nBins, const T* arr, const T xMin, const T xMax,
// create an empty TFitResult
std::shared_ptr<TFitResult> tfr(new TFitResult());
// create the fitter from an empty fit result
//std::shared_ptr<ROOT::Fit::Fitter> fitter(new ROOT::Fit::Fitter(std::static_pointer_cast<ROOT::Fit::FitResult>(tfr) ) );
// std::shared_ptr<ROOT::Fit::Fitter> fitter(new ROOT::Fit::Fitter(std::static_pointer_cast<ROOT::Fit::FitResult>(tfr) ) );
ROOT::Fit::Fitter fitter(tfr);
//ROOT::Fit::FitConfig & fitConfig = fitter->Config();
// ROOT::Fit::FitConfig & fitConfig = fitter->Config();

const double binWidth = double(xMax - xMin) / double(nBins);

Expand Down Expand Up @@ -225,8 +226,8 @@ bool medmadGaus(size_t nBins, const T* arr, const T xMin, const T xMax, std::arr
/// -1: only one point has been used for the calculation - center of gravity was uesed for calculation
/// -4: invalid result!!
///
//template <typename T>
//Double_t fitGaus(const size_t nBins, const T *arr, const T xMin, const T xMax, std::vector<T>& param);
// template <typename T>
// Double_t fitGaus(const size_t nBins, const T *arr, const T xMin, const T xMax, std::vector<T>& param);
template <typename T>
Double_t fitGaus(const size_t nBins, const T* arr, const T xMin, const T xMax, std::vector<T>& param)
{
Expand Down Expand Up @@ -301,7 +302,7 @@ Double_t fitGaus(const size_t nBins, const T* arr, const T xMin, const T xMax, s
Double_t chi2 = 0;
if (npoints >= 3) {
if (npoints == 3) {
//analytic calculation of the parameters for three points
// analytic calculation of the parameters for three points
A.Invert();
TMatrixD res(1, 3);
res.Mult(A, b);
Expand Down Expand Up @@ -334,7 +335,7 @@ Double_t fitGaus(const size_t nBins, const T* arr, const T xMin, const T xMax, s
}

if (npoints == 2) {
//use center of gravity for 2 points
// use center of gravity for 2 points
meanCOG /= sumCOG;
rms2COG /= sumCOG;
param[0] = max;
Expand Down Expand Up @@ -524,7 +525,7 @@ R median(std::vector<T> v)
auto n = v.size() / 2;
nth_element(v.begin(), v.begin() + n, v.end());
auto med = R{v[n]};
if (!(v.size() & 1)) { //If the set size is even
if (!(v.size() & 1)) { // If the set size is even
auto max_it = max_element(v.begin(), v.begin() + n);
med = R{(*max_it + med) / 2.0};
}
Expand Down Expand Up @@ -788,6 +789,169 @@ T MAD2Sigma(int np, T* y)
return median * 1.4826; // convert to Gaussian sigma
}

/// \return returns the index of the closest timestamps to the left and right of the given timestamp
/// \param timestamps vector of timestamps
/// \param timestamp the timestamp to find the closest timestamps for
template <typename DataTimeType, typename DataTime>
std::optional<std::pair<size_t, size_t>> findClosestIndices(const std::vector<DataTimeType>& timestamps, DataTime timestamp)
{
if (timestamps.empty()) {
LOGP(warning, "Timestamp vector is empty!");
return std::nullopt;
}

if (timestamp <= timestamps.front()) {
return std::pair{0, 0};
} else if (timestamp >= timestamps.back()) {
return std::pair{timestamps.size() - 1, timestamps.size() - 1};
}

const auto it = std::lower_bound(timestamps.begin(), timestamps.end(), timestamp);
const size_t idx = std::distance(timestamps.begin(), it);
const auto prevTimestamp = timestamps[idx - 1];
const auto nextTimestamp = timestamps[idx];
return std::pair{(idx - 1), idx};
}

struct RollingStats {
RollingStats() = default;
RollingStats(const int nValues)
{
median.resize(nValues);
std.resize(nValues);
nPoints.resize(nValues);
closestDistanceL.resize(nValues);
closestDistanceR.resize(nValues);
}

std::vector<float> median; ///< median of rolling data
std::vector<float> std; ///< std of rolling data
std::vector<int> nPoints; ///< number of points used for the calculation
std::vector<float> closestDistanceL; ///< distance of closest point to the left
std::vector<float> closestDistanceR; ///< distance of closest point to the right

ClassDefNV(RollingStats, 1);
};

/// \brief calculates the rolling statistics of the input data
/// \return returns the rolling statistics
/// \param timeData times of the input data (assumed to be sorted)
/// \param data values of the input data
/// \param times times for which to calculate the rolling statistics
/// \param deltaMax time range for which the rolling statistics is calculated
/// \param mNthreads number of threads to use for the calculation
/// \param minPoints minimum number of points to use for the calculation of the statistics - otherwise use nearest nClosestPoints points weighted with distance
/// \param nClosestPoints number of closest points in case of number of points in given range is smaller than minPoints
template <typename DataTimeType, typename DataType, typename DataTime>
RollingStats getRollingStatistics(const DataTimeType& timeData, const DataType& data, const DataTime& times, const double deltaMax, const int mNthreads, const size_t minPoints = 4, const size_t nClosestPoints = 4)
{
// output statistics
const size_t vecSize = times.size();
RollingStats stats(vecSize);

if (!std::is_sorted(timeData.begin(), timeData.end())) {
LOGP(error, "Input data is NOT sorted!");
return stats;
}

if (timeData.empty()) {
LOGP(error, "Input data is empty!");
return stats;
}

const size_t dataSize = data.size();
const size_t timeDataSize = timeData.size();
if (timeDataSize != dataSize) {
LOGP(error, "Input data has different sizes {}!={}", timeDataSize, dataSize);
return stats;
}

auto myThread = [&](int iThread) {
// data in given time window for median calculation
DataType window;
for (size_t i = iThread; i < vecSize; i += mNthreads) {
const double timeI = times[i];

// lower index
const double timeStampLower = timeI - deltaMax;
const auto lower = std::lower_bound(timeData.begin(), timeData.end(), timeStampLower);
size_t idxStart = std::distance(timeData.begin(), lower);

// upper index
const double timeStampUpper = timeI + deltaMax;
const auto upper = std::lower_bound(timeData.begin(), timeData.end(), timeStampUpper);
size_t idxEnd = std::distance(timeData.begin(), upper);

// closest data point
if (auto idxClosest = findClosestIndices(timeData, timeI)) {
auto [idxLeft, idxRight] = *idxClosest;
const auto closestL = std::abs(timeData[idxLeft] - timeI);
const auto closestR = std::abs(timeData[idxRight] - timeI);
stats.closestDistanceL[i] = closestL;
stats.closestDistanceR[i] = closestR;

// if no points are in the range use the n closest points - n from the left and n from the right
const size_t reqSize = idxEnd - idxStart;
if (reqSize < minPoints) {
// calculate weighted average
idxStart = (idxRight > nClosestPoints) ? (idxRight - nClosestPoints) : 0;
idxEnd = std::min(data.size(), idxRight + nClosestPoints);
constexpr float epsilon = 1e-6f;
double weightedSum = 0.0;
double weightTotal = 0.0;
for (size_t j = idxStart; j < idxEnd; ++j) {
const double dist = std::abs(timeI - timeData[j]);
const double weight = 1.0 / (dist + epsilon);
weightedSum += weight * data[j];
weightTotal += weight;
}
stats.median[i] = (weightTotal > 0.) ? (weightedSum / weightTotal) : 0.0f;
} else {
// calculate statistics
stats.nPoints[i] = reqSize;

if (idxStart >= data.size()) {
stats.median[i] = data.back();
continue;
}

if (reqSize <= 1) {
stats.median[i] = data[idxStart];
continue;
}

// calculate median
window.clear();
if (reqSize > window.capacity()) {
window.reserve(static_cast<size_t>(reqSize * 1.5));
}
window.insert(window.end(), data.begin() + idxStart, data.begin() + idxEnd);
const size_t middle = window.size() / 2;
std::nth_element(window.begin(), window.begin() + middle, window.end());
stats.median[i] = (window.size() % 2 == 1) ? window[middle] : ((window[middle - 1] + window[middle]) / 2.0);

// calculate the stdev
const float mean = std::accumulate(window.begin(), window.end(), 0.0f) / window.size();
std::transform(window.begin(), window.end(), window.begin(), [mean](const float val) { return val - mean; });
const float sqsum = std::inner_product(window.begin(), window.end(), window.begin(), 0.0f);
const float stdev = std::sqrt(sqsum / window.size());
stats.std[i] = stdev;
}
}
}
};

std::vector<std::thread> threads(mNthreads);
for (int i = 0; i < mNthreads; i++) {
threads[i] = std::thread(myThread, i);
}

for (auto& th : threads) {
th.join();
}
return stats;
}

} // namespace math_utils
} // namespace o2
#endif
2 changes: 2 additions & 0 deletions Common/MathUtils/src/MathUtilsLinkDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,6 @@
#pragma link C++ class o2::math_utils::Legendre1DPolynominal + ;
#pragma link C++ class o2::math_utils::Legendre2DPolynominal + ;

#pragma link C++ class o2::math_utils::RollingStats + ;

#endif
3 changes: 2 additions & 1 deletion DataFormats/Detectors/TPC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ o2_add_library(
O2::CommonDataFormat
O2::Headers
O2::DataSampling
O2::Algorithm)
O2::Algorithm
ROOT::Minuit)

o2_target_root_dictionary(
DataFormatsTPC
Expand Down
Loading