bat.examples.simba_attack_deepapi
1import validators 2 3import numpy as np 4np.set_printoptions(suppress=True) 5 6from PIL import Image 7 8from bat.apis.deepapi import bat_deepapi_model_list 9from bat.attacks import SimBA 10 11def simba_attack_deepapi(): 12 13 for i, (_, model) in enumerate(bat_deepapi_model_list.items(), start=1): 14 print(i, ':', model[0]) 15 16 try: 17 # Get the model type 18 index = input(f"Please input the model index (default: 1): ") 19 if len(index) == 0: 20 index = 1 21 else: 22 while not index.isdigit() or int(index) > len(bat_deepapi_model_list): 23 index = input(f"Model [{index}] does not exist. Please try again: ") 24 25 # Get the DeepAPI server url 26 deepapi_url = input(f"Please input the DeepAPI URL (default: http://localhost:8080): ") 27 if len(deepapi_url) == 0: 28 deepapi_url = 'http://localhost:8080' 29 else: 30 while not validators.url(deepapi_url): 31 deepapi_url = input(f"Invalid URL. Please try again: ") 32 33 # Get the image file 34 try: 35 file = input(f"Please input the image file: ") 36 while len(file) == 0: 37 file = input(f"Please input the image file: ") 38 image = Image.open(file).convert('RGB') 39 if index == 1: 40 image = image.resize((32, 32)) 41 x = np.array(image) 42 x = np.array([x]) 43 except Exception as e: 44 print(e) 45 return 46 47 # DeepAPI Model 48 deepapi_model = bat_deepapi_model_list[int(index)][1](deepapi_url) 49 50 # Make predictions 51 y_pred = deepapi_model.predict(x)[0] 52 53 if y_pred is not None: 54 deepapi_model.print(y_pred) 55 print('Prediction', np.argmax(y_pred), deepapi_model.get_class_name(np.argmax(y_pred))) 56 print() 57 58 # SimBA Attack 59 simba = SimBA(deepapi_model) 60 x_adv = simba.attack(x, np.argmax(y_pred), epsilon=0.05, max_it=3000, concurrency=4) 61 62 # Print result after attack 63 y_adv = deepapi_model.predict(x_adv)[0] 64 deepapi_model.print(y_adv) 65 print('Prediction', np.argmax(y_adv), deepapi_model.get_class_name(np.argmax(y_adv))) 66 print() 67 68 # Save image 69 Image.fromarray((x_adv[0]).astype(np.uint8)).save('result.jpg', subsampling=0, quality=100) 70 print("The adversarial image is saved as result.jpg") 71 72 except KeyboardInterrupt as e: 73 print() 74 return
def
simba_attack_deepapi():
12def simba_attack_deepapi(): 13 14 for i, (_, model) in enumerate(bat_deepapi_model_list.items(), start=1): 15 print(i, ':', model[0]) 16 17 try: 18 # Get the model type 19 index = input(f"Please input the model index (default: 1): ") 20 if len(index) == 0: 21 index = 1 22 else: 23 while not index.isdigit() or int(index) > len(bat_deepapi_model_list): 24 index = input(f"Model [{index}] does not exist. Please try again: ") 25 26 # Get the DeepAPI server url 27 deepapi_url = input(f"Please input the DeepAPI URL (default: http://localhost:8080): ") 28 if len(deepapi_url) == 0: 29 deepapi_url = 'http://localhost:8080' 30 else: 31 while not validators.url(deepapi_url): 32 deepapi_url = input(f"Invalid URL. Please try again: ") 33 34 # Get the image file 35 try: 36 file = input(f"Please input the image file: ") 37 while len(file) == 0: 38 file = input(f"Please input the image file: ") 39 image = Image.open(file).convert('RGB') 40 if index == 1: 41 image = image.resize((32, 32)) 42 x = np.array(image) 43 x = np.array([x]) 44 except Exception as e: 45 print(e) 46 return 47 48 # DeepAPI Model 49 deepapi_model = bat_deepapi_model_list[int(index)][1](deepapi_url) 50 51 # Make predictions 52 y_pred = deepapi_model.predict(x)[0] 53 54 if y_pred is not None: 55 deepapi_model.print(y_pred) 56 print('Prediction', np.argmax(y_pred), deepapi_model.get_class_name(np.argmax(y_pred))) 57 print() 58 59 # SimBA Attack 60 simba = SimBA(deepapi_model) 61 x_adv = simba.attack(x, np.argmax(y_pred), epsilon=0.05, max_it=3000, concurrency=4) 62 63 # Print result after attack 64 y_adv = deepapi_model.predict(x_adv)[0] 65 deepapi_model.print(y_adv) 66 print('Prediction', np.argmax(y_adv), deepapi_model.get_class_name(np.argmax(y_adv))) 67 print() 68 69 # Save image 70 Image.fromarray((x_adv[0]).astype(np.uint8)).save('result.jpg', subsampling=0, quality=100) 71 print("The adversarial image is saved as result.jpg") 72 73 except KeyboardInterrupt as e: 74 print() 75 return