AmNnetSimple Class Reference

#include <am-nnet-simple.h>

Collaboration diagram for AmNnetSimple:

Public Member Functions

 AmNnetSimple ()
 
 AmNnetSimple (const AmNnetSimple &other)
 
 AmNnetSimple (const Nnet &nnet)
 
int32 NumPdfs () const
 
void Write (std::ostream &os, bool binary) const
 
void Read (std::istream &is, bool binary)
 
const NnetGetNnet () const
 
NnetGetNnet ()
 Caution: if you structurally change the nnet, you should call SetContext() afterward. More...
 
void SetNnet (const Nnet &nnet)
 
void SetPriors (const VectorBase< BaseFloat > &priors)
 
const VectorBase< BaseFloat > & Priors () const
 
std::string Info () const
 
int32 LeftContext () const
 Minimum left context required to compute an output. More...
 
int32 RightContext () const
 Minimum right context required to compute an output. More...
 
int32 InputDim () const
 Returns the input feature dim. More...
 
int32 IvectorDim () const
 Returns the iVector dimension, or -1 if there is no such input. More...
 
void SetContext ()
 This function works out the left_context_ and right_context_ variables from the network (it's a rather complex calculation). More...
 

Private Member Functions

const AmNnetSimpleoperator= (const AmNnetSimple &other)
 

Private Attributes

Nnet nnet_
 
Vector< BaseFloatpriors_
 
int32 left_context_
 
int32 right_context_
 

Detailed Description

Definition at line 49 of file am-nnet-simple.h.

Constructor & Destructor Documentation

◆ AmNnetSimple() [1/3]

AmNnetSimple ( )
inline

Definition at line 51 of file am-nnet-simple.h.

51 { }

◆ AmNnetSimple() [2/3]

AmNnetSimple ( const AmNnetSimple other)
inline

Definition at line 53 of file am-nnet-simple.h.

53  :
54  nnet_(other.nnet_),
55  priors_(other.priors_),
56  left_context_(other.left_context_),
57  right_context_(other.right_context_) { }
Vector< BaseFloat > priors_

◆ AmNnetSimple() [3/3]

AmNnetSimple ( const Nnet nnet)
inlineexplicit

Definition at line 59 of file am-nnet-simple.h.

References AmNnetSimple::NumPdfs(), AmNnetSimple::Read(), AmNnetSimple::SetContext(), and AmNnetSimple::Write().

59  :
60  nnet_(nnet) { SetContext(); }
void SetContext()
This function works out the left_context_ and right_context_ variables from the network (it&#39;s a rathe...

Member Function Documentation

◆ GetNnet() [1/2]

◆ GetNnet() [2/2]

Nnet& GetNnet ( )
inline

Caution: if you structurally change the nnet, you should call SetContext() afterward.

Definition at line 72 of file am-nnet-simple.h.

References AmNnetSimple::nnet_, AmNnetSimple::SetNnet(), and AmNnetSimple::SetPriors().

72 { return nnet_; }

◆ Info()

std::string Info ( ) const

Definition at line 80 of file am-nnet-simple.cc.

References Nnet::Info(), Nnet::InputDim(), AmNnetSimple::nnet_, Nnet::OutputDim(), and AmNnetSimple::priors_.

Referenced by main(), and AmNnetSimple::Priors().

80  {
81  std::ostringstream ostr;
82  ostr << "input-dim: " << nnet_.InputDim("input") << "\n";
83  ostr << "ivector-dim: " << nnet_.InputDim("ivector") << "\n";
84  ostr << "num-pdfs: " << nnet_.OutputDim("output") << "\n";
85  ostr << "prior-dimension: " << priors_.Dim() << "\n";
86  if (priors_.Dim() != 0) {
87  ostr << "prior-sum: " << priors_.Sum() << "\n";
88  ostr << "prior-min: " << priors_.Min() << "\n";
89  ostr << "prior-max: " << priors_.Max() << "\n";
90  }
91  ostr << "# Nnet info follows.\n";
92  return ostr.str() + nnet_.Info();
93 }
Vector< BaseFloat > priors_
int32 InputDim(const std::string &input_name) const
Definition: nnet-nnet.cc:669
int32 OutputDim(const std::string &output_name) const
Definition: nnet-nnet.cc:677
std::string Info() const
returns some human-readable information about the network, mostly for debugging purposes.
Definition: nnet-nnet.cc:821

◆ InputDim()

int32 InputDim ( ) const
inline

Returns the input feature dim.

Definition at line 89 of file am-nnet-simple.h.

References Nnet::InputDim(), and AmNnetSimple::nnet_.

89 { return nnet_.InputDim("input"); }
int32 InputDim(const std::string &input_name) const
Definition: nnet-nnet.cc:669

◆ IvectorDim()

int32 IvectorDim ( ) const
inline

Returns the iVector dimension, or -1 if there is no such input.

Definition at line 92 of file am-nnet-simple.h.

References Nnet::InputDim(), AmNnetSimple::nnet_, AmNnetSimple::operator=(), and AmNnetSimple::SetContext().

92 { return nnet_.InputDim("ivector"); }
int32 InputDim(const std::string &input_name) const
Definition: nnet-nnet.cc:669

◆ LeftContext()

int32 LeftContext ( ) const
inline

Minimum left context required to compute an output.

Definition at line 83 of file am-nnet-simple.h.

References AmNnetSimple::left_context_.

◆ NumPdfs()

int32 NumPdfs ( ) const

Definition at line 28 of file am-nnet-simple.cc.

References KALDI_ASSERT, AmNnetSimple::nnet_, and Nnet::OutputDim().

Referenced by AmNnetSimple::AmNnetSimple(), and kaldi::nnet3::SetPriors().

28  {
29  int32 ans = nnet_.OutputDim("output");
30  KALDI_ASSERT(ans > 0);
31  return ans;
32 }
kaldi::int32 int32
int32 OutputDim(const std::string &output_name) const
Definition: nnet-nnet.cc:677
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185

◆ operator=()

const AmNnetSimple& operator= ( const AmNnetSimple other)
private

◆ Priors()

const VectorBase<BaseFloat>& Priors ( ) const
inline

Definition at line 78 of file am-nnet-simple.h.

References AmNnetSimple::Info(), and AmNnetSimple::priors_.

Referenced by DecodableAmNnetSimpleParallel::DecodableAmNnetSimpleParallel(), and main().

78 { return priors_; }
Vector< BaseFloat > priors_

◆ Read()

void Read ( std::istream &  is,
bool  binary 
)

Definition at line 47 of file am-nnet-simple.cc.

References kaldi::nnet3::ExpectToken(), AmNnetSimple::left_context_, AmNnetSimple::nnet_, AmNnetSimple::priors_, Nnet::Read(), kaldi::ReadBasicType(), AmNnetSimple::right_context_, and AmNnetSimple::SetContext().

Referenced by AmNnetSimple::AmNnetSimple(), main(), and Nnet::Read().

47  {
48  nnet_.Read(is, binary);
49  ExpectToken(is, binary, "<LeftContext>");
50  ReadBasicType(is, binary, &left_context_);
51  ExpectToken(is, binary, "<RightContext>");
52  ReadBasicType(is, binary, &right_context_);
53  SetContext(); // temporarily, I'm not trusting the written ones (there was
54  // briefly a bug)
55  ExpectToken(is, binary, "<Priors>");
56  priors_.Read(is, binary);
57 }
Vector< BaseFloat > priors_
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void Read(std::istream &istream, bool binary)
Definition: nnet-nnet.cc:586
void SetContext()
This function works out the left_context_ and right_context_ variables from the network (it&#39;s a rathe...

◆ RightContext()

int32 RightContext ( ) const
inline

Minimum right context required to compute an output.

Definition at line 86 of file am-nnet-simple.h.

References AmNnetSimple::right_context_.

◆ SetContext()

void SetContext ( )

This function works out the left_context_ and right_context_ variables from the network (it's a rather complex calculation).

You should call this if you have structurally changed the nnet without calling SetNnet(), e.g. using non-const GetNnet().

Definition at line 96 of file am-nnet-simple.cc.

References kaldi::nnet3::ComputeSimpleNnetContext(), kaldi::nnet3::IsSimpleNnet(), KALDI_ERR, AmNnetSimple::left_context_, AmNnetSimple::nnet_, and AmNnetSimple::right_context_.

Referenced by AmNnetSimple::AmNnetSimple(), AmNnetSimple::IvectorDim(), main(), AmNnetSimple::Read(), and AmNnetSimple::SetNnet().

96  {
97  if (!IsSimpleNnet(nnet_)) {
98  KALDI_ERR << "Class AmNnetSimple is only intended for a restricted type of "
99  << "nnet, and this one does not meet the conditions.";
100  }
102  &left_context_,
103  &right_context_);
104 }
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
Definition: nnet-utils.cc:146
#define KALDI_ERR
Definition: kaldi-error.h:147
bool IsSimpleNnet(const Nnet &nnet)
This function returns true if the nnet has the following properties: It has an output called "output"...
Definition: nnet-utils.cc:52

◆ SetNnet()

void SetNnet ( const Nnet nnet)

Definition at line 59 of file am-nnet-simple.cc.

References KALDI_WARN, AmNnetSimple::nnet_, Nnet::OutputDim(), AmNnetSimple::priors_, and AmNnetSimple::SetContext().

Referenced by AmNnetSimple::GetNnet(), and main().

59  {
60  nnet_ = nnet;
61  SetContext();
62  if (priors_.Dim() != 0 && priors_.Dim() != nnet_.OutputDim("output")) {
63  KALDI_WARN << "Removing priors since there is a dimension mismatch after "
64  << "changing the nnet: " << priors_.Dim() << " vs. "
65  << nnet_.OutputDim("output");
66  priors_.Resize(0);
67  }
68 }
Vector< BaseFloat > priors_
int32 OutputDim(const std::string &output_name) const
Definition: nnet-nnet.cc:677
#define KALDI_WARN
Definition: kaldi-error.h:150
void SetContext()
This function works out the left_context_ and right_context_ variables from the network (it&#39;s a rathe...

◆ SetPriors()

void SetPriors ( const VectorBase< BaseFloat > &  priors)

Definition at line 70 of file am-nnet-simple.cc.

References VectorBase< Real >::Dim(), KALDI_ERR, AmNnetSimple::nnet_, Nnet::OutputDim(), and AmNnetSimple::priors_.

Referenced by AmNnetSimple::GetNnet(), main(), and kaldi::nnet3::SetPriors().

70  {
71  priors_ = priors;
72  if (priors_.Dim() != nnet_.OutputDim("output") &&
73  priors_.Dim() != 0) {
74  KALDI_ERR << "Dimension mismatch when setting priors: priors have dim "
75  << priors.Dim() << ", model expects "
76  << nnet_.OutputDim("output");
77  }
78 }
Vector< BaseFloat > priors_
int32 OutputDim(const std::string &output_name) const
Definition: nnet-nnet.cc:677
#define KALDI_ERR
Definition: kaldi-error.h:147

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const

Definition at line 34 of file am-nnet-simple.cc.

References AmNnetSimple::left_context_, AmNnetSimple::nnet_, AmNnetSimple::priors_, AmNnetSimple::right_context_, Nnet::Write(), kaldi::WriteBasicType(), and kaldi::WriteToken().

Referenced by AmNnetSimple::AmNnetSimple(), and main().

34  {
35  // We don't write any header or footer like <AmNnetSimple> and </AmNnetSimple> -- we just
36  // write the neural net and then the priors. Who knows, there might be some
37  // situation where we want to just read the neural net.
38  nnet_.Write(os, binary);
39  WriteToken(os, binary, "<LeftContext>");
40  WriteBasicType(os, binary, left_context_);
41  WriteToken(os, binary, "<RightContext>");
42  WriteBasicType(os, binary, right_context_);
43  WriteToken(os, binary, "<Priors>");
44  priors_.Write(os, binary);
45 }
Vector< BaseFloat > priors_
void Write(std::ostream &ostream, bool binary) const
Definition: nnet-nnet.cc:630
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34

Member Data Documentation

◆ left_context_

int32 left_context_
private

◆ nnet_

◆ priors_

◆ right_context_

int32 right_context_
private

The documentation for this class was generated from the following files: