Source code for rag_skeleton.generation

# src/rag_skeleton/generation.py
import torch
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from langchain_huggingface import HuggingFaceEndpoint

[docs] class TextGenerator: """ Generates text responses using a specified LLM model. """ def __init__(self, model_name="meta-llama/Llama-3.2-3B-Instruct", device=None, load_mode="local", api_token=None): """ Initializes the TextGenerator with a model name, device setting, and load mode. Parameters: - model_name: str, HuggingFace model ID. - device: int or str, device to run the model on (e.g., "cuda" for GPU, -1 or "cpu" for CPU). - load_mode: str, whether to load the model locally or from the Hugging Face API ("local" or "api"). - api_token: str, Hugging Face API token, required if load_mode is "api". """ self.model_name = model_name self.load_mode = load_mode self.api_token = api_token self.device = device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu") self.model = None self.tokenizer = None
[docs] def load_model(self): """ Loads the LLM model and tokenizer and sets up the text generation pipeline. If `load_mode` is set to "api", it uses the Hugging Face API to load the model and requires an API token. If `load_mode` is "local", it loads the model and tokenizer locally from the Hugging Face repository. Raises: ValueError: If `load_mode` is "api" and `api_token` is not provided. Configurations: - For both "api" and "local" modes, specific parameters such as `temperature`, `do_sample`, `repetition_penalty`, and `max_new_tokens` are set to control text generation behavior. - For the local model, the `eos_token_id` parameter is set to stop generation at specified tokens, ensuring response clarity. """ if self.load_mode == "api": if not self.api_token: raise ValueError("Hugging Face API token is required for 'api' load mode.") # Use HuggingFaceEndpoint for loading the model via API self.llm = HuggingFaceEndpoint( endpoint_url=f"https://api-inference.huggingface.co/models/{self.model_name}", huggingfacehub_api_token=self.api_token, task="text-generation", temperature=0.3, do_sample=True, repetition_penalty=1.1, return_full_text=False, max_new_tokens=1000 ) else: # Load the model and tokenizer locally self.model = AutoModelForCausalLM.from_pretrained(self.model_name).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) # Set pad_token_id to eos_token_id to avoid warnings if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id if self.model.config.pad_token_id is None: self.model.config.pad_token_id = self.tokenizer.eos_token_id # Define terminators for stopping generation at end of response terminators = [ self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>") ] # Set up the text generation pipeline text_generation_pipeline = pipeline( model=self.model, tokenizer=self.tokenizer, task="text-generation", device=self.device, temperature=0.3, do_sample=True, repetition_penalty=1.1, return_full_text=False, max_new_tokens=1000, eos_token_id=terminators ) # Wrap the pipeline for LangChain compatibility self.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)