import logging import re import os import sys import shlex import fnmatch from stat import S_IRUSR, S_IWUSR, S_IRGRP, S_IWGRP, S_IROTH, S_IWOTH, S_IXUSR, S_IXGRP, S_IXOTH from subprocess import check_call, CalledProcessError from pyoptix.utils import glob_recursive, find_sub_path try: from configparser import ConfigParser except ImportError: from ConfigParser import SafeConfigParser as ConfigParser logger = logging.getLogger(__name__) class Compiler: nvcc_path = 'nvcc' extra_compile_args = [] output_path = '/tmp/pyoptix/ptx' use_fast_math = True _compile_args = [] _program_directories = [] @classmethod def add_program_directory(cls, directory): if directory not in cls._program_directories: cls._program_directories.append(directory) @classmethod def remove_program_directory(cls, directory): if directory in cls._program_directories: cls._program_directories.remove(directory) @classmethod def is_compile_required(cls, source_path, ptx_path): if os.path.isfile(ptx_path): ptx_mtime = os.path.getmtime(ptx_path) source_mtime = os.path.getmtime(source_path) if source_mtime > ptx_mtime: return True elif cls._has_modified_includes(source_path, ptx_mtime): return True else: return False else: return True @classmethod def _has_modified_includes(cls, file_path, modified_after, depth=4): if depth == 0: return False include_pattern = '#include\s*"(.*)"' with open(file_path) as f: content = f.read() for included_path in re.findall(include_pattern, content): for compiler_include_path in cls._program_directories: included_file_path = os.path.join(compiler_include_path, included_path) if not os.path.exists(included_file_path): continue included_file_mtime = os.path.getmtime(included_file_path) if included_file_mtime > modified_after: return True elif cls._has_modified_includes(included_file_path, modified_after, depth=depth - 1): return True return False @classmethod def compile(cls, source_path, output_ptx_name=None): if output_ptx_name is None: output_ptx_name = cls.get_ptx_name(source_path) if not os.path.isdir(cls.output_path): raise RuntimeError('Compiler.output_path is not a directory.') output_ptx_path = os.path.join(cls.output_path, output_ptx_name) is_compiled = True if cls.is_compile_required(source_path, output_ptx_path): if os.path.exists(output_ptx_path): os.remove(output_ptx_path) logger.info("Compiling {0}".format(source_path)) bash_command = cls.nvcc_path + " " bash_command += " ".join(cls._compile_args) bash_command += " ".join(cls.extra_compile_args) bash_command += " -ptx" if cls.use_fast_math: bash_command += " --use_fast_math" for include_path in cls._program_directories: if os.path.exists(include_path): bash_command += " -I=" + include_path bash_command += " " + source_path bash_command += " -o=" + output_ptx_path logger.debug("Executing: {0}".format(bash_command)) try: check_call(shlex.split(bash_command)) except CalledProcessError as e: logger.error(e) if not os.path.exists(output_ptx_path): logger.error("Could not compile {0}".format(source_path)) raise RuntimeError("Could not compile {0}".format(source_path)) else: os.chmod(output_ptx_path, S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP | S_IROTH | S_IWOTH) else: logger.debug("No compiling required for {0}".format(source_path)) is_compiled = False return output_ptx_path, is_compiled @classmethod def compile_all_directories(cls, source_extension='.cu'): for program_dir in cls._program_directories: for program_path in glob_recursive(program_dir, '*' + source_extension): Compiler.compile(os.path.abspath(program_path)) @classmethod def clean(cls): if os.path.exists(cls.output_path): for dirpath, dirnames, filenames in os.walk(cls.output_path): for filename in fnmatch.filter(filenames, '*.ptx'): os.remove(os.path.join(dirpath, filename)) @staticmethod def is_ptx(file_path): return os.path.splitext(file_path)[1].lower() == '.ptx' @staticmethod def get_ptx_name(file_path): return '%s.ptx' % file_path.replace(os.sep, '_') @classmethod def get_abs_program_path(cls, file_path): if os.path.exists(file_path): return file_path else: abs_path = find_sub_path(file_path, cls._program_directories) if os.path.exists(abs_path): return abs_path else: raise ValueError('File not found') try: config_path = os.path.join(os.path.dirname(sys.executable), 'pyoptix.conf') if not os.path.exists(config_path): config_path = '/etc/pyoptix.conf' config = ConfigParser() config.read(config_path) nvcc_path = config.get('pyoptix', 'nvcc_path') compile_args = config.get('pyoptix', 'compile_args') if nvcc_path is not None: Compiler.nvcc_path = nvcc_path if compile_args is not None: Compiler._compile_args = [arg for arg in compile_args.split(os.pathsep)] except Exception as e: logger.warning("Could not load pyoptix.conf") if not os.path.exists(Compiler.output_path): os.makedirs(Compiler.output_path) os.chmod(Compiler.output_path, S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IWGRP | S_IXGRP | S_IROTH | S_IWOTH | S_IXOTH)