about summary refs log tree commit diff stats
path: root/vendor/quadtree/Quadtree.h
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/quadtree/Quadtree.h')
-rw-r--r--vendor/quadtree/Quadtree.h315
1 files changed, 315 insertions, 0 deletions
diff --git a/vendor/quadtree/Quadtree.h b/vendor/quadtree/Quadtree.h new file mode 100644 index 0000000..06097c5 --- /dev/null +++ b/vendor/quadtree/Quadtree.h
@@ -0,0 +1,315 @@
1#pragma once
2
3#include <cassert>
4#include <algorithm>
5#include <array>
6#include <memory>
7#include <type_traits>
8#include <vector>
9#include "Box.h"
10
11namespace quadtree
12{
13
14template<typename T, typename GetBox, typename Equal = std::equal_to<T>, typename Float = float>
15class Quadtree
16{
17 static_assert(std::is_convertible_v<std::invoke_result_t<GetBox, const T&>, Box<Float>>,
18 "GetBox must be a callable of signature Box<Float>(const T&)");
19 static_assert(std::is_convertible_v<std::invoke_result_t<Equal, const T&, const T&>, bool>,
20 "Equal must be a callable of signature bool(const T&, const T&)");
21 static_assert(std::is_arithmetic_v<Float>);
22
23public:
24 Quadtree(const Box<Float>& box, const GetBox& getBox = GetBox(),
25 const Equal& equal = Equal()) :
26 mBox(box), mRoot(std::make_unique<Node>()), mGetBox(getBox), mEqual(equal)
27 {
28
29 }
30
31 void add(const T& value)
32 {
33 add(mRoot.get(), 0, mBox, value);
34 }
35
36 void remove(const T& value)
37 {
38 remove(mRoot.get(), mBox, value);
39 }
40
41 std::vector<T> query(const Box<Float>& box) const
42 {
43 auto values = std::vector<T>();
44 query(mRoot.get(), mBox, box, values);
45 return values;
46 }
47
48 std::vector<std::pair<T, T>> findAllIntersections() const
49 {
50 auto intersections = std::vector<std::pair<T, T>>();
51 findAllIntersections(mRoot.get(), intersections);
52 return intersections;
53 }
54
55 Box<Float> getBox() const
56 {
57 return mBox;
58 }
59
60private:
61 static constexpr auto Threshold = std::size_t(16);
62 static constexpr auto MaxDepth = std::size_t(8);
63
64 struct Node
65 {
66 std::array<std::unique_ptr<Node>, 4> children;
67 std::vector<T> values;
68 };
69
70 Box<Float> mBox;
71 std::unique_ptr<Node> mRoot;
72 GetBox mGetBox;
73 Equal mEqual;
74
75 bool isLeaf(const Node* node) const
76 {
77 return !static_cast<bool>(node->children[0]);
78 }
79
80 Box<Float> computeBox(const Box<Float>& box, int i) const
81 {
82 auto origin = box.getTopLeft();
83 auto childSize = box.getSize() / static_cast<Float>(2);
84 switch (i)
85 {
86 // North West
87 case 0:
88 return Box<Float>(origin, childSize);
89 // Norst East
90 case 1:
91 return Box<Float>(Vector2<Float>(origin.x + childSize.x, origin.y), childSize);
92 // South West
93 case 2:
94 return Box<Float>(Vector2<Float>(origin.x, origin.y + childSize.y), childSize);
95 // South East
96 case 3:
97 return Box<Float>(origin + childSize, childSize);
98 default:
99 assert(false && "Invalid child index");
100 return Box<Float>();
101 }
102 }
103
104 int getQuadrant(const Box<Float>& nodeBox, const Box<Float>& valueBox) const
105 {
106 auto center = nodeBox.getCenter();
107 // West
108 if (valueBox.getRight() < center.x)
109 {
110 // North West
111 if (valueBox.getBottom() < center.y)
112 return 0;
113 // South West
114 else if (valueBox.top >= center.y)
115 return 2;
116 // Not contained in any quadrant
117 else
118 return -1;
119 }
120 // East
121 else if (valueBox.left >= center.x)
122 {
123 // North East
124 if (valueBox.getBottom() < center.y)
125 return 1;
126 // South East
127 else if (valueBox.top >= center.y)
128 return 3;
129 // Not contained in any quadrant
130 else
131 return -1;
132 }
133 // Not contained in any quadrant
134 else
135 return -1;
136 }
137
138 void add(Node* node, std::size_t depth, const Box<Float>& box, const T& value)
139 {
140 assert(node != nullptr);
141 assert(box.contains(mGetBox(value)));
142 if (isLeaf(node))
143 {
144 // Insert the value in this node if possible
145 if (depth >= MaxDepth || node->values.size() < Threshold)
146 node->values.push_back(value);
147 // Otherwise, we split and we try again
148 else
149 {
150 split(node, box);
151 add(node, depth, box, value);
152 }
153 }
154 else
155 {
156 auto i = getQuadrant(box, mGetBox(value));
157 // Add the value in a child if the value is entirely contained in it
158 if (i != -1)
159 add(node->children[static_cast<std::size_t>(i)].get(), depth + 1, computeBox(box, i), value);
160 // Otherwise, we add the value in the current node
161 else
162 node->values.push_back(value);
163 }
164 }
165
166 void split(Node* node, const Box<Float>& box)
167 {
168 assert(node != nullptr);
169 assert(isLeaf(node) && "Only leaves can be split");
170 // Create children
171 for (auto& child : node->children)
172 child = std::make_unique<Node>();
173 // Assign values to children
174 auto newValues = std::vector<T>(); // New values for this node
175 for (const auto& value : node->values)
176 {
177 auto i = getQuadrant(box, mGetBox(value));
178 if (i != -1)
179 node->children[static_cast<std::size_t>(i)]->values.push_back(value);
180 else
181 newValues.push_back(value);
182 }
183 node->values = std::move(newValues);
184 }
185
186 bool remove(Node* node, const Box<Float>& box, const T& value)
187 {
188 assert(node != nullptr);
189 assert(box.contains(mGetBox(value)));
190 if (isLeaf(node))
191 {
192 // Remove the value from node
193 removeValue(node, value);
194 return true;
195 }
196 else
197 {
198 // Remove the value in a child if the value is entirely contained in it
199 auto i = getQuadrant(box, mGetBox(value));
200 if (i != -1)
201 {
202 if (remove(node->children[static_cast<std::size_t>(i)].get(), computeBox(box, i), value))
203 return tryMerge(node);
204 }
205 // Otherwise, we remove the value from the current node
206 else
207 removeValue(node, value);
208 return false;
209 }
210 }
211
212 void removeValue(Node* node, const T& value)
213 {
214 // Find the value in node->values
215 auto it = std::find_if(std::begin(node->values), std::end(node->values),
216 [this, &value](const auto& rhs){ return mEqual(value, rhs); });
217 assert(it != std::end(node->values) && "Trying to remove a value that is not present in the node");
218 // Swap with the last element and pop back
219 *it = std::move(node->values.back());
220 node->values.pop_back();
221 }
222
223 bool tryMerge(Node* node)
224 {
225 assert(node != nullptr);
226 assert(!isLeaf(node) && "Only interior nodes can be merged");
227 auto nbValues = node->values.size();
228 for (const auto& child : node->children)
229 {
230 if (!isLeaf(child.get()))
231 return false;
232 nbValues += child->values.size();
233 }
234 if (nbValues <= Threshold)
235 {
236 node->values.reserve(nbValues);
237 // Merge the values of all the children
238 for (const auto& child : node->children)
239 {
240 for (const auto& value : child->values)
241 node->values.push_back(value);
242 }
243 // Remove the children
244 for (auto& child : node->children)
245 child.reset();
246 return true;
247 }
248 else
249 return false;
250 }
251
252 void query(Node* node, const Box<Float>& box, const Box<Float>& queryBox, std::vector<T>& values) const
253 {
254 assert(node != nullptr);
255 assert(queryBox.intersects(box));
256 for (const auto& value : node->values)
257 {
258 if (queryBox.intersects(mGetBox(value)))
259 values.push_back(value);
260 }
261 if (!isLeaf(node))
262 {
263 for (auto i = std::size_t(0); i < node->children.size(); ++i)
264 {
265 auto childBox = computeBox(box, static_cast<int>(i));
266 if (queryBox.intersects(childBox))
267 query(node->children[i].get(), childBox, queryBox, values);
268 }
269 }
270 }
271
272 void findAllIntersections(Node* node, std::vector<std::pair<T, T>>& intersections) const
273 {
274 // Find intersections between values stored in this node
275 // Make sure to not report the same intersection twice
276 for (auto i = std::size_t(0); i < node->values.size(); ++i)
277 {
278 for (auto j = std::size_t(0); j < i; ++j)
279 {
280 if (mGetBox(node->values[i]).intersects(mGetBox(node->values[j])))
281 intersections.emplace_back(node->values[i], node->values[j]);
282 }
283 }
284 if (!isLeaf(node))
285 {
286 // Values in this node can intersect values in descendants
287 for (const auto& child : node->children)
288 {
289 for (const auto& value : node->values)
290 findIntersectionsInDescendants(child.get(), value, intersections);
291 }
292 // Find intersections in children
293 for (const auto& child : node->children)
294 findAllIntersections(child.get(), intersections);
295 }
296 }
297
298 void findIntersectionsInDescendants(Node* node, const T& value, std::vector<std::pair<T, T>>& intersections) const
299 {
300 // Test against the values stored in this node
301 for (const auto& other : node->values)
302 {
303 if (mGetBox(value).intersects(mGetBox(other)))
304 intersections.emplace_back(value, other);
305 }
306 // Test against values stored into descendants of this node
307 if (!isLeaf(node))
308 {
309 for (const auto& child : node->children)
310 findIntersectionsInDescendants(child.get(), value, intersections);
311 }
312 }
313};
314
315}