GeneralDescriptor Struct Reference

This class is only used when parsing Descriptors. More...

#include <nnet-descriptor.h>

Collaboration diagram for GeneralDescriptor:

Public Types

enum  DescriptorType {
  kAppend, kSum, kFailover, kIfDefined,
  kOffset, kSwitch, kRound, kReplaceIndex,
  kScale, kConst, kNodeName
}
 

Public Member Functions

 GeneralDescriptor (DescriptorType t, int32 value1=-1, int32 value2=-1, BaseFloat alpha=0.0)
 
 ~GeneralDescriptor ()
 
GeneralDescriptorGetNormalizedDescriptor () const
 
DescriptorConvertToDescriptor ()
 
void Print (const std::vector< std::string > &node_names, std::ostream &os)
 

Static Public Member Functions

static GeneralDescriptorParse (const std::vector< std::string > &node_names, const std::string **next_token)
 

Private Member Functions

 KALDI_DISALLOW_COPY_AND_ASSIGN (GeneralDescriptor)
 
void ParseAppendOrSumOrSwitch (const std::vector< std::string > &node_names, const std::string **next_token)
 
void ParseIfDefined (const std::vector< std::string > &node_names, const std::string **next_token)
 
void ParseOffset (const std::vector< std::string > &node_names, const std::string **next_token)
 
void ParseSwitch (const std::vector< std::string > &node_names, const std::string **next_token)
 
void ParseFailover (const std::vector< std::string > &node_names, const std::string **next_token)
 
void ParseRound (const std::vector< std::string > &node_names, const std::string **next_token)
 
void ParseScale (const std::vector< std::string > &node_names, const std::string **next_token)
 
void ParseConst (const std::vector< std::string > &node_names, const std::string **next_token)
 
void ParseReplaceIndex (const std::vector< std::string > &node_names, const std::string **next_token)
 
int32 NumAppendTerms () const
 
GeneralDescriptorGetAppendTerm (int32 term) const
 
GeneralDescriptorNormalizeAppend () const
 
SumDescriptorConvertToSumDescriptor () const
 
ForwardingDescriptorConvertToForwardingDescriptor () const
 

Static Private Member Functions

static bool Normalize (GeneralDescriptor *ptr)
 

Private Attributes

DescriptorType descriptor_type_
 
int32 value1_
 
int32 value2_
 
BaseFloat alpha_
 
std::vector< GeneralDescriptor * > descriptors_
 

Detailed Description

This class is only used when parsing Descriptors.

It is useful for normalizing descriptors that are structured in an invalid or redundant way, into a form that can be turned into a real Descriptor.

Definition at line 609 of file nnet-descriptor.h.

Member Enumeration Documentation

◆ DescriptorType

Constructor & Destructor Documentation

◆ GeneralDescriptor()

GeneralDescriptor ( DescriptorType  t,
int32  value1 = -1,
int32  value2 = -1,
BaseFloat  alpha = 0.0 
)
inlineexplicit

Definition at line 623 of file nnet-descriptor.h.

◆ ~GeneralDescriptor()

~GeneralDescriptor ( )
inline

Definition at line 629 of file nnet-descriptor.h.

References kaldi::DeletePointers(), ForwardingDescriptor::KALDI_DISALLOW_COPY_AND_ASSIGN(), and fst::Print().

void DeletePointers(std::vector< A *> *v)
Deletes any non-NULL pointers in the vector v, and sets the corresponding entries of v to NULL...
Definition: stl-utils.h:184
std::vector< GeneralDescriptor * > descriptors_

Member Function Documentation

◆ ConvertToDescriptor()

Descriptor * ConvertToDescriptor ( )

Definition at line 1029 of file nnet-descriptor.cc.

References GeneralDescriptor::ConvertToSumDescriptor(), GeneralDescriptor::descriptor_type_, GeneralDescriptor::descriptors_, and rnnlm::i.

Referenced by kaldi::nnet3::NormalizeTextDescriptor(), Descriptor::Parse(), and kaldi::nnet3::UnitTestGeneralDescriptor().

