Source code for romsearch.modules.rompatcher

import glob
import os
import shutil

import wget

import romsearch
from ..util import (
    load_yml,
    setup_logger,
    unzip_file,
    centred_string,
    left_aligned_string,
)

ALLOWED_PATCH_METHODS = [
    "xdelta",
]


def find_file_by_extensions(file_exts, patch_dir):
    """Find a file by extension

    Args:
        file_exts (list): File extensions to loop over
        patch_dir (str): Patch directory
    """

    potential_files = []

    for file_ext in file_exts:
        potential_file = glob.glob(os.path.join(patch_dir, f"*{file_ext}"))
        potential_files.extend(potential_file)

    # If we have multiple potential files, then error out
    if len(potential_files) > 1:
        raise ValueError(f'Multiple files found: {", ".join(potential_files)}')

    file = potential_files[0]

    return file


[docs] class ROMPatcher: def __init__( self, platform, config_file=None, config=None, platform_config=None, logger=None, log_line_sep="=", log_line_length=100, ): """ROM Patching tool There are different ways to patch files based on platforms, so we need to keep track of a number of things here Args: platform (str): Platform name config_file (str, optional): path to config file. Defaults to None. config (dict, optional): configuration dictionary. Defaults to None. platform_config (dict, optional): platform configuration dictionary. Defaults to None. logger (logging.Logger, optional): logger. Defaults to None. log_line_length (int, optional): Line length of log. Defaults to 100 """ if config_file is None and config is None: raise ValueError("config_file or config must be specified") if config is None: config = load_yml(config_file) self.config = config if logger is None: log_dir = self.config.get("dirs", {}).get( "log_dir", os.path.join(os.getcwd(), "logs") ) log_level = self.config.get("logger", {}).get("level", "info") logger = setup_logger( log_level=log_level, script_name=f"ROMPatcher", log_dir=log_dir, ) self.logger = logger # Pull in directories self.patch_dir = self.config.get("dirs", {}).get("patch_dir", None) if self.patch_dir is None: raise ValueError("patch_dir needs to be defined in config") if not os.path.exists(self.patch_dir): os.makedirs(self.patch_dir) self.platform = platform # Pull in platform config that we need mod_dir = os.path.dirname(romsearch.__file__) if platform_config is None: platform_config_file = os.path.join( mod_dir, "configs", "platforms", f"{platform}.yml" ) platform_config = load_yml(platform_config_file) self.platform_config = platform_config self.log_line_sep = log_line_sep self.log_line_length = log_line_length
[docs] def run( self, file, patch_url, ): """Run the ROMPatcher""" filename_no_ext = os.path.splitext(os.path.basename(file))[0] patch_dir = str(os.path.join(self.patch_dir, self.platform, filename_no_ext)) # Clean out and create patch directory if os.path.exists(patch_dir): shutil.rmtree(patch_dir) if not os.path.exists(patch_dir): os.makedirs(patch_dir) # Move and unzip file, if needed self.logger.info( centred_string( f"Moving {file} to {patch_dir}", total_length=self.log_line_length, ) ) # If we have a zip, unzip that cutie if file.endswith(".zip"): unzip_file(file, patch_dir) else: ensure_directory = "" if not self.patch_dir.endswith(os.path.sep): ensure_directory = os.path.sep shutil.copy(file, patch_dir + ensure_directory) # Find the unpatched file unpatched_file = self.get_unpatched_file(patch_dir=patch_dir) # Next up, download the patch file patch_file = self.download_patch_file( patch_url, patch_dir=patch_dir, ) # Now we have everything we need to patch this ROM patched_file = self.patch_rom( unpatched_file=unpatched_file, patch_file=patch_file, ) return patched_file
[docs] def get_unpatched_file( self, patch_dir, ): """Get the unpatched file from the patch directory Args: patch_dir (str): Patch directory """ file_exts = self.platform_config.get("file_exts", []) # Error if we don't have file extensions defined if len(file_exts) == 0: raise ValueError( "File extensions need to be defined in the platform config file" ) rom_file = find_file_by_extensions( file_exts=file_exts, patch_dir=patch_dir, ) return rom_file
[docs] def download_patch_file( self, patch_url, patch_dir, ): """Download a patch file Args: patch_url (str): URL to patch file patch_dir (str): Patch directory """ self.logger.info( centred_string( f"Downloading patch file: {patch_url}", total_length=self.log_line_length, ) ) patch_file = wget.download(patch_url, out=patch_dir) if patch_file.endswith(".zip"): unzip_file(patch_file, patch_dir) # Find the patch file patch_file_exts = self.platform_config.get("patch_file_exts", []) # Error if we don't have patch file extensions defined if len(patch_file_exts) == 0: raise ValueError( "Patch file extensions need to be defined in the platform config file" ) patch_file = find_file_by_extensions( patch_file_exts, patch_dir=patch_dir, ) return patch_file
[docs] def patch_rom(self, unpatched_file, patch_file): """Patch a ROM Args: unpatched_file (str): ROM file to patch patch_file (str): Patch file to patch """ # Get the method we're using to patch things patch_method = self.platform_config.get("patch_method", None) # Error out if the patch method is not defined if patch_method is None: raise ValueError( "Patch method needs to be defined in the platform config file" ) # Build an output file, adding a (ROMPatched) to the bit before the file extension unpatch_file_split = os.path.splitext(unpatched_file) patched_file = f"{unpatch_file_split[0]} (ROMPatched){unpatch_file_split[1]}" if patch_method == "xdelta": self.xdelta_patch( unpatched_file=unpatched_file, patch_file=patch_file, out_file=patched_file, ) else: raise ValueError( f"Patch method needs to be one of {', '.join(ALLOWED_PATCH_METHODS)}, not {patch_method}" ) self.logger.info( centred_string( f"Patching complete!", total_length=self.log_line_length, ) ) return patched_file
[docs] def xdelta_patch( self, unpatched_file, patch_file, out_file, ): """Patch using xdelta Args: unpatched_file (str): ROM file to patch patch_file (str): Patch file to patch out_file (str): Path for output file """ xdelta_path = self.config.get("rompatcher", {}).get("xdelta_path", None) if xdelta_path is None: raise ValueError("Path to xdelta needs to be defined in user config") if not os.path.exists(xdelta_path): raise ValueError("xdelta path not found") cmd = f'{xdelta_path} -d -s "{unpatched_file}" "{patch_file}" "{out_file}"' self.logger.info( centred_string( f"Patching file with xdelta:", total_length=self.log_line_length, ) ) self.logger.info( left_aligned_string( f"-> Unpatched file: {os.path.basename(unpatched_file)}", total_length=self.log_line_length, ) ) self.logger.info( left_aligned_string( f"-> Patch file: {os.path.basename(patch_file)}", total_length=self.log_line_length, ) ) self.logger.info( left_aligned_string( f"-> Output file: {os.path.basename(out_file)}", total_length=self.log_line_length, ) ) os.system(cmd) return True