33 std::stringstream stream;
48 stream <<
", entropy=";
52 stream <<
" posterior-stats[" << i <<
"]=";
100 KALDI_ERR <<
"All of the values key-dim, value-dim, " 101 "num-left-inputs and num-right-inputs must be defined.";
122 KALDI_ERR <<
"Config line contains invalid values: " 150 h * input_dim_per_head, input_dim_per_head),
151 c_part(memo->c, 0, out->
NumRows(),
153 out_part(*out, 0, out->
NumRows(),
154 h * output_dim_per_head, output_dim_per_head);
157 return static_cast<void*
>(memo);
169 out->
NumCols() == full_value_dim &&
177 rows_left_context = steps_left_context * io.
num_images;
204 const Memo *memo =
static_cast<const Memo*
>(memo_in);
242 entropy_vec.AddColSumMat(1.0, entropy_mat);
293 const std::string &debug_info,
301 NVTX_RANGE(
"RestrictedAttentionComponent::Backprop");
305 Memo *memo =
static_cast<Memo*
>(memo_in);
311 in_deriv != NULL &&
SameDim(in_value, *in_deriv));
321 in_value_part(in_value, 0, in_value.
NumRows(),
322 h * input_dim_per_head, input_dim_per_head),
323 c_part(c, 0, out_deriv.
NumRows(),
325 out_deriv_part(out_deriv, 0, out_deriv.
NumRows(),
326 h * output_dim_per_head, output_dim_per_head),
327 in_deriv_part(*in_deriv, 0, in_value.
NumRows(),
328 h * input_dim_per_head, input_dim_per_head);
346 out_deriv.
NumCols() == full_value_dim &&
350 SameDim(in_value, *in_deriv) &&
357 rows_left_context = steps_left_context * io.
num_images;
363 queries_deriv(*in_deriv, rows_left_context, out_deriv.
NumRows(),
371 &keys_deriv, &queries_deriv, &values_deriv);
377 std::vector<Index> *input_indexes,
378 std::vector<Index> *output_indexes)
const {
379 using namespace time_height_convolution;
380 ConvolutionComputationIo io;
382 std::vector<Index> new_input_indexes, new_output_indexes;
383 GetIndexes(*input_indexes, *output_indexes, io,
384 &new_input_indexes, &new_output_indexes);
385 input_indexes->swap(new_input_indexes);
386 output_indexes->swap(new_output_indexes);
390 const std::vector<Index> &input_indexes,
391 const std::vector<Index> &output_indexes,
418 int32 first_requested_input =
420 first_required_input =
425 last_required_input =
431 last_t_in <= last_requested_input &&
433 last_t_in >= last_required_input);
439 io->
num_t_in = 1 + (last_requested_input - first_requested_input) / t_step;
443 WriteToken(os, binary,
"<RestrictedAttentionComponent>");
456 WriteToken(os, binary,
"<NumLeftInputsRequired>");
458 WriteToken(os, binary,
"<NumRightInputsRequired>");
470 WriteToken(os, binary,
"</RestrictedAttentionComponent>");
487 ExpectToken(is, binary,
"<NumLeftInputsRequired>");
489 ExpectToken(is, binary,
"<NumRightInputsRequired>");
501 ExpectToken(is, binary,
"</RestrictedAttentionComponent>");
509 const Index &output_index,
510 std::vector<Index> *desired_indexes)
const {
514 desired_indexes->clear();
516 int32 n = output_index.
n, x = output_index.
x,
519 (*desired_indexes)[
i].n =
n;
520 (*desired_indexes)[
i].t = t;
521 (*desired_indexes)[
i].x = x;
529 const Index &output_index,
531 std::vector<Index> *used_inputs)
const {
533 Index index(output_index);
535 if (used_inputs != NULL) {
538 used_inputs->clear();
543 if (input_index_set(index)) {
545 used_inputs->push_back(index);
551 used_inputs->clear();
562 for (
int32 t = first_time_required;
563 t <= last_time_required;
566 if (!input_index_set(index))
576 const std::vector<std::pair<int32, int32> > &n_x_pairs,
578 const std::unordered_set<Index, IndexHasher> &index_set,
579 std::vector<Index> *output_indexes) {
580 output_indexes->resize(static_cast<size_t>(num_t_values) * n_x_pairs.size());
581 std::vector<Index>::iterator out_iter = output_indexes->begin();
582 for (
int32 t = t_start; t < t_start + (t_step * num_t_values); t += t_step) {
583 std::vector<std::pair<int32, int32> >::const_iterator
584 iter = n_x_pairs.begin(), end = n_x_pairs.end();
585 for (; iter != end; ++iter) {
586 out_iter->n = iter->first;
588 out_iter->x = iter->second;
589 if (index_set.count(*out_iter) == 0)
598 const std::vector<Index> &input_indexes,
599 const std::vector<Index> &output_indexes,
601 std::vector<Index> *new_input_indexes,
602 std::vector<Index> *new_output_indexes)
const {
604 std::unordered_set<Index, IndexHasher> input_set, output_set;
605 for (std::vector<Index>::const_iterator iter = input_indexes.begin();
606 iter != input_indexes.end(); ++iter)
607 input_set.insert(*iter);
608 for (std::vector<Index>::const_iterator iter = output_indexes.begin();
609 iter != output_indexes.end(); ++iter)
610 output_set.insert(*iter);
612 std::vector<std::pair<int32, int32> > n_x_pairs;
617 input_set, new_input_indexes);
619 output_set, new_output_indexes);
624 const std::vector<Index> &input_indexes,
625 const std::vector<Index> &output_indexes,
632 std::vector<Index> new_input_indexes, new_output_indexes;
634 &new_input_indexes, &new_output_indexes);
641 output_indexes == new_output_indexes);
654 std::ostream &os,
bool binary)
const {
655 WriteToken(os, binary,
"<RestrictedAttentionComponentPrecomputedIndexes>");
657 io.Write(os, binary);
658 WriteToken(os, binary,
"</RestrictedAttentionComponentPrecomputedIndexes>");
662 std::istream &is,
bool binary) {
664 "<RestrictedAttentionComponentPrecomputedIndexes>",
667 ExpectToken(is, binary,
"</RestrictedAttentionComponentPrecomputedIndexes>");
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void GetIndexes(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo &io, std::vector< Index > *new_input_indexes, std::vector< Index > *new_output_indexes) const
const std::string WholeLine()
virtual void Add(BaseFloat alpha, const Component &other)
This virtual function when called by – an UpdatableComponent adds the parameters of another updatabl...
void GetComputationStructure(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, time_height_convolution::ConvolutionComputationIo *io) const
void AttentionBackward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &output_deriv, CuMatrixBase< BaseFloat > *keys_deriv, CuMatrixBase< BaseFloat > *queries_deriv, CuMatrixBase< BaseFloat > *values_deriv)
Performs the backward pass corresponding to 'AttentionForward', propagating the derivative back to th...
virtual int32 InputDim() const
Returns input-dimension of this component.
void Write(std::ostream &os, bool binary) const
virtual void Read(std::istream &is, bool binary)
Read function (used after we know the type of the Component); accepts input that is missing the token...
Abstract base-class for neural-net components.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
int32 GetVerboseLevel()
Get verbosity level, usually set via command line '–verbose=' switch.
An abstract representation of a set of Indexes.
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
int32 num_left_inputs_required_
void ApplyFloor(Real floor_val)
void BackpropOneHead(const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &c, const CuMatrixBase< BaseFloat > &out_deriv, CuMatrixBase< BaseFloat > *in_deriv) const
void AttentionForward(BaseFloat key_scale, const CuMatrixBase< BaseFloat > &keys, const CuMatrixBase< BaseFloat > &queries, const CuMatrixBase< BaseFloat > &values, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *output)
This is a higher-level interface to the attention code.
void AddMat(Real alpha, const CuMatrixBase< Real > &A, MatrixTransposeType trans=kNoTrans)
*this += alpha * A
static void CreateIndexesVector(const std::vector< std::pair< int32, int32 > > &n_x_pairs, int32 t_start, int32 t_step, int32 num_t_values, const std::unordered_set< Index, IndexHasher > &index_set, std::vector< Index > *output_indexes)
Utility function used in GetIndexes().
This class represents a matrix that's stored on the GPU if we have one, and in memory if not...
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
virtual bool IsComputable(const MiscComputationInfo &misc_info, const Index &output_index, const IndexSet &input_index_set, std::vector< Index > *used_inputs) const
This function only does something interesting for non-simple Components, and it exists to make it pos...
RestrictedAttentionComponent()
int32 num_right_inputs_required_
CuMatrix< double > posterior_stats_
virtual int32 OutputDim() const
Returns output-dimension of this component.
Contains component(s) related to attention models.
virtual void Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, const CuMatrixBase< BaseFloat > &out_deriv, void *memo, Component *to_update, CuMatrixBase< BaseFloat > *in_deriv) const
Backprop function; depending on which of the arguments 'to_update' and 'in_deriv' are non-NULL...
virtual PrecomputedIndexes * Copy() const
time_height_convolution::ConvolutionComputationIo io
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
void ExpectOneOrTwoTokens(std::istream &is, bool binary, const std::string &token1, const std::string &token2)
This function is like ExpectToken but for two tokens, and it will either accept token1 and then token...
virtual void Scale(BaseFloat scale)
This virtual function when called on – an UpdatableComponent scales the parameters by "scale" when c...
void AddDiagMatMat(Real alpha, const CuMatrixBase< Real > &M, MatrixTransposeType transM, const CuMatrixBase< Real > &N, MatrixTransposeType transN, Real beta=1.0)
Add the diagonal of a matrix product: *this = diag(M N), assuming the "trans" arguments are both kNoT...
Vector< double > entropy_stats_
bool SameDim(const MatrixBase< Real > &M, const MatrixBase< Real > &N)
virtual void InitFromConfig(ConfigLine *cfl)
Initialize, from a ConfigLine object.
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void SetZero()
Math operations, some calling kernels.
virtual ComponentPrecomputedIndexes * PrecomputeIndexes(const MiscComputationInfo &misc_info, const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, bool need_backprop) const
This function must return NULL for simple Components.
virtual std::string Type() const
Returns a string such as "SigmoidComponent", describing the type of the object.
virtual void ZeroStats()
Components that provide an implementation of StoreStats should also provide an implementation of Zero...
virtual void GetInputIndexes(const MiscComputationInfo &misc_info, const Index &output_index, std::vector< Index > *desired_indexes) const
This function only does something interesting for non-simple Components.
This class is used for a piece of a CuMatrix.
virtual std::string Info() const
Returns some text-form information about this component, for diagnostics.
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
MatrixIndexT Dim() const
Returns the dimension of the vector.
void Scale(Real alpha)
Multiplies all elements by this constant.
virtual void * Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out) const
Propagate function.
virtual void Write(std::ostream &os, bool binary) const
Write component to stream.
virtual void Read(std::istream &os, bool binary)
virtual void ReorderIndexes(std::vector< Index > *input_indexes, std::vector< Index > *output_indexes) const
This function only does something interesting for non-simple Components.
Matrix for CUDA computing.
MatrixIndexT NumCols() const
void GetComputationIo(const std::vector< Index > &input_indexes, const std::vector< Index > &output_indexes, ConvolutionComputationIo *io)
This function takes lists of input and output indexes to a computation (e.g.
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing='a b c' baz="a b c d='a b' e" and giving you access to the fields, in this case.
RestrictedAttentionComponent implements an attention model with restricted temporal context...
#define KALDI_ASSERT(cond)
virtual void Write(std::ostream &os, bool binary) const
void Read(std::istream &is, bool binary)
I/O functions.
Real * Data()
Returns a pointer to the start of the vector's data.
bool GetValue(const std::string &key, std::string *value)
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
void PropagateOneHead(const time_height_convolution::ConvolutionComputationIo &io, const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *c, CuMatrixBase< BaseFloat > *out) const
virtual void StoreStats(const CuMatrixBase< BaseFloat > &in_value, const CuMatrixBase< BaseFloat > &out_value, void *memo)
This function may store stats on average activation values, and for some component types...
MatrixIndexT NumRows() const
Dimensions.
void SetZero()
Set vector to all zeros.
void AddVec(const Real alpha, const VectorBase< OtherReal > &v)
Add vector : *this = *this + alpha * rv (with casting between floats and doubles) ...
void Read(std::istream &in, bool binary, bool add=false)
Read function using C++ streams.
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
void AddRowSumMat(Real alpha, const CuMatrixBase< Real > &mat, Real beta=1.0)
Sum the rows of the matrix, add to vector.
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
void GetNxList(const std::vector< Index > &indexes, std::vector< std::pair< int32, int32 > > *pairs)
This function outputs a unique, lexicographically sorted list of the pairs of (n, x) values that are ...