"""Tests for neural network module.""" import os import shutil import unittest from imago.data.enums import DecisionAlgorithms from imago.sgfParser.sgf import loadGameTree from imago.gameLogic.gameState import GameState from imago.engine.keras.neuralNetwork import NeuralNetwork from imago.engine.keras.denseNeuralNetwork import DenseNeuralNetwork from imago.engine.keras.convNeuralNetwork import ConvNeuralNetwork from imago.engine.keras.keras import Keras class TestNeuralNetwork(unittest.TestCase): """Test neural network module.""" def testLoadBaseClass(self): """Test error when creating model with the base NeuralNetwork class""" self.assertRaises(NotImplementedError, NeuralNetwork, "non/existing/file") def testNetworks(self): """Test creation of initial model for dense neural network""" testModel = 'testModel' testModelPlot = 'testModelPlot' games = [] for file in [ '../collections/minigo/matches/1.sgf', '../collections/minigo/matches/2.sgf', '../collections/minigo/matches/3.sgf' ]: games.append(loadGameTree(file)) matches = [game.getMainLineOfPlay() for game in games] nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9) nn.trainModel(matches, epochs=1, verbose=0) game = GameState(9) nn.pickMove(game.lastMove, game.getCurrentPlayer()) nn.saveModel(testModel) self.assertTrue(os.path.isdir(testModel)) shutil.rmtree(testModel, ignore_errors=True) nn.saveModel() self.assertTrue(os.path.isdir(testModel)) nn = DenseNeuralNetwork(modelPath=testModel, boardSize=9) nn.saveModelPlot(testModelPlot) self.assertTrue(os.path.isfile(testModelPlot)) shutil.rmtree(testModel, ignore_errors=True) os.remove(testModelPlot) nn = ConvNeuralNetwork(testModel, boardSize=9) def testKeras(self): """Test keras model loading.""" gameState = GameState(9) move = gameState.lastMove keras = Keras(move) keras.forceNextMove("pass") keras = Keras(move, DecisionAlgorithms.DENSE) keras.forceNextMove((3,3)) keras = Keras(move, DecisionAlgorithms.CONV) self.assertRaises(RuntimeError, keras.forceNextMove, "wrongmove") pickedCoords = keras.pickMove() self.assertTrue(len(pickedCoords) == 2 or pickedCoords == "pass") self.assertRaises(RuntimeError, Keras, move, DecisionAlgorithms.MONTECARLO) if __name__ == '__main__': unittest.main()