aboutsummaryrefslogtreecommitdiff
path: root/imago/engine/keras/denseNeuralNetwork.py
blob: 6a350f771e2e1a5398558eb426b922aa4bcdb6ed (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Dense neural network."""

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.optimizers import Adam

from imago.engine.keras.neuralNetwork import NeuralNetwork

class DenseNeuralNetwork(NeuralNetwork):

    NETWORK_ID = "denseNeuralNetwork"
    DEFAULT_MODEL_FILE = "models/imagoDenseKerasModel.h5"

    def _initModel(self, boardSize=NeuralNetwork.DEF_BOARD_SIZE):
        model = Sequential([
            Dense(
                units=81,
                activation="relu",
                input_shape=(boardSize,boardSize,2)
            ),
            Dense(
                units=81,
                activation="relu"
            ),
            Flatten(),
            Dense(
                units=82,
                activation="softmax"
            ),
        ])

        model.summary()

        model.compile(
                optimizer=Adam(learning_rate=0.0001),
                loss="categorical_crossentropy",
                metrics=["accuracy"]
            )

        return model