Elman Net: structure-related hyperparameters best possible settings#
Abstract#
The goal of this experiment is to improve Elman Net performance based on the baseline experiment.
Basically, we increased d_emb
, d_hid
and n_lyr
and recorded what happened.
We found that
Increasing
d_emb
from100
to200
makes training loss and perplexity lower.When
d_emb = 100
andd_emb = 200
, increasingn_lyr
from2
to3
(or4
) makes training loss and perplexity lower.Overfitting was observed.
\(100\%\) accuracy on training set is possible.
Performance are really bad for validation sets. This might be the limit of Elman Net.
Environment setup#
We ran experiments on Nvidia RTX 2070S.
CUDA version is 11.4
and CUDA driver version is 470.129.06
.
Experiment setup#
We changed the values of d_emb
, d_hid
and n_lyr
and recorded training loss and perplexity.
Hyperparameters and their values are listed at the table below.
One should compare the value ranges with baseline experiment.
Name |
Values |
---|---|
|
\(\set{100, 150, 200}\) |
|
\(\set{100, 150, 200}\) |
|
\(\set{2, 3, 4}\) |
Tokenizer settings#
We used lmp.script.train_tknzr to train a whitespace tokenizer WsTknzr
.
Compare to baseline settings, using whitespace tokenizer makes vocabulary size larger.
Script was executed as below:
python -m lmp.script.train_tknzr whitespace \
--dset_name demo \
--exp_name demo_tknzr \
--is_uncased \
--max_vocab -1 \
--min_count 0 \
--ver train
Model training settings#
We trained Elman Net language model ElmanNet
with different model structure hyperparameters.
We used lmp.script.train_model to train language models.
Script was executed as below:
python -m lmp.script.train_model Elman-Net \
--dset_name demo \
--batch_size 32 \
--beta1 0.9 \
--beta1 0.999 \
--ckpt_step 500 \
--d_emb D_EMB \
--d_hid D_HID \
--dset_name demo \
--eps 1e-8 \
--exp_name EXP_NAME \
--init_lower -0.1 \
--init_upper 0.1 \
--label_smoothing 0.0 \
--log_step 100 \
--lr 1e-3 \
--max_norm 1 \
--max_seq_len 35 \
--n_lyr N_LYR \
--p_emb 0.0 \
--p_hid 0.0 \
--seed 42 \
--stride 35 \
--tknzr_exp_name demo_tknzr \
--total_step 40000 \
--ver train \
--warmup_step 10000 \
--weight_decay 0.0
Model evaluation settings#
We evaluated language models using lmp.script.eval_dset_ppl. Script was executed as below:
python -m lmp.script.eval_dset_ppl demo \
--batch_size 512 \
--exp_name EXP_NAME \
--first_ckpt 0 \
--last_ckpt -1 \
--seed 42 \
--ver VER
Experiment results#
All results were logged on tensorboard. You can launch tensorboard with the script
pipenv run tensorboard
Training loss#
|
|
|
5k steps |
10k steps |
15k steps |
20k steps |
25k steps |
30k steps |
35k steps |
40k steps |
---|---|---|---|---|---|---|---|---|---|---|
100 |
100 |
2 |
1.043 |
0.9594 |
0.9187 |
0.8927 |
0.8647 |
0.8515 |
0.8371 |
0.8321 |
100 |
100 |
3 |
1.027 |
0.9519 |
0.9051 |
0.8775 |
0.855 |
0.8369 |
0.8175 |
0.8122 |
100 |
100 |
4 |
1.04 |
0.9851 |
0.9294 |
0.8947 |
0.8628 |
0.8543 |
0.8294 |
0.8223 |
100 |
150 |
2 |
1.036 |
0.96 |
0.9166 |
0.8774 |
0.8613 |
0.8378 |
0.8246 |
0.8189 |
100 |
150 |
3 |
1.017 |
0.9633 |
0.9202 |
0.9002 |
0.8678 |
0.8449 |
0.8257 |
0.8192 |
100 |
150 |
4 |
1.009 |
0.9833 |
0.9239 |
0.9004 |
0.8686 |
0.8287 |
0.816 |
0.81 |
100 |
200 |
2 |
1.026 |
0.9754 |
0.9341 |
0.8995 |
0.8743 |
0.8446 |
0.8331 |
0.8258 |
100 |
200 |
3 |
1.013 |
0.9676 |
0.9332 |
0.8963 |
0.8673 |
0.8452 |
0.8219 |
0.8163 |
100 |
200 |
4 |
1.019 |
0.9735 |
0.9311 |
0.8999 |
0.8698 |
0.843 |
0.8156 |
0.8088 |
150 |
100 |
2 |
1.032 |
0.947 |
0.9044 |
0.8719 |
0.8492 |
0.8284 |
0.8197 |
0.8127 |
150 |
100 |
3 |
1.027 |
0.9455 |
0.9033 |
0.876 |
0.8455 |
0.8224 |
0.815 |
0.8076 |
150 |
100 |
4 |
1.024 |
0.9553 |
0.9059 |
0.8767 |
0.8479 |
0.8153 |
0.8065 |
0.8009 |
150 |
150 |
2 |
1.008 |
0.9533 |
0.9095 |
0.8718 |
0.8398 |
0.8122 |
0.8026 |
0.797 |
150 |
150 |
3 |
1.006 |
0.9699 |
0.9125 |
0.8878 |
0.8527 |
0.82 |
0.8107 |
0.8046 |
150 |
150 |
4 |
1.01 |
0.9586 |
0.9154 |
0.8907 |
0.8576 |
0.8227 |
0.8057 |
0.7997 |
150 |
200 |
2 |
1.007 |
0.9572 |
0.9104 |
0.8758 |
0.8471 |
0.8183 |
0.8059 |
0.7998 |
150 |
200 |
3 |
1.012 |
0.965 |
0.9186 |
0.8866 |
0.8576 |
0.8296 |
0.8089 |
0.8023 |
150 |
200 |
4 |
1.01 |
0.975 |
0.9313 |
0.8979 |
0.8621 |
0.8305 |
0.808 |
0.801 |
200 |
100 |
2 |
1.014 |
0.9473 |
0.9065 |
0.8677 |
0.8453 |
0.8197 |
0.8095 |
0.8027 |
200 |
100 |
3 |
1.008 |
0.9393 |
0.8942 |
0.8656 |
0.8279 |
0.806 |
0.797 |
0.791 |
200 |
100 |
4 |
1.016 |
0.9672 |
0.9139 |
0.8786 |
0.85 |
0.8422 |
0.8063 |
0.7986 |
200 |
150 |
2 |
1.004 |
0.9612 |
0.9108 |
0.8885 |
0.844 |
0.8245 |
0.8047 |
0.799 |
200 |
150 |
3 |
0.9939 |
0.9445 |
0.8991 |
0.8701 |
0.8436 |
0.833 |
0.7979 |
0.7921 |
200 |
150 |
4 |
0.9971 |
0.9465 |
0.9113 |
0.88 |
0.8414 |
0.8129 |
0.7983 |
0.7892 |
200 |
200 |
2 |
0.9984 |
0.9661 |
0.9085 |
0.878 |
0.851 |
0.814 |
0.8032 |
0.7958 |
200 |
200 |
3 |
1.003 |
0.9727 |
0.9111 |
0.8805 |
0.8546 |
0.8162 |
0.8022 |
0.7956 |
200 |
200 |
4 |
0.9909 |
0.9617 |
0.9188 |
0.8797 |
0.8519 |
0.818 |
0.7969 |
0.7904 |
Observation 1: Increasing d_emb
from 100
to 150
in general makes training loss smaller.#
By fixing d_hid
and n_lyr
, we can compare training loss for d_emb = 100
and d_emb = 150
.
Most comparisons (\(\dfrac{67}{72}\)) show that training loss is smaller when increasing d_emb
from 100
to 150
.
Observation 2: Increasing d_emb
from 150
to 200
in general makes training loss smaller.#
By fixing d_hid
and n_lyr
, we can compare training loss for d_emb = 150
and d_emb = 200
.
Most comparisons (\(\dfrac{52}{72}\)) show that training loss is smaller when increasing d_emb
from 150
to 200
.
Observation 3: Increasing d_hid
from 100
to 150
in general makes training loss smaller.#
By fixing d_emb
and n_lyr
, we can compare training loss for d_hid = 100
and d_hid = 150
.
Little more than half comparisons (\(\dfrac{39}{72})\) show that training loss is smaller when increasing d_hid
from 100
to 150
.
Observation 4: Increasing d_hid
from 150
to 200
in general makes training loss larger.#
By fixing d_emb
and n_lyr
, we can compare training loss for d_hid = 150
and d_hid = 200
.
Most comparisons (\(\dfrac{43}{72})\) show that training loss is larger when increasing d_hid
from 150
to 200
.
Observation 5: When d_emb = 100
, increasing n_lyr
from 2
to 3
in general makes training loss smaller.#
By fixing d_emb = 100
and d_hid
, we can compare training loss for n_lyr = 2
and n_lyr = 3
.
Most comparisons (\(\dfrac{17}{24})\) show that training loss is smaller when increasing n_lyr
from 2
to 3
.
Observation 6: When d_emb = 100
, increasing n_lyr
from 2
to 4
in general makes training loss smaller.#
By fixing d_emb = 100
and d_hid
, we can compare training loss for n_lyr = 2
and n_lyr = 4
.
Little more than half comparisons (\(\dfrac{15}{24})\) show that training loss is smaller when increasing n_lyr
from 2
to 4
.
Observation 7: When d_emb = 150
, increasing n_lyr
from 2
to 3
in general makes training loss larger.#
By fixing d_emb = 150
and d_hid
, we can compare training loss for n_lyr = 2
and n_lyr = 3
.
Little more than half comparisons (\(\dfrac{16}{24})\) show that training loss is larger when increasing n_lyr
from 2
to 3
.
Observation 8: When d_emb = 150
, increasing n_lyr
from 2
to 4
in general makes training loss larger.#
By fixing d_emb = 150
and d_hid
, we can compare training loss for n_lyr = 2
and n_lyr = 4
.
Most comparisons (\(\dfrac{19}{24})\) show that training loss is larger when increasing n_lyr
from 2
to 4
Observation 9: When d_emb = 200
, increasing n_lyr
from 2
to 3
in general makes training loss smaller.#
By fixing d_emb = 200
and d_hid
, we can compare training loss for n_lyr = 2
and n_lyr = 3
.
Most comparisons (\(\dfrac{17}{24})\) show that training loss is smaller when increasing n_lyr
from 2
to 3
.
Observation 10: When d_emb = 200
, increasing n_lyr
from 2
to 4
in general makes training loss smaller.#
By fixing d_emb = 200
and d_hid
, we can compare training loss for n_lyr = 2
and n_lyr = 4
.
Little more than half comparisons (\(\dfrac{14}{24})\) show that training loss is smaller when increasing n_lyr
from 2
to 4
.
Observation 11: Minimum loss is achieved when d_emb = 200
, d_hid = 150
and n_lyr = 4
.#
Observation 12: Training loss is still decreasing in all configuration.#
All comparisons (\(\dfrac{189}{189}\)) show that training loss is still decreasing no matter which configuration is used. This suggest that further training may be required.
Perplexity#
|
|
|
5k steps |
10k steps |
15k steps |
20k steps |
25k steps |
30k steps |
35k steps |
40k steps |
||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
|||
100 |
100 |
2 |
2.588 |
4.489 |
2.986 |
2.396 |
6.753 |
2.755 |
2.315 |
12.3 |
2.673 |
2.27 |
21.63 |
2.652 |
2.203 |
26.53 |
2.573 |
2.178 |
29.93 |
2.547 |
2.149 |
30.92 |
2.509 |
2.142 |
30.5 |
2.499 |
100 |
100 |
3 |
2.57 |
6.25 |
2.909 |
2.362 |
17.83 |
2.792 |
2.3 |
27.96 |
2.689 |
2.224 |
40.18 |
2.626 |
2.191 |
44.71 |
2.528 |
2.131 |
56.2 |
2.586 |
2.114 |
58.28 |
2.556 |
2.106 |
59.4 |
2.545 |
100 |
100 |
4 |
2.579 |
4.701 |
2.925 |
2.421 |
23.84 |
2.847 |
2.32 |
68.85 |
2.609 |
2.278 |
119.4 |
2.615 |
2.247 |
154.6 |
2.63 |
2.17 |
156.5 |
2.494 |
2.137 |
168.6 |
2.438 |
2.127 |
175.2 |
2.453 |
100 |
150 |
2 |
2.588 |
4.999 |
2.974 |
2.403 |
11.97 |
2.715 |
2.328 |
19.11 |
2.729 |
2.244 |
24.6 |
2.615 |
2.184 |
29.94 |
2.552 |
2.164 |
33.04 |
2.562 |
2.126 |
34.04 |
2.52 |
2.118 |
34.64 |
2.523 |
100 |
150 |
3 |
2.538 |
4.23 |
2.878 |
2.438 |
11.23 |
2.808 |
2.309 |
19.04 |
2.625 |
2.26 |
26.82 |
2.583 |
2.201 |
32.99 |
2.579 |
2.166 |
38.65 |
2.55 |
2.127 |
39.76 |
2.483 |
2.119 |
40.07 |
2.469 |
100 |
150 |
4 |
2.518 |
4.412 |
2.838 |
2.436 |
13.16 |
2.817 |
2.328 |
30.12 |
2.736 |
2.29 |
46.5 |
2.611 |
2.205 |
48.3 |
2.548 |
2.129 |
52.22 |
2.429 |
2.109 |
59.41 |
2.409 |
2.101 |
59.05 |
2.413 |
100 |
200 |
2 |
2.545 |
4.805 |
2.873 |
2.464 |
15.89 |
2.841 |
2.342 |
30.28 |
2.726 |
2.277 |
39.29 |
2.681 |
2.227 |
46.19 |
2.616 |
2.162 |
48.54 |
2.569 |
2.141 |
48.05 |
2.51 |
2.133 |
49.23 |
2.504 |
100 |
200 |
3 |
2.512 |
5.707 |
2.881 |
2.405 |
20.45 |
2.761 |
2.331 |
40.46 |
2.695 |
2.271 |
55.97 |
2.656 |
2.221 |
58.88 |
2.547 |
2.167 |
68.22 |
2.519 |
2.12 |
68.44 |
2.458 |
2.111 |
68.52 |
2.455 |
100 |
200 |
4 |
2.555 |
6.489 |
3.034 |
2.402 |
27.98 |
2.809 |
2.319 |
35.38 |
2.663 |
2.262 |
43.32 |
2.601 |
2.207 |
51.82 |
2.581 |
2.157 |
56.78 |
2.516 |
2.108 |
61.49 |
2.479 |
2.099 |
62.23 |
2.462 |
150 |
100 |
2 |
2.558 |
5.168 |
2.926 |
2.354 |
14.35 |
2.727 |
2.287 |
23.78 |
2.659 |
2.215 |
31.73 |
2.629 |
2.176 |
33.97 |
2.574 |
2.132 |
36.96 |
2.495 |
2.115 |
40.21 |
2.504 |
2.108 |
40.35 |
2.482 |
150 |
100 |
3 |
2.542 |
6.571 |
2.919 |
2.354 |
15.73 |
2.702 |
2.274 |
22.72 |
2.559 |
2.222 |
28.45 |
2.586 |
2.17 |
35.1 |
2.484 |
2.122 |
40.48 |
2.48 |
2.106 |
44.3 |
2.485 |
2.098 |
45.63 |
2.467 |
150 |
100 |
4 |
2.547 |
10.76 |
3.055 |
2.365 |
15.5 |
2.741 |
2.266 |
35.47 |
2.647 |
2.216 |
56.28 |
2.539 |
2.176 |
71.85 |
2.51 |
2.109 |
79.58 |
2.44 |
2.091 |
88.16 |
2.438 |
2.084 |
90.33 |
2.422 |
150 |
150 |
2 |
2.514 |
7.944 |
2.923 |
2.361 |
23.62 |
2.732 |
2.272 |
39.04 |
2.676 |
2.21 |
50.69 |
2.561 |
2.151 |
60.86 |
2.52 |
2.1 |
71.3 |
2.481 |
2.083 |
72.28 |
2.455 |
2.077 |
73.39 |
2.452 |
150 |
150 |
3 |
2.494 |
8.508 |
2.865 |
2.43 |
38.41 |
2.779 |
2.297 |
61.11 |
2.605 |
2.257 |
90.4 |
2.625 |
2.173 |
115.7 |
2.51 |
2.114 |
135.6 |
2.462 |
2.097 |
148.8 |
2.452 |
2.09 |
147.4 |
2.438 |
150 |
150 |
4 |
2.504 |
7.715 |
2.829 |
2.382 |
33.2 |
2.814 |
2.327 |
56.41 |
2.693 |
2.245 |
74.8 |
2.602 |
2.19 |
88.55 |
2.555 |
2.122 |
98.17 |
2.474 |
2.089 |
108.8 |
2.448 |
2.081 |
109.2 |
2.433 |
150 |
200 |
2 |
2.505 |
5.688 |
2.822 |
2.405 |
39.71 |
2.796 |
2.27 |
71.41 |
2.618 |
2.221 |
80.56 |
2.576 |
2.166 |
99.65 |
2.561 |
2.113 |
109.2 |
2.482 |
2.088 |
114.6 |
2.453 |
2.081 |
114 |
2.446 |
150 |
200 |
3 |
2.535 |
6.452 |
2.912 |
2.446 |
63.95 |
2.809 |
2.307 |
163.4 |
2.657 |
2.244 |
220.2 |
2.579 |
2.18 |
230.6 |
2.539 |
2.128 |
279 |
2.501 |
2.094 |
291.9 |
2.454 |
2.086 |
301 |
2.445 |
150 |
200 |
4 |
2.477 |
7.073 |
2.822 |
2.445 |
30.17 |
2.816 |
2.32 |
43.03 |
2.732 |
2.278 |
53.86 |
2.608 |
2.208 |
67.19 |
2.546 |
2.132 |
76.35 |
2.501 |
2.092 |
78.57 |
2.455 |
2.084 |
80.15 |
2.444 |
200 |
100 |
2 |
2.518 |
6.878 |
2.853 |
2.368 |
41.77 |
2.817 |
2.266 |
124.7 |
2.659 |
2.2 |
233.4 |
2.602 |
2.153 |
331.7 |
2.537 |
2.112 |
450.7 |
2.478 |
2.095 |
544 |
2.516 |
2.089 |
558.5 |
2.497 |
200 |
100 |
3 |
2.507 |
9.783 |
2.864 |
2.344 |
24.58 |
2.717 |
2.266 |
38.58 |
2.698 |
2.193 |
44.55 |
2.582 |
2.13 |
55.65 |
2.542 |
2.088 |
59.09 |
2.472 |
2.07 |
61.16 |
2.459 |
2.064 |
62.02 |
2.467 |
200 |
100 |
4 |
2.516 |
8.239 |
2.857 |
2.405 |
20.88 |
2.77 |
2.299 |
29.06 |
2.668 |
2.234 |
41.72 |
2.574 |
2.197 |
51.4 |
2.562 |
2.175 |
59.59 |
2.575 |
2.088 |
64.57 |
2.455 |
2.08 |
67.06 |
2.444 |
200 |
150 |
2 |
2.52 |
5.719 |
2.851 |
2.402 |
24.45 |
2.805 |
2.28 |
50.64 |
2.638 |
2.241 |
84.59 |
2.645 |
2.164 |
107.5 |
2.571 |
2.122 |
116.8 |
2.517 |
2.087 |
122 |
2.461 |
2.08 |
126.3 |
2.46 |
200 |
150 |
3 |
2.468 |
7.356 |
2.898 |
2.393 |
18.42 |
2.763 |
2.28 |
27.93 |
2.663 |
2.218 |
37.08 |
2.565 |
2.147 |
46.77 |
2.546 |
2.122 |
49.58 |
2.495 |
2.073 |
52.52 |
2.45 |
2.067 |
52.9 |
2.443 |
200 |
150 |
4 |
2.48 |
7.631 |
2.849 |
2.374 |
21.66 |
2.639 |
2.273 |
45.17 |
2.623 |
2.214 |
58.63 |
2.587 |
2.136 |
68.66 |
2.501 |
2.129 |
87.26 |
2.519 |
2.069 |
89.91 |
2.436 |
2.062 |
89.39 |
2.429 |
200 |
200 |
2 |
2.485 |
6.539 |
2.872 |
2.379 |
35.74 |
2.747 |
2.281 |
61.56 |
2.705 |
2.231 |
73.16 |
2.565 |
2.169 |
81.68 |
2.572 |
2.102 |
89.24 |
2.49 |
2.083 |
92.18 |
2.481 |
2.075 |
92.33 |
2.47 |
200 |
200 |
3 |
2.487 |
8.765 |
2.862 |
2.379 |
26.74 |
2.678 |
2.287 |
48.8 |
2.638 |
2.227 |
57.39 |
2.613 |
2.19 |
71.3 |
2.561 |
2.112 |
82.03 |
2.535 |
2.08 |
85.65 |
2.458 |
2.073 |
87.17 |
2.459 |
200 |
200 |
4 |
2.452 |
7.022 |
2.802 |
2.379 |
42.21 |
2.695 |
2.324 |
75.96 |
2.685 |
2.223 |
85.98 |
2.566 |
2.176 |
98.35 |
2.563 |
2.111 |
110.2 |
2.526 |
2.07 |
116.7 |
2.466 |
2.063 |
120.3 |
2.465 |
Observation 1: Increasing d_emb
from 100
to 150
in general makes perplexity smaller.#
By fixing d_hid
and n_lyr
, we can compare perplexity for d_emb = 100
and d_emb = 150
.
Most comparisons (\(\dfrac{138}{216}\)) show that perplexity is smaller when increasing d_emb
from 100
to 150
.
Observation 2: Increasing d_emb
from 150
to 200
in general makes perplexity smaller.#
By fixing d_hid
and n_lyr
, we can compare perplexity for d_emb = 150
and d_emb = 200
.
Most comparisons (\(\dfrac{125}{216}\)) show that perplexity is smaller when increasing d_emb
from 150
to 200
.
Observation 3: Increasing d_hid
from 100
to 150
in general makes perplexity smaller.#
By fixing d_emb
and n_lyr
, we can compare perplexity for d_hid = 100
and d_hid = 150
.
Little more than half comparisons (\(\dfrac{114}{216}\)) show that perplexity is smaller when increasing d_hid
from 100
to 150
.
Observation 4: Increasing d_hid
from 150
to 200
in general makes perplexity larger.#
By fixing d_emb
and n_lyr
, we can compare perplexity for d_hid = 150
and d_hid = 200
.
Most comparisons (\(\dfrac{144}{216}\)) show that perplexity is larger when increasing d_hid
from 150
to 200
.
Observation 5: When d_emb = 100
and d_hid = 100
, increasing n_lyr
from 2
to 3
in general makes perplexity larger.#
By fixing d_emb = 100
and d_hid = 100
, we can compare perplexity for n_lyr = 2
and n_lyr = 3
.
Little more than half comparisons (\(\dfrac{13}{24}\)) show that perplexity is larger when increasing n_lyr
from 2
to 3
.
Observation 6: When d_emb = 100
and d_hid = 150
, increasing n_lyr
from 2
to 3
in general makes perplexity larger.#
By fixing d_emb = 100
and d_hid = 100
, we can compare perplexity for n_lyr = 2
and n_lyr = 3
.
Little more than half comparisons (\(\dfrac{13}{24}\)) show that perplexity is larger when increasing n_lyr
from 2
to 3
.
Observation 7: When d_emb = 100
and d_hid = 200
, increasing n_lyr
from 2
to 3
in general makes perplexity smaller.#
By fixing d_emb = 100
and d_hid = 200
, we can compare perplexity for n_lyr = 2
and n_lyr = 3
.
About but less than comparisons (\(\dfrac{10}{24}\)) show that perplexity is smaller when increasing n_lyr
from 2
to 3
.
Observation 8: When d_emb = 100
and d_hid = 100
, increasing n_lyr
from 2
to 4
in general makes perplexity larger.#
By fixing d_emb = 100
and d_hid = 100
, we can compare perplexity for n_lyr = 2
and n_lyr = 4
.
Little more than half comparisons (\(\dfrac{14}{24}\)) show that perplexity is larger when increasing n_lyr
from 2
to 4
.
Observation 9: When d_emb = 100
and d_hid = 150
, increasing n_lyr
from 2
to 4
doesn’t show the trend of perplexity.#
By fixing d_emb = 100
and d_hid = 150
, we can compare perplexity for n_lyr = 2
and n_lyr = 4
.
Half comparisons (\(\dfrac{12}{24}\)) show that perplexity is larger when increasing n_lyr
from 2
to 4
.
Observation 10: When d_emb = 100
and d_hid = 200
, increasing n_lyr
from 2
to 4
in general makes perplexity smaller.#
By fixing d_emb = 100
and d_hid = 200
, we can compare perplexity for n_lyr = 2
and n_lyr = 4
.
Little less than half comparisons (\(\dfrac{10}{24}\)) show that perplexity is smaller when increasing n_lyr
from 2
to 4
.
Observation 11: When d_emb = 150
, increasing n_lyr
from 2
to 4
in general makes perplexity larger.#
By fixing d_emb = 150
and d_hid
, we can compare perplexity for n_lyr = 2
and n_lyr = 4
.
Most comparisons (\(\dfrac{43}{72}\)) show that perplexity is larger when increasing n_lyr
from 2
to 4
.
Observation 12: When d_emb = 200
, increasing n_lyr
from 2
to 3
in general makes perplexity smaller.#
By fixing d_emb = 200
and d_hid
, we can compare perplexity for n_lyr = 2
and n_lyr = 3
.
Most comparisons (\(\dfrac{58}{72}\)) show that perplexity is smaller when increasing n_lyr
from 2
to 3
.
Observation 13: When d_emb = 200
, increasing n_lyr
from 2
to 4
in general makes perplexity smaller.#
By fixing d_emb = 200
and d_hid
, we can compare perplexity for n_lyr = 2
and n_lyr = 4
.
Most comparisons (\(\dfrac{46}{72}\)) show that perplexity is smaller when increasing n_lyr
from 2
to 4
.
Observation 14: Overfitting seems to happen.#
On test set, most comparisons (\(\dfrac{170}{189}\)) show that perplexity is still decreasing.
However, on validation set, most comparisons (\(\dfrac{183}{189}\)) show that perplexity is increasing.
Most of the perplexity increasing on validation set occur at 10k
or 15k
step.
Observation 15: Minimum perplexity on training set is achieved at 40k
step when d_emb = 200
, d_hid = 150
and n_lyr = 4
.#
On training set, minimum perplexity \(2.062\) is achieved at
40k
step whend_emb = 200
,d_hid = 150
andn_lyr = 4
.On validation set, minimum perplexity \(4.23\) is achieved at
5k
step whend_emb = 100
,d_hid = 150
andn_lyr = 3
.On testing set, minimum perplexity \(2.413\) is achieved at
40k
step whend_emb = 100
,d_hid = 150
andn_lyr = 4
.
Observation 16: Only when setting d_emb = 200
and d_hid = 150
perplexity is lower than \(2.1\).#
Later in the accuracy experiments we see that only when perplexity is lower than \(2.1\), accuracy can be \(100\%\).
Accuracy#
We use the following script to calculate accuracy on demo dataset:
import re
import torch
import lmp.dset
import lmp.infer
import lmp.model
import lmp.script
import lmp.tknzr
import lmp.util.model
import lmp.util.tknzr
device = torch.device('cuda')
tknzr = lmp.util.tknzr.load(exp_name='demo_tknzr')
for d_emb in [100, 150, 200]:
for d_hid in [100, 150, 200]:
for n_lyr in [2, 3, 4]:
for ckpt in [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]:
for ver in lmp.dset.DemoDset.vers:
dset = lmp.dset.DemoDset(ver=ver)
exp_name = f'demo-d_emb-{d_emb}-d_hid-{d_hid}-n_lyr-{n_lyr}'
model = lmp.util.model.load(exp_name=exp_name, ckpt=ckpt).to(device)
infer = lmp.infer.Top1Infer(max_seq_len=35)
correct = 0
for spl in dset:
match = re.match(r'If you add (\d+) to (\d+) you get (\d+) .', spl)
input = f'If you add {match.group(1)} to {match.group(2)} you get '
output = infer.gen(model=model, tknzr=tknzr, txt=input)
if input + output == spl:
correct += 1
print(f'{exp_name}, ckpt: {ckpt}, ver: {ver}, acc: {correct / len(dset) * 100 :.2f}%')
|
|
|
5k steps |
10k steps |
15k steps |
20k steps |
25k steps |
30k steps |
35k steps |
40k steps |
||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
train |
valid |
test |
|||
100 |
100 |
2 |
23.45 |
8.1 |
18 |
31.39 |
9.07 |
22 |
45.19 |
7.21 |
24 |
54.08 |
5.58 |
41 |
81.23 |
7.31 |
56 |
85.6 |
6.3 |
65 |
98.22 |
7.56 |
84 |
98.46 |
8.1 |
88 |
100 |
100 |
3 |
11.43 |
3.92 |
8 |
35.15 |
5.82 |
17 |
39.45 |
6.32 |
33 |
70.16 |
7.64 |
53 |
79.66 |
8.02 |
78 |
98.83 |
7.45 |
83 |
99.7 |
8.48 |
92 |
99.6 |
8.4 |
92 |
100 |
100 |
4 |
20.44 |
8.44 |
17 |
23.9 |
3.21 |
10 |
40.1 |
5.8 |
40 |
46.75 |
4 |
31 |
55.11 |
4.75 |
54 |
89.17 |
5.74 |
72 |
98.06 |
6.24 |
85 |
99.81 |
6.91 |
94 |
100 |
150 |
2 |
13.35 |
8.38 |
8 |
31.86 |
6.3 |
24 |
35.41 |
4.89 |
22 |
64.14 |
6.71 |
51 |
88.30 |
6.51 |
66 |
88.38 |
5.33 |
61 |
99.35 |
6.22 |
88 |
99.6 |
6.26 |
88 |
100 |
150 |
3 |
17.47 |
11.54 |
15 |
21.88 |
4.91 |
20 |
47.07 |
4.14 |
26 |
56.14 |
2.85 |
29 |
76.53 |
3.92 |
54 |
88.34 |
3.41 |
64 |
99.07 |
3.84 |
87 |
99.58 |
4.04 |
88 |
100 |
150 |
4 |
19.62 |
8.59 |
13 |
18.81 |
2.28 |
7 |
34.53 |
2.89 |
18 |
44.65 |
3.8 |
38 |
69.98 |
3.52 |
49 |
99.13 |
4.06 |
82 |
99.9 |
4.34 |
92 |
99.92 |
4.42 |
95 |
100 |
200 |
2 |
26.38 |
10.16 |
12 |
20.42 |
4.1 |
13 |
38.59 |
3.07 |
26 |
54.28 |
3.72 |
29 |
67.47 |
2.95 |
52 |
93.89 |
3.39 |
71 |
96.16 |
3.43 |
80 |
97.82 |
3.62 |
85 |
100 |
200 |
3 |
26.71 |
7.05 |
17 |
27.03 |
3.78 |
22 |
38.14 |
3.68 |
30 |
49.29 |
2.79 |
28 |
68.4 |
2.69 |
54 |
85.6 |
2.63 |
56 |
99.78 |
3.21 |
86 |
99.78 |
2.91 |
86 |
100 |
200 |
4 |
12.59 |
3.49 |
3 |
28.65 |
2.59 |
14 |
43.94 |
2.69 |
27 |
57.37 |
1.52 |
37 |
73.15 |
2.22 |
51 |
90.38 |
2.32 |
67 |
99.88 |
2.4 |
77 |
99.9 |
2.38 |
79 |
150 |
100 |
2 |
23.01 |
7.25 |
16 |
40.83 |
4.99 |
27 |
48.55 |
4.36 |
34 |
71.92 |
4.97 |
48 |
85.23 |
5.43 |
51 |
96.3 |
7.47 |
80 |
98.87 |
6.51 |
81 |
99.29 |
6.99 |
86 |
150 |
100 |
3 |
23.8 |
5.03 |
14 |
36.12 |
6.2 |
21 |
51.52 |
6.89 |
31 |
60.91 |
5.94 |
55 |
82.65 |
5.94 |
62 |
98.87 |
6.81 |
85 |
99.54 |
6.85 |
90 |
99.6 |
7.07 |
89 |
150 |
100 |
4 |
22.65 |
3.52 |
14 |
34.57 |
5.25 |
27 |
54.52 |
4.16 |
36 |
64.22 |
4.24 |
46 |
73.21 |
4.89 |
58 |
99.39 |
5.58 |
90 |
99.74 |
5.27 |
88 |
99.8 |
5.47 |
93 |
150 |
150 |
2 |
20.95 |
5.35 |
13 |
33.9 |
5.52 |
33 |
46.65 |
4.65 |
34 |
67.92 |
3.9 |
42 |
86.22 |
3.13 |
65 |
99.6 |
3.03 |
87 |
99.8 |
2.93 |
89 |
99.8 |
3.05 |
89 |
150 |
150 |
3 |
22.79 |
6.93 |
18 |
23.07 |
2.79 |
20 |
45.31 |
4.24 |
34 |
51.37 |
3.66 |
33 |
80.06 |
4.08 |
68 |
99.25 |
4.12 |
84 |
99.92 |
4.26 |
92 |
99.92 |
4.46 |
94 |
150 |
150 |
4 |
20.4 |
7.17 |
9 |
27.52 |
2.34 |
13 |
36.87 |
2.1 |
28 |
53.62 |
1.8 |
27 |
68.89 |
2.48 |
51 |
95.9 |
2.4 |
74 |
99.76 |
2.57 |
88 |
99.84 |
2.73 |
91 |
150 |
200 |
2 |
20.08 |
7.88 |
12 |
27.58 |
3.11 |
18 |
52.02 |
2.67 |
33 |
63.86 |
2.81 |
44 |
84.28 |
2.02 |
56 |
96.91 |
2.69 |
72 |
99.86 |
2.63 |
87 |
99.88 |
2.63 |
90 |
150 |
200 |
3 |
14.57 |
8.59 |
16 |
20.67 |
3.29 |
15 |
41.39 |
2.4 |
34 |
55.39 |
1.92 |
44 |
79.19 |
2.67 |
60 |
95.88 |
2.2 |
71 |
99.94 |
2.44 |
87 |
99.96 |
2.46 |
86 |
150 |
200 |
4 |
24.85 |
6.2 |
14 |
19.6 |
2.57 |
13 |
38.53 |
2.59 |
32 |
46.06 |
1.86 |
39 |
67.37 |
2.44 |
48 |
93.33 |
2.2 |
73 |
99.94 |
2 |
80 |
99.96 |
2.04 |
82 |
200 |
100 |
2 |
22.32 |
7.74 |
19 |
31.09 |
3.94 |
25 |
51.68 |
4.57 |
35 |
73.17 |
4.51 |
38 |
86.22 |
6.46 |
78 |
98.77 |
6.32 |
84 |
98.95 |
6.42 |
85 |
99.27 |
6.61 |
90 |
200 |
100 |
3 |
22.57 |
2.79 |
16 |
38.44 |
4.06 |
20 |
46.83 |
2.69 |
24 |
70.57 |
4.3 |
53 |
90.38 |
4.44 |
68 |
99.05 |
4.65 |
81 |
99.84 |
4.65 |
88 |
99.88 |
5.09 |
89 |
200 |
100 |
4 |
19.72 |
4.24 |
17 |
25.11 |
8.44 |
23 |
41.8 |
6.12 |
44 |
55.47 |
4.97 |
44 |
69.39 |
5.01 |
51 |
77.11 |
5.35 |
57 |
9.47 |
6.67 |
89 |
99.66 |
6.81 |
90 |
200 |
150 |
2 |
16.1 |
5.82 |
13 |
23.23 |
4.53 |
19 |
45.58 |
3.64 |
32 |
52.97 |
2.57 |
39 |
79.35 |
2.61 |
56 |
94.75 |
2.73 |
71 |
99.62 |
3.66 |
88 |
99.8 |
3.74 |
89 |
200 |
150 |
3 |
25.43 |
4.85 |
12 |
25.37 |
2.81 |
18 |
45.62 |
2.97 |
27 |
59.9 |
2.38 |
39 |
85.76 |
3.05 |
58 |
89.72 |
3.41 |
72 |
99.94 |
3.8 |
89 |
99.96 |
3.8 |
92 |
200 |
150 |
4 |
21.94 |
6.28 |
14 |
28.69 |
3.52 |
27 |
50.32 |
2.63 |
33 |
61.86 |
2.61 |
44 |
89.31 |
2.38 |
63 |
85.09 |
2.32 |
70 |
100 |
2.67 |
87 |
100 |
2.95 |
92 |
200 |
200 |
2 |
23.64 |
7.07 |
10 |
28.63 |
4.2 |
22 |
47.43 |
1.78 |
30 |
58.77 |
2.69 |
50 |
78.48 |
2.93 |
59 |
97.58 |
2.57 |
72 |
9.21 |
2.48 |
83 |
99.54 |
2.53 |
85 |
200 |
200 |
3 |
20.75 |
4.69 |
13 |
27.88 |
3.15 |
24 |
46.28 |
2.61 |
32 |
59.17 |
1.66 |
41 |
72.51 |
2.04 |
56 |
95.52 |
2.1 |
67 |
99.98 |
2.12 |
79 |
99.98 |
2.04 |
82 |
200 |
200 |
4 |
27.17 |
7.88 |
22 |
29.72 |
2.4 |
26 |
37.29 |
1.68 |
24 |
60.87 |
1.64 |
47 |
75.03 |
1.33 |
49 |
94.12 |
1.52 |
72 |
99.96 |
1.7 |
76 |
99.96 |
1.64 |
83 |
Observation 1: \(100\%\) training accuracy is achieved.#
\(100\%\) accuracy is achieved using d_emb = 200
, d_hid = 150
and n_lyr = 4
on step 35k
and 40k
.
Observation 2: Models are not generalized.#
Validation set do not have accuracy higher than \(12\%\). This might be the problem of dataset design.
Future work#
Validation set performance does not increase when Elman Net become bigger and deeper. Since we can achieve \(100 \%\) accuracy on training set, optimization process seems to be okay. Thus we conclude that Elman Net itself might be the cause of bad generalization phenomenon. We should consider changing models.