online-audio-client.cc File Reference
#include <iostream>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <unistd.h>
#include "util/parse-options.h"
#include "util/kaldi-table.h"
#include "feat/wave-reader.h"
#include "online/online-audio-source.h"
Include dependency graph for online-audio-client.cc:

Go to the source code of this file.

Classes

struct  RecognizedWord
 

Namespaces

 kaldi
 This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for mispronunciations detection tasks, the reference:
 

Functions

bool WriteFull (int32 desc, char *data, int32 size)
 
bool ReadLine (int32 desc, std::string *str)
 
std::string TimeToTimecode (float time)
 
int main (int argc, char **argv)
 

Variables

int32 buffer_offset = 0
 
int32 buffer_fill = 0
 
char read_buffer [1025]
 

Function Documentation

◆ main()

int main ( int  argc,
char **  argv 
)

Definition at line 52 of file online-audio-client.cc.

References WaveData::Data(), SequentialTableReader< Holder >::Done(), RecognizedWord::end, ParseOptions::GetArg(), rnnlm::i, KALDI_ASSERT, KALDI_ERR, KALDI_VLOG, KALDI_WARN, SequentialTableReader< Holder >::Key(), SequentialTableReader< Holder >::Next(), ParseOptions::NumArgs(), MatrixBase< Real >::NumRows(), ParseOptions::PrintUsage(), ParseOptions::Read(), kaldi::ReadLine(), ParseOptions::Register(), MatrixBase< Real >::Row(), WaveData::SampFreq(), RecognizedWord::start, kaldi::TimeToTimecode(), SequentialTableReader< Holder >::Value(), RecognizedWord::word, and kaldi::WriteFull().

