-
Notifications
You must be signed in to change notification settings - Fork 427
/
radix_sort.bend
122 lines (107 loc) · 3 KB
/
radix_sort.bend
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
type MyMap:
Free
Used
Both { ~a: MyMap, ~b: MyMap }
type Arr(t):
Null
Leaf { a: t }
Node { ~a: Arr(t), ~b: Arr(t) }
def swap(s: u24, a: MyMap, b: MyMap) -> MyMap:
switch s:
case 0:
return MyMap/Both{ a: a, b: b }
case _:
return MyMap/Both{ a: b, b: a }
def sort(t: Arr(u24)) -> Arr(u24):
return to_arr(0, to_map(t))
def to_map(t: Arr(u24)) -> MyMap:
match t:
case Arr/Null:
return MyMap/Free
case Arr/Leaf:
return radix(t.a)
case Arr/Node:
return merge(to_map(t.a), to_map(t.b))
def to_arr(x: u24, m: MyMap) -> Arr(u24):
match m:
case MyMap/Free:
return Arr/Null
case MyMap/Used:
return Arr/Leaf{ a: x }
case MyMap/Both:
return Arr/Node{ a: to_arr(x * 2, m.a), b: to_arr(x * 2 + 1, m.b) }
def merge(a: MyMap, b: MyMap) -> MyMap:
match a:
case MyMap/Free:
return b
case MyMap/Used:
return MyMap/Used
case MyMap/Both:
match b:
case MyMap/Free:
return a
case MyMap/Used:
return MyMap/Used
case MyMap/Both:
return MyMap/Both{ a: merge(a.a, b.a), b: merge(a.b, b.b) }
def radix(n: u24) -> MyMap:
r = MyMap/Used
r = swap(n & 1, r, MyMap/Free)
r = swap(n & 2, r, MyMap/Free)
r = swap(n & 4, r, MyMap/Free)
r = swap(n & 8, r, MyMap/Free)
r = swap(n & 16, r, MyMap/Free)
r = swap(n & 32, r, MyMap/Free)
r = swap(n & 64, r, MyMap/Free)
r = swap(n & 128, r, MyMap/Free)
r = swap(n & 256, r, MyMap/Free)
r = swap(n & 512, r, MyMap/Free)
return radix2(n, r)
# At the moment, we need to manually break very large functions into smaller ones
# if we want to run this program on the GPU.
# In a future version of Bend, we will be able to do this automatically.
def radix2(n: u24, r: MyMap) -> MyMap:
r = swap(n & 1024, r, MyMap/Free)
r = swap(n & 2048, r, MyMap/Free)
r = swap(n & 4096, r, MyMap/Free)
r = swap(n & 8192, r, MyMap/Free)
r = swap(n & 16384, r, MyMap/Free)
r = swap(n & 32768, r, MyMap/Free)
r = swap(n & 65536, r, MyMap/Free)
r = swap(n & 131072, r, MyMap/Free)
r = swap(n & 262144, r, MyMap/Free)
r = swap(n & 524288, r, MyMap/Free)
return radix3(n, r)
def radix3(n: u24, r: MyMap) -> MyMap:
r = swap(n & 1048576, r, MyMap/Free)
r = swap(n & 2097152, r, MyMap/Free)
r = swap(n & 4194304, r, MyMap/Free)
r = swap(n & 8388608, r, MyMap/Free)
return r
def reverse(t: Arr(u24)) -> Arr(u24):
match t:
case Arr/Null:
return Arr/Null
case Arr/Leaf:
return t
case Arr/Node:
return Arr/Node{ a: reverse(t.b), b: reverse(t.a) }
def sum(t: Arr(u24)) -> u24:
match t:
case Arr/Null:
return 0
case Arr/Leaf:
return t.a
case Arr/Node:
return sum(t.a) + sum(t.b)
def gen(n: u24) -> Arr(u24):
return gen_go(n, 0)
def gen_go(n: u24, x: u24) -> Arr(u24):
switch n:
case 0:
return Arr/Leaf{ a: x }
case _:
a = x * 2
b = x * 2 + 1
return Arr/Node{ a: gen_go(n-1, a), b: gen_go(n-1, b) }
Main: u24 = (sum (sort(reverse(gen 4))))