-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexample.py
More file actions
68 lines (54 loc) · 1.94 KB
/
example.py
File metadata and controls
68 lines (54 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from classes.rnn import RNN
import numpy as np
data = open('anna.txt', 'r').read().lower().replace('\n', ' ')
chars = list(set(data))
data_size, vocab_size = len(data), len(chars)
print('data has %d characters, %d unique.' % (data_size, vocab_size))
char_to_ix = {ch: i for i, ch in enumerate(chars)}
ix_to_char = {i: ch for i, ch in enumerate(chars)}
seq_length = 128
learning_rate = 1e-3
rnn = RNN(
[
{'type': 'lstm', 'hidden_size': 256},
{'type': 'lstm', 'hidden_size': 256},
{'type': 'lstm', 'hidden_size': 256},
# {'type': 'lstm', 'hidden_size': vocab_size, 'dropout': 0.5, 'bi': True, 'u_type': 'adagrad'}
],
vocab_size, learning_rate)
print(rnn.archi)
print('with seq_length {}'.format(seq_length))
n, p = 0, 0
def sample(seed_ix, n):
x = np.zeros((1, vocab_size, 1))
x[0][seed_ix][0] = 1
ixes = []
rnn.reset_h_predict_to_h()
for t in range(n):
h, y = rnn.predict(x)
p = np.exp(y) / np.sum(np.exp(y))
ix = np.random.choice(range(vocab_size), p=p.ravel())
x = np.zeros((1, vocab_size, 1))
x[0][ix][0] = 1
ixes.append(ix)
return ixes
smooth_loss = -np.log(1.0 / vocab_size) # loss at iteration 0
while True:
if p + seq_length + 1 >= len(data) or n == 0:
p = 0
rnn.reset_h()
inputs = [char_to_ix[ch] for ch in data[p:p + seq_length]]
targets = [char_to_ix[ch] for ch in data[p + 1:p + seq_length + 1]]
x, y = np.zeros((seq_length, vocab_size, 1), np.float32), np.zeros((seq_length, vocab_size, 1), np.float32)
x[range(len(x)), inputs] = 1
y[:, targets] = 1
if n % 1000 == 0:
sample_ix = sample(inputs[0], 300)
txt = ''.join(ix_to_char[ix] for ix in sample_ix)
print('----\n %s \n----' % (txt,))
loss = rnn.epoch(x, targets)
smooth_loss = smooth_loss * 0.999 + loss * 0.001
if n % 100 == 0:
print('iter %d, loss: %f' % (n, smooth_loss))
p += seq_length
n += 1