{
"cells": [
{
"cell_type": "markdown",
"source": [
"AG News Dataset at A Glance\r\n",
"===========================\r\n",
"\r\n",
"The **AG News** dataset is a collection of more than one million news articles collected in 2005 by academics for experimenting with data mining and information extraction methods. The goal of this\r\n",
"example is to illustrate the effectiveness of pretrained word embeddings in classifying texts. For this example, we use a slimmed down version consisting of 120,000 news articles that are split evenly\r\n",
"between four categories: \r\n",
"- Sports, \r\n",
"- Science/Technology, \r\n",
"- World, and \r\n",
"- Business. \r\n",
"\r\n",
"In addition to slimming down the dataset, we focus on the article headlines as our observations and create the multiclass classification task of predicting the category given the headline."
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Imports"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"import collections\r\n",
"import numpy as np\r\n",
"import pandas as pd\r\n",
"import re\r\n",
"\r\n",
"from argparse import Namespace"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Setting up"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"args = Namespace(\r\n",
" raw_dataset_csv=\"../data/ag_news/news.csv\",\r\n",
" train_proportion=0.7,\r\n",
" val_proportion=0.15,\r\n",
" test_proportion=0.15,\r\n",
" output_munged_csv=\"../data/ag_news/news_with_splits.csv\",\r\n",
" seed=1337\r\n",
")"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Read Data"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"source": [
"# Read raw data\r\n",
"news = pd.read_csv(args.raw_dataset_csv, header=0)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"news.head()"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" category | \n",
" title | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Business | \n",
" Wall St. Bears Claw Back Into the Black (Reuters) | \n",
"
\n",
" \n",
" 1 | \n",
" Business | \n",
" Carlyle Looks Toward Commercial Aerospace (Reu... | \n",
"
\n",
" \n",
" 2 | \n",
" Business | \n",
" Oil and Economy Cloud Stocks' Outlook (Reuters) | \n",
"
\n",
" \n",
" 3 | \n",
" Business | \n",
" Iraq Halts Oil Exports from Main Southern Pipe... | \n",
"
\n",
" \n",
" 4 | \n",
" Business | \n",
" Oil prices soar to all-time record, posing new... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" category title\n",
"0 Business Wall St. Bears Claw Back Into the Black (Reuters)\n",
"1 Business Carlyle Looks Toward Commercial Aerospace (Reu...\n",
"2 Business Oil and Economy Cloud Stocks' Outlook (Reuters)\n",
"3 Business Iraq Halts Oil Exports from Main Southern Pipe...\n",
"4 Business Oil prices soar to all-time record, posing new..."
]
},
"metadata": {},
"execution_count": 5
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"source": [
"# Unique classes\r\n",
"set(news.category)"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'Business', 'Sci/Tech', 'Sports', 'World'}"
]
},
"metadata": {},
"execution_count": 6
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"source": [
"# Splitting train by category\r\n",
"# Create dict\r\n",
"by_category = collections.defaultdict(list)\r\n",
"for _, row in news.iterrows():\r\n",
" by_category[row.category].append(row.to_dict())"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Create training, validation and test split"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"source": [
"# Create split data\r\n",
"final_list = []\r\n",
"np.random.seed(args.seed)\r\n",
"for _, item_list in sorted(by_category.items()):\r\n",
" np.random.shuffle(item_list)\r\n",
" n = len(item_list)\r\n",
" n_train = int(args.train_proportion*n)\r\n",
" n_val = int(args.val_proportion*n)\r\n",
" n_test = int(args.test_proportion*n)\r\n",
" \r\n",
" # Give data point a split attribute\r\n",
" for item in item_list[:n_train]:\r\n",
" item['split'] = 'train'\r\n",
" for item in item_list[n_train:n_train+n_val]:\r\n",
" item['split'] = 'val'\r\n",
" for item in item_list[n_train+n_val:]:\r\n",
" item['split'] = 'test' \r\n",
" \r\n",
" # Add to final list\r\n",
" final_list.extend(item_list)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"# Write split data to file\r\n",
"final_news = pd.DataFrame(final_list)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"final_news.split.value_counts()"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"train 84000\n",
"val 18000\n",
"test 18000\n",
"Name: split, dtype: int64"
]
},
"metadata": {},
"execution_count": 10
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"source": [
"# Preprocess the reviews\r\n",
"def preprocess_text(text):\r\n",
" text = ' '.join(word.lower() for word in text.split(\" \"))\r\n",
" text = re.sub(r\"([.,!?])\", r\" \\1 \", text)\r\n",
" text = re.sub(r\"[^a-zA-Z.,!?]+\", r\" \", text)\r\n",
" return text\r\n",
" \r\n",
"final_news.title = final_news.title.apply(preprocess_text)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 12,
"source": [
"final_news.head()"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" category | \n",
" title | \n",
" split | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Business | \n",
" jobs , tax cuts key issues for bush | \n",
" train | \n",
"
\n",
" \n",
" 1 | \n",
" Business | \n",
" jarden buying mr . coffee s maker | \n",
" train | \n",
"
\n",
" \n",
" 2 | \n",
" Business | \n",
" retail sales show festive fervour | \n",
" train | \n",
"
\n",
" \n",
" 3 | \n",
" Business | \n",
" intervoice s customers come calling | \n",
" train | \n",
"
\n",
" \n",
" 4 | \n",
" Business | \n",
" boeing expects air force contract | \n",
" train | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" category title split\n",
"0 Business jobs , tax cuts key issues for bush train\n",
"1 Business jarden buying mr . coffee s maker train\n",
"2 Business retail sales show festive fervour train\n",
"3 Business intervoice s customers come calling train\n",
"4 Business boeing expects air force contract train"
]
},
"metadata": {},
"execution_count": 12
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 13,
"source": [
"# Write munged data to CSV\r\n",
"final_news.to_csv(args.output_munged_csv, index=False)"
],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3.8.10 64-bit ('cits4012': conda)"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"toc": {
"colors": {
"hover_highlight": "#DAA520",
"running_highlight": "#FF0000",
"selected_highlight": "#FFD700"
},
"moveMenuLeft": true,
"nav_menu": {
"height": "12px",
"width": "252px"
},
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": "5",
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
},
"interpreter": {
"hash": "d990147e05fc0cc60dd3871899a6233eb6a5324c1885ded43d013dc915f7e535"
}
},
"nbformat": 4,
"nbformat_minor": 2
}