am-nnet-test.cc
Go to the documentation of this file.
1 // nnet2/am-nnet-test.cc
2 
3 // Copyright 2014 Johns Hopkins University (author: Daniel Povey)
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "hmm/transition-model.h"
21 #include "hmm/hmm-test-utils.h"
22 #include "nnet2/am-nnet.h"
23 
24 
25 namespace kaldi {
26 namespace nnet2 {
27 
28 
30  std::vector<int32> phones;
31  phones.push_back(1);
32  for (int32 i = 2; i < 20; i++)
33  if (rand() % 2 == 0)
34  phones.push_back(i);
35  int32 N = 2 + rand() % 2, // context-size N is 2 or 3.
36  P = rand() % N; // Central-phone is random on [0, N)
37 
38  std::vector<int32> num_pdf_classes;
39 
40  ContextDependency *ctx_dep =
41  GenRandContextDependencyLarge(phones, N, P,
42  true, &num_pdf_classes);
43 
44  HmmTopology topo = GetDefaultTopology(phones);
45 
46  TransitionModel trans_model(*ctx_dep, topo);
47 
48  delete ctx_dep; // We won't need this further.
49  ctx_dep = NULL;
50 
51  int32 input_dim = 40, output_dim = trans_model.NumPdfs();
52  Nnet *nnet = GenRandomNnet(input_dim, output_dim);
53 
54  AmNnet am_nnet(*nnet);
55  delete nnet;
56  nnet = NULL;
57  Vector<BaseFloat> priors(output_dim);
58  priors.SetRandn();
59  priors.ApplyExp();
60  priors.Scale(1.0 / priors.Sum());
61 
62  am_nnet.SetPriors(priors);
63 
64  bool binary = (rand() % 2 == 0);
65  std::ostringstream os;
66  am_nnet.Write(os, binary);
67  AmNnet am_nnet2;
68  std::istringstream is(os.str());
69  am_nnet2.Read(is, binary);
70 
71  std::ostringstream os2;
72  am_nnet2.Write(os2, binary);
73 
74  KALDI_ASSERT(os2.str() == os.str());
75 }
76 
77 } // namespace nnet2
78 } // namespace kaldi
79 
80 
81 int main() {
82  using namespace kaldi;
83  using namespace kaldi::nnet2;
84 
86  return 0;
87 }
88 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void ApplyExp()
Apply exponential to each value in vector.
HmmTopology GetDefaultTopology(const std::vector< int32 > &phones_in)
This function returns a HmmTopology object giving a normal 3-state topology, covering all phones in t...
ContextDependency * GenRandContextDependencyLarge(const std::vector< int32 > &phone_ids, int N, int P, bool ensure_all_covered, std::vector< int32 > *hmm_lengths)
GenRandContextDependencyLarge is like GenRandContextDependency but generates a larger tree with speci...
Definition: context-dep.cc:97
A class for storing topology information for phones.
Definition: hmm-topology.h:93
Nnet * GenRandomNnet(int32 input_dim, int32 output_dim)
This function generates a random neural net, for testing purposes.
Definition: nnet-nnet.cc:772
void Read(std::istream &is, bool binary)
Definition: am-nnet.cc:39
kaldi::int32 int32
void Write(std::ostream &os, bool binary) const
Definition: am-nnet.cc:31
void Scale(Real alpha)
Multiplies all elements by this constant.
Real Sum() const
Returns sum of the elements.
void SetRandn()
Set vector to random normally-distributed noise.
void UnitTestAmNnet()
Definition: am-nnet-test.cc:29
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
int main()
Definition: am-nnet-test.cc:81
void SetPriors(const VectorBase< BaseFloat > &priors)
Definition: am-nnet.cc:44