aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..0f518d0
--- /dev/null
+++ b/train.py
@@ -0,0 +1,27 @@
+#!/usr/bin/python
+
+"""Starts training a keras neural network."""
+
+import sys
+
+from imago.sgfParser.sgf import loadGameTree
+from imago.engine.keras.denseNeuralNetwork import DenseNeuralNetwork
+from imago.engine.keras.convNeuralNetwork import ConvNeuralNetwork
+
+def main():
+ games = []
+ for file in sys.argv[1:]:
+ print(file)
+ games.append(loadGameTree(file))
+
+ matches = [game.getMainLineOfPlay() for game in games]
+
+ modelFile = ""
+ boardsize = 9
+ nn = DenseNeuralNetwork(modelFile, boardsize)
+ #nn = ConvNeuralNetwork(modelFile, boardsize)
+ nn.trainModel(matches)
+ nn.saveModel()
+
+if __name__ == '__main__':
+ main()