Neural Networks¶
-
class
AlphaZero.network.main.Network(game_config, num_gpu=1, train_config='/home/docs/checkouts/readthedocs.org/user_builds/alphazero/checkouts/latest/AlphaZero/network/../config/reinforce.yaml', load_pretrained=False, data_format='NHWC', cluster=<MagicMock name='mock.ClusterSpec()' id='139693218757600'>, job='main')¶ This module defines the network structure and its operations.
Parameters: - game_config – the rules and size of the game
- train_config – defines the size of the network and configurations in model training.
- num_gpu – the number of GPUs used for computation.
- load_pretrained – whether to load the pre-trained model
- data_format – input format, either “NCHW” or “NHWC”. “NCHW” achieves higher performance on GPU, but it’s not compatible with CPU.
- job (cluster,) – for distributed training.
-
update(data)¶ Update the model parameters.
Parameters: data – tuple (state, action, result, ). state is a numpy array of shape [None, filters, board_height, board_width]. action is a numpy array of shape [None, flat_move_output]. result is a numpy array of shape [None]. Returns: Average loss of the minibatch.
-
response(data)¶ Predict the action and result given current state.
Parameters: data – (state, ). state is a numpy array of shape [None, filters, board_height, board_width]. Returns: A tuple (R_p, R_v). R_p is the probability distribution of action, a numpy array of shape [None, 362]. R_v is the expected value of current state, a numpy array of shape [None].
-
evaluate(data)¶ Calculate loss and result based on supervised data.
Parameters: data – tuple (state, action, result, ). state is a numpy array of shape [None, filters, board_height, board_width]. action is a numpy array of shape [None, flat_move_output]. result is a numpy array of shape [None]. Returns: A tuple (loss, acc, mse). loss is the average loss of the minibatch. acc is the position prediction accuracy. mse is the mean squared error of game outcome.
-
get_global_step()¶ Get global step.
-
save(filename)¶ Save the model.
Parameters: filename – prefix to the saved file. The final name is filename + global_step
-
load(filename)¶ Load the model.
Parameters: filename – the name of saved file.
-
class
AlphaZero.network.model.Model(game_config, train_config, data_format='NHWC')¶ Neural network for AlphaGoZero. As described in “Mastering the game of Go without human knowledge”.
Parameters: - game_config – the rules and size of the game
- train_config – defines the size of the network and configurations in model training.
- data_format – input format, either “NCHW” or “NHWC”.