diff options
Diffstat (limited to 'tests/test_neuralNetwork.py')
-rw-r--r-- | tests/test_neuralNetwork.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/tests/test_neuralNetwork.py b/tests/test_neuralNetwork.py new file mode 100644 index 0000000..42ba4a1 --- /dev/null +++ b/tests/test_neuralNetwork.py @@ -0,0 +1,83 @@ +"""Tests for neural network module.""" + +import os +import shutil +import unittest + +from imago.data.enums import DecisionAlgorithms +from imago.sgfParser.sgf import loadGameTree +from imago.gameLogic.gameState import GameState +from imago.engine.keras.neuralNetwork import NeuralNetwork +from imago.engine.keras.denseNeuralNetwork import DenseNeuralNetwork +from imago.engine.keras.convNeuralNetwork import ConvNeuralNetwork +from imago.engine.keras.keras import Keras + +class TestNeuralNetwork(unittest.TestCase): + """Test neural network module.""" + + def testLoadBaseClass(self): + """Test error when creating model with the base NeuralNetwork class""" + + self.assertRaises(NotImplementedError, + NeuralNetwork, + "non/existing/file") + + def testNetworks(self): + """Test creation of initial model for dense neural network""" + + testModel = 'testModel' + testModelPlot = 'testModelPlot' + + games = [] + for file in [ + '../collections/minigo/matches/1.sgf', + '../collections/minigo/matches/2.sgf', + '../collections/minigo/matches/3.sgf' + ]: + games.append(loadGameTree(file)) + matches = [game.getMainLineOfPlay() for game in games] + + nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9) + nn.trainModel(matches, epochs=1, verbose=0) + + game = GameState(9) + nn.pickMove(game.lastMove, game.getCurrentPlayer()) + + nn.saveModel(testModel) + self.assertTrue(os.path.isdir(testModel)) + shutil.rmtree(testModel, ignore_errors=True) + + nn.saveModel() + self.assertTrue(os.path.isdir(testModel)) + nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9) + + nn.saveModelPlot(testModelPlot) + self.assertTrue(os.path.isfile(testModelPlot)) + + shutil.rmtree(testModel, ignore_errors=True) + os.remove(testModelPlot) + + nn = ConvNeuralNetwork(testModel, boardSize=9) + + def testKeras(self): + """Test keras model loading.""" + + gameState = GameState(9) + move = gameState.lastMove + + keras = Keras(move) + keras.forceNextMove("pass") + + keras = Keras(move, DecisionAlgorithms.DENSE) + keras.forceNextMove((3,3)) + + keras = Keras(move, DecisionAlgorithms.CONV) + self.assertRaises(RuntimeError, keras.forceNextMove, "wrongmove") + pickedCoords = keras.pickMove() + self.assertTrue(len(pickedCoords) == 2 or pickedCoords == "pass") + + self.assertRaises(RuntimeError, Keras, move, DecisionAlgorithms.MONTECARLO) + + +if __name__ == '__main__': + unittest.main() |