-
Notifications
You must be signed in to change notification settings - Fork 121
/
lattice_attack.py
156 lines (133 loc) · 5.66 KB
/
lattice_attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import sys
from sage.all import QQ
from sage.all import ZZ
from sage.all import matrix
from sage.all import vector
path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(os.path.abspath(__file__)))))
if sys.path[1] != path:
sys.path.insert(1, path)
from shared.lattice import shortest_vectors
def attack(a, b, m, X):
"""
Solves the hidden number problem using an attack based on the shortest vector problem.
The hidden number problem is defined as finding y such that {xi = {aij * yj} + bi mod m}.
:param a: the aij values
:param b: the bi values
:param m: the modulus
:param X: a bound on the xi values
:return: a generator generating tuples containing a list of xi values and a list of yj values
"""
assert len(a) == len(b), "a and b lists should be of equal length."
n1 = len(a)
n2 = len(a[0])
B = matrix(QQ, n1 + n2 + 1, n1 + n2 + 1)
for i in range(n1):
for j in range(n2):
B[n1 + j, i] = a[i][j]
B[i, i] = m
B[n1 + n2, i] = b[i] - X // 2
for j in range(n2):
B[n1 + j, n1 + j] = X / QQ(m)
B[n1 + n2, n1 + n2] = X
for v in shortest_vectors(B):
xs = [int(v[i] + X // 2) for i in range(n1)]
ys = [(int(v[n1 + j] * m) // X) % m for j in range(n2)]
if all(y != 0 for y in ys) and v[n1 + n2] == X:
yield xs, ys
def dsa_known_msb(n, h, r, s, k):
"""
Recovers the (EC)DSA private key and nonces if the most significant nonce bits are known.
:param n: the modulus
:param h: a list containing the hashed messages
:param r: a list containing the r values
:param s: a list containing the s values
:param k: a list containing the partial nonces (PartialIntegers)
:return: a generator generating tuples containing the possible private key and a list of nonces
"""
assert len(h) == len(r) == len(s) == len(k), "h, r, s, and k lists should be of equal length."
a = []
b = []
X = 0
for hi, ri, si, ki in zip(h, r, s, k):
msb, msb_bit_length = ki.get_known_msb()
shift = 2 ** ki.get_unknown_lsb()
a.append([(pow(si, -1, n) * ri) % n])
b.append((pow(si, -1, n) * hi - shift * msb) % n)
X = max(X, shift)
for k_, x in attack(a, b, n, X):
yield x[0], [ki.sub([ki_]) for ki, ki_ in zip(k, k_)]
def dsa_known_lsb(n, h, r, s, k):
"""
Recovers the (EC)DSA private key and nonces if the least significant nonce bits are known.
:param n: the modulus
:param h: a list containing the hashed messages
:param r: a list containing the r values
:param s: a list containing the s values
:param k: a list containing the partial nonces (PartialIntegers)
:return: a generator generating tuples containing the possible private key and a list of nonces
"""
assert len(h) == len(r) == len(s) == len(k), "h, r, s, and k lists should be of equal length."
a = []
b = []
X = 0
for hi, ri, si, ki in zip(h, r, s, k):
lsb, lsb_bit_length = ki.get_known_lsb()
inv_shift = pow(2 ** lsb_bit_length, -1, n)
a.append([(inv_shift * pow(si, -1, n) * ri) % n])
b.append((inv_shift * pow(si, -1, n) * hi - inv_shift * lsb) % n)
X = max(X, 2 ** ki.get_unknown_msb())
for k_, x in attack(a, b, n, X):
nonces = [ki.sub([ki_]) for ki, ki_ in zip(k, k_)]
yield x[0], nonces
def dsa_known_middle(n, h1, r1, s1, k1, h2, r2, s2, k2):
"""
Recovers the (EC)DSA private key and nonces if the middle nonce bits are known.
This is a heuristic extension which might perform worse than the methods to solve the Extended Hidden Number Problem.
More information: De Micheli G., Heninger N., "Recovering cryptographic keys from partial information, by example" (Section 5.2.3)
:param n: the modulus
:param h1: the first hashed message
:param r1: the first r value
:param s1: the first s value
:param k1: the first partial nonce (PartialInteger)
:param h2: the second hashed message
:param r2: the second r value
:param s2: the second s value
:param k2: the second partial nonce (PartialInteger)
:return: a tuple containing the private key, the nonce of the first signature, and the nonce of the second signature
"""
k_bit_length = k1.bit_length
assert k_bit_length == k2.bit_length
lsb_unknown = k1.get_unknown_lsb()
assert lsb_unknown == k2.get_unknown_lsb()
msb_unknown = k1.get_unknown_msb()
assert msb_unknown == k2.get_unknown_msb()
K = 2 ** max(lsb_unknown, msb_unknown)
l = k_bit_length - msb_unknown
a1 = k1.get_known_middle()[0] << lsb_unknown
a2 = k2.get_known_middle()[0] << lsb_unknown
t = -(pow(s1, -1, n) * s2 * r1 * pow(r2, -1, n))
u = pow(s1, -1, n) * r1 * h2 * pow(r2, -1, n) - pow(s1, -1, n) * h1
u_ = a1 + t * a2 + u
B = matrix(ZZ, 5, 5)
B[0] = vector(ZZ, [K, K * 2 ** l, K * t, K * t * 2 ** l, u_])
B[1] = vector(ZZ, [0, K * n, 0, 0, 0])
B[2] = vector(ZZ, [0, 0, K * n, 0, 0])
B[3] = vector(ZZ, [0, 0, 0, K * n, 0])
B[4] = vector(ZZ, [0, 0, 0, 0, n])
A = matrix(ZZ, 4, 4)
b = []
for row, v in enumerate(shortest_vectors(B)):
A[row] = v[:4].apply_map(lambda x: x // K)
b.append(-v[4])
if row == A.nrows() - 1:
break
assert len(b) == 4
x1, y1, x2, y2 = A.solve_right(vector(ZZ, b))
assert (x1 + 2 ** l * y1 + t * x2 + 2 ** l * t * y2 + u_) % n == 0
k1 = k1.sub([int(x1), int(y1)])
k2 = k2.sub([int(x2), int(y2)])
private_key1 = (pow(r1, -1, n) * (s1 * k1 - h1)) % n
private_key2 = (pow(r2, -1, n) * (s2 * k2 - h2)) % n
assert private_key1 == private_key2
return int(private_key1), int(k1), int(k2)