How to Develop an Encoder-Decoder Model with Attention for Sequence-to-Sequence Prediction in Keras

import tensorflow as tf

from keras import backend as K

from keras import regularizers, constraints, initializers, activations

from keras.layers.recurrent import Recurrent, _time_distributed_dense

from keras.engine import InputSpec

 

tfPrint = lambda d, T: tf.Print(input_=T, data=[T, tf.shape(T)], message=d)

 

class AttentionDecoder(Recurrent):

 

    def __init__(self, units, output_dim,

                 activation=‘tanh’,

                 return_probabilities=False,

                 name=‘AttentionDecoder’,

                 kernel_initializer=‘glorot_uniform’,

                 recurrent_initializer=‘orthogonal’,

                 bias_initializer=‘zeros’,

                 kernel_regularizer=None,

                 bias_regularizer=None,

                 activity_regularizer=None,

                 kernel_constraint=None,

                 bias_constraint=None,

                 **kwargs):

        “”

        Implements an AttentionDecoder that takes in a sequence encoded by an

        encoder and outputs the decoded states

        :param units: dimension of the hidden state and the attention matrices

        :param output_dim: the number of labels in the output space

 

        references:

            Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio.

            “Neural machine translation by jointly learning to align and translate.

            arXiv preprint arXiv:1409.0473 (2014).

        ““”

        self.units = units

        self.output_dim = output_dim

        self.return_probabilities = return_probabilities

        self.activation = activations.get(activation)

        self.kernel_initializer = initializers.get(kernel_initializer)

        self.recurrent_initializer = initializers.get(recurrent_initializer)

        self.bias_initializer = initializers.get(bias_initializer)

 

        self.kernel_regularizer = regularizers.get(kernel_regularizer)

        self.recurrent_regularizer = regularizers.get(kernel_regularizer)

        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.activity_regularizer = regularizers.get(activity_regularizer)

 

        self.kernel_constraint = constraints.get(kernel_constraint)

        self.recurrent_constraint = constraints.get(kernel_constraint)

        self.bias_constraint = constraints.get(bias_constraint)

 

        super(AttentionDecoder, self).__init__(**kwargs)

        self.name = name

        self.return_sequences = True  # must return sequences

 

    def build(self, input_shape):

        “”

          See Appendix 2 of Bahdanau 2014, arXiv:1409.0473

          for model details that correspond to the matrices here.

        ““”

 

        self.batch_size, self.timesteps, self.input_dim = input_shape

 

        if self.stateful:

            super(AttentionDecoder, self).reset_states()

 

        self.states = [None, None]  # y, s

 

        “”

            Matrices for creating the context vector

        ““”

 

        self.V_a = self.add_weight(shape=(self.units,),

                                   name=‘V_a’,

                                   initializer=self.kernel_initializer,

                                   regularizer=self.kernel_regularizer,

                                   constraint=self.kernel_constraint)

        self.W_a = self.add_weight(shape=(self.units, self.units),

                                   name=‘W_a’,

                                   initializer=self.kernel_initializer,

                                   regularizer=self.kernel_regularizer,

                                   constraint=self.kernel_constraint)

        self.U_a = self.add_weight(shape=(self.input_dim, self.units),

                                   name=‘U_a’,

                                   initializer=self.kernel_initializer,

                                   regularizer=self.kernel_regularizer,

                                   constraint=self.kernel_constraint)

        self.b_a = self.add_weight(shape=(self.units,),

                                   name=‘b_a’,

                                   initializer=self.bias_initializer,

                                   regularizer=self.bias_regularizer,

                                   constraint=self.bias_constraint)

        “”

            Matrices for the r (reset) gate

        ““”

        self.C_r = self.add_weight(shape=(self.input_dim, self.units),

                                   name=‘C_r’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.U_r = self.add_weight(shape=(self.units, self.units),

                                   name=‘U_r’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.W_r = self.add_weight(shape=(self.output_dim, self.units),

                                   name=‘W_r’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.b_r = self.add_weight(shape=(self.units, ),

                                   name=‘b_r’,

                                   initializer=self.bias_initializer,

                                   regularizer=self.bias_regularizer,

                                   constraint=self.bias_constraint)

 

        “”

            Matrices for the z (update) gate

        ““”

        self.C_z = self.add_weight(shape=(self.input_dim, self.units),

                                   name=‘C_z’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.U_z = self.add_weight(shape=(self.units, self.units),

                                   name=‘U_z’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.W_z = self.add_weight(shape=(self.output_dim, self.units),

                                   name=‘W_z’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.b_z = self.add_weight(shape=(self.units, ),

                                   name=‘b_z’,

                                   initializer=self.bias_initializer,

                                   regularizer=self.bias_regularizer,

                                   constraint=self.bias_constraint)

        “”

            Matrices for the proposal

        ““”

        self.C_p = self.add_weight(shape=(self.input_dim, self.units),

                                   name=‘C_p’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.U_p = self.add_weight(shape=(self.units, self.units),

                                   name=‘U_p’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.W_p = self.add_weight(shape=(self.output_dim, self.units),

                                   name=‘W_p’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.b_p = self.add_weight(shape=(self.units, ),

                                   name=‘b_p’,

                                   initializer=self.bias_initializer,

                                   regularizer=self.bias_regularizer,

                                   constraint=self.bias_constraint)

        “”

            Matrices for making the final prediction vector

        ““”

        self.C_o = self.add_weight(shape=(self.input_dim, self.output_dim),

                                   name=‘C_o’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.U_o = self.add_weight(shape=(self.units, self.output_dim),

                                   name=‘U_o’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.W_o = self.add_weight(shape=(self.output_dim, self.output_dim),

                                   name=‘W_o’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

        self.b_o = self.add_weight(shape=(self.output_dim, ),

                                   name=‘b_o’,

                                   initializer=self.bias_initializer,

                                   regularizer=self.bias_regularizer,

                                   constraint=self.bias_constraint)

 

        # For creating the initial state:

        self.W_s = self.add_weight(shape=(self.input_dim, self.units),

                                   name=‘W_s’,

                                   initializer=self.recurrent_initializer,

                                   regularizer=self.recurrent_regularizer,

                                   constraint=self.recurrent_constraint)

 

        self.input_spec = [

            InputSpec(shape=(self.batch_size, self.timesteps, self.input_dim))]

        self.built = True

 

    def call(self, x):

        # store the whole sequence so we can “attend” to it at each timestep

        self.x_seq = x

 

        # apply the a dense layer over the time dimension of the sequence

        # do it here because it doesn’t depend on any previous steps

        # thefore we can save computation time:

        self._uxpb = _time_distributed_dense(self.x_seq, self.U_a, b=self.b_a,

                                             input_dim=self.input_dim,

                                             timesteps=self.timesteps,

                                             output_dim=self.units)

 

        return super(AttentionDecoder, self).call(x)

 

    def get_initial_state(self, inputs):

        # apply the matrix on the first time step to get the initial s0.

        s0 = activations.tanh(K.dot(inputs[:, 0], self.W_s))

 

        # from keras.layers.recurrent to initialize a vector of (batchsize,

        # output_dim)

        y0 = K.zeros_like(inputs)  # (samples, timesteps, input_dims)

        y0 = K.sum(y0, axis=(1, 2))  # (samples, )

        y0 = K.expand_dims(y0)  # (samples, 1)

        y0 = K.tile(y0, [1, self.output_dim])

 

        return [y0, s0]

 

    def step(self, x, states):

 

        ytm, stm = states

 

        # repeat the hidden state to the length of the sequence

        _stm = K.repeat(stm, self.timesteps)

 

        # now multiplty the weight matrix with the repeated hidden state

        _Wxstm = K.dot(_stm, self.W_a)

 

        # calculate the attention probabilities

        # this relates how much other timesteps contributed to this one.

        et = K.dot(activations.tanh(_Wxstm + self._uxpb),

                   K.expand_dims(self.V_a))

        at = K.exp(et)

        at_sum = K.sum(at, axis=1)

        at_sum_repeated = K.repeat(at_sum, self.timesteps)

        at /= at_sum_repeated  # vector of size (batchsize, timesteps, 1)

 

        # calculate the context vector

        context = K.squeeze(K.batch_dot(at, self.x_seq, axes=1), axis=1)

        # ~~~> calculate new hidden state

        # first calculate the “r” gate:

 

        rt = activations.sigmoid(

            K.dot(ytm, self.W_r)

            + K.dot(stm, self.U_r)

            + K.dot(context, self.C_r)

            + self.b_r)

 

        # now calculate the “z” gate

        zt = activations.sigmoid(

            K.dot(ytm, self.W_z)

            + K.dot(stm, self.U_z)

            + K.dot(context, self.C_z)

            + self.b_z)

 

        # calculate the proposal hidden state:

        s_tp = activations.tanh(

            K.dot(ytm, self.W_p)

            + K.dot((rt * stm), self.U_p)

            + K.dot(context, self.C_p)

            + self.b_p)

 

        # new hidden state:

        st = (1zt)*stm + zt * s_tp

 

        yt = activations.softmax(

            K.dot(ytm, self.W_o)

            + K.dot(stm, self.U_o)

            + K.dot(context, self.C_o)

            + self.b_o)

 

        if self.return_probabilities:

            return at, [yt, st]

        else:

            return yt, [yt, st]

 

    def compute_output_shape(self, input_shape):

        “”

            For Keras internal compatability checking

        ““”

        if self.return_probabilities:

            return (None, self.timesteps, self.timesteps)

        else:

            return (None, self.timesteps, self.output_dim)

 

    def get_config(self):

        “”

            For rebuilding models on load time.

        ““”

        config =

            ‘output_dim’: self.output_dim,

            ‘units’: self.units,

            ‘return_probabilities’: self.return_probabilities

        

        base_config = super(AttentionDecoder, self).get_config()

        return dict(list(base_config.items()) + list(config.items()))