52  {
53  using namespace kaldi;
54  typedef kaldi::int32 int32;
55  #if !defined(_MSC_VER)
56  try {
57 
58  const char *usage =
59  "Sends an audio file to the KALDI audio server (onlinebin/online-audio-server-decode-faster)\n"
60  "and prints the result optionally saving it to an HTK label file or WebVTT subtitles file\n\n"
61  "e.g.: ./online-audio-client 192.168.50.12 9012 'scp:wav_files.scp'\n\n";
62  ParseOptions po(usage);
63 
64  bool htk = false, vtt = false;
65  int32 channel = -1;
66  int32 packet_size = 1024;
67 
68  po.Register("htk", &htk, "Save the result to an HTK label file");
69  po.Register("vtt", &vtt, "Save the result to a WebVTT subtitle file");
70  po.Register(
71  "channel", &channel,
72  "Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right)");
73  po.Register("packet-size", &packet_size, "Send this many bytes per packet");
74 
75  po.Read(argc, argv);
76  if (po.NumArgs() != 3) {
77  po.PrintUsage();
78  return 1;
79  }
80 
81  std::string server_addr_str = po.GetArg(1);
82  std::string server_port_str = po.GetArg(2);
83  int32 server_port = strtol(server_port_str.c_str(), 0, 10);
84  std::string wav_rspecifier = po.GetArg(3);
85 
86  int32 client_desc = socket(AF_INET, SOCK_STREAM, 0);
87  if (client_desc == -1) {
88  std::cerr << "ERROR: couldn't create socket!\n";
89  return -1;
90  }
91 
92  struct hostent* hp;
93  unsigned long addr;
94 
95  addr = inet_addr(server_addr_str.c_str());
96  if (addr == INADDR_NONE) {
97  hp = gethostbyname(server_addr_str.c_str());
98  if (hp == NULL) {
99  std::cerr << "ERROR: couldn't resolve host string: "
100  << server_addr_str << '\n';
101  close(client_desc);
102  return -1;
103  }
104 
105  addr = *((unsigned long*) hp->h_addr);
106  }
107 
108  sockaddr_in server;
109  server.sin_addr.s_addr = addr;
110  server.sin_family = AF_INET;
111  server.sin_port = htons(server_port);
112  if (::connect(client_desc, (struct sockaddr*) &server, sizeof(server))) {
113  std::cerr << "ERROR: couldn't connect to server!\n";
114  close(client_desc);
115  return -1;
116  }
117 
118  KALDI_VLOG(2) << "Connected to KALDI server at host " << server_addr_str
119  << " port " << server_port;
120 
121  char* pack_buffer = new char[packet_size];
122 
123  SequentialTableReader < WaveHolder > reader(wav_rspecifier);
124  for (; !reader.Done(); reader.Next()) {
125  std::string wav_key = reader.Key();
126 
127  KALDI_VLOG(2) << "File: " << wav_key;
128 
129  const WaveData &wav_data = reader.Value();
130 
131  if (wav_data.SampFreq() != 16000)
132  KALDI_ERR << "Sampling rates other than 16kHz are not supported!";
133 
134  int32 num_chan = wav_data.Data().NumRows(), this_chan = channel;
135  { // This block works out the channel (0=left, 1=right...)
136  KALDI_ASSERT(num_chan > 0); // should have been caught in
137  // reading code if no channels.
138  if (channel == -1) {
139  this_chan = 0;
140  if (num_chan != 1)
141  KALDI_WARN << "Channel not specified but you have data with "
142  << num_chan << " channels; defaulting to zero";
143  } else {
144  if (this_chan >= num_chan) {
145  KALDI_WARN << "File with id " << wav_key << " has " << num_chan
146  << " channels but you specified channel " << channel
147  << ", producing no output.";
148  continue;
149  }
150  }
151  }
152 
153  OnlineVectorSource au_src(wav_data.Data().Row(this_chan));
154  Vector < BaseFloat > data(packet_size / 2);
155  while (au_src.Read(&data)) {
156  for (int32 i = 0; i < data.Dim(); i++) {
157  short sample = (short) data(i);
158  memcpy(&pack_buffer[i * 2], (char*) &sample, 2);
159  }
160 
161  int32 size = data.Dim() * 2;
162  WriteFull(client_desc, (char*) &size, 4);
163 
164  WriteFull(client_desc, pack_buffer, size);
165  }
166 
167  //send last packet
168  int32 size = 0;
169  WriteFull(client_desc, (char*) &size, 4);
170 
171  std::string reco_output;
172  std::vector<RecognizedWord> results;
173  float total_input_dur = 0.0f, total_reco_dur = 0.0f;
174 
175  while (true) {
176  std::string line;
177  if (!ReadLine(client_desc, &line))
178  KALDI_ERR << "Server disconnected!";
179 
180  if (line.substr(0, 7) != "RESULT:") {
181  if (line.substr(0, 8) == "PARTIAL:") {
182  std::cout << line.substr(8) << " " << std::flush;
183  continue;
184  }
185  KALDI_ERR << "Header parse error: " << line;
186  }
187 
188  std::cout << std::endl;
189 
190  if (line == "RESULT:DONE")
191  break;
192 
193  int32 res_num = 0;
194  float input_dur = 0;
195  float reco_dur = 0;
196 
197  std::string tok, key, val;
198  size_t beg = 7, end, eq;
199 
200  do {
201  end = line.find_first_of(',', beg);
202  tok = line.substr(beg, end - beg);
203  beg = end + 1;
204  eq = tok.find_first_of('=');
205  if (eq == std::string::npos || eq >= tok.size() - 1) {
206  KALDI_WARN << "Error parsing header token " << tok;
207  continue;
208  }
209 
210  key = tok.substr(0, eq);
211  val = tok.substr(eq + 1);
212 
213  if (key == "NUM") {
214  res_num = strtol(val.c_str(), 0, 10);
215  } else if (key == "FORMAT") {
216  if (val != "WSE") {
217  KALDI_ERR << "Only WSE format supported by this program!";
218  }
219  } else if (key == "RECO-DUR") {
220  reco_dur = strtof(val.c_str(), 0);
221  } else if (key == "INPUT-DUR") {
222  input_dur = strtof(val.c_str(), 0);
223  } else {
224  KALDI_WARN << "Unknown header key: " << key;
225  }
226  } while (end != std::string::npos);
227 
228  total_input_dur += input_dur;
229  total_reco_dur += reco_dur;
230 
231  for (int32 i = 0; i < res_num; i++) {
232  std::string line;
233  if (!ReadLine(client_desc, &line))
234  KALDI_ERR << "Server disconnected!";
235 
236  std::string word_str, start_str, end_str;
237 
238  end = line.find_first_of(',');
239  word_str = line.substr(0, end);
240  beg = end + 1;
241  end = line.find_first_of(',', beg);
242  start_str = line.substr(beg, end - beg);
243  beg = end + 1;
244  end = line.find_first_of(',', beg);
245  end_str = line.substr(beg, end - beg);
246 
247  RecognizedWord word;
248  word.word = word_str;
249  word.start = strtof(start_str.c_str(), 0);
250  word.end = strtof(end_str.c_str(), 0);
251 
252  results.push_back(word);
253 
254  reco_output += word_str + " ";
255  }
256  }
257 
258  {
259  float speed = total_input_dur / total_reco_dur;
260  KALDI_VLOG(2) << "Recognized (" << speed << "xRT): " << reco_output;
261  }
262 
263  if (htk) {
264  std::string name = wav_key + ".lab";
265  std::ofstream htk_file(name.c_str());
266  for (size_t i = 0; i < results.size(); i++)
267  htk_file << (int) (results[i].start * 10000000) << " "
268  << (int) (results[i].end * 10000000) << " "
269  << results[i].word << "\n";
270  htk_file.close();
271  }
272 
273  if (vtt && !results.empty()) {
274  std::vector<RecognizedWord> subtitles;
275  RecognizedWord subtitle_cue;
276 
277  subtitle_cue.start = -1;
278  subtitle_cue.end = -1;
279  subtitle_cue.word = "";
280 
281  for (size_t i = 0; i < results.size(); i++) {
282  if (subtitle_cue.end >= 0) {
283  if (results[i].start - subtitle_cue.end > 3.0f
284  || results[i].word.size() + subtitle_cue.word.size() > 64) {
285 
286  if (results[i].start - subtitle_cue.end < 0.1f)
287  subtitle_cue.end = results[i].start - 0.1f;
288 
289  subtitles.push_back(subtitle_cue);
290  subtitle_cue.start = -1;
291  subtitle_cue.end = -1;
292  subtitle_cue.word = "";
293 
294  }
295  }
296 
297  if (subtitle_cue.start < 0)
298  subtitle_cue.start = results[i].start;
299  else
300  subtitle_cue.word += " ";
301 
302  subtitle_cue.end = results[i].end + 1.0f;
303 
304  subtitle_cue.word += results[i].word;
305  }
306 
307  subtitles.push_back(subtitle_cue);
308 
309  std::string name = wav_key + ".vtt";
310  std::ofstream vtt_file(name.c_str());
311 
312  vtt_file << "WEBVTT FILE\n\n";
313 
314  for (size_t i = 0; i < subtitles.size(); i++)
315  vtt_file << (i + 1) << "\n"
316  << TimeToTimecode(subtitles[i].start) << " --> "
317  << TimeToTimecode(subtitles[i].end) << "\n"
318  << subtitles[i].word << "\n\n";
319 
320  vtt_file.close();
321  }
322  }
323 
324  close(client_desc);
325  delete[] pack_buffer;
326  }
327 
328  catch (const std::exception& e) {
329  std::cerr << e.what();
330  return -1;
331  }
332 
333 #endif
334  return 0;
335 }
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
kaldi::int32 int32
bool ReadLine(int32 desc, std::string *str)
BaseFloat SampFreq() const
Definition: wave-reader.h:126
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
std::string TimeToTimecode(float time)
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
bool WriteFull(int32 desc, char *data, int32 size)
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156