nnet-compile-looped.cc
Go to the documentation of this file.
1 // nnet3/nnet-compile-looped.cc
2 
3 // Copyright 2016 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
22 #include "nnet3/nnet-utils.h"
23 
24 namespace kaldi {
25 namespace nnet3 {
26 
27 
28 void ModifyNnetIvectorPeriod(int32 ivector_period,
29  Nnet *nnet) {
30  KALDI_ASSERT(ivector_period > 0);
31  std::vector<std::string> config_lines;
32  nnet->GetConfigLines(false, &config_lines);
33  std::ostringstream config_to_read;
34  for (size_t i = 0; i < config_lines.size(); i++) {
35  std::string s = config_lines[i];
36  ConfigLine config_line;
37  bool b = config_line.ParseLine(config_lines[i]);
38  KALDI_ASSERT(b && "Could not parse config line.");
39  if (config_line.FirstToken() == "component-node") {
40  // What we're trying to do here is: find a line like:
41  // component-node name=foo component=foo input=Append(bar, ReplaceIndex(ivector, t, 0))
42  // we want to replace it with something like:
43  // component-node name=foo component=foo input=Append(bar, ReplaceIndex(ivector, t, 0))
44  // .. and we want this to also work if instead of 'ivector' it has something like
45  // Scale(0.5, ivector). We assume that ReplaceIndex() expressions only occur in this
46  // type of context.
47  std::string whole_line = config_lines[i];
48  std::string to_search_for = "ReplaceIndex(";
49  std::string::size_type to_search_for_size = to_search_for.size();
50  std::string::size_type pos = whole_line.find(to_search_for);
51  if (pos != std::string::npos) {
52  std::string::size_type comma_pos = whole_line.find(", t, 0)", pos);
53  if (comma_pos != std::string::npos) {
54  // if the line contained ReplaceIndex(ivector, t, 0),
55  // descriptor_name would now be 'ivector'.
56  std::string descriptor_name =
57  whole_line.substr(pos + to_search_for_size,
58  comma_pos - (pos + to_search_for_size));
59  // Note: 7, below, is the size of: ", t, 0)".
60  std::string::size_type end_pos = comma_pos + 7;
61  std::string::size_type expr_size = end_pos - pos;
62  // e.g. expr_size would be strlen("ReplaceIndex(ivector, t, 0)").
63  std::ostringstream to_replace_with;
64  to_replace_with << "Round(" << descriptor_name << ", " << ivector_period << ")";
65  whole_line.replace(pos, expr_size, to_replace_with.str());
66  config_to_read << whole_line << "\n";
67  } else {
68  KALDI_ERR << "Could not process the ReplaceIndex expression in: "
69  << whole_line;
70  }
71  }
72  }
73  }
74  if (!config_to_read.str().empty()) {
75  std::istringstream is(config_to_read.str());
76  nnet->ReadConfig(is);
77  }
78 }
79 
80 
81 int32 GetChunkSize(const Nnet &nnet,
82  int32 frame_subsampling_factor,
83  int32 advised_chunk_size) {
84  int32 modulus = nnet.Modulus();
85  KALDI_ASSERT(modulus > 0 && frame_subsampling_factor > 0 &&
86  advised_chunk_size > 0);
87  int32 chunk_size = advised_chunk_size;
88  while (1) {
89  if (chunk_size % modulus == 0 &&
90  chunk_size % frame_subsampling_factor == 0)
91  return chunk_size;
92  chunk_size++;
93  }
94 }
95 
96 
106 template<class I> I Mod(I m, I n) {
107  I ans = m % n;
108  if (ans < 0) ans += n;
109  return ans;
110 }
111 
112 
114  int32 begin_input_t, int32 end_input_t,
115  int32 begin_output_t, int32 end_output_t,
116  int32 num_sequences,
117  int32 frame_subsampling_factor,
118  const std::set<int32> &ivector_times,
119  ComputationRequest *request) {
120  request->inputs.reserve(2);
121  request->inputs.clear();
122  request->inputs.resize(1 + (ivector_times.empty() ? 0 : 1));
123  request->inputs[0].name = "input";
124  request->inputs[0].has_deriv = false;
125  request->outputs.clear();
126  request->outputs.resize(1);
127  request->outputs[0].name = "output";
128  request->outputs[0].has_deriv = false;
129  if (!ivector_times.empty()) {
130  request->inputs[1].name = "ivector";
131  request->inputs[1].has_deriv = false;
132  }
133 
134  // in the computation request the 'n' indexes (the sequence/utterance indexes)
135  // have the larger stride than 't', although this is opposite to the way it's
136  // done inside the computation. This is for user convenience where it may be
137  // easier to deal with submatrixes per sequence.
138  for (int32 n = 0; n < num_sequences; n++) {
139  int32 x = 0;
140  for (int32 t = begin_input_t; t < end_input_t; t++) {
141  request->inputs[0].indexes.push_back(Index(n, t, x));
142  }
143  for (int32 t = begin_output_t;
144  t < end_output_t;
145  t += frame_subsampling_factor)
146  request->outputs[0].indexes.push_back(Index(n, t, x));
147  }
148  if (!ivector_times.empty()) {
149  request->inputs.resize(2);
150  request->inputs[1].name = "ivector";
151  request->inputs[1].has_deriv = false;
152  for (int32 n = 0; n < num_sequences; n++) {
153  // note: std::sets store things in sorted order.
154  for (std::set<int32>::const_iterator iter = ivector_times.begin();
155  iter != ivector_times.end(); ++iter) {
156  int32 t = *iter, x = 0;
157  request->inputs[1].indexes.push_back(Index(n, t, x));
158  }
159  }
160  }
161 }
162 
163 
165  int32 chunk_size,
166  int32 frame_subsampling_factor,
167  int32 ivector_period,
168  int32 left_context_begin,
169  int32 right_context,
170  int32 num_sequences,
171  ComputationRequest *request1,
172  ComputationRequest *request2,
173  ComputationRequest *request3) {
174  bool has_ivector = (nnet.InputDim("ivector") > 0);
175  KALDI_ASSERT(chunk_size % frame_subsampling_factor == 0 &&
176  chunk_size % nnet.Modulus() == 0 &&
177  chunk_size % ivector_period == 0);
178  KALDI_ASSERT(left_context_begin >= 0 && right_context >= 0);
179  // note, 'end' is one past the last one.
180  int32 chunk1_input_begin_t = - left_context_begin,
181  chunk1_input_end_t = chunk_size + right_context,
182  chunk2_input_begin_t = chunk1_input_end_t,
183  chunk2_input_end_t = chunk2_input_begin_t + chunk_size,
184  chunk3_input_begin_t = chunk2_input_end_t,
185  chunk3_input_end_t = chunk3_input_begin_t + chunk_size;
186 
187 
188  // work out the times at which i-vectors are required.
189  std::set<int32> ivector_times1, ivector_times2, ivector_times3;
190  if (has_ivector) {
191  for (int32 t = chunk1_input_begin_t; t < chunk1_input_end_t; t++) {
192  int32 ivector_t = t - Mod(t, ivector_period);
193  ivector_times1.insert(ivector_t);
194  }
195  for (int32 t = chunk2_input_begin_t; t < chunk2_input_end_t; t++) {
196  int32 ivector_t = t - Mod(t, ivector_period);
197  if (ivector_times2.count(ivector_t) == 0 &&
198  ivector_times1.count(ivector_t) == 0)
199  ivector_times2.insert(ivector_t);
200  }
201  for (int32 t = chunk3_input_begin_t; t < chunk3_input_end_t; t++) {
202  int32 ivector_t = t - Mod(t, ivector_period);
203  if (ivector_times3.count(ivector_t) == 0 &&
204  ivector_times2.count(ivector_t) == 0 &&
205  ivector_times1.count(ivector_t) == 0)
206  ivector_times3.insert(ivector_t);
207  }
208  }
209 
211  chunk1_input_begin_t, chunk1_input_end_t,
212  0, chunk_size,
213  num_sequences, frame_subsampling_factor,
214  ivector_times1,
215  request1);
216 
218  chunk2_input_begin_t, chunk2_input_end_t,
219  chunk_size, chunk_size * 2,
220  num_sequences, frame_subsampling_factor,
221  ivector_times2,
222  request2);
223 
225  chunk3_input_begin_t, chunk3_input_end_t,
226  chunk_size * 2, chunk_size * 3,
227  num_sequences, frame_subsampling_factor,
228  ivector_times3,
229  request3);
230 
231 }
232 
233 
234 
236  ComputationRequest *request) {
237  for (size_t i = 0; i < request->inputs.size(); i++) {
238  size_t size = request->inputs[i].indexes.size();
239  for (size_t j = 0; j < size; j++)
240  request->inputs[i].indexes[j].t += t_offset;
241  }
242  for (size_t i = 0; i < request->outputs.size(); i++) {
243  size_t size = request->outputs[i].indexes.size();
244  for (size_t j = 0; j < size; j++)
245  request->outputs[i].indexes[j].t += t_offset;
246  }
247 }
248 
249 
250 
252  const ComputationRequest &request1,
253  const ComputationRequest &request2,
254  ComputationRequest *request3) {
255  // accepts two computation requests 'request1' and 'request2' that
256  // must be identical except for a time offset, and creates 'request3'
257  // that is the extrapolation of the next term in sequence.
258  *request3 = request2;
259  KALDI_ASSERT(!request1.inputs.empty() && !request1.inputs[0].indexes.empty() &&
260  !request2.inputs.empty() && !request2.inputs[0].indexes.empty());
261  int32 t_offset = request2.inputs[0].indexes[0].t -
262  request1.inputs[0].indexes[0].t;
263  // the following is just to make sure that the inputs are structurally
264  // equivalent.
265  AddTimeOffsetToComputationRequest(-t_offset, request3);
266  if (!(*request3 == request1))
267  return false; // there is somse structural difference, or
268  // the time offset is not consistent.
269  // the following reverses the last call to AddTimeOffsetToComputationRequest,
270  // then adds the offset we want.
271  AddTimeOffsetToComputationRequest(2 * t_offset, request3);
272  return true;
273 }
274 
275 
276 /* Internal version of CompileLooped where
277  you specify the the number of computation requests (must be >= 3).
278  Returns true on success.
279  It's possible for the optimization to fail if you give too small
280  a value of 'num_requests' (this depends on the network topology),
281  and in that case this function will return false and you should re-try
282  with a higher value of num_requests.
283  */
285  const Nnet &nnet,
286  NnetOptimizeOptions optimize_opts,
287  const ComputationRequest &request1,
288  const ComputationRequest &request2,
289  const ComputationRequest &request3,
290  int32 num_requests,
291  NnetComputation *computation) {
292 
293  KALDI_ASSERT(num_requests >= 3);
294  std::vector<ComputationRequest> extra_requests(num_requests - 3);
295  const ComputationRequest *prev_request = &request2;
296  const ComputationRequest *cur_request = &request3;
297  for (int32 i = 0; i < num_requests - 3; i++) {
298  if (!ExtrapolateComputationRequest(*prev_request, *cur_request,
299  &(extra_requests[i]))) {
300  KALDI_LOG << "prev_request is:";
301  prev_request->Print(std::cerr);
302  KALDI_LOG << "cur_request is:";
303  cur_request->Print(std::cerr);
304  KALDI_ERR << "Computation requests do not have the right relationship";
305  }
306  prev_request = cur_request;
307  cur_request = &(extra_requests[i]);
308  }
309 
310  std::vector<const ComputationRequest*> requests;
311  requests.push_back(&request1);
312  requests.push_back(&request2);
313  requests.push_back(&request3);
314  for (int32 i = 0; i < num_requests - 3; i++)
315  requests.push_back(&(extra_requests[i]));
316  Compiler compiler(requests, nnet);
317  CompilerOptions compiler_opts;
318  compiler.CreateComputation(compiler_opts, computation);
319  optimize_opts.optimize_looped_computation = true;
320 
321  int32 dont_really_care = MaxOutputTimeInRequest(request3);
322  Optimize(optimize_opts, nnet,
323  dont_really_care, computation);
324 
325  return computation->commands.size() != 0 &&
326  computation->commands.back().command_type == kGotoLabel;
327 }
328 
329 void CompileLooped(const Nnet &nnet,
330  const NnetOptimizeOptions &optimize_opts,
331  const ComputationRequest &request1,
332  const ComputationRequest &request2,
333  const ComputationRequest &request3,
334  NnetComputation *computation) {
335  int32 num_requests1 = 5, factor = 2, max_requests = 100,
336  num_requests;
337 
338  Timer timer;
339 
340  for (num_requests = num_requests1; num_requests <= max_requests;
341  num_requests *= factor) {
342  if (CompileLoopedInternal(nnet, optimize_opts,
343  request1, request2, request3,
344  num_requests, computation)) {
345  KALDI_LOG << "Spent " << timer.Elapsed()
346  << " seconds in looped compilation.";
347  return;
348  } else {
349  KALDI_VLOG(2) << "Looped compilation failed with "
350  << num_requests << " requests, trying "
351  << (num_requests * factor);
352  }
353  }
354  KALDI_ERR << "Looped compilation failed with "
355  << (num_requests/factor) << " requests, which "
356  << "we expect should be enough... something "
357  << "went wrong.";
358 }
359 
360 
362  int32 chunk_size,
363  int32 frame_subsampling_factor,
364  int32 ivector_period,
365  int32 extra_left_context_begin,
366  int32 extra_right_context,
367  int32 num_sequences,
368  ComputationRequest *request1,
369  ComputationRequest *request2,
370  ComputationRequest *request3) {
371  int32 left_context, right_context;
372  ComputeSimpleNnetContext(nnet, &left_context, &right_context);
373 
374  CreateLoopedComputationRequest(nnet, chunk_size, frame_subsampling_factor,
375  ivector_period,
376  extra_left_context_begin + left_context,
377  extra_right_context + right_context,
378  num_sequences, request1, request2, request3);
379 }
380 
381 } // namespace nnet3
382 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ModifyNnetIvectorPeriod(int32 ivector_period, Nnet *nnet)
This function modifies the descriptors in the neural network to change the periodicity with which it ...
int32 InputDim(const std::string &input_name) const
Definition: nnet-nnet.cc:669
const std::string & FirstToken() const
Definition: text-utils.h:228
static bool ExtrapolateComputationRequest(const ComputationRequest &request1, const ComputationRequest &request2, ComputationRequest *request3)
bool ParseLine(const std::string &line)
Definition: text-utils.cc:343
static void CreateComputationRequestInternal(int32 begin_input_t, int32 end_input_t, int32 begin_output_t, int32 end_output_t, int32 num_sequences, int32 frame_subsampling_factor, const std::set< int32 > &ivector_times, ComputationRequest *request)
void ReadConfig(std::istream &config_file)
Definition: nnet-nnet.cc:189
int32 GetChunkSize(const Nnet &nnet, int32 frame_subsampling_factor, int32 advised_chunk_size)
void CreateLoopedComputationRequest(const Nnet &nnet, int32 chunk_size, int32 frame_subsampling_factor, int32 ivector_period, int32 left_context_begin, int32 right_context, int32 num_sequences, ComputationRequest *request1, ComputationRequest *request2, ComputationRequest *request3)
This function creates computation request suitable for giving to ComputeLooped(). ...
kaldi::int32 int32
std::vector< IoSpecification > inputs
static bool CompileLoopedInternal(const Nnet &nnet, NnetOptimizeOptions optimize_opts, const ComputationRequest &request1, const ComputationRequest &request2, const ComputationRequest &request3, int32 num_requests, NnetComputation *computation)
std::vector< Command > commands
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
Definition: nnet-common.h:44
This file contains some miscellaneous functions dealing with class Nnet.
int32 Modulus() const
[Relevant for clockwork RNNs and similar].
Definition: nnet-nnet.cc:658
void CompileLooped(const Nnet &nnet, const NnetOptimizeOptions &optimize_opts, const ComputationRequest &request1, const ComputationRequest &request2, const ComputationRequest &request3, NnetComputation *computation)
CompileLooped() provides an internal interface for &#39;looped&#39; computation.
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
Definition: nnet-utils.cc:146
int32 MaxOutputTimeInRequest(const ComputationRequest &request)
I Mod(I m, I n)
Mod(m, n), defined for integers m and n where n > 0, returns the modulus m % n, defined as the intege...
struct rnnlm::@11::@12 n
void GetConfigLines(bool include_dim, std::vector< std::string > *config_lines) const
Definition: nnet-nnet.cc:180
#define KALDI_ERR
Definition: kaldi-error.h:147
void CreateLoopedComputationRequestSimple(const Nnet &nnet, int32 chunk_size, int32 frame_subsampling_factor, int32 ivector_period, int32 extra_left_context_begin, int32 extra_right_context, int32 num_sequences, ComputationRequest *request1, ComputationRequest *request2, ComputationRequest *request3)
This function is deprecated.
void Optimize(const NnetOptimizeOptions &config, const Nnet &nnet, int32 max_output_time_in_request, NnetComputation *computation)
This is the top-level function for optimizing a computation.
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing=&#39;a b c&#39; baz="a b c d=&#39;a b&#39; e" and giving you access to the fields, in this case.
Definition: text-utils.h:205
void CreateComputation(const CompilerOptions &opts, NnetComputation *computation)
Definition: nnet-compile.cc:50
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
std::vector< IoSpecification > outputs
This class creates an initial version of the NnetComputation, without any optimization or sharing of ...
Definition: nnet-compile.h:44
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void AddTimeOffsetToComputationRequest(int32 t_offset, ComputationRequest *request)
void Print(std::ostream &os) const
This function is for printing info about the computation request in a human-readable way...
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74