aboutsummaryrefslogtreecommitdiff
path: root/imago/engine/keras/neuralNetwork.py
diff options
context:
space:
mode:
Diffstat (limited to 'imago/engine/keras/neuralNetwork.py')
-rw-r--r--imago/engine/keras/neuralNetwork.py63
1 files changed, 42 insertions, 21 deletions
diff --git a/imago/engine/keras/neuralNetwork.py b/imago/engine/keras/neuralNetwork.py
index d0eb4ae..7eddb9d 100644
--- a/imago/engine/keras/neuralNetwork.py
+++ b/imago/engine/keras/neuralNetwork.py
@@ -81,7 +81,7 @@ class NeuralNetwork:
def _boardToPlayerContext(self, board, player):
"""Converts the board to a 3D matrix with two representations of the board, one
- marking the player's stones and the oter marking the opponent's stones."""
+ marking the player's stones and the other marking the opponent's stones."""
boardRows = len(board)
boardCols = len(board[0])
contextBoard = numpy.zeros((boardRows, boardCols, 2), dtype = float)
@@ -95,14 +95,19 @@ class NeuralNetwork:
return contextBoard
def _movesToTargets(self, moves):
- """Converts the moves to 2D matrices with values zero except for a one on the
- played vertex."""
+ """Converts the moves to 2D matrices with values zero except for a one indicating
+ the played move."""
targets = []
+ targetsSize = self.boardSize * self.boardSize + 1 # Each vertex + 1 for pass
for move in moves:
if len(move.nextMoves) == 0:
continue
- target = numpy.zeros(self.boardSize * self.boardSize, dtype = float)
- target[move.nextMoves[0].getRow() * self.boardSize + move.nextMoves[0].getCol()] = 1
+ target = numpy.zeros(targetsSize, dtype = float)
+ nextMove = move.nextMoves[0]
+ if nextMove.isPass:
+ target[-1] = 1
+ else:
+ target[nextMove.getRow() * self.boardSize + nextMove.getCol()] = 1
targets.append(target.tolist())
return targets
@@ -110,11 +115,12 @@ class NeuralNetwork:
"""Uses the model's predict function to pick the highest valued vertex to play."""
predictionVector = self._predict(gameMove, player)[0]
- prediction = numpy.zeros((self.boardSize, self.boardSize))
+ predictionBoard = numpy.zeros((self.boardSize, self.boardSize))
for row in range(self.boardSize):
for col in range(self.boardSize):
- prediction[row][col] = predictionVector[row * self.boardSize + col]
- self.saveHeatmap(prediction)
+ predictionBoard[row][col] = predictionVector[row * self.boardSize + col]
+ predictionPass = predictionVector[-1]
+ self.saveHeatmap(predictionBoard, predictionPass)
# Search the highest valued vertex which is also playable
playableVertices = gameMove.getPlayableVertices()
@@ -123,11 +129,13 @@ class NeuralNetwork:
hCol = -1
for row in range(self.boardSize):
for col in range(self.boardSize):
- if prediction[row][col] > highest and (row, col) in playableVertices:
+ if predictionBoard[row][col] > highest and (row, col) in playableVertices:
hRow = row
hCol = col
- highest = prediction[row][col]
+ highest = predictionBoard[row][col]
+ if highest < predictionPass:
+ return "pass"
return [hRow, hCol]
def _predict(self, gameMove, player):
@@ -139,29 +147,42 @@ class NeuralNetwork:
batch_size = 1,
verbose = 2)
- def saveHeatmap(self, data):
+ def saveHeatmap(self, data, passChance):
rows = len(data)
cols = len(data[0])
- fig, ax = pyplot.subplots()
- im = ax.imshow(data, cmap="YlGn")
+ fig, (axBoard, axPass) = pyplot.subplots(1, 2, gridspec_kw={'width_ratios': [9, 1]})
+ imBoard = axBoard.imshow(data, cmap="YlGn")
+ axPass.imshow([[passChance]], cmap="YlGn", norm=imBoard.norm)
- # Show all ticks and label them with the respective list entries
- ax.set_xticks(numpy.arange(cols))
- ax.set_xticklabels(self._getLetterLabels(cols))
- ax.set_yticks(numpy.arange(rows))
- ax.set_yticklabels(numpy.arange(rows, 0, -1))
+ # Tick and label the board
+ axBoard.set_xticks(numpy.arange(cols))
+ axBoard.set_xticklabels(self._getLetterLabels(cols))
+ axBoard.set_yticks(numpy.arange(rows))
+ axBoard.set_yticklabels(numpy.arange(rows, 0, -1))
+
+ # Label the pass chance
+ axPass.set_xticks([0])
+ axPass.set_yticks([])
+ axPass.set_xticklabels(["Pass"])
# Loop over data dimensions and create text annotations.
- textColorThreshold = 0.35
+ textColorThreshold = data.max() / 2
for row in range(rows):
for col in range(cols):
textColor = ("k" if data[row, col] < textColorThreshold else "w")
- ax.text(col, row, "%.2f"%(data[row, col]),
+ axBoard.text(col, row, "%.2f"%(data[row, col]),
ha="center", va="center", color=textColor)
- ax.set_title("Heat map of move likelihood")
+ textColor = ("k" if passChance < textColorThreshold else "w")
+ axPass.text(0, 0, "%.2f"%(passChance),
+ ha="center", va="center", color=textColor)
+
+ pyplot.suptitle("Heat map of move likelihood")
+ #axBoard.set_title("Heat map of move likelihood")
fig.tight_layout()
+
+ #pyplot.show()
pyplot.savefig("heatmaps/heatmap_%s_%s_%d.png" %
(
self.NETWORK_ID,