31 std::vector<std::string> config_lines;
33 std::ostringstream config_to_read;
34 for (
size_t i = 0;
i < config_lines.size();
i++) {
35 std::string s = config_lines[
i];
37 bool b = config_line.
ParseLine(config_lines[
i]);
39 if (config_line.
FirstToken() ==
"component-node") {
47 std::string whole_line = config_lines[
i];
48 std::string to_search_for =
"ReplaceIndex(";
49 std::string::size_type to_search_for_size = to_search_for.size();
50 std::string::size_type pos = whole_line.find(to_search_for);
51 if (pos != std::string::npos) {
52 std::string::size_type comma_pos = whole_line.find(
", t, 0)", pos);
53 if (comma_pos != std::string::npos) {
56 std::string descriptor_name =
57 whole_line.substr(pos + to_search_for_size,
58 comma_pos - (pos + to_search_for_size));
60 std::string::size_type end_pos = comma_pos + 7;
61 std::string::size_type expr_size = end_pos - pos;
63 std::ostringstream to_replace_with;
64 to_replace_with <<
"Round(" << descriptor_name <<
", " << ivector_period <<
")";
65 whole_line.replace(pos, expr_size, to_replace_with.str());
66 config_to_read << whole_line <<
"\n";
68 KALDI_ERR <<
"Could not process the ReplaceIndex expression in: " 74 if (!config_to_read.str().empty()) {
75 std::istringstream is(config_to_read.str());
82 int32 frame_subsampling_factor,
83 int32 advised_chunk_size) {
85 KALDI_ASSERT(modulus > 0 && frame_subsampling_factor > 0 &&
86 advised_chunk_size > 0);
87 int32 chunk_size = advised_chunk_size;
89 if (chunk_size % modulus == 0 &&
90 chunk_size % frame_subsampling_factor == 0)
106 template<
class I> I
Mod(I m, I
n) {
108 if (ans < 0) ans +=
n;
117 int32 frame_subsampling_factor,
118 const std::set<int32> &ivector_times,
120 request->
inputs.reserve(2);
122 request->
inputs.resize(1 + (ivector_times.empty() ? 0 : 1));
123 request->
inputs[0].name =
"input";
124 request->
inputs[0].has_deriv =
false;
127 request->
outputs[0].name =
"output";
128 request->
outputs[0].has_deriv =
false;
129 if (!ivector_times.empty()) {
130 request->
inputs[1].name =
"ivector";
131 request->
inputs[1].has_deriv =
false;
138 for (
int32 n = 0;
n < num_sequences;
n++) {
140 for (
int32 t = begin_input_t; t < end_input_t; t++) {
143 for (
int32 t = begin_output_t;
145 t += frame_subsampling_factor)
148 if (!ivector_times.empty()) {
149 request->
inputs.resize(2);
150 request->
inputs[1].name =
"ivector";
151 request->
inputs[1].has_deriv =
false;
152 for (
int32 n = 0;
n < num_sequences;
n++) {
154 for (std::set<int32>::const_iterator iter = ivector_times.begin();
155 iter != ivector_times.end(); ++iter) {
156 int32 t = *iter, x = 0;
166 int32 frame_subsampling_factor,
167 int32 ivector_period,
168 int32 left_context_begin,
174 bool has_ivector = (nnet.
InputDim(
"ivector") > 0);
175 KALDI_ASSERT(chunk_size % frame_subsampling_factor == 0 &&
176 chunk_size % nnet.
Modulus() == 0 &&
177 chunk_size % ivector_period == 0);
178 KALDI_ASSERT(left_context_begin >= 0 && right_context >= 0);
180 int32 chunk1_input_begin_t = - left_context_begin,
181 chunk1_input_end_t = chunk_size + right_context,
182 chunk2_input_begin_t = chunk1_input_end_t,
183 chunk2_input_end_t = chunk2_input_begin_t + chunk_size,
184 chunk3_input_begin_t = chunk2_input_end_t,
185 chunk3_input_end_t = chunk3_input_begin_t + chunk_size;
189 std::set<int32> ivector_times1, ivector_times2, ivector_times3;
191 for (
int32 t = chunk1_input_begin_t; t < chunk1_input_end_t; t++) {
192 int32 ivector_t = t -
Mod(t, ivector_period);
193 ivector_times1.insert(ivector_t);
195 for (
int32 t = chunk2_input_begin_t; t < chunk2_input_end_t; t++) {
196 int32 ivector_t = t -
Mod(t, ivector_period);
197 if (ivector_times2.count(ivector_t) == 0 &&
198 ivector_times1.count(ivector_t) == 0)
199 ivector_times2.insert(ivector_t);
201 for (
int32 t = chunk3_input_begin_t; t < chunk3_input_end_t; t++) {
202 int32 ivector_t = t -
Mod(t, ivector_period);
203 if (ivector_times3.count(ivector_t) == 0 &&
204 ivector_times2.count(ivector_t) == 0 &&
205 ivector_times1.count(ivector_t) == 0)
206 ivector_times3.insert(ivector_t);
211 chunk1_input_begin_t, chunk1_input_end_t,
213 num_sequences, frame_subsampling_factor,
218 chunk2_input_begin_t, chunk2_input_end_t,
219 chunk_size, chunk_size * 2,
220 num_sequences, frame_subsampling_factor,
225 chunk3_input_begin_t, chunk3_input_end_t,
226 chunk_size * 2, chunk_size * 3,
227 num_sequences, frame_subsampling_factor,
237 for (
size_t i = 0;
i < request->
inputs.size();
i++) {
238 size_t size = request->
inputs[
i].indexes.size();
239 for (
size_t j = 0;
j < size;
j++)
240 request->
inputs[
i].indexes[
j].t += t_offset;
242 for (
size_t i = 0;
i < request->
outputs.size();
i++) {
243 size_t size = request->
outputs[
i].indexes.size();
244 for (
size_t j = 0;
j < size;
j++)
245 request->
outputs[
i].indexes[
j].t += t_offset;
258 *request3 = request2;
260 !request2.
inputs.empty() && !request2.
inputs[0].indexes.empty());
262 request1.
inputs[0].indexes[0].t;
266 if (!(*request3 == request1))
294 std::vector<ComputationRequest> extra_requests(num_requests - 3);
297 for (
int32 i = 0;
i < num_requests - 3;
i++) {
299 &(extra_requests[
i]))) {
301 prev_request->
Print(std::cerr);
303 cur_request->
Print(std::cerr);
304 KALDI_ERR <<
"Computation requests do not have the right relationship";
306 prev_request = cur_request;
307 cur_request = &(extra_requests[
i]);
310 std::vector<const ComputationRequest*> requests;
311 requests.push_back(&request1);
312 requests.push_back(&request2);
313 requests.push_back(&request3);
314 for (
int32 i = 0;
i < num_requests - 3;
i++)
315 requests.push_back(&(extra_requests[
i]));
323 dont_really_care, computation);
325 return computation->
commands.size() != 0 &&
335 int32 num_requests1 = 5, factor = 2, max_requests = 100,
340 for (num_requests = num_requests1; num_requests <= max_requests;
341 num_requests *= factor) {
343 request1, request2, request3,
344 num_requests, computation)) {
346 <<
" seconds in looped compilation.";
349 KALDI_VLOG(2) <<
"Looped compilation failed with " 350 << num_requests <<
" requests, trying " 351 << (num_requests * factor);
354 KALDI_ERR <<
"Looped compilation failed with " 355 << (num_requests/factor) <<
" requests, which " 356 <<
"we expect should be enough... something " 363 int32 frame_subsampling_factor,
364 int32 ivector_period,
365 int32 extra_left_context_begin,
366 int32 extra_right_context,
371 int32 left_context, right_context;
376 extra_left_context_begin + left_context,
377 extra_right_context + right_context,
378 num_sequences, request1, request2, request3);
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void ModifyNnetIvectorPeriod(int32 ivector_period, Nnet *nnet)
This function modifies the descriptors in the neural network to change the periodicity with which it ...
int32 InputDim(const std::string &input_name) const
const std::string & FirstToken() const
static bool ExtrapolateComputationRequest(const ComputationRequest &request1, const ComputationRequest &request2, ComputationRequest *request3)
bool ParseLine(const std::string &line)
static void CreateComputationRequestInternal(int32 begin_input_t, int32 end_input_t, int32 begin_output_t, int32 end_output_t, int32 num_sequences, int32 frame_subsampling_factor, const std::set< int32 > &ivector_times, ComputationRequest *request)
void ReadConfig(std::istream &config_file)
int32 GetChunkSize(const Nnet &nnet, int32 frame_subsampling_factor, int32 advised_chunk_size)
void CreateLoopedComputationRequest(const Nnet &nnet, int32 chunk_size, int32 frame_subsampling_factor, int32 ivector_period, int32 left_context_begin, int32 right_context, int32 num_sequences, ComputationRequest *request1, ComputationRequest *request2, ComputationRequest *request3)
This function creates computation request suitable for giving to ComputeLooped(). ...
std::vector< IoSpecification > inputs
static bool CompileLoopedInternal(const Nnet &nnet, NnetOptimizeOptions optimize_opts, const ComputationRequest &request1, const ComputationRequest &request2, const ComputationRequest &request3, int32 num_requests, NnetComputation *computation)
std::vector< Command > commands
struct Index is intended to represent the various indexes by which we number the rows of the matrices...
This file contains some miscellaneous functions dealing with class Nnet.
bool optimize_looped_computation
int32 Modulus() const
[Relevant for clockwork RNNs and similar].
void CompileLooped(const Nnet &nnet, const NnetOptimizeOptions &optimize_opts, const ComputationRequest &request1, const ComputationRequest &request2, const ComputationRequest &request3, NnetComputation *computation)
CompileLooped() provides an internal interface for 'looped' computation.
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
int32 MaxOutputTimeInRequest(const ComputationRequest &request)
I Mod(I m, I n)
Mod(m, n), defined for integers m and n where n > 0, returns the modulus m % n, defined as the intege...
void GetConfigLines(bool include_dim, std::vector< std::string > *config_lines) const
void CreateLoopedComputationRequestSimple(const Nnet &nnet, int32 chunk_size, int32 frame_subsampling_factor, int32 ivector_period, int32 extra_left_context_begin, int32 extra_right_context, int32 num_sequences, ComputationRequest *request1, ComputationRequest *request2, ComputationRequest *request3)
This function is deprecated.
void Optimize(const NnetOptimizeOptions &config, const Nnet &nnet, int32 max_output_time_in_request, NnetComputation *computation)
This is the top-level function for optimizing a computation.
This class is responsible for parsing input like hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing='a b c' baz="a b c d='a b' e" and giving you access to the fields, in this case.
void CreateComputation(const CompilerOptions &opts, NnetComputation *computation)
#define KALDI_ASSERT(cond)
std::vector< IoSpecification > outputs
This class creates an initial version of the NnetComputation, without any optimization or sharing of ...
void AddTimeOffsetToComputationRequest(int32 t_offset, ComputationRequest *request)
void Print(std::ostream &os) const
This function is for printing info about the computation request in a human-readable way...
double Elapsed() const
Returns time in seconds.