ViennaLS
Loading...
Searching...
No Matches
lsPrune.hpp
Go to the documentation of this file.
1#pragma once
2
4
5#include <hrleSparseStarIterator.hpp>
6#include <hrleVectorType.hpp>
7
8#include <lsDomain.hpp>
9
10namespace viennals {
11
12using namespace viennacore;
13
19template <class T, int D> class Prune {
20 SmartPointer<Domain<T, D>> levelSet = nullptr;
21 bool updatePointData = true;
22 bool removeStrayZeros = false;
23
24 template <class Numeric> static bool isNegative(const Numeric a) {
25 return a <= -std::numeric_limits<Numeric>::epsilon();
26 }
27
28 template <class Numeric>
29 static bool isSignDifferent(const Numeric a, const Numeric b) {
30 return (isNegative(a) ^ isNegative(b));
31 }
32
33 template <class Numeric>
34 static bool isSignDifferentOrZero(const Numeric a, const Numeric b) {
35 if (a == 0. || b == 0.)
36 return true;
37 return isSignDifferent(a, b);
38 }
39
40 template <class hrleIterator, class Compare>
41 static bool checkNeighbourSigns(const hrleIterator &it, Compare comp) {
42 for (int i = 0; i < 2 * D; i++) {
43 if (comp(it.getCenter().getValue(), it.getNeighbor(i).getValue())) {
44 return true;
45 }
46 }
47 return false;
48 }
49
50 // small helper to check whether LS function is monotone
51 // around a zero value
52 static bool isMonotone(const T a, const T c) {
53 return a == 0. || c == 0. || isSignDifferent(a, c);
54 }
55
56public:
57 Prune() {}
58
59 Prune(SmartPointer<Domain<T, D>> passedlsDomain) : levelSet(passedlsDomain){};
60
61 void setLevelSet(SmartPointer<Domain<T, D>> passedlsDomain) {
62 levelSet = passedlsDomain;
63 }
64
67 void setUpdatePointData(bool update) { updatePointData = update; }
68
71 void setRemoveStrayZeros(bool rsz) { removeStrayZeros = rsz; }
72
76 void apply() {
77 if (levelSet == nullptr) {
78 Logger::getInstance()
79 .addWarning("No level set was passed to Prune.")
80 .print();
81 return;
82 }
83 if (levelSet->getNumberOfPoints() == 0) {
84 return;
85 }
86
87 auto &grid = levelSet->getGrid();
88 auto newlsDomain = SmartPointer<Domain<T, D>>::New(grid);
89 typename Domain<T, D>::DomainType &newDomain = newlsDomain->getDomain();
90 typename Domain<T, D>::DomainType &domain = levelSet->getDomain();
91
92 newDomain.initialize(domain.getNewSegmentation(), domain.getAllocation());
93
94 const bool updateData = updatePointData;
95 const bool removeZeros = removeStrayZeros;
96 // save how data should be transferred to new level set
97 // list of indices into the old pointData vector
98 std::vector<std::vector<unsigned>> newDataSourceIds;
99 if (updateData)
100 newDataSourceIds.resize(newDomain.getNumberOfSegments());
101
102#pragma omp parallel num_threads(newDomain.getNumberOfSegments())
103 {
104 int p = 0;
105#ifdef _OPENMP
106 p = omp_get_thread_num();
107#endif
108
109 auto &domainSegment = newDomain.getDomainSegment(p);
110
111 hrleVectorType<hrleIndexType, D> startVector =
112 (p == 0) ? grid.getMinGridPoint()
113 : newDomain.getSegmentation()[p - 1];
114
115 hrleVectorType<hrleIndexType, D> endVector =
116 (p != static_cast<int>(newDomain.getNumberOfSegments() - 1))
117 ? newDomain.getSegmentation()[p]
118 : grid.incrementIndices(grid.getMaxGridPoint());
119
120 for (hrleSparseStarIterator<typename Domain<T, D>::DomainType, 1>
121 neighborIt(domain, startVector);
122 neighborIt.getIndices() < endVector; neighborIt.next()) {
123 auto &centerIt = neighborIt.getCenter();
124 bool centerSign = isNegative(centerIt.getValue());
125 if (centerIt.isDefined()) {
126 bool keepPoint = true;
127 // if center is exact zero, always treat as unprunable
128 if (std::abs(centerIt.getValue()) != 0.) {
129 if (removeZeros) {
130 keepPoint =
131 checkNeighbourSigns(neighborIt, isSignDifferentOrZero<T>);
132 } else {
133 keepPoint = checkNeighbourSigns(neighborIt, isSignDifferent<T>);
134 }
135 }
136
137 if (removeZeros) {
138 // if the centre point is 0.0 and the level set values
139 // along each grid dimension are not monotone, it is
140 // a numerical glitch and should be removed
141 if (std::abs(centerIt.getValue()) == 0.) {
142 bool overWritePoint = false;
143 T undefVal = 0.;
144 for (int i = 0; i < D; i++) {
145 const auto &negVal = neighborIt.getNeighbor(i).getValue();
146 const auto &posVal = neighborIt.getNeighbor(D + i).getValue();
147
148 // if LS function is not monotone around the zero value,
149 // set the points value to that of the lower neighbour
150 if (!isMonotone(negVal, posVal)) {
151 overWritePoint = true;
152 undefVal =
153 std::abs(negVal) < std::abs(posVal) ? negVal : posVal;
154 break;
155 }
156 }
157 if (overWritePoint) {
158 domainSegment.insertNextDefinedPoint(neighborIt.getIndices(),
159 undefVal);
160 continue;
161 }
162 }
163 }
164
165 if (keepPoint) {
166 domainSegment.insertNextDefinedPoint(neighborIt.getIndices(),
167 centerIt.getValue());
168 if (updateData)
169 newDataSourceIds[p].push_back(centerIt.getPointId());
170 } else {
171 // TODO: it is more efficient to insertNextUndefinedRunType, since
172 // we know it already exists
173 domainSegment.insertNextUndefinedPoint(
174 neighborIt.getIndices(),
176 }
177 } else {
178 domainSegment.insertNextUndefinedPoint(
179 neighborIt.getIndices(),
181 }
182 }
183 }
184
185 // now copy old data into new level set
186 if (updateData) {
187 newlsDomain->getPointData().translateFromMultiData(
188 levelSet->getPointData(), newDataSourceIds);
189 }
190
191 // distribute evenly across segments and copy
192 newDomain.finalize();
193 newDomain.segment();
194 levelSet->deepCopy(newlsDomain);
195 levelSet->finalize(2);
196 }
197};
198
199// add all template specialisations for this class
201
202} // namespace viennals
Class containing all information about the level set, including the dimensions of the domain,...
Definition lsDomain.hpp:28
void deepCopy(const SmartPointer< Domain< T, D > > passedDomain)
copy all values of "passedDomain" to this Domain
Definition lsDomain.hpp:121
unsigned getNumberOfSegments() const
returns the number of segments, the levelset is split into. This is useful for algorithm parallelisat...
Definition lsDomain.hpp:150
hrleDomain< T, D > DomainType
Definition lsDomain.hpp:33
void finalize(int newWidth)
this function sets a new levelset width and finalizes the levelset, so it is ready for use by other a...
Definition lsDomain.hpp:114
Removes all level set points, which do not have at least one oppositely signed neighbour (Meaning the...
Definition lsPrune.hpp:19
void setUpdatePointData(bool update)
Set whether to update the point data stored in the LS during this algorithm. Defaults to true.
Definition lsPrune.hpp:67
void setRemoveStrayZeros(bool rsz)
Set whether to remove exact zero values between grid points with the same sign.
Definition lsPrune.hpp:71
void apply()
removes all grid points, which do not have at least one opposite signed neighbour returns the number ...
Definition lsPrune.hpp:76
Prune(SmartPointer< Domain< T, D > > passedlsDomain)
Definition lsPrune.hpp:59
void setLevelSet(SmartPointer< Domain< T, D > > passedlsDomain)
Definition lsPrune.hpp:61
Prune()
Definition lsPrune.hpp:57
#define PRECOMPILE_PRECISION_DIMENSION(className)
Definition lsPreCompileMacros.hpp:24
Definition lsAdvect.hpp:46
constexpr int D
Definition pyWrap.cpp:65
double T
Definition pyWrap.cpp:63