Neural Networks

class AlphaZero.network.main.Network(game_config, num_gpu=1, train_config='/home/docs/checkouts/readthedocs.org/user_builds/alphazero/checkouts/latest/AlphaZero/network/../config/reinforce.yaml', load_pretrained=False, data_format='NHWC', cluster=<MagicMock name='mock.ClusterSpec()' id='139693218757600'>, job='main')

This module defines the network structure and its operations.

Parameters:
  • game_config – the rules and size of the game
  • train_config – defines the size of the network and configurations in model training.
  • num_gpu – the number of GPUs used for computation.
  • load_pretrained – whether to load the pre-trained model
  • data_format – input format, either “NCHW” or “NHWC”. “NCHW” achieves higher performance on GPU, but it’s not compatible with CPU.
  • job (cluster,) – for distributed training.
update(data)

Update the model parameters.

Parameters:data – tuple (state, action, result, ). state is a numpy array of shape [None, filters, board_height, board_width]. action is a numpy array of shape [None, flat_move_output]. result is a numpy array of shape [None].
Returns:Average loss of the minibatch.
response(data)

Predict the action and result given current state.

Parameters:data(state, ). state is a numpy array of shape [None, filters, board_height, board_width].
Returns:A tuple (R_p, R_v). R_p is the probability distribution of action, a numpy array of shape [None, 362]. R_v is the expected value of current state, a numpy array of shape [None].
evaluate(data)

Calculate loss and result based on supervised data.

Parameters:data – tuple (state, action, result, ). state is a numpy array of shape [None, filters, board_height, board_width]. action is a numpy array of shape [None, flat_move_output]. result is a numpy array of shape [None].
Returns:A tuple (loss, acc, mse). loss is the average loss of the minibatch. acc is the position prediction accuracy. mse is the mean squared error of game outcome.
get_global_step()

Get global step.

save(filename)

Save the model.

Parameters:filename – prefix to the saved file. The final name is filename + global_step
load(filename)

Load the model.

Parameters:filename – the name of saved file.
class AlphaZero.network.model.Model(game_config, train_config, data_format='NHWC')

Neural network for AlphaGoZero. As described in “Mastering the game of Go without human knowledge”.

Parameters:
  • game_config – the rules and size of the game
  • train_config – defines the size of the network and configurations in model training.
  • data_format – input format, either “NCHW” or “NHWC”.