bat.examples.simba_attack_deepapi

 1import validators
 2
 3import numpy as np
 4np.set_printoptions(suppress=True)
 5
 6from PIL import Image
 7
 8from bat.apis.deepapi import bat_deepapi_model_list
 9from bat.attacks import SimBA
10
11def simba_attack_deepapi():
12
13    for i, (_, model) in enumerate(bat_deepapi_model_list.items(), start=1):
14        print(i, ':', model[0])
15
16    try:
17        # Get the model type
18        index = input(f"Please input the model index (default: 1): ")
19        if len(index) == 0:
20            index = 1
21        else:
22            while not index.isdigit() or int(index) > len(bat_deepapi_model_list):
23                index = input(f"Model [{index}] does not exist. Please try again: ")
24
25        # Get the DeepAPI server url
26        deepapi_url = input(f"Please input the DeepAPI URL (default: http://localhost:8080): ")
27        if len(deepapi_url) == 0:
28            deepapi_url = 'http://localhost:8080'
29        else:
30            while not validators.url(deepapi_url):
31                deepapi_url = input(f"Invalid URL. Please try again: ")
32
33        # Get the image file
34        try:
35            file = input(f"Please input the image file: ")
36            while len(file) == 0:
37                file = input(f"Please input the image file: ")
38            image = Image.open(file).convert('RGB')
39            if index == 1:
40                image = image.resize((32, 32))
41            x = np.array(image)
42            x = np.array([x])
43        except Exception as e:
44            print(e)
45            return
46
47        # DeepAPI Model
48        deepapi_model = bat_deepapi_model_list[int(index)][1](deepapi_url)
49    
50        # Make predictions
51        y_pred = deepapi_model.predict(x)[0]
52
53        if y_pred is not None:
54            deepapi_model.print(y_pred)
55            print('Prediction', np.argmax(y_pred), deepapi_model.get_class_name(np.argmax(y_pred)))
56            print()
57        
58        # SimBA Attack
59        simba = SimBA(deepapi_model)
60        x_adv = simba.attack(x, np.argmax(y_pred), epsilon=0.05, max_it=3000, concurrency=4)
61
62        # Print result after attack
63        y_adv = deepapi_model.predict(x_adv)[0]
64        deepapi_model.print(y_adv)
65        print('Prediction', np.argmax(y_adv), deepapi_model.get_class_name(np.argmax(y_adv)))
66        print()
67
68        # Save image
69        Image.fromarray((x_adv[0]).astype(np.uint8)).save('result.jpg', subsampling=0, quality=100)
70        print("The adversarial image is saved as result.jpg")
71
72    except KeyboardInterrupt as e:
73        print()
74        return
def simba_attack_deepapi():
12def simba_attack_deepapi():
13
14    for i, (_, model) in enumerate(bat_deepapi_model_list.items(), start=1):
15        print(i, ':', model[0])
16
17    try:
18        # Get the model type
19        index = input(f"Please input the model index (default: 1): ")
20        if len(index) == 0:
21            index = 1
22        else:
23            while not index.isdigit() or int(index) > len(bat_deepapi_model_list):
24                index = input(f"Model [{index}] does not exist. Please try again: ")
25
26        # Get the DeepAPI server url
27        deepapi_url = input(f"Please input the DeepAPI URL (default: http://localhost:8080): ")
28        if len(deepapi_url) == 0:
29            deepapi_url = 'http://localhost:8080'
30        else:
31            while not validators.url(deepapi_url):
32                deepapi_url = input(f"Invalid URL. Please try again: ")
33
34        # Get the image file
35        try:
36            file = input(f"Please input the image file: ")
37            while len(file) == 0:
38                file = input(f"Please input the image file: ")
39            image = Image.open(file).convert('RGB')
40            if index == 1:
41                image = image.resize((32, 32))
42            x = np.array(image)
43            x = np.array([x])
44        except Exception as e:
45            print(e)
46            return
47
48        # DeepAPI Model
49        deepapi_model = bat_deepapi_model_list[int(index)][1](deepapi_url)
50    
51        # Make predictions
52        y_pred = deepapi_model.predict(x)[0]
53
54        if y_pred is not None:
55            deepapi_model.print(y_pred)
56            print('Prediction', np.argmax(y_pred), deepapi_model.get_class_name(np.argmax(y_pred)))
57            print()
58        
59        # SimBA Attack
60        simba = SimBA(deepapi_model)
61        x_adv = simba.attack(x, np.argmax(y_pred), epsilon=0.05, max_it=3000, concurrency=4)
62
63        # Print result after attack
64        y_adv = deepapi_model.predict(x_adv)[0]
65        deepapi_model.print(y_adv)
66        print('Prediction', np.argmax(y_adv), deepapi_model.get_class_name(np.argmax(y_adv)))
67        print()
68
69        # Save image
70        Image.fromarray((x_adv[0]).astype(np.uint8)).save('result.jpg', subsampling=0, quality=100)
71        print("The adversarial image is saved as result.jpg")
72
73    except KeyboardInterrupt as e:
74        print()
75        return