diff options
author | InigoGutierrez <inigogf.95@gmail.com> | 2023-06-12 20:16:04 +0200 |
---|---|---|
committer | InigoGutierrez <inigogf.95@gmail.com> | 2023-06-12 20:16:04 +0200 |
commit | d4a81490bf1396089eb3dac5955a3a8e4cb26e37 (patch) | |
tree | f96febc7950c2742bc36f04ab13bff56851f2388 /tests/test_neuralNetwork.py | |
parent | b08408d23186205e71dfc68634021e3236bfb45c (diff) | |
parent | 65ac3a6b050dcb88688cdc2654b1ed6693e9a160 (diff) | |
download | imago-master.tar.gz imago-master.zip |
Diffstat (limited to 'tests/test_neuralNetwork.py')
-rw-r--r-- | tests/test_neuralNetwork.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/tests/test_neuralNetwork.py b/tests/test_neuralNetwork.py new file mode 100644 index 0000000..42ba4a1 --- /dev/null +++ b/tests/test_neuralNetwork.py @@ -0,0 +1,83 @@ +"""Tests for neural network module.""" + +import os +import shutil +import unittest + +from imago.data.enums import DecisionAlgorithms +from imago.sgfParser.sgf import loadGameTree +from imago.gameLogic.gameState import GameState +from imago.engine.keras.neuralNetwork import NeuralNetwork +from imago.engine.keras.denseNeuralNetwork import DenseNeuralNetwork +from imago.engine.keras.convNeuralNetwork import ConvNeuralNetwork +from imago.engine.keras.keras import Keras + +class TestNeuralNetwork(unittest.TestCase): + """Test neural network module.""" + + def testLoadBaseClass(self): + """Test error when creating model with the base NeuralNetwork class""" + + self.assertRaises(NotImplementedError, + NeuralNetwork, + "non/existing/file") + + def testNetworks(self): + """Test creation of initial model for dense neural network""" + + testModel = 'testModel' + testModelPlot = 'testModelPlot' + + games = [] + for file in [ + '../collections/minigo/matches/1.sgf', + '../collections/minigo/matches/2.sgf', + '../collections/minigo/matches/3.sgf' + ]: + games.append(loadGameTree(file)) + matches = [game.getMainLineOfPlay() for game in games] + + nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9) + nn.trainModel(matches, epochs=1, verbose=0) + + game = GameState(9) + nn.pickMove(game.lastMove, game.getCurrentPlayer()) + + nn.saveModel(testModel) + self.assertTrue(os.path.isdir(testModel)) + shutil.rmtree(testModel, ignore_errors=True) + + nn.saveModel() + self.assertTrue(os.path.isdir(testModel)) + nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9) + + nn.saveModelPlot(testModelPlot) + self.assertTrue(os.path.isfile(testModelPlot)) + + shutil.rmtree(testModel, ignore_errors=True) + os.remove(testModelPlot) + + nn = ConvNeuralNetwork(testModel, boardSize=9) + + def testKeras(self): + """Test keras model loading.""" + + gameState = GameState(9) + move = gameState.lastMove + + keras = Keras(move) + keras.forceNextMove("pass") + + keras = Keras(move, DecisionAlgorithms.DENSE) + keras.forceNextMove((3,3)) + + keras = Keras(move, DecisionAlgorithms.CONV) + self.assertRaises(RuntimeError, keras.forceNextMove, "wrongmove") + pickedCoords = keras.pickMove() + self.assertTrue(len(pickedCoords) == 2 or pickedCoords == "pass") + + self.assertRaises(RuntimeError, Keras, move, DecisionAlgorithms.MONTECARLO) + + +if __name__ == '__main__': + unittest.main() |