nnet-discriminative-diagnostics.cc
Go to the documentation of this file.
1 // nnet3/nnet-discriminative-diagnostics.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 // Copyright 2014-2015 Vimal Manohar
5 
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
20 
22 #include "nnet3/nnet-utils.h"
24 
25 namespace kaldi {
26 namespace nnet3 {
27 
29  const NnetComputeProbOptions &nnet_config,
30  const discriminative::DiscriminativeOptions &discriminative_config,
31  const TransitionModel &tmodel,
32  const VectorBase<BaseFloat> &priors,
33  const Nnet &nnet):
34  nnet_config_(nnet_config),
35  discriminative_config_(discriminative_config),
36  tmodel_(tmodel),
37  log_priors_(priors),
38  nnet_(nnet),
39  compiler_(nnet, nnet_config_.optimize_config),
40  deriv_nnet_(NULL),
41  num_minibatches_processed_(0) {
42  log_priors_.ApplyLog();
44  deriv_nnet_ = new Nnet(nnet_);
45  ScaleNnet(0.0, deriv_nnet_);
46  SetNnetAsGradient(deriv_nnet_); // force simple update
47  }
48 }
49 
51  if (deriv_nnet_ == NULL)
52  KALDI_ERR << "GetDeriv() called when no derivatives were requested.";
53  return *deriv_nnet_;
54 }
55 
57  delete deriv_nnet_; // delete does nothing if pointer is NULL.
58 }
59 
62  objf_info_.clear();
63  if (deriv_nnet_) {
64  ScaleNnet(0.0, deriv_nnet_);
66  }
67 }
68 
70  bool need_model_derivative = nnet_config_.compute_deriv,
71  store_component_stats = false;
72  bool use_xent_regularization = (discriminative_config_.xent_regularize != 0.0),
73  use_xent_derivative = false;
74 
75  ComputationRequest request;
77  need_model_derivative,
78  store_component_stats,
79  use_xent_regularization, use_xent_derivative,
80  &request);
81  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
82  NnetComputer computer(nnet_config_.compute_config, *computation,
84  // give the inputs to the computer object.
85  computer.AcceptInputs(nnet_, eg.inputs);
86  computer.Run();
87  this->ProcessOutputs(eg, &computer);
89  computer.Run();
90 }
91 
93  const NnetDiscriminativeExample &eg,
94  NnetComputer *computer) {
95  // There will normally be just one output here, named 'output',
96  // but the code is more general than this.
97  std::vector<NnetDiscriminativeSupervision>::const_iterator iter = eg.outputs.begin(),
98  end = eg.outputs.end();
99  for (; iter != end; ++iter) {
100  const NnetDiscriminativeSupervision &sup = *iter;
101  int32 node_index = nnet_.GetNodeIndex(sup.name);
102  if (node_index < 0 ||
103  !nnet_.IsOutputNode(node_index))
104  KALDI_ERR << "Network has no output named " << sup.name;
105 
106  const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);
107 
108  bool use_xent = (discriminative_config_.xent_regularize != 0.0);
109  std::string xent_name = sup.name + "-xent"; // typically "output-xent".
110  CuMatrix<BaseFloat> nnet_output_deriv, xent_deriv;
111 
113  nnet_output_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
114  kUndefined);
115 
116  if (use_xent)
117  xent_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
118  kUndefined);
119 
120  if (objf_info_.count(sup.name) == 0)
121  objf_info_.insert(std::make_pair(sup.name,
123 
125 
128  sup.supervision, nnet_output,
129  stats,
131  &nnet_output_deriv : NULL),
132  (use_xent ? &xent_deriv : NULL));
133 
135  computer->AcceptInput(sup.name, &nnet_output_deriv);
136 
137  if (use_xent) {
138  if (objf_info_.count(xent_name) == 0)
139  objf_info_.insert(std::make_pair(xent_name,
142 
143  // this block computes the cross-entropy objective.
144  const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput(xent_name);
145  // at this point, xent_deriv is posteriors derived from the numerator
146  // computation. note, xent_deriv has a factor of 'supervision.weight',
147  // but so does tot_weight.
148  BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans);
149  xent_stats.tot_t_weighted += stats->tot_t_weighted;
150  xent_stats.tot_objf += xent_objf;
151  }
152 
154  }
155 }
156 
158  bool ans = false;
159  unordered_map<std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher>::const_iterator
160  iter, end;
161  iter = objf_info_.begin();
162  end = objf_info_.end();
163  for (; iter != end; ++iter) {
164  const std::string &name = iter->first;
165  int32 node_index = nnet_.GetNodeIndex(name);
166  KALDI_ASSERT(node_index >= 0);
167  const discriminative::DiscriminativeObjectiveInfo &info = iter->second;
168  BaseFloat tot_weight = info.tot_t_weighted;
169  BaseFloat tot_objective = info.TotalObjf(
171 
173 
174  if (info.tot_l2_term == 0.0) {
176  << " objective for '"
177  << name << "' is "
178  << (tot_objective / tot_weight)
179  << " per frame, "
180  << "over " << tot_weight << " frames.";
181  } else {
183  << " objective for '"
184  << name << "' is "
185  << (tot_objective / tot_weight)
186  << " + " << (info.tot_l2_term / tot_weight)
187  << " per frame, "
188  << "over " << tot_weight << " frames.";
189  }
190 
191  if (tot_weight > 0)
192  ans = true;
193  }
194  return ans;
195 }
196 
198  const std::string &output_name) const {
199  unordered_map<std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher>::const_iterator
200  iter = objf_info_.find(output_name);
201  if (iter != objf_info_.end())
202  return &(iter->second);
203  else
204  return NULL;
205 }
206 
207 } // namespace nnet3
208 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
unordered_map< std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher > objf_info_
void PrintAll(const std::string &criterion) const
kaldi::int32 int32
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
This file contains some miscellaneous functions dealing with class Nnet.
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
Definition: nnet-utils.cc:292
const discriminative::DiscriminativeObjectiveInfo * GetObjective(const std::string &output_name) const
void AcceptInput(const std::string &node_name, CuMatrix< BaseFloat > *input)
e.g.
double TotalObjf(const std::string &criterion) const
const CuMatrixBase< BaseFloat > & GetOutput(const std::string &node_name)
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
void AcceptInputs(const Nnet &nnet, const std::vector< NnetIo > &io)
This convenience function calls AcceptInput() in turn on all the inputs in the training example...
NnetDiscriminativeComputeObjf(const NnetComputeProbOptions &nnet_config, const discriminative::DiscriminativeOptions &discriminative_config, const TransitionModel &tmodel, const VectorBase< BaseFloat > &priors, const Nnet &nnet)
void ComputeDiscriminativeObjfAndDeriv(const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
This function does forward-backward on the numerator and denominator lattices and computes derivates ...
#define KALDI_ERR
Definition: kaldi-error.h:147
discriminative::DiscriminativeOptions discriminative_config_
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
Matrix for CUDA computing.
Definition: matrix-common.h:69
class NnetComputer is responsible for executing the computation described in the "computation" object...
Definition: nnet-compute.h:59
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< NnetIo > inputs
&#39;inputs&#39; contains the input to the network– normally just it has just one element called "input"...
std::vector< NnetDiscriminativeSupervision > outputs
&#39;outputs&#39; contains the sequence output supervision.
void Compute(const NnetDiscriminativeExample &eg)
void GetDiscriminativeComputationRequest(const Nnet &nnet, const NnetDiscriminativeExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetDiscriminativeExample and produces a ComputationRequest.
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
discriminative::DiscriminativeSupervision supervision
#define KALDI_LOG
Definition: kaldi-error.h:153
void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type=kSetZero, MatrixStrideType stride_type=kDefaultStride)
Allocate the memory.
Definition: cu-matrix.cc:50
NnetDiscriminativeExample is like NnetExample, but specialized for sequence training.
void ProcessOutputs(const NnetDiscriminativeExample &eg, NnetComputer *computer)
void Run()
This does either the forward or backward computation, depending when it is called (in a typical compu...