NnetTrainer Class Reference

This class is for single-threaded training of neural nets using standard objective functions such as cross-entropy (implemented with logsoftmax nonlinearity and a linear objective function) and quadratic loss. More...

#include <nnet-training.h>

Collaboration diagram for NnetTrainer:

Public Member Functions

 NnetTrainer (const NnetTrainerOptions &config, Nnet *nnet)
 
void Train (const NnetExample &eg)
 
bool PrintTotalStats () const
 
 ~NnetTrainer ()
 

Private Member Functions

void TrainInternal (const NnetExample &eg, const NnetComputation &computation)
 
void TrainInternalBackstitch (const NnetExample &eg, const NnetComputation &computation, bool is_backstitch_step1)
 
void ProcessOutputs (bool is_backstitch_step2, const NnetExample &eg, NnetComputer *computer)
 

Private Attributes

const NnetTrainerOptions config_
 
Nnetnnet_
 
Nnetdelta_nnet_
 
CachingOptimizingCompiler compiler_
 
int32 num_minibatches_processed_
 
MaxChangeStats max_change_stats_
 
unordered_map< std::string, ObjectiveFunctionInfo, StringHasherobjf_info_
 
int32 srand_seed_
 

Detailed Description

This class is for single-threaded training of neural nets using standard objective functions such as cross-entropy (implemented with logsoftmax nonlinearity and a linear objective function) and quadratic loss.

Something that we should do in the future is to make it possible to have two different threads, one for the compilation, and one for the computation. This would only improve efficiency in the cases where the structure of the input example was different each time, which isn't what we expect to see in speech-recognition training. (If the structure is the same each time, the CachingOptimizingCompiler notices this and uses the computation from last time).

Definition at line 180 of file nnet-training.h.

Constructor & Destructor Documentation

◆ NnetTrainer()

NnetTrainer ( const NnetTrainerOptions config,
Nnet nnet 
)

Definition at line 27 of file nnet-training.cc.

References NnetTrainerOptions::backstitch_training_interval, NnetTrainer::compiler_, NnetTrainer::config_, Nnet::Copy(), NnetTrainer::delta_nnet_, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetTrainer::nnet_, Input::Open(), NnetTrainerOptions::read_cache, CachingOptimizingCompiler::ReadCache(), kaldi::nnet3::ScaleNnet(), Input::Stream(), NnetTrainerOptions::zero_component_stats, and kaldi::nnet3::ZeroComponentStats().

