{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"execution": {},
"id": "view-in-github"
},
"source": [
" "
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"# Machine Translation\n",
"\n",
"**By Neuromatch Academy**\n",
"\n",
"__Content creators:__ Juan Manuel Rodriguez, Salomey Osei\n",
"\n",
"__Production editors:__ Amita Kapoor, Spiros Chavlis"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Objective\n",
"\n",
"The main goal of this project is to train a sequence to sequence NN that transtlate a language into another language, e.g. french to english. This notebook is based on this [Pytorch tutorial](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html), but change several thing. "
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# Imports\n",
"import io\n",
"import re\n",
"import math\n",
"import random\n",
"import unicodedata\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch import optim\n",
"\n",
"from tqdm.notebook import tqdm\n",
"from sklearn.utils import shuffle\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# Download the data\n",
"import requests, zipfile\n",
"\n",
"zip_file_url = 'https://download.pytorch.org/tutorial/data.zip'\n",
"r = requests.get(zip_file_url)\n",
"z = zipfile.ZipFile(io.BytesIO(r.content))\n",
"z.extractall()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Go.\tVa !\n",
"Run!\tCours !\n",
"Run!\tCourez !\n",
"Wow!\tÇa alors !\n",
"Fire!\tAu feu !\n",
"Help!\tÀ l'aide !\n",
"Jump.\tSaute.\n",
"Stop!\tÇa suffit !\n",
"Stop!\tStop !\n",
"Stop!\tArrête-toi !\n"
]
}
],
"source": [
"N = 10 # print the 10 first lines\n",
"with open('data/eng-fra.txt') as f:\n",
" for i in range(N):\n",
" line = next(f).strip()\n",
" print(line)"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Representing the data\n",
"\n",
"We create a language representation defining indixes for each word. In addition to the words, our languages have three special tokens:\n",
"\n",
"* SOS: Start Of Sentence\n",
"* EOS: End Of Sentence\n",
"* PAD: Padding token used to fill inputs vectors where there are no other words."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"SOS_token = 0\n",
"EOS_token = 1\n",
"\n",
"class Lang:\n",
" def __init__(self, name):\n",
" self.name = name\n",
" self.word2index = {}\n",
" self.word2count = {}\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\", 2: \"PAD\"}\n",
" self.n_words = 3 # Count SOS and EOS and PAD\n",
"\n",
" def addSentence(self, sentence):\n",
" for word in sentence.split(' '):\n",
" self.addWord(word)\n",
"\n",
" def addWord(self, word):\n",
" if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n",
" self.index2word[self.n_words] = word\n",
" self.n_words += 1\n",
" else:\n",
" self.word2count[word] += 1\n",
"\n",
"\n",
"def unicodeToAscii(s):\n",
" return ''.join(\n",
" c for c in unicodedata.normalize('NFD', s)\n",
" if unicodedata.category(c) != 'Mn'\n",
" )\n",
"\n",
"\n",
"def normalizeString(s):\n",
" s = unicodeToAscii(s.lower().strip())\n",
" s = re.sub(r\"([.!?])\", r\" \\1\", s)\n",
" s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)\n",
" return s\n",
"\n",
"\n",
"def readLangs(lang1, lang2, reverse=False):\n",
" print(\"Reading lines...\")\n",
"\n",
" # Read the file and split into lines\n",
" lines = io.open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\\\n",
" read().strip().split('\\n')\n",
"\n",
" # Split every line into pairs and normalize\n",
" pairs = [[normalizeString(s) for s in l.split('\\t')] for l in lines]\n",
"\n",
" # Reverse pairs, make Lang instances\n",
" if reverse:\n",
" pairs = [list(reversed(p)) for p in pairs]\n",
" input_lang = Lang(lang2)\n",
" output_lang = Lang(lang1)\n",
" else:\n",
" input_lang = Lang(lang1)\n",
" output_lang = Lang(lang2)\n",
"\n",
" return input_lang, output_lang, pairs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"MAX_LENGTH = 10\n",
"\n",
"eng_prefixes = (\n",
" \"i am \", \"i m \",\n",
" \"he is\", \"he s \",\n",
" \"she is\", \"she s \",\n",
" \"you are\", \"you re \",\n",
" \"we are\", \"we re \",\n",
" \"they are\", \"they re \"\n",
")\n",
"\n",
"\n",
"def filterPair(p):\n",
" return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
" len(p[1].split(' ')) < MAX_LENGTH and \\\n",
" p[1].startswith(eng_prefixes)\n",
"\n",
"\n",
"def filterPairs(pairs):\n",
" return [pair for pair in pairs if filterPair(pair)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading lines...\n",
"Read 135842 sentence pairs\n",
"Trimmed to 10599 sentence pairs\n",
"Counting words...\n",
"Counted words:\n",
"fra 4346\n",
"eng 2804\n",
"['nous sommes sans emploi .', 'we re unemployed .']\n"
]
}
],
"source": [
"def prepareData(lang1, lang2, reverse=False):\n",
" input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n",
" print(\"Read %s sentence pairs\" % len(pairs))\n",
" pairs = filterPairs(pairs)\n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
" print(\"Counting words...\")\n",
" for pair in pairs:\n",
" input_lang.addSentence(pair[0])\n",
" output_lang.addSentence(pair[1])\n",
" print(\"Counted words:\")\n",
" print(input_lang.name, input_lang.n_words)\n",
" print(output_lang.name, output_lang.n_words)\n",
" return input_lang, output_lang, pairs\n",
"\n",
"\n",
"input_lang, output_lang, pairs = prepareData('eng', 'fra', True)\n",
"print(random.choice(pairs))"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"## Language word distributions\n",
"\n",
"We can check which is the word distribution in our dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"def plot_lang(lang, top_k=100):\n",
" words = list(lang.word2count.keys())\n",
" words.sort(key=lambda w: lang.word2count[w], reverse=True)\n",
" print(words[:top_k])\n",
" count_occurences = sum(lang.word2count.values())\n",
"\n",
" accumulated = 0\n",
" counter = 0\n",
"\n",
" while accumulated < count_occurences * 0.8:\n",
" accumulated += lang.word2count[words[counter]]\n",
" counter += 1\n",
"\n",
" print(f\"The {counter * 100 / len(words)}% most common words \"\n",
" f\"account for the {accumulated * 100 / count_occurences}% of the occurrences\")\n",
" plt.bar(range(100), [lang.word2count[w] for w in words[:top_k]])\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['.', 'je', 'suis', 'est', 'vous', 'pas', 'de', 'il', 'nous', 'tu', 'etes', 'ne', 'es', 'en', 'a', 'n', 'un', 'sommes', 'elle', 'la', 'tres', 'c', 'que', 'le', 'sont', 'j', 'une', 'd', 'ai', 'pour', 'l', 'ils', 'plus', 'ce', 'des', 'me', 'vais', 'elles', 'moi', '!', 'mon', 'trop', 'train', 'fort', 'si', 'ici', 'du', 'toujours', 'toi', 'tout', 'tous', 'les', '?', 'vraiment', 'sur', 't', 'te', 'm', 'dans', 'avec', 'avoir', 'encore', 'qu', 'tom', 'votre', 'au', 'peur', 'y', 'desole', 'bien', 'ca', 'bon', 'fais', 'toutes', 'heureux', 'faire', 'etre', 'son', 'aussi', 'assez', 'lui', 'tellement', 'ma', 'mes', 'fatigue', 'par', 'et', 'fait', 'ton', 'se', 'juste', 'maintenant', 'grand', 'desolee', 'avons', 'allons', 'peu', 'deux', 'on', 'vieux']\n",
"The 4.674188349067465% most common words account for the 80.0371543427945% of the occurrences\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAQaUlEQVR4nO3df4ydVZ3H8fdnqaBgpEUmDbbNTjc2GjRxYRuoYWMMdaHgxvIHGoyRxnS3fyyuaEzcsvtHsyoJJkaEZCVpaLUYA7KVLI11Jd2CMfsHlUEMApXtCGLbFDragkbjj+p3/7inu5c6A525M3Pbe9+v5OY+z3nO89zz5DT9zDnPuTOpKiRJw+3P+t0ASVL/GQaSJMNAkmQYSJIwDCRJwIJ+N2Cmzj///BodHe13MyTptPHoo4/+rKpGJjt22obB6OgoY2Nj/W6GJJ02kjw31TGniSRJhoEkyTCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CSxJCGwejGnYxu3NnvZkjSKWMow0CS9HKvGgZJtiY5nOSJrrLzkuxKsq+9L2rlSXJ7kvEkjye5uOucda3+viTrusr/KskP2zm3J8ls36Qk6ZWdzMjgK8CaE8o2AruragWwu+0DXAWsaK8NwB3QCQ9gE3ApcAmw6XiAtDp/33XeiZ8lSZpjrxoGVfVd4MgJxWuBbW17G3BNV/ld1fEwsDDJBcCVwK6qOlJVR4FdwJp27A1V9XBVFXBX17UkSfNkps8MFlfVobb9PLC4bS8B9nfVO9DKXqn8wCTlk0qyIclYkrGJiYkZNl2SdKKeHyC3n+hrFtpyMp+1uapWVtXKkZFJ/z6DJGkGZhoGL7QpHtr74VZ+EFjWVW9pK3ul8qWTlEuS5tFMw2AHcHxF0Drg/q7y69uqolXAS2066QHgiiSL2oPjK4AH2rFfJFnVVhFd33UtSdI8edU/e5nkbuDdwPlJDtBZFXQLcG+S9cBzwAda9W8BVwPjwK+BjwBU1ZEknwEeafU+XVXHH0r/A50VS68D/rO9JEnz6FXDoKo+OMWh1ZPULeCGKa6zFdg6SfkY8PZXa4ckae74DWRJkmEgSTIMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJJEj2GQ5BNJnkzyRJK7k7w2yfIke5KMJ/l6kjNb3bPa/ng7Ptp1nZta+dNJruztliRJ0zXjMEiyBPgYsLKq3g6cAVwHfA64tareDBwF1rdT1gNHW/mtrR5JLmznvQ1YA3wpyRkzbZckafp6nSZaALwuyQLgbOAQcDmwvR3fBlzTtte2fdrx1UnSyu+pqt9W1bPAOHBJj+2SJE3DjMOgqg4Cnwd+SicEXgIeBV6sqmOt2gFgSdteAuxv5x5r9d/YXT7JOS+TZEOSsSRjExMTM226JOkEvUwTLaLzU/1y4E3AOXSmeeZMVW2uqpVVtXJkZGQuP0qShkov00TvAZ6tqomq+j1wH3AZsLBNGwEsBQ627YPAMoB2/Fzg593lk5wjSZoHvYTBT4FVSc5uc/+rgaeAh4BrW511wP1te0fbpx1/sKqqlV/XVhstB1YA3+uhXZKkaVrw6lUmV1V7kmwHvg8cAx4DNgM7gXuSfLaVbWmnbAG+mmQcOEJnBRFV9WSSe+kEyTHghqr6w0zbJUmavhmHAUBVbQI2nVD8DJOsBqqq3wDvn+I6NwM399IWSdLM+Q1kSZJhIEkyDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgSaLHMEiyMMn2JD9KsjfJO5Ocl2RXkn3tfVGrmyS3JxlP8niSi7uus67V35dkXa83JUmanl5HBrcB366qtwLvAPYCG4HdVbUC2N32Aa4CVrTXBuAOgCTnAZuAS4FLgE3HA0SSND9mHAZJzgXeBWwBqKrfVdWLwFpgW6u2Dbimba8F7qqOh4GFSS4ArgR2VdWRqjoK7ALWzLRdkqTp62VksByYAL6c5LEkdyY5B1hcVYdaneeBxW17CbC/6/wDrWyq8j+RZEOSsSRjExMTPTRdktStlzBYAFwM3FFVFwG/4v+nhACoqgKqh894maraXFUrq2rlyMjIbF1WkoZeL2FwADhQVXva/nY64fBCm/6hvR9uxw8Cy7rOX9rKpiqXJM2TGYdBVT0P7E/ylla0GngK2AEcXxG0Dri/be8Arm+rilYBL7XppAeAK5Isag+Or2hlkqR5sqDH8/8R+FqSM4FngI/QCZh7k6wHngM+0Op+C7gaGAd+3epSVUeSfAZ4pNX7dFUd6bFdkqRp6CkMquoHwMpJDq2epG4BN0xxna3A1l7aIkmaOb+BLEkyDCRJhoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkpiFMEhyRpLHknyz7S9PsifJeJKvJzmzlZ/V9sfb8dGua9zUyp9OcmWvbZIkTc9sjAxuBPZ27X8OuLWq3gwcBda38vXA0VZ+a6tHkguB64C3AWuALyU5YxbaJUk6ST2FQZKlwHuBO9t+gMuB7a3KNuCatr227dOOr2711wL3VNVvq+pZYBy4pJd2SZKmp9eRwReBTwF/bPtvBF6sqmNt/wCwpG0vAfYDtOMvtfr/Vz7JOS+TZEOSsSRjExMTPTZdknTcjMMgyd8Ch6vq0Vlszyuqqs1VtbKqVo6MjMzXx0rSwFvQw7mXAe9LcjXwWuANwG3AwiQL2k//S4GDrf5BYBlwIMkC4Fzg513lx3WfI0maBzMeGVTVTVW1tKpG6TwAfrCqPgQ8BFzbqq0D7m/bO9o+7fiDVVWt/Lq22mg5sAL43kzbJUmavl5GBlP5J+CeJJ8FHgO2tPItwFeTjANH6AQIVfVkknuBp4BjwA1V9Yc5aJckaQqzEgZV9R3gO237GSZZDVRVvwHeP8X5NwM3z0ZbJEnT5zeQJUmGwejGnYxu3NnvZkhSXw19GEiSDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMHiZ0Y07Gd24s9/NkKR5ZxhIkgwDSZJhIEnCMJAkYRhIkjAMJEkYBpIkegiDJMuSPJTkqSRPJrmxlZ+XZFeSfe19UStPktuTjCd5PMnFXdda1+rvS7Ku99uSJE1HLyODY8Anq+pCYBVwQ5ILgY3A7qpaAexu+wBXASvaawNwB3TCA9gEXApcAmw6HiCSpPkx4zCoqkNV9f22/UtgL7AEWAtsa9W2Ade07bXAXdXxMLAwyQXAlcCuqjpSVUeBXcCambZLkjR9s/LMIMkocBGwB1hcVYfaoeeBxW17CbC/67QDrWyq8sk+Z0OSsSRjExMTs9F0SRKzEAZJXg98A/h4Vf2i+1hVFVC9fkbX9TZX1cqqWjkyMjJbl5WkoddTGCR5DZ0g+FpV3deKX2jTP7T3w638ILCs6/SlrWyqcknSPOllNVGALcDeqvpC16EdwPEVQeuA+7vKr2+rilYBL7XppAeAK5Isag+Or2hlkqR5sqCHcy8DPgz8MMkPWtk/A7cA9yZZDzwHfKAd+xZwNTAO/Br4CEBVHUnyGeCRVu/TVXWkh3bNiuO/yvont7z3ZduSNIhmHAZV9d9Apji8epL6BdwwxbW2Altn2hZJUm/8BrIkyTCQJBkG0+afxpQ0iAwDSZJhIEnqbWnp0OueLnLZqaTTmSMDSZJhIElymmjWOGUk6XTmyECSZBhIkpwmmhMn86U0p5IknUocGUiSDANJktNEfePqI0mnEkcGkiRHBqcCRwmS+s2RgSTJkcGpZqplqY4YJM0lRwanEf+wjqS5YhicprqDwZCQ1CvDQJJkGEiSDIOB45SRpJkwDAaYzxUknSzDYAgZDJJOZBgMualGD44qpOHil840Lf7qDGkwGQaasV5GCwaJdGoxDNQXJ/vX4I7XMzykueUzA502fI4hzR1HBjrtdY8e5nL7+L40iE6ZMEiyBrgNOAO4s6pu6XOTpD8x16ORuQ40w0xTOSXCIMkZwL8BfwMcAB5JsqOqnupvy6TBM93R0FybjwActHueC6fKM4NLgPGqeqaqfgfcA6ztc5skaWikqvrdBpJcC6ypqr9r+x8GLq2qj55QbwOwoe2+BXi6h489H/hZD+efjrzn4eA9D4eZ3POfV9XIZAdOiWmik1VVm4HNs3GtJGNVtXI2rnW68J6Hg/c8HGb7nk+VaaKDwLKu/aWtTJI0D06VMHgEWJFkeZIzgeuAHX1ukyQNjVNimqiqjiX5KPAAnaWlW6vqyTn+2FmZbjrNeM/DwXseDrN6z6fEA2RJUn+dKtNEkqQ+MgwkScMZBknWJHk6yXiSjf1uz2xLsizJQ0meSvJkkhtb+XlJdiXZ194X9butsy3JGUkeS/LNtr88yZ7W119vCxQGRpKFSbYn+VGSvUneOej9nOQT7d/1E0nuTvLaQeznJFuTHE7yRFfZpH2bjtvb/T+e5OLpft7QhUHXr764CrgQ+GCSC/vbqll3DPhkVV0IrAJuaPe4EdhdVSuA3W1/0NwI7O3a/xxwa1W9GTgKrO9Lq+bObcC3q+qtwDvo3PvA9nOSJcDHgJVV9XY6C06uYzD7+SvAmhPKpurbq4AV7bUBuGO6HzZ0YcAQ/OqLqjpUVd9v27+k8x/EEjr3ua1V2wZc058Wzo0kS4H3Ane2/QCXA9tblYG65yTnAu8CtgBU1e+q6kUGvJ/prIJ8XZIFwNnAIQawn6vqu8CRE4qn6tu1wF3V8TCwMMkF0/m8YQyDJcD+rv0DrWwgJRkFLgL2AIur6lA79DywuE/NmitfBD4F/LHtvxF4saqOtf1B6+vlwATw5TY1dmeScxjgfq6qg8DngZ/SCYGXgEcZ7H7uNlXf9vz/2jCGwdBI8nrgG8DHq+oX3ceqs6Z4YNYVJ/lb4HBVPdrvtsyjBcDFwB1VdRHwK06YEhrAfl5E56fg5cCbgHP406mUoTDbfTuMYTAUv/oiyWvoBMHXquq+VvzC8aFjez/cr/bNgcuA9yX5CZ2pv8vpzKcvbNMJMHh9fQA4UFV72v52OuEwyP38HuDZqpqoqt8D99Hp+0Hu525T9W3P/68NYxgM/K++aHPlW4C9VfWFrkM7gHVtex1w/3y3ba5U1U1VtbSqRun06YNV9SHgIeDaVm3Q7vl5YH+St7Si1cBTDHA/05keWpXk7Pbv/Pg9D2w/n2Cqvt0BXN9WFa0CXuqaTjo5VTV0L+Bq4H+AHwP/0u/2zMH9/TWd4ePjwA/a62o6c+i7gX3AfwHn9butc3T/7wa+2bb/AvgeMA78O3BWv9s3y/f6l8BY6+v/ABYNej8D/wr8CHgC+Cpw1iD2M3A3neciv6czClw/Vd8CobNK8sfAD+mstprW5/nrKCRJQzlNJEk6gWEgSTIMJEmGgSQJw0CShGEgScIwkCQB/wslKzo0kE6LnAAAAABJRU5ErkJggg==\n",
"text/plain": [
"