NnetDiscriminativeTrainer Class Reference

This class is for single-threaded discriminative training of neural nets. More...

#include <nnet-discriminative-training.h>

Collaboration diagram for NnetDiscriminativeTrainer:

Public Member Functions

 NnetDiscriminativeTrainer (const NnetDiscriminativeOptions &config, const TransitionModel &tmodel, const VectorBase< BaseFloat > &priors, Nnet *nnet)
 
void Train (const NnetDiscriminativeExample &eg)
 
bool PrintTotalStats () const
 
 ~NnetDiscriminativeTrainer ()
 

Private Member Functions

void ProcessOutputs (const NnetDiscriminativeExample &eg, NnetComputer *computer)
 

Private Attributes

const NnetDiscriminativeOptions opts_
 
const TransitionModeltmodel_
 
CuVector< BaseFloatlog_priors_
 
Nnetnnet_
 
Nnetdelta_nnet_
 
CachingOptimizingCompiler compiler_
 
int32 num_minibatches_processed_
 
unordered_map< std::string, DiscriminativeObjectiveFunctionInfo, StringHasherobjf_info_
 

Detailed Description

This class is for single-threaded discriminative training of neural nets.

Definition at line 87 of file nnet-discriminative-training.h.

Constructor & Destructor Documentation

◆ NnetDiscriminativeTrainer()

NnetDiscriminativeTrainer ( const NnetDiscriminativeOptions config,
const TransitionModel tmodel,
const VectorBase< BaseFloat > &  priors,
Nnet nnet 
)

Definition at line 27 of file nnet-discriminative-training.cc.

References NnetDiscriminativeTrainer::compiler_, Nnet::Copy(), NnetDiscriminativeTrainer::delta_nnet_, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, NnetDiscriminativeTrainer::log_priors_, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetDiscriminativeTrainer::nnet_, NnetDiscriminativeOptions::nnet_config, Input::Open(), NnetTrainerOptions::read_cache, CachingOptimizingCompiler::ReadCache(), kaldi::nnet3::ScaleNnet(), Input::Stream(), NnetTrainerOptions::zero_component_stats, and kaldi::nnet3::ZeroComponentStats().

31  :
32  opts_(opts), tmodel_(tmodel), log_priors_(priors),
33  nnet_(nnet),
36  if (opts.nnet_config.zero_component_stats)
37  ZeroComponentStats(nnet);
38  if (opts.nnet_config.momentum == 0.0 &&
39  opts.nnet_config.max_param_change == 0.0) {
40  delta_nnet_= NULL;
41  } else {
42  KALDI_ASSERT(opts.nnet_config.momentum >= 0.0 &&
43  opts.nnet_config.max_param_change >= 0.0);
44  delta_nnet_ = nnet_->Copy();
45  ScaleNnet(0.0, delta_nnet_);
46  }
47  if (opts.nnet_config.read_cache != "") {
48  bool binary;
49  Input ki;
50  if (ki.Open(opts.nnet_config.read_cache, &binary)) {
51  compiler_.ReadCache(ki.Stream(), binary);
52  KALDI_LOG << "Read computation cache from "
53  << opts.nnet_config.read_cache;
54  } else {
55  KALDI_WARN << "Could not open cached computation. "
56  "Probably this is the first training iteration.";
57  }
58  }
59  log_priors_.ApplyLog();
60 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
NnetOptimizeOptions optimize_config
Definition: nnet-training.h:48
Nnet * Copy() const
Definition: nnet-nnet.h:246
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Definition: nnet-utils.cc:269
#define KALDI_WARN
Definition: kaldi-error.h:150
void ReadCache(std::istream &is, bool binary)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153

◆ ~NnetDiscriminativeTrainer()

Member Function Documentation

◆ PrintTotalStats()

bool PrintTotalStats ( ) const

Definition at line 191 of file nnet-discriminative-training.cc.

References DiscriminativeOptions::criterion, NnetDiscriminativeOptions::discriminative_config, NnetDiscriminativeTrainer::objf_info_, NnetDiscriminativeTrainer::opts_, and DiscriminativeObjectiveFunctionInfo::PrintTotalStats().

Referenced by main().

191  {
192  unordered_map<std::string, DiscriminativeObjectiveFunctionInfo,
193  StringHasher>::const_iterator
194  iter = objf_info_.begin(),
195  end = objf_info_.end();
196  bool ans = false;
197  for (; iter != end; ++iter) {
198  const std::string &name = iter->first;
199  const DiscriminativeObjectiveFunctionInfo &info = iter->second;
200  bool ret = info.PrintTotalStats(name, opts_.discriminative_config.criterion);
201  ans = ans || ret;
202  }
203 
204  return ans;
205 }
discriminative::DiscriminativeOptions discriminative_config
unordered_map< std::string, DiscriminativeObjectiveFunctionInfo, StringHasher > objf_info_