28  :
29  config_(config),
30  nnet_(nnet),
33  max_change_stats_(*nnet),
34  srand_seed_(RandInt(0, 100000)) {
35  if (config.zero_component_stats)
36  ZeroComponentStats(nnet);
37  KALDI_ASSERT(config.momentum >= 0.0 &&
38  config.max_param_change >= 0.0 &&
39  config.backstitch_training_interval > 0);
40  delta_nnet_ = nnet_->Copy();
41  ScaleNnet(0.0, delta_nnet_);
42 
43  if (config_.read_cache != "") {
44  bool binary;
45  Input ki;
46  if (ki.Open(config_.read_cache, &binary)) {
47  compiler_.ReadCache(ki.Stream(), binary);
48  KALDI_LOG << "Read computation cache from " << config_.read_cache;
49  } else {
50  KALDI_WARN << "Could not open cached computation. "
51  "Probably this is the first training iteration.";
52  }
53  }
54 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
NnetOptimizeOptions optimize_config
Definition: nnet-training.h:48
CachingOptimizingCompilerOptions compiler_config
Definition: nnet-training.h:50
MaxChangeStats max_change_stats_
Nnet * Copy() const
Definition: nnet-nnet.h:246
CachingOptimizingCompiler compiler_
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Definition: nnet-utils.cc:269
#define KALDI_WARN
Definition: kaldi-error.h:150
const NnetTrainerOptions config_
void ReadCache(std::istream &is, bool binary)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95

◆ ~NnetTrainer()

Member Function Documentation

◆ PrintTotalStats()

bool PrintTotalStats ( ) const

Definition at line 220 of file nnet-training.cc.

References rnnlm::i, NnetTrainer::max_change_stats_, NnetTrainer::nnet_, NnetTrainer::objf_info_, MaxChangeStats::Print(), and ObjectiveFunctionInfo::PrintTotalStats().

Referenced by main().

220  {
221  unordered_map<std::string, ObjectiveFunctionInfo, StringHasher>::const_iterator
222  iter = objf_info_.begin(),
223  end = objf_info_.end();
224  std::vector<std::pair<std::string, const ObjectiveFunctionInfo*> > all_pairs;
225  for (; iter != end; ++iter)
226  all_pairs.push_back(std::pair<std::string, const ObjectiveFunctionInfo*>(
227  iter->first, &(iter->second)));
228  // ensure deterministic order of these names (this will matter in situations
229  // where a script greps for the objective from the log).
230  std::sort(all_pairs.begin(), all_pairs.end());
231  bool ans = false;
232  for (size_t i = 0; i < all_pairs.size(); i++) {
233  const std::string &name = all_pairs[i].first;
234  const ObjectiveFunctionInfo &info = *(all_pairs[i].second);
235  bool ok = info.PrintTotalStats(name);
236  ans = ans || ok;
237  }
239  return ans;
240 }
MaxChangeStats max_change_stats_
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_
void Print(const Nnet &nnet) const
Definition: nnet-utils.cc:2284

◆ ProcessOutputs()

void ProcessOutputs ( bool  is_backstitch_step2,
const NnetExample eg,
NnetComputer computer 
)
private

Definition at line 191 of file nnet-training.cc.

References kaldi::nnet3::ComputeObjectiveFunction(), NnetTrainer::config_, NnetIo::features, Nnet::GetNode(), Nnet::GetNodeIndex(), NnetExample::io, Nnet::IsOutputNode(), KALDI_ASSERT, NnetIo::name, NnetTrainer::nnet_, NnetTrainer::num_minibatches_processed_, NetworkNode::objective_type, NnetTrainer::objf_info_, NnetTrainerOptions::print_interval, and NetworkNode::u.

Referenced by NnetTrainer::TrainInternal(), and NnetTrainer::TrainInternalBackstitch().

193  {
194  // normally the eg will have just one output named 'output', but
195  // we don't assume this.
196  // In backstitch training, the output-name with the "_backstitch" suffix is
197  // the one computed after the first, backward step of backstitch.
198  const std::string suffix = (is_backstitch_step2 ? "_backstitch" : "");
199  std::vector<NnetIo>::const_iterator iter = eg.io.begin(),
200  end = eg.io.end();
201  for (; iter != end; ++iter) {
202  const NnetIo &io = *iter;
203  int32 node_index = nnet_->GetNodeIndex(io.name);
204  KALDI_ASSERT(node_index >= 0);
205  if (nnet_->IsOutputNode(node_index)) {
206  ObjectiveType obj_type = nnet_->GetNode(node_index).u.objective_type;
207  BaseFloat tot_weight, tot_objf;
208  bool supply_deriv = true;
209  ComputeObjectiveFunction(io.features, obj_type, io.name,
210  supply_deriv, computer,
211  &tot_weight, &tot_objf);
212  objf_info_[io.name + suffix].UpdateStats(io.name + suffix,
215  tot_weight, tot_objf);
216  }
217  }
218 }
void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, BaseFloat *tot_objf)
This function computes the objective function, and if supply_deriv = true, supplies its derivative to...
kaldi::int32 int32
ObjectiveType objective_type
Definition: nnet-nnet.h:97
const NetworkNode & GetNode(int32 node) const
returns const reference to a particular numbered network node.
Definition: nnet-nnet.h:146
float BaseFloat
Definition: kaldi-types.h:29
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_
const NnetTrainerOptions config_
ObjectiveType
This enum is for a kind of annotation we associate with output nodes of the network; it&#39;s for the con...
Definition: nnet-nnet.h:52
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
union kaldi::nnet3::NetworkNode::@15 u
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466

◆ Train()

void Train ( const NnetExample eg)

Definition at line 57 of file nnet-training.cc.

References NnetTrainerOptions::backstitch_training_interval, NnetTrainerOptions::backstitch_training_scale, CachingOptimizingCompiler::Compile(), NnetTrainer::compiler_, NnetTrainer::config_, kaldi::nnet3::ConsolidateMemory(), NnetTrainer::delta_nnet_, kaldi::nnet3::FreezeNaturalGradient(), kaldi::nnet3::GetComputationRequest(), KALDI_ASSERT, NnetTrainerOptions::momentum, NnetTrainer::nnet_, NnetTrainer::num_minibatches_processed_, kaldi::nnet3::ResetGenerators(), NnetTrainer::srand_seed_, NnetTrainerOptions::store_component_stats, NnetTrainer::TrainInternal(), and NnetTrainer::TrainInternalBackstitch().

Referenced by main().

