aboutsummaryrefslogtreecommitdiff
path: root/tests/test_neuralNetwork.py
blob: 42ba4a109a1e2800e8e8939db4f7171a1b1a05cb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Tests for neural network module."""

import os
import shutil
import unittest

from imago.data.enums import DecisionAlgorithms
from imago.sgfParser.sgf import loadGameTree
from imago.gameLogic.gameState import GameState
from imago.engine.keras.neuralNetwork import NeuralNetwork
from imago.engine.keras.denseNeuralNetwork import DenseNeuralNetwork
from imago.engine.keras.convNeuralNetwork import ConvNeuralNetwork
from imago.engine.keras.keras import Keras

class TestNeuralNetwork(unittest.TestCase):
    """Test neural network module."""

    def testLoadBaseClass(self):
        """Test error when creating model with the base NeuralNetwork class"""

        self.assertRaises(NotImplementedError,
                          NeuralNetwork,
                          "non/existing/file")

    def testNetworks(self):
        """Test creation of initial model for dense neural network"""

        testModel = 'testModel'
        testModelPlot = 'testModelPlot'

        games = []
        for file in [
                '../collections/minigo/matches/1.sgf',
                '../collections/minigo/matches/2.sgf',
                '../collections/minigo/matches/3.sgf'
                ]:
            games.append(loadGameTree(file))
        matches = [game.getMainLineOfPlay() for game in games]

        nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9)
        nn.trainModel(matches, epochs=1, verbose=0)

        game = GameState(9)
        nn.pickMove(game.lastMove, game.getCurrentPlayer())

        nn.saveModel(testModel)
        self.assertTrue(os.path.isdir(testModel))
        shutil.rmtree(testModel, ignore_errors=True)

        nn.saveModel()
        self.assertTrue(os.path.isdir(testModel))
        nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9)

        nn.saveModelPlot(testModelPlot)
        self.assertTrue(os.path.isfile(testModelPlot))

        shutil.rmtree(testModel, ignore_errors=True)
        os.remove(testModelPlot)

        nn = ConvNeuralNetwork(testModel, boardSize=9)

    def testKeras(self):
        """Test keras model loading."""

        gameState = GameState(9)
        move = gameState.lastMove

        keras = Keras(move)
        keras.forceNextMove("pass")

        keras = Keras(move, DecisionAlgorithms.DENSE)
        keras.forceNextMove((3,3))

        keras = Keras(move, DecisionAlgorithms.CONV)
        self.assertRaises(RuntimeError, keras.forceNextMove, "wrongmove")
        pickedCoords = keras.pickMove()
        self.assertTrue(len(pickedCoords) == 2 or pickedCoords == "pass")

        self.assertRaises(RuntimeError, Keras, move, DecisionAlgorithms.MONTECARLO)


if __name__ == '__main__':
    unittest.main()