ViennaLS
Loading...
Searching...
No Matches
lsPrune.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <hrleSparseStarIterator.hpp>
4#include <lsDomain.hpp>
6#include <vcVectorType.hpp>
7
8namespace viennals {
9
10using namespace viennacore;
11
17template <class T, int D> class Prune {
18 SmartPointer<Domain<T, D>> levelSet = nullptr;
19 bool updatePointData = true;
20 bool removeStrayZeros = false;
21
22 template <class Numeric> static bool isNegative(const Numeric a) {
23 return a <= -std::numeric_limits<Numeric>::epsilon();
24 }
25
26 template <class Numeric>
27 static bool isSignDifferent(const Numeric a, const Numeric b) {
28 return (isNegative(a) ^ isNegative(b));
29 }
30
31 template <class Numeric>
32 static bool isSignDifferentOrZero(const Numeric a, const Numeric b) {
33 if (a == 0. || b == 0.)
34 return true;
35 return isSignDifferent(a, b);
36 }
37
38 template <class hrleIterator, class Compare>
39 static bool checkNeighbourSigns(const hrleIterator &it, Compare comp) {
40 for (int i = 0; i < 2 * D; i++) {
41 if (comp(it.getCenter().getValue(), it.getNeighbor(i).getValue())) {
42 return true;
43 }
44 }
45 return false;
46 }
47
48 // small helper to check whether LS function is monotone
49 // around a zero value
50 static bool isMonotone(const T a, const T c) {
51 return a == 0. || c == 0. || isSignDifferent(a, c);
52 }
53
54public:
55 Prune() = default;
56
57 Prune(SmartPointer<Domain<T, D>> passedlsDomain) : levelSet(passedlsDomain){};
58
59 void setLevelSet(SmartPointer<Domain<T, D>> passedlsDomain) {
60 levelSet = passedlsDomain;
61 }
62
65 void setUpdatePointData(bool update) { updatePointData = update; }
66
69 void setRemoveStrayZeros(bool rsz) { removeStrayZeros = rsz; }
70
74 void apply() {
75 if (levelSet == nullptr) {
76 Logger::getInstance()
77 .addWarning("No level set was passed to Prune.")
78 .print();
79 return;
80 }
81 if (levelSet->getNumberOfPoints() == 0) {
82 return;
83 }
84
85 auto &grid = levelSet->getGrid();
86 auto newlsDomain = SmartPointer<Domain<T, D>>::New(grid);
87 typename Domain<T, D>::DomainType &newDomain = newlsDomain->getDomain();
88 typename Domain<T, D>::DomainType &domain = levelSet->getDomain();
89
90 newDomain.initialize(domain.getNewSegmentation(), domain.getAllocation());
91
92 const bool updateData = updatePointData;
93 const bool removeZeros = removeStrayZeros;
94 // save how data should be transferred to new level set
95 // list of indices into the old pointData vector
96 std::vector<std::vector<unsigned>> newDataSourceIds;
97 if (updateData)
98 newDataSourceIds.resize(newDomain.getNumberOfSegments());
99
100#pragma omp parallel num_threads(newDomain.getNumberOfSegments())
101 {
102 int p = 0;
103#ifdef _OPENMP
104 p = omp_get_thread_num();
105#endif
106
107 auto &domainSegment = newDomain.getDomainSegment(p);
108
109 viennahrle::Index<D> const startVector =
110 (p == 0) ? grid.getMinGridPoint()
111 : newDomain.getSegmentation()[p - 1];
112
113 viennahrle::Index<D> const endVector =
114 (p != static_cast<int>(newDomain.getNumberOfSegments() - 1))
115 ? newDomain.getSegmentation()[p]
116 : grid.incrementIndices(grid.getMaxGridPoint());
117
118 for (viennahrle::SparseStarIterator<typename Domain<T, D>::DomainType, 1>
119 neighborIt(domain, startVector);
120 neighborIt.getIndices() < endVector; neighborIt.next()) {
121 auto &centerIt = neighborIt.getCenter();
122 bool centerSign = isNegative(centerIt.getValue());
123 if (centerIt.isDefined()) {
124 bool keepPoint = true;
125 // if center is exact zero, always treat as unprunable
126 if (std::abs(centerIt.getValue()) != 0.) {
127 if (removeZeros) {
128 keepPoint =
129 checkNeighbourSigns(neighborIt, isSignDifferentOrZero<T>);
130 } else {
131 keepPoint = checkNeighbourSigns(neighborIt, isSignDifferent<T>);
132 }
133 }
134
135 if (removeZeros) {
136 // if the centre point is 0.0 and the level set values
137 // along each grid dimension are not monotone, it is
138 // a numerical glitch and should be removed
139 if (std::abs(centerIt.getValue()) == 0.) {
140 bool overWritePoint = false;
141 T undefVal = 0.;
142 for (int i = 0; i < D; i++) {
143 const auto &negVal = neighborIt.getNeighbor(i).getValue();
144 const auto &posVal = neighborIt.getNeighbor(D + i).getValue();
145
146 // if LS function is not monotone around the zero value,
147 // set the points value to that of the lower neighbour
148 if (!isMonotone(negVal, posVal)) {
149 overWritePoint = true;
150 undefVal =
151 std::abs(negVal) < std::abs(posVal) ? negVal : posVal;
152 break;
153 }
154 }
155 if (overWritePoint) {
156 domainSegment.insertNextDefinedPoint(neighborIt.getIndices(),
157 undefVal);
158 continue;
159 }
160 }
161 }
162
163 if (keepPoint) {
164 domainSegment.insertNextDefinedPoint(neighborIt.getIndices(),
165 centerIt.getValue());
166 if (updateData)
167 newDataSourceIds[p].push_back(centerIt.getPointId());
168 } else {
169 // TODO: it is more efficient to insertNextUndefinedRunType, since
170 // we know it already exists
171 domainSegment.insertNextUndefinedPoint(
172 neighborIt.getIndices(),
174 }
175 } else {
176 domainSegment.insertNextUndefinedPoint(
177 neighborIt.getIndices(),
179 }
180 }
181 }
182
183 // now copy old data into new level set
184 if (updateData) {
185 newlsDomain->getPointData().translateFromMultiData(
186 levelSet->getPointData(), newDataSourceIds);
187 }
188
189 // distribute evenly across segments and copy
190 newDomain.finalize();
191 newDomain.segment();
192 levelSet->deepCopy(newlsDomain);
193 levelSet->finalize(2);
194 }
195};
196
197// add all template specialisations for this class
199
200} // namespace viennals
Class containing all information about the level set, including the dimensions of the domain,...
Definition lsDomain.hpp:27
viennahrle::Domain< T, D > DomainType
Definition lsDomain.hpp:32
static constexpr T NEG_VALUE
Definition lsDomain.hpp:52
void deepCopy(const SmartPointer< Domain< T, D > > passedDomain)
copy all values of "passedDomain" to this Domain
Definition lsDomain.hpp:119
unsigned getNumberOfSegments() const
returns the number of segments, the levelset is split into. This is useful for algorithm parallelisat...
Definition lsDomain.hpp:148
void finalize(int newWidth)
this function sets a new levelset width and finalizes the levelset, so it is ready for use by other a...
Definition lsDomain.hpp:112
static constexpr T POS_VALUE
Definition lsDomain.hpp:51
void setUpdatePointData(bool update)
Set whether to update the point data stored in the LS during this algorithm. Defaults to true.
Definition lsPrune.hpp:65
void setRemoveStrayZeros(bool rsz)
Set whether to remove exact zero values between grid points with the same sign.
Definition lsPrune.hpp:69
void apply()
removes all grid points, which do not have at least one opposite signed neighbour returns the number ...
Definition lsPrune.hpp:74
Prune(SmartPointer< Domain< T, D > > passedlsDomain)
Definition lsPrune.hpp:57
void setLevelSet(SmartPointer< Domain< T, D > > passedlsDomain)
Definition lsPrune.hpp:59
Prune()=default
#define PRECOMPILE_PRECISION_DIMENSION(className)
Definition lsPreCompileMacros.hpp:24
Definition lsAdvect.hpp:36
constexpr int D
Definition pyWrap.cpp:70
double T
Definition pyWrap.cpp:68