Training Infrastructure of a Text-to-Text Generation System
Explore the design of training infrastructure for text-to-text generation systems. Learn key requirements, select appropriate models like Llama 3.2 3B, and understand data acquisition, preprocessing, and distributed training. This lesson covers computational resource estimation, training processes, and model evaluation to build efficient conversational AI.
We'll cover the following...
Text-to-text LLMs are a subset of language models. Unlike their predecessors, which were primarily designed for text generation or translation tasks, conversational LLMs are specifically trained to engage in interactive dialogue. They can understand user input and generate human-like responses, making them ideal for applications like chatbots, virtual assistants, and interactive storytelling.
These are the brains behind those friendly AI assistants you interact with on websites or your smartphone. They are designed to understand your needs (even if you phrase them roundaboutly) and provide helpful, informative, and often entertaining responses.
Let’s see how we can design our own conversational AI. The first step is defining the requirements to guide the design process.
Requirements
Building the backend for a robust conversational AI system requires careful consideration of both functional and nonfunctional requirements.
Functional requirements
Natural language understanding: The system must decipher the meaning behind user input, including identifying
,intent Intent refers to the purpose or goal behind a user's query (e.g., asking for information or making a request). , andentities Entities are specific pieces of information extracted from the input, such as names, locations, or dates. . Imagine asking your AI assistant, “What’s the weather like in London tomorrow?” The system needs to understand that you’re asking about the weather (intent), that “London” is the location (entity), and “tomorrow” is the time (entity).sentiment Sentiment is the emotional tone or attitude conveyed in the input, which can range from positive to negative or neutral. Recognizing sentiment enables the system to tailor responses appropriately.
We can also look at an example of sentiment in a query. For instance, if the user says, “I’m so excited about the sunny weather tomorrow in London!” the system should extract:
Intent: The user is expressing enthusiasm about the weather.
Entities: London is the location, and tomorrow is the time.
Sentiment: The user’s sentiment is positive, as their excitement shows.
Dialogue management: The system must effectively manage conversations by retaining relevant information from previous interactions (
) and maintaining an awareness of the conversation’s progress (context retention Context retention refers to the ability of the system to store and recall relevant details from earlier in the conversation, such as user preferences, prior topics discussed, or incomplete tasks, to provide coherent and personalized responses. ). This includes keeping track of user preferences, remembering recent topics, and understanding when to revisit or conclude a topic based on the conversation’s flow.state management State management is the process of tracking the current state of the dialogue, including the conversation's flow, user intents, and unresolved queries, to ensure logical progression and appropriate responses. Natural language generation: Once the system (LLM) understands the user’s input and the conversation context, it needs to respond accurately to the query.
Personalization: The system should also be capable of tailoring responses based on user preferences and historical interaction.
Modern conversational bots now include the ability to tailor their responses to each user. For example, we can tell Gemini to remember that our name is ABC, and it will remember that whenever we chat. We will see how LLMs can maintain memory in the next lesson.
Nonfunctional requirements
Low latency: The system should be optimized to minimize latency and provide a seamless conversational experience.
Note: There can be trade-offs between latency and accuracy. For instance, achieving faster responses might mean sacrificing some degree of accuracy, as complex computations or larger models may require more processing time. Balancing latency and accuracy is essential, especially in applications where real-time interaction is critical, yet the accuracy of information remains important.
Scalability: As the user base grows, the system needs to handle the increased demand without compromising performance. This means efficiently processing a large volume of requests concurrently.
Availability: The text generation model should be accessible and operational whenever users need it. This means minimizing downtime and ensuring consistent uptime.
Reliability: The model should give dependable and legitimate responses.
Security: Protecting user data and ensuring privacy is paramount. User data typically includes personally identifiable information (PII) and, importantly, the user’s inputs to the system. Strong security measures must be implemented to safeguard this sensitive information.
Additionally, preventing
Note: User inputs may also be used as training data for the model. Transparency about such practices is critical to maintaining user trust and complying with ethical and legal standards.
With our requirements decided, we can now discuss how to pick a GenAI model that can fulfill our system’s needs.
Model selection
Building a conversational AI requires careful selection of the base language model, balancing capabilities with efficiency and cost-effectiveness. For this design, we’ll use the Llama 3.2 3B model, a 3-billion-parameter LLM optimized for handling natural language inputs.
Let’s understand the reasons behind choosing this model for our use case:
Open source: Llama models provide flexibility and control. We can go into the model’s architecture, fine-tune it extensively, and adapt it precisely to our conversational needs without restrictions. The open-source nature of the Llama model provides the freedom to experiment with different training techniques and modify the model architecture if needed.
Smaller size, greater efficiency: The 3B parameter size balances capability and efficiency. It’s significantly smaller than its 11B and 90B counterparts, making it more manageable for training and deployment, especially when resources might be limited. This translates to faster training, reduced computational costs, faster inference times, and easier deployment (since the model size is smaller, it takes less memory to store it).
Accuracy: Despite its smaller size, Llama 3.2 3B demonstrates impressive accuracy on various language tasks, including dialogue generation and comprehension. It exhibits a good understanding of conversational nuances and can generate human-like responses.
We’ll be training the Llama model (
Now that we have selected our model, we can talk about training this model from scratch.
The training process
Let’s examine the process of training a text generation model. When working with large datasets, there are inherent challenges that need attention. In conversational datasets, data quality issues—such as incomplete sentences, slang, or inconsistencies in dialogue formatting—can affect model performance and reliability. Ensuring
Data acquisition and preparation
High-quality, relevant data is essential for training a conversational AI that understands language, generates meaningful responses, and engages in natural-sounding dialogue. There are several steps we usually take in data acquisition and preparation:
Collect or fetch the dataset.
Preprocess the dataset:
Remove irrelevant or offensive content.
Ensure consistency by formatting or converting the data.
Store the processed dataset.
Let’s look at each of these steps in detail.
Source of training data
We can choose from multiple sources. For this application, we can choose a publicly available dataset that includes human conversations and real-world knowledge. Let’s assume this dataset is about 50 GB with 200 million rows of text.
Dataset format
For text generation, we require the dataset to be in the following format:
Prompt/Query | Response (Label) |
What is the capital of Switzerland? | The capital of Switzerland is Bern. |
What is the capital of Japan? | The capital of Japan is Tokyo. |
Is Earth a star? | No, Earth is not a star. It is a planet located in the Sun’s habitable zone. |
... | ... |
Note: Many off-the-shelf datasets will come in formats that may not be suitable for training immediately. We may need to remove or merge some columns (features) to prepare the dataset for training.
Data preprocessing
Here are some considerations we will need to make and aspects we will need to review:
Topics: Evaluate the distribution of topics within the dataset to ensure balanced representation. For instance, identify whether specific topics—like hobbies, current events, or social issues—appear more frequently than others. This can help prevent the model from biasing toward certain topics and ensure it responds well across various subjects.
Demographics: Analyze the language used for potential biases regarding sex, age, ethnicity, or socioeconomic status. We consider metrics like portraying certain groups in stereotypical ways or excluding them altogether.
Sentiment and emotion: Examine the distribution of sentiment and emotions expressed in the dialogues. Is there an imbalance toward positive or negative sentiment? Are certain emotions expressed more frequently by specific groups?
Note: These are just some aspects we need to consider when handling text data. They might change depending on your use case and data. Others include
, translating multilingual data, and removing rare word substitution Replacing rare or unknown words with placeholders or more common words. . noise Eliminating extraneous characters, such as HTML tags, emojis (if irrelevant), or unnecessary punctuation.
This often requires manual reviews. We can train a separate sentiment analysis model to streamline this process, but expert review may still be necessary. If we do find discrepancies, here are a few techniques to mitigate bias:
Generate synthetic data: Create new text data that address underrepresented topics or demographics, ensuring a more balanced and inclusive dataset.
Downsample overrepresented groups: If certain groups are significantly overrepresented, consider downsampling their data to create a more balanced distribution.
Remove offensive content: Filter out data containing hate speech, profanity, or other offensive language that could perpetuate harmful biases.
With this, we can move on to processing our unbiased data.
Before we can move on to training, we need to process the data to ensure consistency. This will involve:
Formatting: Ensuring consistent formatting and removing any unnecessary annotations.
Handling special characters and cases: Standardizing punctuation and converting text to lowercase for consistency.
Then, we need to store this data in some database for efficient retrieval, as the training process will likely take a substantial amount of time, as discussed later.
Transformers like GPT are usually pretrained on massive datasets before being adapted to specific applications. What is the difference between pretraining and training in this context?
Write your answer in the widget below.
Database selection
We’ll use a vector database to handle efficient storage and retrieval of conversational data embeddings. After preprocessing, our data will be converted into vector embeddings and stored as vectors, each representing a conversation segment.
Factors influencing the choice of vector database include retrieval speed for a responsive conversational experience and robust indexing to optimize search and retrieval of relevant conversations. These considerations ensure our database choice aligns with the demands of real-time conversational AI applications.
Learn more about database selection here.
To perform all of this preprocessing, we need to estimate the time it will take to complete this task and determine our resource requirements.
Time estimation for data gathering and processing
Let’s use some approximate values for our dataset:
Assuming:
We get:
So the total time is:
So, one machine would take about 117 hours to process all of this data at the speeds we assumed. To speed up this process, we can divide the data into multiple machines in parallel. For example, 16 machines would only take about 7.3 hours to process the 50 GB dataset.
Note: In this case, splitting this data onto the 16 processing servers will also take time, but since we are calculating a rough estimate, we can use these numbers.
Is there a way to further reduce the time it takes to process the data?
With our data training ready and stored, we can now discuss using this data to train the model.
Model training and testing
To train our model on the dataset selected, we first need to estimate its training time by performing some back-of-the-envelope calculations. Let’s say we are using NVIDIA A100 80GB GPUs. These graphical processing units (GPUs) have 156
Note: The RAM on these GPUs matters a lot for storing the model and the data used for training. If we cannot store these directly on the GPU, we may need to connect a database to the GPU to hold the data. However, this will also introduce latency into the system. The data is around 50 GB, and the model is around 12 GB (discussed later), so we can store them in a single A100 GPU if needed.
When training a model, selecting the right
The learning rate is the step size at which a model updates its weights during training. It controls how quickly the model updates weights, impacting convergence and stability.
Batch size is the number of training samples processed in a single iteration. It affects memory usage and the stability of gradient updates.
Setting the number of epochs—the total passes through the training data—dictates the duration and depth of training.
Typically, training continues until
To estimate the time taken for training, we can use the formula created earlier:
During training, text-to-text models usually generate the entire output in one time step. So, we set the scaling variable
If using one GPU and training for 1,000 epochs, training the model will take ~267 days.
This formula allows you to experiment using the calculator widget in the section below.
Distributing the training load
As we can see, using only one server for a small training cycle of only 1,000 epochs results in an infeasible training time. We would want to train the model for far more epochs. This is where we can utilize distributed machine learning.
Since we already have the data stored in our vector database, we can split up that data using data parallelism techniques to train concurrently. Here is what that will look like:
The time estimate when training for 1,000 epochs using 32 training servers is given below:
| A | B | C | |
| 1 | Number of parameters | 3 | Billion |
| 2 | FLOPS per parameter update | 6 | FLOPS/parameter |
| 3 | Number of Epochs | 1000 | epochs |
| 4 | Size of data | 200 | Million entries |
| 5 | Time steps (C) | 1 | Steps |
| 6 | FLOPS for training | f3600000000 | TFLOPS |
| 7 | Rate of operations | 156 | TFLOPS |
| 8 | Number of GPUs | 32 | NVIDIA A100 GPUs |
| 9 | Training time (in seconds) | f721153.85 | Seconds |
| 10 | Training time (in days) | f8.35 | Days |
Try changing the values highlighted in yellow to see the effect on the time.
After training, the most essential step is to evaluate the model to ensure its output meets the criteria for our system. If the model is not up to the mark, the later steps, e.g., system design, deployment, and inference optimization, will be unproductive.
Model evaluation
To evaluate our text generation model during and after training, we need to calculate metrics like perplexity, BLEU, ROUGE score, and loss on the validation set to evaluate our text generation model during and after training. Steady decreases in validation perplexity and loss generally indicate learning progress, while stagnation or an increase can signal overfitting.
Once the model is trained, we can
The image below summarizes the complete training process from data acquisition to training and saving the model weights.
Now, let’s look at how we can design the deployment infrastructure for this model to serve users.