1029  {
1031  std::vector<SumDescriptor*> sum_descriptors;
1032  if (normalized->descriptor_type_ == kAppend) {
1033  for (size_t i = 0; i < normalized->descriptors_.size(); i++)
1034  sum_descriptors.push_back(
1035  normalized->descriptors_[i]->ConvertToSumDescriptor());
1036  } else {
1037  sum_descriptors.push_back(normalized->ConvertToSumDescriptor());
1038  }
1039  Descriptor *ans = new Descriptor(sum_descriptors);
1040  delete normalized;
1041  return ans;
1042 }
GeneralDescriptor(DescriptorType t, int32 value1=-1, int32 value2=-1, BaseFloat alpha=0.0)
GeneralDescriptor * GetNormalizedDescriptor() const

◆ ConvertToForwardingDescriptor()

ForwardingDescriptor * ConvertToForwardingDescriptor ( ) const
private

Definition at line 1075 of file nnet-descriptor.cc.

References rnnlm::i, KALDI_ASSERT, KALDI_ERR, ReplaceIndexForwardingDescriptor::kT, and ReplaceIndexForwardingDescriptor::kX.

1075  {
1076  switch (this->descriptor_type_) {
1077  case kNodeName: return new SimpleForwardingDescriptor(value1_);
1078  case kOffset: {
1079  KALDI_ASSERT(descriptors_.size() == 1 && "bad descriptor");
1080  return new OffsetForwardingDescriptor(
1082  Index(0, value1_, value2_));
1083  }
1084  case kSwitch: {
1085  std::vector<ForwardingDescriptor*> descriptors;
1086  for (size_t i = 0; i < descriptors_.size(); i++)
1087  descriptors.push_back(descriptors_[i]->ConvertToForwardingDescriptor());
1088  return new SwitchingForwardingDescriptor(descriptors);
1089  }
1090  case kRound: {
1091  KALDI_ASSERT(descriptors_.size() == 1 && "bad descriptor");
1092  return new RoundingForwardingDescriptor(
1094  value1_);
1095  }
1096  case kReplaceIndex: {
1097  KALDI_ASSERT(descriptors_.size() == 1 && "bad descriptor");
1100  return new ReplaceIndexForwardingDescriptor(
1105  value2_);
1106  }
1107  case kScale: {
1108  if (!(descriptors_.size() == 1 &&
1109  descriptors_[0]->descriptor_type_ == kNodeName)) {
1110  KALDI_ERR << "Invalid combination of Scale() expression and other "
1111  "expressions encountered in descriptor.";
1112  }
1113  return new SimpleForwardingDescriptor(descriptors_[0]->value1_,
1114  alpha_);
1115  }
1116  case kConst: {
1117  KALDI_ERR << "Error in Descriptor: Const() "
1118  "appeared too deep in the expression.";
1119  }
1120  default:
1121  KALDI_ERR << "Invalid descriptor type (failure in normalization?)";
1122  return NULL;
1123  }
1124 }
std::vector< GeneralDescriptor * > descriptors_
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147
ForwardingDescriptor * ConvertToForwardingDescriptor() const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ ConvertToSumDescriptor()

SumDescriptor * ConvertToSumDescriptor ( ) const
private

Definition at line 1044 of file nnet-descriptor.cc.

References ConstantSumDescriptor::ConstantSumDescriptor(), KALDI_ASSERT, KALDI_ERR, BinarySumDescriptor::kFailoverOperation, and BinarySumDescriptor::kSumOperation.

Referenced by GeneralDescriptor::ConvertToDescriptor().

