nnet-parallel-component.h
Go to the documentation of this file.
1 // nnet/nnet-parallel-component.h
2 
3 // Copyright 2014 Brno University of Technology (Author: Karel Vesely)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 
21 #ifndef KALDI_NNET_NNET_PARALLEL_COMPONENT_H_
22 #define KALDI_NNET_NNET_PARALLEL_COMPONENT_H_
23 
24 #include <string>
25 #include <vector>
26 #include <sstream>
27 
28 #include "nnet/nnet-component.h"
29 #include "nnet/nnet-utils.h"
30 #include "cudamatrix/cu-math.h"
31 
32 
33 namespace kaldi {
34 namespace nnet1 {
35 
37  public:
38  ParallelComponent(int32 dim_in, int32 dim_out):
39  MultistreamComponent(dim_in, dim_out)
40  { }
41 
43  { }
44 
45  Component* Copy() const { return new ParallelComponent(*this); }
47 
48  const Nnet& GetNestedNnet(int32 id) const { return nnet_.at(id); }
49  Nnet& GetNestedNnet(int32 id) { return nnet_.at(id); }
50 
51  void InitData(std::istream &is) {
52  // define options
53  std::vector<std::string> nested_nnet_proto;
54  std::vector<std::string> nested_nnet_filename;
55  // parse config
56  std::string token;
57  while (is >> std::ws, !is.eof()) {
58  ReadToken(is, false, &token);
59  if (token == "<NestedNnet>" || token == "<NestedNnetFilename>") {
60  while (is >> std::ws, !is.eof()) {
61  std::string file_or_end;
62  ReadToken(is, false, &file_or_end);
63  if (file_or_end == "</NestedNnet>" ||
64  file_or_end == "</NestedNnetFilename>") break;
65  nested_nnet_filename.push_back(file_or_end);
66  }
67  } else if (token == "<NestedNnetProto>") {
68  while (is >> std::ws, !is.eof()) {
69  std::string file_or_end;
70  ReadToken(is, false, &file_or_end);
71  if (file_or_end == "</NestedNnetProto>") break;
72  nested_nnet_proto.push_back(file_or_end);
73  }
74  } else { KALDI_ERR << "Unknown token " << token << ", typo in config?"
75  << " (NestedNnet|NestedNnetFilename|NestedNnetProto)";
76  }
77  }
78  // Initialize,
79  // First, read nnets from files,
80  if (nested_nnet_filename.size() > 0) {
81  for (int32 i = 0; i < nested_nnet_filename.size(); i++) {
82  Nnet nnet;
83  nnet.Read(nested_nnet_filename[i]);
84  nnet_.push_back(nnet);
85  KALDI_LOG << "Loaded nested <Nnet> from file : "
86  << nested_nnet_filename[i];
87  }
88  }
89  // Second, initialize nnets from prototypes,
90  if (nested_nnet_proto.size() > 0) {
91  for (int32 i = 0; i < nested_nnet_proto.size(); i++) {
92  Nnet nnet;
93  nnet.Init(nested_nnet_proto[i]);
94  nnet_.push_back(nnet);
95  KALDI_LOG << "Initialized nested <Nnet> from prototype : "
96  << nested_nnet_proto[i];
97  }
98  }
99  // Check dim-sum of nested nnets,
100  int32 nnet_input_sum = 0, nnet_output_sum = 0;
101  for (int32 i = 0; i < nnet_.size(); i++) {
102  nnet_input_sum += nnet_[i].InputDim();
103  nnet_output_sum += nnet_[i].OutputDim();
104  }
105  KALDI_ASSERT(InputDim() == nnet_input_sum);
106  KALDI_ASSERT(OutputDim() == nnet_output_sum);
107  }
108 
109  void ReadData(std::istream &is, bool binary) {
110  // read
111  ExpectToken(is, binary, "<NestedNnetCount>");
112  int32 nnet_count;
113  ReadBasicType(is, binary, &nnet_count);
114  for (int32 i = 0; i < nnet_count; i++) {
115  ExpectToken(is, binary, "<NestedNnet>");
116  int32 dummy;
117  ReadBasicType(is, binary, &dummy);
118  Nnet nnet;
119  nnet.Read(is, binary);
120  nnet_.push_back(nnet);
121  }
122  ExpectToken(is, binary, "</ParallelComponent>");
123 
124  // check dim-sum of nested nnets
125  int32 nnet_input_sum = 0, nnet_output_sum = 0;
126  for (int32 i = 0; i < nnet_.size(); i++) {
127  nnet_input_sum += nnet_[i].InputDim();
128  nnet_output_sum += nnet_[i].OutputDim();
129  }
130  KALDI_ASSERT(InputDim() == nnet_input_sum);
131  KALDI_ASSERT(OutputDim() == nnet_output_sum);
132  }
133 
134  void WriteData(std::ostream &os, bool binary) const {
135  // useful dims
136  int32 nnet_count = nnet_.size();
137  //
138  WriteToken(os, binary, "<NestedNnetCount>");
139  WriteBasicType(os, binary, nnet_count);
140  if (!binary) os << "\n";
141  for (int32 i = 0; i < nnet_count; i++) {
142  WriteToken(os, binary, "<NestedNnet>");
143  WriteBasicType(os, binary, i+1);
144  if (!binary) os << "\n";
145  nnet_[i].Write(os, binary);
146  }
147  WriteToken(os, binary, "</ParallelComponent>");
148  }
149 
150  int32 NumParams() const {
151  int32 ans = 0;
152  for (int32 i = 0; i < nnet_.size(); i++) {
153  ans += nnet_[i].NumParams();
154  }
155  return ans;
156  }
157 
158  void GetGradient(VectorBase<BaseFloat>* gradient) const {
159  KALDI_ASSERT(gradient->Dim() == NumParams());
160  int32 offset = 0;
161  for (int32 i = 0; i < nnet_.size(); i++) {
162  int32 n_params = nnet_[i].NumParams();
163  Vector<BaseFloat> gradient_aux; // we need 'Vector<>',
164  nnet_[i].GetGradient(&gradient_aux); // copy gradient from Nnet,
165  gradient->Range(offset, n_params).CopyFromVec(gradient_aux);
166  offset += n_params;
167  }
168  KALDI_ASSERT(offset == NumParams());
169  }
170 
171  void GetParams(VectorBase<BaseFloat>* params) const {
172  KALDI_ASSERT(params->Dim() == NumParams());
173  int32 offset = 0;
174  for (int32 i = 0; i < nnet_.size(); i++) {
175  int32 n_params = nnet_[i].NumParams();
176  Vector<BaseFloat> params_aux; // we need 'Vector<>',
177  nnet_[i].GetParams(&params_aux); // copy params from Nnet,
178  params->Range(offset, n_params).CopyFromVec(params_aux);
179  offset += n_params;
180  }
181  KALDI_ASSERT(offset == NumParams());
182  }
183 
184  void SetParams(const VectorBase<BaseFloat>& params) {
185  KALDI_ASSERT(params.Dim() == NumParams());
186  int32 offset = 0;
187  for (int32 i = 0; i < nnet_.size(); i++) {
188  int32 n_params = nnet_[i].NumParams();
189  nnet_[i].SetParams(params.Range(offset, n_params));
190  offset += n_params;
191  }
192  KALDI_ASSERT(offset == NumParams());
193  }
194 
195  std::string Info() const {
196  std::ostringstream os;
197  os << "\n";
198  for (int32 i = 0; i < nnet_.size(); i++) {
199  os << "nested_network #" << i+1 << " {\n"
200  << nnet_[i].Info()
201  << "}\n";
202  }
203  std::string s(os.str());
204  s.erase(s.end() -1); // removing last '\n'
205  return s;
206  }
207 
208  std::string InfoGradient() const {
209  std::ostringstream os;
210  os << "\n";
211  for (int32 i = 0; i < nnet_.size(); i++) {
212  os << "nested_gradient #" << i+1 << " {\n"
213  << nnet_[i].InfoGradient(false)
214  << "}\n";
215  }
216  std::string s(os.str());
217  s.erase(s.end() -1); // removing last '\n'
218  return s;
219  }
220 
221  std::string InfoPropagate() const {
222  std::ostringstream os;
223  for (int32 i = 0; i < nnet_.size(); i++) {
224  os << "nested_propagate #" << i+1 << " {\n"
225  << nnet_[i].InfoPropagate(false)
226  << "}\n";
227  }
228  return os.str();
229  }
230 
231  std::string InfoBackPropagate() const {
232  std::ostringstream os;
233  for (int32 i = 0; i < nnet_.size(); i++) {
234  os << "nested_backpropagate #" << i+1 << " {\n"
235  << nnet_[i].InfoBackPropagate(false)
236  << "}\n";
237  }
238  return os.str();
239  }
240 
243  // column-offsets for data buffers 'in,out',
244  int32 input_offset = 0, output_offset = 0;
245  // loop over nnets,
246  for (int32 i = 0; i < nnet_.size(); i++) {
247  // get the data 'windows',
249  in.ColRange(input_offset, nnet_[i].InputDim())
250  );
252  out->ColRange(output_offset, nnet_[i].OutputDim())
253  );
254  // forward through auxiliary matrix, as 'Propagate' requires 'CuMatrix',
255  CuMatrix<BaseFloat> tgt_aux;
256  nnet_[i].Propagate(src, &tgt_aux);
257  tgt.CopyFromMat(tgt_aux);
258  // advance the offsets,
259  input_offset += nnet_[i].InputDim();
260  output_offset += nnet_[i].OutputDim();
261  }
262  }
263 
265  const CuMatrixBase<BaseFloat> &out,
266  const CuMatrixBase<BaseFloat> &out_diff,
267  CuMatrixBase<BaseFloat> *in_diff) {
268  // column-offsets for data buffers 'in,out',
269  int32 input_offset = 0, output_offset = 0;
270  // loop over nnets,
271  for (int32 i = 0; i < nnet_.size(); i++) {
272  // get the data 'windows',
274  out_diff.ColRange(output_offset, nnet_[i].OutputDim())
275  );
277  in_diff->ColRange(input_offset, nnet_[i].InputDim())
278  );
279  // ::Backpropagate through auxiliary matrix (CuMatrix in the interface),
280  CuMatrix<BaseFloat> tgt_aux;
281  nnet_[i].Backpropagate(src, &tgt_aux);
282  tgt.CopyFromMat(tgt_aux);
283  // advance the offsets,
284  input_offset += nnet_[i].InputDim();
285  output_offset += nnet_[i].OutputDim();
286  }
287  }
288 
289  void Update(const CuMatrixBase<BaseFloat> &input,
290  const CuMatrixBase<BaseFloat> &diff) {
291  { } // do nothing
292  }
293 
298  void SetTrainOptions(const NnetTrainOptions &opts) {
299  for (int32 i = 0; i < nnet_.size(); i++) {
300  nnet_[i].SetTrainOptions(opts);
301  }
302  }
303 
309  // loop over nnets,
310  for (int32 i = 0; i < nnet_.size(); i++) {
311  // loop over components,
312  for (int32 j = 0; j < nnet_[i].NumComponents(); j++) {
313  if (nnet_[i].GetComponent(j).IsUpdatable()) {
314  UpdatableComponent& comp =
315  dynamic_cast<UpdatableComponent&>(nnet_[i].GetComponent(j));
316  // set the value,
317  comp.SetLearnRateCoef(val);
318  }
319  }
320  }
321  }
322 
328  // loop over nnets,
329  for (int32 i = 0; i < nnet_.size(); i++) {
330  // loop over components,
331  for (int32 j = 0; j < nnet_[i].NumComponents(); j++) {
332  if (nnet_[i].GetComponent(j).IsUpdatable()) {
333  UpdatableComponent& comp =
334  dynamic_cast<UpdatableComponent&>(nnet_[i].GetComponent(j));
335  // set the value,
336  comp.SetBiasLearnRateCoef(val);
337  }
338  }
339  }
340  }
341 
346  void SetSeqLengths(const std::vector<int32> &sequence_lengths) {
347  sequence_lengths_ = sequence_lengths;
348  // loop over nnets,
349  for (int32 i = 0; i < nnet_.size(); i++) {
350  nnet_[i].SetSeqLengths(sequence_lengths);
351  }
352  }
353 
354  private:
355  std::vector<Nnet> nnet_;
356 };
357 
358 } // namespace nnet1
359 } // namespace kaldi
360 
361 #endif // KALDI_NNET_NNET_PARALLEL_COMPONENT_H_
void WriteData(std::ostream &os, bool binary) const
Writes the component content.
void CopyFromMat(const MatrixBase< OtherReal > &src, MatrixTransposeType trans=kNoTrans)
Definition: cu-matrix.cc:344
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PropagateFnc(const CuMatrixBase< BaseFloat > &in, CuMatrixBase< BaseFloat > *out)
Abstract interface for propagation/backpropagation.
std::string Info() const
Print some additional info (after <ComponentName> and the dims),.
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
Class UpdatableComponent is a Component which has trainable parameters, it contains SGD training hype...
int32 NumParams() const
Number of trainable parameters,.
kaldi::int32 int32
void ReadToken(std::istream &is, bool binary, std::string *str)
ReadToken gets the next token and puts it in str (exception on failure).
Definition: io-funcs.cc:154
void SetTrainOptions(const NnetTrainOptions &opts)
Overriding the default, which was UpdatableComponent::SetTrainOptions(...)
This class represents a matrix that&#39;s stored on the GPU if we have one, and in memory if not...
Definition: matrix-common.h:71
const Nnet & GetNestedNnet(int32 id) const
ComponentType GetType() const
Get Type Identification of the component,.
virtual void SetLearnRateCoef(BaseFloat val)
Set the learn-rate coefficient,.
ComponentType
Component type identification mechanism,.
void SetParams(const VectorBase< BaseFloat > &params)
Set the trainable parameters from, reshaped as a vector,.
void SetSeqLengths(const std::vector< int32 > &sequence_lengths)
Overriding the default, which was MultistreamComponent::SetSeqLengths(...)
void SetBiasLearnRateCoef(BaseFloat val)
Overriding the default, which was UpdatableComponent::SetBiasLearnRateCoef(...)
void ReadData(std::istream &is, bool binary)
Reads the component content.
virtual void SetBiasLearnRateCoef(BaseFloat val)
Set the learn-rate coefficient for bias,.
void ExpectToken(std::istream &is, bool binary, const char *token)
ExpectToken tries to read in the given token, and throws an exception on failure. ...
Definition: io-funcs.cc:191
int32 InputDim() const
Get the dimension of the input,.
#define KALDI_ERR
Definition: kaldi-error.h:147
void Read(const std::string &rxfilename)
Read Nnet from &#39;rxfilename&#39;,.
Definition: nnet-nnet.cc:333
This class is used for a piece of a CuMatrix.
Definition: matrix-common.h:70
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
Class MultistreamComponent is an extension of UpdatableComponent for recurrent networks, which are trained with parallel sequences.
std::string InfoGradient() const
Print some additional info about gradient (after <...> and dims),.
ParallelComponent(int32 dim_in, int32 dim_out)
CuSubMatrix< Real > ColRange(const MatrixIndexT col_offset, const MatrixIndexT num_cols) const
Definition: cu-matrix.h:665
void InitData(std::istream &is)
Initialize the content of the component by the &#39;line&#39; from the prototype,.
Matrix for CUDA computing.
Definition: matrix-common.h:69
void Init(const std::string &proto_file)
Initialize the Nnet from the prototype,.
Definition: nnet-nnet.cc:301
Component * Copy() const
Copy component (deep copy),.
void GetParams(VectorBase< BaseFloat > *params) const
Get the trainable parameters reshaped as a vector,.
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void BackpropagateFnc(const CuMatrixBase< BaseFloat > &in, const CuMatrixBase< BaseFloat > &out, const CuMatrixBase< BaseFloat > &out_diff, CuMatrixBase< BaseFloat > *in_diff)
Backward pass transformation (to be implemented by descending class...)
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
Abstract class, building block of the network.
std::vector< int32 > sequence_lengths_
int32 OutputDim() const
Get the dimension of the output,.
void Update(const CuMatrixBase< BaseFloat > &input, const CuMatrixBase< BaseFloat > &diff)
Compute gradient and update parameters,.
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
#define KALDI_LOG
Definition: kaldi-error.h:153
void GetGradient(VectorBase< BaseFloat > *gradient) const
Get gradient reshaped as a vector,.
void SetLearnRateCoef(BaseFloat val)
Overriding the default, which was UpdatableComponent::SetLearnRateCoef(...)
SubVector< Real > Range(const MatrixIndexT o, const MatrixIndexT l)
Returns a sub-vector of a vector (a range of elements).
Definition: kaldi-vector.h:94