diff options
author | InigoGutierrez <inigogf.95@gmail.com> | 2023-06-12 19:43:40 +0200 |
---|---|---|
committer | InigoGutierrez <inigogf.95@gmail.com> | 2023-06-12 19:43:40 +0200 |
commit | 65ac3a6b050dcb88688cdc2654b1ed6693e9a160 (patch) | |
tree | 19797a3d1a2f897628d0413482117c27c9cfe6b9 /tests/test_neuralNetwork.py | |
parent | a005228a986b17732ae7cccbedde450533cfe1f1 (diff) | |
download | imago-65ac3a6b050dcb88688cdc2654b1ed6693e9a160.tar.gz imago-65ac3a6b050dcb88688cdc2654b1ed6693e9a160.zip |
Submitted version.
Diffstat (limited to 'tests/test_neuralNetwork.py')
-rw-r--r-- | tests/test_neuralNetwork.py | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/tests/test_neuralNetwork.py b/tests/test_neuralNetwork.py index dfcbd7a..42ba4a1 100644 --- a/tests/test_neuralNetwork.py +++ b/tests/test_neuralNetwork.py @@ -1,8 +1,16 @@ """Tests for neural network module.""" +import os +import shutil import unittest +from imago.data.enums import DecisionAlgorithms +from imago.sgfParser.sgf import loadGameTree +from imago.gameLogic.gameState import GameState from imago.engine.keras.neuralNetwork import NeuralNetwork +from imago.engine.keras.denseNeuralNetwork import DenseNeuralNetwork +from imago.engine.keras.convNeuralNetwork import ConvNeuralNetwork +from imago.engine.keras.keras import Keras class TestNeuralNetwork(unittest.TestCase): """Test neural network module.""" @@ -13,3 +21,63 @@ class TestNeuralNetwork(unittest.TestCase): self.assertRaises(NotImplementedError, NeuralNetwork, "non/existing/file") + + def testNetworks(self): + """Test creation of initial model for dense neural network""" + + testModel = 'testModel' + testModelPlot = 'testModelPlot' + + games = [] + for file in [ + '../collections/minigo/matches/1.sgf', + '../collections/minigo/matches/2.sgf', + '../collections/minigo/matches/3.sgf' + ]: + games.append(loadGameTree(file)) + matches = [game.getMainLineOfPlay() for game in games] + + nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9) + nn.trainModel(matches, epochs=1, verbose=0) + + game = GameState(9) + nn.pickMove(game.lastMove, game.getCurrentPlayer()) + + nn.saveModel(testModel) + self.assertTrue(os.path.isdir(testModel)) + shutil.rmtree(testModel, ignore_errors=True) + + nn.saveModel() + self.assertTrue(os.path.isdir(testModel)) + nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9) + + nn.saveModelPlot(testModelPlot) + self.assertTrue(os.path.isfile(testModelPlot)) + + shutil.rmtree(testModel, ignore_errors=True) + os.remove(testModelPlot) + + nn = ConvNeuralNetwork(testModel, boardSize=9) + + def testKeras(self): + """Test keras model loading.""" + + gameState = GameState(9) + move = gameState.lastMove + + keras = Keras(move) + keras.forceNextMove("pass") + + keras = Keras(move, DecisionAlgorithms.DENSE) + keras.forceNextMove((3,3)) + + keras = Keras(move, DecisionAlgorithms.CONV) + self.assertRaises(RuntimeError, keras.forceNextMove, "wrongmove") + pickedCoords = keras.pickMove() + self.assertTrue(len(pickedCoords) == 2 or pickedCoords == "pass") + + self.assertRaises(RuntimeError, Keras, move, DecisionAlgorithms.MONTECARLO) + + +if __name__ == '__main__': + unittest.main() |