nnet-descriptor.h
Go to the documentation of this file.
1 // nnet3/nnet-descriptor.h
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_DESCRIPTOR_H_
21 #define KALDI_NNET3_NNET_DESCRIPTOR_H_
22 
23 #include "base/kaldi-common.h"
24 #include "util/kaldi-io.h"
25 #include "matrix/matrix-lib.h"
26 #include "nnet3/nnet-common.h"
28 
29 #include <iostream>
30 #include <sstream>
31 #include <vector>
32 #include <map>
33 
34 
35 namespace kaldi {
36 namespace nnet3 {
37 
87 class ForwardingDescriptor {
96  public:
97  // Given an Index that's requested at the output of this descriptor, maps it
98  // to a (node_index, Index) pair that says where we are to get the data from.
99  //
100  virtual Cindex MapToInput(const Index &output) const = 0;
101 
102  // Return the feature dimension.
103  virtual int32 Dim(const Nnet &nnet) const = 0;
104 
105  virtual ForwardingDescriptor *Copy() const = 0;
106 
113  virtual int32 Modulus() const { return 1; }
114 
115  // Write to string that will be one line of a config-file-like format. The
116  // opposite of Parse.
117  virtual void WriteConfig(std::ostream &os,
118  const std::vector<std::string> &node_names) const = 0;
119 
121  // that this descriptor may access.
122  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const = 0;
123 
131  virtual BaseFloat GetScaleForNode(int32 node_index) const = 0;
132 
133  virtual ~ForwardingDescriptor() { }
135  private:
137 };
138 
145  public:
146  virtual Cindex MapToInput(const Index &index) const;
147  virtual int32 Dim(const Nnet &nnet) const;
148  virtual ForwardingDescriptor *Copy() const;
149  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const;
150  virtual BaseFloat GetScaleForNode(int32 node_index) const;
151 
152  // Write to string that will be one line of a config-file-like format. The
153  // opposite of Parse.
154  // written form is just the node-name of src_node_.
155  virtual void WriteConfig(std::ostream &os,
156  const std::vector<std::string> &node_names) const;
157 
159  BaseFloat scale = 1.0):
160  src_node_(src_node), scale_(scale) {
161  KALDI_ASSERT(src_node >= 0);
162  }
164  private:
165  int32 src_node_; // index of the source NetworkNode.
166  BaseFloat scale_; // Scale of the node in the expression; this will be 1.0
167  // unless you used a Scale(...) expression in your
168  // Descriptor.
169 };
170 
176  public:
177  virtual Cindex MapToInput(const Index &ind) const;
178  virtual int32 Dim(const Nnet &nnet) const { return src_->Dim(nnet); }
179  virtual ForwardingDescriptor *Copy() const;
180 
181  // written form is: Offset(<src-written-form>, t-offset [, x-offset])
182  virtual void WriteConfig(std::ostream &os,
183  const std::vector<std::string> &node_names) const;
184 
185  virtual int32 Modulus() const { return src_->Modulus(); }
186 
187  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const;
188  virtual BaseFloat GetScaleForNode(int32 node_index) const;
189 
190  // takes ownership of src.
192  Index offset): src_(src), offset_(offset) { }
193 
194  virtual ~OffsetForwardingDescriptor() { delete src_; }
195 
196 
197  // this function is not in the shared interface. it's used
198  // in class ModelCollapser.
199  const ForwardingDescriptor &Src() const { return *src_; }
200  private:
201  ForwardingDescriptor *src_; // Owned here.
202  Index offset_; // The index-offset to be added to the index.
203 };
204 
211  public:
212  virtual Cindex MapToInput(const Index &ind) const;
213  virtual int32 Dim(const Nnet &nnet) const { return src_[0]->Dim(nnet); }
214  virtual ForwardingDescriptor *Copy() const;
215  // Written form is "Switch(<written-form-of-src1>, <written-form-of-src2>, ... )"
216  virtual void WriteConfig(std::ostream &os,
217  const std::vector<std::string> &node_names) const;
218 
219  virtual int32 Modulus() const;
220 
222  // that this descriptor may access.
223  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const;
224  virtual BaseFloat GetScaleForNode(int32 node_index) const;
225 
226  // takes ownership of items in src.
227  SwitchingForwardingDescriptor(std::vector<ForwardingDescriptor*> &src):
228  src_(src) { }
230  private:
231  // Pointers are owned here.
232  std::vector<ForwardingDescriptor*> src_;
233 };
234 
235 
236 
243  public:
244  virtual Cindex MapToInput(const Index &ind) const;
245  virtual int32 Dim(const Nnet &nnet) const { return src_->Dim(nnet); }
246  virtual ForwardingDescriptor *Copy() const;
247  // Written form is "Round(<written-form-of-src>, <t_modulus>)"
248  virtual void WriteConfig(std::ostream &os,
249  const std::vector<std::string> &node_names) const;
250 
251  virtual int32 Modulus() const { return t_modulus_; }
252 
254  // that this descriptor may access.
255  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const;
256  virtual BaseFloat GetScaleForNode(int32 node_index) const;
257 
258  // takes ownership of src.
260  int32 t_modulus):
261  src_(src), t_modulus_(t_modulus) { }
262 
263  virtual ~RoundingForwardingDescriptor() { delete src_; }
264  private:
267 };
268 
274  public:
275  enum VariableName { kN = 0, kT = 1, kX = 2};
276 
277  virtual Cindex MapToInput(const Index &ind) const;
278  virtual int32 Dim(const Nnet &nnet) const { return src_->Dim(nnet); }
279  virtual ForwardingDescriptor *Copy() const;
280  // Written form is "ReplaceIndex(<written-form-of-src>, <variable-name>, <value>)"
281  // where <variable-name> is either "t" or "x".
282  virtual void WriteConfig(std::ostream &os,
283  const std::vector<std::string> &node_names) const;
284 
286  // that this descriptor may access.
287  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const;
288  virtual BaseFloat GetScaleForNode(int32 node_index) const;
289 
290  // takes ownership of src.
292  VariableName variable_name,
293  int32 value):
294  src_(src), variable_name_(variable_name), value_(value) { }
295 
296  virtual ~ReplaceIndexForwardingDescriptor() { delete src_; }
297  private:
301 };
302 
303 
305 class CindexSet;
306 
316  public:
321  virtual void GetDependencies(const Index &ind,
322  std::vector<Cindex> *dependencies) const = 0;
323 
343  virtual bool IsComputable(const Index &ind,
344  const CindexSet &cindex_set,
345  std::vector<Cindex> *used_inputs) const = 0;
346 
347  virtual int32 Dim(const Nnet &nnet) const = 0;
348 
349  virtual SumDescriptor *Copy() const = 0;
350 
351  virtual ~SumDescriptor() { }
352 
355  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const = 0;
356 
367  virtual BaseFloat GetScaleForNode(int32 node_index) const = 0;
368 
369  // see Modulus function of ForwardingDescriptor for explanation.
370  virtual int32 Modulus() const = 0;
371 
374  virtual void WriteConfig(std::ostream &os,
375  const std::vector<std::string> &node_names) const = 0;
376 };
377 
384  public:
385  virtual void GetDependencies(const Index &ind,
386  std::vector<Cindex> *dependencies) const;
387  virtual bool IsComputable(const Index &ind,
388  const CindexSet &cindex_set,
389  std::vector<Cindex> *used_inputs) const {
390  return src_->IsComputable(ind, cindex_set, used_inputs) || true;
391  }
392 
393  virtual int32 Dim(const Nnet &nnet) const;
394 
395  // This function appends to "node_indexes" a list (not necessarily sorted or
396  // unique) of all the node indexes that this descriptor may forward data from.
397  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const;
398  virtual BaseFloat GetScaleForNode(int32 node_index) const;
399  virtual int32 Modulus() const { return src_->Modulus(); }
402  virtual void WriteConfig(std::ostream &os,
403  const std::vector<std::string> &node_names) const;
404  virtual SumDescriptor *Copy() const;
405 
407  virtual ~OptionalSumDescriptor() { delete src_; }
408  private:
410 };
411 
417  public:
418  virtual void GetDependencies(const Index &ind,
419  std::vector<Cindex> *dependencies) const;
420  virtual bool IsComputable(const Index &ind,
421  const CindexSet &cindex_set,
422  std::vector<Cindex> *used_inputs) const;
423  virtual int32 Dim(const Nnet &nnet) const;
424 
425  virtual BaseFloat GetScaleForNode(int32 node_index) const;
426 
427  // This function appends to "node_indexes" a list (not necessarily sorted or
428  // unique) of all the node indexes that this descriptor may forward data from.
429  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const;
430  virtual int32 Modulus() const { return src_->Modulus(); }
433  virtual void WriteConfig(std::ostream &os,
434  const std::vector<std::string> &node_names) const;
435  virtual SumDescriptor *Copy() const;
436 
438  virtual ~SimpleSumDescriptor() { delete src_; }
439 
440  // this function is not in the shared interface. it's used
441  // in class ModelCollapser.
442  const ForwardingDescriptor &Src() const { return *src_; }
443  private:
445 };
446 
447 
456  public:
457  virtual void GetDependencies(const Index &ind,
458  std::vector<Cindex> *dependencies) const { }
459  virtual bool IsComputable(const Index &ind,
460  const CindexSet &cindex_set,
461  std::vector<Cindex> *used_inputs) const {
462  return true;
463  }
464  virtual int32 Dim(const Nnet &nnet) const { return dim_; }
465  virtual BaseFloat GetScaleForNode(int32 node_index) const;
466 
467  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const { }
468  virtual int32 Modulus() const { return 1; }
471  virtual void WriteConfig(std::ostream &os,
472  const std::vector<std::string> &node_names) const;
473  virtual SumDescriptor *Copy() const;
474 
477  private:
480 };
481 
488  public:
489  enum Operation {
490  kSumOperation, // A + B
491  kFailoverOperation, // A if defined, else B.
492  };
493  virtual void GetDependencies(const Index &ind,
494  std::vector<Cindex> *dependencies) const;
495  virtual bool IsComputable(const Index &ind,
496  const CindexSet &cindex_set,
497  std::vector<Cindex> *used_inputs) const;
498  virtual int32 Dim(const Nnet &nnet) const;
499  virtual BaseFloat GetScaleForNode(int32 node_index) const;
500 
501  // This function appends to "node_indexes" a list (not necessarily sorted or
502  // unique) of all the node indexes that this descriptor may forward data from.
503  virtual void GetNodeDependencies(std::vector<int32> *node_indexes) const;
504  virtual int32 Modulus() const;
508  virtual void WriteConfig(std::ostream &os,
509  const std::vector<std::string> &node_names) const;
510  virtual SumDescriptor *Copy() const;
512  op_(op), src1_(src1), src2_(src2) {}
513  virtual ~BinarySumDescriptor() { delete src1_; delete src2_; }
514  private:
518 };
519 
520 
521 // A Descriptor concatenates over its parts, so its feature-dimension will
522 // be the sum of the feature-dimensions of its parts. In a valid Descriptor,
523 // "parts" will be nonempty. Each part may be (in general) a summation, but
524 // usually a summation with just one term.
525 class Descriptor {
526  public:
527  int32 Dim(const Nnet &nnet) const;
528 
529  // The Parse method is used for reading a config-file-style represenation.
530  // Internally this uses class GeneralDescriptor to read and normalize the
531  // input. Assumes the input has already been tokenized into an array of
532  // strings by DescriptorTokenize(); it moves the begin-pointer "next_token" to
533  // account for each token that it consumes. Prints warning and returns false on
534  // error (including if there was junk after the last token). The input tokens
535  // should be terminated with a token that says "end of input".
536  bool Parse(const std::vector<std::string> &node_names,
537  const std::string **next_token);
538 
539  // Write in config-file format.
540  // if parts_.size() == 1, written form is just "<written-form-of-part0>"
541  // otherwise, written form is "Append(<written-form-of-part0>, <written-form-of-part1>, ... )".
542  void WriteConfig(std::ostream &os,
543  const std::vector<std::string> &node_names) const;
544 
565  void GetDependencies(const Index &index,
566  std::vector<Cindex> *used_inputs) const;
567 
571  bool IsComputable(const Index &ind,
572  const CindexSet &cindex_set,
573  std::vector<Cindex> *used_inputs) const;
574 
575  // This function outputs to "node_indexes" a list (not necessarily sorted or
576  // unique) of all the node indexes that this descriptor may forward data from.
577  void GetNodeDependencies(std::vector<int32> *node_indexes) const;
578 
579  // see Modulus function of ForwardingDescriptor for explanation.
580  int32 Modulus() const;
581 
583  int32 NumParts() const { return parts_.size(); }
585  const SumDescriptor &Part(int32 n) const;
586 
589  Descriptor(const Descriptor &other) { *this = other; }
591  Descriptor &operator = (const Descriptor &other);
593  Descriptor(const std::vector<SumDescriptor*> &parts):
594  parts_(parts) { }
596  ~Descriptor() { Destroy(); }
597  private:
598  void Destroy();
599  // the elements of parts_ are owned here.
600  std::vector<SumDescriptor*> parts_;
601 };
602 
603 
610  enum DescriptorType { kAppend, kSum, kFailover, kIfDefined, kOffset, kSwitch,
611  kRound, kReplaceIndex, kScale, kConst, kNodeName };
612 
613  // The Parse method is used for reading a config-file-style represenation.
614  // Assumes the input has already been tokenized into an array of strings, and
615  // it moves the begin-pointer "next_token" to account for token that it
616  // consumes. Calls KALDI_ERR on error. The list of tokens should be
617  // terminated with a string saying "end of input". Does not check that all
618  // the input has been consumed-- the caller should do that [check that
619  // **next_token == "end of input" after calling.]
620  static GeneralDescriptor *Parse(const std::vector<std::string> &node_names,
621  const std::string **next_token);
622 
623  explicit GeneralDescriptor(DescriptorType t, int32 value1 = -1,
624  int32 value2 = -1, BaseFloat alpha = 0.0):
625  descriptor_type_(t), value1_(value1), value2_(value2),
626  alpha_(alpha) { }
627 
628 
629  ~GeneralDescriptor() { DeletePointers(&descriptors_); }
630 
631  GeneralDescriptor *GetNormalizedDescriptor() const;
632 
633  Descriptor *ConvertToDescriptor();
634 
635  // prints in text form-- this is really only used for debug.
636  void Print(const std::vector<std::string> &node_names,
637  std::ostream &os);
638 
639  private:
641 
643 
644  // value1_ is only relevant if:
645  // (a) descriptor_type_ == kReplaceIndex (value1_ is 1 for t, 2 for x)
646  // (b) descriptor_type_ == kNodeName (value1_ is the index of the node)
647  // (c) descriptor_type_ == kOffset (value1_ is the t offset).
648  // (d) descriptor_type_ == kConst (value1_ is the dimension and alpha_
649  // is the value).
651  // value2_ is only relevant if
652  // (a) descriptor_type == kReplaceIndex (value2_ is the value
653  // we replace the index with).
654  // (b) descriptor_type_ == kOffset (value2_ is the x offset)
656 
657  // alpha is only relevant if
658  // (a) descriptor_type == kScale, and this will be the scaling factor.
659  // (b) descriptor_type == kConst; this is the value, and value1_ is set to the
660  // dimension.
662 
663  // For any descriptor types that take args of type kDescriptor, a list of those
664  // args. Pointers owned here.
665  std::vector<GeneralDescriptor*> descriptors_;
666 
667  // parses an Append() or Sum() or Switch() expression after the "Append(" or
668  // "Sum(" or "Switch(" has been read.
669  void ParseAppendOrSumOrSwitch(const std::vector<std::string> &node_names,
670  const std::string **next_token);
671  // parse an IfDefined() expression after the IfDefined( has already been
672  // read.
673  void ParseIfDefined(const std::vector<std::string> &node_names,
674  const std::string **next_token);
675  // ... and so on.
676  void ParseOffset(const std::vector<std::string> &node_names,
677  const std::string **next_token);
678  void ParseSwitch(const std::vector<std::string> &node_names,
679  const std::string **next_token);
680  void ParseFailover(const std::vector<std::string> &node_names,
681  const std::string **next_token);
682  void ParseRound(const std::vector<std::string> &node_names,
683  const std::string **next_token);
684  void ParseScale(const std::vector<std::string> &node_names,
685  const std::string **next_token);
686  void ParseConst(const std::vector<std::string> &node_names,
687  const std::string **next_token);
688  void ParseReplaceIndex(const std::vector<std::string> &node_names,
689  const std::string **next_token);
690 
691 
692 
693  // Used inside NormalizeAppend(). Return the number of terms there
694  // would be in a single consolidated Append() expressions, and asserts that in
695  // whichever branch of any other expressions we take, the number of terms is
696  // the same.
697  int32 NumAppendTerms() const;
698  // Used inside NormalizeAppend(). Gets one of the appended terms from this
699  // descriptor, with 0 <= term < NumAppendTerms(). Answer is newly allocated.
700  GeneralDescriptor *GetAppendTerm(int32 term) const;
701 
702 
703  // Normalizes w.r.t. Append expressions by moving Append() to the outside.
704  // Called only at the top level.
705  GeneralDescriptor *NormalizeAppend() const;
706 
707  // This call does all other types of normalization except for normalizing
708  // Append() expressions (which is assumed to have been done already). Returns
709  // true if anything was changed.
710  static bool Normalize(GeneralDescriptor *ptr);
711 
712  SumDescriptor *ConvertToSumDescriptor() const;
713  ForwardingDescriptor *ConvertToForwardingDescriptor() const;
714 
715 };
716 
717 
718 
719 
720 } // namespace nnet3
721 } // namespace kaldi
722 
723 #endif
KALDI_DISALLOW_COPY_AND_ASSIGN(ForwardingDescriptor)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
GeneralDescriptor(DescriptorType t, int32 value1=-1, int32 value2=-1, BaseFloat alpha=0.0)
virtual void WriteConfig(std::ostream &os, const std::vector< std::string > &node_names) const =0
void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
OffsetForwardingDescriptor(ForwardingDescriptor *src, Index offset)
This is the case of class SumDescriptor, in which we contain just one term, and that term is optional...
virtual int32 Dim(const Nnet &nnet) const =0
RoundingForwardingDescriptor(ForwardingDescriptor *src, int32 t_modulus)
virtual void GetNodeDependencies(std::vector< int32 > *node_indexes) const =0
This function appends to "node_indexes" all the node indexes.
This class is only used when parsing Descriptors.
std::vector< GeneralDescriptor * > descriptors_
This is an alternative base-case of SumDescriptor (an alternative to SimpleSumDescriptor) which repre...
virtual bool IsComputable(const Index &ind, const CindexSet &cindex_set, std::vector< Cindex > *used_inputs) const
This function exists to enable us to manage optional dependencies, i.e.
SimpleForwardingDescriptor is the base-case of ForwardingDescriptor, consisting of a source node in t...
virtual Cindex MapToInput(const Index &output) const =0
virtual bool IsComputable(const Index &ind, const CindexSet &cindex_set, std::vector< Cindex > *used_inputs) const
This function exists to enable us to manage optional dependencies, i.e.
std::vector< SumDescriptor * > parts_
BinarySumDescriptor can represent either A + B, or (A if defined, else B).
virtual int32 Dim(const Nnet &nnet) const
kaldi::int32 int32
ReplaceIndexForwardingDescriptor(ForwardingDescriptor *src, VariableName variable_name, int32 value)
virtual void GetDependencies(const Index &ind, std::vector< Cindex > *dependencies) const
Given an Index at the output of this Descriptor, append to "dependencies" a list of Cindexes that des...
For use in clockwork RNNs and the like, this forwarding-descriptor rounds the time-index t down to th...
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
virtual int32 Dim(const Nnet &nnet) const
std::pair< int32, Index > Cindex
Definition: nnet-common.h:115
virtual void GetNodeDependencies(std::vector< int32 > *node_indexes) const
This function appends to "node_indexes" a list (not necessarily sorted or unique) of all the node ind...
This is an abstract base-class.
virtual int32 Dim(const Nnet &nnet) const
const ForwardingDescriptor & Src() const
std::vector< ForwardingDescriptor * > src_
Chooses from different inputs based on the the time index modulo (the number of ForwardingDescriptors...
virtual int32 Dim(const Nnet &nnet) const
virtual int32 Modulus() const
This function is for use in things like clockwork RNNs, where shifting the time of the inputs and out...
struct rnnlm::@11::@12 n
OptionalSumDescriptor(SumDescriptor *src)
This is the normal base-case of SumDescriptor which just wraps a ForwardingDescriptor.
virtual int32 Modulus() const
This function is for use in things like clockwork RNNs, where shifting the time of the inputs and out...
virtual int32 Modulus() const
This function is for use in things like clockwork RNNs, where shifting the time of the inputs and out...
const ForwardingDescriptor & Src() const
A ForwardingDescriptor describes how we copy data from another NetworkNode, or from multiple other Ne...
virtual ForwardingDescriptor * Copy() const =0
Offsets in &#39;t&#39; and &#39;x&#39; values of other ForwardingDescriptors.
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
This ForwardingDescriptor modifies the indexes (n, t, x) by replacing one of them (normally t) with a...
SimpleSumDescriptor(ForwardingDescriptor *src)
void Print(const Fst< Arc > &fst, std::string message)
Descriptor(const Descriptor &other)
Copy constructor.
BinarySumDescriptor(Operation op, SumDescriptor *src1, SumDescriptor *src2)
SimpleForwardingDescriptor(int32 src_node, BaseFloat scale=1.0)
int32 NumParts() const
Returns the number of parts that are concatenated over.
Descriptor(const std::vector< SumDescriptor *> &parts)
Takes ownership of pointers in "parts".
virtual int32 Dim(const Nnet &nnet) const
SwitchingForwardingDescriptor(std::vector< ForwardingDescriptor *> &src)
virtual BaseFloat GetScaleForNode(int32 node_index) const =0
This function returns the scale on the node-index &#39;node_index&#39; when it appears in expressions inside ...