ViennaLS
Loading...
Searching...
No Matches
lsCompareNarrowBand.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <hrleDenseCellIterator.hpp>
4#include <lsDomain.hpp>
5#include <lsExpand.hpp>
6#include <lsMesh.hpp>
8
9#include <unordered_map>
10
11namespace viennals {
12
13using namespace viennacore;
14
19template <class T, int D = 2> class CompareNarrowBand {
20 using hrleIndexType = viennahrle::IndexType;
21 SmartPointer<Domain<T, D>> levelSetTarget = nullptr;
22 SmartPointer<Domain<T, D>> levelSetSample = nullptr;
23 viennahrle::Index<D> minIndex, maxIndex;
24
25 // Variables for x and y range restrictions
26 T xRangeMin = std::numeric_limits<T>::lowest();
27 T xRangeMax = std::numeric_limits<T>::max();
28 T yRangeMin = std::numeric_limits<T>::lowest();
29 T yRangeMax = std::numeric_limits<T>::max();
30 bool useXRange = false;
31 bool useYRange = false;
32
33 // Fields to store the calculation results
34 T sumSquaredDifferences = 0.0;
35 T sumDifferences = 0.0;
36 unsigned numPoints = 0;
37
38 // Add mesh output capability
39 SmartPointer<Mesh<T>> outputMesh = nullptr;
40 bool outputMeshSquaredDifferences = true;
41
42 bool checkAndCalculateBounds() {
43 if (levelSetTarget == nullptr || levelSetSample == nullptr) {
44 Logger::getInstance()
45 .addWarning("Missing level set in CompareNarrowBand.")
46 .print();
47 return false;
48 }
49
50 // Check if the grids are compatible
51 const auto &gridTarget = levelSetTarget->getGrid();
52 const auto &gridSample = levelSetSample->getGrid();
53
54 if (gridTarget.getGridDelta() != gridSample.getGridDelta()) {
55 Logger::getInstance()
56 .addWarning("Grid delta mismatch in CompareNarrowBand. The grid "
57 "deltas of the two level sets must be equal.")
58 .print();
59 return false;
60 }
61
62 // Check if the x extents of both level sets are equal
63 const auto &domainTarget = levelSetTarget->getDomain();
64 const auto &domainSample = levelSetSample->getDomain();
65
66 // hrleIndexType targetMinX = gridTarget.isNegBoundaryInfinite(0)
67 // ? domainTarget.getMinRunBreak(0)
68 // : gridTarget.getMinIndex(0);
69 // hrleIndexType targetMaxX = gridTarget.isPosBoundaryInfinite(0)
70 // ? domainTarget.getMaxRunBreak(0)
71 // : gridTarget.getMaxIndex(0);
72 // hrleIndexType sampleMinX = gridSample.isNegBoundaryInfinite(0)
73 // ? domainSample.getMinRunBreak(0)
74 // : gridSample.getMinIndex(0);
75 // hrleIndexType sampleMaxX = gridSample.isPosBoundaryInfinite(0)
76 // ? domainSample.getMaxRunBreak(0)
77 // : gridSample.getMaxIndex(0);
78
79 // if (targetMinX != sampleMinX || targetMaxX != sampleMaxX) {
80 // Logger::getInstance()
81 // .addWarning("X extent mismatch in CompareNarrowBand. The x extents
82 // "
83 // "of both level sets must be equal.")
84 // .print();
85 // return false;
86 // }
87
88 // Expand the sample level set using lsExpand to a default width of 5
89 if (levelSetSample->getLevelSetWidth() < 5) {
90 Logger::getInstance()
91 .addWarning("Sample level set width is insufficient. Expanding it to "
92 "a width of 5.")
93 .print();
94 Expand<T, D>(levelSetSample, 5).apply();
95 }
96
97 // Check if target level set width is sufficient
98 if (levelSetTarget->getLevelSetWidth() <
99 levelSetSample->getLevelSetWidth() + 50) {
100 Logger::getInstance()
101 .addWarning(
102 "Target level set width is insufficient. It must exceed sample "
103 "width by least 50. \n"
104 " CORRECTION: The expansion was performed. \n"
105 "ALTERNATIVE: Alternatively, please expand the target yourself "
106 "using lsExpand before passing it to this function. \n")
107 .print();
108 Expand<T, D>(levelSetTarget, levelSetSample->getLevelSetWidth() + 50)
109 .apply();
110 }
111
112 // Initialize min and max indices
113 for (unsigned i = 0; i < D; ++i) {
114 minIndex[i] = std::numeric_limits<hrleIndexType>::max();
115 maxIndex[i] = std::numeric_limits<hrleIndexType>::lowest();
116 }
117
118 // Calculate actual bounds
119 for (unsigned i = 0; i < D; ++i) {
120 minIndex[i] = std::min({minIndex[i],
121 (gridTarget.isNegBoundaryInfinite(i))
122 ? domainTarget.getMinRunBreak(i)
123 : gridTarget.getMinIndex(i),
124 (gridSample.isNegBoundaryInfinite(i))
125 ? domainSample.getMinRunBreak(i)
126 : gridSample.getMinIndex(i)});
127
128 maxIndex[i] = std::max({maxIndex[i],
129 (gridTarget.isPosBoundaryInfinite(i))
130 ? domainTarget.getMaxRunBreak(i)
131 : gridTarget.getMaxIndex(i),
132 (gridSample.isPosBoundaryInfinite(i))
133 ? domainSample.getMaxRunBreak(i)
134 : gridSample.getMaxIndex(i)});
135 }
136
137 return true;
138 }
139
140public:
142 static_assert(
143 D == 2 &&
144 "CompareNarrowBand is currently only implemented for 2D level sets.");
145 }
146
147 CompareNarrowBand(SmartPointer<Domain<T, D>> passedLevelSetTarget,
148 SmartPointer<Domain<T, D>> passedlevelSetSample)
149 : levelSetTarget(passedLevelSetTarget),
150 levelSetSample(passedlevelSetSample) {
151 static_assert(
152 D == 2 &&
153 "CompareNarrowBand is currently only implemented for 2D level sets.");
154 }
155
156 void setLevelSetTarget(SmartPointer<Domain<T, D>> passedLevelSet) {
157 levelSetTarget = passedLevelSet;
158 }
159
160 void setLevelSetSample(SmartPointer<Domain<T, D>> passedLevelSet) {
161 levelSetSample = passedLevelSet;
162 }
163
165 void setXRange(T minXRange, T maxXRange) {
166 xRangeMin = minXRange;
167 xRangeMax = maxXRange;
168 useXRange = true;
169 }
170
172 void setYRange(T minYRange, T maxYRange) {
173 yRangeMin = minYRange;
174 yRangeMax = maxYRange;
175 useYRange = true;
176 }
177
179 void clearXRange() {
180 useXRange = false;
181 xRangeMin = std::numeric_limits<T>::lowest();
182 xRangeMax = std::numeric_limits<T>::max();
183 }
184
186 void clearYRange() {
187 useYRange = false;
188 yRangeMin = std::numeric_limits<T>::lowest();
189 yRangeMax = std::numeric_limits<T>::max();
190 }
191
193 void setOutputMesh(SmartPointer<Mesh<T>> passedMesh,
194 bool outputMeshSquaredDiffs = true) {
195 outputMesh = passedMesh;
196 outputMeshSquaredDifferences = outputMeshSquaredDiffs;
197 }
198
202 outputMeshSquaredDifferences = value;
203 }
204
206 void apply() {
207 // Perform compatibility checks and calculate bounds
208 if (!checkAndCalculateBounds()) {
209 // If checks fail, return NaN
210 sumSquaredDifferences = std::numeric_limits<T>::quiet_NaN();
211 numPoints = 0;
212 return;
213 }
214
215 const auto &gridTarget = levelSetTarget->getGrid();
216 double gridDelta = gridTarget.getGridDelta();
217
218 // Set up iterators for both level sets
219 viennahrle::ConstDenseCellIterator<typename Domain<T, D>::DomainType>
220 itSample(levelSetSample->getDomain(), minIndex);
221 viennahrle::ConstDenseCellIterator<typename Domain<T, D>::DomainType>
222 itTarget(levelSetTarget->getDomain(), minIndex);
223
224 sumSquaredDifferences = 0.0;
225 numPoints = 0;
226
227 // Prepare mesh output if needed
228 std::unordered_map<viennahrle::Index<D>, size_t,
229 typename viennahrle::Index<D>::hash>
230 pointIdMapping;
231 std::vector<T> differenceValues;
232 size_t currentPointId = 0;
233
234 const bool generateMesh = outputMesh != nullptr;
235 if (generateMesh) {
236 outputMesh->clear();
237
238 // Initialize mesh extent
239 for (unsigned i = 0; i < D; ++i) {
240 outputMesh->minimumExtent[i] = std::numeric_limits<T>::max();
241 outputMesh->maximumExtent[i] = std::numeric_limits<T>::lowest();
242 }
243 }
244
245 // Iterate through the domain defined by the bounding box
246 for (; itSample.getIndices() < maxIndex; itSample.next()) {
247 // Check if current point is within specified x and y ranges
248 T xCoord = itSample.getIndices()[0] * gridDelta;
249 T yCoord = (D > 1) ? itSample.getIndices()[1] * gridDelta : 0;
250
251 // Skip if outside the specified x-range
252 if (useXRange && (xCoord < xRangeMin || xCoord > xRangeMax)) {
253 continue;
254 }
255
256 // Skip if outside the specified y-range (only check in 2D and 3D)
257 if (D > 1 && useYRange && (yCoord < yRangeMin || yCoord > yRangeMax)) {
258 continue;
259 }
260
261 // Move the second iterator to the same position
262 itTarget.goToIndicesSequential(itSample.getIndices());
263
264 // Get values at current position
265 T valueTarget = 0.0;
266 T valueSample = 0.0;
267
268 // Calculate average value at cell center
269 for (int i = 0; i < (1 << D); ++i) {
270 valueSample += itSample.getCorner(i).getValue();
271 valueTarget += itTarget.getCorner(i).getValue();
272 }
273 valueTarget /= (1 << D);
274 valueSample /= (1 << D);
275
276 // Check for infinite or extreme values that might cause numerical issues
277 if (std::isinf(valueTarget) || std::isinf(valueSample) ||
278 std::abs(valueTarget) > 1000 || std::abs(valueSample) > 1000) {
279 continue;
280 }
281
282 // Calculate difference and add to sum
283 T diff = std::abs(valueTarget - valueSample) * gridDelta;
284 T diffSquared = diff * diff;
285 sumSquaredDifferences += diffSquared;
286 sumDifferences += diff;
287 numPoints++;
288
289 // Store difference in mesh if required
290 if (generateMesh) {
291 std::array<unsigned, 1 << D> voxel;
292 bool addVoxel = true;
293 // TODO: possibly remove this addVoxel check
294 // Insert all points of voxel into pointList
295 for (unsigned i = 0; i < (1 << D); ++i) {
296 viennahrle::Index<D> index;
297 for (unsigned j = 0; j < D; ++j) {
298 index[j] =
299 itSample.getIndices(j) + itSample.getCorner(i).getOffset()[j];
300 if (index[j] > maxIndex[j]) {
301 addVoxel = false;
302 break;
303 }
304 }
305 if (addVoxel) {
306 auto pointIdValue = std::make_pair(index, currentPointId);
307 auto pointIdPair = pointIdMapping.insert(pointIdValue);
308 voxel[i] = pointIdPair.first->second;
309 if (pointIdPair.second) {
310 ++currentPointId;
311 }
312 } else {
313 break;
314 }
315 }
316
317 if (addVoxel) {
318 if constexpr (D == 3) {
319 std::array<unsigned, 8> hexa{voxel[0], voxel[1], voxel[3],
320 voxel[2], voxel[4], voxel[5],
321 voxel[7], voxel[6]};
322 outputMesh->hexas.push_back(hexa);
323 } else if constexpr (D == 2) {
324 std::array<unsigned, 4> quad{voxel[0], voxel[1], voxel[3],
325 voxel[2]};
326 outputMesh->tetras.push_back(quad);
327 }
328
329 // Add difference value to cell data depending on whether squared
330 // differences are requested
331 differenceValues.push_back(outputMeshSquaredDifferences ? diffSquared
332 : diff);
333 }
334 }
335 }
336
337 // Finalize mesh output
338 if (generateMesh && !pointIdMapping.empty()) {
339 outputMesh->nodes.resize(pointIdMapping.size());
340 for (auto it = pointIdMapping.begin(); it != pointIdMapping.end(); ++it) {
341 std::array<T, 3> coords{};
342 for (unsigned i = 0; i < D; ++i) {
343 coords[i] = gridDelta * it->first[i];
344
345 if (coords[i] < outputMesh->minimumExtent[i]) {
346 outputMesh->minimumExtent[i] = coords[i];
347 } else if (coords[i] > outputMesh->maximumExtent[i]) {
348 outputMesh->maximumExtent[i] = coords[i];
349 }
350 }
351 outputMesh->nodes[it->second] = coords;
352 }
353
354 assert(differenceValues.size() ==
355 outputMesh->template getElements<1 << D>().size());
356 outputMesh->cellData.insertNextScalarData(std::move(differenceValues),
357 outputMeshSquaredDifferences
358 ? "Squared differences"
359 : "Absolute differences");
360 }
361 }
362
364 T getSumSquaredDifferences() const { return sumSquaredDifferences; }
365
366 // Return the sum of differences calculated by apply().
367 T getSumDifferences() const { return sumDifferences; }
368
370 unsigned getNumPoints() const { return numPoints; }
371
373 T getRMSE() const {
374 return (numPoints > 0) ? std::sqrt(sumSquaredDifferences / numPoints)
375 : std::numeric_limits<T>::infinity();
376 }
377};
378
379// Add template specializations for this class
380PRECOMPILE_PRECISION_DIMENSION(CompareNarrowBand)
381
382} // namespace viennals
void setYRange(T minYRange, T maxYRange)
Set the y-coordinate range to restrict the comparison area.
Definition lsCompareNarrowBand.hpp:172
void setOutputMeshSquaredDifferences(bool value)
Set whether to output squared differences (true) or absolute differences (false)
Definition lsCompareNarrowBand.hpp:201
void setXRange(T minXRange, T maxXRange)
Set the x-coordinate range to restrict the comparison area.
Definition lsCompareNarrowBand.hpp:165
void setOutputMesh(SmartPointer< Mesh< T > > passedMesh, bool outputMeshSquaredDiffs=true)
Set the output mesh where difference values will be stored.
Definition lsCompareNarrowBand.hpp:193
void setLevelSetTarget(SmartPointer< Domain< T, D > > passedLevelSet)
Definition lsCompareNarrowBand.hpp:156
T getRMSE() const
Calculate the root mean square error from previously computed values.
Definition lsCompareNarrowBand.hpp:373
void setLevelSetSample(SmartPointer< Domain< T, D > > passedLevelSet)
Definition lsCompareNarrowBand.hpp:160
unsigned getNumPoints() const
Return the number of points used in the comparison.
Definition lsCompareNarrowBand.hpp:370
void clearYRange()
Clear the y-range restriction.
Definition lsCompareNarrowBand.hpp:186
T getSumSquaredDifferences() const
Return the sum of squared differences calculated by apply().
Definition lsCompareNarrowBand.hpp:364
CompareNarrowBand(SmartPointer< Domain< T, D > > passedLevelSetTarget, SmartPointer< Domain< T, D > > passedlevelSetSample)
Definition lsCompareNarrowBand.hpp:147
void apply()
Apply the comparison and calculate the sum of squared differences.
Definition lsCompareNarrowBand.hpp:206
void clearXRange()
Clear the x-range restriction.
Definition lsCompareNarrowBand.hpp:179
T getSumDifferences() const
Definition lsCompareNarrowBand.hpp:367
CompareNarrowBand()
Definition lsCompareNarrowBand.hpp:141
Class containing all information about the level set, including the dimensions of the domain,...
Definition lsDomain.hpp:27
Expands the levelSet to the specified number of layers. The largest value in the levelset is thus wid...
Definition lsExpand.hpp:17
void apply()
Apply the expansion to the specified width.
Definition lsExpand.hpp:44
This class holds an explicit mesh, which is always given in 3 dimensions. If it describes a 2D mesh,...
Definition lsMesh.hpp:21
#define PRECOMPILE_PRECISION_DIMENSION(className)
Definition lsPreCompileMacros.hpp:24
Definition lsAdvect.hpp:36
constexpr int D
Definition pyWrap.cpp:71
double T
Definition pyWrap.cpp:69