Skip to content

Commit 288bf87

Browse files
Merge pull request #62 from mayurdb/knnEfficient_takeOrdered
Added a KNNEfficient implementation with takeOrdered
2 parents acf3388 + 6a76b93 commit 288bf87

File tree

12 files changed

+350
-51
lines changed

12 files changed

+350
-51
lines changed

src/main/scala/com/spark3d/geometryObjects/Shape3D.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,9 @@ object Shape3D extends Serializable {
101101
def hasCenterCloseTo(p: Point3D, epsilon: Double): Boolean = {
102102
center.distanceTo(p) <= epsilon
103103
}
104+
105+
def getHash(): Int = {
106+
(center.getCoordinate.mkString("/") + getEnvelope.toString).hashCode
107+
}
104108
}
105109
}

src/main/scala/com/spark3d/spatial3DRDD/Shape3DRDD.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,11 @@ abstract class Shape3DRDD[T<:Shape3D] extends Serializable {
122122
if (maxItemsPerBox > Int.MaxValue) {
123123
throw new AssertionError(
124124
"""
125-
The max mumber of elements per partition have become greater than Int limit.
126-
Consider increasing number of partitions.
125+
The max number of elements per partition have become greater than Int limit.
126+
Consider increasing the number of partitions.
127127
""")
128128
}
129-
val octree = new Octree(getDataEnvelope, 0, maxItemsPerBox, maxLevels)
129+
val octree = new Octree(getDataEnvelope, 0, null, maxItemsPerBox, maxLevels)
130130
val partitioning = OctreePartitioning.apply(samples, octree)
131131
val grids = partitioning.getGrids
132132
new OctreePartitioner(octree, grids)

src/main/scala/com/spark3d/spatialOperator/SpatialQuery.scala

Lines changed: 96 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,108 @@ package com.spark3d.spatialOperator
1919
import com.spark3d.geometryObjects.Shape3D.Shape3D
2020
import com.spark3d.utils.GeometryObjectComparator
2121
import org.apache.spark.rdd.RDD
22+
import com.spark3d.spatialPartitioning._
2223

23-
import scala.collection.mutable.PriorityQueue
24-
24+
import scala.collection.mutable
25+
import scala.collection.mutable.{HashSet, ListBuffer, PriorityQueue}
26+
import scala.reflect.ClassTag
27+
import scala.util.control.Breaks._
2528

2629
object SpatialQuery {
2730

28-
def KNN[A <: Shape3D, B <:Shape3D](queryObject: A, rdd: RDD[B], k: Int): List[B] = {
29-
30-
val pq: PriorityQueue[B] = PriorityQueue.empty[B](new GeometryObjectComparator[B](queryObject.center))
31-
32-
val itr = rdd.toLocalIterator
33-
34-
while (itr.hasNext) {
35-
val currentElement = itr.next
36-
if (pq.size < k) {
37-
pq.enqueue(currentElement)
38-
} else {
39-
val currentEleDist = currentElement.center.distanceTo(queryObject.center)
40-
// TODO make use of pq.max
41-
val maxElement = pq.dequeue
42-
val maxEleDist = maxElement.center.distanceTo(queryObject.center)
43-
if (currentEleDist < maxEleDist) {
44-
pq.enqueue(currentElement)
45-
} else {
46-
pq.enqueue(maxElement)
31+
/**
32+
* Finds the K nearest neighbors of the query object. The naive implementation here searches
33+
* through all the the objects in the RDD to get the KNN. The nearness of the objects here
34+
* is decided on the basis of the distance between their centers.
35+
*
36+
* @param queryObject object to which the knn are to be found
37+
* @param rdd RDD of a Shape3D (Shape3DRDD)
38+
* @param k number of nearest neighbors are to be found
39+
* @return knn
40+
*/
41+
def KNN[A <: Shape3D: ClassTag, B <:Shape3D: ClassTag](queryObject: A, rdd: RDD[B], k: Int): List[B] = {
42+
val knn = rdd.takeOrdered(k)(new GeometryObjectComparator[B](queryObject.center))
43+
knn.toList
44+
}
45+
46+
/**
47+
* Much more efficient implementation of the KNN query above. First we seek the partitions in
48+
* which the query object belongs and we will look for the knn only in those partitions. After
49+
* this if the limit k is not satisfied, we keep looking similarly in the neighbors of the
50+
* containing partitions.
51+
*
52+
* @param queryObject object to which the knn are to be found
53+
* @param rdd RDD of a Shape3D (Shape3DRDD)
54+
* @param k number of nearest neighbors are to be found
55+
* @return knn
56+
*/
57+
def KNNEfficient[A <: Shape3D: ClassTag, B <:Shape3D: ClassTag](queryObject: A, rdd: RDD[B], k: Int): List[B] = {
58+
59+
val partitioner = rdd.partitioner.get.asInstanceOf[SpatialPartitioner]
60+
val containingPartitions = partitioner.getPartitionNodes(queryObject)
61+
val containingPartitionsIndex = containingPartitions.map(x => x._1)
62+
val matchedContainingSubRDD = rdd.mapPartitionsWithIndex(
63+
(index, iter) => {
64+
if (containingPartitionsIndex.contains(index)) iter else Iterator.empty
65+
}
66+
)
67+
68+
val knn_1 = matchedContainingSubRDD.takeOrdered(k)(new GeometryObjectComparator[B](queryObject.center))
69+
70+
if (knn_1.size >= k) {
71+
return knn_1.toList
72+
}
73+
74+
val visitedPartitions = new HashSet[Int]
75+
visitedPartitions ++= containingPartitionsIndex
76+
77+
val neighborPartitions = partitioner.getNeighborNodes(queryObject)
78+
.filter(x => !visitedPartitions.contains(x._1)).to[ListBuffer]
79+
val neighborPartitionsIndex = neighborPartitions.map(x => x._1)
80+
81+
val matchedNeighborSubRDD = rdd.mapPartitionsWithIndex(
82+
(index, iter) => {
83+
if (neighborPartitionsIndex.contains(index)) iter else Iterator.empty
84+
}
85+
)
86+
87+
val knn_2 = matchedNeighborSubRDD.takeOrdered(k-knn_1.size)(new GeometryObjectComparator[B](queryObject.center))
88+
89+
var knn_f = knn_1 ++ knn_2
90+
if (knn_f.size >= k) {
91+
return knn_f.toList
92+
}
93+
94+
visitedPartitions ++= neighborPartitionsIndex
95+
96+
breakable {
97+
for (neighborPartition <- neighborPartitions) {
98+
val secondaryNeighborPartitions = partitioner.getSecondaryNeighborNodes(neighborPartition._2, neighborPartition._1)
99+
.filter(x => !visitedPartitions.contains(x._1))
100+
val secondaryNeighborPartitionsIndex = secondaryNeighborPartitions.map(x => x._1)
101+
102+
val matchedSecondaryNeighborSubRDD = rdd.mapPartitionsWithIndex(
103+
(index, iter) => {
104+
if (secondaryNeighborPartitionsIndex.contains(index))
105+
iter
106+
else
107+
Iterator.empty
108+
}
109+
)
110+
111+
112+
val knn_t = matchedNeighborSubRDD.takeOrdered(k-knn_f.size)(new GeometryObjectComparator[B](queryObject.center))
113+
114+
knn_f = knn_f ++ knn_t
115+
116+
if (knn_f.size >= k) {
117+
break
47118
}
119+
120+
visitedPartitions ++= secondaryNeighborPartitionsIndex
121+
neighborPartitions ++= secondaryNeighborPartitions
48122
}
49123
}
50-
pq.toList.sortWith(_.center.distanceTo(queryObject.center) < _.center.distanceTo(queryObject.center))
124+
knn_f.toList
51125
}
52126
}

src/main/scala/com/spark3d/spatialPartitioning/Octree.scala

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import scala.collection.mutable.Queue
4444
class Octree(
4545
val box: BoxEnvelope,
4646
val level: Int,
47+
val parentNode: Octree = null,
4748
val maxItemsPerNode: Int = 5,
4849
val maxLevel: Int = 10)
4950
extends Serializable {
@@ -77,56 +78,56 @@ class Octree(
7778
box.minX, (box.maxX - box.minX) / 2,
7879
box.minY, (box.maxY - box.minY) / 2,
7980
box.minZ, (box.maxZ - box.minZ) / 2),
80-
level + 1, maxItemsPerNode, maxLevel)
81+
level + 1, this, maxItemsPerNode, maxLevel)
8182

8283
children(CHILD_L_SE) = new Octree(
8384
BoxEnvelope.apply(
8485
(box.maxX - box.minX) / 2, box.maxX,
8586
box.minY, (box.maxY - box.minY) / 2,
8687
box.minZ, (box.maxZ - box.minZ) / 2),
87-
level + 1, maxItemsPerNode, maxLevel)
88+
level + 1, this, maxItemsPerNode, maxLevel)
8889

8990
children(CHILD_L_NW) = new Octree(
9091
BoxEnvelope.apply(
9192
box.minX, (box.maxX - box.minX) / 2,
9293
(box.maxY - box.minY) / 2, box.maxY,
9394
box.minZ, (box.maxZ - box.minZ) / 2),
94-
level + 1, maxItemsPerNode, maxLevel)
95+
level + 1, this, maxItemsPerNode, maxLevel)
9596

9697
children(CHILD_L_NE) = new Octree(
9798
BoxEnvelope.apply(
9899
(box.maxX - box.minX) / 2, box.maxX,
99100
(box.maxY - box.minY) / 2, box.maxY,
100101
box.minZ, (box.maxZ - box.minZ) / 2),
101-
level + 1, maxItemsPerNode, maxLevel)
102+
level + 1, this, maxItemsPerNode, maxLevel)
102103

103104
children(CHILD_U_SW) = new Octree(
104105
BoxEnvelope.apply(
105106
box.minX, (box.maxX - box.minX) / 2,
106107
box.minY, (box.maxY - box.minY) / 2,
107108
(box.maxZ - box.minZ) / 2, box.maxZ),
108-
level + 1, maxItemsPerNode, maxLevel)
109+
level + 1, this, maxItemsPerNode, maxLevel)
109110

