-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathST.py
More file actions
138 lines (115 loc) · 4.4 KB
/
ST.py
File metadata and controls
138 lines (115 loc) · 4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Specifies the inference interfaces for Speech Translation (ST) modules.
Authors:
* Aku Rouhe 2021
* Peter Plantinga 2021
* Loren Lugosch 2020
* Mirco Ravanelli 2020
* Titouan Parcollet 2021
* Abdel Heba 2021
* Andreas Nautsch 2022, 2023
* Pooneh Mousavi 2023
* Sylvain de Langen 2023
* Adel Moumen 2023
* Pradnya Kandarkar 2023
"""
import torch
from speechbrain.inference.interfaces import Pretrained
class EncoderDecoderS2UT(Pretrained):
"""A ready-to-use Encoder Decoder for speech-to-unit translation model
The class can be used to run the entire encoder-decoder S2UT model
(translate_file()) to translate speech. The given YAML must contains the fields
specified in the *_NEEDED[] lists.
Arguments
---------
*args : tuple
**kwargs : dict
Arguments are forwarded to ``Pretrained`` parent class.
Example
-------
>>> from speechbrain.inference.ST import EncoderDecoderS2UT
>>> tmpdir = getfixture("tmpdir")
>>> s2ut_model = EncoderDecoderS2UT.from_hparams(
... source="speechbrain/s2st-transformer-fr-en-hubert-l6-k100-cvss",
... savedir=tmpdir,
... ) # doctest: +SKIP
>>> s2ut_model.translate_file(
... "speechbrain/s2st-transformer-fr-en-hubert-l6-k100-cvss/example-fr.wav"
... ) # doctest: +SKIP
"""
HPARAMS_NEEDED = ["sample_rate"]
MODULES_NEEDED = ["encoder", "decoder"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sample_rate = self.hparams.sample_rate
def translate_file(self, path):
"""Translates the given audiofile into a sequence speech unit.
Arguments
---------
path : str
Path to audio file which to translate.
Returns
-------
int[]
The audiofile translation produced by this speech-to-unit translationmodel.
"""
audio = self.load_audio(path)
audio = audio.to(self.device)
# Fake a batch:
batch = audio.unsqueeze(0)
rel_length = torch.tensor([1.0])
predicted_tokens = self.translate_batch(batch, rel_length)
return predicted_tokens[0]
def encode_batch(self, wavs, wav_lens):
"""Encodes the input audio into a sequence of hidden states
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderS2UT.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
wavs = wavs.float()
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
encoder_out = self.mods.encoder(wavs, wav_lens)
return encoder_out
def translate_batch(self, wavs, wav_lens):
"""Translates the input audio into a sequence of words
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderS2UT.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
list
Each waveform in the batch translated.
tensor
Each predicted token id.
"""
with torch.no_grad():
wav_lens = wav_lens.to(self.device)
encoder_out = self.encode_batch(wavs, wav_lens)
predicted_tokens, _, _, _ = self.mods.decoder(encoder_out, wav_lens)
return predicted_tokens
def forward(self, wavs, wav_lens):
"""Runs full translation"""
return self.encode_batch(wavs, wav_lens)