Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply grad limit for CTC training #1793

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion egs/librispeech/ASR/zipformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import random
from typing import Optional, Tuple

import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from lhotse.dataset import SpecAugment
from scaling import ScaledLinear
from scaling import ScaledLinear, scale_grad

from icefall.utils import add_sos, make_pad_mask, time_warp


@contextlib.contextmanager
def fork_rng(cpu_state, cuda_state, rng_state, device):
with torch.random.fork_rng(devices=[device]):
torch.set_rng_state(cpu_state)
torch.cuda.set_rng_state(cuda_state, device)

rng_state2 = random.getstate()
random.setstate(rng_state)

try:
yield
finally:
random.setstate(rng_state2)


class AsrModel(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -159,6 +176,9 @@ def forward_ctc(
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
encoder_out_prev: Optional[torch.Tensor] = None,
encoder_out_lens_prev: Optional[torch.Tensor] = None,
model_prev=None,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
Expand All @@ -170,9 +190,28 @@ def forward_ctc(
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
device = encoder_out.device
if model_prev:
cpu_state = torch.get_rng_state()
cuda_state = torch.cuda.get_rng_state(device)
rng_state = random.getstate()

# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)

if model_prev:
with fork_rng(
cpu_state=cpu_state,
cuda_state=cuda_state,
rng_state=rng_state,
device=device,
):
ctc_output_prev = model_prev.ctc_output(encoder_out_prev)

has_grown = ctc_output > 0.8 * ctc_output_prev
grad_scale_tensor = torch.where(has_grown, 0.5, 1.0)
ctc_output = scale_grad(ctc_output, grad_scale_tensor)

ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets.cpu(),
Expand Down Expand Up @@ -345,6 +384,7 @@ def forward(
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
model_prev=None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -418,9 +458,29 @@ def forward(
x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0)

device = x.device
if model_prev:
cpu_state = torch.get_rng_state()
cuda_state = torch.cuda.get_rng_state(device)
rng_state = random.getstate()

# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)

if model_prev:
with fork_rng(
cpu_state=cpu_state,
cuda_state=cuda_state,
rng_state=rng_state,
device=device,
):
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
x, x_lens
)
else:
encoder_out_prev = None
encoder_out_lens_prev = None

row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]

Expand Down Expand Up @@ -451,6 +511,9 @@ def forward(
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
encoder_out_prev=encoder_out_prev,
encoder_out_lens_prev=encoder_out_lens_prev,
model_prev=model_prev,
)
cr_loss = torch.empty(0)
else:
Expand Down
16 changes: 12 additions & 4 deletions egs/librispeech/ASR/zipformer/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,16 +1136,24 @@ def with_loss(x, y, name):

class ScaleGradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
ctx.alpha = alpha
def forward(ctx, x: Tensor, alpha: Union[float, Tensor]) -> Tensor:
if isinstance(alpha, Tensor):
ctx.save_for_backward(alpha)
else:
ctx.alpha = alpha
return x

@staticmethod
def backward(ctx, grad: Tensor):
return grad * ctx.alpha, None
if hasattr(ctx, "alpha"):
alpha = ctx.alpha
else:
(alpha,) = ctx.saved_tensors

return grad * alpha, None


def scale_grad(x: Tensor, alpha: float):
def scale_grad(x: Tensor, alpha: Union[float, Tensor]):
return ScaleGradFunction.apply(x, alpha)


Expand Down
Loading
Loading