NnetDiscriminativeComputeObjf Class Reference

This class is for computing objective-function values in a nnet3 discriminative training, for diagnostics. More...

#include <nnet-discriminative-diagnostics.h>

Collaboration diagram for NnetDiscriminativeComputeObjf:

Public Member Functions

 NnetDiscriminativeComputeObjf (const NnetComputeProbOptions &nnet_config, const discriminative::DiscriminativeOptions &discriminative_config, const TransitionModel &tmodel, const VectorBase< BaseFloat > &priors, const Nnet &nnet)
 
void Reset ()
 
void Compute (const NnetDiscriminativeExample &eg)
 
bool PrintTotalStats () const
 
const discriminative::DiscriminativeObjectiveInfoGetObjective (const std::string &output_name) const
 
const NnetGetDeriv () const
 
 ~NnetDiscriminativeComputeObjf ()
 

Private Member Functions

void ProcessOutputs (const NnetDiscriminativeExample &eg, NnetComputer *computer)
 

Private Attributes

NnetComputeProbOptions nnet_config_
 
discriminative::DiscriminativeOptions discriminative_config_
 
const TransitionModeltmodel_
 
CuVector< BaseFloatlog_priors_
 
const Nnetnnet_
 
CachingOptimizingCompiler compiler_
 
Nnetderiv_nnet_
 
int32 num_minibatches_processed_
 
unordered_map< std::string, discriminative::DiscriminativeObjectiveInfo, StringHasherobjf_info_
 

Detailed Description

This class is for computing objective-function values in a nnet3 discriminative training, for diagnostics.

It also supports computing model derivatives.

Definition at line 39 of file nnet-discriminative-diagnostics.h.

Constructor & Destructor Documentation

◆ NnetDiscriminativeComputeObjf()

NnetDiscriminativeComputeObjf ( const NnetComputeProbOptions nnet_config,
const discriminative::DiscriminativeOptions discriminative_config,
const TransitionModel tmodel,
const VectorBase< BaseFloat > &  priors,
const Nnet nnet 
)

Definition at line 28 of file nnet-discriminative-diagnostics.cc.

References NnetComputeProbOptions::compute_deriv, NnetDiscriminativeComputeObjf::deriv_nnet_, NnetDiscriminativeComputeObjf::log_priors_, NnetDiscriminativeComputeObjf::nnet_, NnetDiscriminativeComputeObjf::nnet_config_, kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::SetNnetAsGradient().

33  :
34  nnet_config_(nnet_config),
35  discriminative_config_(discriminative_config),
36  tmodel_(tmodel),
37  log_priors_(priors),
38  nnet_(nnet),
40  deriv_nnet_(NULL),
42  log_priors_.ApplyLog();
44  deriv_nnet_ = new Nnet(nnet_);
45  ScaleNnet(0.0, deriv_nnet_);
46  SetNnetAsGradient(deriv_nnet_); // force simple update
47  }
48 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
Definition: nnet-utils.cc:292
discriminative::DiscriminativeOptions discriminative_config_

◆ ~NnetDiscriminativeComputeObjf()

Definition at line 56 of file nnet-discriminative-diagnostics.cc.

References NnetDiscriminativeComputeObjf::deriv_nnet_.

56  {
57  delete deriv_nnet_; // delete does nothing if pointer is NULL.
58 }

Member Function Documentation

◆ Compute()

void Compute ( const NnetDiscriminativeExample eg)

Definition at line 69 of file nnet-discriminative-diagnostics.cc.

References NnetComputer::AcceptInputs(), CachingOptimizingCompiler::Compile(), NnetDiscriminativeComputeObjf::compiler_, NnetComputeProbOptions::compute_config, NnetComputeProbOptions::compute_deriv, NnetDiscriminativeComputeObjf::deriv_nnet_, NnetDiscriminativeComputeObjf::discriminative_config_, kaldi::nnet3::GetDiscriminativeComputationRequest(), NnetDiscriminativeExample::inputs, NnetDiscriminativeComputeObjf::nnet_, NnetDiscriminativeComputeObjf::nnet_config_, NnetDiscriminativeComputeObjf::ProcessOutputs(), NnetComputer::Run(), and DiscriminativeOptions::xent_regularize.

