Multilanguage topic modeling with BERT
!pip install contextualized_topic_models
!pip uninstall transformers -y
!pip install transformers==3.0.2
import os
import numpy as np
import pickle
from contextualized_topic_models.models.ctm import CTM
from contextualized_topic_models.utils.data_preparation import bert_embeddings_from_file, bert_embeddings_from_list
from contextualized_topic_models.datasets.dataset import CTMDataset
from contextualized_topic_models.utils.data_preparation import TextHandler
!curl -s https://raw.githubusercontent.com/MilaNLProc/contextualized-topic-models/master/contextualized_topic_models/data/gnews/GoogleNews.txt | head -n1000 > googlenews.txt
!head googlenews.txt
!cat googlenews.txt | wc -l
file_name = "googlenews.txt"
handler = TextHandler(file_name)
handler.prepare() # create vocabulary and training data
train_bert = bert_embeddings_from_file(file_name, "distiluse-base-multilingual-cased")
training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)
num_topics = 50
ctm = CTM(input_size=len(handler.vocab), bert_input_size=512, num_epochs=100, hidden_sizes = (100, ),
inference_type="contextual", n_components=num_topics, num_data_loader_workers=0)
ctm.fit(training_dataset) # run the model
ctm.get_topic_lists(5) # get the top-5 words lists
!tail -n 5 googlenews.txt > test.txt
!cat test.txt
test_handler = TextHandler("test.txt")
test_handler.prepare() # create vocabulary and training data
# generate BERT data
testing_bert = bert_embeddings_from_file("test.txt", "distiluse-base-multilingual-cased")
testing_dataset = CTMDataset(test_handler.bow, testing_bert, test_handler.idx2token)
# we sample n times and average to get a more accurate estimate of the document-topic distribution
predicted_topics = []
thetas = np.zeros((len(testing_dataset), num_topics))
for a in range(0, 100):
thetas = thetas + np.array(ctm.get_thetas(testing_dataset))
for idd in range(0, len(testing_dataset)):
thetas[idd] = thetas[idd]/np.sum(thetas[idd])
predicted_topic = np.argmax(thetas[idd])
predicted_topics.append(predicted_topic)
# document-topic distribution , list of the topic predicted for each testing document
# thetas,
predicted_topics
test_handler.load_text_file()[1]
ctm.get_topic_lists(20)[41]