#!/usr/bin/env python import os import csv import queue import zipfile import requests import argparse import multiprocessing # TODO: Don't hardcode the relative path? samples_path = "gym_malware/envs/utils/samples/" hashes_path = "gym_malware/envs/utils/sample_hashes.csv" vturl = "https://www.virustotal.com/vtapi/v2/file/download" def get_sample_hashes(): hash_rows = [] with open(hashes_path) as csvfile: for row in csv.DictReader(csvfile): hash_rows.append(row) return hash_rows def vt_download_sample(sha256, sample_path, vtapikey): tries = 0 success = False while not success and tries < 10: resp = requests.get(vturl, params={"hash": sha256, "apikey": vtapikey}) if not resp.ok: tries += 1 continue else: success = True if not success: return False with open(sample_path, "wb") as ofile: ofile.write(resp.content) return True def download_worker_function(download_queue, vtapikey): while True: try: sha256 = download_queue.get() except queue.Empty: continue if sha256 == "STOP": download_queue.task_done() return True print("{} downloading".format(sha256)) sample_path = os.path.join(samples_path, sha256) success = vt_download_sample(sha256, sample_path, vtapikey) if not success: print("{} had a problem".format(sha256)) print("{} done".format(sha256)) download_queue.task_done() def use_virustotal(args): """ Use Virustotal to download the environment malware """ m = multiprocessing.Manager() download_queue = m.JoinableQueue(args.nconcurrent) archive_procs = [ multiprocessing.Process( target=download_worker_function, args=(download_queue, args.vtapikey)) for i in range(args.nconcurrent) ] for w in archive_procs: w.start() for row in get_sample_hashes(): download_queue.put(row["sha256"]) for i in range(args.narchiveprocs): download_queue.put("STOP") download_queue.join() for w in archive_procs: w.join() def use_virusshare(args): """ Use VirusShare zip files as the source for the envirnment malware """ pwd = bytes(args.zipfilepassword, "ascii") md5_to_sha256_dict = {d["md5"]: d["sha256"] for d in get_sample_hashes()} for path in args.zipfile: z = zipfile.ZipFile(path) for f in z.namelist(): z_object_md5 = f.split("_")[1] if z_object_md5 in md5_to_sha256_dict: sample_bytez = z.open(f, "r", pwd).read() with open(md5_to_sha256_dict[z_object_md5], "wb") as ofile: ofile.write(sample_bytez) print("Extracted {}".format(md5_to_sha256_dict[z_object_md5])) if __name__ == '__main__': prog = "download_samples" descr = "Download the samples that define the malware gym environment" parser = argparse.ArgumentParser(prog=prog, description=descr) parser.add_argument( "--virustotal", default=False, action="store_true", help="Use Virustotal to download malware samples") parser.add_argument( "--vtapikey", type=str, default=None, help="Virustotal API key") parser.add_argument( "--nconcurrent", type=int, default=6, help="Maximum concurrent downloads from Virustotal") parser.add_argument( "--virusshare", default=False, action="store_true", help="Use malware samples from VirusShare torrents") parser.add_argument( "--zipfile", type=str, nargs="+", help="The path of VirusShare zipfile 290 or 291") parser.add_argument( "--zipfilepassword", type=str, default=None, help="Password for the VirusShare zipfiles 290 or 291") args = parser.parse_args() if not args.virustotal and not args.virusshare: parser.error("Must use either Virustotal or VirusShare") if args.virusshare: if len(args.zipfile) == 0: parser.error("Must the paths for one or more Virusshare zip files") if args.zipfilepassword is None: parser.error("Must enter a password for the VirusShare zip files") use_virusshare(args) if args.virustotal: if args.vtapikey is None: parser.error("Must enter a VirusTotal API key") use_virustotal(args)