wav-reverberate.cc
Go to the documentation of this file.
1 // featbin/wav-reverberate.cc
2 
3 // Copyright 2015 Tom Ko
4 
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
19 
20 #include "base/kaldi-common.h"
21 #include "util/common-utils.h"
22 #include "feat/wave-reader.h"
23 #include "feat/signal.h"
24 
25 namespace kaldi {
26 
27 /*
28  This function is to repeatedly concatenate signal1 by itself
29  to match the length of signal2 and add the two signals together.
30 */
32  Vector<BaseFloat> *signal2) {
33  for (int32 po = 0; po < signal2->Dim(); po += signal1.Dim()) {
34  int32 block_length = signal1.Dim();
35  if (signal2->Dim() - po < block_length) block_length = signal2->Dim() - po;
36  signal2->Range(po, block_length).AddVec(1.0, signal1.Range(0, block_length));
37  }
38 }
39 
40 /*
41  This function is to add signal1 to signal2 starting at the offset of signal2
42  This will not extend the length of signal2.
43 */
44 void AddVectorsWithOffset(const Vector<BaseFloat> &signal1, int32 offset,
45  Vector<BaseFloat> *signal2) {
46  int32 add_length = std::min(signal2->Dim() - offset, signal1.Dim());
47  if (add_length > 0)
48  signal2->Range(offset, add_length).AddVec(1.0, signal1.Range(0, add_length));
49 }
50 
51 
53  return std::max(std::abs(vector.Max()), std::abs(vector.Min()));
54 }
55 
56 /*
57  Early reverberation component of the signal is composed of reflections
58  within 0.05 seconds of the direct path signal (assumed to be the peak of
59  the room impulse response). This function returns the energy in
60  this early reverberation component of the signal.
61  The input parameters to this function are the room impulse response, the signal
62  and their sampling frequency respectively.
63 */
65  BaseFloat samp_freq) {
66  int32 peak_index = 0;
67  rir.Max(&peak_index);
68  KALDI_VLOG(1) << "peak index is " << peak_index;
69 
70  const float sec_before_peak = 0.001;
71  const float sec_after_peak = 0.05;
72  int32 early_rir_start_index = peak_index - sec_before_peak * samp_freq;
73  int32 early_rir_end_index = peak_index + sec_after_peak * samp_freq;
74  if (early_rir_start_index < 0) early_rir_start_index = 0;
75  if (early_rir_end_index > rir.Dim()) early_rir_end_index = rir.Dim();
76 
77  int32 duration = early_rir_end_index - early_rir_start_index;
78  Vector<BaseFloat> early_rir(rir.Range(early_rir_start_index, duration));
79  Vector<BaseFloat> early_reverb(signal);
80  FFTbasedBlockConvolveSignals(early_rir, &early_reverb);
81 
82  // compute the energy
83  return VecVec(early_reverb, early_reverb) / early_reverb.Dim();
84 }
85 
86 /*
87  This is the core function to do reverberation on the given signal.
88  The input parameters to this function are the room impulse response,
89  the sampling frequency and the signal respectively.
90  The length of the signal will be extended to (original signal length +
91  rir length - 1) after the reverberation.
92 */
93 float DoReverberation(const Vector<BaseFloat> &rir, BaseFloat samp_freq,
94  Vector<BaseFloat> *signal) {
95  float signal_power = ComputeEarlyReverbEnergy(rir, *signal, samp_freq);
96  FFTbasedBlockConvolveSignals(rir, signal);
97  return signal_power;
98 }
99 
100 /*
101  The noise will be scaled before the addition
102  to match the given signal-to-noise ratio (SNR).
103 */
105  BaseFloat time, BaseFloat samp_freq,
106  BaseFloat signal_power, Vector<BaseFloat> *signal) {
107  float noise_power = VecVec(*noise, *noise) / noise->Dim();
108  float scale_factor = sqrt(pow(10, -snr_db / 10) * signal_power / noise_power);
109  noise->Scale(scale_factor);
110  KALDI_VLOG(1) << "Noise signal is being scaled with " << scale_factor
111  << " to generate output with SNR " << snr_db << "db\n";
112  int32 offset = time * samp_freq;
113  AddVectorsWithOffset(*noise, offset, signal);
114 }
115 
116 /*
117  This function converts comma-spearted string into float vector.
118 */
119 void ReadCommaSeparatedCommand(const std::string &s,
120  std::vector<BaseFloat> *v) {
121  std::vector<std::string> split_string;
122  SplitStringToVector(s, ",", true, &split_string);
123  for (size_t i = 0; i < split_string.size(); i++) {
124  float ret;
125  ConvertStringToReal(split_string[i], &ret);
126  v->push_back(ret);
127  }
128 }
129 }
130 
131 int main(int argc, char *argv[]) {
132  try {
133  using namespace kaldi;
134 
135  const char *usage =
136  "Corrupts the wave files supplied via input pipe with the specified\n"
137  "room-impulse response (rir_matrix) and additive noise distortions\n"
138  "(specified by corresponding files).\n"
139  "Usage: wav-reverberate [options...] <wav-in-rxfilename> "
140  "<wav-out-wxfilename>\n"
141  "e.g.\n"
142  "wav-reverberate --duration=20.25 --impulse-response=rir.wav "
143  "--additive-signals='noise1.wav,noise2.wav' --snrs='20.0,15.0' "
144  "--start-times='0,17.8' input.wav output.wav\n";
145 
146  ParseOptions po(usage);
147  std::string rir_file;
148  std::string additive_signals;
149  std::string snrs;
150  std::string start_times;
151  bool multi_channel_output = false;
152  bool shift_output = true;
153  int32 input_channel = 0;
154  int32 rir_channel = 0;
155  int32 noise_channel = 0;
156  bool normalize_output = true;
157  BaseFloat volume = 0;
158  BaseFloat duration = 0;
159 
160  po.Register("multi-channel-output", &multi_channel_output,
161  "Specifies if the output should be multi-channel or not");
162  po.Register("shift-output", &shift_output,
163  "If true, the reverberated waveform will be shifted by the "
164  "amount of the peak position of the RIR and the length of "
165  "the output waveform will be equal to the input waveform. "
166  "If false, the length of the output waveform will be "
167  "equal to (original input length + rir length - 1). "
168  "This value is true by default and "
169  "it only affects the output when RIR file is provided.");
170  po.Register("input-wave-channel", &input_channel,
171  "Specifies the channel to be used from input as only a "
172  "single channel will be used to generate reverberated output");
173  po.Register("rir-channel", &rir_channel,
174  "Specifies the channel of the room impulse response, "
175  "it will only be used when multi-channel-output is false");
176  po.Register("noise-channel", &noise_channel,
177  "Specifies the channel of the noise file, "
178  "it will only be used when multi-channel-output is false");
179  po.Register("impulse-response", &rir_file,
180  "File with the impulse response for reverberating the input wave"
181  "It can be either a file in wav format or a piped command. "
182  "E.g. --impulse-response='rir.wav' or 'sox rir.wav - |' ");
183  po.Register("additive-signals", &additive_signals,
184  "A comma separated list of additive signals. "
185  "They can be either filenames or piped commands. "
186  "E.g. --additive-signals='noise1.wav,noise2.wav' or "
187  "'sox noise1.wav - |,sox noise2.wav - |'. "
188  "Requires --snrs and --start-times.");
189  po.Register("snrs", &snrs,
190  "A comma separated list of SNRs(dB). "
191  "The additive signals will be scaled according to these SNRs. "
192  "E.g. --snrs='20.0,0.0,5.0,10.0' ");
193  po.Register("start-times", &start_times,
194  "A comma separated list of start times referring to the "
195  "input signal. The additive signals will be added to the "
196  "input signal starting at the offset. If the start time "
197  "exceed the length of the input signal, the addition will "
198  "be ignored.");
199  po.Register("normalize-output", &normalize_output,
200  "If true, then after reverberating and "
201  "possibly adding noise, scale so that the signal "
202  "energy is the same as the original input signal. "
203  "See also the --volume option.");
204  po.Register("duration", &duration,
205  "If nonzero, it specified the duration (secs) of the output "
206  "signal. If the duration t is less than the length of the "
207  "input signal, the first t secs of the signal is trimmed, "
208  "otherwise, the signal will be repeated to "
209  "fulfill the duration specified.");
210  po.Register("volume", &volume,
211  "If nonzero, a scaling factor on the signal that is applied "
212  "after reverberating and possibly adding noise. "
213  "If you set this option to a nonzero value, it will be as "
214  "if you had also specified --normalize-output=false.");
215 
216  po.Read(argc, argv);
217  if (po.NumArgs() != 2) {
218  po.PrintUsage();
219  exit(1);
220  }
221 
222  if (multi_channel_output) {
223  if (rir_channel != 0 || noise_channel != 0)
224  KALDI_WARN << "options for --rir-channel and --noise-channel"
225  "are ignored as --multi-channel-output is true.";
226  }
227 
228  std::string input_wave_file = po.GetArg(1);
229  std::string output_wave_file = po.GetArg(2);
230 
231  WaveData input_wave;
232  {
233  WaveHolder waveholder;
234  Input ki(input_wave_file);
235  waveholder.Read(ki.Stream());
236  input_wave = waveholder.Value();
237  }
238 
239  const Matrix<BaseFloat> &input_matrix = input_wave.Data();
240  BaseFloat samp_freq_input = input_wave.SampFreq();
241  int32 num_samp_input = input_matrix.NumCols(), // #samples in the input
242  num_input_channel = input_matrix.NumRows(); // #channels in the input
243  KALDI_VLOG(1) << "sampling frequency of input: " << samp_freq_input
244  << " #samples: " << num_samp_input
245  << " #channel: " << num_input_channel;
246  KALDI_ASSERT(input_channel < num_input_channel);
247 
248  Matrix<BaseFloat> rir_matrix;
249  BaseFloat samp_freq_rir = samp_freq_input;
250  int32 num_samp_rir = 0,
251  num_rir_channel = 0;
252  if (!rir_file.empty()) {
253  WaveData rir_wave;
254  {
255  WaveHolder waveholder;
256  Input ki(rir_file);
257  waveholder.Read(ki.Stream());
258  rir_wave = waveholder.Value();
259  }
260  rir_matrix = rir_wave.Data();
261  samp_freq_rir = rir_wave.SampFreq();
262  num_samp_rir = rir_matrix.NumCols();
263  num_rir_channel = rir_matrix.NumRows();
264  KALDI_VLOG(1) << "sampling frequency of rir: " << samp_freq_rir
265  << " #samples: " << num_samp_rir
266  << " #channel: " << num_rir_channel;
267  if (!multi_channel_output) {
268  KALDI_ASSERT(rir_channel < num_rir_channel);
269  }
270  }
271 
272  std::vector<Matrix<BaseFloat> > additive_signal_matrices;
273  if (!additive_signals.empty()) {
274  if (snrs.empty() || start_times.empty())
275  KALDI_ERR << "--additive-signals option requires "
276  "--snrs and --start-times to be set.";
277  std::vector<std::string> split_string;
278  SplitStringToVector(additive_signals, ",", true, &split_string);
279  for (size_t i = 0; i < split_string.size(); i++) {
280  WaveHolder waveholder;
281  Input ki(split_string[i]);
282  waveholder.Read(ki.Stream());
283  WaveData additive_signal_wave = waveholder.Value();
284  Matrix<BaseFloat> additive_signal_matrix = additive_signal_wave.Data();
285  BaseFloat samp_freq = additive_signal_wave.SampFreq();
286  KALDI_ASSERT(samp_freq == samp_freq_input);
287  int32 num_samp = additive_signal_matrix.NumCols(),
288  num_channel = additive_signal_matrix.NumRows();
289  KALDI_VLOG(1) << "sampling frequency of additive signal: " << samp_freq
290  << " #samples: " << num_samp
291  << " #channel: " << num_channel;
292  if (multi_channel_output) {
293  KALDI_ASSERT(num_rir_channel == num_channel);
294  } else {
295  KALDI_ASSERT(noise_channel < num_channel);
296  }
297 
298  additive_signal_matrices.push_back(additive_signal_matrix);
299  }
300  }
301 
302  std::vector<BaseFloat> snr_vector;
303  if (!snrs.empty()) {
304  ReadCommaSeparatedCommand(snrs, &snr_vector);
305  }
306 
307  std::vector<BaseFloat> start_time_vector;
308  if (!start_times.empty()) {
309  ReadCommaSeparatedCommand(start_times, &start_time_vector);
310  }
311 
312  int32 shift_index = 0;
313  int32 num_output_channels = (multi_channel_output ? num_rir_channel : 1);
314  int32 num_samp_output = (duration > 0 ? samp_freq_input * duration :
315  (shift_output ? num_samp_input :
316  num_samp_input + num_samp_rir - 1));
317  Matrix<BaseFloat> out_matrix(num_output_channels, num_samp_output);
318 
319  for (int32 output_channel = 0; output_channel < num_output_channels; output_channel++) {
320  Vector<BaseFloat> input(num_samp_input);
321  input.CopyRowFromMat(input_matrix, input_channel);
322  float power_before_reverb = VecVec(input, input) / input.Dim();
323 
324  int32 this_rir_channel = (multi_channel_output ? output_channel : rir_channel);
325 
326  float early_energy = power_before_reverb;
327  if (!rir_file.empty()) {
328  Vector<BaseFloat> rir;
329  rir.Resize(num_samp_rir);
330  rir.CopyRowFromMat(rir_matrix, this_rir_channel);
331  rir.Scale(1.0 / (1 << 15));
332  early_energy = DoReverberation(rir, samp_freq_rir, &input);
333  if (shift_output) {
334  // find the position of the peak of the impulse response
335  // and shift the output waveform by this amount
336  rir.Max(&shift_index);
337  }
338  }
339 
340  if (additive_signal_matrices.size() > 0) {
341  Vector<BaseFloat> noise(0);
342  int32 this_noise_channel = (multi_channel_output ? output_channel : noise_channel);
343  KALDI_ASSERT(additive_signal_matrices.size() == snr_vector.size());
344  KALDI_ASSERT(additive_signal_matrices.size() == start_time_vector.size());
345  for (int32 i = 0; i < additive_signal_matrices.size(); i++) {
346  noise.Resize(additive_signal_matrices[i].NumCols());
347  noise.CopyRowFromMat(additive_signal_matrices[i], this_noise_channel);
348  AddNoise(&noise, snr_vector[i], start_time_vector[i],
349  samp_freq_input, early_energy, &input);
350  }
351  }
352 
353  float power_after_reverb = VecVec(input, input) / input.Dim();
354 
355  if (volume > 0)
356  input.Scale(volume);
357  else if (normalize_output)
358  input.Scale(sqrt(power_before_reverb / power_after_reverb));
359 
360  if (num_samp_output <= num_samp_input) {
361  // trim the signal from the start
362  out_matrix.CopyRowFromVec(input.Range(shift_index, num_samp_output), output_channel);
363  } else {
364  // repeat the signal to fill up the duration
365  Vector<BaseFloat> extended_input(num_samp_output);
366  extended_input.SetZero();
367  AddVectorsOfUnequalLength(input.Range(shift_index, num_samp_input), &extended_input);
368  out_matrix.CopyRowFromVec(extended_input, output_channel);
369  }
370  }
371 
372  WaveData out_wave(samp_freq_input, out_matrix);
373  Output ko(output_wave_file, false);
374  out_wave.Write(ko.Stream());
375 
376  return 0;
377  } catch(const std::exception &e) {
378  std::cerr << e.what();
379  return -1;
380  }
381 }
382 
This code computes Goodness of Pronunciation (GOP) and extracts phone-level pronunciation feature for...
Definition: chain.dox:20
bool Read(std::istream &is)
Definition: wave-reader.h:191
MatrixIndexT NumCols() const
Returns number of columns (or zero for empty matrix).
Definition: kaldi-matrix.h:67
void ReadCommaSeparatedCommand(const std::string &s, std::vector< BaseFloat > *v)
float DoReverberation(const Vector< BaseFloat > &rir, BaseFloat samp_freq, Vector< BaseFloat > *signal)
void PrintUsage(bool print_command_line=false)
Prints the usage documentation [provided in the constructor].
BaseFloat ComputeEarlyReverbEnergy(const Vector< BaseFloat > &rir, const Vector< BaseFloat > &signal, BaseFloat samp_freq)
kaldi::int32 int32
int main(int argc, char *argv[])
BaseFloat SampFreq() const
Definition: wave-reader.h:126
void Resize(MatrixIndexT length, MatrixResizeType resize_type=kSetZero)
Set vector to a specified size (can be zero).
Real Min() const
Returns the minimum value of any element, or +infinity for the empty vector.
const Matrix< BaseFloat > & Data() const
Definition: wave-reader.h:124
void CopyRowFromMat(const MatrixBase< Real > &M, MatrixIndexT row)
Extracts a row of the matrix M.
void Register(const std::string &name, bool *ptr, const std::string &doc)
void FFTbasedBlockConvolveSignals(const Vector< BaseFloat > &filter, Vector< BaseFloat > *signal)
Definition: signal.cc:77
std::istream & Stream()
Definition: kaldi-io.cc:826
float BaseFloat
Definition: kaldi-types.h:29
The class ParseOptions is for parsing command-line options; see Parsing command-line options for more...
Definition: parse-options.h:36
void AddVectorsOfUnequalLength(const VectorBase< BaseFloat > &signal1, Vector< BaseFloat > *signal2)
std::ostream & Stream()
Definition: kaldi-io.cc:701
void SplitStringToVector(const std::string &full, const char *delim, bool omit_empty_strings, std::vector< std::string > *out)
Split a string using any of the single character delimiters.
Definition: text-utils.cc:63
int Read(int argc, const char *const *argv)
Parses the command line options and fills the ParseOptions-registered variables.
#define KALDI_ERR
Definition: kaldi-error.h:147
Real Max() const
Returns the maximum value of any element, or -infinity for the empty vector.
bool ConvertStringToReal(const std::string &str, T *out)
ConvertStringToReal converts a string into either float or double and returns false if there was any ...
Definition: text-utils.cc:238
std::string GetArg(int param) const
Returns one of the positional parameters; 1-based indexing for argc/argv compatibility.
#define KALDI_WARN
Definition: kaldi-error.h:150
MatrixIndexT Dim() const
Returns the dimension of the vector.
Definition: kaldi-vector.h:64
void Scale(Real alpha)
Multiplies all elements by this constant.
This class&#39;s purpose is to read in Wave files.
Definition: wave-reader.h:106
int NumArgs() const
Number of positional parameters (c.f. argc-1).
void AddVectorsWithOffset(const Vector< BaseFloat > &signal1, int32 offset, Vector< BaseFloat > *signal2)
A class representing a vector.
Definition: kaldi-vector.h:406
#define KALDI_ASSERT(cond)
Definition: kaldi-error.h:185
void AddNoise(Vector< BaseFloat > *noise, BaseFloat snr_db, BaseFloat time, BaseFloat samp_freq, BaseFloat signal_power, Vector< BaseFloat > *signal)
MatrixIndexT NumRows() const
Returns number of rows (or zero for empty matrix).
Definition: kaldi-matrix.h:64
#define KALDI_VLOG(v)
Definition: kaldi-error.h:156
void CopyRowFromVec(const VectorBase< Real > &v, const MatrixIndexT row)
Copy vector into specific row of matrix.
BaseFloat MaxAbsolute(const Vector< BaseFloat > &vector)
void Write(std::ostream &os) const
Write() will throw on error. os should be opened in binary mode.
Definition: wave-reader.cc:332
Provides a vector abstraction class.
Definition: kaldi-vector.h:41
void SetZero()
Set vector to all zeros.
Real VecVec(const VectorBase< Real > &a, const VectorBase< Real > &b)
Returns dot product between v1 and v2.
Definition: kaldi-vector.cc:37
SubVector< Real > Range(const MatrixIndexT o, const MatrixIndexT l)
Returns a sub-vector of a vector (a range of elements).
Definition: kaldi-vector.h:94