forked from NeuralNetworkVerification/Marabou
-
Notifications
You must be signed in to change notification settings - Fork 0
/
testmarabouonl1.py
70 lines (56 loc) · 2.24 KB
/
testmarabouonl1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from fileinput import filename
from maraboupy import Marabou
from test_utils import *
from maraboupy import MarabouCore
def check(filename, num_imgs, delta=0.03, b = 1):
good = 0
eqs = 0
dns = 0
for _ in range(num_imgs):
#clear constrints
print("getting network..", )
network = Marabou.read_onnx(filename)
options = Marabou.createOptions(verbosity = 0)###is there a better way??
# network.equList = network.equList[:-eqs]
# print(network.inputVars[0][0].shape)
inputVars = network.inputVars[0][0]
outputVars = network.outputVars[0]
print("got network!", )
#read image
print("adding infinity norms..", )
img, lab = get_image()
#set linf norm
for h in range(img.shape[1]):
for w in range(img.shape[1]):
network.setLowerBound(inputVars[28*h+w], img[h][w])
network.setUpperBound(inputVars[28*h+w], img[h][w])
network.setLowerBound(inputVars[28*28+28*h+w], img[h][w]-delta)
network.setUpperBound(inputVars[28*28+28*h+w], img[h][w]+delta)
eqs +=4
print("added infinity norms!", )
#set l1 norm using disjunction
print("adding other bounds..", flush = True)
l1eq = MarabouCore.Equation(MarabouCore.Equation.GE);
l1eq.addAddend(1, outputVars[0])
l1eq.setScalar(b) #l1 norm > b condition
#bounds for label
# labineq = []
for i in range(1, 10):
if(i==lab+1): continue
ineq = MarabouCore.Equation(MarabouCore.Equation.LE);
ineq.addAddend(1, outputVars[lab+1])
ineq.addAddend(-1, outputVars[lab+1])
ineq.setScalar(0)
# labineq.append(ineq)
network.addDisjunctionConstraint([[l1eq] , [ineq] ])
print("adding all bounds!", flush = True)
#check sat
print("checking sat..")
vals = network.solve(options = options)
# print(vals)
if(vals[0]=="unsat"): good += 1
print("verified correctly: {}/{}".format(good,num_imgs))
if __name__=="__main__":
filename = "l1model_combined_v2.onnx"
num_imgs = 1
check(filename, num_imgs, delta = 0.0, b=0.0)