aboutsummaryrefslogtreecommitdiff
path: root/tests/test_neuralNetwork.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_neuralNetwork.py')
-rw-r--r--tests/test_neuralNetwork.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/tests/test_neuralNetwork.py b/tests/test_neuralNetwork.py
index dfcbd7a..42ba4a1 100644
--- a/tests/test_neuralNetwork.py
+++ b/tests/test_neuralNetwork.py
@@ -1,8 +1,16 @@
"""Tests for neural network module."""
+import os
+import shutil
import unittest
+from imago.data.enums import DecisionAlgorithms
+from imago.sgfParser.sgf import loadGameTree
+from imago.gameLogic.gameState import GameState
from imago.engine.keras.neuralNetwork import NeuralNetwork
+from imago.engine.keras.denseNeuralNetwork import DenseNeuralNetwork
+from imago.engine.keras.convNeuralNetwork import ConvNeuralNetwork
+from imago.engine.keras.keras import Keras
class TestNeuralNetwork(unittest.TestCase):
"""Test neural network module."""
@@ -13,3 +21,63 @@ class TestNeuralNetwork(unittest.TestCase):
self.assertRaises(NotImplementedError,
NeuralNetwork,
"non/existing/file")
+
+ def testNetworks(self):
+ """Test creation of initial model for dense neural network"""
+
+ testModel = 'testModel'
+ testModelPlot = 'testModelPlot'
+
+ games = []
+ for file in [
+ '../collections/minigo/matches/1.sgf',
+ '../collections/minigo/matches/2.sgf',
+ '../collections/minigo/matches/3.sgf'
+ ]:
+ games.append(loadGameTree(file))
+ matches = [game.getMainLineOfPlay() for game in games]
+
+ nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9)
+ nn.trainModel(matches, epochs=1, verbose=0)
+
+ game = GameState(9)
+ nn.pickMove(game.lastMove, game.getCurrentPlayer())
+
+ nn.saveModel(testModel)
+ self.assertTrue(os.path.isdir(testModel))
+ shutil.rmtree(testModel, ignore_errors=True)
+
+ nn.saveModel()
+ self.assertTrue(os.path.isdir(testModel))
+ nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9)
+
+ nn.saveModelPlot(testModelPlot)
+ self.assertTrue(os.path.isfile(testModelPlot))
+
+ shutil.rmtree(testModel, ignore_errors=True)
+ os.remove(testModelPlot)
+
+ nn = ConvNeuralNetwork(testModel, boardSize=9)
+
+ def testKeras(self):
+ """Test keras model loading."""
+
+ gameState = GameState(9)
+ move = gameState.lastMove
+
+ keras = Keras(move)
+ keras.forceNextMove("pass")
+
+ keras = Keras(move, DecisionAlgorithms.DENSE)
+ keras.forceNextMove((3,3))
+
+ keras = Keras(move, DecisionAlgorithms.CONV)
+ self.assertRaises(RuntimeError, keras.forceNextMove, "wrongmove")
+ pickedCoords = keras.pickMove()
+ self.assertTrue(len(pickedCoords) == 2 or pickedCoords == "pass")
+
+ self.assertRaises(RuntimeError, Keras, move, DecisionAlgorithms.MONTECARLO)
+
+
+if __name__ == '__main__':
+ unittest.main()