69  {
70  bool need_model_derivative = nnet_config_.compute_deriv,
71  store_component_stats = false;
72  bool use_xent_regularization = (discriminative_config_.xent_regularize != 0.0),
73  use_xent_derivative = false;
74 
75  ComputationRequest request;
77  need_model_derivative,
78  store_component_stats,
79  use_xent_regularization, use_xent_derivative,
80  &request);
81  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
82  NnetComputer computer(nnet_config_.compute_config, *computation,
84  // give the inputs to the computer object.
85  computer.AcceptInputs(nnet_, eg.inputs);
86  computer.Run();
87  this->ProcessOutputs(eg, &computer);
89  computer.Run();
90 }
discriminative::DiscriminativeOptions discriminative_config_
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
void GetDiscriminativeComputationRequest(const Nnet &nnet, const NnetDiscriminativeExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetDiscriminativeExample and produces a ComputationRequest.
void ProcessOutputs(const NnetDiscriminativeExample &eg, NnetComputer *computer)

◆ GetDeriv()

const Nnet & GetDeriv ( ) const

Definition at line 50 of file nnet-discriminative-diagnostics.cc.

References NnetDiscriminativeComputeObjf::deriv_nnet_, and KALDI_ERR.

50  {
51  if (deriv_nnet_ == NULL)
52  KALDI_ERR << "GetDeriv() called when no derivatives were requested.";
53  return *deriv_nnet_;
54 }
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ GetObjective()

const discriminative::DiscriminativeObjectiveInfo * GetObjective ( const std::string &  output_name) const

Definition at line 197 of file nnet-discriminative-diagnostics.cc.

References NnetDiscriminativeComputeObjf::objf_info_.

198  {
199  unordered_map<std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher>::const_iterator
200  iter = objf_info_.find(output_name);
201  if (iter != objf_info_.end())
202  return &(iter->second);
203  else
204  return NULL;
205 }
unordered_map< std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher > objf_info_

◆ PrintTotalStats()

bool PrintTotalStats ( ) const

Definition at line 157 of file nnet-discriminative-diagnostics.cc.

References DiscriminativeOptions::criterion, NnetDiscriminativeComputeObjf::discriminative_config_, Nnet::GetNodeIndex(), KALDI_ASSERT, KALDI_LOG, NnetDiscriminativeComputeObjf::nnet_, NnetDiscriminativeComputeObjf::objf_info_, DiscriminativeObjectiveInfo::PrintAll(), DiscriminativeObjectiveInfo::tot_l2_term, DiscriminativeObjectiveInfo::tot_t_weighted, and DiscriminativeObjectiveInfo::TotalObjf().

157  {
158  bool ans = false;
159  unordered_map<std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher>::const_iterator
160  iter, end;
161  iter = objf_info_.begin();
162  end = objf_info_.end();
163  for (; iter != end; ++iter) {
164  const std::string &name = iter->first;
165  int32 node_index = nnet_.GetNodeIndex(name);
166  KALDI_ASSERT(node_index >= 0);
167  const discriminative::DiscriminativeObjectiveInfo &info = iter->second;
168  BaseFloat tot_weight = info.tot_t_weighted;
169  BaseFloat tot_objective = info.TotalObjf(
171 
172  info.PrintAll(discriminative_config_.criterion);
173 
174  if (info.tot_l2_term == 0.0) {
176  << " objective for '"
177  << name << "' is "
178  << (tot_objective / tot_weight)
179  << " per frame, "
180  << "over " << tot_weight << " frames.";
181  } else {
183  << " objective for '"
184  << name << "' is "
185  << (tot_objective / tot_weight)
186  << " + " << (info.tot_l2_term / tot_weight)
187  << " per frame, "
188  << "over " << tot_weight << " frames.";
189  }
190 
191  if (tot_weight > 0)
192  ans = true;
193  }
194  return ans;
195 }
unordered_map< std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher > objf_info_
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
discriminative::DiscriminativeOptions discriminative_config_
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ ProcessOutputs()

void ProcessOutputs ( const NnetDiscriminativeExample eg,
NnetComputer computer 
)
private

Definition at line 92 of file nnet-discriminative-diagnostics.cc.

References NnetComputer::AcceptInput(), NnetComputeProbOptions::compute_deriv, kaldi::discriminative::ComputeDiscriminativeObjfAndDeriv(), NnetDiscriminativeComputeObjf::discriminative_config_, Nnet::GetNodeIndex(), NnetComputer::GetOutput(), Nnet::IsOutputNode(), KALDI_ERR, kaldi::kTrans, kaldi::kUndefined, NnetDiscriminativeComputeObjf::log_priors_, NnetDiscriminativeSupervision::name, NnetDiscriminativeComputeObjf::nnet_, NnetDiscriminativeComputeObjf::nnet_config_, NnetDiscriminativeComputeObjf::num_minibatches_processed_, NnetDiscriminativeComputeObjf::objf_info_, NnetDiscriminativeExample::outputs, CuMatrix< Real >::Resize(), NnetDiscriminativeSupervision::supervision, NnetDiscriminativeComputeObjf::tmodel_, DiscriminativeObjectiveInfo::tot_objf, DiscriminativeObjectiveInfo::tot_t_weighted, kaldi::TraceMatMat(), and DiscriminativeOptions::xent_regularize.

Referenced by NnetDiscriminativeComputeObjf::Compute().

94  {
95  // There will normally be just one output here, named 'output',
96  // but the code is more general than this.
97  std::vector<NnetDiscriminativeSupervision>::const_iterator iter = eg.outputs.begin(),
98  end = eg.outputs.end();
99  for (; iter != end; ++iter) {
100  const NnetDiscriminativeSupervision &sup = *iter;
101  int32 node_index = nnet_.GetNodeIndex(sup.name);
102  if (node_index < 0 ||
103  !nnet_.IsOutputNode(node_index))
104  KALDI_ERR << "Network has no output named " << sup.name;
105 
106  const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);
107 
108  bool use_xent = (discriminative_config_.xent_regularize != 0.0);
109  std::string xent_name = sup.name + "-xent"; // typically "output-xent".
110  CuMatrix<BaseFloat> nnet_output_deriv, xent_deriv;
111 
113  nnet_output_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
114  kUndefined);
115 
116  if (use_xent)
117  xent_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
118  kUndefined);
119 
120  if (objf_info_.count(sup.name) == 0)
121  objf_info_.insert(std::make_pair(sup.name,
122  discriminative::DiscriminativeObjectiveInfo(discriminative_config_)));
123 
124  discriminative::DiscriminativeObjectiveInfo *stats = &(objf_info_[sup.name]);
125 
128  sup.supervision, nnet_output,
129  stats,
131  &nnet_output_deriv : NULL),
132  (use_xent ? &xent_deriv : NULL));
133 
135  computer->AcceptInput(sup.name, &nnet_output_deriv);
136 
137  if (use_xent) {
138  if (objf_info_.count(xent_name) == 0)
139  objf_info_.insert(std::make_pair(xent_name,
140  discriminative::DiscriminativeObjectiveInfo(discriminative_config_)));
141  discriminative::DiscriminativeObjectiveInfo &xent_stats = objf_info_[xent_name];
142 
143  // this block computes the cross-entropy objective.
144  const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput(xent_name);
145  // at this point, xent_deriv is posteriors derived from the numerator
146  // computation. note, xent_deriv has a factor of 'supervision.weight',
147  // but so does tot_weight.
148  BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans);
149  xent_stats.tot_t_weighted += stats->tot_t_weighted;
150  xent_stats.tot_objf += xent_objf;
151  }
152 
154  }
155 }
unordered_map< std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher > objf_info_
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
void ComputeDiscriminativeObjfAndDeriv(const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
This function does forward-backward on the numerator and denominator lattices and computes derivates ...
#define KALDI_ERR
Definition: kaldi-error.h:147
discriminative::DiscriminativeOptions discriminative_config_
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ Reset()

void Reset ( )

Definition at line 60 of file nnet-discriminative-diagnostics.cc.

References NnetDiscriminativeComputeObjf::deriv_nnet_, NnetDiscriminativeComputeObjf::num_minibatches_processed_, NnetDiscriminativeComputeObjf::objf_info_, kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::SetNnetAsGradient().

60  {
62  objf_info_.clear();
63  if (deriv_nnet_) {
64  ScaleNnet(0.0, deriv_nnet_);
66  }
67 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
unordered_map< std::string, discriminative::DiscriminativeObjectiveInfo, StringHasher > objf_info_
void SetNnetAsGradient(Nnet *nnet)
Sets nnet as gradient by Setting is_gradient_ to true and learning_rate_ to 1 for each UpdatableCompo...
Definition: nnet-utils.cc:292

Member Data Documentation

◆ compiler_

◆ deriv_nnet_

◆ discriminative_config_

◆ log_priors_

◆ nnet_

◆ nnet_config_

◆ num_minibatches_processed_

int32 num_minibatches_processed_
private

◆ objf_info_

◆ tmodel_

const TransitionModel& tmodel_
private

The documentation for this class was generated from the following files: