不多說,直接上源碼吧
# Author: Jintao Huang
# Email: [email protected]
# Date:
# Ref: https://www.bilibili.com/video/BV16D4y1d7d1
"""排序筆記總結
排序方法 時間複雜度(最好/最壞/平均) 空間複雜度 穩定性 特點
插入排序: 适用于幾乎有序
1. 直接插入 O(N) O(N^2) O(N^2) O(1) 穩定 (前有序,後無序)
2. 折半插入 O(NLogN) O(N^2) O(N^2) O(1) 穩定 搜尋O(LogN), 移動O(N). (前有序,後無序)
3. 希爾 未知 O(1) 不穩定 對每個子序列進行直插排序
交換排序 每趟排序确定一個最終位置
1. 冒泡 O(N) O(N^2) O(N^2) O(1) 穩定 從前往後冒泡. (後有序,前無序)
2. 快速 O(NLogN) O(N^2) O(NLogN) O(LogN)-O(N) 不穩定
選擇排序 每趟排序确定一個最終位置
1. 簡單選擇 O(N^2) O(N^2) O(N^2) O(1) 不穩定 (前有序,後無序)
2. 堆 O(NLogN) O(NLogN) O(NLogN) O(1) 不穩定 非遞減用大根堆. (前有序,後無序)
建堆、删除(下濾). 插入(上濾)
二路歸并排序 O(NLogN) O(NLogN) O(NLogN) O(N) 穩定
複雜度與初始狀态相關:插排 * 2, 冒泡, 快速
前有序,後無序: 直接插入, 折半插入; 簡單選擇, 堆
後有序,前無序: 冒泡
"""
from typing import Any, Callable, List
import time
from copy import copy
import random
def get_runtime(func, *args, **kwargs):
t = time.time()
result = func(*args, **kwargs)
print(time.time() - t)
return result
def shuffle_arr(arr: List[Any]) -> List[Any]:
"""打亂數組"""
arr = copy(arr)
random.shuffle(arr)
return arr
def test_sort(func):
# test time
import random
a = list(range(5000))
random.seed(0)
a = shuffle_arr(a)
result = get_runtime(func, a)
result2 = get_runtime(func, a, key=lambda x: x % 10)
print(result[:10])
print(result2[:10])
def test_sort_std(func):
# test time
import random
a = list(range(5000))
random.seed(0)
a = shuffle_arr(a)
get_runtime(func, a)
print(a[:10])
print("sorted")
test_sort(sorted)
# sorted
# 0.0
# 0.0009980201721191406
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# [3550, 1500, 3310, 530, 2970, 490, 850, 650, 3690, 3480]
def bubble(arr: List[Any], lo: int, hi: int) -> None:
"""将arr[low:high]中最大的元素冒泡到arr[high - 1]處. 比較次數: hi - 1 - lo
Ot(N) Os(1)
"""
for i in range(lo, hi - 1):
# 前比後大則置換,等于則不置換(for stable)
if arr[i] > arr[i + 1]:
arr[i], arr[i + 1] = arr[i + 1], arr[i]
def bubble_sort(arr: List[Any], *, key: Callable[[Any], Any] = None) -> List[Any]:
"""冒泡排序 stable. 從前往後冒泡. Ot(N^2) Os(N)
:param arr: const
:param key: func
:return:
"""
# 防止重複計算key造成的性能下降、引入i為了stable
arr = [(key(x), i, x) for i, x in enumerate(arr)] if key is not None else copy(arr)
for n in reversed(range(2, len(arr) + 1)): # bubble()結束位置[len, 2], 每輪比較的次數[len - 1, 1]
bubble(arr, 0, n) # 最大的往後扔
return [item[2] for item in arr] if key is not None else arr
def bubble_sort_std(arr: List[Any]) -> None:
"""冒泡排序 stable. 從前往後冒泡. Ot(N^2) Os(1)"""
# 防止重複計算key造成的性能下降、引入i為了stable
for n in reversed(range(2, len(arr) + 1)): # bubble()結束位置[len, 2], 每輪比較的次數[len - 1, 1]
bubble(arr, 0, n) # 最大的往後扔
print("bubble_sort")
test_sort(bubble_sort)
print("-----------------------")
test_sort_std(bubble_sort_std)
# bubble_sort
# 1.6645495891571045
# 1.8520421981811523
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# [3550, 1500, 3310, 530, 2970, 490, 850, 650, 3690, 3480]
# -----------------------
# 1.6525754928588867
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
https://www.bilibili.com/video/BV16D4y1d7d1
tim排序:用于python的sorted()和list.sort()中
是 二分插入排序(二分查找+) + 歸并排序. Ot(NLogN) Os(N)
基于一個事實:
現實生活中,大多數真實的資料集中,已經有很多元素是排好序的
術語:
run(分區):一組資料的集合(嚴格的單調遞增或遞減)
1. 元素個數 < 64_python(或32_java),使用二分插入排序
2. > 64,使用tim排序
(tim排序自己内部做的判斷)
tim排序的步驟:
1. 先周遊全表,查找嚴格遞增/遞減的區間(run)(長度的講究),
嚴格遞減的部分反轉得到遞增
2. 分區根據一定規則合并。維持合并效率
"""
def tim_sort():
pass
def _min(arr: List[Any], lo: int = 0, hi: int = None) -> int:
"""傳回最小元素的索引"""
hi = len(arr) if hi is None else hi
min_idx = lo
for i in range(lo, hi):
if arr[min_idx] > arr[i]:
min_idx = i
return min_idx
def select_sort(arr: List[Any], *, key: Callable[[Any], Any] = None) -> List[Any]:
"""選擇排序 not stable. Ot(N^2) Os(N)
:param arr: const
:param key: func
:return:
"""
arr = [(key(x), x) for x in arr] if key is not None else copy(arr)
for i in range(len(arr) - 1): # 最後一輪不需要
min_idx = _min(arr, i)
arr[i], arr[min_idx] = arr[min_idx], arr[i]
return [item[1] for item in arr] if key is not None else arr
def select_sort_std(arr: List[Any]) -> None:
"""選擇排序 not stable. Ot(N^2) Os(1)"""
for i in range(len(arr) - 1): # 最後一輪不需要
min_idx = _min(arr, i)
arr[i], arr[min_idx] = arr[min_idx], arr[i]
print("select_sort")
test_sort(select_sort)
print("-----------------------")
test_sort_std(select_sort_std)
# select_sort not stable
# select_sort
# 0.6043496131896973
# 0.9325358867645264
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
# -----------------------
# 0.6562752723693848
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
def sect_left(arr: List[Any], x: Any, lo: int = 0, hi: int = None) -> int:
"""在arr中找到x合适的插入點以維持有序. Ot(N) Os(1)"""
hi = len(arr) if hi is None else hi
for i in range(lo, hi): # O(N)
if x <= arr[i]:
return i
else:
return hi
def bisect_left(arr: List[Any], x: Any, lo: int = 0, hi: int = None) -> int:
"""在arr中找到x合适的插入點以維持有序. Ot(LogN) Os(1)"""
# 沒有len: 位置. len=1: 元素
hi = len(arr) if hi is None else hi
while lo < hi:
mid = (lo + hi) // 2
if arr[mid] < x:
lo = mid + 1
else:
hi = mid
return lo
def _bisect_left(arr: List[Any], x: Any, lo: int = 0, hi: int = None) -> int:
"""另一種寫法,但是可讀性不好. Ot(LogN) Os(1)"""
hi = len(arr) if hi is None else hi
hi -= 1 # 範圍是閉區間
while lo <= hi:
mid = (lo + hi) // 2
if arr[mid] < x:
lo = mid + 1
else:
hi = mid - 1
return lo
# a = [1, 2, 3]
# print(sect_left(a, 2)) # 1
# print(sect_left(a, 0)) # 0
# print(sect_left(a, 3)) # 2
# print(sect_left(a, 4)) # 3
# print(sect_left(a, 4, 1)) # 3
# print(sect_left(a, 4, 1, 2)) # 2
# a = []
# print(sect_left(a, 2)) # 0
# print("---------------------------")
# #
# a = [1, 2, 3]
# print(bisect_left(a, 2)) # 1
# print(bisect_left(a, 0)) # 0
# print(bisect_left(a, 3)) # 2
# print(bisect_left(a, 4)) # 3
# print(bisect_left(a, 4, 1)) # 3
# print(bisect_left(a, 4, 1, 2)) # 2
# a = []
# print(bisect_left(a, 2)) # 0
# print("---------------------------")
# import bisect
#
# a = [1, 2, 3]
# print(bisect.bisect_left(a, 2)) # 1
# print(bisect.bisect_left(a, 0)) # 0
# print(bisect.bisect_left(a, 3)) # 2
# print(bisect.bisect_left(a, 4)) # 3
# print(bisect.bisect_left(a, 4, 1)) # 3
# print(bisect.bisect_left(a, 4, 1, 2)) # 2
# a = []
# print(bisect.bisect_left(a, 2)) # 0
# print("---------------------------")
def sect_right(arr: List[Any], x: Any, lo: int = 0, hi: int = None) -> int:
"""在arr中找到x合适的插入點以維持有序。Ot(N) Os(1)"""
hi = len(arr) if hi is None else hi
for i in reversed(range(lo, hi)): # O(N)
if arr[i] <= x:
return i + 1
else:
return lo
def bisect_right(arr: List[Any], x: Any, lo: int = 0, hi: int = None) -> int:
"""在arr中找到x合适的插入點以維持有序。Ot(LogN) Os(1)"""
hi = len(arr) if hi is None else hi
while lo < hi:
mid = (lo + hi) // 2
if x < arr[mid]:
hi = mid
else:
lo = mid + 1
return lo
# a = [1, 2, 3]
# print(sect_right(a, 2)) # 2
# print(sect_right(a, 0)) # 0
# print(sect_right(a, 3)) # 3
# print(sect_right(a, 4)) # 3
# print(sect_right(a, 4, 1)) # 3
# print(sect_right(a, 4, 1, 2)) # 2
# a = []
# print(sect_right(a, 2)) # 0
# print("---------------------------")
# a = [1, 2, 3]
# print(bisect_right(a, 2)) # 2
# print(bisect_right(a, 0)) # 0
# print(bisect_right(a, 3)) # 3
# print(bisect_right(a, 4)) # 3
# print(bisect_right(a, 4, 1)) # 3
# print(bisect_right(a, 4, 1, 2)) # 2
# a = []
# print(bisect_right(a, 2)) # 0
# print("---------------------------")
# import bisect
#
# a = [1, 2, 3]
# print(bisect.bisect_right(a, 2)) # 2
# print(bisect.bisect_right(a, 0)) # 0
# print(bisect.bisect_right(a, 3)) # 3
# print(bisect.bisect_right(a, 4)) # 3
# print(bisect.bisect_right(a, 4, 1)) # 3
# print(bisect.bisect_right(a, 4, 1, 2)) # 2
# a = []
# print(bisect.bisect_right(a, 2)) # 0
# print("---------------------------")
def insert_sort(arr: List[Any], *, key: Callable[[Any], Any] = None) -> List[Any]:
"""插入排序 stable. Ot(N^2) Os(N).
:param arr: const
:param key: func
:return:
"""
# i的引入為了stable. 可以不copy(不影響const).
arr = [(key(x), i, x) for i, x in enumerate(arr)] if key is not None else copy(arr) # 不重複計算key
out = []
for x in arr: # 需要被插入的數字
out.insert(sect_right(out, x), x)
return [item[2] for item in out] if key is not None else out
def insert_sort_bi(arr: List[Any], *, key: Callable[[Any], Any] = None) -> List[Any]:
"""折半插入排序 stable. Ot(N^2) Os(N). Y
:param arr: const
:param key: func
:return:
"""
# i的引入為了stable. 可以不copy(不影響const).
arr = [(key(x), i, x) for i, x in enumerate(arr)] if key is not None else copy(arr) # 不重複計算key
out = []
for x in arr: # 需要被插入的數字
out.insert(bisect_right(out, x), x)
return [item[2] for item in out] if key is not None else out
def insert_sort_std(arr: List[Any]) -> None:
"""插入排序 stable. Ot(N^2) Os(1)"""
for i in range(1, len(arr)): # 需要被插入的數字
t = arr[i]
idx = sect_right(arr, t, 0, i)
# 往後平移
for j in reversed(range(idx, i)): # Ot(N)
arr[j + 1] = arr[j]
arr[idx] = t
def insert_sort_bi_std(arr: List[Any]) -> None:
"""折半插入排序 stable. Ot(N^2) Os(1)"""
for i in range(1, len(arr)): # 需要被插入的數字
t = arr[i]
idx = bisect_right(arr, t, 0, i)
# 往後平移
for j in reversed(range(idx, i)): # Ot(N)
arr[j + 1] = arr[j]
arr[idx] = t
print("insert_sort")
test_sort(insert_sort)
print("-----------------------")
test_sort(insert_sort_bi)
print("-----------------------")
test_sort_std(insert_sort_std)
print("-----------------------")
test_sort_std(insert_sort_bi_std)
# insert_sort
# 0.2563166618347168
# 0.32114648818969727
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# [3550, 1500, 3310, 530, 2970, 490, 850, 650, 3690, 3480]
# -----------------------
# 0.008975505828857422
# 0.01197052001953125
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# [3550, 1500, 3310, 530, 2970, 490, 850, 650, 3690, 3480]
# -----------------------
# 0.628321647644043
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# -----------------------
# 0.3849670886993408
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
def partition(arr: List[Any], lo: int, hi: int) -> int:
"""Ot(N) Os(1). [lo, hi). 傳回索引"""
value = arr[lo] # 取第一個
hi -= 1 # [lo, hi]
while lo < hi:
while lo < hi and arr[hi] >= value:
hi -= 1
arr[lo] = arr[hi]
while lo < hi and arr[lo] <= value:
lo += 1
arr[hi] = arr[lo]
arr[lo] = value
return lo
def _quick_sort(arr: List[Any], lo: int, hi: int) -> None:
"""快速排序 not stable. [lo, hi)"""
if hi - lo <= 1: # <= 1個元素時
return
pivot = partition(arr, lo, hi)
_quick_sort(arr, lo, pivot)
_quick_sort(arr, pivot + 1, hi)
def quick_sort(arr: List[Any], *, key: Callable[[Any], Any] = None) -> List[Any]:
"""快速排序 not stable. Ot(NLogN) Os(N)
:param arr: const
:param key: func
:return:
"""
arr = [(key(x), x) for x in arr] if key is not None else copy(arr)
_quick_sort(arr, 0, len(arr))
return [item[1] for item in arr] if key is not None else arr
def quick_sort_std(arr: List[Any]) -> None:
"""快速排序 not stable. Ot(NLogN) Os(LogN)"""
_quick_sort(arr, 0, len(arr))
def mid_partition(arr: List[Any], lo: int, hi: int) -> int:
"""對中間元素partition. Ot(N) Os(1)"""
mid = (lo + hi) // 2
arr[lo], arr[mid] = arr[mid], arr[lo]
return partition(arr, lo, hi)
def _quick_sort2(arr: List[Any], lo: int, hi: int) -> None:
"""快速排序 not stable. mid_partition為了避免最壞複雜度. [lo, hi)"""
if hi - lo <= 1: # <= 1個元素時
return
pivot = mid_partition(arr, lo, hi)
_quick_sort2(arr, lo, pivot)
_quick_sort2(arr, pivot + 1, hi)
def quick_sort2(arr: List[Any], *, key: Callable[[Any], Any] = None) -> List[Any]:
"""快速排序 not stable. Ot(NLogN) Os(N)
:param arr: const
:param key: func
:return:
"""
arr = [(key(x), x) for x in arr] if key is not None else copy(arr)
_quick_sort2(arr, 0, len(arr))
return [item[1] for item in arr] if key is not None else arr
def quick_sort_std2(arr: List[Any]) -> None:
"""快速排序 not stable. Ot(NLogN) Os(LogN)"""
_quick_sort2(arr, 0, len(arr))
a = list(range(10000))
a.reverse()
get_runtime(quick_sort_std2, a)
get_runtime(quick_sort_std2, a)
print(a[:10])
print("quick_sort")
test_sort(quick_sort)
print("-----------------------")
test_sort_std(quick_sort_std)
# quick_sort
# 0.008977651596069336
# 0.010937929153442383
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
# -----------------------
# 0.008971929550170898
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
def merge(arr: List[Any], lo: int, mid: int, hi: int) -> None:
"""mid: 是兩個有序序列的分割點[lo, mid), [mid, hi). Ot(N) Os(N)"""
b_arr = copy(arr[lo:mid]) # 複制前半部分
i, j, k = lo, mid, 0 # arr_front, arr_mid, b_arr_front
while j < hi and k < len(b_arr):
if b_arr[k] <= arr[j]:
arr[i] = b_arr[k]
k += 1
else:
arr[i] = arr[j] # 含等于
j += 1
i += 1
while k < len(b_arr):
arr[i] = b_arr[k]
i += 1
k += 1
while j < hi:
arr[i] = arr[j]
i += 1
j += 1
def merge2(arr: List[Any], lo: int, mid: int, hi: int) -> None:
"""簡化merge. mid: 是兩個有序序列的分割點[lo, mid), [mid, hi). Ot(N) Os(N)"""
b_arr = copy(arr[lo:mid]) # 複制前半部分
i, j, k = lo, mid, 0 # arr_front, arr_mid, b_arr_front
while j < hi or k < len(b_arr):
if j >= hi or k < len(b_arr) and b_arr[k] <= arr[j]: # 使不越界
arr[i] = b_arr[k]
k += 1
else:
arr[i] = arr[j] # 含等于
j += 1
i += 1
def _merge_sort(arr: List[Any], lo: int, hi: int) -> None:
"""歸并排序 stable. [lo, hi)"""
if hi - lo <= 1: # length <= 1
return
mid = (lo + hi) // 2
_merge_sort(arr, lo, mid)
_merge_sort(arr, mid, hi)
merge(arr, lo, mid, hi) # 或用merge2
def merge_sort(arr: List[Any], *, key: Callable[[Any], Any] = None) -> List[Any]:
"""歸并排序 stable. Ot(NLogN) Os(N)
:param arr: const
:param key: func
:return:
"""
arr = [(key(x), i, x) for i, x in enumerate(arr)] if key is not None else copy(arr)
_merge_sort(arr, 0, len(arr))
return [item[2] for item in arr] if key is not None else arr
def merge_sort_std(arr: List[Any]) -> None:
"""歸并排序 stable. Ot(NLogN) Os(N)"""
_merge_sort(arr, 0, len(arr))
print("merge_sort")
test_sort(merge_sort)
print("-----------------------")
test_sort_std(merge_sort_std)
# merge_sort
# 0.014959573745727539
# 0.017950773239135742
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# [3550, 1500, 3310, 530, 2970, 490, 850, 650, 3690, 3480]
# -----------------------
# 0.015954256057739258
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]