aboutsummaryrefslogtreecommitdiff
path: root/tests/test_neuralNetwork.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_neuralNetwork.py')
-rw-r--r--tests/test_neuralNetwork.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/tests/test_neuralNetwork.py b/tests/test_neuralNetwork.py
new file mode 100644
index 0000000..42ba4a1
--- /dev/null
+++ b/tests/test_neuralNetwork.py
@@ -0,0 +1,83 @@
+"""Tests for neural network module."""
+
+import os
+import shutil
+import unittest
+
+from imago.data.enums import DecisionAlgorithms
+from imago.sgfParser.sgf import loadGameTree
+from imago.gameLogic.gameState import GameState
+from imago.engine.keras.neuralNetwork import NeuralNetwork
+from imago.engine.keras.denseNeuralNetwork import DenseNeuralNetwork
+from imago.engine.keras.convNeuralNetwork import ConvNeuralNetwork
+from imago.engine.keras.keras import Keras
+
+class TestNeuralNetwork(unittest.TestCase):
+ """Test neural network module."""
+
+ def testLoadBaseClass(self):
+ """Test error when creating model with the base NeuralNetwork class"""
+
+ self.assertRaises(NotImplementedError,
+ NeuralNetwork,
+ "non/existing/file")
+
+ def testNetworks(self):
+ """Test creation of initial model for dense neural network"""
+
+ testModel = 'testModel'
+ testModelPlot = 'testModelPlot'
+
+ games = []
+ for file in [
+ '../collections/minigo/matches/1.sgf',
+ '../collections/minigo/matches/2.sgf',
+ '../collections/minigo/matches/3.sgf'
+ ]:
+ games.append(loadGameTree(file))
+ matches = [game.getMainLineOfPlay() for game in games]
+
+ nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9)
+ nn.trainModel(matches, epochs=1, verbose=0)
+
+ game = GameState(9)
+ nn.pickMove(game.lastMove, game.getCurrentPlayer())
+
+ nn.saveModel(testModel)
+ self.assertTrue(os.path.isdir(testModel))
+ shutil.rmtree(testModel, ignore_errors=True)
+
+ nn.saveModel()
+ self.assertTrue(os.path.isdir(testModel))
+ nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9)
+
+ nn.saveModelPlot(testModelPlot)
+ self.assertTrue(os.path.isfile(testModelPlot))
+
+ shutil.rmtree(testModel, ignore_errors=True)
+ os.remove(testModelPlot)
+
+ nn = ConvNeuralNetwork(testModel, boardSize=9)
+
+ def testKeras(self):
+ """Test keras model loading."""
+
+ gameState = GameState(9)
+ move = gameState.lastMove
+
+ keras = Keras(move)
+ keras.forceNextMove("pass")
+
+ keras = Keras(move, DecisionAlgorithms.DENSE)
+ keras.forceNextMove((3,3))
+
+ keras = Keras(move, DecisionAlgorithms.CONV)
+ self.assertRaises(RuntimeError, keras.forceNextMove, "wrongmove")
+ pickedCoords = keras.pickMove()
+ self.assertTrue(len(pickedCoords) == 2 or pickedCoords == "pass")
+
+ self.assertRaises(RuntimeError, Keras, move, DecisionAlgorithms.MONTECARLO)
+
+
+if __name__ == '__main__':
+ unittest.main()