1044  {
1046  "Badly normalized descriptor");
1047  switch (descriptor_type_) {
1048  case kAppend:
1049  KALDI_ERR << "Badly normalized descriptor";
1050  case kSum: case kFailover: {
1051  KALDI_ASSERT(descriptors_.size() == 2 && "Bad descriptor");
1052  return new BinarySumDescriptor(
1053  descriptor_type_ == kSum ?
1058  }
1059  case kIfDefined: {
1060  KALDI_ASSERT(descriptors_.size() == 1 && "Bad descriptor");
1061  return new OptionalSumDescriptor(
1063  }
1064  case kConst: {
1065  KALDI_ASSERT(descriptors_.empty() && value1_ > 0);
1066  return new ConstantSumDescriptor(alpha_, value1_);
1067  }
1068  default: {
1069  return new SimpleSumDescriptor(this->ConvertToForwardingDescriptor());
1070  }
1071  }
1072 }
SumDescriptor * ConvertToSumDescriptor() const
std::vector< GeneralDescriptor * > descriptors_
#define KALDI_ERR
Definition: kaldi-error.h:147
ForwardingDescriptor * ConvertToForwardingDescriptor() const
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetAppendTerm()

GeneralDescriptor * GetAppendTerm ( int32  term) const
private

Definition at line 775 of file nnet-descriptor.cc.

References GeneralDescriptor::descriptors_, rnnlm::i, KALDI_ASSERT, and KALDI_ERR.

775  {
776  switch (descriptor_type_) {
777  case kNodeName:
778  KALDI_ASSERT(term == 0);
779  return new GeneralDescriptor(kNodeName, value1_);
780  case kAppend: {
781  int32 cur_term = term;
782  for (size_t i = 0; i < descriptors_.size(); i++) {
783  int32 this_num_terms = descriptors_[i]->NumAppendTerms();
784  if (cur_term < this_num_terms)
785  return descriptors_[i]->GetAppendTerm(cur_term);
786  else
787  cur_term -= this_num_terms;
788  }
789  KALDI_ERR << "Code error, getting append term.";
790  return NULL; // avoid compiler warning
791  }
792  default: {
794  value1_, value2_,
795  alpha_);
796  ans->descriptors_.resize(descriptors_.size());
797  for (size_t i = 0; i < descriptors_.size(); i++)
798  ans->descriptors_[i] = descriptors_[i]->GetAppendTerm(term);
799  return ans;
800  }
801  }
802 }
GeneralDescriptor(DescriptorType t, int32 value1=-1, int32 value2=-1, BaseFloat alpha=0.0)
std::vector< GeneralDescriptor * > descriptors_
kaldi::int32 int32
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ GetNormalizedDescriptor()

GeneralDescriptor * GetNormalizedDescriptor ( ) const

Definition at line 969 of file nnet-descriptor.cc.

969  {
971  while (Normalize(ans)); // keep normalizing as long as it changes.
972  return ans;
973 }
GeneralDescriptor(DescriptorType t, int32 value1=-1, int32 value2=-1, BaseFloat alpha=0.0)
static bool Normalize(GeneralDescriptor *ptr)
GeneralDescriptor * NormalizeAppend() const

◆ KALDI_DISALLOW_COPY_AND_ASSIGN()

KALDI_DISALLOW_COPY_AND_ASSIGN ( GeneralDescriptor  )
private

◆ Normalize()

bool Normalize ( GeneralDescriptor ptr)
staticprivate

Definition at line 823 of file nnet-descriptor.cc.

References GeneralDescriptor::alpha_, GeneralDescriptor::descriptor_type_, GeneralDescriptor::descriptors_, rnnlm::i, KALDI_ASSERT, KALDI_ERR, kaldi::swap(), GeneralDescriptor::value1_, and GeneralDescriptor::value2_.

