import os import sys import math import logging import argparse import numpy as np import pandas as pd from collections import Counter from typing import Dict, List, Tuple, Optional, Any, Set import json import xxhash # pip install xxhash import lief # pip install lief import pefile # pip install pefile import multiprocessing import time # Added for execution time measurement import tempfile # For temporary files/directories import shutil # For removing temporary directory from tqdm import tqdm # Import tqdm for progress bars # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Reduce tqdm's default log level spam if logging level is INFO if logger.getEffectiveLevel() == logging.INFO: logging.getLogger('tqdm').setLevel(logging.WARNING) # --- Top-level helper for entropy --- def calculate_entropy_static(data: bytes) -> float: if not data: return 0.0 byte_counts = Counter(data) data_len = len(data) entropy = 0.0 for count in byte_counts.values(): probability = count / data_len if probability > 0: entropy -= probability * math.log2(probability) return entropy # --- Worker function for Pass 1: Vocabulary Building --- def _worker_extract_imports_for_vocab(file_path: str) -> Tuple[Optional[str], Optional[Set[str]], Optional[Set[str]]]: api_names, dll_names, pe_data_obj = set(), set(), None try: try: pe_data_obj = lief.parse(file_path) if pe_data_obj is None: try: pe_data_obj = pefile.PE(file_path, fast_load=True) except Exception: raise lief.bad_file("lief.parse returned None and pefile failed") except Exception: # lief failed (or returned None and pefile also failed) try: pe_data_obj = pefile.PE(file_path, fast_load=True) except Exception: return file_path, None, None if hasattr(pe_data_obj, 'imports') and isinstance(pe_data_obj, lief.PE.Binary): # lief for imported_lib in pe_data_obj.imports: dll_name_str = imported_lib.name dll_name = dll_name_str.lower() if dll_name_str and isinstance(dll_name_str, str) else 'unknown_dll' dll_names.add(dll_name) for imported_func in imported_lib.entries: if imported_func.name and isinstance(imported_func.name, str): api_names.add(imported_func.name) elif hasattr(pe_data_obj, 'DIRECTORY_ENTRY_IMPORT') and isinstance(pe_data_obj, pefile.PE): # pefile if hasattr(pe_data_obj, 'parse_data_directories'): # Ensure imports are parsed for pefile pe_data_obj.parse_data_directories(directories=[pefile.DIRECTORY_ENTRY['IMAGE_DIRECTORY_ENTRY_IMPORT']]) if not hasattr(pe_data_obj, 'DIRECTORY_ENTRY_IMPORT'): # Still no imports return file_path, api_names, dll_names # Return empty sets if no imports for entry in pe_data_obj.DIRECTORY_ENTRY_IMPORT: if entry.dll: try: dll_name_str = entry.dll.decode(errors='ignore') dll_name = dll_name_str.lower() dll_names.add(dll_name) for imported_func in entry.imports: if imported_func.name: try: api_name_str = imported_func.name.decode(errors='ignore') api_names.add(api_name_str) except UnicodeDecodeError: pass # Ignore API names that can't be decoded except UnicodeDecodeError: dll_names.add('unknown_dll_decode_error') # Mark DLLs that caused decoding issues # Filter out empty or whitespace-only names api_names = {api for api in api_names if api and api.strip()} dll_names = {dll for dll in dll_names if dll and dll.strip()} return file_path, api_names, dll_names except Exception: # logger.debug(f"Vocab worker general error for {file_path}: {e_outer}", exc_info=False) # Can be noisy return file_path, None, None # --- Globals for Pass 2 Workers (to be initialized) --- g_api_vocab: Set[str] = set() g_dll_vocab: Set[str] = set() g_expected_columns_pass2: List[str] = [] g_samples_labels: Dict[str, int] = {} def _init_pass2_worker_globals(api_vocab_arg: Set[str], dll_vocab_arg: Set[str], expected_cols_arg: List[str], samples_labels_arg: Dict[str, int]): global g_api_vocab, g_dll_vocab, g_expected_columns_pass2, g_samples_labels g_api_vocab = api_vocab_arg g_dll_vocab = dll_vocab_arg g_expected_columns_pass2 = expected_cols_arg g_samples_labels = samples_labels_arg # --- SINGLE FILE Feature Extraction Worker --- def _extract_single_file_features(file_path: str) -> Optional[Dict[str, Any]]: start_time_file = time.time() # Initialize features dict with all expected columns to ensure consistency even if some are not populated features: Dict[str, Any] = {col: 0 for col in g_expected_columns_pass2} features['file_path'] = file_path # Ensure file_path is always present file_id = os.path.basename(file_path) features['sample_label'] = g_samples_labels.get(file_id, 0) # Default to 0 if not found try: with open(file_path, 'rb') as f: file_data = f.read() pe_data_obj: Any = None try: pe_data_obj = lief.parse(file_path) if pe_data_obj is None: # lief.parse can return None without raising an exception try: pe_data_obj = pefile.PE(file_path, fast_load=True) except Exception: # pefile also failed raise lief.bad_file("lief.parse returned None and pefile failed") except Exception: # lief failed (or returned None and pefile also failed) try: pe_data_obj = pefile.PE(file_path, fast_load=True) except Exception as e_pefile_outer: logger.debug(f"Feature worker failed to parse {file_path} with both lief and pefile: {e_pefile_outer}") # Populate minimal features for unparseable files features['execution_time_s'] = time.time() - start_time_file features['file_size'] = len(file_data) if file_data else 0 features['file_size_kb'] = features['file_size'] / 1024.0 if file_data else 0.0 return features # Return partially filled features # 1. File-level features features['xxhash64'] = xxhash.xxh64(file_data).hexdigest() features['file_size'] = len(file_data) features['file_size_kb'] = features['file_size'] / 1024.0 features['file_entropy'] = calculate_entropy_static(file_data) features['size_tiny'] = 1 if features['file_size'] < 10240 else 0 features['size_small'] = 1 if 10240 <= features['file_size'] < 102400 else 0 features['size_medium'] = 1 if 102400 <= features['file_size'] < 1048576 else 0 features['size_large'] = 1 if 1048576 <= features['file_size'] < 10485760 else 0 features['size_huge'] = 1 if features['file_size'] >= 10485760 else 0 features['entropy_low'] = 1 if features['file_entropy'] < 4.0 else 0 features['entropy_medium'] = 1 if 4.0 <= features['file_entropy'] < 7.0 else 0 features['entropy_high'] = 1 if features['file_entropy'] >= 7.0 else 0 features['entropy_packed'] = 1 if features['file_entropy'] > 7.5 else 0 # Common heuristic for packed files file_ext = os.path.splitext(file_path)[1].lower() features['is_exe_ext'] = 1 if file_ext == '.exe' else 0 features['is_dll_ext'] = 1 if file_ext == '.dll' else 0 features['is_scr_ext'] = 1 if file_ext == '.scr' else 0 features['is_com_ext'] = 1 if file_ext == '.com' else 0 # 2. Header features try: if isinstance(pe_data_obj, lief.PE.Binary): # LIEF parsed object if hasattr(pe_data_obj, 'optional_header'): oh = pe_data_obj.optional_header features.update({ 'image_base': oh.imagebase, 'size_of_image': oh.sizeof_image, 'size_of_headers': oh.sizeof_headers, 'checksum': oh.checksum, 'subsystem': oh.subsystem.value if hasattr(oh.subsystem, 'value') else oh.subsystem, # Enum to value 'dll_characteristics': oh.dll_characteristics, 'size_of_stack_reserve': oh.sizeof_stack_reserve, 'size_of_stack_commit': oh.sizeof_stack_commit, 'size_of_heap_reserve': oh.sizeof_heap_reserve, 'size_of_heap_commit': oh.sizeof_heap_commit, 'loader_flags': oh.loader_flags, 'number_of_rva_and_sizes': oh.numberof_rva_and_size }) if hasattr(pe_data_obj, 'header'): header = pe_data_obj.header features.update({ 'machine': header.machine.value if hasattr(header.machine, 'value') else header.machine, # Enum to value 'number_of_sections_hdr': header.numberof_sections, 'time_date_stamp': header.time_date_stamps, 'pointer_to_symbol_table': header.pointerto_symbol_table, 'number_of_symbols': header.numberof_symbols, 'size_of_optional_header': header.sizeof_optional_header, 'characteristics': header.characteristics }) elif isinstance(pe_data_obj, pefile.PE): # PEFile parsed object if hasattr(pe_data_obj, 'OPTIONAL_HEADER'): oh = pe_data_obj.OPTIONAL_HEADER features.update({ 'image_base': oh.ImageBase, 'size_of_image': oh.SizeOfImage, 'size_of_headers': oh.SizeOfHeaders, 'checksum': oh.CheckSum, 'subsystem': oh.Subsystem, 'dll_characteristics': oh.DllCharacteristics, 'size_of_stack_reserve': oh.SizeOfStackReserve, 'size_of_stack_commit': oh.SizeOfStackCommit, 'size_of_heap_reserve': oh.SizeOfHeapReserve, 'size_of_heap_commit': oh.SizeOfHeapCommit, 'loader_flags': oh.LoaderFlags, 'number_of_rva_and_sizes': oh.NumberOfRvaAndSizes }) if hasattr(pe_data_obj, 'FILE_HEADER'): fh = pe_data_obj.FILE_HEADER features.update({ 'machine': fh.Machine, 'number_of_sections_hdr': fh.NumberOfSections, 'time_date_stamp': fh.TimeDateStamp, 'pointer_to_symbol_table': fh.PointerToSymbolTable, 'number_of_symbols': fh.NumberOfSymbols, 'size_of_optional_header': fh.SizeOfOptionalHeader, 'characteristics': fh.Characteristics }) except Exception: # Catch any error during header extraction logger.debug(f"Error extracting header features for {file_path}", exc_info=False) # Derived header features (ensure keys exist from try-except block above, default to 0 if not) features['is_32bit_machine'] = 1 if features.get('machine', 0) == 0x014c else 0 # IMAGE_FILE_MACHINE_I386 features['is_64bit_machine'] = 1 if features.get('machine', 0) == 0x8664 else 0 # IMAGE_FILE_MACHINE_AMD64 features['is_dll_char'] = 1 if features.get('characteristics', 0) & 0x2000 else 0 # IMAGE_FILE_DLL features['is_executable_char'] = 1 if features.get('characteristics', 0) & 0x0002 else 0 # IMAGE_FILE_EXECUTABLE_IMAGE features['is_console_app_subsystem'] = 1 if features.get('subsystem', 0) == 3 else 0 # IMAGE_SUBSYSTEM_WINDOWS_CUI features['is_gui_app_subsystem'] = 1 if features.get('subsystem', 0) == 2 else 0 # IMAGE_SUBSYSTEM_WINDOWS_GUI # 3. Section features sections_info_list = [] try: if isinstance(pe_data_obj, lief.PE.Binary) and hasattr(pe_data_obj, 'sections'): for section in pe_data_obj.sections: s_info = { 'name': section.name, 'size': section.size, 'virtual_size': section.virtual_size, 'raw_size': section.sizeof_raw_data, 'entropy': calculate_entropy_static(bytes(section.content)), 'writable': bool(section.characteristics & lief.PE.SECTION_CHARACTERISTICS.MEM_WRITE), 'executable': bool(section.characteristics & lief.PE.SECTION_CHARACTERISTICS.MEM_EXECUTE)} sections_info_list.append(s_info) elif isinstance(pe_data_obj, pefile.PE) and hasattr(pe_data_obj, 'sections'): for section in pe_data_obj.sections: s_info = { 'name': section.Name.decode(errors='ignore').rstrip('\x00'), 'size': section.SizeOfRawData, 'virtual_size': section.Misc_VirtualSize, 'raw_size': section.SizeOfRawData, 'entropy': calculate_entropy_static(section.get_data()), 'writable': bool(section.Characteristics & 0x80000000), # IMAGE_SCN_MEM_WRITE 'executable': bool(section.Characteristics & 0x20000000)} # IMAGE_SCN_MEM_EXECUTE sections_info_list.append(s_info) if sections_info_list: features['num_sections'] = len(sections_info_list) features['total_section_size'] = sum(s['raw_size'] for s in sections_info_list) features['total_virtual_size'] = sum(s['virtual_size'] for s in sections_info_list) entropies = [s['entropy'] for s in sections_info_list if s['entropy'] is not None and s['entropy'] > 0] # Filter out 0 or None if entropies: features.update({'avg_section_entropy': np.mean(entropies) if entropies else 0.0, 'max_section_entropy': np.max(entropies) if entropies else 0.0, 'min_section_entropy': np.min(entropies) if entropies else 0.0, 'std_section_entropy': np.std(entropies) if entropies else 0.0}) else: # Ensure these keys exist even if no valid entropies features.update({'avg_section_entropy':0.0,'max_section_entropy':0.0,'min_section_entropy':0.0,'std_section_entropy':0.0}) features['num_writable_sections'] = sum(1 for s in sections_info_list if s['writable']) features['num_executable_sections'] = sum(1 for s in sections_info_list if s['executable']) features['num_writable_executable_sections'] = sum(1 for s in sections_info_list if s['writable'] and s['executable']) section_names_upper = [s['name'].upper() for s in sections_info_list] standard_sections_set = {'.TEXT', '.DATA', '.RDATA', '.BSS', '.IDATA', '.EDATA', '.PDATA', '.RSRC', '.RELOC', 'CODE', 'DATA'} # Count standard sections more robustly features['num_standard_sections'] = sum(1 for name in section_names_upper if name in standard_sections_set or name.startswith(('.TEXT', '.DATA'))) features['num_custom_sections'] = features.get('num_sections',0) - features.get('num_standard_sections',0) suspicious_names_set = {'UPX', 'ASPACK', 'FSG', 'RLPACK', 'MEW', 'MPRESS', 'NSIS', 'THEMIDA', 'VMPROTECT', 'ENIGMA'} features['has_suspicious_section_names'] = int(any(any(susp_name in s_name for susp_name in suspicious_names_set) for s_name in section_names_upper)) if features.get('total_section_size',0) > 0: # Use .get for safety features['virtual_to_raw_size_ratio'] = features.get('total_virtual_size',0) / features['total_section_size'] else: features['virtual_to_raw_size_ratio'] = 0.0 # Avoid division by zero except Exception: logger.debug(f"Error extracting section features for {file_path}", exc_info=False) # 4. Import features current_file_apis: Dict[str, int] = {} # For counting occurrences if needed, though vocab is binary current_file_dlls: Set[str] = set() try: if isinstance(pe_data_obj, lief.PE.Binary) and hasattr(pe_data_obj, 'imports'): for imported_lib in pe_data_obj.imports: dll_name_str = imported_lib.name dll_name = dll_name_str.lower() if dll_name_str and isinstance(dll_name_str, str) else 'unknown_dll' current_file_dlls.add(dll_name) for imported_func in imported_lib.entries: if imported_func.name and isinstance(imported_func.name, str): current_file_apis[imported_func.name] = current_file_apis.get(imported_func.name, 0) + 1 elif isinstance(pe_data_obj, pefile.PE) and hasattr(pe_data_obj, 'DIRECTORY_ENTRY_IMPORT'): if hasattr(pe_data_obj, 'parse_data_directories') and not hasattr(pe_data_obj, 'parsed_data_directories'): # Check if already parsed pe_data_obj.parse_data_directories(directories=[pefile.DIRECTORY_ENTRY['IMAGE_DIRECTORY_ENTRY_IMPORT']]) if not hasattr(pe_data_obj, 'DIRECTORY_ENTRY_IMPORT'): # Still no imports after explicit parse pass # No imports to process else: for entry in pe_data_obj.DIRECTORY_ENTRY_IMPORT: if entry.dll: try: dll_name_str = entry.dll.decode(errors='ignore') dll_name = dll_name_str.lower() current_file_dlls.add(dll_name) for imported_func in entry.imports: if imported_func.name: try: api_name_str = imported_func.name.decode(errors='ignore') current_file_apis[api_name_str] = current_file_apis.get(api_name_str, 0) + 1 except UnicodeDecodeError: pass except UnicodeDecodeError: pass except Exception: logger.debug(f"Error extracting import features for {file_path}", exc_info=False) # Populate binary features based on vocabulary for api_from_vocab in g_api_vocab: # g_api_vocab is set by _init_pass2_worker_globals if api_from_vocab in current_file_apis: features[f'api_{api_from_vocab}'] = 1 for dll_from_vocab in g_dll_vocab: # g_dll_vocab is set by _init_pass2_worker_globals if dll_from_vocab in current_file_dlls: features[f'dll_{dll_from_vocab}'] = 1 # Aggregate import counts features['total_imported_apis_count'] = sum(current_file_apis.values()) features['unique_imported_apis_count'] = len(current_file_apis) features['total_imported_dlls_count'] = len(current_file_dlls) if features['total_imported_dlls_count'] > 0: features['avg_api_imports_per_dll'] = features.get('total_imported_apis_count',0) / features['total_imported_dlls_count'] else: features['avg_api_imports_per_dll'] = 0.0 features['apis_in_vocab_count'] = sum(1 for api in current_file_apis if api in g_api_vocab) features['dlls_in_vocab_count'] = sum(1 for dll in current_file_dlls if dll in g_dll_vocab) if features['unique_imported_apis_count'] > 0: features['api_vocab_coverage'] = features.get('apis_in_vocab_count',0) / features['unique_imported_apis_count'] else: features['api_vocab_coverage'] = 0.0 if features['total_imported_dlls_count'] > 0: features['dll_vocab_coverage'] = features.get('dlls_in_vocab_count',0) / features['total_imported_dlls_count'] else: features['dll_vocab_coverage'] = 0.0 features['execution_time_s'] = time.time() - start_time_file return features except Exception as e_outer: # Catch-all for any other unhandled error in this worker logger.error(f"Worker unhandled error processing {file_path}: {e_outer}", exc_info=False) features['execution_time_s'] = time.time() - start_time_file # Record time even on error try: # Attempt to get file size if it exists, even on error if os.path.exists(file_path): features['file_size'] = os.path.getsize(file_path) features['file_size_kb'] = features['file_size'] / 1024.0 except Exception: pass # Ignore if this also fails return features # Return partially filled features # --- Worker function for processing a CHUNK of files and writing a temp CSV --- def _worker_process_file_chunk_to_temp_csv( file_chunk: List[str], temp_dir: str, worker_id: int # 0-indexed for tqdm position ) -> Optional[str]: """ Processes a list of file paths, extracts features, and writes them to a temporary CSV. Returns the path to the temporary CSV file, or None on failure. Relies on globals (g_api_vocab, etc.) being initialized via _init_pass2_worker_globals. """ feature_list_for_chunk: List[Dict[str, Any]] = [] # Each worker gets its own tqdm bar, positioned by worker_id with tqdm(total=len(file_chunk), desc=f"Worker {worker_id:02d}", position=worker_id, leave=False, ascii=True, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') as pbar: for file_path in file_chunk: features = _extract_single_file_features(file_path) if features: # features should always be returned, even if partial feature_list_for_chunk.append(features) pbar.update(1) if not feature_list_for_chunk: # This might happen if all files in a chunk failed to produce any features (e.g., all unparseable and _extract_single_file_features returned None before fix) # logger.warning(f"Worker {worker_id} produced no features for its chunk.") # Can be noisy if chunks are small return None # Create DataFrame from the collected features for this chunk try: chunk_df = pd.DataFrame(feature_list_for_chunk) # Reindex to ensure all columns defined in g_expected_columns_pass2 are present, filling NaNs with 0 # This is crucial if _extract_single_file_features sometimes fails to populate all fields. chunk_df = chunk_df.reindex(columns=g_expected_columns_pass2).fillna(0) except Exception as e: logger.error(f"Worker {worker_id} error creating DataFrame for chunk: {e}") return None temp_csv_path = os.path.join(temp_dir, f"worker_{worker_id:02d}_features.csv") try: chunk_df.to_csv(temp_csv_path, index=False, header=True) # Each temp CSV has a header # logger.info(f"Worker {worker_id} finished and wrote {len(chunk_df)} entries to {temp_csv_path}.") # Can be noisy return temp_csv_path except Exception as e: logger.error(f"Worker {worker_id} failed to write temp CSV {temp_csv_path}: {e}") return None class PEFeatureExtractor: def __init__(self, max_api_features: int = 15000, max_dll_features: int = 5000, num_cores: Optional[int] = None, samples_labels: Optional[Dict[str, int]] = None): self.max_api_features = max_api_features self.max_dll_features = max_dll_features self.num_cores = num_cores if num_cores is not None and num_cores > 0 else (os.cpu_count() or 1) self.api_vocab: Set[str] = set() self.dll_vocab: Set[str] = set() self.api_frequency = Counter() # Stores full frequencies self.dll_frequency = Counter() # Stores full frequencies self._expected_feature_columns: List[str] = [] self.samples_labels: Dict[str, int] = samples_labels if samples_labels is not None else {} @staticmethod def calculate_entropy(data: bytes) -> float: return calculate_entropy_static(data) def _build_expected_feature_columns(self): """Defines the canonical list of feature columns AFTER vocabulary is set.""" base_features = [ 'file_path', 'xxhash64', 'sample_label', 'file_size', 'file_size_kb', 'file_entropy', 'size_tiny', 'size_small', 'size_medium', 'size_large', 'size_huge', 'entropy_low', 'entropy_medium', 'entropy_high', 'entropy_packed', 'is_exe_ext', 'is_dll_ext', 'is_scr_ext', 'is_com_ext', 'image_base', 'size_of_image', 'size_of_headers', 'checksum', 'subsystem', 'dll_characteristics', 'size_of_stack_reserve', 'size_of_stack_commit', 'size_of_heap_reserve', 'size_of_heap_commit', 'loader_flags', 'number_of_rva_and_sizes', 'machine', 'number_of_sections_hdr', 'time_date_stamp', 'pointer_to_symbol_table', 'number_of_symbols', 'size_of_optional_header', 'characteristics', 'is_32bit_machine', 'is_64bit_machine', 'is_dll_char', 'is_executable_char', 'is_console_app_subsystem', 'is_gui_app_subsystem', 'num_sections', 'total_section_size', 'total_virtual_size', 'avg_section_entropy', 'max_section_entropy', 'min_section_entropy', 'std_section_entropy', 'num_writable_sections', 'num_executable_sections', 'num_writable_executable_sections', 'num_standard_sections', 'num_custom_sections', 'has_suspicious_section_names', 'virtual_to_raw_size_ratio', 'total_imported_apis_count', 'unique_imported_apis_count', 'total_imported_dlls_count', 'avg_api_imports_per_dll', 'apis_in_vocab_count', 'dlls_in_vocab_count', 'api_vocab_coverage', 'dll_vocab_coverage', 'execution_time_s' ] # These vocabs (self.api_vocab, self.dll_vocab) are set by build_vocabulary_from_files or load_vocabulary api_cols = sorted([f'api_{api}' for api in self.api_vocab]) dll_cols = sorted([f'dll_{dll}' for dll in self.dll_vocab]) # Define a preferred order for the initial set of features ordered_cols = ['file_path', 'xxhash64', 'sample_label', 'file_size', 'file_size_kb', 'execution_time_s'] # Get remaining base features not in the initial ordered list remaining_base = [f for f in base_features if f not in ordered_cols] # Combine all, ensuring no duplicates and maintaining some order combined_cols = ordered_cols + remaining_base + api_cols + dll_cols seen = set() # self._expected_feature_columns = [x for x in combined_cols if not (x in seen or seen.add(x))] # Simpler unique list maintaining order for Python 3.7+ self._expected_feature_columns = list(dict.fromkeys(combined_cols)) def build_vocabulary_from_files(self, file_paths: List[str]) -> None: logger.info(f"Starting vocabulary building: {len(file_paths)} files, {self.num_cores} cores.") if not file_paths: logger.warning("No files for vocab. Vocab will be empty.") self.api_vocab = set() self.dll_vocab = set() self.api_frequency = Counter() self.dll_frequency = Counter() self._build_expected_feature_columns() # Build with empty vocabs return pass_start_time = time.time() aggregated_apis, aggregated_dlls = Counter(), Counter() # These are local to this function pass processed_count, failed_count = 0, 0 with multiprocessing.Pool(processes=self.num_cores) as pool: results_iter = pool.imap_unordered(_worker_extract_imports_for_vocab, file_paths) with tqdm(total=len(file_paths), desc="Vocab Building", ascii=True, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar: for res in results_iter: _file_path, apis, dlls = res if apis is not None and dlls is not None: aggregated_apis.update(apis) aggregated_dlls.update(dlls) processed_count += 1 else: failed_count += 1 pbar.update(1) # Store full frequencies on the instance before filtering for vocab self.api_frequency = Counter({str(k): v for k, v in aggregated_apis.items() if k and str(k).strip()}) self.dll_frequency = Counter({str(k): v for k, v in aggregated_dlls.items() if k and str(k).strip()}) # --- API Vocabulary Selection (50% top, 50% bottom) --- num_total_unique_apis = len(aggregated_apis) # Use local aggregated_apis target_api_vocab_size = self.max_api_features if num_total_unique_apis == 0: self.api_vocab = set() elif num_total_unique_apis <= target_api_vocab_size: # If total unique APIs are fewer than or equal to the target, take all of them. self.api_vocab = {str(api) for api, count in aggregated_apis.items() if api and str(api).strip()} else: # More unique APIs than target_vocab_size, so apply top/bottom split num_top_apis = target_api_vocab_size // 2 num_bottom_apis = target_api_vocab_size - num_top_apis # Handles odd target_api_vocab_size all_apis_sorted_by_freq = aggregated_apis.most_common() # List of (api, count), most common first top_api_names = {str(api) for api, count in all_apis_sorted_by_freq[:num_top_apis] if api and str(api).strip()} # Select from the end of the list for least frequent bottom_api_names = {str(api) for api, count in all_apis_sorted_by_freq[-num_bottom_apis:] if api and str(api).strip()} self.api_vocab = top_api_names.union(bottom_api_names) # --- DLL Vocabulary Selection (50% top, 50% bottom) --- num_total_unique_dlls = len(aggregated_dlls) # Use local aggregated_dlls target_dll_vocab_size = self.max_dll_features if num_total_unique_dlls == 0: self.dll_vocab = set() elif num_total_unique_dlls <= target_dll_vocab_size: self.dll_vocab = {str(dll) for dll, count in aggregated_dlls.items() if dll and str(dll).strip()} else: num_top_dlls = target_dll_vocab_size // 2 num_bottom_dlls = target_dll_vocab_size - num_top_dlls all_dlls_sorted_by_freq = aggregated_dlls.most_common() # List of (dll, count) top_dll_names = {str(dll) for dll, count in all_dlls_sorted_by_freq[:num_top_dlls] if dll and str(dll).strip()} bottom_dll_names = {str(dll) for dll, count in all_dlls_sorted_by_freq[-num_bottom_dlls:] if dll and str(dll).strip()} self.dll_vocab = top_dll_names.union(bottom_dll_names) self._build_expected_feature_columns() # Rebuild columns based on the new self.api_vocab and self.dll_vocab logger.info( f"Vocab building completed. Took {time.time()-pass_start_time:.2f}s.\n" f" - Files processed for vocab: {processed_count}, Failed: {failed_count}\n" f" - Total unique APIs found: {len(self.api_frequency)}, Selected for API vocab: {len(self.api_vocab)} (target mix for: {self.max_api_features})\n" f" - Total unique DLLs found: {len(self.dll_frequency)}, Selected for DLL vocab: {len(self.dll_vocab)} (target mix for: {self.max_dll_features})" ) def extract_features_to_csv( self, file_paths_to_process: List[str], output_csv_file: str ) -> pd.DataFrame: if not file_paths_to_process: logger.warning("No files provided for feature extraction.") return pd.DataFrame(columns=self._expected_feature_columns or []) # Return empty DF with columns if known if not self.api_vocab and not self.dll_vocab: logger.warning("API and DLL vocabularies are empty. API/DLL related features will be zero.") # Ensure expected columns are built, especially if vocab was loaded and this is the first feature extraction if not self._expected_feature_columns: logger.info("Expected feature columns not built (e.g. vocab loaded). Building now.") self._build_expected_feature_columns() if not self._expected_feature_columns: # Still empty after build attempt logger.error("FATAL: _expected_feature_columns is empty even after build attempt. Cannot proceed.") # Create an empty CSV with just file_path if absolutely nothing else known pd.DataFrame(columns=['file_path']).to_csv(output_csv_file, index=False, header=True) return pd.DataFrame() logger.info(f"Starting feature extraction for {len(file_paths_to_process)} files. Output will be '{output_csv_file}'.") temp_dir = tempfile.mkdtemp(prefix="pe_feat_chunks_") # logger.info(f"Using temporary directory for chunks: {temp_dir}") # Can be noisy, tqdm implies activity num_workers = self.num_cores # Adjust number of workers if fewer files than cores actual_num_workers = min(num_workers, len(file_paths_to_process)) if actual_num_workers == 0: # Should be caught by initial file_paths_to_process check logger.warning("No files to process after worker adjustment (should not happen if files exist).") shutil.rmtree(temp_dir, ignore_errors=True) return pd.DataFrame(columns=self._expected_feature_columns or []) chunk_size = math.ceil(len(file_paths_to_process) / actual_num_workers) # Ensure all files are covered file_chunks = [file_paths_to_process[i:i + chunk_size] for i in range(0, len(file_paths_to_process), chunk_size)] # Prepare arguments for each worker: (file_chunk, temp_dir, worker_id) # Ensure worker_args only contains non-empty chunks worker_args = [(chunk, temp_dir, idx) for idx, chunk in enumerate(file_chunks) if chunk] # Globals for worker processes init_args_pass2 = (self.api_vocab, self.dll_vocab, self._expected_feature_columns, self.samples_labels) temp_csv_paths: List[str] = [] all_workers_completed_successfully = True # Flag to track if any worker failed to produce output try: logger.info(f"Launching {len(worker_args)} worker(s) for feature extraction...") with multiprocessing.Pool(processes=actual_num_workers, initializer=_init_pass2_worker_globals, initargs=init_args_pass2) as pool: # starmap blocks until all results are collected results = pool.starmap(_worker_process_file_chunk_to_temp_csv, worker_args) # Attempt to clear tqdm's residual lines from multiple bars # This is a bit of a hack; console behavior can vary. # Printing newlines to stderr as tqdm often writes there. if sys.stderr.isatty(): # Only if output is a TTY sys.stderr.write("\n" * (actual_num_workers + 2)) # Push cursor below where bars were sys.stderr.flush() for temp_path_result in results: # Collect valid temp CSV paths returned by workers if temp_path_result and os.path.exists(temp_path_result) and os.path.getsize(temp_path_result) > 0: temp_csv_paths.append(temp_path_result) elif temp_path_result is None: # Worker explicitly returned None (e.g. error or no features) all_workers_completed_successfully = False logger.info(f"All worker processes completed. {len(temp_csv_paths)} temporary CSVs collected for merge.") if not all_workers_completed_successfully and len(worker_args) > len(temp_csv_paths): logger.warning(f"One or more workers ({len(worker_args) - len(temp_csv_paths)}) may not have completed successfully or produced output files.") if temp_csv_paths: logger.info(f"Concatenating {len(temp_csv_paths)} chunk CSVs...") all_chunk_dfs = [] # Show progress for merging step as well for p in tqdm(temp_csv_paths, desc="Merging temp CSVs", ascii=True, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}'): try: df_part = pd.read_csv(p) all_chunk_dfs.append(df_part) except pd.errors.EmptyDataError: # If a temp CSV was empty (e.g., all files in chunk failed parsing) logger.warning(f"Temporary CSV {p} is empty, skipping.") except Exception as e: logger.error(f"Error reading temporary CSV {p}: {e}") if all_chunk_dfs: final_df = pd.concat(all_chunk_dfs, ignore_index=True) # Final reindex to ensure canonical column order and presence, fill NaNs final_df = final_df.reindex(columns=self._expected_feature_columns).fillna(0) final_df.to_csv(output_csv_file, index=False, header=True) logger.info(f"Successfully wrote {len(final_df)} total entries to {output_csv_file}") return final_df else: logger.warning("No data in temporary CSVs after attempting to read. Output CSV will likely be empty or contain only headers.") # Write an empty DataFrame with headers if no data was concatenated pd.DataFrame(columns=self._expected_feature_columns).to_csv(output_csv_file, index=False, header=True) else: logger.warning("No valid temporary CSVs were produced by workers. Output CSV will be empty or contain only headers.") pd.DataFrame(columns=self._expected_feature_columns).to_csv(output_csv_file, index=False, header=True) except Exception as e: logger.error(f"Major error during feature extraction pool or concatenation: {e}", exc_info=True) # Attempt to create an empty file with headers if an error occurs mid-process pd.DataFrame(columns=self._expected_feature_columns or ['file_path']).to_csv(output_csv_file, index=False, header=True) return pd.DataFrame(columns=self._expected_feature_columns or []) # Return empty DF finally: # logger.info(f"Cleaning up temporary directory: {temp_dir}") # Can be noisy if many runs shutil.rmtree(temp_dir, ignore_errors=True) # Clean up temp dir regardless of outcome return pd.DataFrame(columns=self._expected_feature_columns or []) # Default return def save_vocabulary(self, vocab_file: str) -> None: valid_api_vocab = sorted([str(api) for api in self.api_vocab if api and str(api).strip()]) valid_dll_vocab = sorted([str(dll) for dll in self.dll_vocab if dll and str(dll).strip()]) # Use the full frequency counters stored on the instance api_freq_serializable = {str(k): v for k,v in sorted(self.api_frequency.items()) if k and str(k).strip()} dll_freq_serializable = {str(k): v for k,v in sorted(self.dll_frequency.items()) if k and str(k).strip()} vocab_data = { 'api_vocab': valid_api_vocab, # The selected vocabulary 'dll_vocab': valid_dll_vocab, # The selected vocabulary 'api_frequency': api_freq_serializable, # Full frequencies of all found items 'dll_frequency': dll_freq_serializable, # Full frequencies of all found items 'max_api_features': self.max_api_features, # Target config 'max_dll_features': self.max_dll_features, # Target config '_expected_feature_columns': self._expected_feature_columns # Current feature columns } try: with open(vocab_file, 'w') as f: json.dump(vocab_data, f, indent=2) logger.info(f"Vocabulary saved to {vocab_file}") except Exception as e: logger.error(f"Failed to save vocabulary to {vocab_file}: {e}") def load_vocabulary(self, vocab_file: str) -> bool: try: with open(vocab_file, 'r') as f: vocab_data = json.load(f) # Load the selected vocabularies self.api_vocab = {str(api) for api in vocab_data.get('api_vocab', []) if api and str(api).strip()} self.dll_vocab = {str(dll) for dll in vocab_data.get('dll_vocab', []) if dll and str(dll).strip()} # Load full frequencies self.api_frequency = Counter({str(k):v for k,v in vocab_data.get('api_frequency', {}).items() if k and str(k).strip()}) self.dll_frequency = Counter({str(k):v for k,v in vocab_data.get('dll_frequency', {}).items() if k and str(k).strip()}) # Load config params self.max_api_features = vocab_data.get('max_api_features', self.max_api_features) self.max_dll_features = vocab_data.get('max_dll_features', self.max_dll_features) # Always rebuild expected feature columns based on the loaded vocabs and current class definition # This ensures 'sample_label' and other static features are correctly included/ordered. self._build_expected_feature_columns() # Optional: Could load vocab_data.get('_expected_feature_columns') and compare/warn if different, # but rebuilding is generally safer for forward compatibility of the script. logger.info(f"Vocabulary loaded from {vocab_file}") logger.info(f" - API features in loaded vocab: {len(self.api_vocab)}") logger.info(f" - DLL features in loaded vocab: {len(self.dll_vocab)}") logger.info(f" - Total expected columns after load: {len(self._expected_feature_columns)}") return True except FileNotFoundError: logger.error(f"Vocabulary file {vocab_file} not found.") return False except Exception as e: logger.error(f"Failed to load vocabulary from {vocab_file}: {e}") # Reset to defaults on failed load self.api_vocab, self.dll_vocab = set(), set() self.api_frequency, self.dll_frequency = Counter(), Counter() self._build_expected_feature_columns() # Rebuild with empty vocabs return False def main(): parser = argparse.ArgumentParser(description='Extract PE features (Overwrites output, parallel per-process temp CSVs, 50/50 vocab strategy)') parser.add_argument('input_path', help='PE file or directory containing PE files') parser.add_argument('-o', '--output', help='Output CSV file (default: comprehensive_analysis.csv)', default='comprehensive_analysis.csv') parser.add_argument('samples_csv', help='Path to the CSV file containing sample IDs (filenames) and labels (in a "list" column). Example: id,list -> file1.exe,Blacklist') parser.add_argument('-r', '--recursive', action='store_true', help='Process directories recursively') parser.add_argument('--max-apis', type=int, default=15000, help='Target number of API features in vocab (default: 15000)') # Adjusted default from 50k to 15k as example parser.add_argument('--max-dlls', type=int, default=1000, help='Target number of DLL features in vocab (default: 1000)') # Adjusted default from 5k to 1k as example parser.add_argument('--save-vocab', type=str, help='Save vocabulary to a specific JSON file. If not set, auto-saves to .vocab.json') parser.add_argument('--load-vocab', type=str, help='Load vocabulary from a specific JSON file. Overrides auto-load from .vocab.json') parser.add_argument('--cores', type=int, default=None, help='Number of CPU cores to use (default: all available, min 1)') parser.add_argument('-v', '--verbose', action='store_true', help='Enable verbose logging (includes LIEF/PEFile debug messages)') args = parser.parse_args() if args.verbose: logging.getLogger().setLevel(logging.DEBUG) # Keep LIEF/PEFile logs less verbose unless truly debugging them, even in verbose mode logging.getLogger('lief').setLevel(logging.INFO) logging.getLogger('pefile').setLevel(logging.INFO) logging.getLogger('tqdm').setLevel(logging.INFO) # Show tqdm info if verbose else: lief.logging.disable() # Disable LIEF's own logging system logging.getLogger('lief').setLevel(logging.ERROR) # Set LIEF's Python logger to ERROR logging.getLogger('pefile').setLevel(logging.ERROR) # Set PEFile's Python logger to ERROR logging.getLogger('tqdm').setLevel(logging.ERROR) # Suppress tqdm info logs if not verbose num_cores_to_use = args.cores if args.cores is not None and args.cores > 0 else (os.cpu_count() or 1) logger.info(f"Using {num_cores_to_use} CPU cores for parallel processing.") # --- Load Samples CSV for Labels --- samples_labels_map: Dict[str, int] = {} if args.samples_csv: try: samples_df = pd.read_csv(args.samples_csv) if "id" not in samples_df.columns or "list" not in samples_df.columns: logger.error(f"Samples CSV '{args.samples_csv}' must contain 'id' and 'list' columns.") sys.exit(1) samples_df['id'] = samples_df['id'].astype(str) # Ensure filename is string for matching for _, row in samples_df.iterrows(): sample_id = row['id'] # This is expected to be the filename label_list_value = row['list'] samples_labels_map[sample_id] = 1 if label_list_value == "Blacklist" else 0 # Binary label logger.info(f"Successfully loaded {len(samples_labels_map)} labels from '{args.samples_csv}'.") except FileNotFoundError: logger.error(f"Samples CSV file not found: {args.samples_csv}") sys.exit(1) except Exception as e: logger.error(f"Error reading or processing samples CSV '{args.samples_csv}': {e}") sys.exit(1) else: logger.warning("No samples_csv provided. 'sample_label' feature will be 0 for all files.") extractor = PEFeatureExtractor( max_api_features=args.max_apis, max_dll_features=args.max_dlls, num_cores=num_cores_to_use, samples_labels=samples_labels_map ) # --- File Collection --- input_file_paths_discovered = [] if os.path.isfile(args.input_path): input_file_paths_discovered = [args.input_path] elif os.path.isdir(args.input_path): if args.recursive: for root, _, files_in_dir in os.walk(args.input_path): for file_name in files_in_dir: input_file_paths_discovered.append(os.path.join(root, file_name)) else: # Not recursive, just top-level directory for file_name in os.listdir(args.input_path): full_path = os.path.join(args.input_path, file_name) if os.path.isfile(full_path): # Only add if it's a file input_file_paths_discovered.append(full_path) else: logger.error(f"Input path {args.input_path} is not a valid file or directory") sys.exit(1) # Filter for unique, existing, non-empty files unique_abs_file_paths = sorted(list(set(os.path.abspath(fp) for fp in input_file_paths_discovered))) valid_input_files = [] for fp in unique_abs_file_paths: try: if os.path.exists(fp) and os.path.isfile(fp) and os.path.getsize(fp) > 0: valid_input_files.append(fp) # else: logger.debug(f"Skipping {fp}: Not a valid, non-empty file or inaccessible.") # Can be noisy except OSError as e: # Handles permission errors for os.path.getsize etc. logger.debug(f"Skipping {fp} due to OSError: {e}") pass if not valid_input_files: logger.error("No valid, existing, non-empty PE files found to process after filtering.") sys.exit(1) logger.info(f"Found {len(valid_input_files)} unique, existing, and non-empty files to consider for processing.") output_csv_abspath = os.path.abspath(args.output) # Script always overwrites or creates new output CSV logger.info(f"Output CSV '{output_csv_abspath}' will be created or overwritten.") # --- Vocabulary Management --- vocab_loaded_successfully = False associated_vocab_file = output_csv_abspath + ".vocab.json" # Default vocab file associated with output CSV # Initialize expected columns once, even if empty. build_vocabulary_from_files will call it again. # load_vocabulary will also call it. extractor._build_expected_feature_columns() if args.load_vocab: # User explicitly specified a vocab file to load logger.info(f"Attempting to load user-specified vocabulary from: {args.load_vocab}") if extractor.load_vocabulary(args.load_vocab): vocab_loaded_successfully = True else: logger.warning(f"Failed to load from {args.load_vocab}. Will attempt to build new vocabulary if needed.") elif os.path.exists(associated_vocab_file): # Attempt to auto-load associated vocab file logger.info(f"Attempting to load associated vocabulary from: {associated_vocab_file}") if extractor.load_vocabulary(associated_vocab_file): vocab_loaded_successfully = True else: logger.warning(f"Failed to load associated_vocab_file '{associated_vocab_file}'. Will build new vocabulary.") if not vocab_loaded_successfully: logger.info("Building new vocabulary from all valid input files...") extractor.build_vocabulary_from_files(valid_input_files) # Build from all files # Auto-save the newly built vocabulary if it contains data if extractor.api_vocab or extractor.dll_vocab: logger.info(f"Auto-saving newly built vocabulary to: {associated_vocab_file}") extractor.save_vocabulary(associated_vocab_file) else: logger.info("Vocabulary previously loaded. Skipping vocabulary building pass.") # Ensure columns are correctly built even if vocab loaded, as per current class definition # load_vocabulary already calls _build_expected_feature_columns # extractor._build_expected_feature_columns() # Not strictly needed here if load_vocabulary does it. if not extractor.api_vocab and not extractor.dll_vocab: logger.warning("Vocabulary is empty. API/DLL related features will be all zeros.") if not extractor._expected_feature_columns: logger.error("FATAL: Expected feature columns are not set even after vocabulary phase. Cannot proceed.") sys.exit(1) elif 'sample_label' not in extractor._expected_feature_columns: # Critical check logger.error("FATAL: 'sample_label' is missing from expected feature columns. Check _build_expected_feature_columns().") sys.exit(1) # --- Feature Extraction --- overall_start_time = time.time() df_this_session = extractor.extract_features_to_csv( valid_input_files, # Process all valid files output_csv_abspath ) overall_end_time = time.time() total_script_execution_time = overall_end_time - overall_start_time # --- Save Vocabulary (if specified by --save-vocab or if it's different from associated one and has content) --- custom_save_vocab_path = args.save_vocab # Save if there's content and either a custom path is given or the associated one needs update if extractor.api_vocab or extractor.dll_vocab or extractor._expected_feature_columns: if custom_save_vocab_path: logger.info(f"Saving vocabulary (and expected columns) to user-specified path: {custom_save_vocab_path}") extractor.save_vocabulary(custom_save_vocab_path) # If no custom path, the associated one was already saved if newly built. # If a custom path was used and it's different from associated, maybe update associated too? # For simplicity, if custom_save_vocab, only save there. Associated is handled during build. # The auto-save during build handles the associated_vocab_file. # This section is mainly for explicit --save-vocab. elif not (extractor.api_vocab or extractor.dll_vocab or extractor._expected_feature_columns): logger.warning("Vocabulary and expected columns are empty, skipping final save_vocabulary.") # --- Final Summary --- print(f"\n--- Feature Extraction Run Summary ---") # Use print for final summary to ensure visibility if not df_this_session.empty: print(f"Files processed in this session: {len(df_this_session)}") print(f"Total features per file: {len(df_this_session.columns)}") if 'execution_time_s' in df_this_session.columns: total_feature_extraction_time_sum = df_this_session['execution_time_s'].sum() avg_time_per_file = df_this_session['execution_time_s'].mean() print(f"Sum of individual file processing times (from feature vectors): {total_feature_extraction_time_sum:.2f} seconds") print(f"Average processing time per file (from feature vectors): {avg_time_per_file:.4f} seconds") memory_mb = df_this_session.memory_usage(deep=True).sum() / (1024 * 1024) print(f"Final DataFrame memory usage (this session's data): {memory_mb:.2f} MB") if 'sample_label' in df_this_session.columns: label_counts = df_this_session['sample_label'].value_counts().to_dict() print(f"Sample labels in this session's output: {label_counts}") else: print("No new files were processed in this session, or no features extracted.") if os.path.exists(output_csv_abspath) and os.path.getsize(output_csv_abspath) > 0 : print(f"Output CSV: {output_csv_abspath}") # Optionally, re-read to confirm total count if needed, but df_this_session represents this run's output else: print(f"Output CSV {output_csv_abspath} was not created or is empty.") print(f"Total script execution time (this session): {total_script_execution_time:.2f} seconds") if __name__ == '__main__': multiprocessing.freeze_support() # Good practice for Windows main()