4. AG News Dataset at A Glance¶
The AG News dataset is a collection of more than one million news articles collected in 2005 by academics for experimenting with data mining and information extraction methods. The goal of this example is to illustrate the effectiveness of pretrained word embeddings in classifying texts. For this example, we use a slimmed down version consisting of 120,000 news articles that are split evenly between four categories:
Sports,
Science/Technology,
World, and
Business.
In addition to slimming down the dataset, we focus on the article headlines as our observations and create the multiclass classification task of predicting the category given the headline.
4.1. Imports¶
import collections
import numpy as np
import pandas as pd
import re
from argparse import Namespace
4.2. Setting up¶
args = Namespace(
raw_dataset_csv="../data/ag_news/news.csv",
train_proportion=0.7,
val_proportion=0.15,
test_proportion=0.15,
output_munged_csv="../data/ag_news/news_with_splits.csv",
seed=1337
)
4.3. Read Data¶
# Read raw data
news = pd.read_csv(args.raw_dataset_csv, header=0)
news.head()
category | title | |
---|---|---|
0 | Business | Wall St. Bears Claw Back Into the Black (Reuters) |
1 | Business | Carlyle Looks Toward Commercial Aerospace (Reu... |
2 | Business | Oil and Economy Cloud Stocks' Outlook (Reuters) |
3 | Business | Iraq Halts Oil Exports from Main Southern Pipe... |
4 | Business | Oil prices soar to all-time record, posing new... |
# Unique classes
set(news.category)
{'Business', 'Sci/Tech', 'Sports', 'World'}
# Splitting train by category
# Create dict
by_category = collections.defaultdict(list)
for _, row in news.iterrows():
by_category[row.category].append(row.to_dict())
4.4. Create training, validation and test split¶
# Create split data
final_list = []
np.random.seed(args.seed)
for _, item_list in sorted(by_category.items()):
np.random.shuffle(item_list)
n = len(item_list)
n_train = int(args.train_proportion*n)
n_val = int(args.val_proportion*n)
n_test = int(args.test_proportion*n)
# Give data point a split attribute
for item in item_list[:n_train]:
item['split'] = 'train'
for item in item_list[n_train:n_train+n_val]:
item['split'] = 'val'
for item in item_list[n_train+n_val:]:
item['split'] = 'test'
# Add to final list
final_list.extend(item_list)
# Write split data to file
final_news = pd.DataFrame(final_list)
final_news.split.value_counts()
train 84000
val 18000
test 18000
Name: split, dtype: int64
# Preprocess the reviews
def preprocess_text(text):
text = ' '.join(word.lower() for word in text.split(" "))
text = re.sub(r"([.,!?])", r" \1 ", text)
text = re.sub(r"[^a-zA-Z.,!?]+", r" ", text)
return text
final_news.title = final_news.title.apply(preprocess_text)
final_news.head()
category | title | split | |
---|---|---|---|
0 | Business | jobs , tax cuts key issues for bush | train |
1 | Business | jarden buying mr . coffee s maker | train |
2 | Business | retail sales show festive fervour | train |
3 | Business | intervoice s customers come calling | train |
4 | Business | boeing expects air force contract | train |
# Write munged data to CSV
final_news.to_csv(args.output_munged_csv, index=False)