NnetChainTrainer Class Reference

This class is for single-threaded training of neural nets using the 'chain' model. More...

#include <nnet-chain-training.h>

Collaboration diagram for NnetChainTrainer:

Public Member Functions

 NnetChainTrainer (const NnetChainTrainingOptions &config, const fst::StdVectorFst &den_fst, Nnet *nnet)
 
void Train (const NnetChainExample &eg)
 
bool PrintTotalStats () const
 
 ~NnetChainTrainer ()
 

Private Member Functions

void TrainInternal (const NnetChainExample &eg, const NnetComputation &computation)
 
void TrainInternalBackstitch (const NnetChainExample &eg, const NnetComputation &computation, bool is_backstitch_step1)
 
void ProcessOutputs (bool is_backstitch_step2, const NnetChainExample &eg, NnetComputer *computer)
 

Private Attributes

const NnetChainTrainingOptions opts_
 
chain::DenominatorGraph den_graph_
 
Nnetnnet_
 
Nnetdelta_nnet_
 
CachingOptimizingCompiler compiler_
 
int32 num_minibatches_processed_
 
MaxChangeStats max_change_stats_
 
unordered_map< std::string, ObjectiveFunctionInfo, StringHasherobjf_info_
 
int32 srand_seed_
 

Detailed Description

This class is for single-threaded training of neural nets using the 'chain' model.

Definition at line 55 of file nnet-chain-training.h.

Constructor & Destructor Documentation

◆ NnetChainTrainer()

NnetChainTrainer ( const NnetChainTrainingOptions config,
const fst::StdVectorFst den_fst,
Nnet nnet 
)

Definition at line 27 of file nnet-chain-training.cc.

References NnetTrainerOptions::backstitch_training_interval, NnetChainTrainer::compiler_, Nnet::Copy(), NnetChainTrainer::delta_nnet_, KALDI_ASSERT, KALDI_LOG, KALDI_WARN, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetTrainerOptions::read_cache, CachingOptimizingCompiler::ReadCache(), kaldi::nnet3::ScaleNnet(), Input::Stream(), NnetTrainerOptions::zero_component_stats, and kaldi::nnet3::ZeroComponentStats().

