pub(super) type Shape = (usize, usize);
Probe shape: (batch_count, seq_length).
(batch_count, seq_length)