bat.examples.square_attack_deepapi

 1import validators
 2
 3import numpy as np
 4from PIL import Image
 5
 6from bat.apis.deepapi import bat_deepapi_model_list
 7from bat.attacks.square_attack import SquareAttack
 8
 9def dense_to_onehot(y, n_classes):
10    y_onehot = np.zeros([len(y), n_classes], dtype=bool)
11    y_onehot[np.arange(len(y)), y] = True
12    return y_onehot
13
14def square_attack_deepapi():
15
16    for i, (_, model) in enumerate(bat_deepapi_model_list.items(), start=1):
17        print(i, ':', model[0])
18
19    try:
20        # Get the model type
21        index = input(f"Please input the model index (default: 1): ")
22        if len(index) == 0:
23            index = 1
24        else:
25            while not index.isdigit() or int(index) > len(bat_deepapi_model_list):
26                index = input(f"Model [{index}] does not exist. Please try again: ")
27
28        # Get the DeepAPI server url
29        deepapi_url = input(f"Please input the DeepAPI URL (default: http://localhost:8080): ")
30        if len(deepapi_url) == 0:
31            deepapi_url = 'http://localhost:8080'
32        else:
33            while not validators.url(deepapi_url):
34                deepapi_url = input(f"Invalid URL. Please try again: ")
35
36        # Get the image file
37        file = input(f"Please input the image file: ")
38        while len(file) == 0:
39            file = input(f"Please input the image file: ")
40        image = Image.open(file).convert('RGB')
41
42        if index == 1:
43            image = image.resize((32, 32))
44
45        x = np.array(image)
46        x = np.array([x])
47
48        # DeepAPI Model
49        deepapi_model = bat_deepapi_model_list[int(index)][1](deepapi_url)
50
51        # Make predictions
52        y_pred = deepapi_model.predict(x)[0]
53
54        if y_pred is not None:
55            deepapi_model.print(y_pred)
56            print('Prediction', np.argmax(y_pred), deepapi_model.get_class_name(np.argmax(y_pred)))
57            print()
58
59        y_target_onehot = dense_to_onehot(np.array([np.argmax(y_pred)]), n_classes=len(y_pred))
60    
61        # Note: we count the queries only across correctly classified images
62        square_attack = SquareAttack(deepapi_model)
63
64        # Vertically Distributed Attack
65        x_adv, _ = square_attack.attack(x,  y_target_onehot, False, epsilon = 0.05, max_it=3000, concurrency=8)
66
67        # Save image
68        Image.fromarray((x_adv[0]).astype(np.uint8)).save('result.jpg', subsampling=0, quality=100)
69        print("The adversarial image is saved as result.jpg")
70
71    except Exception as e:
72        print(e)
73        return
def dense_to_onehot(y, n_classes):
11def dense_to_onehot(y, n_classes):
12    y_onehot = np.zeros([len(y), n_classes], dtype=bool)
13    y_onehot[np.arange(len(y)), y] = True
14    return y_onehot
def square_attack_deepapi():
16def square_attack_deepapi():
17
18    for i, (_, model) in enumerate(bat_deepapi_model_list.items(), start=1):
19        print(i, ':', model[0])
20
21    try:
22        # Get the model type
23        index = input(f"Please input the model index (default: 1): ")
24        if len(index) == 0:
25            index = 1
26        else:
27            while not index.isdigit() or int(index) > len(bat_deepapi_model_list):
28                index = input(f"Model [{index}] does not exist. Please try again: ")
29
30        # Get the DeepAPI server url
31        deepapi_url = input(f"Please input the DeepAPI URL (default: http://localhost:8080): ")
32        if len(deepapi_url) == 0:
33            deepapi_url = 'http://localhost:8080'
34        else:
35            while not validators.url(deepapi_url):
36                deepapi_url = input(f"Invalid URL. Please try again: ")
37
38        # Get the image file
39        file = input(f"Please input the image file: ")
40        while len(file) == 0:
41            file = input(f"Please input the image file: ")
42        image = Image.open(file).convert('RGB')
43
44        if index == 1:
45            image = image.resize((32, 32))
46
47        x = np.array(image)
48        x = np.array([x])
49
50        # DeepAPI Model
51        deepapi_model = bat_deepapi_model_list[int(index)][1](deepapi_url)
52
53        # Make predictions
54        y_pred = deepapi_model.predict(x)[0]
55
56        if y_pred is not None:
57            deepapi_model.print(y_pred)
58            print('Prediction', np.argmax(y_pred), deepapi_model.get_class_name(np.argmax(y_pred)))
59            print()
60
61        y_target_onehot = dense_to_onehot(np.array([np.argmax(y_pred)]), n_classes=len(y_pred))
62    
63        # Note: we count the queries only across correctly classified images
64        square_attack = SquareAttack(deepapi_model)
65
66        # Vertically Distributed Attack
67        x_adv, _ = square_attack.attack(x,  y_target_onehot, False, epsilon = 0.05, max_it=3000, concurrency=8)
68
69        # Save image
70        Image.fromarray((x_adv[0]).astype(np.uint8)).save('result.jpg', subsampling=0, quality=100)
71        print("The adversarial image is saved as result.jpg")
72
73    except Exception as e:
74        print(e)
75        return