nnet-discriminative-training.h
Go to the documentation of this file.
1 // nnet3/nnet-discriminative-training.h
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // 2014-2015 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
21 #ifndef KALDI_NNET3_NNET_DISCRIMINATIVE_TRAINING_H_
22 #define KALDI_NNET3_NNET_DISCRIMINATIVE_TRAINING_H_
23 
24 #include "nnet3/nnet-example.h"
25 #include "nnet3/nnet-computation.h"
26 #include "nnet3/nnet-compute.h"
27 #include "nnet3/nnet-optimize.h"
29 #include "nnet3/nnet-training.h"
31 
32 namespace kaldi {
33 namespace nnet3 {
34 
38 
40 
41  NnetDiscriminativeOptions(): apply_deriv_weights(true) { }
42 
43  void Register(OptionsItf *opts) {
44  nnet_config.Register(opts);
45  discriminative_config.Register(opts);
46  opts->Register("apply-deriv-weights", &apply_deriv_weights,
47  "If true, apply the per-frame derivative weights stored with "
48  "the example.");
49  }
50 };
51 
52 // This struct is used in multiple nnet training classes for keeping
53 // track of objective function values.
54 // Also see struct AccuracyInfo, in nnet-diagnostics.h.
57 
60 
62  current_phase(0) { }
63 
64  // This function updates the stats and, if the phase has just changed,
65  // prints a message indicating progress. The phase equals
66  // minibatch_counter / minibatches_per_phase. Its only function is to
67  // control how frequently we print logging messages.
68  void UpdateStats(const std::string &output_name,
69  const std::string &criterion,
70  int32 minibatches_per_phase,
71  int32 minibatch_counter,
73 
74  // Prints stats for the current phase.
75  void PrintStatsForThisPhase(const std::string &output_name,
76  const std::string &criterion,
77  int32 minibatches_per_phase) const;
78  // Prints total stats, and returns true if total stats' weight was nonzero.
79  bool PrintTotalStats(const std::string &output_name,
80  const std::string &criterion) const;
81 };
82 
83 
88  public:
90  const TransitionModel &tmodel,
91  const VectorBase<BaseFloat> &priors,
92  Nnet *nnet);
93 
94  // train on one minibatch.
95  void Train(const NnetDiscriminativeExample &eg);
96 
97  // Prints out the final stats, and return true if there was a nonzero count.
98  bool PrintTotalStats() const;
99 
101  private:
102  void ProcessOutputs(const NnetDiscriminativeExample &eg,
103  NnetComputer *computer);
104 
106 
109 
111 
112  Nnet *delta_nnet_; // Only used if momentum != 0.0. nnet representing
113  // accumulated parameter-change (we'd call this
114  // gradient_nnet_, but due to natural-gradient update,
115  // it's better to consider it as a delta-parameter nnet.
117 
119 
120  // This code supports multiple output layers, even though in the
121  // normal case there will be just one output layer named "output".
122  // So we store the objective functions per output layer.
123  unordered_map<std::string, DiscriminativeObjectiveFunctionInfo, StringHasher> objf_info_;
124 };
125 
126 
127 } // namespace nnet3
128 } // namespace kaldi
129 
130 #endif // KALDI_NNET3_NNET_DISCRIMINATIVE_TRAINING_H_
131 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
This class enables you to do the compilation and optimization in one call, and also ensures that if t...
void Register(OptionsItf *opts)
Definition: nnet-training.h:63
kaldi::int32 int32
virtual void Register(const std::string &name, bool *ptr, const std::string &doc)=0
The two main classes defined in this header are struct ComputationRequest, which basically defines a ...
discriminative::DiscriminativeOptions discriminative_config
This class is for single-threaded discriminative training of neural nets.
discriminative::DiscriminativeObjectiveInfo stats
unordered_map< std::string, DiscriminativeObjectiveFunctionInfo, StringHasher > objf_info_
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
discriminative::DiscriminativeObjectiveInfo stats_this_phase
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
NnetDiscriminativeExample is like NnetExample, but specialized for sequence training.