nnet-discriminative-example.h
Go to the documentation of this file.
1 // nnet3/nnet-discriminative-example.h
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2014-2015 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET3_NNET_DISCRIMINATIVE_EXAMPLE_H_
22 #define KALDI_NNET3_NNET_DISCRIMINATIVE_EXAMPLE_H_
23 
24 #include "nnet3/nnet-nnet.h"
25 #include "nnet3/nnet-computation.h"
26 #include "util/table-types.h"
28 #include "nnet3/nnet-example.h"
30 #include "hmm/posterior.h"
31 #include "hmm/transition-model.h"
32 
33 namespace kaldi {
34 namespace nnet3 {
35 
36 // Glossary: mmi = Maximum Mutual Information,
37 // mpfe = Minimum Phone Frame Error
38 // smbr = State-level Minimum Bayes Risk
39 
40 // This file relates to the creation of examples for discriminative training
41 
43  // the name of the output in the neural net; in simple setups it
44  // will just be "output".
45  std::string name;
46 
47  // The indexes that the output corresponds to. The size of this vector will
48  // be equal to supervision.num_sequences * supervision.frames_per_sequence.
49  // Be careful about the order of these indexes-- it is a little confusing.
50  // The indexes in the 'index' vector are ordered as: (frame 0 of each sequence);
51  // (frame 1 of each sequence); and so on. But in the 'supervision' object,
52  // the lattice contains (sequence 0; sequence 1; ...). So reordering is needed.
53  // This is done to make the code similar that for the 'chain' model.
54  std::vector<Index> indexes;
55 
56  // The supervision object, containing the numerator and denominator
57  // lattices.
59 
60  // This is a vector of per-frame weights, required to be between 0 and 1,
61  // that is applied to the derivative during training (but not during model
62  // combination, where the derivatives need to agree with the computed objf
63  // values for the optimization code to work). The reason for this is to more
64  // exactly handle edge effects and to ensure that no frames are
65  // 'double-counted'. The order of this vector corresponds to the order of
66  // the 'indexes' (i.e. all the first frames, then all the second frames,
67  // etc.)
68  // If this vector is empty it means we're not applying per-frame weights,
69  // so it's equivalent to a vector of all ones. This vector is written
70  // to disk compactly as unsigned char.
72 
73  // Use default assignment operator
75 
76  // Initialize the object from an object of type discriminative::Supervision,
77  // and some extra information.
78  // Note: you probably want to set 'name' to "output".
79  // 'first_frame' will often be zero but you can choose (just make it
80  // consistent with how you numbered your inputs), and 'frame_skip' would be 1
81  // in a vanilla setup, but 3 in the case of 'chain' models
82  NnetDiscriminativeSupervision(const std::string &name,
84  const VectorBase<BaseFloat> &deriv_weights,
85  int32 first_frame,
86  int32 frame_skip);
87 
89 
90  void Write(std::ostream &os, bool binary) const;
91 
92  void Read(std::istream &is, bool binary);
93 
95 
96  void CheckDim() const;
97 
98  bool operator == (const NnetDiscriminativeSupervision &other) const;
99 };
100 
104 
108  std::vector<NnetIo> inputs;
109 
112  std::vector<NnetDiscriminativeSupervision> outputs;
113 
114  void Write(std::ostream &os, bool binary) const;
115 
116  void Read(std::istream &is, bool binary);
117 
118  void Swap(NnetDiscriminativeExample *other);
119 
120  // Compresses the input features (if not compressed)
121  void Compress();
122 
124 
126 
127  bool operator == (const NnetDiscriminativeExample &other) const {
128  return inputs == other.inputs && outputs == other.outputs;
129  }
130 };
131 
132 
137  size_t operator () (const NnetDiscriminativeExample &eg) const noexcept ;
138  // We also provide a version of this that works from pointers.
139  size_t operator () (const NnetDiscriminativeExample *eg) const noexcept {
140  return (*this)(*eg);
141  }
142 };
143 
144 
148  bool operator () (const NnetDiscriminativeExample &a,
149  const NnetDiscriminativeExample &b) const;
150  // We also provide a version of this that works from pointers.
151  bool operator () (const NnetDiscriminativeExample *a,
152  const NnetDiscriminativeExample *b) const {
153  return (*this)(*a, *b);
154  }
155 };
156 
157 
170  std::vector<NnetDiscriminativeExample> *input,
171  bool compress,
172  NnetDiscriminativeExample *output);
173 
174 // called from MergeDiscriminativeExamples, this function merges the Supervision
175 // objects into one. Requires (and checks) that they all have the same name.
176 void MergeSupervision(
177  const std::vector<const NnetDiscriminativeSupervision*> &inputs,
179 
180 
195 void ShiftDiscriminativeExampleTimes(int32 frame_shift,
196  const std::vector<std::string> &exclude_names,
198 
207  const NnetDiscriminativeExample &eg,
208  bool need_model_derivative,
209  bool store_component_stats,
210  bool use_xent_regularization,
211  bool use_xent_derivative,
212  ComputationRequest *computation_request);
213 
217 
218 
223 
224 
229  public:
231  NnetDiscriminativeExampleWriter *writer);
232 
233  // This function accepts an example, and if possible, writes a merged example
234  // out. The ownership of the pointer 'a' is transferred to this class when
235  // you call this function.
236  void AcceptExample(NnetDiscriminativeExample *a);
237 
238  // This function announces to the class that the input has finished, so it
239  // should flush out any smaller-sized minibatches, as dictated by the config.
240  // This will be called in the destructor, but you can call it explicitly when
241  // all the input is done if you want to; it won't repeat anything if called
242  // twice. It also prints the stats.
243  void Finish();
244 
245  // returns a suitable exit status for a program.
246  int32 ExitStatus() { Finish(); return (num_egs_written_ > 0 ? 0 : 1); }
247 
249  private:
250  // called by Finish() and AcceptExample(). Merges, updates the stats, and
251  // writes. The 'egs' is non-const only because the egs are temporarily
252  // changed inside MergeDiscriminativeEgs. The pointer 'egs' is still owned
253  // by the caller.
254  void WriteMinibatch(std::vector<NnetDiscriminativeExample> *egs);
255 
256  bool finished_;
259  NnetDiscriminativeExampleWriter *writer_;
261 
262  // Note: the "key" into the egs is the first element of the vector.
263  typedef unordered_map<NnetDiscriminativeExample*,
264  std::vector<NnetDiscriminativeExample*>,
267  MapType eg_to_egs_;
268 };
269 
270 
271 } // namespace nnet3
272 } // namespace kaldi
273 
274 #endif // KALDI_NNET3_NNET_DISCRIMINATIVE_EXAMPLE_H_
void ShiftDiscriminativeExampleTimes(int32 frame_shift, const std::vector< std::string > &exclude_names, NnetDiscriminativeExample *eg)
Shifts the time-index t of everything in the input of "eg" by adding "t_offset" to all "t" values– b...
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
This class is responsible for storing, and displaying in log messages, statistics about how examples ...
This class is responsible for arranging examples in groups that have the same strucure (i...
static void MergeSupervision(const std::vector< const NnetChainSupervision *> &inputs, NnetChainSupervision *output)
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
This hashing object hashes just the structural aspects of the NnetExample without looking at the valu...
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
int32 GetDiscriminativeNnetExampleSize(const NnetDiscriminativeExample &a)
This function returns the &#39;size&#39; of a discriminative example as defined for purposes of merging egs...
void Write(std::ostream &os, bool binary) const
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
void Swap(NnetDiscriminativeSupervision *other)
TableWriter< KaldiObjectHolder< NnetDiscriminativeExample > > NnetDiscriminativeExampleWriter
A class representing a vector.
Definition: kaldi-vector.h:406
std::vector< NnetIo > inputs
&#39;inputs&#39; contains the input to the network– normally just it has just one element called "input"...
std::vector< NnetDiscriminativeSupervision > outputs
&#39;outputs&#39; contains the sequence output supervision.
void MergeDiscriminativeExamples(bool compress, std::vector< NnetDiscriminativeExample > *input, NnetDiscriminativeExample *output)
void GetDiscriminativeComputationRequest(const Nnet &nnet, const NnetDiscriminativeExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetDiscriminativeExample and produces a ComputationRequest.
bool operator==(const NnetDiscriminativeSupervision &other) const
RandomAccessTableReader< KaldiObjectHolder< NnetDiscriminativeExample > > RandomAccessNnetDiscriminativeExampleReader
unordered_map< NnetDiscriminativeExample *, std::vector< NnetDiscriminativeExample * >, NnetDiscriminativeExampleStructureHasher, NnetDiscriminativeExampleStructureCompare > MapType
This comparator object compares just the structural aspects of the NnetDiscriminativeExample without ...
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
discriminative::DiscriminativeSupervision supervision
NnetDiscriminativeExample is like NnetExample, but specialized for sequence training.
SequentialTableReader< KaldiObjectHolder< NnetDiscriminativeExample > > SequentialNnetDiscriminativeExampleReader