
本文深入探讨了pytorch中`crossentropyloss`常见的`runtimeerror: expected scalar type long but found float`错误。该错误通常源于目标标签(target)的数据类型不符合`crossentropyloss`的预期。我们将详细解析错误原因,并提供如何在训练循环中正确使用`crossentropyloss`,包括标签类型转换、输入顺序以及避免重复应用softmax等关键最佳实践,以确保模型训练的稳定性和准确性。
在深度学习的分类任务中,torch.nn.CrossEntropyLoss是一个非常常用的损失函数。它结合了LogSoftmax和负对数似然损失(NLLLoss),能够高效地处理多分类问题。然而,初学者在使用时常会遇到一个特定的运行时错误:RuntimeError: expected scalar type Long but found Float。这个错误明确指出,CrossEntropyLoss在处理其目标标签(target)时,期望的数据类型是torch.Long(即64位整数),但实际接收到的是torch.Float。
理解CrossEntropyLoss的工作原理
CrossEntropyLoss函数在PyTorch中通常接收两个主要参数:
input (或 logits):这是模型的原始输出,通常是未经Softmax激活函数处理的“对数几率”(logits)。它的形状通常是 (N, C),其中 N 是批量大小,C 是类别数量。对于图像任务,如果模型输出是像素级别的分类(如U-Net),则形状可能是 (N, C, H, W)。target (或 labels):这是真实的类别标签。它应该包含每个样本的类别索引,其数据类型必须是torch.long(或torch.int64)。它的形状通常是 (N),对于像素级别的分类,形状可能是 (N, H, W)。target中的值应介于 0 到 C-1 之间,代表对应的类别索引。
关键点: CrossEntropyLoss内部会自行执行Softmax操作,因此,向其传入经过Softmax处理的概率值是不正确的,这可能导致数值不稳定或不准确的损失计算。
RuntimeError: expected scalar type Long but found Float 错误解析与修正
这个错误的核心在于target张量的数据类型不匹配。在提供的代码片段中,错误发生在以下这行:
loss = criterion(output, labels.float())
尽管labels张量在创建时已经被明确指定为long类型:
labels = Variable(torch.FloatTensor(10).uniform_(0, 120).long())
但在计算损失时,又通过.float()方法将其强制转换回了float类型。这就是导致CrossEntropyLoss抛出错误的原因。
修正方法:只需移除对labels的.float()调用,确保target张量保持其long类型即可。
# 错误代码# loss = criterion(output, labels.float())# 正确代码loss = criterion(output, labels)
训练循环中的常见误用及修正
除了上述直接的类型转换错误,在提供的train_one_epoch函数中,也存在一些与CrossEntropyLoss使用相关的常见误区。
1. 标签数据类型转换错误
在train_one_epoch函数内部,标签被错误地转换成了float类型:
labels = labels.to(device).float() # 错误:将标签转换为float类型
这会直接导致CrossEntropyLoss接收到float类型的标签,再次触发同样的RuntimeError。
修正方法:确保标签在送入损失函数前是long类型。
labels = labels.to(device).long() # 正确:将标签转换为long类型
2. CrossEntropyLoss输入参数顺序和类型错误
在train_one_epoch函数中,计算损失的行是:
Spacely AI
为您的房间提供AI室内设计解决方案,寻找无限的创意
67 查看详情
loss = criterion(labels, torch.argmax(outputs, dim=1)) # 错误:参数顺序和类型不符
这里存在两个问题:
参数顺序错误: criterion(即CrossEntropyLoss)期望的第一个参数是模型的输出(logits),第二个参数是真实标签(target)。这里却反了过来。target参数类型错误: torch.argmax(outputs, dim=1) 已经是一个预测结果的类别索引,它不应该作为CrossEntropyLoss的target参数传入。target参数应是真实的、未经模型处理的类别标签。
修正方法:将模型的原始输出(logits)作为第一个参数,真实的long类型标签作为第二个参数。
3. 预先应用Softmax的错误
在计算outputs时,代码中显式地应用了F.softmax:
outputs = F.softmax(model(inputs.float()), dim=1) # 错误:CrossEntropyLoss内部已包含Softmax
由于CrossEntropyLoss内部已经包含了Softmax操作,再次应用F.softmax会导致:
冗余计算: 增加了不必要的计算开销。数值稳定性问题: 两次Softmax操作可能导致数值精度下降,尤其是在处理非常大或非常小的对数几率时。
修正方法:直接将模型的原始输出(logits)传递给CrossEntropyLoss。
优化后的训练函数示例
综合以上修正,以下是train_one_epoch函数的一个优化版本,遵循了CrossEntropyLoss的最佳实践:
import torchimport torch.nn as nnimport torch.nn.functional as Fimport time# 假设 model, optimizer, dataloaders, device 已经定义def train_one_epoch(model, optimizer, data_loader, device): model.train() running_loss = 0.0 start_time = time.time() total = 0 correct = 0 # 确保 data_loader 是实际的 DataLoader 对象 # 这里假设 dataloaders['train'] 是一个可迭代的 DataLoader current_data_loader = data_loader # 如果传入的是字符串'train',需要根据实际情况获取 if isinstance(data_loader, str): current_data_loader = dataloaders[data_loader] # 假设 dataloaders 是一个全局字典 for i, (inputs, labels) in enumerate(current_data_loader): inputs = inputs.to(device) # 核心修正:确保标签是long类型 labels = labels.to(device).long() optimizer.zero_grad() # 修正:直接使用模型的原始输出(logits),不应用Softmax # 假设 model(inputs.float()) 返回的是 logits logits = model(inputs.float()) # 打印形状以调试 # print("Inputs shape:", inputs.shape) # print("Logits shape:", logits.shape) # print("Labels shape:", labels.shape) # 修正:CrossEntropyLoss的正确使用方式是 (logits, target_indices) loss = criterion(logits, labels) loss.backward() optimizer.step() # 计算准确率时,需要对logits应用argmax _, predicted = torch.max(logits.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total running_loss += loss.item() if i % 10 == 0: # print every 10 batches batch_time = time.time() speed = (i+1)/(batch_time-start_time) print('[%5d] loss: %.3f, speed: %.2f, accuracy: %.2f %%' % (i, running_loss, speed, accuracy)) running_loss = 0.0 total = 0 correct = 0
验证模型函数 (val_model) 的注意事项
val_model函数在处理标签时使用了labels = labels.to(device).long(),这是正确的。同时,outputs = model(inputs.float()) 假设模型输出的是logits,然后用 torch.max(outputs.data, 1) 来获取预测类别,这也是标准做法。
唯一需要注意的是,model.val() 应该更正为 model.eval(),这会将模型设置为评估模式,禁用Dropout和BatchNorm等层,以确保评估结果的稳定性。
def val_model(model, data_loader, device): # 添加 device 参数 model.eval() # 修正:使用 model.eval() start_time = time.time() total = 0 correct = 0 current_data_loader = data_loader if isinstance(data_loader, str): current_data_loader = dataloaders[data_loader] with torch.no_grad(): for i, (inputs, labels) in enumerate(current_data_loader): inputs = inputs.to(device) labels = labels.to(device).long() # 正确 outputs = model(inputs.float()) # 假设 model 输出 logits _, predicted = torch.max(outputs.data, 1) total += labels.size(0) # 修正:(predicted == labels).sum() 返回一个标量,直接 .item() 即可 correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print('Finished Testing') print('Testing accuracy: %.1f %%' %(accuracy))
总结与最佳实践
处理PyTorch中的CrossEntropyLoss时,请牢记以下关键点:
目标标签的数据类型: CrossEntropyLoss的target参数必须是torch.long类型(即64位整数),且包含类别索引(从0到C-1)。模型输出: CrossEntropyLoss的input参数应是模型的原始输出(logits),即未经Softmax激活函数处理的对数几率。避免重复Softmax: 不要在将模型输出传递给CrossEntropyLoss之前手动应用F.softmax,因为CrossEntropyLoss内部已经包含了此操作。参数顺序: CrossEntropyLoss的调用格式是 loss = criterion(logits, target_labels)。评估模式: 在验证或测试模型时,务必使用model.eval()来设置模型为评估模式,并在torch.no_grad()上下文管理器中执行前向传播,以节省内存和计算。
遵循这些原则,可以有效避免RuntimeError: expected scalar type Long but found Float以及其他与CrossEntropyLoss使用相关的常见问题,确保模型训练的顺利进行。
以上就是PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践的详细内容,更多请关注创想鸟其它相关文章!
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/847285.html
微信扫一扫
支付宝扫一扫