online-audio-client.cc
Go to the documentation of this file.
1 // onlinebin/online-audio-client.cc
2 
3 // Copyright 2012 Cisco Systems (author: Matthias Paulik)
4 // Copyright 2013 Polish-Japanese Institute of Information Technology (author: Danijel Korzinek)
5 
6 // Modifications to the original contribution by Cisco Systems made by:
7 // Vassil Panayotov
8 
9 // See ../../COPYING for clarification regarding multiple authors
10 //
11 // Licensed under the Apache License, Version 2.0 (the "License");
12 // you may not use this file except in compliance with the License.
13 // You may obtain a copy of the License at
14 //
15 // http://www.apache.org/licenses/LICENSE-2.0
16 //
17 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
18 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
19 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
20 // MERCHANTABLITY OR NON-INFRINGEMENT.
21 // See the Apache 2 License for the specific language governing permissions and
22 // limitations under the License.
23 
24 #include <iostream>
25 #if !defined(_MSC_VER)
26 #include <sys/types.h>
27 #include <sys/socket.h>
28 #include <netinet/in.h>
29 #include <arpa/inet.h>
30 #include <netdb.h>
31 #include <unistd.h>
32 #endif
33 
34 #include "util/parse-options.h"
35 #include "util/kaldi-table.h"
36 #include "feat/wave-reader.h"
38 
39 namespace kaldi {
40 
41 bool WriteFull(int32 desc, char* data, int32 size);
42 bool ReadLine(int32 desc, std::string* str);
43 std::string TimeToTimecode(float time);
44 
46  std::string word;
47  float start, end;
48 };
49 
50 } //namespace kaldi
51 
52 int main(int argc, char** argv) {
53  using namespace kaldi;
54  typedef kaldi::int32 int32;
55  #if !defined(_MSC_VER)
56  try {
57 
58  const char *usage =
59  "Sends an audio file to the KALDI audio server (onlinebin/online-audio-server-decode-faster)\n"
60  "and prints the result optionally saving it to an HTK label file or WebVTT subtitles file\n\n"
61  "e.g.: ./online-audio-client 192.168.50.12 9012 'scp:wav_files.scp'\n\n";
62  ParseOptions po(usage);
63 
64  bool htk = false, vtt = false;
65  int32 channel = -1;
66  int32 packet_size = 1024;
67 
68  po.Register("htk", &htk, "Save the result to an HTK label file");
69  po.Register("vtt", &vtt, "Save the result to a WebVTT subtitle file");
70  po.Register(
71  "channel", &channel,
72  "Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right)");
73  po.Register("packet-size", &packet_size, "Send this many bytes per packet");
74 
75  po.Read(argc, argv);
76  if (po.NumArgs() != 3) {
77  po.PrintUsage();
78  return 1;
79  }
80 
81  std::string server_addr_str = po.GetArg(1);
82  std::string server_port_str = po.GetArg(2);
83  int32 server_port = strtol(server_port_str.c_str(), 0, 10);
84  std::string wav_rspecifier = po.GetArg(3);
85 
86  int32 client_desc = socket(AF_INET, SOCK_STREAM, 0);
87  if (client_desc == -1) {
88  std::cerr << "ERROR: couldn't create socket!\n";
89  return -1;
90  }
91 
92  struct hostent* hp;
93  unsigned long addr;
94 
95  addr = inet_addr(server_addr_str.c_str());
96  if (addr == INADDR_NONE) {
97  hp = gethostbyname(server_addr_str.c_str());
98  if (hp == NULL) {
99  std::cerr << "ERROR: couldn't resolve host string: "
100  << server_addr_str << '\n';
101  close(client_desc);
102  return -1;
103  }
104 
105  addr = *((unsigned long*) hp->h_addr);
106  }
107 
108  sockaddr_in server;
109  server.sin_addr.s_addr = addr;
110  server.sin_family = AF_INET;
111  server.sin_port = htons(server_port);
112  if (::connect(client_desc, (struct sockaddr*) &server, sizeof(server))) {
113  std::cerr << "ERROR: couldn't connect to server!\n";
114  close(client_desc);
115  return -1;
116  }
117 
118  KALDI_VLOG(2) << "Connected to KALDI server at host " << server_addr_str
119  << " port " << server_port;
120 
121  char* pack_buffer = new char[packet_size];
122 
123  SequentialTableReader < WaveHolder > reader(wav_rspecifier);
124  for (; !reader.Done(); reader.Next()) {
125  std::string wav_key = reader.Key();
126 
127  KALDI_VLOG(2) << "File: " << wav_key;
128 
129  const WaveData &wav_data = reader.Value();
130 
131  if (wav_data.SampFreq() != 16000)
132  KALDI_ERR << "Sampling rates other than 16kHz are not supported!";
133 
134  int32 num_chan = wav_data.Data().NumRows(), this_chan = channel;
135  { // This block works out the channel (0=left, 1=right...)
136  KALDI_ASSERT(num_chan > 0); // should have been caught in
137  // reading code if no channels.
138  if (channel == -1) {
139  this_chan = 0;
140  if (num_chan != 1)
141  KALDI_WARN << "Channel not specified but you have data with "
142  << num_chan << " channels; defaulting to zero";
143  } else {
144  if (this_chan >= num_chan) {
145  KALDI_WARN << "File with id " << wav_key << " has " << num_chan
146  << " channels but you specified channel " << channel
147  << ", producing no output.";
148  continue;
149  }
150  }
151  }
152 
153  OnlineVectorSource au_src(wav_data.Data().Row(this_chan));
154  Vector < BaseFloat > data(packet_size / 2);
155  while (au_src.Read(&data)) {
156  for (int32 i = 0; i < data.Dim(); i++) {
157  short sample = (short) data(i);
158  memcpy(&pack_buffer[i * 2], (char*) &sample, 2);
159  }
160 
161  int32 size = data.Dim() * 2;
162  WriteFull(client_desc, (char*) &size, 4);
163 
164  WriteFull(client_desc, pack_buffer, size);
165  }
166 
167  //send last packet
168  int32 size = 0;
169  WriteFull(client_desc, (char*) &size, 4);
170 
171  std::string reco_output;
172  std::vector<RecognizedWord> results;
173  float total_input_dur = 0.0f, total_reco_dur = 0.0f;
174 
175  while (true) {
176  std::string line;
177  if (!ReadLine(client_desc, &line))
178  KALDI_ERR << "Server disconnected!";
179 
180  if (line.substr(0, 7) != "RESULT:") {
181  if (line.substr(0, 8) == "PARTIAL:") {
182  std::cout << line.substr(8) << " " << std::flush;
183  continue;
184  }
185  KALDI_ERR << "Header parse error: " << line;
186  }
187 
188  std::cout << std::endl;
189 
190  if (line == "RESULT:DONE")
191  break;
192 
193  int32 res_num = 0;
194  float input_dur = 0;
195  float reco_dur = 0;
196 
197  std::string tok, key, val;
198  size_t beg = 7, end, eq;
199 
200  do {
201  end = line.find_first_of(',', beg);
202  tok = line.substr(beg, end - beg);
203  beg = end + 1;
204  eq = tok.find_first_of('=');
205  if (eq == std::string::npos || eq >= tok.size() - 1) {
206  KALDI_WARN << "Error parsing header token " << tok;
207  continue;
208  }
209 
210  key = tok.substr(0, eq);
211  val = tok.substr(eq + 1);
212 
213  if (key == "NUM") {
214  res_num = strtol(val.c_str(), 0, 10);
215  } else if (key == "FORMAT") {
216  if (val != "WSE") {
217  KALDI_ERR << "Only WSE format supported by this program!";
218  }
219  } else if (key == "RECO-DUR") {
220  reco_dur = strtof(val.c_str(), 0);
221  } else if (key == "INPUT-DUR") {
222  input_dur = strtof(val.c_str(), 0);
223  } else {
224  KALDI_WARN << "Unknown header key: " << key;
225  }
226  } while (end != std::string::npos);
227 
228  total_input_dur += input_dur;
229  total_reco_dur += reco_dur;
230 
231  for (int32 i = 0; i < res_num; i++) {
232  std::string line;
233  if (!ReadLine(client_desc, &line))
234  KALDI_ERR << "Server disconnected!";
235 
236  std::string word_str, start_str, end_str;
237 
238  end = line.find_first_of(',');
239  word_str = line.substr(0, end);
240  beg = end + 1;
241  end = line.find_first_of(',', beg);
242  start_str = line.substr(beg, end - beg);
243  beg = end + 1;
244  end = line.find_first_of(',', beg);
245  end_str = line.substr(beg, end - beg);
246 
248  word.word = word_str;
249  word.start = strtof(start_str.c_str(), 0);
250  word.end = strtof(end_str.c_str(), 0);
251 
252  results.push_back(word);
253 
254  reco_output += word_str + " ";
255  }
256  }
257 
258  {
259  float speed = total_input_dur / total_reco_dur;
260  KALDI_VLOG(2) << "Recognized (" << speed << "xRT): " << reco_output;
261  }
262 
263  if (htk) {
264  std::string name = wav_key + ".lab";
265  std::ofstream htk_file(name.c_str());
266  for (size_t i = 0; i < results.size(); i++)
267  htk_file << (int) (results[i].start * 10000000) << " "
268  << (int) (results[i].end * 10000000) << " "
269  << results[i].word << "\n";
270  htk_file.close();
271  }
272 
273  if (vtt && !results.empty()) {
274  std::vector<RecognizedWord> subtitles;
275  RecognizedWord subtitle_cue;
276 
277  subtitle_cue.start = -1;
278  subtitle_cue.end = -1;
279  subtitle_cue.word = "";
280 
281  for (size_t i = 0; i < results.size(); i++) {
282  if (subtitle_cue.end >= 0) {
283  if (results[i].start - subtitle_cue.end > 3.0f
284  || results[i].word.size() + subtitle_cue.word.size() > 64) {
285 
286  if (results[i].start - subtitle_cue.end < 0.1f)
287  subtitle_cue.end = results[i].start - 0.1f;
288 
289  subtitles.push_back(subtitle_cue);
290  subtitle_cue.start = -1;
291  subtitle_cue.end = -1;
292  subtitle_cue.word = "";
293 
294  }
295  }
296 
297  if (subtitle_cue.start < 0)
298  subtitle_cue.start = results[i].start;
299  else
300  subtitle_cue.word += " ";
301 
302  subtitle_cue.end = results[i].end + 1.0f;
303 
304  subtitle_cue.word += results[i].word;
305  }
306 
307  subtitles.push_back(subtitle_cue);
308 
309  std::string name = wav_key + ".vtt";
310  std::ofstream vtt_file(name.c_str());
311 
312  vtt_file << "WEBVTT FILE\n\n";
313 
314  for (size_t i = 0; i < subtitles.size(); i++)
315  vtt_file << (i + 1) << "\n"
316  << TimeToTimecode(subtitles[i].start) << " --> "
317  << TimeToTimecode(subtitles[i].end) << "\n"
318  << subtitles[i].word << "\n\n";
319 
320  vtt_file.close();
321  }
322  }
323 
324  close(client_desc);
325  delete[] pack_buffer;
326  }
327 
328  catch (const std::exception& e) {
329  std::cerr << e.what();
330  return -1;
331  }
332 
333 #endif
334  return 0;
335 }
336 
337 
338 namespace kaldi {
339 
340 bool WriteFull(int32 desc, char* data, int32 size) {
341  int32 to_write = size;
342  int32 wrote = 0;
343  while (to_write > 0) {
344  int32 ret = write(desc, data + wrote, to_write);
345  if (ret <= 0)
346  return false;
347 
348  to_write -= ret;
349  wrote += ret;
350  }
351 
352  return true;
353 }
354 
357 char read_buffer[1025];
358 
359 bool ReadLine(int32 desc, std::string* str) {
360  *str = "";
361 
362  while (true) {
363  if (buffer_offset >= buffer_fill) {
364  buffer_fill = read(desc, read_buffer, 1024);
365 
366  if (buffer_fill <= 0)
367  break;
368 
369  buffer_offset = 0;
370  }
371 
372  for (int32 i = buffer_offset; i < buffer_fill; i++) {
373  if (read_buffer[i] == '\r' || read_buffer[i] == '\n') {
374  read_buffer[i] = 0;
375  *str += (read_buffer + buffer_offset);
376 
377  buffer_offset = i + 1;
378 
379  if (i < buffer_fill) {
380  if (read_buffer[i] == '\n' && read_buffer[i + 1] == '\r') {
381  read_buffer[i + 1] = 0;
382  buffer_offset = i + 2;
383  }
384  if (read_buffer[i] == '\r' && read_buffer[i + 1] == '\n') {
385  read_buffer[i + 1] = 0;
386  buffer_offset = i + 2;
387  }
388  }
389 
390  return true;
391  }
392  }
393 
394  read_buffer[buffer_fill] = 0;
395  *str += (read_buffer + buffer_offset);
396  buffer_offset = buffer_fill;
397  }
398 
399  return false;
400 }
401 
402 std::string TimeToTimecode(float time) {
403 
404  char buf[64];
405 
406  int32 h, m, s, ms;
407  s = (int32) time;
408  ms = (int32)((time - (float) s) * 1000.0f);
409  m = s / 60;
410  s %= 60;
411  h = m / 60;
412  m %= 60;
413 
414 #if !defined(_MSC_VER)
415  snprintf(buf, 64, "%02d:%02d:%02d.%03d", h, m, s, ms);
416 #endif
417 
418  return buf;
419 }
420 
421 } //namespace kaldi
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
kaldi::int32 int32
bool ReadLine(int32 desc, std::string *str)
BaseFloat SampFreq() const
Definition: wave-reader.h:126
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
void Register(const std::string &name, bool *ptr, const std::string &doc)
int32 buffer_offset
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
const SubVector< Real > Row(MatrixIndexT i) const
Return specific row of matrix [const].
Definition: kaldi-matrix.h:188
char read_buffer[1025]
A templated class for reading objects sequentially from an archive or script file; see The Table conc...
Definition: kaldi-table.h:287
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
#define KALDI_WARN
Definition: kaldi-error.h:150
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
int NumArgs() const
Number of positional parameters (c.f. argc-1).
std::string TimeToTimecode(float time)
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
bool WriteFull(int32 desc, char *data, int32 size)
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
int main(int argc, char **argv)