João R.

Text Classification with Python

If you are already familiar with what text classification is, you might want to jump to this part, or get the code here.

What is Text Classification?

Document or text classification is used to classify information, that is, assign a category to a text; it can be a document, a tweet, a simple message, an email, and so on. In this article, I will show how you can classify retail products into categories. Although in this example the categories are structured in a hierarchy, to keep it simple I will consider all subcategories as top-level.

If you are looking for complex implementations of large scale hierarchical text classification, I will leave links to some really good papers and projects at the end of this post.

Getting started

Now, before you go any further, make sure you have installed Python3+ and virtualenv (optional, but I highly recommend you to use it).

Let’s break down the problem into steps:

Setting up the environment

The main packages used in this projects are: sklearn, nltk and dataset. Due to the size of the data-set, it might take some time to clone/download the repository; NLTK data is also considerably big. Run the following commands to setup the project structure and download the required packages:

# Clone the repo
git clone;
cd text-classification-python;

# Create virtualenv; skip this one if you dont have virtualenv.
virtualenv venv && source venv/bin/activate;

# Install all requirements
pip install -r requirements.txt;

# Download all data that NLTK uses
python -m nltk.downloader all;

Gathering the data

The dataset that will be used was created by scraping some products from Amazon. Scraping might be fine for projects where only a small amount of data is required, but it can be a really slow process since it is very simple for a server to detect a robot, unless you are rotating over a list of proxies, which can slow the process even more.

Using this script, I downloaded information of over 22,000 products, organized into 42 top-level categories, and a total of 6233 subcategories. See the whole category tree structure here.

Again, to keep it simple I will be using only 3 top-level categories: Automotive, Home & Kitchen and Industrial & Scientific. Including the subcategories, there are 36 categories in total.

To extract the data from database, run the command:

# dump from db to dumps/all_products.json
datafreeze .datafreeze.yaml;

Inside the project you will also find a file called, in this file you can set the categories you want to use, the minimum amount of samples per category and the depth of a category. As I said before, only 3 categories are going to be used: Home & Kitchen, Industrial & Scientific and Automotive. I did not specify the depth of the subcategories, but I did specify 50 as the minimum amount of samples (is this case, products) per category. To transform the data dumped from the database into this “filtered” data, just execute the file:


The script will create a new file called products.json at the root of the project, and print out the category tree structure. Change the value of the variables default_depth, min_samples and domain if you need more data.

Extracting features from the dataset

In order to run machine learning algorithms, we need to transform the text into numerical vectors. Bag-of-words is one of the most used models, it assigns a numerical value to a word, creating a list of numbers. It can also assign a value to a set of words, known as N-gram.

Scikit provides a vectorizer called TfidfVectorizer which transforms the text based on the bag-of-words/n-gram model, additionally, it computes term frequencies and evaluate each word using the tf-idf weighting scheme.

Counting terms frequencies might not be enough sometimes. Take the words ‘cars’ and ‘car’ for example, by only using tf-idf, they are considered different words. This problem can be solved using Stemming and/or Lemmatisation. And there is where NLTK comes into play.

NLTK offers some pretty useful tools for NLP. For this project I used it to perform Lemmatisation and Part-of-speech tagging.

With Lemmatisation we can group together the inflected forms of a word. For example, the words ‘walked’, ‘walks’ and ‘walking’, can be grouped into their base form, the verb ‘walk’. That is why we need to POS tag each word as a noun, verb, adverb, and so on.

It is also worth noting that some words despite the fact that they appear frequently, they do not really make any difference for classification, in fact they could even help misclassify a text. Words like ‘a’, ‘an’, ‘the’, ‘to’, ‘or’ etc, are known as stop-words. These words can be ignored during the tokenization process.

Testing the algorithms

Now that we have all the features and labels, it is time to train the classifiers. There are a number of algorithms you can use for this type of problem, for example: Multinomial Naive Bayes, Linear SVC, SGD Classifier, K-Neighbors Classifier, Random Forest Classifier. Inside the file you can find an example using the SGDClassifier. Run it yourself using the command:


It will print out the accuracy of each category, along with the confusion matrix.

Here is how it is implemented: load the dataset, initiate WordNetLemmatizer and PerceptronTagger from NLTK. As I was only interested in nouns, verbs, adverbs and adjectives, I created a lookup dict to quicken up the process. Although NLTK is great, its aim is not performance, so I also implemented python’s LRU Cache for both lemmatize and tagger functions.

# Load data
dataset = json.load(open('products.json', encoding='utf-8'))

# Initiate lemmatizer
wnl = WordNetLemmatizer()

# Load tagger pickle
tagger = PerceptronTagger()

# Lookup if tag is noun, verb, adverb or an adjective
tags = {'N': wn.NOUN, 'V': wn.VERB, 'R': wn.ADV, 'J': wn.ADJ}

# Memoization of POS tagging and Lemmatizer
lemmatize_mem = lru_cache(maxsize=10000)(wnl.lemmatize)
tagger_mem = lru_cache(maxsize=10000)(tagger.tag)

Next, the tokenizer function was created. It breaks the text into words and iterate over them, ignoring the stop-words and POS-tagging/Lemmatising the rest. This function will receive all documents from the dataset.

# POS tag sentences and lemmatize each word
def tokenizer(text):
    for token in wordpunct_tokenize(text):
        if token not in ENGLISH_STOP_WORDS:
            tag = tagger_mem(frozenset({token}))
            yield lemmatize_mem(token, tags.get(tag[0][1],  wn.NOUN))

At last the pipeline is defined; the first step is to call TfidfVectorizer, with the tokenizer function preprocessing each document, and then pass through the SGDClassifier. The classifier is trained and tested using 10-fold Cross-Validation provided by the cross_val_predict method from scikit-learn.

# Pipeline definition
pipeline = Pipeline([
    ('vectorizer', TfidfVectorizer(
        ngram_range=(1, 2),
    ('classifier', SGDClassifier(
        alpha=1e-4, n_jobs=-1

# Cross validate using k-fold
y_pred = cross_val_predict(
    pipeline, dataset.get('data'),
    cv=10, n_jobs=-1, verbose=20

# Print out precison, recall and f1 scode.
    dataset.get('target'), y_pred,

And here are the accuracy results for each algorithm I tested (all algorithms were tested with their default parameters):

Algorithms Precision Recall
SGDClassifier 0.975 0.975
LinearSVC 0.972 0.971
RandomForest 0.938 0.936
MultinomialNB 0.882 0.851

The precision is the percentage of the test samples that were classified to the category and actually belonged to the category.

The recall is the percentage of all the test samples that originally belonged to the category and in the evaluation process were correctly classified to the category.


As the category tree gets bigger, and you have more and more data to classify, you cannot use a model as simple as the one above (well, you can but its precision will be very low, not to mention the computational cost). Another important thing to notice, is how you structure the categories, in amazon category structure, a lot of subcategories are so confused that I doubt even humans could correctly classify products to them. The full code of this post can be found here.

If you noticed something wrong, or you know something that can make the algorithms better, please do comment bellow. Thanks for reading!

Further reading