823  {
824  bool changed = false;
825  switch (desc->descriptor_type_) {
826  case kOffset: { // this block combines Offset(Offset(x, ..), ..).
827  KALDI_ASSERT(desc->descriptors_.size() == 1);
828  GeneralDescriptor *child = desc->descriptors_[0];
829  if (child->descriptor_type_ == kOffset) {
830  KALDI_ASSERT(child->descriptors_.size() == 1);
831  GeneralDescriptor *grandchild = child->descriptors_[0];
832  desc->value1_ += child->value1_;
833  desc->value2_ += child->value2_;
834  child->descriptors_.clear(); // avoid delete in destructor.
835  delete child;
836  desc->descriptors_[0] = grandchild;
837  changed = true;
838  } else if (desc->value1_ == 0 && desc->value2_ == 0) {
839  // remove redundant Offset expression like Offset(x, 0).
840  desc->descriptors_.swap(child->descriptors_);
841  desc->descriptor_type_ = child->descriptor_type_;
842  desc->value1_ = child->value1_;
843  desc->value2_ = child->value2_;
844  desc->alpha_ = child->alpha_;
845  child->descriptors_.clear(); // avoid delete in destructor.
846  delete child;
847  changed = true;
848  break; // break from the switch ('desc' is no longer of type
849  // kOffset)', so we don't want to carry through.
850  }
851  }
852  // ... and continue through to the next case statement.
853  case kSwitch: case kRound: case kReplaceIndex: { // ..and kOffset:
854  KALDI_ASSERT(desc->descriptors_.size() >= 1);
855  GeneralDescriptor *child = desc->descriptors_[0];
856  // If child->descriptor_type_ == kAppend, it would be code error since we
857  // already did NormalizeAppend().
858  KALDI_ASSERT(child->descriptor_type_ != kAppend);
859  if (child->descriptor_type_ == kSum ||
860  child->descriptor_type_ == kFailover ||
861  child->descriptor_type_ == kIfDefined) {
862  if (desc->descriptors_.size() > 1) {
863  KALDI_ASSERT(desc->descriptor_type_ == kSwitch);
864  KALDI_ERR << "Sum(), Failover() or IfDefined() expression inside Switch(), "
865  << "we can't currently normalize this.";
866  }
867  // this is a forbidden case of a sum descriptor inside a forwarding
868  // descriptor. we need to rearrange. E.g. Offset(Sum(x, y), 1) becomes
869  // Sum(Offset(x, 1), Offset(y, 1)).
870  for (size_t i = 0; i < child->descriptors_.size(); i++) {
871  GeneralDescriptor *grandchild = child->descriptors_[i];
872  GeneralDescriptor *modified_grandchild =
873  new GeneralDescriptor(desc->descriptor_type_,
874  desc->value1_,
875  desc->value2_,
876  desc->alpha_);
877  // modified_grandchild takes ownership of grandchild.
878  modified_grandchild->descriptors_.push_back(grandchild);
879  child->descriptors_[i] = modified_grandchild;
880  }
881  // copy all members from child to desc.
882  desc->descriptor_type_ = child->descriptor_type_;
883  desc->value1_ = child->value1_;
884  desc->value2_ = child->value2_;
885  desc->descriptors_.swap(child->descriptors_);
886  child->descriptors_.clear(); // avoid delete in destructor of 'child'
887  delete child;
888  changed = true;
889  }
890  break;
891  }
892  case kSum: {
893  KALDI_ASSERT(!desc->descriptors_.empty());
894  if (desc->descriptors_.size() == 1) {
895  // convert Sum(x) to just x.
896  GeneralDescriptor *child = desc->descriptors_[0];
897  desc->descriptor_type_ = child->descriptor_type_;
898  desc->descriptors_.swap(child->descriptors_);
899  desc->value1_ = child->value1_;
900  desc->value2_ = child->value2_;
901  desc->alpha_ = child->alpha_;
902  child->descriptors_.clear(); // avoid delete in destructor.
903  delete child;
904  changed = true;
905  } else if (desc->descriptors_.size() > 2) {
906  // convert Sum(a, b, c, ...) to Sum(a, Sum(b, c, ...)).
907  GeneralDescriptor *new_child = new GeneralDescriptor(kSum);
908  // assign b, c, .. to the descriptors of new_child.
909  new_child->descriptors_.insert(new_child->descriptors_.begin(),
910  desc->descriptors_.begin() + 1,
911  desc->descriptors_.end());
912  desc->descriptors_.erase(desc->descriptors_.begin() + 1,
913  desc->descriptors_.end());
914  desc->descriptors_.push_back(new_child);
915  changed = true;
916  }
917  break;
918  }
919  case kScale: {
920  KALDI_ASSERT(desc->descriptors_.size() == 1);
921  GeneralDescriptor *child = desc->descriptors_[0];
922  if (child->descriptor_type_ == kOffset ||
923  child->descriptor_type_ == kReplaceIndex ||
924  child->descriptor_type_ == kRound) {
925  // push the Scale() inside those expressions.
926  std::swap(desc->descriptor_type_, child->descriptor_type_);
927  std::swap(desc->alpha_, child->alpha_);
928  std::swap(desc->value1_, child->value1_);
929  std::swap(desc->value2_, child->value2_);
930  changed = true;
931  } else if (child->descriptor_type_ == kSum) {
932  // Push the Scale() inside the sum expression.
933  desc->descriptors_.clear();
934  for (size_t i = 0; i < child->descriptors_.size(); i++) {
935  GeneralDescriptor *new_child =
936  new GeneralDescriptor(kScale, -1, -1, desc->alpha_);
937  new_child->descriptors_.push_back(child->descriptors_[i]);
938  desc->descriptors_.push_back(new_child);
939  }
940  desc->descriptor_type_ = kSum;
941  desc->alpha_ = 0.0;
942  child->descriptors_.clear(); // prevent them being freed.
943  delete child;
944  changed = true;
945  } else if (child->descriptor_type_ == kScale) {
946  // Combine the 'scale' expressions.
947  KALDI_ASSERT(child->descriptors_.size() == 1);
948  GeneralDescriptor *grandchild = child->descriptors_[0];
949  desc->alpha_ *= child->alpha_;
950  desc->descriptors_[0] = grandchild;
951  child->descriptors_.clear(); // prevent them being freed.
952  delete child;
953  changed = true;
954  } else if (child->descriptor_type_ != kNodeName) {
955  KALDI_ERR << "Unhandled case encountered when normalizing Descriptor; "
956  "you can work around this by pushing Scale() inside "
957  "other expressions.";
958  }
959  break;
960  }
961  default: { } // empty statement
962  }
963  // ... and recurse.
964  for (size_t i = 0; i < desc->descriptors_.size(); i++)
965  changed = changed || Normalize(desc->descriptors_[i]);
966  return changed;
967 }
GeneralDescriptor(DescriptorType t, int32 value1=-1, int32 value2=-1, BaseFloat alpha=0.0)
static bool Normalize(GeneralDescriptor *ptr)
void swap(basic_filebuf< CharT, Traits > &x, basic_filebuf< CharT, Traits > &y)
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ NormalizeAppend()

