nnet-chain-example.h
Go to the documentation of this file.
1 // nnet3/nnet-chain-example.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_CHAIN_EXAMPLE_H_
21 #define KALDI_NNET3_NNET_CHAIN_EXAMPLE_H_
22 
23 #include "nnet3/nnet-nnet.h"
24 #include "nnet3/nnet-computation.h"
25 #include "hmm/posterior.h"
26 #include "util/table-types.h"
27 #include "nnet3/nnet-example.h"
29 #include "chain/chain-supervision.h"
30 
31 namespace kaldi {
32 namespace nnet3 {
33 
34 
35 // For regular setups we use struct 'NnetIo' as the output. For the 'chain'
36 // models, the output supervision is a little more complex as it involves a
37 // lattice and we need to do forward-backward, so we use a separate struct for
38 // it. The 'output' name means that it pertains to the output of the network,
39 // as opposed to the features which pertain to the input of the network. It
40 // actually stores the lattice-like supervision information at the output of the
41 // network (which imposes constraints on which frames each phone can be active
42 // on.
46  std::string name;
47 
60  std::vector<Index> indexes;
61 
62 
64  chain::Supervision supervision;
65 
78 
79  // Use default assignment operator
80 
82 
89  NnetChainSupervision(const std::string &name,
90  const chain::Supervision &supervision,
91  const VectorBase<BaseFloat> &deriv_weights,
92  int32 first_frame,
93  int32 frame_skip);
94 
96 
97  void Write(std::ostream &os, bool binary) const;
98 
99  void Read(std::istream &is, bool binary);
100 
101  void Swap(NnetChainSupervision *other);
102 
103  void CheckDim() const;
104 
105  bool operator == (const NnetChainSupervision &other) const;
106 };
107 
111 
115  std::vector<NnetIo> inputs;
116 
119  std::vector<NnetChainSupervision> outputs;
120 
121  void Write(std::ostream &os, bool binary) const;
122  void Read(std::istream &is, bool binary);
123 
124  void Swap(NnetChainExample *other);
125 
126  // Compresses the input features (if not compressed)
127  void Compress();
128 
130 
131  NnetChainExample(const NnetChainExample &other);
132 
133  bool operator == (const NnetChainExample &other) const {
134  return inputs == other.inputs && outputs == other.outputs;
135  }
136 };
137 
142  size_t operator () (const NnetChainExample &eg) const noexcept;
143  // We also provide a version of this that works from pointers.
144  size_t operator () (const NnetChainExample *eg) const noexcept {
145  return (*this)(*eg);
146  }
147 };
148 
149 
153  bool operator () (const NnetChainExample &a,
154  const NnetChainExample &b) const;
155  // We also provide a version of this that works from pointers.
156  bool operator () (const NnetChainExample *a,
157  const NnetChainExample *b) const {
158  return (*this)(*a, *b);
159  }
160 };
161 
162 
163 
172 void MergeChainExamples(bool compress,
173  std::vector<NnetChainExample> *input,
174  NnetChainExample *output);
175 
176 
177 
192 void ShiftChainExampleTimes(int32 frame_shift,
193  const std::vector<std::string> &exclude_names,
194  NnetChainExample *eg);
195 
210 void GetChainComputationRequest(const Nnet &nnet,
211  const NnetChainExample &eg,
212  bool need_model_derivative,
213  bool store_component_stats,
214  bool use_xent_regularization,
215  bool use_xent_derivative,
216  ComputationRequest *computation_request);
217 
218 
219 
223 
224 
229 
230 
235  public:
237  NnetChainExampleWriter *writer);
238 
239  // This function accepts an example, and if possible, writes a merged example
240  // out. The ownership of the pointer 'a' is transferred to this class when
241  // you call this function.
242  void AcceptExample(NnetChainExample *a);
243 
244  // This function announces to the class that the input has finished, so it
245  // should flush out any smaller-sized minibatches, as dictated by the config.
246  // This will be called in the destructor, but you can call it explicitly when
247  // all the input is done if you want to; it won't repeat anything if called
248  // twice. It also prints the stats.
249  void Finish();
250 
251  // returns a suitable exit status for a program.
252  int32 ExitStatus() { Finish(); return (num_egs_written_ > 0 ? 0 : 1); }
253 
254  ~ChainExampleMerger() { Finish(); };
255  private:
256  // called by Finish() and AcceptExample(). Merges, updates the stats, and
257  // writes. The 'egs' is non-const only because the egs are temporarily
258  // changed inside MergeChainEgs. The pointer 'egs' is still owned
259  // by the caller.
260  void WriteMinibatch(std::vector<NnetChainExample> *egs);
261 
262  bool finished_;
265  NnetChainExampleWriter *writer_;
267 
268  // Note: the "key" into the egs is the first element of the vector.
269  typedef unordered_map<NnetChainExample*,
270  std::vector<NnetChainExample*>,
273 MapType eg_to_egs_;
274 };
275 
276 
277 
278 } // namespace nnet3
279 } // namespace kaldi
280 
281 #endif // KALDI_NNET3_NNET_CHAIN_EXAMPLE_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
Vector< BaseFloat > deriv_weights
This is a vector of per-frame weights, required to be between 0 and 1, that is applied to the derivat...
This class is responsible for storing, and displaying in log messages, statistics about how examples ...
int32 GetChainNnetExampleSize(const NnetChainExample &a)
This function returns the &#39;size&#39; of a chain example as defined for purposes of merging egs...
chain::Supervision supervision
The supervision object, containing the FST.
std::vector< NnetIo > inputs
&#39;inputs&#39; contains the input to the network– normally just it has just one element called "input"...
void MergeChainExamples(bool compress, std::vector< NnetChainExample > *input, NnetChainExample *output)
This function merges a list of NnetChainExample objects into a single one– intended to be used when ...
void ShiftChainExampleTimes(int32 frame_shift, const std::vector< std::string > &exclude_names, NnetChainExample *eg)
Shifts the time-index t of everything in the input of "eg" by adding "t_offset" to all "t" values– b...
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void Write(std::ostream &os, bool binary) const
TableWriter< KaldiObjectHolder< NnetChainExample > > NnetChainExampleWriter
void Swap(NnetChainSupervision *other)
std::string name
the name of the output in the neural net; in simple setups it will just be "output".
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
This hashing object hashes just the structural aspects of the NnetExample without looking at the valu...
std::vector< NnetChainSupervision > outputs
&#39;outputs&#39; contains the chain output supervision.
void Read(std::istream &is, bool binary)
NnetChainExample is like NnetExample, but specialized for lattice-free (chain) training.
unordered_map< NnetChainExample *, std::vector< NnetChainExample * >, NnetChainExampleStructureHasher, NnetChainExampleStructureCompare > MapType
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
SequentialTableReader< KaldiObjectHolder< NnetChainExample > > SequentialNnetChainExampleReader
RandomAccessTableReader< KaldiObjectHolder< NnetChainExample > > RandomAccessNnetChainExampleReader
const ExampleMergingConfig & config_
NnetChainExampleWriter * writer_
bool operator==(const NnetChainSupervision &other) const
A class representing a vector.
Definition: kaldi-vector.h:406
This comparator object compares just the structural aspects of the NnetChainExample without looking a...
std::vector< Index > indexes
The indexes that the output corresponds to.
This class is responsible for arranging examples in groups that have the same strucure (i...
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void GetChainComputationRequest(const Nnet &nnet, const NnetChainExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetChainExample and produces a ComputationRequest.