nnet3-xvector-compute-batched.cc File Reference
Include dependency graph for nnet3-xvector-compute-batched.cc:

Go to the source code of this file.

Classes

struct  BatchedXvectorComputerOptions
 
class  BatchedXvectorComputer
 
struct  BatchedXvectorComputer::XvectorTask
 

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 
 kaldi::nnet3
 

Functions

void DivideIntoPieces (int32 a, int32 b, std::vector< int32 > *pieces)
 This function divides the number 'a' into 'b' pieces, such that the sum of the pieces equals 'a' and no two pieces differ by more than 1. More...
 
int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 524 of file nnet3-xvector-compute-batched.cc.

References BatchedXvectorComputer::AcceptUtterance(), kaldi::nnet3::CollapseModel(), kaldi::nnet3::ComputeSimpleNnetContext(), SequentialTableReader< Holder >::Done(), Timer::Elapsed(), BatchedXvectorComputer::Flush(), ParseOptions::GetArg(), KALDI_LOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), BatchedXvectorComputer::OutputXvector(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadKaldiObject(), BatchedXvectorComputerOptions::Register(), ParseOptions::Register(), kaldi::nnet3::SetBatchnormTestMode(), kaldi::nnet3::SetDropoutTestMode(), kaldi::nnet3::SetRequireDirectInput(), SequentialTableReader< Holder >::Value(), TableWriter< Holder >::Write(), and BatchedXvectorComputer::XvectorReady().

524  {
525  try {
526  using namespace kaldi;
527  using namespace kaldi::nnet3;
528  typedef kaldi::int32 int32;
529  typedef kaldi::int64 int64;
530 
531  const char *usage =
532  "Propagate features through an xvector neural network model and write\n"
533  "the output vectors. \"Xvector\" is our term for a vector or\n"
534  "embedding which is the output of a particular type of neural network\n"
535  "architecture found in speaker recognition. This architecture\n"
536  "consists of several layers that operate on frames, a statistics\n"
537  "pooling layer that aggregates over the frame-level representations\n"
538  "and possibly additional layers that operate on segment-level\n"
539  "representations. The xvectors are generally extracted from an\n"
540  "output layer after the statistics pooling layer. By default, one\n"
541  "xvector is extracted directly from the set of features for each\n"
542  "utterance. Optionally, xvectors are extracted from chunks of input\n"
543  "features and averaged, to produce a single vector.\n"
544  "\n"
545  "Usage: nnet3-xvector-compute [options] <raw-nnet-in> "
546  "<features-rspecifier> <vector-wspecifier>\n"
547  "e.g.: nnet3-xvector-compute final.raw scp:feats.scp "
548  "ark:nnet_prediction.ark\n"
549  "See also: nnet3-compute\n";
550 
551  ParseOptions po(usage);
552  Timer timer;
553 
555 
556  std::string use_gpu = "no";
557 
558  opts.Register(&po);
559 
560  po.Register("use-gpu", &use_gpu,
561  "yes|no|optional|wait, only has effect if compiled with CUDA");
562 
563 #if HAVE_CUDA==1
564  CuDevice::RegisterDeviceOptions(&po);
565 #endif
566  po.Read(argc, argv);
567 
568  if (po.NumArgs() != 3) {
569  po.PrintUsage();
570  exit(1);
571  }
572 
573 #if HAVE_CUDA==1
574  CuDevice::Instantiate().SelectGpuId(use_gpu);
575 #endif
576 
577  std::string nnet_rxfilename = po.GetArg(1),
578  feature_rspecifier = po.GetArg(2),
579  vector_wspecifier = po.GetArg(3);
580 
581  Nnet nnet;
582  ReadKaldiObject(nnet_rxfilename, &nnet);
583  SetBatchnormTestMode(true, &nnet);
584  SetDropoutTestMode(true, &nnet);
586 
587  int32 total_context;
588  {
589  int32 left_context, right_context;
590  // Compute left_context, right_context as the 'real' left/right context
591  // of the network; they'll tell us how many frames on the chunk boundaries
592  // won't really participate in the statistics averaging.
593  // SetRequireDirectInput() modifies how the StatisticsPoolingComponent
594  // treats its dependences, so we'll get the 'real' left/right context.
595  SetRequireDirectInput(true, &nnet);
596  ComputeSimpleNnetContext(nnet, &left_context, &right_context);
597  KALDI_LOG << "Left/right context is " << left_context << ", "
598  << right_context;
599  SetRequireDirectInput(false, &nnet);
600  total_context = left_context + right_context;
601  }
602 
603  BatchedXvectorComputer computer(opts, nnet, total_context);
604  BaseFloatVectorWriter vector_writer(vector_wspecifier);
605 
606  int32 num_utts_read = 0, num_xvectors_written = 0;
607  int64 frame_count = 0;
608 
609  SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
610 
611  for (; !feature_reader.Done(); feature_reader.Next()) {
612  std::string utt = feature_reader.Key();
613  const Matrix<BaseFloat> &features (feature_reader.Value());
614  if (features.NumRows() == 0) {
615  KALDI_WARN << "Zero-length utterance: " << utt;
616  continue;
617  }
618 
619  frame_count += features.NumRows();
620 
621  computer.AcceptUtterance(utt, features);
622  num_utts_read++;
623 
624  while (computer.XvectorReady()) {
625  std::string utt;
626  Vector<BaseFloat> xvector;
627  computer.OutputXvector(&utt, &xvector);
628  vector_writer.Write(utt, xvector);
629  num_xvectors_written++;
630  }
631  }
632 
633  computer.Flush();
634  while (computer.XvectorReady()) {
635  std::string utt;
636  Vector<BaseFloat> xvector;
637  computer.OutputXvector(&utt, &xvector);
638  vector_writer.Write(utt, xvector);
639  num_xvectors_written++;
640  }
641 
642 
643 #if HAVE_CUDA==1
644  CuDevice::Instantiate().PrintProfile();
645 #endif
646  double elapsed = timer.Elapsed();
647  KALDI_LOG << "Time taken "<< elapsed
648  << "s: real-time factor assuming 100 frames/sec is "
649  << (elapsed*100.0/frame_count);
650  KALDI_LOG << "Read " << num_utts_read << " utterances, wrote "
651  << num_xvectors_written << " xvectors.";
652 
653  // Note: the following rule does something reasonable even if there are 0, 1
654  // or 2 utterances read.
655  if (num_xvectors_written > num_utts_read / 2)
656  return 0;
657  else
658  return 1;
659  } catch(const std::exception &e) {
660  std::cerr << e.what();
661  return -1;
662  }
663 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
Definition: nnet-utils.cc:2100
void Write(std::ostream &Out, bool binary) const
Writes to C++ stream (option to write in binary).
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
Definition: nnet-utils.cc:564
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
void ReadKaldiObject(const std::string &filename, Matrix< float > *m)
Definition: kaldi-io.cc:832
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
Definition: nnet-utils.cc:573
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
Definition: nnet-utils.cc:146
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_WARN
Definition: kaldi-error.h:150
A class representing a vector.
Definition: kaldi-vector.h:406
void SetRequireDirectInput(bool b, Nnet *nnet)
Calls the corresponding function in any component of type StatisticsPoolingComponent; used as a way t...
Definition: nnet-utils.cc:303
#define KALDI_LOG
Definition: kaldi-error.h:153
double Elapsed() const
Returns time in seconds.
Definition: timer.h:74
Config class for the CollapseModel function.
Definition: nnet-utils.h:240