am-nnet.cc
Go to the documentation of this file.
1 // nnet2/am-nnet.cc
2 
3 // Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "nnet2/am-nnet.h"
21 
22 namespace kaldi {
23 namespace nnet2 {
24 
25 
26 void AmNnet::Init(std::istream &config_is) {
27  nnet_.Init(config_is);
28 }
29 
30 
31 void AmNnet::Write(std::ostream &os, bool binary) const {
32  // We don't write any header or footer like <AmNnet> and </AmNnet> -- we just
33  // write the neural net and then the priors. Who knows, there might be some
34  // situation where we want to just read the neural net.
35  nnet_.Write(os, binary);
36  priors_.Write(os, binary);
37 }
38 
39 void AmNnet::Read(std::istream &is, bool binary) {
40  nnet_.Read(is, binary);
41  priors_.Read(is, binary);
42 }
43 
45  priors_ = priors;
46  if (priors_.Dim() > NumPdfs())
47  KALDI_ERR << "Dimension of priors cannot exceed number of pdfs.";
48 
49  if (priors_.Dim() > 0 && priors_.Dim() < NumPdfs()) {
50  KALDI_WARN << "Dimension of priors is " << priors_.Dim() << " < "
51  << NumPdfs() << ": extending with zeros, in case you had "
52  << "unseen pdf's, but this possibly indicates a serious problem.";
53  priors_.Resize(NumPdfs(), kCopyData);
54  }
55 }
56 
57 std::string AmNnet::Info() const {
58  std::ostringstream ostr;
59  ostr << "prior dimension: " << priors_.Dim();
60  if (priors_.Dim() != 0) {
61  ostr << ", prior sum: " << priors_.Sum() << ", prior min: " << priors_.Min()
62  << "\n";
63  }
64  return nnet_.Info() + ostr.str();
65 }
66 
67 void AmNnet::Init(const Nnet &nnet) {
68  nnet_ = nnet;
69  if (priors_.Dim() != 0 && priors_.Dim() != nnet.OutputDim()) {
70  KALDI_WARN << "Initializing neural net: prior dimension mismatch, "
71  << "discarding old priors.";
72  priors_.Resize(0);
73  }
74 }
75 
76 void AmNnet::ResizeOutputLayer(int32 new_num_pdfs) {
77  nnet_.ResizeOutputLayer(new_num_pdfs);
78  priors_.Resize(new_num_pdfs);
79  priors_.Set(1.0 / new_num_pdfs);
80 }
81 
82 } // namespace nnet2
83 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void Read(std::istream &is, bool binary)
Definition: nnet-nnet.cc:175
Vector< BaseFloat > priors_
Definition: am-nnet.h:78
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
int32 OutputDim() const
The output dimension of the network – typically the number of pdfs.
Definition: nnet-nnet.cc:31
kaldi::int32 int32
std::string Info() const
Definition: am-nnet.cc:57
void Write(std::ostream &os, bool binary) const
Definition: nnet-nnet.cc:160
void ResizeOutputLayer(int32 new_num_pdfs)
This function is used when doing transfer learning to a new system.
Definition: nnet-nnet.cc:356
void Init(std::istream &config_is)
Initialize the neural network based acoustic model from a config file.
Definition: am-nnet.cc:26
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
#define KALDI_ERR
Definition: kaldi-error.h:147
std::string Info() const
Definition: nnet-nnet.cc:257
#define KALDI_WARN
Definition: kaldi-error.h:150
void Init(std::istream &is)
Initialize from config file.
Definition: nnet-nnet.cc:281
void ResizeOutputLayer(int32 new_num_pdfs)
This function is used when doing transfer learning to a new system.
Definition: am-nnet.cc:76
int32 NumPdfs() const
Definition: am-nnet.h:55
void SetPriors(const VectorBase< BaseFloat > &priors)
Definition: am-nnet.cc:44
Provides a vector abstraction class.
Definition: kaldi-vector.h:41