◆ ProcessOutputs()

void ProcessOutputs ( const NnetDiscriminativeExample eg,
NnetComputer computer 
)
private

Definition at line 109 of file nnet-discriminative-training.cc.

References NnetComputer::AcceptInput(), NnetDiscriminativeOptions::apply_deriv_weights, kaldi::discriminative::ComputeDiscriminativeObjfAndDeriv(), DiscriminativeOptions::criterion, NnetDiscriminativeSupervision::deriv_weights, NnetDiscriminativeOptions::discriminative_config, Nnet::GetNodeIndex(), NnetComputer::GetOutput(), Nnet::IsOutputNode(), KALDI_ERR, kaldi::kTrans, kaldi::kUndefined, NnetDiscriminativeTrainer::log_priors_, CuMatrixBase< Real >::MulRowsVec(), NnetDiscriminativeSupervision::name, NnetDiscriminativeTrainer::nnet_, NnetDiscriminativeOptions::nnet_config, NnetDiscriminativeTrainer::num_minibatches_processed_, NnetDiscriminativeTrainer::objf_info_, NnetDiscriminativeTrainer::opts_, NnetDiscriminativeExample::outputs, NnetTrainerOptions::print_interval, CuMatrix< Real >::Resize(), CuMatrixBase< Real >::Scale(), NnetDiscriminativeSupervision::supervision, NnetDiscriminativeTrainer::tmodel_, DiscriminativeObjectiveInfo::tot_objf, DiscriminativeObjectiveInfo::tot_t_weighted, kaldi::TraceMatMat(), and DiscriminativeOptions::xent_regularize.

Referenced by NnetDiscriminativeTrainer::Train().

110  {
111  // normally the eg will have just one output named 'output', but
112  // we don't assume this.
113  std::vector<NnetDiscriminativeSupervision>::const_iterator iter = eg.outputs.begin(),
114  end = eg.outputs.end();
115  for (; iter != end; ++iter) {
116  const NnetDiscriminativeSupervision &sup = *iter;
117  int32 node_index = nnet_->GetNodeIndex(sup.name);
118  if (node_index < 0 ||
119  !nnet_->IsOutputNode(node_index))
120  KALDI_ERR << "Network has no output named " << sup.name;
121 
122  const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);
123 
124  CuMatrix<BaseFloat> nnet_output_deriv(nnet_output.NumRows(),
125  nnet_output.NumCols(),
126  kUndefined);
127 
128  bool use_xent = (opts_.discriminative_config.xent_regularize != 0.0);
129  std::string xent_name = sup.name + "-xent"; // typically "output-xent".
130  CuMatrix<BaseFloat> xent_deriv;
131  if (use_xent)
132  xent_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
133  kUndefined);
134 
135  discriminative::DiscriminativeObjectiveInfo stats(opts_.discriminative_config);
136 
137  if (objf_info_.count(sup.name) == 0) {
138  objf_info_[sup.name].stats.Configure(opts_.discriminative_config);
139  objf_info_[sup.name].stats.Reset();
140  }
141 
144  sup.supervision, nnet_output,
145  &stats,
146  &nnet_output_deriv,
147  (use_xent ? &xent_deriv : NULL));
148 
149  if (use_xent) {
150  // this block computes the cross-entropy objective.
151  const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput(xent_name);
152  // at this point, xent_deriv is posteriors derived from the numerator
153  // computation. note, xent_objf has a factor of '.supervision.weight'
154  BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans);
155  if (xent_objf != xent_objf) {
156  BaseFloat default_objf = -10;
157  xent_objf = default_objf;
158  }
159 
160  discriminative::DiscriminativeObjectiveInfo xent_stats;
161  xent_stats.tot_t_weighted = stats.tot_t_weighted;
162  xent_stats.tot_objf = xent_objf;
163 
164  objf_info_[xent_name].UpdateStats(xent_name, "xent",
166  num_minibatches_processed_, xent_stats);
167  }
168 
169  if (opts_.apply_deriv_weights && sup.deriv_weights.Dim() != 0) {
170  CuVector<BaseFloat> cu_deriv_weights(sup.deriv_weights);
171  nnet_output_deriv.MulRowsVec(cu_deriv_weights);
172  if (use_xent)
173  xent_deriv.MulRowsVec(cu_deriv_weights);
174  }
175 
176  computer->AcceptInput(sup.name, &nnet_output_deriv);
177 
178  objf_info_[sup.name].UpdateStats(sup.name, opts_.discriminative_config.criterion,
181  stats);
182 
183  if (use_xent) {
184  xent_deriv.Scale(opts_.discriminative_config.xent_regularize);
185  computer->AcceptInput(xent_name, &xent_deriv);
186  }
187  }
188 }
kaldi::int32 int32
float BaseFloat
Definition: kaldi-types.h:29
discriminative::DiscriminativeOptions discriminative_config
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
void ComputeDiscriminativeObjfAndDeriv(const DiscriminativeOptions &opts, const TransitionModel &tmodel, const CuVectorBase< BaseFloat > &log_priors, const DiscriminativeSupervision &supervision, const CuMatrixBase< BaseFloat > &nnet_output, DiscriminativeObjectiveInfo *stats, CuMatrixBase< BaseFloat > *nnet_output_deriv, CuMatrixBase< BaseFloat > *xent_output_deriv)
This function does forward-backward on the numerator and denominator lattices and computes derivates ...
#define KALDI_ERR
Definition: kaldi-error.h:147
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
unordered_map< std::string, DiscriminativeObjectiveFunctionInfo, StringHasher > objf_info_
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ Train()

void Train ( const NnetDiscriminativeExample eg)

Definition at line 63 of file nnet-discriminative-training.cc.

References NnetComputer::AcceptInputs(), kaldi::nnet3::AddNnet(), CachingOptimizingCompiler::Compile(), NnetDiscriminativeTrainer::compiler_, NnetTrainerOptions::compute_config, NnetDiscriminativeTrainer::delta_nnet_, NnetDiscriminativeOptions::discriminative_config, kaldi::nnet3::DotProduct(), kaldi::nnet3::GetDiscriminativeComputationRequest(), NnetDiscriminativeExample::inputs, KALDI_LOG, KALDI_WARN, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetDiscriminativeTrainer::nnet_, NnetDiscriminativeOptions::nnet_config, NnetDiscriminativeTrainer::opts_, NnetDiscriminativeTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleNnet(), NnetTrainerOptions::store_component_stats, and DiscriminativeOptions::xent_regularize.

Referenced by main().

63  {
64  bool need_model_derivative = true;
65  const NnetTrainerOptions &nnet_config = opts_.nnet_config;
66  bool use_xent_regularization = (opts_.discriminative_config.xent_regularize != 0.0);
67  ComputationRequest request;
68  GetDiscriminativeComputationRequest(*nnet_, eg, need_model_derivative,
69  nnet_config.store_component_stats,
70  use_xent_regularization,
71  need_model_derivative,
72  &request);
73  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
74 
75  NnetComputer computer(nnet_config.compute_config, *computation,
76  *nnet_,
77  (delta_nnet_ == NULL ? nnet_ : delta_nnet_));
78  // give the inputs to the computer object.
79  computer.AcceptInputs(*nnet_, eg.inputs);
80  computer.Run();
81 
82  this->ProcessOutputs(eg, &computer);
83  computer.Run();
84 
85  if (delta_nnet_ != NULL) {
86  BaseFloat scale = (1.0 - nnet_config.momentum);
87  if (nnet_config.max_param_change != 0.0) {
88  BaseFloat param_delta =
89  std::sqrt(DotProduct(*delta_nnet_, *delta_nnet_)) * scale;
90  if (param_delta > nnet_config.max_param_change) {
91  if (param_delta - param_delta != 0.0) {
92  KALDI_WARN << "Infinite parameter change, will not apply.";
93  ScaleNnet(0.0, delta_nnet_);
94  } else {
95  scale *= nnet_config.max_param_change / param_delta;
96  KALDI_LOG << "Parameter change too big: " << param_delta << " > "
97  << "--max-param-change=" << nnet_config.max_param_change
98  << ", scaling by "
99  << nnet_config.max_param_change / param_delta;
100  }
101  }
102  }
103  AddNnet(*delta_nnet_, scale, nnet_);
104  ScaleNnet(nnet_config.momentum, delta_nnet_);
105  }
106 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void ProcessOutputs(const NnetDiscriminativeExample &eg, NnetComputer *computer)
float BaseFloat
Definition: kaldi-types.h:29
discriminative::DiscriminativeOptions discriminative_config
#define KALDI_WARN
Definition: kaldi-error.h:150
BaseFloat DotProduct(const Nnet &nnet1, const Nnet &nnet2)
Returns dot product between two networks of the same structure (calls the DotProduct functions of the...
Definition: nnet-utils.cc:250
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
void GetDiscriminativeComputationRequest(const Nnet &nnet, const NnetDiscriminativeExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetDiscriminativeExample and produces a ComputationRequest.
#define KALDI_LOG
Definition: kaldi-error.h:153
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest)
Does *dest += alpha * src (affects nnet parameters and stored stats).
Definition: nnet-utils.cc:349

Member Data Documentation

◆ compiler_

◆ delta_nnet_

◆ log_priors_

◆ nnet_

◆ num_minibatches_processed_

int32 num_minibatches_processed_
private

◆ objf_info_

◆ opts_

◆ tmodel_

const TransitionModel& tmodel_
private

The documentation for this class was generated from the following files: