nnet-chain-training.h
Go to the documentation of this file.
1 // nnet3/nnet-chain-training.h
2 
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #ifndef KALDI_NNET3_NNET_CHAIN_TRAINING_H_
21 #define KALDI_NNET3_NNET_CHAIN_TRAINING_H_
22 
23 #include "nnet3/nnet-example.h"
24 #include "nnet3/nnet-computation.h"
25 #include "nnet3/nnet-compute.h"
26 #include "nnet3/nnet-optimize.h"
28 #include "nnet3/nnet-training.h"
29 #include "chain/chain-training.h"
30 #include "chain/chain-den-graph.h"
31 
32 namespace kaldi {
33 namespace nnet3 {
34 
37  chain::ChainTrainingOptions chain_config;
39  NnetChainTrainingOptions(): apply_deriv_weights(true) { }
40 
41  void Register(OptionsItf *opts) {
42  nnet_config.Register(opts);
43  chain_config.Register(opts);
44  opts->Register("apply-deriv-weights", &apply_deriv_weights,
45  "If true, apply the per-frame derivative weights stored with "
46  "the example");
47  }
48 };
49 
50 
56  public:
58  const fst::StdVectorFst &den_fst,
59  Nnet *nnet);
60 
61  // train on one minibatch.
62  void Train(const NnetChainExample &eg);
63 
64  // Prints out the final stats, and return true if there was a nonzero count.
65  bool PrintTotalStats() const;
66 
68  private:
69  // The internal function for doing one step of conventional SGD training.
70  void TrainInternal(const NnetChainExample &eg,
71  const NnetComputation &computation);
72 
73  // The internal function for doing one step of backstitch training. Depending
74  // on whether is_backstitch_step1 is true, It could be either the first
75  // (backward) step, or the second (forward) step of backstitch.
76  void TrainInternalBackstitch(const NnetChainExample &eg,
77  const NnetComputation &computation,
78  bool is_backstitch_step1);
79 
80  void ProcessOutputs(bool is_backstitch_step2, const NnetChainExample &eg,
81  NnetComputer *computer);
82 
84 
85  chain::DenominatorGraph den_graph_;
87  Nnet *delta_nnet_; // stores the change to the parameters on each training
88  // iteration.
90 
91  // This code supports multiple output layers, even though in the
92  // normal case there will be just one output layer named "output".
93  // So we store the objective functions per output layer.
95 
96  // stats for max-change.
98 
99  unordered_map<std::string, ObjectiveFunctionInfo, StringHasher> objf_info_;
100 
101  // This value is used in backstitch training when we need to ensure
102  // consistent dropout masks. It's set to a value derived from rand()
103  // when the class is initialized.
105 };
106 
107 
108 } // namespace nnet3
109 } // namespace kaldi
110 
111 #endif // KALDI_NNET3_NNET_CHAIN_TRAINING_H_
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
void Register(OptionsItf *opts)
Definition: nnet-training.h:63
const NnetChainTrainingOptions opts_
kaldi::int32 int32
fst::StdVectorFst StdVectorFst
chain::DenominatorGraph den_graph_
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
CachingOptimizingCompiler compiler_
NnetChainExample is like NnetExample, but specialized for lattice-free (chain) training.
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
chain::ChainTrainingOptions chain_config
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_
This class is for single-threaded training of neural nets using the &#39;chain&#39; model.