diff options
Diffstat (limited to 'imago/engine/keras/neuralNetwork.py')
-rw-r--r-- | imago/engine/keras/neuralNetwork.py | 63 |
1 files changed, 42 insertions, 21 deletions
diff --git a/imago/engine/keras/neuralNetwork.py b/imago/engine/keras/neuralNetwork.py index d0eb4ae..7eddb9d 100644 --- a/imago/engine/keras/neuralNetwork.py +++ b/imago/engine/keras/neuralNetwork.py @@ -81,7 +81,7 @@ class NeuralNetwork: def _boardToPlayerContext(self, board, player): """Converts the board to a 3D matrix with two representations of the board, one - marking the player's stones and the oter marking the opponent's stones.""" + marking the player's stones and the other marking the opponent's stones.""" boardRows = len(board) boardCols = len(board[0]) contextBoard = numpy.zeros((boardRows, boardCols, 2), dtype = float) @@ -95,14 +95,19 @@ class NeuralNetwork: return contextBoard def _movesToTargets(self, moves): - """Converts the moves to 2D matrices with values zero except for a one on the - played vertex.""" + """Converts the moves to 2D matrices with values zero except for a one indicating + the played move.""" targets = [] + targetsSize = self.boardSize * self.boardSize + 1 # Each vertex + 1 for pass for move in moves: if len(move.nextMoves) == 0: continue - target = numpy.zeros(self.boardSize * self.boardSize, dtype = float) - target[move.nextMoves[0].getRow() * self.boardSize + move.nextMoves[0].getCol()] = 1 + target = numpy.zeros(targetsSize, dtype = float) + nextMove = move.nextMoves[0] + if nextMove.isPass: + target[-1] = 1 + else: + target[nextMove.getRow() * self.boardSize + nextMove.getCol()] = 1 targets.append(target.tolist()) return targets @@ -110,11 +115,12 @@ class NeuralNetwork: """Uses the model's predict function to pick the highest valued vertex to play.""" predictionVector = self._predict(gameMove, player)[0] - prediction = numpy.zeros((self.boardSize, self.boardSize)) + predictionBoard = numpy.zeros((self.boardSize, self.boardSize)) for row in range(self.boardSize): for col in range(self.boardSize): - prediction[row][col] = predictionVector[row * self.boardSize + col] - self.saveHeatmap(prediction) + predictionBoard[row][col] = predictionVector[row * self.boardSize + col] + predictionPass = predictionVector[-1] + self.saveHeatmap(predictionBoard, predictionPass) # Search the highest valued vertex which is also playable playableVertices = gameMove.getPlayableVertices() @@ -123,11 +129,13 @@ class NeuralNetwork: hCol = -1 for row in range(self.boardSize): for col in range(self.boardSize): - if prediction[row][col] > highest and (row, col) in playableVertices: + if predictionBoard[row][col] > highest and (row, col) in playableVertices: hRow = row hCol = col - highest = prediction[row][col] + highest = predictionBoard[row][col] + if highest < predictionPass: + return "pass" return [hRow, hCol] def _predict(self, gameMove, player): @@ -139,29 +147,42 @@ class NeuralNetwork: batch_size = 1, verbose = 2) - def saveHeatmap(self, data): + def saveHeatmap(self, data, passChance): rows = len(data) cols = len(data[0]) - fig, ax = pyplot.subplots() - im = ax.imshow(data, cmap="YlGn") + fig, (axBoard, axPass) = pyplot.subplots(1, 2, gridspec_kw={'width_ratios': [9, 1]}) + imBoard = axBoard.imshow(data, cmap="YlGn") + axPass.imshow([[passChance]], cmap="YlGn", norm=imBoard.norm) - # Show all ticks and label them with the respective list entries - ax.set_xticks(numpy.arange(cols)) - ax.set_xticklabels(self._getLetterLabels(cols)) - ax.set_yticks(numpy.arange(rows)) - ax.set_yticklabels(numpy.arange(rows, 0, -1)) + # Tick and label the board + axBoard.set_xticks(numpy.arange(cols)) + axBoard.set_xticklabels(self._getLetterLabels(cols)) + axBoard.set_yticks(numpy.arange(rows)) + axBoard.set_yticklabels(numpy.arange(rows, 0, -1)) + + # Label the pass chance + axPass.set_xticks([0]) + axPass.set_yticks([]) + axPass.set_xticklabels(["Pass"]) # Loop over data dimensions and create text annotations. - textColorThreshold = 0.35 + textColorThreshold = data.max() / 2 for row in range(rows): for col in range(cols): textColor = ("k" if data[row, col] < textColorThreshold else "w") - ax.text(col, row, "%.2f"%(data[row, col]), + axBoard.text(col, row, "%.2f"%(data[row, col]), ha="center", va="center", color=textColor) - ax.set_title("Heat map of move likelihood") + textColor = ("k" if passChance < textColorThreshold else "w") + axPass.text(0, 0, "%.2f"%(passChance), + ha="center", va="center", color=textColor) + + pyplot.suptitle("Heat map of move likelihood") + #axBoard.set_title("Heat map of move likelihood") fig.tight_layout() + + #pyplot.show() pyplot.savefig("heatmaps/heatmap_%s_%s_%d.png" % ( self.NETWORK_ID, |