GeneralDescriptor * NormalizeAppend ( ) const
private

Definition at line 806 of file nnet-descriptor.cc.

References GeneralDescriptor::descriptors_, rnnlm::i, and KALDI_ASSERT.

806  {
807  int32 num_terms = NumAppendTerms();
808  KALDI_ASSERT(num_terms > 0);
809  if (num_terms == 1) {
810  return GetAppendTerm(0);
811  } else {
813  ans->descriptors_.resize(num_terms);
814  for (size_t i = 0; i < num_terms; i++) {
815  ans->descriptors_[i] = GetAppendTerm(i);
816  }
817  return ans;
818  }
819 }
GeneralDescriptor(DescriptorType t, int32 value1=-1, int32 value2=-1, BaseFloat alpha=0.0)
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
GeneralDescriptor * GetAppendTerm(int32 term) const

◆ NumAppendTerms()

int32 NumAppendTerms ( ) const
private

Definition at line 756 of file nnet-descriptor.cc.

References rnnlm::i, and KALDI_ASSERT.

756  {
757  int32 ans = 0;
758  switch (descriptor_type_) {
759  case kNodeName: ans = 1; break;
760  case kConst: ans = 1; break;
761  case kAppend: {
762  for (size_t i = 0; i < descriptors_.size(); i++)
763  ans += descriptors_[i]->NumAppendTerms();
764  break;
765  }
766  default:
767  KALDI_ASSERT(descriptors_.size() > 0);
768  ans = descriptors_[0]->NumAppendTerms();
769  for (size_t i = 1; i < descriptors_.size(); i++)
771  }
772  return ans;
773 }
std::vector< GeneralDescriptor * > descriptors_
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ Parse()

GeneralDescriptor * Parse ( const std::vector< std::string > &  node_names,
const std::string **  next_token 
)
static

Definition at line 585 of file nnet-descriptor.cc.

References kaldi::nnet3::ExpectToken(), rnnlm::i, KALDI_ERR, GeneralDescriptor::ParseAppendOrSumOrSwitch(), GeneralDescriptor::ParseConst(), GeneralDescriptor::ParseFailover(), GeneralDescriptor::ParseIfDefined(), GeneralDescriptor::ParseOffset(), GeneralDescriptor::ParseReplaceIndex(), GeneralDescriptor::ParseRound(), and GeneralDescriptor::ParseScale().

Referenced by kaldi::nnet3::NormalizeTextDescriptor(), Descriptor::Parse(), and kaldi::nnet3::UnitTestGeneralDescriptor().

587  {
588 
589  DescriptorType t;
590  if (**next_token == "Append") {
591  t = kAppend;
592  } else if (**next_token == "Sum") {
593  t = kSum;
594  } else if (**next_token == "Failover") {
595  t = kFailover;
596  } else if (**next_token == "IfDefined") {
597  t = kIfDefined;
598  } else if (**next_token == "Offset") {
599  t = kOffset;
600  } else if (**next_token == "Switch") {
601  t = kSwitch;
602  } else if (**next_token == "Scale") {
603  t = kScale;
604  } else if (**next_token == "Const") {
605  t = kConst;
606  } else if (**next_token == "Round") {
607  t = kRound;
608  } else if (**next_token == "ReplaceIndex") {
609  t = kReplaceIndex;
610  } else {
611  // what we read wasn't a reserved name like Offset, etc.
612  // We expect a node name in that case.
613  for (size_t i = 0; i < node_names.size(); i++) {
614  if (**next_token == node_names[i]) {
616  (*next_token)++;
617  return ans;
618  }
619  }
620  KALDI_ERR << "Expected a Descriptor, got instead "
621  << **next_token;
622  t = kNodeName; // suppress compiler warning.
623  }
624  (*next_token)++;
625  ExpectToken("(", "Descriptor", next_token);
626  GeneralDescriptor *ans = new GeneralDescriptor(t);
627  switch (t) {
628  case kAppend: case kSum: case kSwitch:
629  ans->ParseAppendOrSumOrSwitch(node_names, next_token); break;
630  case kFailover: ans->ParseFailover(node_names, next_token); break;
631  case kIfDefined: ans->ParseIfDefined(node_names, next_token); break;
632  case kOffset: ans->ParseOffset(node_names, next_token); break;
633  case kRound: ans->ParseRound(node_names, next_token); break;
634  case kReplaceIndex: ans->ParseReplaceIndex(node_names, next_token); break;
635  case kScale: ans->ParseScale(node_names, next_token); break;
636  case kConst: ans->ParseConst(node_names, next_token); break;
637  default:
638  KALDI_ERR << "Code error";
639  }
640  return ans;
641 }
GeneralDescriptor(DescriptorType t, int32 value1=-1, int32 value2=-1, BaseFloat alpha=0.0)
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ ParseAppendOrSumOrSwitch()

void ParseAppendOrSumOrSwitch ( const std::vector< std::string > &  node_names,
const std::string **  next_token 
)
private

Definition at line 643 of file nnet-descriptor.cc.

References KALDI_ERR.

Referenced by GeneralDescriptor::Parse().

645  {
646  descriptors_.push_back(Parse(node_names, next_token));
647  while (true) {
648  if (**next_token == ")") {
649  (*next_token)++;
650  return;
651  } else if (**next_token == ",") {
652  (*next_token)++;
653  descriptors_.push_back(Parse(node_names, next_token));
654  } else {
655  KALDI_ERR << "Expected ',' or ')', got "
656  << **next_token;
657  }
658  }
659 }
std::vector< GeneralDescriptor * > descriptors_
static GeneralDescriptor * Parse(const std::vector< std::string > &node_names, const std::string **next_token)
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ ParseConst()

void ParseConst ( const std::vector< std::string > &  node_names,
const std::string **  next_token 
)
private

Definition at line 690 of file nnet-descriptor.cc.

References kaldi::ConvertStringToInteger(), kaldi::ConvertStringToReal(), kaldi::nnet3::ExpectToken(), and KALDI_ERR.

Referenced by GeneralDescriptor::Parse().

692  {
693  if (!ConvertStringToReal(**next_token, &alpha_)) {
694  KALDI_ERR << "Parsing Const() in descriptor: expected floating-point value"
695  ", got: " << **next_token;
696  }
697  (*next_token)++; // Consume the float.
698  ExpectToken(",", "Const", next_token);
699  if (!ConvertStringToInteger(**next_token, &value1_) ||
700  value1_ <= 0) {
701  KALDI_ERR << "Parsing Const() in descriptor: expected nonnegative integer, "
702  "got: " << **next_token;
703  }
704  (*next_token)++; // Consume the int.
705  ExpectToken(")", "Const", next_token);
706 }
bool ConvertStringToInteger(const std::string &str, Int *out)
Converts a string into an integer via strtoll and returns false if there was any kind of problem (i...
Definition: text-utils.h:118
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
#define KALDI_ERR
Definition: kaldi-error.h:147
bool ConvertStringToReal(const std::string &str, T *out)
ConvertStringToReal converts a string into either float or double and returns false if there was any ...
Definition: text-utils.cc:238

◆ ParseFailover()

void ParseFailover ( const std::vector< std::string > &  node_names,
const std::string **  next_token 
)
private

Definition at line 668 of file nnet-descriptor.cc.

References kaldi::nnet3::ExpectToken().

Referenced by GeneralDescriptor::Parse().

670  {
671  descriptors_.push_back(Parse(node_names, next_token));
672  ExpectToken(",", "Failover", next_token);
673  descriptors_.push_back(Parse(node_names, next_token));
674  ExpectToken(")", "Failover", next_token);
675 }
std::vector< GeneralDescriptor * > descriptors_
static GeneralDescriptor * Parse(const std::vector< std::string > &node_names, const std::string **next_token)
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)