57  {
58  bool need_model_derivative = true;
59  ComputationRequest request;
60  GetComputationRequest(*nnet_, eg, need_model_derivative,
62  &request);
63  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
64 
68  // backstitch training is incompatible with momentum > 0
71  bool is_backstitch_step1 = true;
74  TrainInternalBackstitch(eg, *computation, is_backstitch_step1);
75  FreezeNaturalGradient(false, delta_nnet_); // un-freeze natural gradient
76  is_backstitch_step1 = false;
79  TrainInternalBackstitch(eg, *computation, is_backstitch_step1);
80  } else { // conventional training
81  TrainInternal(eg, *computation);
82  }
83  if (num_minibatches_processed_ == 0) {
86  }
88 
89 }
void TrainInternalBackstitch(const NnetExample &eg, const NnetComputation &computation, bool is_backstitch_step1)
void FreezeNaturalGradient(bool freeze, Nnet *nnet)
Controls if natural gradient will be updated.
Definition: nnet-utils.cc:432
void ResetGenerators(Nnet *nnet)
This function calls &#39;ResetGenerator()&#39; on all components in &#39;nnet&#39; that inherit from class RandomComp...
Definition: nnet-utils.cc:582
CachingOptimizingCompiler compiler_
const NnetTrainerOptions config_
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
void ConsolidateMemory(Nnet *nnet)
This just calls ConsolidateMemory() on all the components of the nnet.
Definition: nnet-utils.cc:1147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void TrainInternal(const NnetExample &eg, const NnetComputation &computation)
void GetComputationRequest(const Nnet &nnet, const NnetExample &eg, bool need_model_derivative, bool store_component_stats, ComputationRequest *request)
This function takes a NnetExample (which should already have been frame-selected, if desired...

◆ TrainInternal()

void TrainInternal ( const NnetExample eg,
const NnetComputation computation 
)
private

Definition at line 91 of file nnet-training.cc.

References NnetComputer::AcceptInputs(), kaldi::nnet3::ApplyL2Regularization(), NnetTrainerOptions::batchnorm_stats_scale, NnetTrainerOptions::compute_config, NnetTrainer::config_, kaldi::nnet3::ConstrainOrthonormal(), NnetTrainer::delta_nnet_, kaldi::nnet3::GetNumNvalues(), NnetExample::io, NnetTrainerOptions::l2_regularize_factor, NnetTrainer::max_change_stats_, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetTrainer::nnet_, NnetTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::UpdateNnetWithMaxChange().

Referenced by NnetTrainer::Train().

92  {
93  // note: because we give the 1st arg (nnet_) as a pointer to the
94  // constructor of 'computer', it will use that copy of the nnet to
95  // store stats.
96  NnetComputer computer(config_.compute_config, computation,
98  // give the inputs to the computer object.
99  computer.AcceptInputs(*nnet_, eg.io);
100  computer.Run();
101 
102  this->ProcessOutputs(false, eg, &computer);
103  computer.Run();
104 
105  // If relevant, add in the part of the gradient that comes from L2
106  // regularization.
109  delta_nnet_);
110 
111  // Update the parameters of nnet
112  bool success = UpdateNnetWithMaxChange(
114  1.0, 1.0 - config_.momentum, nnet_, &max_change_stats_);
115 
116  // Scale down the batchnorm stats (keeps them fresh... this affects what
117  // happens when we use the model with batchnorm test-mode set).
119 
120  // The following will only do something if we have a LinearComponent
121  // or AffineComponent with orthonormal-constraint set to a nonzero value.
123 
124  // Scale deta_nnet
125  if (success)
127  else
128  ScaleNnet(0.0, delta_nnet_);
129 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale, Nnet *nnet)
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComp...
Definition: nnet-utils.cc:536
void ConstrainOrthonormal(Nnet *nnet)
This function, to be called after processing every minibatch, is responsible for enforcing the orthog...
Definition: nnet-utils.cc:1108
int32 GetNumNvalues(const std::vector< NnetIo > &io_vec, bool exhaustive)
This utility function can be used to obtain the number of distinct &#39;n&#39; values in a training example...
Definition: nnet-utils.cc:2198
MaxChangeStats max_change_stats_
void ApplyL2Regularization(const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange().
Definition: nnet-utils.cc:2244
const NnetTrainerOptions config_
NnetComputeOptions compute_config
Definition: nnet-training.h:49
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
This function does the operation &#39;*nnet += scale * delta_nnet&#39;, while respecting any max-parameter-ch...
Definition: nnet-utils.cc:2106
void ProcessOutputs(bool is_backstitch_step2, const NnetExample &eg, NnetComputer *computer)

◆ TrainInternalBackstitch()

void TrainInternalBackstitch ( const NnetExample eg,
const NnetComputation computation,
bool  is_backstitch_step1 
)
private

Definition at line 131 of file nnet-training.cc.

References NnetComputer::AcceptInputs(), kaldi::nnet3::ApplyL2Regularization(), NnetTrainerOptions::backstitch_training_scale, NnetTrainerOptions::batchnorm_stats_scale, NnetTrainerOptions::compute_config, NnetTrainer::config_, kaldi::nnet3::ConstrainOrthonormal(), NnetTrainer::delta_nnet_, kaldi::nnet3::GetNumNvalues(), NnetExample::io, NnetTrainerOptions::l2_regularize_factor, NnetTrainer::max_change_stats_, NnetTrainerOptions::max_param_change, NnetTrainer::nnet_, NnetTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::UpdateNnetWithMaxChange().

Referenced by NnetTrainer::Train().

133  {
134  // note: because we give the 1st arg (nnet_) as a pointer to the
135  // constructor of 'computer', it will use that copy of the nnet to
136  // store stats.
137  NnetComputer computer(config_.compute_config, computation,
138  nnet_, delta_nnet_);
139  // give the inputs to the computer object.
140  computer.AcceptInputs(*nnet_, eg.io);
141  computer.Run();
142 
143  bool is_backstitch_step2 = !is_backstitch_step1;
144  this->ProcessOutputs(is_backstitch_step2, eg, &computer);
145  computer.Run();
146 
147  BaseFloat max_change_scale, scale_adding;
148  if (is_backstitch_step1) {
149  // max-change is scaled by backstitch_training_scale;
150  // delta_nnet is scaled by -backstitch_training_scale when added to nnet;
151  max_change_scale = config_.backstitch_training_scale;
152  scale_adding = -config_.backstitch_training_scale;
153  } else {
154  // max-change is scaled by 1 + backstitch_training_scale;
155  // delta_nnet is scaled by 1 + backstitch_training_scale when added to nnet;
156  max_change_scale = 1.0 + config_.backstitch_training_scale;
157  scale_adding = 1.0 + config_.backstitch_training_scale;
158  // If relevant, add in the part of the gradient that comes from L2
159  // regularization. It may not be optimally inefficient to do it on both
160  // passes of the backstitch, like we do here, but it probably minimizes
161  // any harmful interactions with the max-change.
163  1.0 / scale_adding * GetNumNvalues(eg.io, false) *
165  }
166 
167  // Updates the parameters of nnet
170  max_change_scale, scale_adding, nnet_,
172 
173  if (is_backstitch_step1) {
174  // The following will only do something if we have a LinearComponent or
175  // AffineComponent with orthonormal-constraint set to a nonzero value. We
176  // choose to do this only on the 1st backstitch step, for efficiency.
178  }
179 
180  if (!is_backstitch_step1) {
181  // Scale down the batchnorm stats (keeps them fresh... this affects what
182  // happens when we use the model with batchnorm test-mode set). Do this
183  // after backstitch step 2 so that the stats are scaled down before we start
184  // the next minibatch.
186  }
187 
188  ScaleNnet(0.0, delta_nnet_);
189 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale, Nnet *nnet)
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComp...
Definition: nnet-utils.cc:536
void ConstrainOrthonormal(Nnet *nnet)
This function, to be called after processing every minibatch, is responsible for enforcing the orthog...
Definition: nnet-utils.cc:1108
int32 GetNumNvalues(const std::vector< NnetIo > &io_vec, bool exhaustive)
This utility function can be used to obtain the number of distinct &#39;n&#39; values in a training example...
Definition: nnet-utils.cc:2198
float BaseFloat
Definition: kaldi-types.h:29
MaxChangeStats max_change_stats_
void ApplyL2Regularization(const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange().
Definition: nnet-utils.cc:2244
const NnetTrainerOptions config_
NnetComputeOptions compute_config
Definition: nnet-training.h:49
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
This function does the operation &#39;*nnet += scale * delta_nnet&#39;, while respecting any max-parameter-ch...
Definition: nnet-utils.cc:2106
void ProcessOutputs(bool is_backstitch_step2, const NnetExample &eg, NnetComputer *computer)

Member Data Documentation

◆ compiler_

◆ config_

◆ delta_nnet_

◆ max_change_stats_

◆ nnet_

◆ num_minibatches_processed_

int32 num_minibatches_processed_
private

Definition at line 217 of file nnet-training.h.

Referenced by NnetTrainer::ProcessOutputs(), and NnetTrainer::Train().

◆ objf_info_

unordered_map<std::string, ObjectiveFunctionInfo, StringHasher> objf_info_
private

Definition at line 222 of file nnet-training.h.

Referenced by NnetTrainer::PrintTotalStats(), and NnetTrainer::ProcessOutputs().

◆ srand_seed_

int32 srand_seed_
private

Definition at line 227 of file nnet-training.h.

Referenced by NnetTrainer::Train().


The documentation for this class was generated from the following files: