33 #include <netinet/in.h> 34 #include <sys/socket.h> 35 #include <sys/types.h> 38 #include <arpa/inet.h> 56 bool Write(
const std::string &msg);
57 bool WriteLn(
const std::string &msg,
const std::string &eol =
"\n");
72 std::vector<int32> alignment;
73 std::vector<int32>
words;
76 std::ostringstream msg;
77 for (
size_t i = 0;
i < words.size();
i++) {
78 std::string s = word_syms.Find(words[
i]);
80 KALDI_WARN <<
"Word-id " << words[
i] <<
" not in symbol table.";
81 msg <<
"<#" << std::to_string(i) <<
"> ";
90 double t_beg2 = t_beg * time_unit;
91 double t_end2 = t_end * time_unit;
92 snprintf(buffer, 100,
"%.2f %.2f", t_beg2, t_end2);
93 return std::string(buffer);
97 std::vector<int32> times;
103 if (clat.NumStates() == 0) {
116 int main(
int argc,
char *argv[]) {
118 using namespace kaldi;
122 typedef kaldi::int64 int64;
125 "Reads in audio from a network socket and performs online\n" 126 "decoding with neural nets (nnet3 setup), with iVector-based\n" 127 "speaker adaptation and endpointing.\n" 128 "Note: some configuration values and inputs are set via config\n" 129 "files whose filenames are passed as options\n" 131 "Usage: online2-tcp-nnet3-decode-faster [options] <nnet3-in> " 132 "<fst-in> <word-symbol-table>\n";
148 int read_timeout = 3;
149 bool produce_time =
false;
151 po.
Register(
"samp-freq", &samp_freq,
152 "Sampling frequency of the input signal (coded as 16-bit slinear).");
153 po.
Register(
"chunk-length", &chunk_length_secs,
154 "Length of chunk size in seconds, that we process.");
155 po.
Register(
"output-period", &output_period,
156 "How often in seconds, do we check for changes in output.");
158 "Number of threads used when initializing iVector extractor.");
159 po.
Register(
"read-timeout", &read_timeout,
160 "Number of seconds of timout for TCP audio data to appear on the stream. Use -1 for blocking.");
162 "Port number the server will listen on.");
163 po.
Register(
"produce-time", &produce_time,
164 "Prepend begin/end times between endpoints (e.g. '5.46 6.81 <text_output>', in seconds)");
178 std::string nnet3_rxfilename = po.
GetArg(1),
179 fst_rxfilename = po.
GetArg(2),
180 word_syms_filename = po.
GetArg(3);
193 Input ki(nnet3_rxfilename, &binary);
211 fst::SymbolTable *word_syms = NULL;
212 if (!word_syms_filename.empty())
213 if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
214 KALDI_ERR <<
"Could not read symbol table from file " 215 << word_syms_filename;
217 signal(SIGPIPE, SIG_IGN);
221 server.Listen(port_num);
227 int32 samp_count = 0;
228 size_t chunk_len =
static_cast<size_t>(chunk_length_secs * samp_freq);
229 int32 check_period =
static_cast<int32
>(samp_freq * output_period);
230 int32 check_count = check_period;
232 int32 frame_offset = 0;
239 *decode_fst, &feature_pipeline);
248 std::vector<std::pair<int32, BaseFloat>> delta_weights;
251 eos = !server.ReadChunk(chunk_len);
266 int32 t_end = frame_offset;
267 msg =
GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) +
" " + msg;
270 KALDI_VLOG(1) <<
"EndOfAudio, sending message: " << msg;
278 Vector<BaseFloat> wave_part = server.GetChunk();
280 samp_count += chunk_len;
282 if (silence_weighting.
Active() &&
293 if (samp_count > check_count) {
302 int32 t_beg = frame_offset;
304 msg =
GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) +
" " + msg;
307 KALDI_VLOG(1) <<
"Temporary transcript: " << msg;
308 server.WriteLn(msg,
"\r");
310 check_count += check_period;
323 int32 t_end = frame_offset;
324 msg =
GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) +
" " + msg;
327 KALDI_VLOG(1) <<
"Endpoint, sending message: " << msg;
334 }
catch (
const std::exception &e) {
335 std::cerr << e.what();
351 h_addr_.sin_addr.s_addr = INADDR_ANY;
352 h_addr_.sin_port = htons(port);
358 KALDI_ERR <<
"Cannot create TCP socket!";
364 if (setsockopt(
server_desc_, SOL_SOCKET, SO_REUSEADDR, &flag, len) == -1) {
365 KALDI_ERR <<
"Cannot set socket options!";
370 KALDI_ERR <<
"Cannot bind to port: " << port <<
" (is it taken?)";
379 KALDI_LOG <<
"TcpServer: Listening on port: " << port;
397 len =
sizeof(
struct sockaddr);
400 struct sockaddr_storage addr;
404 getpeername(
client_desc_, (
struct sockaddr *) &addr, &len);
406 struct sockaddr_in *s = (
struct sockaddr_in *) &addr;
407 inet_ntop(AF_INET, &s->sin_addr, ipstr,
sizeof ipstr);
412 KALDI_LOG <<
"Accepted connection from: " << ipstr;
426 size_t to_read = len;
428 while (to_read > 0) {
431 KALDI_WARN <<
"Socket timeout! Disconnecting...";
435 KALDI_WARN <<
"Socket error! Disconnecting...";
443 to_read -= ret /
sizeof(int16);
463 const char *p = msg.c_str();
464 size_t to_write = msg.size();
466 while (to_write > 0) {
467 ssize_t ret = write(
client_desc_, static_cast<const void *>(p + wrote), to_write);
bool Write(const std::string &msg)
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
void CollapseModel(const CollapseModelConfig &config, Nnet *nnet)
This function modifies the neural net for efficiency, in a way that suitable to be done in test time...
This configuration class is to set up OnlineNnet2FeaturePipelineInfo, which in turn is the configurat...
struct ::sockaddr_in h_addr_
int32 frame_subsampling_factor
void UpdateFrameWeights(const std::vector< std::pair< int32, BaseFloat > > &delta_weights)
If you are downweighting silence, you can call OnlineSilenceWeighting::GetDeltaWeights and supply the...
Fst< StdArc > * ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err)
int32 LatticeStateTimes(const Lattice &lat, vector< int32 > *times)
This function iterates over the states of a topologically sorted lattice and counts the time instance...
For an extended explanation of the framework of which grammar-fsts are a part, please see Support for...
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
void SetBatchnormTestMode(bool test_mode, Nnet *nnet)
This function affects only components of type BatchNormComponent.
void InputFinished()
If you call InputFinished(), it tells the class you won't be providing any more waveform.
struct sockaddr_in h_addr_
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
OnlineIvectorFeature * IvectorFeature()
This function returns the iVector-extracting part of the feature pipeline (or NULL if iVectors are no...
virtual int32 NumFramesReady() const
returns the feature dimension.
const Nnet & GetNnet() const
void Register(OptionsItf *opts)
bool GetLinearSymbolSequence(const Fst< Arc > &fst, std::vector< I > *isymbols_out, std::vector< I > *osymbols_out, typename Arc::Weight *tot_weight_out)
GetLinearSymbolSequence gets the symbol sequence from a linear FST.
This file contains a different version of the feature-extraction pipeline in online-feature-pipeline...
This class is responsible for storing configuration variables, objects and options for OnlineNnet2Fea...
void Read(std::istream &is, bool binary)
void Register(const std::string &name, bool *ptr, const std::string &doc)
This file contains some miscellaneous functions dealing with class Nnet.
void InitDecoding(int32 frame_offset=0)
Initializes the decoding and sets the frame offset of the underlying decodable object.
void CompactLatticeShortestPath(const CompactLattice &clat, CompactLattice *shortest_path)
A form of the shortest-path/best-path algorithm that's specially coded for CompactLattice.
void SetDropoutTestMode(bool test_mode, Nnet *nnet)
This function affects components of child-classes of RandomComponent.
BaseFloat FrameShiftInSeconds() const
void AcceptWaveform(BaseFloat sampling_rate, const VectorBase< BaseFloat > &waveform)
Accept more data to process.
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl< FST > &decoder)
int32 GetLatticeTimeSpan(const Lattice &lat)
void Read(std::istream &is, bool binary)
void ConvertLattice(const ExpandedFst< ArcTpl< Weight > > &ifst, MutableFst< ArcTpl< CompactLatticeWeightTpl< Weight, Int > > > *ofst, bool invert)
Convert lattice from a normal FST to a CompactLattice FST.
void GetLattice(bool end_of_utterance, CompactLattice *clat) const
Gets the lattice.
int32 NumFramesDecoded() const
int main(int argc, char *argv[])
bool EndpointDetected(const OnlineEndpointConfig &config)
This function calls EndpointDetected from online-endpoint.h, with the required arguments.
You will instantiate this class when you want to decode a single utterance using the online-decoding ...
fst::VectorFst< LatticeArc > Lattice
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
void Register(OptionsItf *opts)
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
fst::VectorFst< CompactLatticeArc > CompactLattice
void GetBestPath(bool end_of_utterance, Lattice *best_path) const
Outputs an FST corresponding to the single best path through the current lattice. ...
int NumArgs() const
Number of positional parameters (c.f. argc-1).
A class representing a vector.
OnlineNnet2FeaturePipeline is a class that's responsible for putting together the various parts of th...
OnlineSilenceWeightingConfig silence_weighting_config
Config for weighting silence in iVector adaptation.
void AdvanceDecoding()
Advances the decoding as far as we can.
std::string LatticeToString(const Lattice &lat, const fst::SymbolTable &word_syms)
std::string GetTimeString(int32 t_beg, int32 t_end, BaseFloat time_unit)
Vector< BaseFloat > GetChunk()
void Register(OptionsItf *opts)
bool WriteLn(const std::string &msg, const std::string &eol="\)
When you instantiate class DecodableNnetSimpleLooped, you should give it a const reference to this cl...
void Register(OptionsItf *opts)
const LatticeFasterOnlineDecoderTpl< FST > & Decoder() const
void FinalizeDecoding()
Finalizes the decoding.
bool ReadChunk(size_t len)
Config class for the CollapseModel function.
void GetDeltaWeights(int32 num_frames_ready, int32 first_decoder_frame, std::vector< std::pair< int32, BaseFloat > > *delta_weights)