agglomerative-cluster.cc
Go to the documentation of this file.
1 // ivectorbin/agglomerative-cluster.cc
2 
3 // Copyright 2016-2018 David Snyder
4 // 2017-2018 Matthew Maciejewski
5 // 2019 Dogan Can
6 
7 // See ../../COPYING for clarification regarding multiple authors
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
17 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
18 // MERCHANTABLITY OR NON-INFRINGEMENT.
19 // See the Apache 2 License for the specific language governing permissions and
20 // limitations under the License.
21 
22 
23 #include "base/kaldi-common.h"
24 #include "util/common-utils.h"
25 #include "util/stl-utils.h"
27 
28 int main(int argc, char *argv[]) {
29  using namespace kaldi;
30  typedef kaldi::int32 int32;
31  try {
32  const char *usage =
33  "Cluster utterances by similarity score, used in diarization.\n"
34  "Takes a table of score matrices indexed by recording, with the\n"
35  "rows/columns corresponding to the utterances of that recording in\n"
36  "sorted order and a reco2utt file that contains the mapping from\n"
37  "recordings to utterances, and outputs a list of labels in the form\n"
38  "<utt> <label>. Clustering is done using agglomerative hierarchical\n"
39  "clustering with a score threshold as stop criterion. By default, the\n"
40  "program reads in similarity scores, but with --read-costs=true\n"
41  "the scores are interpreted as costs (i.e. a smaller value indicates\n"
42  "utterance similarity).\n"
43  "Usage: agglomerative-cluster [options] <scores-rspecifier> "
44  "<reco2utt-rspecifier> <labels-wspecifier>\n"
45  "e.g.: \n"
46  " agglomerative-cluster ark:scores.ark ark:reco2utt \n"
47  " ark,t:labels.txt\n";
48 
49  ParseOptions po(usage);
50  std::string reco2num_spk_rspecifier;
51  BaseFloat threshold = 0.0, max_spk_fraction = 1.0;
52  bool read_costs = false;
53  int32 first_pass_max_utterances = std::numeric_limits<int16>::max();
54 
55  po.Register("reco2num-spk-rspecifier", &reco2num_spk_rspecifier,
56  "If supplied, clustering creates exactly this many clusters for each"
57  " recording and the option --threshold is ignored.");
58  po.Register("threshold", &threshold, "Merge clusters if their distance"
59  " is less than this threshold.");
60  po.Register("read-costs", &read_costs, "If true, the first"
61  " argument is interpreted as a matrix of costs rather than a"
62  " similarity matrix.");
63  po.Register("first-pass-max-utterances", &first_pass_max_utterances,
64  "If the number of utterances is larger than first-pass-max-utterances,"
65  " then clustering is done in two passes. In the first pass, input points"
66  " are divided into contiguous subsets of size first-pass-max-utterances"
67  " and each subset is clustered separately. In the second pass, the first"
68  " pass clusters are merged into the final set of clusters.");
69  po.Register("max-spk-fraction", &max_spk_fraction, "Merge clusters if the"
70  " total fraction of utterances in them is less than this threshold."
71  " This is active only when reco2num-spk-rspecifier is supplied and"
72  " 1.0 / num-spk <= max-spk-fraction <= 1.0.");
73 
74  po.Read(argc, argv);
75 
76  if (po.NumArgs() != 3) {
77  po.PrintUsage();
78  exit(1);
79  }
80 
81  std::string scores_rspecifier = po.GetArg(1),
82  reco2utt_rspecifier = po.GetArg(2),
83  label_wspecifier = po.GetArg(3);
84 
85  SequentialBaseFloatMatrixReader scores_reader(scores_rspecifier);
86  RandomAccessTokenVectorReader reco2utt_reader(reco2utt_rspecifier);
87  RandomAccessInt32Reader reco2num_spk_reader(reco2num_spk_rspecifier);
88  Int32Writer label_writer(label_wspecifier);
89 
90  if (!read_costs)
91  threshold = -threshold;
92  for (; !scores_reader.Done(); scores_reader.Next()) {
93  std::string reco = scores_reader.Key();
94  Matrix<BaseFloat> costs = scores_reader.Value();
95  // By default, the scores give the similarity between pairs of
96  // utterances. We need to multiply the scores by -1 to reinterpet
97  // them as costs (unless --read-costs=true) as the agglomerative
98  // clustering code requires.
99  if (!read_costs)
100  costs.Scale(-1);
101  std::vector<std::string> uttlist = reco2utt_reader.Value(reco);
102  std::vector<int32> spk_ids;
103  if (reco2num_spk_rspecifier.size()) {
104  int32 num_speakers = reco2num_spk_reader.Value(reco);
105  if (1.0 / num_speakers <= max_spk_fraction && max_spk_fraction <= 1.0)
106  AgglomerativeCluster(costs, std::numeric_limits<BaseFloat>::max(),
107  num_speakers, first_pass_max_utterances,
108  max_spk_fraction, &spk_ids);
109  else
110  AgglomerativeCluster(costs, std::numeric_limits<BaseFloat>::max(),
111  num_speakers, first_pass_max_utterances,
112  1.0, &spk_ids);
113  } else {
114  AgglomerativeCluster(costs, threshold, 1, first_pass_max_utterances,
115  1.0, &spk_ids);
116  }
117  for (int32 i = 0; i < spk_ids.size(); i++)
118  label_writer.Write(uttlist[i], spk_ids[i]);
119  }
120  return 0;
121 
122  } catch(const std::exception &e) {
123  std::cerr << e.what();
124  return -1;
125  }
126 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
A templated class for writing objects to an archive or script file; see The Table concept...
Definition: kaldi-table.h:368
kaldi::int32 int32
int main(int argc, char *argv[])
void Write(const std::string &key, const T &value) const
void Register(const std::string &name, bool *ptr, const std::string &doc)
Allows random access to a collection of objects in an archive or script file; see The Table concept...
Definition: kaldi-table.h:233
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const T & Value(const std::string &key)
void Scale(Real alpha)
Multiply each element with a scalar value.
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void AgglomerativeCluster(const Matrix< BaseFloat > &costs, BaseFloat threshold, int32 min_clusters, int32 first_pass_max_points, BaseFloat max_cluster_fraction, std::vector< int32 > *assignments_out)
This is the function that is called to perform the agglomerative clustering.