◆ ParseIfDefined()

void ParseIfDefined ( const std::vector< std::string > &  node_names,
const std::string **  next_token 
)
private

Definition at line 661 of file nnet-descriptor.cc.

References kaldi::nnet3::ExpectToken().

Referenced by GeneralDescriptor::Parse().

663  {
664  descriptors_.push_back(Parse(node_names, next_token));
665  ExpectToken(")", "IfDefined", next_token);
666 }
std::vector< GeneralDescriptor * > descriptors_
static GeneralDescriptor * Parse(const std::vector< std::string > &node_names, const std::string **next_token)
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)

◆ ParseOffset()

void ParseOffset ( const std::vector< std::string > &  node_names,
const std::string **  next_token 
)
private

Definition at line 710 of file nnet-descriptor.cc.

References kaldi::nnet3::ExpectToken(), and kaldi::nnet3::ReadIntegerToken().

Referenced by GeneralDescriptor::Parse().

712  {
713  descriptors_.push_back(Parse(node_names, next_token));
714  ExpectToken(",", "Offset", next_token);
715  value1_ = ReadIntegerToken("Offset", next_token);
716  if (**next_token == ",") {
717  (*next_token)++;
718  value2_ = ReadIntegerToken("Offset", next_token);
719  } else {
720  value2_ = 0;
721  }
722  ExpectToken(")", "Offset", next_token);
723 }
static int32 ReadIntegerToken(const std::string &what_we_are_parsing, const std::string **next_token)
std::vector< GeneralDescriptor * > descriptors_
static GeneralDescriptor * Parse(const std::vector< std::string > &node_names, const std::string **next_token)
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)

◆ ParseReplaceIndex()

void ParseReplaceIndex ( const std::vector< std::string > &  node_names,
const std::string **  next_token 
)
private

Definition at line 736 of file nnet-descriptor.cc.

References kaldi::nnet3::ExpectToken(), KALDI_ERR, ReplaceIndexForwardingDescriptor::kT, ReplaceIndexForwardingDescriptor::kX, and kaldi::nnet3::ReadIntegerToken().

Referenced by GeneralDescriptor::Parse().

738  {
739  descriptors_.push_back(Parse(node_names, next_token));
740  ExpectToken(",", "ReplaceIndex", next_token);
741  if (**next_token == "t") {
743  (*next_token)++;
744  } else if (**next_token == "x") {
746  (*next_token)++;
747  } else {
748  KALDI_ERR << "Expected 't' or 'x', got " << **next_token;
749  }
750  ExpectToken(",", "ReplaceIndex", next_token);
751  value2_ = ReadIntegerToken("Replace", next_token);
752  ExpectToken(")", "ReplaceIndex", next_token);
753 }
static int32 ReadIntegerToken(const std::string &what_we_are_parsing, const std::string **next_token)
std::vector< GeneralDescriptor * > descriptors_
kaldi::int32 int32
static GeneralDescriptor * Parse(const std::vector< std::string > &node_names, const std::string **next_token)
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ ParseRound()

void ParseRound ( const std::vector< std::string > &  node_names,
const std::string **  next_token 
)
private

Definition at line 726 of file nnet-descriptor.cc.

References kaldi::nnet3::ExpectToken(), and kaldi::nnet3::ReadIntegerToken().

Referenced by GeneralDescriptor::Parse().

728  {
729  descriptors_.push_back(Parse(node_names, next_token));
730  ExpectToken(",", "Round", next_token);
731  value1_ = ReadIntegerToken("Round", next_token);
732  ExpectToken(")", "Round", next_token);
733 }
static int32 ReadIntegerToken(const std::string &what_we_are_parsing, const std::string **next_token)
std::vector< GeneralDescriptor * > descriptors_
static GeneralDescriptor * Parse(const std::vector< std::string > &node_names, const std::string **next_token)
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)

