https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/149215
>From 5a8f74a2c56d6052bf1db29fe3c16950c50c3987 Mon Sep 17 00:00:00 2001 From: svkeerthy <venkatakeer...@google.com> Date: Wed, 16 Jul 2025 22:45:36 +0000 Subject: [PATCH] triplet-ext-script --- llvm/docs/CommandGuide/llvm-ir2vec.rst | 3 + .../mlgo-utils/IR2Vec/generateTriplets.py | 291 ++++++++++++++++++ 2 files changed, 294 insertions(+) create mode 100644 llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py diff --git a/llvm/docs/CommandGuide/llvm-ir2vec.rst b/llvm/docs/CommandGuide/llvm-ir2vec.rst index 56ece4f509f6e..e39a663e3be5a 100644 --- a/llvm/docs/CommandGuide/llvm-ir2vec.rst +++ b/llvm/docs/CommandGuide/llvm-ir2vec.rst @@ -50,6 +50,9 @@ embedding training (see <https://github.com/thunlp/OpenKE/tree/OpenKE-PyTorch?tab=readme-ov-file#data-format> for details). +See `llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py` for more details on how +these two modes are used to generate the triplets and entity mappings. + Triplet Generation Mode ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py b/llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py new file mode 100644 index 0000000000000..0858d10ce0138 --- /dev/null +++ b/llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py @@ -0,0 +1,291 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""IR2Vec Triplet Generator + +Generates IR2Vec triplets by applying random optimization levels to LLVM IR files +and extracting triplets using llvm-ir2vec. Automatically generates preprocessed +files: entity2id.txt, relation2id.txt, and train2id.txt. + +Usage: + python generateTriplets.py <llvm_build_dir> <num_optimizations> <ll_file_list> <output_dir> +""" + +import argparse +import logging +import os +import random +import subprocess +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import List, Set, Tuple + +# Configuration +OPT_LEVELS = ["O0", "O1", "O2", "O3", "Os", "Oz"] +DEFAULT_MAX_WORKERS = 100 + +logger = logging.getLogger(__name__) + + +class TripletResult: + """Result from processing a single LLVM IR file""" + + __slots__ = ["triplets", "max_relation"] + + def __init__(self, triplets: Set[str], max_relation: int): + self.triplets = triplets + self.max_relation = max_relation + + +class IR2VecTripletGenerator: + """Main class for generating IR2Vec triplets""" + + def __init__( + self, + llvm_build_dir: Path, + num_optimizations: int, + output_dir: Path, + max_workers: int = DEFAULT_MAX_WORKERS, + ): + self.llvm_build_dir = llvm_build_dir + self.num_optimizations = num_optimizations + self.output_dir = output_dir + self.max_workers = max_workers + + # Tool paths + self.opt_binary = os.path.join(llvm_build_dir, "bin", "opt") + self.ir2vec_binary = os.path.join(llvm_build_dir, "bin", "llvm-ir2vec") + + self._validate_setup() + + def _validate_setup(self): + """Validate that all required tools and paths exist""" + if not self.llvm_build_dir.exists(): + raise FileNotFoundError( + f"LLVM build directory not found: {self.llvm_build_dir}" + ) + + if not os.path.isfile(self.opt_binary) or not os.access( + self.opt_binary, os.X_OK + ): + raise FileNotFoundError( + f"opt binary not found or not executable: {self.opt_binary}" + ) + + if not os.path.isfile(self.ir2vec_binary) or not os.access( + self.ir2vec_binary, os.X_OK + ): + raise FileNotFoundError( + f"llvm-ir2vec binary not found or not executable: {self.ir2vec_binary}" + ) + + if not (1 <= self.num_optimizations <= len(OPT_LEVELS)): + raise ValueError( + f"Number of optimizations must be between 1-{len(OPT_LEVELS)}" + ) + + self.output_dir.mkdir(parents=True, exist_ok=True) + + def _select_optimization_levels(self) -> List[str]: + """Select unique random optimization levels""" + return random.sample(OPT_LEVELS, self.num_optimizations) + + def _process_single_file(self, input_file: Path) -> TripletResult: + """Process a single LLVM IR file with multiple optimization levels""" + all_triplets = set() + max_relation = 1 + opt_levels = self._select_optimization_levels() + + for opt_level in opt_levels: + try: + triplets, file_max_relation = self._run_pipeline(input_file, opt_level) + if triplets: + all_triplets.update(triplets) + max_relation = max(max_relation, file_max_relation) + logger.debug( + f"Generated {len(triplets)} triplets for {input_file} with {opt_level}" + ) + except Exception as e: + logger.warning(f"Error processing {input_file} with {opt_level}: {e}") + + return TripletResult(all_triplets, max_relation) + + def _run_pipeline(self, input_file: Path, opt_level: str) -> Tuple[Set[str], int]: + """Run opt | llvm-ir2vec pipeline elegantly.""" + pipeline_cmd = ( + f'"{self.opt_binary}" -{opt_level} "{input_file}" -o - | ' + f'"{self.ir2vec_binary}" --mode=triplets - -o -' + ) + + try: + result = subprocess.run( + pipeline_cmd, shell=True, capture_output=True, text=True, check=True + ) + return self._parse_triplet_output(result.stdout) + except subprocess.CalledProcessError: + return set(), 1 + + def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]: + """Parse triplet output and extract max relation""" + if not output.strip(): + return set(), 1 + + lines = output.strip().split("\n") + max_relation = 1 + + # Extract max relation from metadata line + if lines and lines[0].startswith("MAX_RELATION="): + max_relation = int(lines[0].split("=")[1]) + lines = lines[1:] + + # Remove duplicate triplets by converting to a set + return set(lines), max_relation + + def generate_triplets(self, file_list: Path) -> None: + """Main method to generate triplets from a list of LLVM IR files""" + input_files = self._read_file_list(file_list) + logger.info( + f"Processing {len(input_files)} files with {self.num_optimizations} " + f"optimization levels using {self.max_workers} workers" + ) + + all_triplets = set() + global_max_relation = 1 + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_to_file = { + executor.submit(self._process_single_file, file): file + for file in input_files + } + + for future in as_completed(future_to_file): + try: + result = future.result() + all_triplets.update(result.triplets) + global_max_relation = max(global_max_relation, result.max_relation) + except Exception as e: + file_path = future_to_file[future] + logger.error(f"Error processing {file_path}: {e}") + + self._generate_output_files(all_triplets, global_max_relation) + logger.info("Processing completed successfully") + + def _read_file_list(self, file_list: Path) -> List[Path]: + """Read and validate the list of input files""" + input_files = [] + with open(file_list, "r") as f: + for line_num, line in enumerate(f, 1): + if line := line.strip(): + file_path = Path(line) + if file_path.exists(): + input_files.append(file_path) + else: + logger.warning(f"File not found (line {line_num}): {file_path}") + + if not input_files: + raise ValueError("No valid input files found") + return input_files + + def _generate_output_files(self, all_triplets: Set[str], max_relation: int) -> None: + """Generate the final output files""" + logger.info(f"Generating output files with {len(all_triplets)} unique triplets") + + # Write all output files -- train2id.txt, entity2id.txt, relation2id.txt + train2id_file = os.path.join(self.output_dir, "train2id.txt") + entity2id_file = os.path.join(self.output_dir, "entity2id.txt") + relation2id_file = os.path.join(self.output_dir, "relation2id.txt") + + with open(train2id_file, "w") as f: + f.write(f"{len(all_triplets)}\n") + f.writelines(f"{triplet}\n" for triplet in all_triplets) + + self._generate_entity2id(entity2id_file) + self._generate_relation2id(relation2id_file, max_relation) + + def _generate_entity2id(self, output_file: Path) -> None: + """Generate entity2id.txt using llvm-ir2vec""" + subprocess.run( + [str(self.ir2vec_binary), "--mode=entities", "-o", str(output_file)], + check=True, + capture_output=True, + ) + + def _generate_relation2id(self, output_file: Path, max_relation: int) -> None: + """Generate relation2id.txt from max relation""" + max_relation = max(max_relation, 1) # At least Type and Next relations + num_relations = max_relation + 1 + + with open(output_file, "w") as f: + f.write(f"{num_relations}\n") + f.write("Type\t0\n") + f.write("Next\t1\n") + f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations)) + + +def main(): + """Main entry point""" + parser = argparse.ArgumentParser( + description="Generate IR2Vec triplets from LLVM IR files", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "llvm_build_dir", type=Path, help="Path to LLVM build directory" + ) + parser.add_argument( + "num_optimizations", + type=int, + help="Number of optimization levels to apply (1-6)", + ) + parser.add_argument( + "ll_file_list", + type=Path, + help="File containing list of LLVM IR files to process", + ) + parser.add_argument( + "output_dir", type=Path, help="Output directory for generated files" + ) + parser.add_argument( + "-j", + "--max-workers", + type=int, + default=DEFAULT_MAX_WORKERS, + help=f"Maximum number of parallel workers (default: {DEFAULT_MAX_WORKERS})", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable debug logging" + ) + parser.add_argument( + "-q", "--quiet", action="store_true", help="Suppress all output except errors" + ) + + args = parser.parse_args() + + # Configure logging + level = ( + logging.ERROR + if args.quiet + else (logging.DEBUG if args.verbose else logging.INFO) + ) + logging.basicConfig( + level=level, + format="[%(asctime)s] %(levelname)s: %(message)s", + datefmt="%H:%M:%S", + ) + + try: + generator = IR2VecTripletGenerator( + args.llvm_build_dir, + args.num_optimizations, + args.output_dir, + args.max_workers, + ) + generator.generate_triplets(args.ll_file_list) + except Exception as e: + logger.error(f"Error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits