{ "cells": [ { "cell_type": "markdown", "source": [ "AG News Dataset at A Glance\r\n", "===========================\r\n", "\r\n", "The **AG News** dataset is a collection of more than one million news articles collected in 2005 by academics for experimenting with data mining and information extraction methods. The goal of this\r\n", "example is to illustrate the effectiveness of pretrained word embeddings in classifying texts. For this example, we use a slimmed down version consisting of 120,000 news articles that are split evenly\r\n", "between four categories: \r\n", "- Sports, \r\n", "- Science/Technology, \r\n", "- World, and \r\n", "- Business. \r\n", "\r\n", "In addition to slimming down the dataset, we focus on the article headlines as our observations and create the multiclass classification task of predicting the category given the headline." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Imports" ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "import collections\r\n", "import numpy as np\r\n", "import pandas as pd\r\n", "import re\r\n", "\r\n", "from argparse import Namespace" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Setting up" ], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "args = Namespace(\r\n", " raw_dataset_csv=\"../data/ag_news/news.csv\",\r\n", " train_proportion=0.7,\r\n", " val_proportion=0.15,\r\n", " test_proportion=0.15,\r\n", " output_munged_csv=\"../data/ag_news/news_with_splits.csv\",\r\n", " seed=1337\r\n", ")" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Read Data" ], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "# Read raw data\r\n", "news = pd.read_csv(args.raw_dataset_csv, header=0)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "news.head()" ], "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
categorytitle
0BusinessWall St. Bears Claw Back Into the Black (Reuters)
1BusinessCarlyle Looks Toward Commercial Aerospace (Reu...
2BusinessOil and Economy Cloud Stocks' Outlook (Reuters)
3BusinessIraq Halts Oil Exports from Main Southern Pipe...
4BusinessOil prices soar to all-time record, posing new...
\n", "
" ], "text/plain": [ " category title\n", "0 Business Wall St. Bears Claw Back Into the Black (Reuters)\n", "1 Business Carlyle Looks Toward Commercial Aerospace (Reu...\n", "2 Business Oil and Economy Cloud Stocks' Outlook (Reuters)\n", "3 Business Iraq Halts Oil Exports from Main Southern Pipe...\n", "4 Business Oil prices soar to all-time record, posing new..." ] }, "metadata": {}, "execution_count": 5 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 6, "source": [ "# Unique classes\r\n", "set(news.category)" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'Business', 'Sci/Tech', 'Sports', 'World'}" ] }, "metadata": {}, "execution_count": 6 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 7, "source": [ "# Splitting train by category\r\n", "# Create dict\r\n", "by_category = collections.defaultdict(list)\r\n", "for _, row in news.iterrows():\r\n", " by_category[row.category].append(row.to_dict())" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Create training, validation and test split" ], "metadata": {} }, { "cell_type": "code", "execution_count": 8, "source": [ "# Create split data\r\n", "final_list = []\r\n", "np.random.seed(args.seed)\r\n", "for _, item_list in sorted(by_category.items()):\r\n", " np.random.shuffle(item_list)\r\n", " n = len(item_list)\r\n", " n_train = int(args.train_proportion*n)\r\n", " n_val = int(args.val_proportion*n)\r\n", " n_test = int(args.test_proportion*n)\r\n", " \r\n", " # Give data point a split attribute\r\n", " for item in item_list[:n_train]:\r\n", " item['split'] = 'train'\r\n", " for item in item_list[n_train:n_train+n_val]:\r\n", " item['split'] = 'val'\r\n", " for item in item_list[n_train+n_val:]:\r\n", " item['split'] = 'test' \r\n", " \r\n", " # Add to final list\r\n", " final_list.extend(item_list)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 9, "source": [ "# Write split data to file\r\n", "final_news = pd.DataFrame(final_list)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 10, "source": [ "final_news.split.value_counts()" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "train 84000\n", "val 18000\n", "test 18000\n", "Name: split, dtype: int64" ] }, "metadata": {}, "execution_count": 10 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 11, "source": [ "# Preprocess the reviews\r\n", "def preprocess_text(text):\r\n", " text = ' '.join(word.lower() for word in text.split(\" \"))\r\n", " text = re.sub(r\"([.,!?])\", r\" \\1 \", text)\r\n", " text = re.sub(r\"[^a-zA-Z.,!?]+\", r\" \", text)\r\n", " return text\r\n", " \r\n", "final_news.title = final_news.title.apply(preprocess_text)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 12, "source": [ "final_news.head()" ], "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
categorytitlesplit
0Businessjobs , tax cuts key issues for bushtrain
1Businessjarden buying mr . coffee s makertrain
2Businessretail sales show festive fervourtrain
3Businessintervoice s customers come callingtrain
4Businessboeing expects air force contracttrain
\n", "
" ], "text/plain": [ " category title split\n", "0 Business jobs , tax cuts key issues for bush train\n", "1 Business jarden buying mr . coffee s maker train\n", "2 Business retail sales show festive fervour train\n", "3 Business intervoice s customers come calling train\n", "4 Business boeing expects air force contract train" ] }, "metadata": {}, "execution_count": 12 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 13, "source": [ "# Write munged data to CSV\r\n", "final_news.to_csv(args.output_munged_csv, index=False)" ], "outputs": [], "metadata": {} } ], "metadata": { "kernelspec": { "name": "python3", "display_name": "Python 3.8.10 64-bit ('cits4012': conda)" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "toc": { "colors": { "hover_highlight": "#DAA520", "running_highlight": "#FF0000", "selected_highlight": "#FFD700" }, "moveMenuLeft": true, "nav_menu": { "height": "12px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": "5", "toc_cell": false, "toc_section_display": "block", "toc_window_display": false }, "interpreter": { "hash": "d990147e05fc0cc60dd3871899a6233eb6a5324c1885ded43d013dc915f7e535" } }, "nbformat": 4, "nbformat_minor": 2 }