110111
children(CHILD_U_SE) = new Octree(
111112
BoxEnvelope.apply(
112113
(box.maxX - box.minX) / 2, box.maxX,
113114
box.minY, (box.maxY - box.minY) / 2,
114115
(box.maxZ - box.minZ) / 2, box.maxZ),
115-
level + 1, maxItemsPerNode, maxLevel)
116+
level + 1, this, maxItemsPerNode, maxLevel)
116117

117118
children(CHILD_U_NW) = new Octree(
118119
BoxEnvelope.apply(
119120
box.minX, (box.maxX - box.minX) / 2,
120121
(box.maxY - box.minY) / 2, box.maxY,
121122
(box.maxZ - box.minZ) / 2, box.maxZ),
122-
level + 1, maxItemsPerNode, maxLevel)
123+
level + 1, this, maxItemsPerNode, maxLevel)
123124

124125
children(CHILD_U_NE) = new Octree(
125126
BoxEnvelope.apply(
126127
(box.maxX - box.minX) / 2, box.maxX,
127128
(box.maxY - box.minY) / 2, box.maxY,
128129
(box.maxZ - box.minZ) / 2, box.maxZ),
129-
level + 1, maxItemsPerNode, maxLevel)
130+
level + 1, this, maxItemsPerNode, maxLevel)
130131

131132
}
132133

@@ -254,9 +255,9 @@ class Octree(
254255
* @param obj input object for which the search is to be performed
255256
* @param data a ListBuffer in which the desired data should be placed when the funct() == true
256257
*/
257-
private def dfsTraverse(func: (Octree, BoxEnvelope) => Boolean, obj: BoxEnvelope, data: ListBuffer[BoxEnvelope]): Unit = {
258+
private def dfsTraverse(func: (Octree, BoxEnvelope) => Boolean, obj: BoxEnvelope, data: ListBuffer[Octree]): Unit = {
258259
if (func(this, obj)) {
259-
data += box
260+
data += this
260261
}
261262

262263
if (!isLeaf) {
@@ -360,22 +361,57 @@ class Octree(
360361
}
361362

362363
/**
363-
* get all the leaf nodes, which intersect, contain or are contained
364+
* get all the containing Envelopes of the leaf nodes, which intersect, contain or are contained
364365
* by the input BoxEnvelope
365366
*
366367
* @param obj Input object to be checked for the match
367-
* @return list of leafNodes which match the conditions
368+
* @return list of Envelopes of the leafNodes which match the conditions
368369
*/
369-
def getMatchedLeaves(obj: BoxEnvelope): ListBuffer[BoxEnvelope] = {
370+
def getMatchedLeafBoxes(obj: BoxEnvelope): ListBuffer[BoxEnvelope] = {
371+
372+
val matchedLeaves = getMatchedLeaves(obj)
373+
matchedLeaves.map(x => x.box)
374+
}
370375

371-
val matchedLeaves = new ListBuffer[BoxEnvelope]
376+
/**
377+
* get all the containing Envelopes of the leaf nodes, which intersect, contain or are contained
378+
* by the input BoxEnvelope
379+
*
380+
* @param obj Input object to be checked for the match
381+
* @return list of leafNodes which match the conditions
382+
*/
383+
def getMatchedLeaves(obj: BoxEnvelope): ListBuffer[Octree] = {
384+
val matchedLeaves = new ListBuffer[Octree]
372385
val traverseFunct: (Octree, BoxEnvelope) => Boolean = {
373386
(node, obj) => node.isLeaf && (node.box.intersects(obj) ||
374-
node.box.contains(obj) ||
375-
obj.contains(node.box))
387+
node.box.contains(obj) ||
388+
obj.contains(node.box))
376389
}
377390

378391
dfsTraverse(traverseFunct, obj, matchedLeaves)
379392
matchedLeaves
380393
}
394+
395+
/**
396+
* Get the neighbors of this node. Neighbors here are leaf sibling or leaf descendants of the
397+
* siblings.
398+
*
399+
* @param queryNode the box of the the input node to avoid passing same node as neighbor
400+
* @return list of lead neghbors and their index/partition ID's
401+
*/
402+
def getLeafNeighbors(queryNode: BoxEnvelope): List[Tuple2[Int, BoxEnvelope]] = {
403+
val leafNeighbors = new ListBuffer[Tuple2[Int, BoxEnvelope]]
404+
if (parentNode != null){
405+
for (neighbor <- parentNode.children) {
406+
if (!neighbor.box.isEqual(queryNode)) {
407+
if (neighbor.isLeaf) {
408+
leafNeighbors += new Tuple2(neighbor.box.indexID, neighbor.box)
409+
} else {
410+
leafNeighbors ++= neighbor.children(0).getLeafNeighbors(queryNode)
411+
}
412+
}
413+
}
414+
}
415+
leafNeighbors.toList.distinct
416+
}
381417
}

src/main/scala/com/spark3d/spatialPartitioning/OctreePartitioner.scala

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,65 @@ class OctreePartitioner (octree: Octree, grids : List[BoxEnvelope]) extends Spat
4646

4747
val result = HashSet.empty[Tuple2[Int, T]]
4848
var matchedPartitions = new ListBuffer[BoxEnvelope]
49-
matchedPartitions ++= octree.getMatchedLeaves(spatialObject.getEnvelope)
49+
matchedPartitions ++= octree.getMatchedLeafBoxes(spatialObject.getEnvelope)
5050
for(partition <- matchedPartitions) {
5151
result += new Tuple2(partition.indexID, spatialObject)
5252
}
5353
result.toIterator
5454
}
55+
56+
/**
57+
* Gets the partitions which contain the input object.
58+
*
59+
* @param spatialObject input object for which the containment is to be found
60+
* @return list of Tuple of containing partitions and their index/partition ID's
61+
*/
62+
override def getPartitionNodes[T <: Shape3D](spatialObject: T): List[Tuple2[Int, Shape3D]] = {
63+
64+
var partitionNodes = new ListBuffer[Shape3D]
65+
partitionNodes ++= octree.getMatchedLeafBoxes(spatialObject.getEnvelope)
66+
var partitionNodesIDs = partitionNodes.map(x => new Tuple2(x.getEnvelope.indexID, x))
67+
partitionNodesIDs.toList
68+
}
69+
70+
/**
71+
* Gets the partitions which are the neighbors of the partitions which contain the input object.
72+
*
73+
* @param spatialObject input object for which the neighbors are to be found
74+
* @return list of Tuple of neighbor partitions and their index/partition ID's
75+
*/
76+
override def getNeighborNodes[T <: Shape3D](spatialObject: T): List[Tuple2[Int, Shape3D]] = {
77+
val neighborNodes = new ListBuffer[Tuple2[Int, Shape3D]]
78+
val partitionNodes = octree.getMatchedLeaves(spatialObject.getEnvelope)
79+
for (partitionNode <- partitionNodes) {
80+
neighborNodes ++= partitionNode.getLeafNeighbors(partitionNode.box.getEnvelope)
81+
}
82+
neighborNodes.toList
83+
}
84+
85+
/**
86+
* Gets the partitions which are the neighbors to the input partition. Useful when getting
87+
* secondary neighbors (neighbors to neighbor) of the queryObject.
88+
*
89+
* @param containingNode The boundary of the Node for which neighbors are to be found.
90+
* @param containingNodeID The index/partition ID of the containingNode
91+
* @return list of Tuple of secondary neighbor partitions and their index/partition IDs
92+
*/
93+
override def getSecondaryNeighborNodes[T <: Shape3D](containingNode: T, containingNodeID: Int): List[Tuple2[Int, Shape3D]] = {
94+
val secondaryNeighborNodes = new ListBuffer[Tuple2[Int, Shape3D]]
95+
// get the bounding box
96+
val box = containingNode.getEnvelope
97+
// reduce the bounding box slightly to avoid getting all the neighbor nodes as the containing nodes
98+
val searchBox = BoxEnvelope.apply(box.minX+0.0001, box.maxX-0.0001,
99+
box.minY+0.0001, box.maxY-0.0001,
100+
box.minZ+0.0001, box.maxZ-0.0001)
101+
val partitionNodes = octree.getMatchedLeaves(searchBox.getEnvelope)
102+
// ideally partitionNodes should be of size 1 as the input containingNode is nothing but the
103+
// boundary of a node in the tree.
104+
for (partitionNode <- partitionNodes) {
105+
secondaryNeighborNodes ++= partitionNode.getLeafNeighbors(partitionNode.box.getEnvelope)
106+
}
107+
secondaryNeighborNodes.toList
108+
}
109+
55110
}

0 commit comments

Comments
 (0)