-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dice_loss训练中显示为NAN #6
Comments
您好,我们在repo中提供了将dice-loss应用到二分类paraphrase identification任务的MRPC数据集中的样例,麻烦请您参考一下。 |
you can try gradient clip to prevent NAN |
1 - 2 *intersection / denominator 是不是应该1.0 - 2 *intersection / denominator?
…------------------ 原始邮件 ------------------
发件人: ***@***.***>;
发送时间: 2022年1月24日(星期一) 晚上6:29
收件人: ***@***.***>;
抄送: ***@***.***>; ***@***.***>;
主题: Re: [ShannonAI/dice_loss_for_NLP] dice_loss训练中显示为NAN (#6)
您好,在做二分类任务时,我参考adaptive_dice_loss.py中代码:
intersection = torch.sum((1-flat_input)**self.alpha * flat_input * flat_target, -1) + self.smooth denominator = torch.sum((1-flat_input)**self.alpha * flat_input) + flat_target.sum() + self.smooth return 1 - 2 * intersection / denominator
写了对应的tensorflow版的损失函数:
def dice_loss(alpha=0.1, smooth=1e-8): def dice_loss_fixed(y_pred, y_true): intersection = K.sum((1-y_pred)**alpha * y_pred * y_true, -1) + smooth denominator = K.sum((1-y_pred)**alpha * y_pred,-1) + K.sum(y_true) + smooth return 1 - 2 *intersection / denominator return dice_loss_fixed
可在训练中,损失值一直显示为NAN,不知为何,还请麻烦解答指正,谢谢~
model.compile(optimizer=keras.optimizers.RMSprop(), loss=[dice_loss(alpha=0.1,smooth=1e-8)], metrics=['accuracy']) history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_data=(x_test, y_test))
you can try gradient clip to prevent NAN
—
Reply to this email directly, view it on GitHub, or unsubscribe.
Triage notifications on the go with GitHub Mobile for iOS or Android.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
您好,在做二分类任务时,我参考adaptive_dice_loss.py中代码:
写了对应的tensorflow版的损失函数:
可在训练中,损失值一直显示为NAN,不知为何,还请麻烦解答指正,谢谢~
The text was updated successfully, but these errors were encountered: