Comment Classification with BERT (Sequential Transfer Learning)¶
For educational purposes, consider a hypothetical online retailer launching an innovative service. This service enables users to collaboratively edit and enhance product descriptions, similar to the contributions found in wiki communities. Customers can suggest changes and comment on others’ edits, and to maintain a respectful and safe environment, the retailer requires a tool that detects toxic comments and routes them for moderation.
Project Overview¶
Objective:
- Develop a model that classifies comments as either positive or negative using a dataset labeled for toxicity.
- The model must achieve an F1 score of at least
0.75
.
Data Description:
The dataset is stored in the file /datasets/toxic_comments.csv
.
- The column text contains the comment text.
- The column toxic is the target label indicating toxicity.
Approach:
- Load and Prepare the Data.
- Evaluation and Conclusions.
Preparation¶
Dependencies¶
import os
import hashlib
from dataclasses import dataclass, field
from enum import Enum
import nltk
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from matplotlib import pyplot as plt
from sklearn.metrics import (
classification_report, confusion_matrix, f1_score, roc_auc_score, roc_curve
)
from sklearn.model_selection import train_test_split
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification
from tqdm import tqdm
from google.colab import drive
mount_files = True
if mount_files:
drive.mount('/content/drive')
Mounted at /content/drive
Constants¶
STL_ID = 3 # Sequential Transfer Learning (start with 0)
SUBSETS = 4
VERSION = 2
ID_STEP = 1
SEED = 42
STOPWORDS_STL_ID = 0
TEST_SIZE = 0.3
PRODUCTION_TEST_SIZE = 0.2
MODEL_CACHE = 'drive/MyDrive/datasets/.cache'
DATASETS_PATH = 'https://code.s3.yandex.net/datasets/toxic_comments.csv'
DATASETS_PATH_LOCAL = 'drive/MyDrive/datasets/toxic_comments.csv'
NLTK_DATA_PATH = ('drive', 'MyDrive', 'datasets', 'nltk_data')
INFERENCE_ONLY = False
PRODUCTION_N_SAMPLES = 2000
PRODUCTION_N_SAMPLES_BOOTSTRAP = 800
BOOTSTRAP_N_SAMPLES = 600
BR = '\n'
f'Version {VERSION}.{ID_STEP*STL_ID}.{SUBSETS}'
'Version 2.3.4'
Hyperparameters¶
class InfraParameter(Enum):
BATCH_SIZE = 128
MAX_LENGTH = 256 + 64
EPOCHS = 8
class MetaParameter(Enum):
L_RATE_MIN = 1e-6
L_RATE = 1e-5
L_RATE_MAX = 5e-5
L_RATE_SUBSET_ID_DECREASE = -1e-6
@dataclass
class HyperParameter:
infra: InfraParameter = InfraParameter
meta: MetaParameter = MetaParameter
subset_id: int = ID_STEP * STL_ID
warm_up: tuple[int, int] = (0, 10)
def lr(self) -> float:
addapt = self.meta.L_RATE_SUBSET_ID_DECREASE.value * self.subset_id
if self.subset_id <= self.warm_up[0]:
addapt = self.meta.L_RATE_SUBSET_ID_DECREASE.value * self.warm_up[1] * -1
lr_ = max(self.meta.L_RATE.value + addapt, self.meta.L_RATE_MIN.value)
return min(lr_, self.meta.L_RATE_MAX.value)
def epochs(self, threshold: int = 1, add: int = 1) -> int:
hp_epochs = self.infra.EPOCHS.value
return hp_epochs + add if self.subset_id < threshold else hp_epochs
def lr_development(self) -> None:
lrs = [HyperParameter(subset_id=i).lr() for i in range(SUBSETS)]
plt.figure(figsize=(12, 3))
bar_colors, bar_colors[self.subset_id] = ['gray'] * SUBSETS, 'green'
plt.bar(range(SUBSETS), lrs, color=bar_colors, alpha=0.4, width=0.5)
act_lr = f'{lrs[self.subset_id]:.2e}'
plt.text(self.subset_id, lrs[self.subset_id], act_lr, va='top')
plt.title('Learning Rate Development')
plt.xlabel('Subset ID')
plt.ylabel('Learning Rate')
plt.show()
Sequential Transfer Learning¶
hp = HyperParameter()
hp.lr_development()
This project is focused on developing a robust toxic comment classification model using Sequential Transfer Learning (STL) to adapt training on data subsets.
Sequential Transfer Learning (STL) is an approach that enables the model to learn iteratively on smaller, manageable subsets of data, which is especially useful when working with large datasets and accounting for potential data drift over time.
The process begins with a warm start: the first subset is trained without stop-words and with a warm (warm-up learning rate) training speed. This allows the model to adapt to the basic structure of the data while maintaining its generalization capability. For subsequent subsets, the learning rate is gradually decreased, enabling more precise tuning of the model training and reducing the risk of overfitting. Each following subset is built on the knowledge (pre-training) acquired from the previous stage, creating an iterative improvement process that mimics pre-training while incorporating controlled fine-tuning at each stage. This approach ensures a structured and adaptive training process suited to the complexity of large datasets. Moreover, STL offers a systematic method for adapting the model to changing data distributions (for example, data drift), which contributes to the model's improved performance over time. The use of subsets also simplifies the gradual tuning of hyperparameters, such as adjusting the learning rate, allowing for more precise model optimization and enhanced performance at each stage of training.
Auxiliary Classes¶
@dataclass
class Log:
time_start: pd.Timestamp = None
def set_time(self):
self.time_start = pd.Timestamp.now()
def get_time_h(self, print_: bool = False) -> float:
t_ = round((pd.Timestamp.now() - self.time_start).seconds / 3600, 2)
if not print_:
return t_
print(f'Elapsed time: {t_} hours')
@staticmethod
def value_counts(
df_cnt: pd.DataFrame, column: str = 'target', cnt: str = 'cnt'
) -> None:
df_cnt = df_cnt[column].value_counts().to_frame()
df_cnt.columns = [cnt]
display(df_cnt)
@staticmethod
def censored(log_df: pd.DataFrame, cols: str = 'target-text', hd: int = 8) -> None:
prefix = '*** [ positive ] *** | HASH: '
log_df_, c0, c1 = log_df.copy(), cols.split('-')[0], cols.split('-')[1]
log_df_.loc[log_df_[c0] == 1, c1] = log_df_.loc[log_df_[c0] == 1, c1].apply(
lambda x: prefix + hashlib.sha256(str(x).encode()).hexdigest()
)
display(log_df_.head(hd))
@dataclass
class DataPreparation:
target: str
seed: int
size: float
is_under_sample: bool = False
schema: tuple[str, str] = ('text', 'target')
cache: dict = field(default_factory=dict)
def key_value_structure(self, dfm: pd.DataFrame, km: str, vm: str) -> pd.DataFrame:
schema = {km: self.schema[0], vm: self.schema[1]}
self.target, self.cache['target_name'] = self.schema[1], vm
return dfm[[km, vm]].copy().rename(columns=schema)
def split(self, *series: pd.Series, **conf) -> tuple[pd.Series, ...]:
conf['test_size'] = conf.get('test_size', self.size)
return train_test_split(*series, random_state=self.seed, **conf)
def sub_samples(self, dfm: pd.DataFrame, n_subsamples: int) -> list[pd.DataFrame]:
dfm = dfm.copy()
split_size = len(dfm) // n_subsamples
subsamples = [
dfm.iloc[i * split_size: (i + 1) * split_size].reset_index(drop=True)
for i in range(n_subsamples)
]
if len(dfm) % n_subsamples != 0:
remaining = dfm.iloc[n_subsamples * split_size:].reset_index(drop=True)
subsamples[-1] = pd.concat([subsamples[-1], remaining]).reset_index(
drop=True
)
return subsamples
@staticmethod
def join(x_: pd.Series, y_: pd.Series) -> pd.DataFrame:
return pd.concat([x_, y_], axis=1)
@staticmethod
def text_target(dfm: pd.DataFrame) -> tuple[pd.Series, pd.Series]:
dfm = dfm.copy()
return dfm['text'], dfm['target']
EDA (Unstructured Data)¶
df_raw = pd.read_csv(DATASETS_PATH)
df_raw.info() # no missing values
Log.censored(df_raw, 'toxic-text')
<class 'pandas.core.frame.DataFrame'> RangeIndex: 159292 entries, 0 to 159291 Data columns (total 3 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Unnamed: 0 159292 non-null int64 1 text 159292 non-null object 2 toxic 159292 non-null int64 dtypes: int64(2), object(1) memory usage: 3.6+ MB
Unnamed: 0 | text | toxic | |
---|---|---|---|
0 | 0 | Explanation\nWhy the edits made under my usern... | 0 |
1 | 1 | D'aww! He matches this background colour I'm s... | 0 |
2 | 2 | Hey man, I'm really not trying to edit war. It... | 0 |
3 | 3 | "\nMore\nI can't make any real suggestions on ... | 0 |
4 | 4 | You, sir, are my hero. Any chance you remember... | 0 |
5 | 5 | "\n\nCongratulations from me as well, use the ... | 0 |
6 | 6 | *** [ positive ] *** | HASH: 6e4d3584d34a8a9e2... | 1 |
7 | 7 | Your vandalism to the Matt Shirvington article... | 0 |
vc_target = df_raw['toxic'].value_counts().to_frame()
vc_target['percent'] = round(vc_target / vc_target.sum() * 100, 2)
vc_target # imbalanced target (90% non-toxic)
count | percent | |
---|---|---|
toxic | ||
0 | 143106 | 89.84 |
1 | 16186 | 10.16 |
plt.figure(figsize=(12, 4))
for value, color in [(1, 'red'), (0, 'gray')]:
label = f'Toxic = {value}'
df_raw[df_raw['toxic'] == value]['text'].apply(lambda x: len(str(x))).plot(
kind='hist', bins=200, range=(0, 3000), alpha=0.5, label=label, color=color
)
plt.title('Distribution of Token (Approximation) Count')
plt.xlabel('Word Count')
plt.ylabel('Frequency')
plt.legend()
plt.show()
Data Splitting I¶
dp = DataPreparation(target='toxic', seed=SEED, size=TEST_SIZE)
df = dp.key_value_structure(df_raw, 'text', 'toxic')
Log.censored(df)
text | target | |
---|---|---|
0 | Explanation\nWhy the edits made under my usern... | 0 |
1 | D'aww! He matches this background colour I'm s... | 0 |
2 | Hey man, I'm really not trying to edit war. It... | 0 |
3 | "\nMore\nI can't make any real suggestions on ... | 0 |
4 | You, sir, are my hero. Any chance you remember... | 0 |
5 | "\n\nCongratulations from me as well, use the ... | 0 |
6 | *** [ positive ] *** | HASH: 6e4d3584d34a8a9e2... | 1 |
7 | Your vandalism to the Matt Shirvington article... | 0 |
production: dict[str, pd.Series] = {}
df_X_temp, production['X_test'], df_y_temp, production['y_test'] = dp.split(
*dp.text_target(df), test_size=PRODUCTION_TEST_SIZE
)
# Training, Validation, Test
# ==========================
df = dp.join(df_X_temp, df_y_temp)
# Production Test
# ===============
dp.cache['production'] = production
dp.cache['df_production'] = dp.join(production['X_test'], production['y_test'])
Log.value_counts(df, cnt='count_train_val_test')
Log.censored(df, hd=5)
df.shape, production['X_test'].shape, production['y_test'].shape
count_train_val_test | |
---|---|
target | |
0 | 114448 |
1 | 12985 |
text | target | |
---|---|---|
45155 | "\nYou claimed to have ""scavenged the UN and ... | 0 |
60904 | "\n\n Please do not vandalize pages, as you di... | 0 |
92242 | "\n\n ""largest moon"" \n\nShouldn't it say la... | 0 |
74757 | "\n\n Isn't baking cooking? \n\nAccording to t... | 0 |
7198 | I am sure the judges smiled too.\n\nWhen you c... | 0 |
((127433, 2), (31859,), (31859,))
Preprocessing¶
@dataclass
class StopWordsProcessor:
nltk_path: tuple[str, ...]
stop_words: set = field(init=False)
def __post_init__(self):
nltk_data_path = os.path.join(*self.nltk_path)
os.makedirs(nltk_data_path, exist_ok=True)
nltk.data.path.append(nltk_data_path)
nltk.download('stopwords', download_dir=nltk_data_path)
nltk.download('punkt_tab', download_dir=nltk_data_path)
self.stop_words = set(stopwords.words('english'))
def remove_stop_words(self, text: str) -> str:
words = word_tokenize(text)
filtered_words = [word for word in words if word.lower() not in self.stop_words]
return ' '.join(filtered_words)
def preprocess_dataframe(self, df: pd.DataFrame, rm: bool = True) -> pd.DataFrame:
if rm:
df = df.copy()
df['text'] = df['text'].apply(self.remove_stop_words)
return df
stop_words_processor = StopWordsProcessor(NLTK_DATA_PATH)
df = stop_words_processor.preprocess_dataframe(df, rm = STL_ID == STOPWORDS_STL_ID)
Log.censored(df)
[nltk_data] Downloading package stopwords to [nltk_data] drive/MyDrive/datasets/nltk_data... [nltk_data] Package stopwords is already up-to-date! [nltk_data] Downloading package punkt_tab to [nltk_data] drive/MyDrive/datasets/nltk_data... [nltk_data] Package punkt_tab is already up-to-date!
text | target | |
---|---|---|
45155 | "\nYou claimed to have ""scavenged the UN and ... | 0 |
60904 | "\n\n Please do not vandalize pages, as you di... | 0 |
92242 | "\n\n ""largest moon"" \n\nShouldn't it say la... | 0 |
74757 | "\n\n Isn't baking cooking? \n\nAccording to t... | 0 |
7198 | I am sure the judges smiled too.\n\nWhen you c... | 0 |
22361 | this is a government IP used by roughly 3,000 ... | 0 |
132882 | 12:13, 4 May 2012 User:217.217.197.24 | 0 |
154232 | gentlemen and gentlemen, ProKo has been revert... | 0 |
df_subsets = dp.sub_samples(df, SUBSETS)
df = df_subsets[hp.subset_id]
df = df.sample(frac=1, random_state=SEED).reset_index(drop=True)
Log.value_counts(df, cnt=f'count_train_val_test_with_subset_id_{hp.subset_id}')
Log.censored(df)
count_train_val_test_with_subset_id_3 | |
---|---|
target | |
0 | 28616 |
1 | 3243 |
text | target | |
---|---|---|
0 | Yes, the Mustang GT500 laptime is real \n\nThe... | 0 |
1 | User Sitush \n\nYou seem to be a genuine crusa... | 0 |
2 | RfAr notice \n\nYou are mentioned in Wikipedia... | 0 |
3 | In popular culture \n\nMossad is seen in many ... | 0 |
4 | 2010 (UTC)\n\nWelcome\n86.29.137.111 03:50, 1... | 0 |
5 | OK, Ngo Thanh Nhan just wrote to me after read... | 0 |
6 | *** [ positive ] *** | HASH: 9565b007350481f51... | 1 |
7 | Are you brooding about colours nuances, violet... | 0 |
Training¶
Data Splitting II¶
test: dict[str, pd.Series] = {}
X_train_val, test['X'], y_train_val, test['y'] = dp.split(*dp.text_target(df))
print(' ' * 3, 'train/val', 'test', sep=' ')
display(('X', X_train_val.shape, test['X'].shape))
'y', y_train_val.shape, test['y'].shape
train/val test
('X', (22301,), (9558,))
('y', (22301,), (9558,))
train: dict[str, pd.Series] = {}
val: dict[str, pd.Series] = {}
train['X'], val['X'], train['y'], val['y'] = dp.split(X_train_val, y_train_val)
print(' ' * 4, 'train', ' val', sep=' ')
display(('X', train['X'].shape, val['X'].shape))
'y', train['y'].shape, val['y'].shape
train val
('X', (15610,), (6691,))
('y', (15610,), (6691,))
Tokenization and Models¶
# Subset Determination
# ====================
name_or_path = 'bert-base-uncased'
if hp.subset_id > 0:
prev_subset_id = hp.subset_id - ID_STEP
name_or_path = f'{MODEL_CACHE}/model_v{VERSION}.{prev_subset_id}.{SUBSETS}'
# Tokenization and Load Model
# ===========================
tokenizer: BertTokenizer = BertTokenizer.from_pretrained(
name_or_path,
cache_dir=MODEL_CACHE
)
model = BertForSequenceClassification.from_pretrained(
name_or_path,
num_labels=2,
cache_dir=MODEL_CACHE
)
ml_info = f'Model "{name_or_path}" ready for SUBSET with ID: {hp.subset_id}'
print(BR, BR + ml_info, BR, sep=BR + '*' * 64)
'Total subsets: ' + str(SUBSETS)
**************************************************************** Model "drive/MyDrive/datasets/.cache/model_v2.2.4" ready for SUBSET with ID: 3 ****************************************************************
'Total subsets: 4'
def tokenize_function(text: list[str], hp_: HyperParameter) -> dict[str, torch.Tensor]:
kw_args = {
'padding': 'max_length',
'truncation': True,
'max_length': hp_.infra.MAX_LENGTH.value,
'return_tensors': 'pt'
}
return tokenizer(text, **kw_args)
Encoding¶
train_encodings = tokenize_function(train['X'].tolist(), hp)
val_encodings = tokenize_function(val['X'].tolist(), hp)
test_encodings = tokenize_function(test['X'].tolist(), hp)
sequence_lengths = [
len(tokenizer(text, truncation=True)['input_ids']) for text in train['X']
]
print(f'Average sequence length: {sum(sequence_lengths) / len(sequence_lengths)}')
print(BR, 'Percentiles', '=' * 20, sep=BR)
for percentile in [80, 90, 95]:
print(f'{percentile}th: {int(np.percentile(sequence_lengths, percentile))}')
'Sequence (truncated) length: ' + str(train_encodings['input_ids'].shape[1])
Average sequence length: 90.0457399103139 Percentiles ==================== 80th: 128 90th: 211 95th: 325
'Sequence (truncated) length: 320'
Dataset Setup¶
class ToxicDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return {key: val[idx] for key, val in self.encodings.items()}, self.labels[idx]
train_dataset = ToxicDataset(train_encodings, train['y'].tolist())
train_loader = DataLoader(
train_dataset, batch_size=hp.infra.BATCH_SIZE.value, shuffle=True
)
val_dataset = ToxicDataset(val_encodings, val['y'].tolist())
val_loader = DataLoader(
val_dataset, batch_size=hp.infra.BATCH_SIZE.value, shuffle=False
)
Training (BERT)¶
Training Functions¶
def train_one_epoch(
model: torch.nn.Module,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
loss_fn: torch.nn.Module,
device: torch.device
) -> tuple[list[int], list[int], float]:
model.train()
all_preds, all_labels = [], []
progress_bar = tqdm(train_loader, desc="Training", leave=False)
for batch in progress_bar:
inputs, labels = batch
inputs = {key: value.to(device) for key, value in inputs.items()}
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(**inputs, labels=labels)
loss = loss_fn(outputs.logits, labels)
loss.backward()
optimizer.step()
progress_bar.set_postfix({'loss': loss.item()})
preds = outputs.logits.argmax(dim=-1).detach().cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
train_f1 = f1_score(all_labels, all_preds, average='binary')
return all_preds, all_labels, train_f1
def validate_model(
model: torch.nn.Module, val_loader: DataLoader, device: torch.device
) -> tuple[list[int], list[int], float]:
model.eval()
val_preds, val_labels = [], []
with torch.no_grad():
for batch in val_loader:
inputs, labels = batch
inputs = {key: value.to(device) for key, value in inputs.items()}
labels = labels.to(device)
outputs = model(**inputs)
preds = outputs.logits.argmax(dim=-1).detach().cpu().numpy()
val_preds.extend(preds)
val_labels.extend(labels.cpu().numpy())
val_f1 = f1_score(val_labels, val_preds, average='binary')
return val_preds, val_labels, val_f1
Device¶
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
if torch.cuda.is_available():
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
device
1 NVIDIA A100-SXM4-40GB
device(type='cuda')
Loss Function¶
class_counts = np.bincount(train['y']) # 90% non-toxic, 10% toxic
total_samples = class_counts.sum()
class_weights = torch.tensor(total_samples / class_counts, dtype=torch.float32)
loss_fn = CrossEntropyLoss(weight=class_weights).to(device)
Training Configuration¶
optimizer = torch.optim.AdamW(model.parameters(), lr=hp.lr())
patience = 2 # Number of epochs to wait before stopping
f1_epoch_factor = 0.01
wait = 0
best_val_f1 = float('-inf')
best_model_state = None
log = Log()
Early Stopping Training¶
log.set_time()
for epoch in range(hp.epochs()):
if INFERENCE_ONLY:
print('Set "INFERENCE_ONLY = False" for model training')
break
print(f'Epoch {epoch + 1} of {hp.epochs()}')
train_preds, train_labels, train_f1 = train_one_epoch(
model, train_loader, optimizer, loss_fn, device=device
)
print(f'Epoch {epoch + 1} Training F1 Score: {train_f1:.4f}')
_, _, val_f1 = validate_model(model, val_loader, device=device)
print(f'Epoch {epoch + 1} Validation F1 Score: {val_f1:.4f}')
if val_f1 > best_val_f1:
best_model_state = model.state_dict()
print(f'Model (state) with validation F1 {val_f1} is updated.')
if val_f1 >= (best_val_f1 + (f1_epoch_factor * epoch)):
wait = 0
else:
wait += 1
print(f'No significant improvement in validation F1 for {wait} epoch(s).')
if val_f1 > best_val_f1:
best_val_f1 = val_f1
if wait >= patience:
print(f'Early stopping triggered after {epoch + 1} epochs.')
break
log.get_time_h(print_=True)
if best_model_state:
model.load_state_dict(best_model_state)
print(f'Loaded best model with Validation F1: {best_val_f1:.4f}')
else:
print('(!) - No improvement during training. Model not updated.')
Epoch 1 of 8
Epoch 1 Training F1 Score: 0.7505 Epoch 1 Validation F1 Score: 0.7700 Model (state) with validation F1 0.76996996996997 is updated. Epoch 2 of 8
Epoch 2 Training F1 Score: 0.8171 Epoch 2 Validation F1 Score: 0.7538 No significant improvement in validation F1 for 1 epoch(s). Epoch 3 of 8
Epoch 3 Training F1 Score: 0.8533 Epoch 3 Validation F1 Score: 0.7903 Model (state) with validation F1 0.7903123008285532 is updated. Epoch 4 of 8
Epoch 4 Training F1 Score: 0.8926 Epoch 4 Validation F1 Score: 0.7908 Model (state) with validation F1 0.7908455181182454 is updated. No significant improvement in validation F1 for 1 epoch(s). Epoch 5 of 8
Epoch 5 Training F1 Score: 0.9173 Epoch 5 Validation F1 Score: 0.8088 Model (state) with validation F1 0.8088235294117647 is updated. No significant improvement in validation F1 for 2 epoch(s). Early stopping triggered after 5 epochs. Elapsed time: 0.26 hours Loaded best model with Validation F1: 0.8088
Saving the Model¶
path_pretrained = f'{MODEL_CACHE}/model_v{VERSION}.{hp.subset_id}.{SUBSETS}'
if not INFERENCE_ONLY:
model.save_pretrained(path_pretrained)
tokenizer.save_pretrained(path_pretrained)
path_pretrained
'drive/MyDrive/datasets/.cache/model_v2.3.4'
Testing¶
Testing Functions¶
def inference(
encodings: dict[str, torch.Tensor],
model: BertForSequenceClassification,
target: pd.Series,
batch_size: int = 8,
threshold: float = 0.5,
) -> tuple[list[int], list[int], list[float]]:
model.eval()
device = next(model.parameters()).device
dataset = ToxicDataset(encodings, target.tolist())
probabilities, predictions, true_labels = [], [], []
with torch.no_grad():
for batch in DataLoader(dataset, batch_size=batch_size):
inputs, labels = batch
inputs = {key: value.to(device) for key, value in inputs.items()}
labels = labels.to(device)
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)[:, 1].tolist()
probabilities.extend(probs)
batch_predictions = [1 if prob >= threshold else 0 for prob in probs]
predictions.extend(batch_predictions)
true_labels.extend(labels.tolist())
return predictions, true_labels, probabilities
def report(true_labels, proba, threshold: float = 0.5):
proba_pred = (np.array(proba) >= threshold).astype(int)
cm = confusion_matrix(true_labels, proba_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
print(classification_report(true_labels, proba_pred))
def production_preparation(
text: pd.Series,
target: pd.Series,
seed: int,
sw: StopWordsProcessor,
n: int = None,
rm: bool = True
) -> tuple[pd.Series, pd.Series, pd.DataFrame]:
df_production = pd.concat([text, target], axis=1)
df_production = sw.preprocess_dataframe(df_production, rm)
if n is not None:
df_production = df_production.sample(n=n, random_state=seed)
df_production = df_production.reset_index(drop=True)
return df_production['text'], df_production['target'], df_production.copy()
def production_pipeline(
text: pd.Series,
target: pd.Series,
seed: int,
hp_: HyperParameter,
sw: StopWordsProcessor,
n: int = None,
rm: bool = True,
) -> tuple[list[int], list[int], pd.DataFrame, list[float]]:
X_text, y_target, df_prod = production_preparation(text, target, seed, sw, n, rm)
predictions, true_labels, probabilities = inference(
tokenize_function(X_text.tolist(), hp_),
model,
y_target,
batch_size=hp_.infra.BATCH_SIZE.value,
)
return predictions, true_labels, df_prod, probabilities
def f1_distribution(production_pipeline, n_samples, **kwargs):
np.random.seed(kwargs.get('seed', 42))
f1_scores, f1_str = [], ''
for i in range(n_samples):
kwargs['seed'] = kwargs['seed'] + i
predictions, true_labels, _, _ = production_pipeline(**kwargs)
f1 = f1_score(true_labels, predictions)
f1_scores.append(f1)
f1_str += ' ' + str(round(f1, 2))
lower_ci, upper_ci = np.percentile(f1_scores, 2.5), np.percentile(f1_scores, 97.5)
return f1_scores, np.mean(f1_scores), np.std(f1_scores), lower_ci, upper_ci
Subset Testing¶
predictions_, true_labels_, probas_ = inference(
test_encodings, model, target=test['y'], batch_size=hp.infra.BATCH_SIZE.value
)
report(true_labels_, probas_)
f1_score(true_labels_, predictions_), hp.subset_id, SUBSETS
precision recall f1-score support 0 0.99 0.97 0.98 8590 1 0.75 0.90 0.82 968 accuracy 0.96 9558 macro avg 0.87 0.93 0.90 9558 weighted avg 0.96 0.96 0.96 9558
(0.8185654008438819, 3, 4)
Pipeline-Based Predictions (Test)¶
predictions, true_labels, df_production, proba = production_pipeline(
text=dp.cache['production']['X_test'],
target=dp.cache['production']['y_test'],
seed=SEED+hp.subset_id,
hp_=hp,
n=PRODUCTION_N_SAMPLES,
sw=stop_words_processor,
)
report(true_labels, proba)
precision recall f1-score support 0 0.98 0.97 0.97 1802 1 0.74 0.81 0.77 198 accuracy 0.95 2000 macro avg 0.86 0.89 0.87 2000 weighted avg 0.95 0.95 0.95 2000
fpr, tpr, _ = roc_curve(true_labels, proba)
roc_auc = roc_auc_score(true_labels, proba)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(alpha=0.3)
plt.show()
Bootstrap Test¶
f1_scores, mean_f1, std_f1, lower_ci, upper_ci = f1_distribution(
production_pipeline=production_pipeline,
n_samples=BOOTSTRAP_N_SAMPLES,
text=dp.cache['production']['X_test'],
target=dp.cache['production']['y_test'],
hp_=hp,
seed=SEED,
sw=stop_words_processor,
n=PRODUCTION_N_SAMPLES_BOOTSTRAP,
rm=STL_ID == STOPWORDS_STL_ID,
)
plt.hist(f1_scores, bins=30, alpha=0.7, edgecolor='black')
plt.axvline(mean_f1, color='red', linestyle='dashed', linewidth=2, label='Mean F1')
plt.axvline(
lower_ci, color='green', linestyle='dashed', linewidth=2, label='2.5th Percentile'
)
plt.axvline(
upper_ci, color='green', linestyle='dashed', linewidth=2, label='97.5th Percentile'
)
plt.title("F1 Score Distribution")
plt.xlabel("F1 Score")
plt.ylabel("Frequency")
plt.legend()
plt.show()
Conclusions¶
The project successfully trained a toxic comment classification model using sequential transfer learning, achieving high performance metrics and demonstrating strong generalization capabilities.
Sequential Transfer Learning (STL) is especially effective in addressing data drift issues because it enables the model to iteratively adapt to smaller, representative data subsets. This ensures that the model maintains accuracy even as the data distribution changes over time. Splitting the dataset into manageable parts makes the training process more computationally efficient and allows for hyperparameter tuning at each stage. In addition, STL facilitates pre-training on the initial subset, followed by targeted learning rate adjustments on subsequent subsets, establishing a controlled and systematic process for improving model performance.
The production pipeline integrates preprocessing, inference, and evaluation, making it ideally suited for real-world deployment. Furthermore, the pipeline design is both scalable and modular, allowing for the integration of new subsets or data to ensure adaptability and long-term viability in dynamic environments.
Model v2.3.4¶
Upon completing training on the 4th
subset in the sequential transfer learning process, the model achieved a high F1 score of 0.82
, reflecting its ability to effectively balance precision and recall in toxic comment classification. The F1 score distribution, obtained through bootstrap analysis, demonstrated stability and consistency, with a mean value of about 0.82
and confidence intervals ranging from approximately 0.75
(2.5th percentile) to 0.87
(97.5th percentile). This indicates that the model's performance is not only strong but also reliable across different conditions, confirming the robustness of the sequential transfer learning approach at this stage. These results underscore the effectiveness of iterative improvements via subsets while maintaining statistical reliability.
The model delivered outstanding results, achieving an impressive ROC-AUC of 0.96
. This highlights its ability to clearly distinguish between toxic and non-toxic comments, a key metric for classification tasks, and confirms its reliability. In addition, the bootstrap analysis of the F1 metric showed stability and consistency in the model's performance. Confidence intervals, such as the 2.5th and 97.5th percentiles, provide a quantitative measure of uncertainty, further confirming the model's robustness under various conditions.