FluxArchitectures: DA-RNN

By: Sören Dobberschütz

Re-posted from: http://sdobber.github.io/FA_DARNN/

The next model in the FluxArchitectures repository is the “Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction”, based on the paper by Qin et. al., 2017. It claims to have a better performance than the previously implemented LSTNet, with the additional advantage that an attention mechanism automatically tries to determine important parts of the time series, instead of introducing parameters that need to be optimized by the user.

Model Architecture

The neural network has a rather complex structure. Starting with an encoder-decoder structure, it consists of two units, one called the input attention mechanism, and a temporal attention mechanism.

  • The input attention mechanism feeds the input data to a LSTM network. In subsequent calculations, only its hidden state is used, where additional network layers try to estimate the importance of different hidden variables.

  • The temporal attention mechanism takes the hidden state of the encoder network and combines it with the hidden state of another LSTM decoder. Additional network layers try again to estimate the importance of the hidden variables of the encoder and decoder combined.

  • Linear layers combine the output of different layers to the final time series prediction.

Our implementation follows the one for PyTorch. We start out by creating a struct to hold all the necessary elements:

mutable struct DARNNCell{A, B, C, D, E, F, W, X, Y, Z}
  # Encoder part
	encoder_lstm::A
	encoder_attn::B
  # Decoder part
	decoder_lstm::C
	decoder_attn::D
	decoder_fc::E
	decoder_fc_final::F
  # Index for original data etc
	encodersize::W
	decodersize::X
	orig_idx::Y
	poollength::Z
end

In addition to the layers we need for constructing the DA-RNN network, we also store some metadata that are needed for the calculations: The size of the encoder and decoder network, the index orig_idx describing where in the input data the original time series can be found, and the number of time steps that the input data was pooled (corresponding to T in the following picture).

The constructor initializes all layers with their correct size:

function DARNN(inp::Integer, encodersize::Integer, decodersize::Integer, poollength::Integer, orig_idx::Integer)
	# Encoder part
	encoder_lstm = LSTM(inp, encodersize)
	encoder_attn = Chain(Dense(2*encodersize + poollength, poollength),
	                    a -> tanh.(a),
	                    Dense(poollength,1)
	# Decoder part
	decoder_lstm = LSTM(1, decodersize)
	decoder_attn = Chain(Dense(2*decodersize + encodersize, encodersize),
	                    a -> tanh.(a),
	                    Dense(encodersize, 1))
	decoder_fc = Dense(encodersize + 1, 1)
	decoder_fc_final = Dense(decodersize + encodersize, 1)

	return DARNNCell(encoder_lstm, encoder_attn, decoder_lstm, decoder_attn, decoder_fc,
	 		  decoder_fc_final, encodersize, decodersize, orig_idx, poollength)
end

Encoder network

Model Structure Encoder

Image from Qin et. al., “Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction”, ArXiv, 2017.

We use the same input data format as for the previous LSTnet layer, i.e. “Number of input features x Number of pooled timesteps x 1 x Number of data points”. Before feeding the data to the encoder, we drop the singleton dimension: input_data = dropdims(x; dims=3).

The encoder loops over the pooled timesteps to perform a scaling of the input data: It extracts the hidden state and cell state of the encoder LSTM layer, concatenates it with the input data and feeds it to the attention network. Using a softmax function, we obtain the scaling for the input data for timestep t, which is fed to the LSTM network. In the following code, we indicate the equation numbers from the paper cited in the introduction.

for t in 1:m.poollength
  hidden = m.encoder_lstm.state[1]
  cell = m.encoder_lstm.state[2]

	# Eq. (8)
  x = cat(repeat(hidden, inner=(1,1,size(input_data,1))),
          repeat(cell, inner=(1,1,size(input_data,1))),
          permutedims(input_data,[2,3,1]), dims=1) |>  # (2*encodersize + poollength) x datapoints x features
			a -> reshape(a, (:, size(input_data,1)*size(input_data,3))) |>  # (2*encodersize + poollength) x (features * datapoints)
      m.encoder_attn  # features * datapoints

	# Eq. (9)
	attn_weights = Flux.softmax( reshape(x, (size(input_data,1), size(input_data,3))))  # features x datapoints
	# Eq. (10)
	weighted_input = attn_weights .* input_data[:,t,:]  # features x datapoints
	# Eq. (11)
	_ = m.encoder_lstm(weighted_input)

  input_encoded[:,t,:] = Flux.unsqueeze(m.encoder_lstm.state[1],2)  # features x 1 x datapoints
end

In order to make this code trainable by Flux, we wrap the input_encoded into a Zygote.Buffer structure, and return copy(input_encoded).

Decoder Network

Model Structure Encoder

Image from Qin et. al., “Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction”, ArXiv, 2017.

The decoder operates on input_encoded from the encoder, i.e. a collection of hidden states of the encoder LSTM network. It also loops over the pooled timesteps to calculate an “attention weight” to find relevant encoder hidden states and to calculate a “context vector” as a weighted sum of hidden states.

for t in 1:m.poollength
	# Extract hidden state and cell state from decoder
	hidden = m.decoder_lstm.state[1]
  cell = m.decoder_lstm.state[2]

	# Eq. (12) - (13)
	x = cat(permutedims(repeat(hidden, inner=(1,1,m.poollength)), [1,3,2]),
          permutedims(repeat(cell, inner=(1,1,m.poollength)), [1,3,2]),
          input_encoded, dims=1) |>  # (2*decodersize + encodersize) x poollength x datapoints
      a -> reshape(a, (2*m.decodersize + m.encodersize,:)) |>  #  (2*decodersize + encodersize) x (poollength * datapoints)
			m.decoder_attn |>  # poollength * datapoints
			a -> Flux.softmax(reshape(a, (m.poollength,:)))  # poollength x datapoints

	# Eq. (14)
	context = dropdims(NNlib.batched_mul(input_encoded, Flux.unsqueeze(x,2)), dims=2)  # encodersize x datapoints
	# Eq. (15)
	 = m.decoder_fc(cat(context, input_data[m.orig_idx,t,:]', dims=1))  # 1 x datapoints
	# Eq. (16)
	_ = m.decoder_lstm()
end

The decoder returns the context vector context of the last timestep.

Final Output

The final model output is obtained by feeding the encoder output to the decoder, and calling the final Dense layer on the concatenation of the decoder hidden state and the context vector:

function (m::DARNNCell)(x)
	# Initialization code missing...

	input_data = dropdims(x; dims=3)
	input_encoded = darnn_encoder(m, input_data)
	context = darnn_decoder(m, input_encoded, input_data)
	# Eq. (22)
	return m.decoder_fc_final( cat(m.decoder_lstm.state[1], context, dims=1))
end

Helper functions

To make sure that Flux knows which parameters to train, and how to reset the model, we define

Flux.trainable(m::DARNNCell) = (m.encoder_lstm, m.encoder_attn, m.decoder_lstm,
    m.decoder_attn, m.decoder_fc, m.decoder_fc_final)
Flux.reset!(m::DARNNCell) = Flux.reset!.((m.encoder_lstm, m.decoder_lstm))

When the DA-RNN network is reset, the number of hidden states in the LSTM units does not have the desired size. To initialize them, we feed input data of the right size manually to those layers:

function darnn_init(m::DARNNCell,x)
	m.encoder_lstm(x[:,1,1,:])
	m.decoder_lstm(x[m.orig_idx,1,1,:]')
	return nothing
end
Flux.Zygote.@nograd darnn_init