4. AG News Dataset at A Glance

The AG News dataset is a collection of more than one million news articles collected in 2005 by academics for experimenting with data mining and information extraction methods. The goal of this example is to illustrate the effectiveness of pretrained word embeddings in classifying texts. For this example, we use a slimmed down version consisting of 120,000 news articles that are split evenly between four categories:

  • Sports,

  • Science/Technology,

  • World, and

  • Business.

In addition to slimming down the dataset, we focus on the article headlines as our observations and create the multiclass classification task of predicting the category given the headline.

4.1. Imports

import collections
import numpy as np
import pandas as pd
import re

from argparse import Namespace

4.2. Setting up

args = Namespace(
    raw_dataset_csv="../data/ag_news/news.csv",
    train_proportion=0.7,
    val_proportion=0.15,
    test_proportion=0.15,
    output_munged_csv="../data/ag_news/news_with_splits.csv",
    seed=1337
)

4.3. Read Data

# Read raw data
news = pd.read_csv(args.raw_dataset_csv, header=0)
news.head()
category title
0 Business Wall St. Bears Claw Back Into the Black (Reuters)
1 Business Carlyle Looks Toward Commercial Aerospace (Reu...
2 Business Oil and Economy Cloud Stocks' Outlook (Reuters)
3 Business Iraq Halts Oil Exports from Main Southern Pipe...
4 Business Oil prices soar to all-time record, posing new...
# Unique classes
set(news.category)
{'Business', 'Sci/Tech', 'Sports', 'World'}
# Splitting train by category
# Create dict
by_category = collections.defaultdict(list)
for _, row in news.iterrows():
    by_category[row.category].append(row.to_dict())

4.4. Create training, validation and test split

# Create split data
final_list = []
np.random.seed(args.seed)
for _, item_list in sorted(by_category.items()):
    np.random.shuffle(item_list)
    n = len(item_list)
    n_train = int(args.train_proportion*n)
    n_val = int(args.val_proportion*n)
    n_test = int(args.test_proportion*n)
    
    # Give data point a split attribute
    for item in item_list[:n_train]:
        item['split'] = 'train'
    for item in item_list[n_train:n_train+n_val]:
        item['split'] = 'val'
    for item in item_list[n_train+n_val:]:
        item['split'] = 'test'  
    
    # Add to final list
    final_list.extend(item_list)
# Write split data to file
final_news = pd.DataFrame(final_list)
final_news.split.value_counts()
train    84000
val      18000
test     18000
Name: split, dtype: int64
# Preprocess the reviews
def preprocess_text(text):
    text = ' '.join(word.lower() for word in text.split(" "))
    text = re.sub(r"([.,!?])", r" \1 ", text)
    text = re.sub(r"[^a-zA-Z.,!?]+", r" ", text)
    return text
    
final_news.title = final_news.title.apply(preprocess_text)
final_news.head()
category title split
0 Business jobs , tax cuts key issues for bush train
1 Business jarden buying mr . coffee s maker train
2 Business retail sales show festive fervour train
3 Business intervoice s customers come calling train
4 Business boeing expects air force contract train
# Write munged data to CSV
final_news.to_csv(args.output_munged_csv, index=False)