bat.examples.bandits_attack_deepapi

 1import validators
 2
 3import numpy as np
 4from PIL import Image
 5
 6from bat.apis.deepapi import bat_deepapi_model_list
 7from bat.attacks.bandits_attack import BanditsAttack
 8
 9def bandits_attack_deepapi():
10
11    for i, (_, model) in enumerate(bat_deepapi_model_list.items(), start=1):
12        print(i, ':', model[0])
13
14    try:
15        # Get the model type
16        index = input(f"Please input the model index (default: 1): ")
17        if len(index) == 0:
18            index = 1
19        else:
20            while not index.isdigit() or int(index) > len(bat_deepapi_model_list):
21                index = input(f"Model [{index}] does not exist. Please try again: ")
22
23        # Get the DeepAPI server url
24        deepapi_url = input(f"Please input the DeepAPI URL (default: http://localhost:8080): ")
25        if len(deepapi_url) == 0:
26            deepapi_url = 'http://localhost:8080'
27        else:
28            while not validators.url(deepapi_url):
29                deepapi_url = input(f"Invalid URL. Please try again: ")
30
31        # Get the image file
32        file = input(f"Please input the image file: ")
33        while len(file) == 0:
34            file = input(f"Please input the image file: ")
35        image = Image.open(file).convert('RGB')
36
37        if index == 1:
38            image = image.resize((32, 32))
39
40        x = np.array(image)
41        x = np.array([x])
42
43        # DeepAPI Model
44        deepapi_model = bat_deepapi_model_list[int(index)][1](deepapi_url)
45
46        # Make predictions
47        y_pred = deepapi_model.predict(x)[0]
48
49        if y_pred is not None:
50            deepapi_model.print(y_pred)
51            print('Prediction', np.argmax(y_pred), deepapi_model.get_class_name(np.argmax(y_pred)))
52            print()
53
54        # Bandits Attack
55        bandits_attack = BanditsAttack(deepapi_model)
56
57        x_adv = bandits_attack.attack(x, np.array([np.argmax(y_pred)]), epsilon = 0.05, max_it=3000, online_lr=100, concurrency=8)
58
59        # Print result after attack
60        y_adv = deepapi_model.predict(x_adv)[0]
61        deepapi_model.print(y_adv)
62        print('Prediction', np.argmax(y_adv), deepapi_model.get_class_name(np.argmax(y_adv)))
63        print()
64
65        # Save image
66        Image.fromarray((x_adv[0]).astype(np.uint8)).save('result.jpg', subsampling=0, quality=100)
67        print("The adversarial image is saved as result.jpg")
68
69    except Exception as e:
70        print(e)
71        return
def bandits_attack_deepapi():
11def bandits_attack_deepapi():
12
13    for i, (_, model) in enumerate(bat_deepapi_model_list.items(), start=1):
14        print(i, ':', model[0])
15
16    try:
17        # Get the model type
18        index = input(f"Please input the model index (default: 1): ")
19        if len(index) == 0:
20            index = 1
21        else:
22            while not index.isdigit() or int(index) > len(bat_deepapi_model_list):
23                index = input(f"Model [{index}] does not exist. Please try again: ")
24
25        # Get the DeepAPI server url
26        deepapi_url = input(f"Please input the DeepAPI URL (default: http://localhost:8080): ")
27        if len(deepapi_url) == 0:
28            deepapi_url = 'http://localhost:8080'
29        else:
30            while not validators.url(deepapi_url):
31                deepapi_url = input(f"Invalid URL. Please try again: ")
32
33        # Get the image file
34        file = input(f"Please input the image file: ")
35        while len(file) == 0:
36            file = input(f"Please input the image file: ")
37        image = Image.open(file).convert('RGB')
38
39        if index == 1:
40            image = image.resize((32, 32))
41
42        x = np.array(image)
43        x = np.array([x])
44
45        # DeepAPI Model
46        deepapi_model = bat_deepapi_model_list[int(index)][1](deepapi_url)
47
48        # Make predictions
49        y_pred = deepapi_model.predict(x)[0]
50
51        if y_pred is not None:
52            deepapi_model.print(y_pred)
53            print('Prediction', np.argmax(y_pred), deepapi_model.get_class_name(np.argmax(y_pred)))
54            print()
55
56        # Bandits Attack
57        bandits_attack = BanditsAttack(deepapi_model)
58
59        x_adv = bandits_attack.attack(x, np.array([np.argmax(y_pred)]), epsilon = 0.05, max_it=3000, online_lr=100, concurrency=8)
60
61        # Print result after attack
62        y_adv = deepapi_model.predict(x_adv)[0]
63        deepapi_model.print(y_adv)
64        print('Prediction', np.argmax(y_adv), deepapi_model.get_class_name(np.argmax(y_adv)))
65        print()
66
67        # Save image
68        Image.fromarray((x_adv[0]).astype(np.uint8)).save('result.jpg', subsampling=0, quality=100)
69        print("The adversarial image is saved as result.jpg")
70
71    except Exception as e:
72        print(e)
73        return