bat.examples.bandits_attack_deepapi
1import validators 2 3import numpy as np 4from PIL import Image 5 6from bat.apis.deepapi import bat_deepapi_model_list 7from bat.attacks.bandits_attack import BanditsAttack 8 9def bandits_attack_deepapi(): 10 11 for i, (_, model) in enumerate(bat_deepapi_model_list.items(), start=1): 12 print(i, ':', model[0]) 13 14 try: 15 # Get the model type 16 index = input(f"Please input the model index (default: 1): ") 17 if len(index) == 0: 18 index = 1 19 else: 20 while not index.isdigit() or int(index) > len(bat_deepapi_model_list): 21 index = input(f"Model [{index}] does not exist. Please try again: ") 22 23 # Get the DeepAPI server url 24 deepapi_url = input(f"Please input the DeepAPI URL (default: http://localhost:8080): ") 25 if len(deepapi_url) == 0: 26 deepapi_url = 'http://localhost:8080' 27 else: 28 while not validators.url(deepapi_url): 29 deepapi_url = input(f"Invalid URL. Please try again: ") 30 31 # Get the image file 32 file = input(f"Please input the image file: ") 33 while len(file) == 0: 34 file = input(f"Please input the image file: ") 35 image = Image.open(file).convert('RGB') 36 37 if index == 1: 38 image = image.resize((32, 32)) 39 40 x = np.array(image) 41 x = np.array([x]) 42 43 # DeepAPI Model 44 deepapi_model = bat_deepapi_model_list[int(index)][1](deepapi_url) 45 46 # Make predictions 47 y_pred = deepapi_model.predict(x)[0] 48 49 if y_pred is not None: 50 deepapi_model.print(y_pred) 51 print('Prediction', np.argmax(y_pred), deepapi_model.get_class_name(np.argmax(y_pred))) 52 print() 53 54 # Bandits Attack 55 bandits_attack = BanditsAttack(deepapi_model) 56 57 x_adv = bandits_attack.attack(x, np.array([np.argmax(y_pred)]), epsilon = 0.05, max_it=3000, online_lr=100, concurrency=8) 58 59 # Print result after attack 60 y_adv = deepapi_model.predict(x_adv)[0] 61 deepapi_model.print(y_adv) 62 print('Prediction', np.argmax(y_adv), deepapi_model.get_class_name(np.argmax(y_adv))) 63 print() 64 65 # Save image 66 Image.fromarray((x_adv[0]).astype(np.uint8)).save('result.jpg', subsampling=0, quality=100) 67 print("The adversarial image is saved as result.jpg") 68 69 except Exception as e: 70 print(e) 71 return
def
bandits_attack_deepapi():
11def bandits_attack_deepapi(): 12 13 for i, (_, model) in enumerate(bat_deepapi_model_list.items(), start=1): 14 print(i, ':', model[0]) 15 16 try: 17 # Get the model type 18 index = input(f"Please input the model index (default: 1): ") 19 if len(index) == 0: 20 index = 1 21 else: 22 while not index.isdigit() or int(index) > len(bat_deepapi_model_list): 23 index = input(f"Model [{index}] does not exist. Please try again: ") 24 25 # Get the DeepAPI server url 26 deepapi_url = input(f"Please input the DeepAPI URL (default: http://localhost:8080): ") 27 if len(deepapi_url) == 0: 28 deepapi_url = 'http://localhost:8080' 29 else: 30 while not validators.url(deepapi_url): 31 deepapi_url = input(f"Invalid URL. Please try again: ") 32 33 # Get the image file 34 file = input(f"Please input the image file: ") 35 while len(file) == 0: 36 file = input(f"Please input the image file: ") 37 image = Image.open(file).convert('RGB') 38 39 if index == 1: 40 image = image.resize((32, 32)) 41 42 x = np.array(image) 43 x = np.array([x]) 44 45 # DeepAPI Model 46 deepapi_model = bat_deepapi_model_list[int(index)][1](deepapi_url) 47 48 # Make predictions 49 y_pred = deepapi_model.predict(x)[0] 50 51 if y_pred is not None: 52 deepapi_model.print(y_pred) 53 print('Prediction', np.argmax(y_pred), deepapi_model.get_class_name(np.argmax(y_pred))) 54 print() 55 56 # Bandits Attack 57 bandits_attack = BanditsAttack(deepapi_model) 58 59 x_adv = bandits_attack.attack(x, np.array([np.argmax(y_pred)]), epsilon = 0.05, max_it=3000, online_lr=100, concurrency=8) 60 61 # Print result after attack 62 y_adv = deepapi_model.predict(x_adv)[0] 63 deepapi_model.print(y_adv) 64 print('Prediction', np.argmax(y_adv), deepapi_model.get_class_name(np.argmax(y_adv))) 65 print() 66 67 # Save image 68 Image.fromarray((x_adv[0]).astype(np.uint8)).save('result.jpg', subsampling=0, quality=100) 69 print("The adversarial image is saved as result.jpg") 70 71 except Exception as e: 72 print(e) 73 return