""" Class for correcting text using a pretrained model grammar synthesis model. - models are available here: https://hf.co/models?other=grammar%20synthesis requirements for this snippet: pip install -U transformers accelerate NOTE: if you want to use 9-bit to fit the model on a smaller GPU, you need bitsandbytes: pip install -U transformers accelerate bitsandbytes """ import warnings import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer class GrammarSynthesizer: """ Class for correcting text using a pretrained model grammar synthesis model. models are available here: https://hf.co/models?other=grammar%20synthesis # Example usage with the XL corrector = GrammarSynthesizer("pszemraj/flan-t5-xl-grammar-synthesis") raw_text = 'sky is blu.' results = corrector(raw_text, num_beams=2) print(results) """ DEFAULT_MAX_INPUT_LENGTH = 384 DEFAULT_MAX_LENGTH = 128 DEFAULT_NUM_BEAMS = 4 def __init__( self, model_name_or_path: str, should_compile: bool = True, load_in_8bit: bool = False, ): """ Initializes the GrammarSynthesizer. Args: model_name_or_path: The name or path of the pretrained model. should_compile: If True, tries to compile the model for faster inference. load_in_8bit: If True, loads model in 8-bit precision (lower memory usage). requires bitsandbytes """ self.model_name_or_path = model_name_or_path self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.model = self._load_and_compile_model(model_name_or_path, should_compile) def _load_and_compile_model( self, model_name_or_path: str, should_compile: bool, load_in_8bit: bool ): """ Load and compile the model. Args: model_name_or_path: The name or path of the pretrained model. should_compile: If True, tries to compile the model for faster inference. load_in_8bit: If True, loads model in 8-bit precision (lower memory usage). requires bitsandbytes Returns: The loaded and potentially compiled model. """ model = AutoModelForSeq2SeqLM.from_pretrained( model_name_or_path, load_in_8bit=load_in_8bit, device_map="auto" ) if should_compile: try: model = torch.compile(model) except Exception as e: print(f"Unable to compile model for faster inference. Reason: {e}") should_compile = False self.compiled_model = should_compile return model def _prepare_inputs(self, input_text: str): """ Prepares the inputs for the model. Args: input_text: The input text to prepare. Returns: The prepared inputs. """ inputs = self.tokenizer.encode(input_text, return_tensors="pt").to( self.model.device ) if len(inputs) > self.DEFAULT_MAX_INPUT_LENGTH: warnings.warn( "Input is longer than model training data. Unexpected behavior may occur. " "Consider batch-processing smaller chunks." ) return inputs def generate_text( self, input_text: str, max_length: int = DEFAULT_MAX_LENGTH, num_beams: int = DEFAULT_NUM_BEAMS, **kwargs, ): """ Generates text from the input. Args: input_text: The input text to generate from. max_length: The maximum length of the generated text. num_beams: The number of beams for beam search. Returns: The generated text. """ if len(input_text) < 2: warnings.warn( f"input text is too short to correct, returning:\t{input_text}" ) return input_text inputs = self._prepare_inputs(input_text) outputs = self.model.generate( inputs, max_length=max_length, num_beams=num_beams, **kwargs ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) def __call__(self, input_text: str, **kwargs): return self.generate_text(input_text, **kwargs)