◆ ParseScale()

void ParseScale ( const std::vector< std::string > &  node_names,
const std::string **  next_token 
)
private

Definition at line 677 of file nnet-descriptor.cc.

References kaldi::ConvertStringToReal(), kaldi::nnet3::ExpectToken(), and KALDI_ERR.

Referenced by GeneralDescriptor::Parse().

679  {
680  if (!ConvertStringToReal(**next_token, &alpha_)) {
681  KALDI_ERR << "Parsing Scale() in descriptor: expected floating-point scale"
682  ", got: " << **next_token;
683  }
684  (*next_token)++; // Consume the float.
685  ExpectToken(",", "Scale", next_token);
686  descriptors_.push_back(Parse(node_names, next_token));
687  ExpectToken(")", "Scale", next_token);
688 }
std::vector< GeneralDescriptor * > descriptors_
static GeneralDescriptor * Parse(const std::vector< std::string > &node_names, const std::string **next_token)
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
#define KALDI_ERR
Definition: kaldi-error.h:147
bool ConvertStringToReal(const std::string &str, T *out)
ConvertStringToReal converts a string into either float or double and returns false if there was any ...
Definition: text-utils.cc:238

◆ ParseSwitch()

void ParseSwitch ( const std::vector< std::string > &  node_names,
const std::string **  next_token 
)
private

◆ Print()

void Print ( const std::vector< std::string > &  node_names,
std::ostream &  os 
)

Definition at line 975 of file nnet-descriptor.cc.

References rnnlm::i, KALDI_ASSERT, ReplaceIndexForwardingDescriptor::kT, and ReplaceIndexForwardingDescriptor::kX.

976  {
977  switch (descriptor_type_) {
978  // first handle all the expressions of the form "Operator(<desc1>, ... <descN>)".
979  case kAppend: os << "Append("; break;
980  case kSum: os << "Sum("; break;
981  case kFailover: os << "Failover("; break;
982  case kIfDefined: os << "IfDefined("; break;
983  case kSwitch: os << "Switch("; break;
984  // Scale() ends in a descriptor, so we also break and let the generic code
985  // handle that.
986  case kScale: os << "Scale(" << alpha_ << ", "; break;
987  // now handle the exceptions.
988  case kOffset: case kRound: {
989  os << "Offset(";
990  KALDI_ASSERT(descriptors_.size() == 1);
991  descriptors_[0]->Print(node_names, os);
992  os << ", " << value1_;
993  if (descriptor_type_ == kOffset && value2_ != 0) os << ", " << value2_;
994  os << ")";
995  return;
996  }
997  case kReplaceIndex: {
998  os << "ReplaceIndex(";
999  KALDI_ASSERT(descriptors_.size() == 1);
1000  descriptors_[0]->Print(node_names, os);
1003  if (value1_ == int32(ReplaceIndexForwardingDescriptor::kT)) {
1004  os << ", t, ";
1005  } else {
1006  os << ", x, ";
1007  }
1008  os << value2_ << ")";
1009  return;
1010  }
1011  case kNodeName: {
1012  KALDI_ASSERT(static_cast<size_t>(value1_) < node_names.size());
1013  os << node_names[value1_];
1014  return;
1015  }
1016  case kConst: {
1017  os << "Const(" << alpha_ << ", " << value1_ << ")";
1018  return;
1019  }
1020  }
1021  for (size_t i = 0; i < descriptors_.size(); i++) {
1022  if (i > 0) os << ", ";
1023  descriptors_[i]->Print(node_names, os);
1024  }
1025  os << ")";
1026 }
std::vector< GeneralDescriptor * > descriptors_
kaldi::int32 int32
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

Member Data Documentation

◆ alpha_

BaseFloat alpha_
private

Definition at line 661 of file nnet-descriptor.h.

Referenced by GeneralDescriptor::Normalize().

◆ descriptor_type_

DescriptorType descriptor_type_
private

◆ descriptors_

◆ value1_

int32 value1_
private

Definition at line 650 of file nnet-descriptor.h.

Referenced by GeneralDescriptor::Normalize().

◆ value2_

int32 value2_
private

Definition at line 655 of file nnet-descriptor.h.

Referenced by GeneralDescriptor::Normalize().


The documentation for this struct was generated from the following files: