!pip install contextualized_topic_models
!pip uninstall transformers -y
!pip install transformers==3.0.2

Requirement already satisfied: contextualized_topic_models in /usr/local/lib/python3.6/dist-packages (1.4.2)
Requirement already satisfied: torchvision==0.7.0 in /usr/local/lib/python3.6/dist-packages (from contextualized_topic_models) (0.7.0+cu101)
Requirement already satisfied: gensim==3.8.3 in /usr/local/lib/python3.6/dist-packages (from contextualized_topic_models) (3.8.3)
Requirement already satisfied: wheel==0.33.6 in /usr/local/lib/python3.6/dist-packages (from contextualized_topic_models) (0.33.6)
Requirement already satisfied: pytest-runner==5.1 in /usr/local/lib/python3.6/dist-packages (from contextualized_topic_models) (5.1)
Requirement already satisfied: pytest==4.6.5 in /usr/local/lib/python3.6/dist-packages (from contextualized_topic_models) (4.6.5)
Requirement already satisfied: numpy==1.19.1 in /usr/local/lib/python3.6/dist-packages (from contextualized_topic_models) (1.19.1)
Requirement already satisfied: sentence-transformers==0.3.2 in /usr/local/lib/python3.6/dist-packages (from contextualized_topic_models) (0.3.2)
Requirement already satisfied: torch==1.6.0 in /usr/local/lib/python3.6/dist-packages (from contextualized_topic_models) (1.6.0)
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.7.0->contextualized_topic_models) (7.0.0)
Requirement already satisfied: six>=1.5.0 in /usr/local/lib/python3.6/dist-packages (from gensim==3.8.3->contextualized_topic_models) (1.15.0)
Requirement already satisfied: scipy>=0.18.1 in /usr/local/lib/python3.6/dist-packages (from gensim==3.8.3->contextualized_topic_models) (1.4.1)
Requirement already satisfied: smart-open>=1.8.1 in /usr/local/lib/python3.6/dist-packages (from gensim==3.8.3->contextualized_topic_models) (2.1.0)
Requirement already satisfied: importlib-metadata>=0.12 in /usr/local/lib/python3.6/dist-packages (from pytest==4.6.5->contextualized_topic_models) (1.7.0)
Requirement already satisfied: more-itertools>=4.0.0; python_version > "2.7" in /usr/local/lib/python3.6/dist-packages (from pytest==4.6.5->contextualized_topic_models) (8.4.0)
Requirement already satisfied: atomicwrites>=1.0 in /usr/local/lib/python3.6/dist-packages (from pytest==4.6.5->contextualized_topic_models) (1.4.0)
Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.6/dist-packages (from pytest==4.6.5->contextualized_topic_models) (1.9.0)
Requirement already satisfied: pluggy<1.0,>=0.12 in /usr/local/lib/python3.6/dist-packages (from pytest==4.6.5->contextualized_topic_models) (0.13.1)
Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.6/dist-packages (from pytest==4.6.5->contextualized_topic_models) (20.1.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from pytest==4.6.5->contextualized_topic_models) (20.4)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from pytest==4.6.5->contextualized_topic_models) (0.2.5)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from sentence-transformers==0.3.2->contextualized_topic_models) (0.22.2.post1)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from sentence-transformers==0.3.2->contextualized_topic_models) (4.41.1)
Requirement already satisfied: transformers>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from sentence-transformers==0.3.2->contextualized_topic_models) (3.1.0)
Requirement already satisfied: nltk in /usr/local/lib/python3.6/dist-packages (from sentence-transformers==0.3.2->contextualized_topic_models) (3.2.5)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch==1.6.0->contextualized_topic_models) (0.16.0)
Requirement already satisfied: boto in /usr/local/lib/python3.6/dist-packages (from smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (2.49.0)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (2.23.0)
Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (1.14.48)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.12->pytest==4.6.5->contextualized_topic_models) (3.1.0)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->pytest==4.6.5->contextualized_topic_models) (2.4.7)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->sentence-transformers==0.3.2->contextualized_topic_models) (0.16.0)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers>=3.0.2->sentence-transformers==0.3.2->contextualized_topic_models) (0.0.43)
Requirement already satisfied: tokenizers==0.8.1.rc2 in /usr/local/lib/python3.6/dist-packages (from transformers>=3.0.2->sentence-transformers==0.3.2->contextualized_topic_models) (0.8.1rc2)
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers>=3.0.2->sentence-transformers==0.3.2->contextualized_topic_models) (3.0.12)
Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers>=3.0.2->sentence-transformers==0.3.2->contextualized_topic_models) (0.7)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers>=3.0.2->sentence-transformers==0.3.2->contextualized_topic_models) (2019.12.20)
Requirement already satisfied: sentencepiece!=0.1.92 in /usr/local/lib/python3.6/dist-packages (from transformers>=3.0.2->sentence-transformers==0.3.2->contextualized_topic_models) (0.1.91)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (2020.6.20)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (3.0.4)
Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (0.10.0)
Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (0.3.3)
Requirement already satisfied: botocore<1.18.0,>=1.17.48 in /usr/local/lib/python3.6/dist-packages (from boto3->smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (1.17.48)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers>=3.0.2->sentence-transformers==0.3.2->contextualized_topic_models) (7.1.2)
Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.18.0,>=1.17.48->boto3->smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (0.15.2)
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.18.0,>=1.17.48->boto3->smart-open>=1.8.1->gensim==3.8.3->contextualized_topic_models) (2.8.1)
Uninstalling transformers-3.1.0:
  Successfully uninstalled transformers-3.1.0
Collecting transformers==3.0.2
  Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
     |████████████████████████████████| 778kB 3.4MB/s 
Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2) (0.0.43)
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2) (3.0.12)
Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2) (0.7)
Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2) (20.4)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2) (2019.12.20)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2) (2.23.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2) (1.19.1)
Collecting tokenizers==0.8.1.rc1
  Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)
     |████████████████████████████████| 3.0MB 17.9MB/s 
Requirement already satisfied: sentencepiece!=0.1.92 in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2) (0.1.91)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers==3.0.2) (4.41.1)
Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==3.0.2) (0.16.0)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==3.0.2) (7.1.2)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==3.0.2) (1.15.0)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers==3.0.2) (2.4.7)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==3.0.2) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==3.0.2) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==3.0.2) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==3.0.2) (2020.6.20)
Installing collected packages: tokenizers, transformers
  Found existing installation: tokenizers 0.8.1rc2
    Uninstalling tokenizers-0.8.1rc2:
      Successfully uninstalled tokenizers-0.8.1rc2
Successfully installed tokenizers-0.8.1rc1 transformers-3.0.2

import os
import numpy as np
import pickle
from contextualized_topic_models.models.ctm import CTM
from contextualized_topic_models.utils.data_preparation import bert_embeddings_from_file, bert_embeddings_from_list
from contextualized_topic_models.datasets.dataset import CTMDataset
from contextualized_topic_models.utils.data_preparation import TextHandler
!curl -s https://raw.githubusercontent.com/MilaNLProc/contextualized-topic-models/master/contextualized_topic_models/data/gnews/GoogleNews.txt | head -n1000 > googlenews.txt
!head googlenews.txt
!cat googlenews.txt | wc -l
centrepoint winter white gala london
mourinho seek killer instinct
roundup golden globe won seduced johansson voice
travel disruption mount storm cold air sweep south florida
wes welker blame costly turnover
psalm book fetch record ny auction ktvn channel reno
surface review comparison window powered tablet pitted
scientist unreported fish trap space
nokia lumia launch
edward snowden latest leak nsa monitored online porn habit radicalizers
1000

Load The Data

file_name = "googlenews.txt"
handler = TextHandler(file_name)
handler.prepare() # create vocabulary and training data 
train_bert = bert_embeddings_from_file(file_name, "distiluse-base-multilingual-cased")
training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

Train the Fully Contextualized Topic Model

num_topics = 50
ctm = CTM(input_size=len(handler.vocab), bert_input_size=512, num_epochs=100, hidden_sizes = (100, ),
            inference_type="contextual", n_components=num_topics, num_data_loader_workers=0)

ctm.fit(training_dataset) # run the model
ctm.get_topic_lists(5) # get the top-5 words lists
[['kim', 'west', 'kanye', 'kardashian', 'bound'],
 ['day', 'thanksgiving', 'parade', 'macy', 'packer'],
 ['patriot', 'bronco', 'pat', 'packer', 'loss'],
 ['xbox', 'microsoft', 'p', 'game', 'console'],
 ['government', 'political', 'thai', 'party', 'protest'],
 ['oldboy', 'brolin', 'josh', 'lee', 'spike'],
 ['google', 'chrome', 'search', 'extension', 'voice'],
 ['johansson', 'globe', 'golden', 'scarlett', 'ineligible'],
 ['star', 'dancing', 'amber', 'riley', 'win'],
 ['police', 'guilty', 'watkins', 'case', 'lostprophets'],
 ['san', 'andreas', 'gta', 'mobile', 'android'],
 ['flat', 'future', 'record', 'level', 'p'],
 ['thanksgiving', 'day', 'parade', 'thanksgivukkah', 'holiday'],
 ['jos', 'wearhouse', 'men', 'bank', 'baldwin'],
 ['prince', 'william', 'swift', 'jovi', 'bon'],
 ['porn', 'nsa', 'habit', 'radicalizers', 'spying'],
 ['pope', 'church', 'putin', 'issue', 'coalition'],
 ['report', 'benghazi', 'security', 'baldwin', 'alec'],
 ['china', 'zone', 'flight', 'airspace', 'disputed'],
 ['storm', 'parade', 'macy', 'balloon', 'travel'],
 ['bank', 'men', 'palestinian', 'jos', 'wearhouse'],
 ['review', 'homefront', 'frozen', 'inch', 'oldboy'],
 ['bronco', 'packer', 'seahawks', 'rodgers', 'patriot'],
 ['frozen', 'heart', 'review', 'homefront', 'detroit'],
 ['hiv', 'meningitis', 'flu', 'greece', 'health'],
 ['black', 'friday', 'nativity', 'deal', 'monday'],
 ['aarushi', 'hiv', 'killing', 'teen', 'murder'],
 ['west', 'kanye', 'kim', 'seth', 'bound'],
 ['cb', 'seahawks', 'dallas', 'chelsea', 'browner'],
 ['hp', 'revenue', 'raise', 'week', 'shopping'],
 ['lumia', 'nokia', 'price', 'power', 'uk'],
 ['typhoon', 'philippine', 'haiyan', 'climate', 'gain'],
 ['african', 'france', 'central', 'republic', 'troop'],
 ['parade', 'macy', 'carlos', 'beltran', 'york'],
 ['kim', 'kardashian', 'video', 'west', 'bound'],
 ['hewitt', 'love', 'star', 'jennifer', 'dancing'],
 ['swift', 'william', 'taylor', 'prince', 'jovi'],
 ['launch', 'microsoft', 'chrome', 'google', 'search'],
 ['pakistan', 'army', 'chief', 'sharif', 'pm'],
 ['air', 'china', 'zone', 'sea', 'disputed'],
 ['west', 'kanye', 'bound', 'kim', 'video'],
 ['ison', 'comet', 'raptor', 'sun', 'bonobo'],
 ['irs', 'google', 'tax', 'group', 'glass'],
 ['net', 'review', 'preview', 'disney', 'movie'],
 ['nokia', 'lumia', 'tablet', 'window', 'moto'],
 ['three', 'seahawks', 'year', 'burning', 'officer'],
 ['report', 'burning', 'officer', 'storm', 'truck'],
 ['girl', 'baby', 'guilty', 'lostprophets', 'hewitt'],
 ['black', 'friday', 'sale', 'deal', 'monday'],
 ['heart', 'woman', 'pill', 'frozen', 'crisis']]
!tail -n 5 googlenews.txt > test.txt
!cat test.txt
ray whitney return will dallas star huge boost offensively
s relied intermediary probe spacex sept upper stage
nokia lumia tablet kill surface
lakers net preview
neighbor helped save girl imprisoned year speaks
test_handler = TextHandler("test.txt")
test_handler.prepare() # create vocabulary and training data

# generate BERT data
testing_bert = bert_embeddings_from_file("test.txt", "distiluse-base-multilingual-cased")
testing_dataset = CTMDataset(test_handler.bow, testing_bert, test_handler.idx2token)
# we sample n times and average to get a more accurate estimate of the document-topic distribution
predicted_topics = [] 
thetas = np.zeros((len(testing_dataset), num_topics))
for a in range(0, 100):
    thetas = thetas + np.array(ctm.get_thetas(testing_dataset))
    
for idd in range(0, len(testing_dataset)):
    
    thetas[idd] = thetas[idd]/np.sum(thetas[idd])
    predicted_topic = np.argmax(thetas[idd]) 
    predicted_topics.append(predicted_topic)

# document-topic distribution , list of the topic predicted for each testing document
# thetas, 
predicted_topics 
[22, 41, 44, 23, 47]
test_handler.load_text_file()[1]
's relied intermediary probe spacex sept upper stage\n'
ctm.get_topic_lists(20)[41]
['ison',
 'comet',
 'raptor',
 'sun',
 'bonobo',
 'dna',
 'flying',
 'trouble',
 'stereo',
 'seahorse',
 'researcher',
 'preview',
 'spacecraft',
 'century',
 'jellyfish',
 'testing',
 'minute',
 'net',
 'spectacular',
 'congo']