29  :
30  opts_(opts),
31  den_graph_(den_fst, nnet->OutputDim("output")),
32  nnet_(nnet),
36  max_change_stats_(*nnet),
37  srand_seed_(RandInt(0, 100000)) {
38  if (opts.nnet_config.zero_component_stats)
39  ZeroComponentStats(nnet);
40  KALDI_ASSERT(opts.nnet_config.momentum >= 0.0 &&
41  opts.nnet_config.max_param_change >= 0.0 &&
42  opts.nnet_config.backstitch_training_interval > 0);
43  delta_nnet_ = nnet_->Copy();
44  ScaleNnet(0.0, delta_nnet_);
45 
46  if (opts.nnet_config.read_cache != "") {
47  bool binary;
48  try {
49  Input ki(opts.nnet_config.read_cache, &binary);
50  compiler_.ReadCache(ki.Stream(), binary);
51  KALDI_LOG << "Read computation cache from " << opts.nnet_config.read_cache;
52  } catch (...) {
53  KALDI_WARN << "Could not open cached computation. "
54  "Probably this is the first training iteration.";
55  }
56  }
57 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
NnetOptimizeOptions optimize_config
Definition: nnet-training.h:48
const NnetChainTrainingOptions opts_
CachingOptimizingCompilerOptions compiler_config
Definition: nnet-training.h:50
chain::DenominatorGraph den_graph_
CachingOptimizingCompiler compiler_
Nnet * Copy() const
Definition: nnet-nnet.h:246
void ZeroComponentStats(Nnet *nnet)
Zeroes the component stats in all nonlinear components in the nnet.
Definition: nnet-utils.cc:269
#define KALDI_WARN
Definition: kaldi-error.h:150
void ReadCache(std::istream &is, bool binary)
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define KALDI_LOG
Definition: kaldi-error.h:153
int32 RandInt(int32 min_val, int32 max_val, struct RandomState *state)
Definition: kaldi-math.cc:95

◆ ~NnetChainTrainer()

Definition at line 287 of file nnet-chain-training.cc.

References NnetTrainerOptions::binary_write_cache, NnetChainTrainer::compiler_, NnetChainTrainer::delta_nnet_, KALDI_LOG, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::opts_, Output::Stream(), NnetTrainerOptions::write_cache, and CachingOptimizingCompiler::WriteCache().

287  {
288  if (opts_.nnet_config.write_cache != "") {
291  KALDI_LOG << "Wrote computation cache to " << opts_.nnet_config.write_cache;
292  }
293  delete delta_nnet_;
294 }
const NnetChainTrainingOptions opts_
CachingOptimizingCompiler compiler_
void WriteCache(std::ostream &os, bool binary)
#define KALDI_LOG
Definition: kaldi-error.h:153

Member Function Documentation

◆ PrintTotalStats()

bool PrintTotalStats ( ) const

Definition at line 273 of file nnet-chain-training.cc.

References NnetChainTrainer::max_change_stats_, NnetChainTrainer::nnet_, NnetChainTrainer::objf_info_, MaxChangeStats::Print(), and ObjectiveFunctionInfo::PrintTotalStats().

273  {
274  unordered_map<std::string, ObjectiveFunctionInfo, StringHasher>::const_iterator
275  iter = objf_info_.begin(),
276  end = objf_info_.end();
277  bool ans = false;
278  for (; iter != end; ++iter) {
279  const std::string &name = iter->first;
280  const ObjectiveFunctionInfo &info = iter->second;
281  ans = info.PrintTotalStats(name) || ans;
282  }
284  return ans;
285 }
void Print(const Nnet &nnet) const
Definition: nnet-utils.cc:2284
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_

◆ ProcessOutputs()

void ProcessOutputs ( bool  is_backstitch_step2,
const NnetChainExample eg,
NnetComputer computer 
)
private

Definition at line 204 of file nnet-chain-training.cc.

References NnetComputer::AcceptInput(), NnetChainTrainingOptions::apply_deriv_weights, NnetChainTrainingOptions::chain_config, NnetChainTrainer::den_graph_, NnetChainSupervision::deriv_weights, Nnet::GetNodeIndex(), NnetComputer::GetOutput(), Nnet::IsOutputNode(), KALDI_ERR, kaldi::kTrans, kaldi::kUndefined, CuMatrixBase< Real >::MulRowsVec(), NnetChainSupervision::name, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::num_minibatches_processed_, NVTX_RANGE, NnetChainTrainer::objf_info_, NnetChainTrainer::opts_, NnetChainExample::outputs, NnetTrainerOptions::print_interval, CuMatrixBase< Real >::Scale(), NnetChainSupervision::supervision, and kaldi::TraceMatMat().

Referenced by NnetChainTrainer::TrainInternal(), and NnetChainTrainer::TrainInternalBackstitch().

206  {
207  NVTX_RANGE(__func__);
208  // normally the eg will have just one output named 'output', but
209  // we don't assume this.
210  // In backstitch training, the output-name with the "_backstitch" suffix is
211  // the one computed after the first, backward step of backstitch.
212  const std::string suffix = (is_backstitch_step2 ? "_backstitch" : "");
213  std::vector<NnetChainSupervision>::const_iterator iter = eg.outputs.begin(),
214  end = eg.outputs.end();
215  for (; iter != end; ++iter) {
216  const NnetChainSupervision &sup = *iter;
217  int32 node_index = nnet_->GetNodeIndex(sup.name);
218  if (node_index < 0 ||
219  !nnet_->IsOutputNode(node_index))
220  KALDI_ERR << "Network has no output named " << sup.name;
221 
222  const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);
223  CuMatrix<BaseFloat> nnet_output_deriv(nnet_output.NumRows(),
224  nnet_output.NumCols(),
225  kUndefined);
226 
227  bool use_xent = (opts_.chain_config.xent_regularize != 0.0);
228  std::string xent_name = sup.name + "-xent"; // typically "output-xent".
229  CuMatrix<BaseFloat> xent_deriv;
230 
231  BaseFloat tot_objf, tot_l2_term, tot_weight;
232 
233  ComputeChainObjfAndDeriv(opts_.chain_config, den_graph_,
234  sup.supervision, nnet_output,
235  &tot_objf, &tot_l2_term, &tot_weight,
236  &nnet_output_deriv,
237  (use_xent ? &xent_deriv : NULL));
238 
239  if (use_xent) {
240  // this block computes the cross-entropy objective.
241  const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput(
242  xent_name);
243  // at this point, xent_deriv is posteriors derived from the numerator
244  // computation. note, xent_objf has a factor of '.supervision.weight'
245  BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans);
246  objf_info_[xent_name + suffix].UpdateStats(xent_name + suffix,
249  tot_weight, xent_objf);
250  }
251 
252  if (opts_.apply_deriv_weights && sup.deriv_weights.Dim() != 0) {
253  CuVector<BaseFloat> cu_deriv_weights(sup.deriv_weights);
254  nnet_output_deriv.MulRowsVec(cu_deriv_weights);
255  if (use_xent)
256  xent_deriv.MulRowsVec(cu_deriv_weights);
257  }
258 
259  computer->AcceptInput(sup.name, &nnet_output_deriv);
260 
261  objf_info_[sup.name + suffix].UpdateStats(sup.name + suffix,
264  tot_weight, tot_objf, tot_l2_term);
265 
266  if (use_xent) {
267  xent_deriv.Scale(opts_.chain_config.xent_regularize);
268  computer->AcceptInput(xent_name, &xent_deriv);
269  }
270  }
271 }
const NnetChainTrainingOptions opts_
kaldi::int32 int32
chain::DenominatorGraph den_graph_
float BaseFloat
Definition: kaldi-types.h:29
bool IsOutputNode(int32 node) const
Returns true if this is an output node, meaning that it is of type kDescriptor and is not directly fo...
Definition: nnet-nnet.cc:112
#define KALDI_ERR
Definition: kaldi-error.h:147
Real TraceMatMat(const MatrixBase< Real > &A, const MatrixBase< Real > &B, MatrixTransposeType trans)
We need to declare this here as it will be a friend function.
#define NVTX_RANGE(name)
Definition: cu-common.h:143
chain::ChainTrainingOptions chain_config
int32 GetNodeIndex(const std::string &node_name) const
returns index associated with this node name, or -1 if no such index.
Definition: nnet-nnet.cc:466
unordered_map< std::string, ObjectiveFunctionInfo, StringHasher > objf_info_

◆ Train()

void Train ( const NnetChainExample eg)

Definition at line 60 of file nnet-chain-training.cc.

References NnetTrainerOptions::backstitch_training_interval, NnetTrainerOptions::backstitch_training_scale, NnetChainTrainingOptions::chain_config, CachingOptimizingCompiler::Compile(), NnetChainTrainer::compiler_, kaldi::nnet3::ConsolidateMemory(), NnetChainTrainer::delta_nnet_, kaldi::nnet3::FreezeNaturalGradient(), kaldi::nnet3::GetChainComputationRequest(), KALDI_ASSERT, NnetTrainerOptions::momentum, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::num_minibatches_processed_, NVTX_RANGE, NnetChainTrainer::opts_, kaldi::nnet3::ResetGenerators(), NnetChainTrainer::srand_seed_, NnetTrainerOptions::store_component_stats, NnetChainTrainer::TrainInternal(), and NnetChainTrainer::TrainInternalBackstitch().

60  {
61  NVTX_RANGE(__func__);
62  bool need_model_derivative = true;
63  const NnetTrainerOptions &nnet_config = opts_.nnet_config;
64  bool use_xent_regularization = (opts_.chain_config.xent_regularize != 0.0);
65  ComputationRequest request;
66  GetChainComputationRequest(*nnet_, chain_eg, need_model_derivative,
67  nnet_config.store_component_stats,
68  use_xent_regularization, need_model_derivative,
69  &request);
70  std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
71 
72  if (nnet_config.backstitch_training_scale > 0.0 && num_minibatches_processed_
73  % nnet_config.backstitch_training_interval ==
74  srand_seed_ % nnet_config.backstitch_training_interval) {
75  // backstitch training is incompatible with momentum > 0
76  KALDI_ASSERT(nnet_config.momentum == 0.0);
78  bool is_backstitch_step1 = true;
81  TrainInternalBackstitch(chain_eg, *computation, is_backstitch_step1);
82  FreezeNaturalGradient(false, delta_nnet_); // un-freeze natural gradient
83  is_backstitch_step1 = false;
86  TrainInternalBackstitch(chain_eg, *computation, is_backstitch_step1);
87  } else { // conventional training
88  TrainInternal(chain_eg, *computation);
89  }
90  if (num_minibatches_processed_ == 0) {
93  }
95 }
void TrainInternal(const NnetChainExample &eg, const NnetComputation &computation)
void TrainInternalBackstitch(const NnetChainExample &eg, const NnetComputation &computation, bool is_backstitch_step1)
const NnetChainTrainingOptions opts_
void FreezeNaturalGradient(bool freeze, Nnet *nnet)
Controls if natural gradient will be updated.
Definition: nnet-utils.cc:432
void ResetGenerators(Nnet *nnet)
This function calls &#39;ResetGenerator()&#39; on all components in &#39;nnet&#39; that inherit from class RandomComp...
Definition: nnet-utils.cc:582
CachingOptimizingCompiler compiler_
std::shared_ptr< const NnetComputation > Compile(const ComputationRequest &request)
Does the compilation and returns a const pointer to the result, which is owned by this class...
void ConsolidateMemory(Nnet *nnet)
This just calls ConsolidateMemory() on all the components of the nnet.
Definition: nnet-utils.cc:1147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
#define NVTX_RANGE(name)
Definition: cu-common.h:143
chain::ChainTrainingOptions chain_config
void GetChainComputationRequest(const Nnet &nnet, const NnetChainExample &eg, bool need_model_derivative, bool store_component_stats, bool use_xent_regularization, bool use_xent_derivative, ComputationRequest *request)
This function takes a NnetChainExample and produces a ComputationRequest.

◆ TrainInternal()

void TrainInternal ( const NnetChainExample eg,
const NnetComputation computation 
)
private

Definition at line 97 of file nnet-chain-training.cc.

References NnetComputer::AcceptInputs(), kaldi::nnet3::ApplyL2Regularization(), NnetTrainerOptions::batchnorm_stats_scale, NnetTrainerOptions::compute_config, kaldi::nnet3::ConstrainOrthonormal(), NnetChainTrainer::delta_nnet_, kaldi::nnet3::GetNumNvalues(), NnetChainExample::inputs, NnetTrainerOptions::l2_regularize_factor, NnetChainTrainer::max_change_stats_, NnetTrainerOptions::max_param_change, NnetTrainerOptions::momentum, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NVTX_RANGE, NnetChainTrainer::opts_, NnetChainTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::UpdateNnetWithMaxChange().

Referenced by NnetChainTrainer::Train().

98  {
99  NVTX_RANGE(__func__);
100  const NnetTrainerOptions &nnet_config = opts_.nnet_config;
101  // note: because we give the 1st arg (nnet_) as a pointer to the
102  // constructor of 'computer', it will use that copy of the nnet to
103  // store stats.
104  NnetComputer computer(nnet_config.compute_config, computation,
105  nnet_, delta_nnet_);
106 
107  // give the inputs to the computer object.
108  computer.AcceptInputs(*nnet_, eg.inputs);
109  computer.Run();
110 
111  this->ProcessOutputs(false, eg, &computer);
112  computer.Run();
113 
114  // If relevant, add in the part of the gradient that comes from
115  // parameter-level L2 regularization.
117  GetNumNvalues(eg.inputs, false) *
118  nnet_config.l2_regularize_factor,
119  delta_nnet_);
120 
121  // Updates the parameters of nnet
122  bool success = UpdateNnetWithMaxChange(
123  *delta_nnet_,
124  nnet_config.max_param_change,
125  1.0, 1.0 - nnet_config.momentum, nnet_,
127 
128  // Scale down the batchnorm stats (keeps them fresh... this affects what
129  // happens when we use the model with batchnorm test-mode set).
130  ScaleBatchnormStats(nnet_config.batchnorm_stats_scale, nnet_);
131 
132  // The following will only do something if we have a LinearComponent
133  // or AffineComponent with orthonormal-constraint set to a nonzero value.
135 
136  // Scale delta_nnet
137  if (success)
138  ScaleNnet(nnet_config.momentum, delta_nnet_);
139  else
140  ScaleNnet(0.0, delta_nnet_);
141 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale, Nnet *nnet)
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComp...
Definition: nnet-utils.cc:536
const NnetChainTrainingOptions opts_
void ConstrainOrthonormal(Nnet *nnet)
This function, to be called after processing every minibatch, is responsible for enforcing the orthog...
Definition: nnet-utils.cc:1108
int32 GetNumNvalues(const std::vector< NnetIo > &io_vec, bool exhaustive)
This utility function can be used to obtain the number of distinct &#39;n&#39; values in a training example...
Definition: nnet-utils.cc:2198
void ApplyL2Regularization(const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange().
Definition: nnet-utils.cc:2244
#define NVTX_RANGE(name)
Definition: cu-common.h:143
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
This function does the operation &#39;*nnet += scale * delta_nnet&#39;, while respecting any max-parameter-ch...
Definition: nnet-utils.cc:2106
void ProcessOutputs(bool is_backstitch_step2, const NnetChainExample &eg, NnetComputer *computer)

◆ TrainInternalBackstitch()

void TrainInternalBackstitch ( const NnetChainExample eg,
const NnetComputation computation,
bool  is_backstitch_step1 
)
private

Definition at line 143 of file nnet-chain-training.cc.

References NnetComputer::AcceptInputs(), kaldi::nnet3::ApplyL2Regularization(), NnetTrainerOptions::backstitch_training_scale, NnetTrainerOptions::batchnorm_stats_scale, NnetTrainerOptions::compute_config, kaldi::nnet3::ConstrainOrthonormal(), NnetChainTrainer::delta_nnet_, kaldi::nnet3::GetNumNvalues(), NnetChainExample::inputs, NnetTrainerOptions::l2_regularize_factor, NnetChainTrainer::max_change_stats_, NnetTrainerOptions::max_param_change, NnetChainTrainer::nnet_, NnetChainTrainingOptions::nnet_config, NnetChainTrainer::opts_, NnetChainTrainer::ProcessOutputs(), NnetComputer::Run(), kaldi::nnet3::ScaleBatchnormStats(), kaldi::nnet3::ScaleNnet(), and kaldi::nnet3::UpdateNnetWithMaxChange().

Referenced by NnetChainTrainer::Train().

145  {
146  const NnetTrainerOptions &nnet_config = opts_.nnet_config;
147  // note: because we give the 1st arg (nnet_) as a pointer to the
148  // constructor of 'computer', it will use that copy of the nnet to
149  // store stats.
150  NnetComputer computer(nnet_config.compute_config, computation,
151  nnet_, delta_nnet_);
152  // give the inputs to the computer object.
153  computer.AcceptInputs(*nnet_, eg.inputs);
154  computer.Run();
155 
156  bool is_backstitch_step2 = !is_backstitch_step1;
157  this->ProcessOutputs(is_backstitch_step2, eg, &computer);
158  computer.Run();
159 
160  BaseFloat max_change_scale, scale_adding;
161  if (is_backstitch_step1) {
162  // max-change is scaled by backstitch_training_scale;
163  // delta_nnet is scaled by -backstitch_training_scale when added to nnet;
164  max_change_scale = nnet_config.backstitch_training_scale;
165  scale_adding = -nnet_config.backstitch_training_scale;
166  } else {
167  // max-change is scaled by 1 + backstitch_training_scale;
168  // delta_nnet is scaled by 1 + backstitch_training_scale when added to nnet;
169  max_change_scale = 1.0 + nnet_config.backstitch_training_scale;
170  scale_adding = 1.0 + nnet_config.backstitch_training_scale;
171  // If relevant, add in the part of the gradient that comes from L2
172  // regularization. It may not be optimally inefficient to do it on both
173  // passes of the backstitch, like we do here, but it probably minimizes
174  // any harmful interactions with the max-change.
176  1.0 / scale_adding * GetNumNvalues(eg.inputs, false) *
177  nnet_config.l2_regularize_factor, delta_nnet_);
178  }
179 
180  // Updates the parameters of nnet
182  *delta_nnet_, nnet_config.max_param_change,
183  max_change_scale, scale_adding, nnet_,
185 
186  if (is_backstitch_step1) {
187  // The following will only do something if we have a LinearComponent or
188  // AffineComponent with orthonormal-constraint set to a nonzero value. We
189  // choose to do this only on the 1st backstitch step, for efficiency.
191  }
192 
193  if (!is_backstitch_step1) {
194  // Scale down the batchnorm stats (keeps them fresh... this affects what
195  // happens when we use the model with batchnorm test-mode set). Do this
196  // after backstitch step 2 so that the stats are scaled down before we start
197  // the next minibatch.
198  ScaleBatchnormStats(nnet_config.batchnorm_stats_scale, nnet_);
199  }
200 
201  ScaleNnet(0.0, delta_nnet_);
202 }
void ScaleNnet(BaseFloat scale, Nnet *nnet)
Scales the nnet parameters and stats by this scale.
Definition: nnet-utils.cc:312
void ScaleBatchnormStats(BaseFloat batchnorm_stats_scale, Nnet *nnet)
This function scales the batchorm stats of any batchnorm components (components of type BatchNormComp...
Definition: nnet-utils.cc:536
const NnetChainTrainingOptions opts_
void ConstrainOrthonormal(Nnet *nnet)
This function, to be called after processing every minibatch, is responsible for enforcing the orthog...
Definition: nnet-utils.cc:1108
int32 GetNumNvalues(const std::vector< NnetIo > &io_vec, bool exhaustive)
This utility function can be used to obtain the number of distinct &#39;n&#39; values in a training example...
Definition: nnet-utils.cc:2198
float BaseFloat
Definition: kaldi-types.h:29
void ApplyL2Regularization(const Nnet &nnet, BaseFloat l2_regularize_scale, Nnet *delta_nnet)
This function is used as part of the regular training workflow, prior to UpdateNnetWithMaxChange().
Definition: nnet-utils.cc:2244
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet, BaseFloat max_param_change, BaseFloat max_change_scale, BaseFloat scale, Nnet *nnet, std::vector< int32 > *num_max_change_per_component_applied, int32 *num_max_change_global_applied)
This function does the operation &#39;*nnet += scale * delta_nnet&#39;, while respecting any max-parameter-ch...
Definition: nnet-utils.cc:2106
void ProcessOutputs(bool is_backstitch_step2, const NnetChainExample &eg, NnetComputer *computer)

Member Data Documentation

◆ compiler_

◆ delta_nnet_

◆ den_graph_

chain::DenominatorGraph den_graph_
private

Definition at line 85 of file nnet-chain-training.h.

Referenced by NnetChainTrainer::ProcessOutputs().

◆ max_change_stats_

◆ nnet_

◆ num_minibatches_processed_

int32 num_minibatches_processed_
private

◆ objf_info_

unordered_map<std::string, ObjectiveFunctionInfo, StringHasher> objf_info_
private

◆ opts_

◆ srand_seed_

int32 srand_seed_
private

Definition at line 104 of file nnet-chain-training.h.

Referenced by NnetChainTrainer::Train().


The documentation for this class was generated from the following files: