scala最大堆最小堆,通过堆取TopN

使用最小堆取最大的N个元素

Posted by Woods on March 23, 2018

用scala实现一个最大堆或最小堆

并且维护一个长度为N的最小堆取,用来取TopN

定义一个HeapSort class
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/**
  * 堆排序:采用大顶堆,根节点大于两层元素,该类需要满足具备三个方法: 1.由正常数组变换为符合堆结构的方法
  * 2.向堆中插入数据。3.从堆中删除数据。
  * 重要特性:1.根节点索引值为i,根节点元素满足Key[i]>=Key[2i+1]&&key>=key[2i+2]
  * 2.非叶子节点数量<=总节点数的/2
  *
  * ascending = true 维护一个最小堆  反之最大堆
  * @param ascending 
  */
class HeapSort(var ascending:Boolean = true) {
  /**
    * 初始化堆
    *
    * @param unSortHeap 任意数组
    * @return 符合堆结构的数组
    */
  def generateHeap(unSortHeap: Array[Int]): Array[Int] = {
    val num = unSortHeap.length
    if (num <= 1) return unSortHeap
    var tempUnSortHeap = unSortHeap


    for (i <- num / 2 - 1 until num) {
      tempUnSortHeap = sort(tempUnSortHeap, i)
    }
    tempUnSortHeap
  }

  /**
    * 排序特定子树
    *
    * @param heap  未排序堆
    * @param index 当前排序索引值
    * @return 某子树已经排序完
    */
  def sort(heap: Array[Int], index: Int): Array[Int] = {
    //父节点
    var childIndex = index

    var parentIndex = (childIndex - 1) / 2
    var temp = heap(childIndex)
    while (parentIndex >= 0 && childIndex != 0) {
      val flag = heap(parentIndex) < temp
      if( (!this.ascending && flag) || (this.ascending && !flag)) {
        //swap
        heap(childIndex) = heap(parentIndex)
        heap(parentIndex) = temp
      }

      childIndex = parentIndex
      parentIndex = (childIndex - 1) / 2
      temp = heap(childIndex)
    }
    heap
  }

  /**
    * 从堆中删除数据
    *
    * @param sortedHeap 已排序的队列
    * @param index      索引值
    */
  def deleteFromHeap(sortedHeap: Array[Int], index: Int): Array[Int] = {
    val newArray = sortedHeap.toBuffer
    newArray.remove(index)
    generateHeap(newArray.toArray)
  }

  /**
    * 从堆顶删除元素
    * 用于维持一个固定大小的堆,如果是最小堆,可以取最大的N个元素
    * @param sortedHeap 已排序的队列
    */
  def deleteHead(sortedHeap: Array[Int]): Array[Int] = {
    val newArray = sortedHeap.toBuffer
    newArray.remove(0)
    generateHeap(newArray.toArray)
  }



  /**
    * 向堆中插入数据
    *
    * @param sortedHeap 已排序的队列
    * @param newElement 添加到堆中的元素
    */
  def insertToHeap(sortedHeap: Array[Int], newElement: Int): Array[Int] = {
    val newList = sortedHeap :+ newElement
    sort(newList, newList.length - 1)
  }

  def printlnList(sortedHeap: Array[Int]): Unit = {
    sortedHeap.foreach(item => print(s"$item "))
    println("sorted!")
    //    for (k <- sortedHeap.indices) {
    //      println(s"($k->${sortedHeap(k)})")
    //    }
  }
}
在同一文件中定义一个伴生对象
1
2
3
4
5
object HeapSort {
  def apply(ascending:Boolean = true):HeapSort = {
    new HeapSort(ascending)
  }
}
使用1 堆的插入删除
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def main(args: Array[String]): Unit = {
    //ascending = true  最小堆   false 最大堆
    val ascending = false //最大堆
    val list = Array[Int](1, 4, 3, 5, 6, 8, 2, 3, 89, 4, 34, 50)
    val headsort = HeapSort(ascending)
    var sortedList = headsort.generateHeap(list)
    headsort.printlnList(sortedList) //输出1
    //添加新元素
    headsort.printlnList(headsort.insertToHeap(sortedList, 100)) //输出2

    //全排序
    def sortAll = {
      while (sortedList.length > 0) {
        print(s"${sortedList.head} ") 
        sortedList = headsort.generateHeap(sortedList.tail)
        //    printlnList(sortedList)
      }
    }
    //删除
    headsort.printlnList(headsort.deleteFromHeap(sortedList, 3)) //输出3
    //全排序,按大到小输出元素
    sortAll //输出4
  }

结果

1
2
3
4
89 34 50 5 8 2 1 3 4 4 6 3 sorted!
100 34 89 5 8 50 1 3 4 4 6 3 2 sorted!
89 34 50 8 6 1 3 4 4 2 3 sorted!
89 50 34 8 6 5 4 4 3 3 2 1 
使用2 使用最小堆取最大的N个元素TopN
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def main(args: Array[String]): Unit = {
    val heapSort = HeapSort(true)
    val list = Array(9,3,5,1,2,6,11,8)

    val num = 5

    var tempUnSortHeap = list.take(1)

    //取最大的5个数  用最小堆
    for(i <- 1 until list.length) {
      val ele = list(i)
      if (tempUnSortHeap.length >= num && ele > tempUnSortHeap.head) { //如果数据个数等于num的时候,如果新元素小于堆顶,即最小值,什么也不操作,否则删除堆顶,插入新元素,并重新排序
        tempUnSortHeap = heapSort.deleteHead(tempUnSortHeap)  //删除堆顶
        tempUnSortHeap = heapSort.insertToHeap(tempUnSortHeap, ele) //插入新元素并重新排序
      }else if(tempUnSortHeap.length < num ) //如果数据长度不足num , 直接插入元素
        tempUnSortHeap = heapSort.insertToHeap(tempUnSortHeap, ele)
    }
    heapSort.printlnList(tempUnSortHeap)

  }

结果

1
5 6 9 11 8 sorted!

结果为数组中最大的5个元素

如果需要排序请参考使用1