AmNnet Class Reference

#include <am-nnet.h>

Collaboration diagram for AmNnet:

Public Member Functions

 AmNnet ()
 
 AmNnet (const AmNnet &other)
 
 AmNnet (const Nnet &nnet)
 
void Init (std::istream &config_is)
 Initialize the neural network based acoustic model from a config file. More...
 
void Init (const Nnet &nnet)
 Initialize from a neural network that's already been set up. More...
 
int32 NumPdfs () const
 
void Write (std::ostream &os, bool binary) const
 
void Read (std::istream &is, bool binary)
 
const NnetGetNnet () const
 
NnetGetNnet ()
 
void SetPriors (const VectorBase< BaseFloat > &priors)
 
const VectorBase< BaseFloat > & Priors () const
 
std::string Info () const
 
void ResizeOutputLayer (int32 new_num_pdfs)
 This function is used when doing transfer learning to a new system. More...
 

Private Member Functions

const AmNnetoperator= (const AmNnet &other)
 

Private Attributes

Nnet nnet_
 
Vector< BaseFloatpriors_
 

Detailed Description

Definition at line 38 of file am-nnet.h.

Constructor & Destructor Documentation

◆ AmNnet() [1/3]

AmNnet ( )
inline

Definition at line 40 of file am-nnet.h.

40 { }

◆ AmNnet() [2/3]

AmNnet ( const AmNnet other)
inline

Definition at line 42 of file am-nnet.h.

42 : nnet_(other.nnet_), priors_(other.priors_) { }
Vector< BaseFloat > priors_
Definition: am-nnet.h:78

◆ AmNnet() [3/3]

AmNnet ( const Nnet nnet)
inlineexplicit

Definition at line 44 of file am-nnet.h.

References AmNnet::Init().

44 : nnet_(nnet) { }

Member Function Documentation

◆ GetNnet() [1/2]

◆ GetNnet() [2/2]

Nnet& GetNnet ( )
inline

Definition at line 63 of file am-nnet.h.

References AmNnet::nnet_, and AmNnet::SetPriors().

63 { return nnet_; }

◆ Info()

std::string Info ( ) const

Definition at line 57 of file am-nnet.cc.

References Nnet::Info(), AmNnet::nnet_, and AmNnet::priors_.

Referenced by main(), and AmNnet::Priors().

57  {
58  std::ostringstream ostr;
59  ostr << "prior dimension: " << priors_.Dim();
60  if (priors_.Dim() != 0) {
61  ostr << ", prior sum: " << priors_.Sum() << ", prior min: " << priors_.Min()
62  << "\n";
63  }
64  return nnet_.Info() + ostr.str();
65 }
Vector< BaseFloat > priors_
Definition: am-nnet.h:78
std::string Info() const
Definition: nnet-nnet.cc:257

◆ Init() [1/2]

void Init ( std::istream &  config_is)

Initialize the neural network based acoustic model from a config file.

At this point the priors won't be initialized; you'd have to do SetPriors for that.

Definition at line 26 of file am-nnet.cc.

References Nnet::Init(), and AmNnet::nnet_.

Referenced by AmNnet::AmNnet(), and main().

26  {
27  nnet_.Init(config_is);
28 }
void Init(std::istream &is)
Initialize from config file.
Definition: nnet-nnet.cc:281

◆ Init() [2/2]

void Init ( const Nnet nnet)

Initialize from a neural network that's already been set up.

Again, the priors will be empty at this point.

Definition at line 67 of file am-nnet.cc.

References KALDI_WARN, AmNnet::nnet_, Nnet::OutputDim(), and AmNnet::priors_.

67  {
68  nnet_ = nnet;
69  if (priors_.Dim() != 0 && priors_.Dim() != nnet.OutputDim()) {
70  KALDI_WARN << "Initializing neural net: prior dimension mismatch, "
71  << "discarding old priors.";
72  priors_.Resize(0);
73  }
74 }
Vector< BaseFloat > priors_
Definition: am-nnet.h:78
#define KALDI_WARN
Definition: kaldi-error.h:150

◆ NumPdfs()

int32 NumPdfs ( ) const
inline

Definition at line 55 of file am-nnet.h.

References AmNnet::nnet_, Nnet::OutputDim(), AmNnet::Read(), and AmNnet::Write().

Referenced by main(), kaldi::nnet2::SetPriors(), and AmNnet::SetPriors().

55 { return nnet_.OutputDim(); }
int32 OutputDim() const
The output dimension of the network – typically the number of pdfs.
Definition: nnet-nnet.cc:31

◆ operator=()

const AmNnet& operator= ( const AmNnet other)
private

Referenced by AmNnet::Priors().

◆ Priors()

◆ Read()

void Read ( std::istream &  is,
bool  binary 
)

Definition at line 39 of file am-nnet.cc.

References AmNnet::nnet_, AmNnet::priors_, and Nnet::Read().

Referenced by main(), AmNnet::NumPdfs(), and kaldi::nnet2::UnitTestAmNnet().

39  {
40  nnet_.Read(is, binary);
41  priors_.Read(is, binary);
42 }
void Read(std::istream &is, bool binary)
Definition: nnet-nnet.cc:175
Vector< BaseFloat > priors_
Definition: am-nnet.h:78

◆ ResizeOutputLayer()

void ResizeOutputLayer ( int32  new_num_pdfs)

This function is used when doing transfer learning to a new system.

It will set the priors to be all the same.

Definition at line 76 of file am-nnet.cc.

References AmNnet::nnet_, AmNnet::priors_, and Nnet::ResizeOutputLayer().

Referenced by main(), and AmNnet::Priors().

76  {
77  nnet_.ResizeOutputLayer(new_num_pdfs);
78  priors_.Resize(new_num_pdfs);
79  priors_.Set(1.0 / new_num_pdfs);
80 }
Vector< BaseFloat > priors_
Definition: am-nnet.h:78
void ResizeOutputLayer(int32 new_num_pdfs)
This function is used when doing transfer learning to a new system.
Definition: nnet-nnet.cc:356

◆ SetPriors()

void SetPriors ( const VectorBase< BaseFloat > &  priors)

Definition at line 44 of file am-nnet.cc.

References KALDI_ERR, KALDI_WARN, kaldi::kCopyData, AmNnet::NumPdfs(), and AmNnet::priors_.

Referenced by AmNnet::GetNnet(), main(), kaldi::nnet2::SetPriors(), kaldi::nnet2::UnitTestAmNnet(), and kaldi::nnet2::UnitTestNnetDecodable().

44  {
45  priors_ = priors;
46  if (priors_.Dim() > NumPdfs())
47  KALDI_ERR << "Dimension of priors cannot exceed number of pdfs.";
48 
49  if (priors_.Dim() > 0 && priors_.Dim() < NumPdfs()) {
50  KALDI_WARN << "Dimension of priors is " << priors_.Dim() << " < "
51  << NumPdfs() << ": extending with zeros, in case you had "
52  << "unseen pdf's, but this possibly indicates a serious problem.";
53  priors_.Resize(NumPdfs(), kCopyData);
54  }
55 }
Vector< BaseFloat > priors_
Definition: am-nnet.h:78
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
int32 NumPdfs() const
Definition: am-nnet.h:55

◆ Write()

void Write ( std::ostream &  os,
bool  binary 
) const

Definition at line 31 of file am-nnet.cc.

References AmNnet::nnet_, AmNnet::priors_, and Nnet::Write().

Referenced by main(), AmNnet::NumPdfs(), and kaldi::nnet2::UnitTestAmNnet().

31  {
32  // We don't write any header or footer like <AmNnet> and </AmNnet> -- we just
33  // write the neural net and then the priors. Who knows, there might be some
34  // situation where we want to just read the neural net.
35  nnet_.Write(os, binary);
36  priors_.Write(os, binary);
37 }
Vector< BaseFloat > priors_
Definition: am-nnet.h:78
void Write(std::ostream &os, bool binary) const
Definition: nnet-nnet.cc:160

Member Data Documentation

◆ nnet_

◆ priors_


The documentation for this class was generated from the following files: