aboutsummaryrefslogtreecommitdiff
path: root/train.py
blob: 306d6e57376a5b440a9c27d0898f241c77f22573 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#!/usr/bin/python

"""Starts training a keras neural network."""

import sys

from imago.sgfParser.sgf import loadGameTree
from imago.engine.keras.denseNeuralNetwork import DenseNeuralNetwork
from imago.engine.keras.convNeuralNetwork import ConvNeuralNetwork

def main():
    games = []
    for file in sys.argv[1:]:
        print(file)
        games.append(loadGameTree(file))

    if len(games) == 0:
        print("Error: No game files provided. Provide some SGF files as arguments.")
        sys.exit(0)

    matches = [game.getMainLineOfPlay() for game in games]

    modelFile = ""
    boardsize = 9
    nn = DenseNeuralNetwork(modelFile, boardsize)
    #nn = ConvNeuralNetwork(modelFile, boardsize)
    nn.trainModel(matches)
    nn.saveModel()

if __name__ == '__main__':
    main()