am-nnet-simple.cc
Go to the documentation of this file.
1 // nnet3/am-nnet-simple.cc
2 
3 // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "nnet3/am-nnet-simple.h"
21 #include "nnet3/nnet-utils.h"
22 
23 namespace kaldi {
24 namespace nnet3 {
25 
26 
27 
29  int32 ans = nnet_.OutputDim("output");
30  KALDI_ASSERT(ans > 0);
31  return ans;
32 }
33 
34 void AmNnetSimple::Write(std::ostream &os, bool binary) const {
35  // We don't write any header or footer like <AmNnetSimple> and </AmNnetSimple> -- we just
36  // write the neural net and then the priors. Who knows, there might be some
37  // situation where we want to just read the neural net.
38  nnet_.Write(os, binary);
39  WriteToken(os, binary, "<LeftContext>");
40  WriteBasicType(os, binary, left_context_);
41  WriteToken(os, binary, "<RightContext>");
42  WriteBasicType(os, binary, right_context_);
43  WriteToken(os, binary, "<Priors>");
44  priors_.Write(os, binary);
45 }
46 
47 void AmNnetSimple::Read(std::istream &is, bool binary) {
48  nnet_.Read(is, binary);
49  ExpectToken(is, binary, "<LeftContext>");
50  ReadBasicType(is, binary, &left_context_);
51  ExpectToken(is, binary, "<RightContext>");
52  ReadBasicType(is, binary, &right_context_);
53  SetContext(); // temporarily, I'm not trusting the written ones (there was
54  // briefly a bug)
55  ExpectToken(is, binary, "<Priors>");
56  priors_.Read(is, binary);
57 }
58 
59 void AmNnetSimple::SetNnet(const Nnet &nnet) {
60  nnet_ = nnet;
61  SetContext();
62  if (priors_.Dim() != 0 && priors_.Dim() != nnet_.OutputDim("output")) {
63  KALDI_WARN << "Removing priors since there is a dimension mismatch after "
64  << "changing the nnet: " << priors_.Dim() << " vs. "
65  << nnet_.OutputDim("output");
66  priors_.Resize(0);
67  }
68 }
69 
71  priors_ = priors;
72  if (priors_.Dim() != nnet_.OutputDim("output") &&
73  priors_.Dim() != 0) {
74  KALDI_ERR << "Dimension mismatch when setting priors: priors have dim "
75  << priors.Dim() << ", model expects "
76  << nnet_.OutputDim("output");
77  }
78 }
79 
80 std::string AmNnetSimple::Info() const {
81  std::ostringstream ostr;
82  ostr << "input-dim: " << nnet_.InputDim("input") << "\n";
83  ostr << "ivector-dim: " << nnet_.InputDim("ivector") << "\n";
84  ostr << "num-pdfs: " << nnet_.OutputDim("output") << "\n";
85  ostr << "prior-dimension: " << priors_.Dim() << "\n";
86  if (priors_.Dim() != 0) {
87  ostr << "prior-sum: " << priors_.Sum() << "\n";
88  ostr << "prior-min: " << priors_.Min() << "\n";
89  ostr << "prior-max: " << priors_.Max() << "\n";
90  }
91  ostr << "# Nnet info follows.\n";
92  return ostr.str() + nnet_.Info();
93 }
94 
95 
97  if (!IsSimpleNnet(nnet_)) {
98  KALDI_ERR << "Class AmNnetSimple is only intended for a restricted type of "
99  << "nnet, and this one does not meet the conditions.";
100  }
102  &left_context_,
103  &right_context_);
104 }
105 
106 
107 } // namespace nnet3
108 } // namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
std::string Info() const
Vector< BaseFloat > priors_
int32 InputDim(const std::string &input_name) const
Definition: nnet-nnet.cc:669
void Write(std::ostream &ostream, bool binary) const
Definition: nnet-nnet.cc:630
void ReadBasicType(std::istream &is, bool binary, T *t)
ReadBasicType is the name of the read function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:55
kaldi::int32 int32
void Read(std::istream &is, bool binary)
int32 OutputDim(const std::string &output_name) const
Definition: nnet-nnet.cc:677
This file contains some miscellaneous functions dealing with class Nnet.
std::string Info() const
returns some human-readable information about the network, mostly for debugging purposes.
Definition: nnet-nnet.cc:821
void SetPriors(const VectorBase< BaseFloat > &priors)
static void ExpectToken(const std::string &token, const std::string &what_we_are_parsing, const std::string **next_token)
void ComputeSimpleNnetContext(const Nnet &nnet, int32 *left_context, int32 *right_context)
ComputeSimpleNnetContext computes the left-context and right-context of a nnet.
Definition: nnet-utils.cc:146
void Read(std::istream &istream, bool binary)
Definition: nnet-nnet.cc:586
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
void WriteToken(std::ostream &os, bool binary, const char *token)
The WriteToken functions are for writing nonempty sequences of non-space characters.
Definition: io-funcs.cc:134
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void SetNnet(const Nnet &nnet)
bool IsSimpleNnet(const Nnet &nnet)
This function returns true if the nnet has the following properties: It has an output called "output"...
Definition: nnet-utils.cc:52
void SetContext()
This function works out the left_context_ and right_context_ variables from the network (it&#39;s a rathe...
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void WriteBasicType(std::ostream &os, bool binary, T t)
WriteBasicType is the name of the write function for bool, integer types, and floating-point types...
Definition: io-funcs-inl.h:34
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void Write(std::ostream &os, bool binary) const