天天看點

LSTM with Keras functional API (2)

Train the model – Part 1

After defining the model (see LSTM with Keras functional API (1)), we can create an instance of the MyModel class, compile and fit the dataset same as Sequential:

# build an instance of class Model
model = MyModel()

# copile the instance
model.compile(optimizer=tf.train.AdamOptimizer(0.001), loss='binary_crossentropy', metric=['accuracy'])
#fit the model with dataset
model.fit(x, y, epochs = 100, batch_size = 4)
           

compile: Once we have defined our network, we must compile it. Compilation is an efficiency step that can transform the simple sequence of layers that we defined into a highly efficient series of matrix transformations in a format intended to be executed on your GPU or CPU.

Compilation requires a number of parameters to be specified, specifically tailored to training your network. For example:

  • optimizer:

    'sgd'

    ,

    'adam'

    ,

    'rmsprop'

    , …
  • loss:

    'mean_squared_error'('mse')

    ,

    'binary_crossentropy'

    ,

    'categorical_crossentrpy'

    , …
  • metric:

    ['accuracy']

    , …

fit: Once the network is compiled, it can be fit, which means adapting the weights on a training dataset. Fitting the model requires the training data to be specified, both the input patterns X and the output patterns y. Then the network is trained using backpropagation algorithm and optimized according to the optimization algorithm and loss function specified in the compilation part.

The backpropagation requires that the network be trained for a specified number of epochs. Each epoch can be partitioned into groups of input-output pattern called batches. – This defines the number of patterns that the network is exposed to before the weights are updated within an epoch.

  • epoch: LSTMs may be trained for tens, hundreds, or thousands of epochs.
  • batch: For example in mini-batch gradient descent, common values for batch_size are 32, 64, 128, tailored to the desired efficiency and rate of model updates.

Reference

[1] Keras Documentation: models

[2] Long Short-Term Memory Networks With Python