We'll be using aitextgen to finetune the model.

pip install aitextgen

Requirement already satisfied: aitextgen in /usr/local/lib/python3.7/dist-packages (0.5.2)
Requirement already satisfied: pytorch-lightning>=1.3.1 in /usr/local/lib/python3.7/dist-packages (from aitextgen) (1.3.4)
Requirement already satisfied: fire>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from aitextgen) (0.4.0)
Requirement already satisfied: transformers>=4.5.1 in /usr/local/lib/python3.7/dist-packages (from aitextgen) (4.6.1)
Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from aitextgen) (1.8.1+cu101)
Requirement already satisfied: PyYAML<=5.4.1,>=5.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning>=1.3.1->aitextgen) (5.4.1)
Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning>=1.3.1->aitextgen) (0.18.2)
Requirement already satisfied: pyDeprecate==0.3.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning>=1.3.1->aitextgen) (0.3.0)
Requirement already satisfied: fsspec[http]>=2021.4.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning>=1.3.1->aitextgen) (2021.5.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning>=1.3.1->aitextgen) (20.9)
Requirement already satisfied: tensorboard!=2.5.0,>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning>=1.3.1->aitextgen) (2.4.1)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning>=1.3.1->aitextgen) (4.41.1)
Requirement already satisfied: torchmetrics>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning>=1.3.1->aitextgen) (0.3.2)
Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning>=1.3.1->aitextgen) (1.19.5)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from fire>=0.3.0->aitextgen) (1.15.0)
Requirement already satisfied: termcolor in /usr/local/lib/python3.7/dist-packages (from fire>=0.3.0->aitextgen) (1.1.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers>=4.5.1->aitextgen) (3.0.12)
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers>=4.5.1->aitextgen) (0.10.3)
Requirement already satisfied: huggingface-hub==0.0.8 in /usr/local/lib/python3.7/dist-packages (from transformers>=4.5.1->aitextgen) (0.0.8)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers>=4.5.1->aitextgen) (0.0.45)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers>=4.5.1->aitextgen) (2019.12.20)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from transformers>=4.5.1->aitextgen) (4.0.1)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers>=4.5.1->aitextgen) (2.23.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.6.0->aitextgen) (3.7.4.3)
Requirement already satisfied: aiohttp; extra == "http" in /usr/local/lib/python3.7/dist-packages (from fsspec[http]>=2021.4.0->pytorch-lightning>=1.3.1->aitextgen) (3.7.4.post0)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->pytorch-lightning>=1.3.1->aitextgen) (2.4.7)
Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (1.30.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (1.34.1)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (3.12.4)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (1.0.1)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (57.0.0)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (1.8.0)
Requirement already satisfied: wheel>=0.26; python_version >= "3" in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (0.36.2)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (0.4.4)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (3.3.4)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (0.12.0)
Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers>=4.5.1->aitextgen) (7.1.2)
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers>=4.5.1->aitextgen) (1.0.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < "3.8"->transformers>=4.5.1->aitextgen) (3.4.1)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers>=4.5.1->aitextgen) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers>=4.5.1->aitextgen) (2020.12.5)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers>=4.5.1->aitextgen) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers>=4.5.1->aitextgen) (1.24.3)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp; extra == "http"->fsspec[http]>=2021.4.0->pytorch-lightning>=1.3.1->aitextgen) (5.1.0)
Requirement already satisfied: async-timeout<4.0,>=3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp; extra == "http"->fsspec[http]>=2021.4.0->pytorch-lightning>=1.3.1->aitextgen) (3.0.1)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp; extra == "http"->fsspec[http]>=2021.4.0->pytorch-lightning>=1.3.1->aitextgen) (1.6.3)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp; extra == "http"->fsspec[http]>=2021.4.0->pytorch-lightning>=1.3.1->aitextgen) (21.2.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (0.2.8)
Requirement already satisfied: rsa<5,>=3.1.4; python_version >= "3.6" in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (4.7.2)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (4.2.2)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (1.3.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning>=1.3.1->aitextgen) (3.1.0)

Import modules and mount google drive

from aitextgen import aitextgen
from aitextgen.colab import mount_gdrive, copy_file_from_gdrive
from aitextgen.TokenDataset import TokenDataset, merge_datasets
from aitextgen.utils import build_gpt2_config
from aitextgen.tokenizers import train_tokenizer

mount_gdrive()
!curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt > input.txt
!head input.txt
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1089k  100 1089k    0     0  9002k      0 --:--:-- --:--:-- --:--:-- 9002k
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:

Train tokenizer

file_name = "input.txt"
project_name = "project_name"

# copy_file_from_gdrive(file_name)
train_tokenizer(file_name);
INFO:aitextgen.tokenizers:Saving aitextgen-vocab.json and aitextgen-merges.txt to the current directory. You will need both files to build the GPT2Tokenizer.

Training the model should take about 30 minutes

model = None
config = None

for _ in ["pytorch_model.bin", "config.json", "aitextgen_vocab.json", "aitextgen_merges.json"]:
    try:
        copy_file_from_gdrive(_, project_name)
        model = "pytorch_model.bin"
        config = "config.json"
    except FileNotFoundError:
        pass

config = config or build_gpt2_config(
    vocab_size=5000, max_length=200, dropout=0.0, n_embd=256, n_layer=8, n_head=8
)

ai = aitextgen(
    vocab_file="aitextgen-vocab.json",
    merges_file="aitextgen-merges.txt",
    config=config,
    model=model,
    to_gpu=True
)
INFO:aitextgen:Constructing GPT-2 model from provided config.
INFO:aitextgen:Using a custom tokenizer.
ai.train(
    file_name,
    line_by_line=False,
    num_steps=10000,
    generate_every=1000,
    save_every=500,
    learning_rate=1e-4,
    batch_size=128,
    save_gdrive=True,
    run_id=project_name
)
INFO:aitextgen.TokenDataset:Encoding 40,000 sets of tokens from input.txt.
GPU available: True, used: True
INFO:lightning:GPU available: True, used: True
TPU available: False, using: 0 TPU cores
INFO:lightning:TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]

Generating examples

ai.generate(
    n=5,
    batch_size=5,
    prompt="Speak:",
    temperature=1.0,
    top_p=0.9,    
)