在TensorFlow 2中实现完全卷积网络(FCN)

在TensorFlow 2中实现完全卷积网络(FCN)

作者 | himanshu rawlani

来源 | Medium

编辑 | 代码医生团队

卷积神经网络(CNN)非常适合计算机视觉任务。使用对大型图像集(如ImageNet,COCO等)进行训练的预训练模型,可以快速使这些体系结构专业化,以适合独特数据集。此过程称为迁移学习。但是有一个陷阱!用于图像分类和对象检测任务的预训练模型通常在固定的输入图像尺寸上训练。这些通常从224x224x3到某个范围变化,512x512x3并且大多数具有1的长宽比,即图像的宽度和高度相等。如果它们不相等,则将图像调整为相等的高度和宽度。

较新的体系结构确实能够处理可变的输入图像大小,但是与图像分类任务相比,它在对象检测和分割任务中更为常见。最近遇到了一个有趣的用例,其中有5个不同类别的图像,每个类别都有微小的差异。此外图像的纵横比也比平常高。图像的平均高度约为30像素,宽度约为300像素。这是一个有趣的原因,其原因如下:

调整图像大小容易使重要功能失真预训练的架构非常庞大,并且总是过度拟合数据集任务要求低延迟

需要具有可变输入尺寸的CNN

尝试了MobileNet和EfficientNet的基本模型,但没有任何效果。需要一种对输入图像大小没有任何限制并且可以执行手边的图像分类任务的网络。震惊的第一件事是完全卷积网络(FCN)。FCN是一个不包含任何“密集”层的网络(如在传统的CNN中一样),而是包含1×1卷积,用于执行完全连接的层(密集层)的任务。尽管没有密集层可以输入可变的输入,但是有两种技术可以在保留可变输入尺寸的同时使用密集层。本教程描述了其中一些技术。在本教程中,将执行以下步骤:

使用Keras在TensorFlow中构建完全卷积网络(FCN)下载并拆分样本数据集在Keras中创建生成器以加载和处理内存中的一批数据训练具有可变批次尺寸的网络使用TensorFlow Serving部署模型

获取代码

本文中的代码片段仅突出实际脚本的一部分,有关完整代码,请参阅GitHub存储库。

https://github.com/himanshurawlani/fully_convolutional_network.git

1.设计引擎(model.py)

通过堆叠由2D卷积层(Conv2D)和所需的正则化(Dropout和BatchNormalization)组成的卷积块来构建FCN模型。正则化可防止过度拟合并有助于快速收敛。还添加了一个激活层来合并非线性。在Keras中,输入批次尺寸是自动添加的,不需要在输入层中指定它。由于输入图像的高度和宽度是可变的,因此将输入形状指定为(None, None, 3)。3表示图像中的通道数,该数量对于彩色图像(RGB)是固定的。

代码语言:javascript代码运行次数:0运行复制

import tensorflow as tf def FCN_model(len_classes=5, dropout_rate=0.2):        # Input layer    input = tf.keras.layers.Input(shape=(None, None, 3))     # A convolution block    x = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=1)(input)    x = tf.keras.layers.Dropout(dropout_rate)(x)    x = tf.keras.layers.BatchNormalization()(x)    x = tf.keras.layers.Activation('relu')(x)        # Stack of convolution blocks    .    .    .

最小图像尺寸要求

在输入施加卷积块之后,输入的高度和宽度将降低基于所述值kernel_size和strides。如果输入图像的尺寸太小,那么可能无法达到下一个卷积块所需的最小高度和宽度(应大于或等于内核尺寸)。确定最小输入尺寸的尝试和错误方法如下:

确定要堆叠的卷积块数选择任何输入形状以说出(32, 32, 3)并堆叠数量越来越多的通道的卷积块尝试构建模型并打印model.summary()以查看每个图层的输出形状。确保(1, 1, num_of_filters)从最后一个卷积块获得输出尺寸(这将被输入到完全连接的层)。尝试减小/增大输入形状,内核大小或步幅,以满足步骤4中的条件。满足条件的输入形状以及其他配置是网络所需的最小输入尺寸。

还有,以计算输出体积的空间大小,其所示的输入体积的函数的数学方式这里。找到最小输入尺寸后,现在需要将最后一个卷积块的输出传递到完全连接的层。但是任何尺寸大于最小输入尺寸的输入都需要汇总以满足步骤4中的条件。了解如何使用我们的主要成分来做到这一点。

http://cs231n.github.io/convolutional-networks/#conv

主要成分

全连接层(FC层)将执行分类任务。可以通过两种方式构建FC层:

致密层1x1卷积

如果要使用密集层,则必须固定模型输入尺寸,因为必须预先定义作为密集层输入的参数数量才能创建密集层。具体来说,希望(height, width, num_of_filters)最后一个卷积块的输出中的高度和宽度为常数或1。滤波器的数量始终是固定的,因为这些值是在每个卷积块中定义的。

1x1卷积的输入尺寸可以是(1, 1, num_of_filters)或(height, width, num_of_filters)模仿它们沿num_of_filters尺寸方向FC层的功能。但是,在1x1卷积之后,最后一层(Softmax激活层)的输入必须具有固定的长度(类数)。

主要成分:GlobalMaxPooling2D() / GlobalAveragePooling2D()。Keras中的这些层将尺寸的输入转换(height, width, num_of_filters)为(1, 1, num_of_filters)实质上沿尺寸的每个值的最大值或平均值,用于沿尺寸的每个过滤器num_of_filters。

代码语言:javascript代码运行次数:0运行复制

# Uncomment the below line if you're using dense layers# x = tf.keras.layers.GlobalMaxPooling2D()(x) # Fully connected layer 1# x = tf.keras.layers.Dropout(dropout_rate)(x)# x = tf.keras.layers.BatchNormalization()(x)# x = tf.keras.layers.Dense(units=64)(x)# x = tf.keras.layers.Activation('relu')(x) # Fully connected layer 1x = tf.keras.layers.Conv2D(filters=64, kernel_size=1, strides=1)(x)x = tf.keras.layers.Dropout(dropout_rate)(x)x = tf.keras.layers.BatchNormalization()(x)x = tf.keras.layers.Activation('relu')(x) # Fully connected layer 2# x = tf.keras.layers.Dropout(dropout_rate)(x)# x = tf.keras.layers.BatchNormalization()(x)# x = tf.keras.layers.Dense(units=len_classes)(x)# predictions = tf.keras.layers.Activation('softmax')(x) # Fully connected layer 2x = tf.keras.layers.Conv2D(filters=len_classes, kernel_size=1, strides=1)(x)x = tf.keras.layers.Dropout(dropout_rate)(x)x = tf.keras.layers.BatchNormalization()(x)x = tf.keras.layers.GlobalMaxPooling2D()(x)predictions = tf.keras.layers.Activation('softmax')(x) model = tf.keras.Model(inputs=input, outputs=predictions)print(model.summary())

密集层与1x1卷积

该代码包括密集层(注释掉)和1x1卷积。在使用两种配置构建和训练模型之后,这里是一些观察结果:

两种模型都包含相同数量的可训练参数。类似的训练和推理时间。密集层比1x1卷积的泛化效果更好。

第三点不能一概而论,因为它取决于诸如数据集中的图像数量,使用的数据扩充,模型初始化等因素。但是这些是实验中的观察结果。可以通过执行命令来独立运行脚本,以测试是否已成功构建模型$python model.py。

2.下载fuel(data.py)

采风问卷 采风问卷

采风问卷是一款全新体验的调查问卷、表单、投票、评测的调研平台,新奇的交互形式,漂亮的作品,让客户眼前一亮,让创作者获得更多的回复。

采风问卷 20 查看详情 采风问卷

本教程中使用的flowers数据集主要旨在了解在训练具有可变输入维度的模型时面临的挑战。测试FCN模型的一些有趣的数据集可能来自医学成像领域,其中包含对图像分类至关重要的微观特征,而其他数据集包含的几何图案/形状在调整图像大小后可能会失真。

1.提供的脚本(data.py)需要独立运行($python data.py)。它将执行以下任务:

2.下载包含5类(“雏菊”,“蒲公英”,“玫瑰”,“向日葵”,“郁金香”)的花卉数据集。有关数据集的更多细节在这里。

https://www.tensorflow.org/datasets/catalog/tf_flowers

3.将数据集分为训练和验证集。可以设置要复制到训练和验证集中的图像数量。

提供有关数据集的统计信息,例如图像的最小,平均和最大高度和宽度。

此脚本使用来下载.tar文件并将其内容提取到当前目录中keras.utils.get_file()。如果想使用TensorFlow数据集(TFDS),可以查看本教程,该教程说明了TFDS以及数据扩充的用法。

3.特殊化carburetor(generator.py)

想在不同的输入维度上训练模型。给定批次和批次之间的每个图像都有不同的尺寸。所以有什么问题?退后一步,回顾一下如何训练传统的图像分类器。在传统的图像分类器中,将图像调整为给定尺寸,通过转换为numpy数组或张量将其打包成批,然后将这批数据通过模型进行正向传播。在整个批次中评估指标(损失,准确性等)。根据这些指标计算要反向传播的梯度。

无法调整图像大小(因为我们将失去微观特征)。现在由于无法调整图像的大小,因此无法将其转换为成批的numpy数组。这是因为如果有一个10张图像的列表,(height, width, 3)它们的height和值不同,width并且尝试将其传递给np.array(),则结果数组的形状将为(10,)and not (10, height, width, 3)!但是模型期望输入尺寸为后一种形状。一种解决方法是编写一个自定义训练循环,该循环执行以下操作:

通过将通过每个图像,在列表中(分批),通过模型(height, width, 3)来(1, height, width, 3)使用np.expand_dims(img, axis=0)。累积python列表(批处理)中每个图像的度量。使用累积的指标计算损耗和梯度。将渐变更新应用到模型。重置指标的值并创建新的图像列表(批次)。

尝试了上述步骤,但建议不要采用上述策略。它很费力,导致代码复杂且不可持续,并且运行速度非常慢!每个人都喜欢优雅的 model.fit()和model.fit_generator()。后者是将在这里使用的!但是首先是化油器。

化油器是一种以合适的空燃比混合用于内燃机的空气和燃料的装置。这就是所需要的,空气!找到批处理中图像的最大高度和宽度,并用零填充每个其他图像,以使批处理中的每个图像都具有相等的尺寸。现在可以轻松地将其转换为numpy数组或张量,并将其传递给fit_generator()。该模型会自动学习忽略零(基本上是黑色像素),并从填充图像的预期部分学习特征。这样就有了一个具有相等图像尺寸的批处理,但是每个批处理具有不同的形状(由于批处理中图像的最大高度和宽度不同)。可以generator.py使用独立运行文件$python generator.py并交叉检查输出。

代码语言:javascript代码运行次数:0运行复制

def construct_image_batch(image_group, BATCH_SIZE):    # get the max image shape    max_shape = tuple(max(image.shape[x] for image in image_group) for x in range(3))     # construct an image batch object    image_batch = np.zeros((BATCH_SIZE,) + max_shape, dtype='float32')     # copy all images to the upper left part of the image batch object    for image_index, image in enumerate(image_group):        image_batch[image_index, :image.shape[0], :image.shape[1], :image.shape[2]] = image     return image_batch

4.点燃认知(train.py)

训练脚本导入并实例化以下类:

生成器:需要指定到创建的路径train和val目录data.py。FCN_model:需要指定最终输出层中所需的类数。

将上述对象传递给train()使用Adam优化器和分类交叉熵损失函数编译模型的函数。创建一个检查点回调,以在训练期间保存最佳模型。最佳模型是根据每个时期结束时的验证集计算出的损失值确定的。fit_generator()函数在很大程度上简化了代码。

代码语言:javascript代码运行次数:0运行复制

def train(model, train_generator, val_generator, epochs = 50):    model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0001),                    loss='categorical_crossentropy',                    metrics=['accuracy'])     checkpoint_path = './snapshots'    os.makedirs(checkpoint_path, exist_ok=True)    model_path = os.path.join(checkpoint_path, 'model_epoch_{epoch:02d}_loss_{loss:.2f}_acc_{acc:.2f}_val_loss_{val_loss:.2f}_val_acc_{val_acc:.2f}.h5')        history = model.fit_generator(generator=train_generator,                                    steps_per_epoch=len(train_generator),                                    epochs=epochs,                                    callbacks=[tf.keras.callbacks.ModelCheckpoint(model_path, monitor='val_loss', save_best_only=True, verbose=1)],                                    validation_data=val_generator,                                    validation_steps=len(val_generator))     return history

建议在Google Colab上进行训练,除非本地计算机上有GPU。GitHub存储库包含一个Colab笔记本,该笔记本将训练所需的所有内容组合在一起。可以在Colab本身中修改python脚本,并在选择的数据集上训练不同的模型配置。完成训练后,可以从Colab中的“文件”选项卡将最佳快照下载到本地计算机。

5.使用TensorFlow Serving(inference.py)部署模型

下载模型后,需要使用将其导出为SavedModel格式export_savedmodel.py。.h5在主要功能中指定下载模型(文件)的路径,然后使用命令执行脚本$python export_savedmodel.py。该脚本使用TensorFlow 2.0中的新功能,该功能从.h5文件中加载Keras模型并将其保存为TensorFlow SavedModel格式。SavedModel将导出到export_path脚本中指定的位置。TensorFlow服务docker映像需要此SavedModel。

代码语言:javascript代码运行次数:0运行复制

def export(input_h5_file, export_path):    # The export path contains the name and the version of the model    tf.keras.backend.set_learning_phase(0)  # Ignore dropout at inference    model = tf.keras.models.load_model(input_h5_file)    model.save(export_path, save_format='tf')    print(f"SavedModel created at {export_path}")

要启动TensorFlow Serving服务器,请转到导出SavedModel的目录(./flower_classifier在这种情况下)并运行以下命令(注意:计算机上必须安装了Docker):

代码语言:javascript代码运行次数:0运行复制

$ docker run --rm -t -p 8501:8501 -v "$(pwd):/models/flower_classifier" -e MODEL_NAME=flower_classifier --name flower_classifier tensorflow/serving

可以使用$ docker ps命令验证容器在后台运行。还可以使用查看容器日志$ docker logs your_container_id。该inference.py脚本包含用于构建具有统一图像尺寸的批次的代码,并将这些批次作为POST请求发送到TensorFlow服务服务器。从服务器接收的输出被解码并在终端中打印。

代码语言:javascript代码运行次数:0运行复制

def make_serving_request(image_batch):    data = json.dumps({"signature_name": "serving_default",                       "instances": image_batch.tolist()})     headers = {"content-type": "application/json"}     os.environ['NO_PROXY'] = 'localhost'    json_response = requests.post(        'http://localhost:8501/v1/models/flower_classifier:predict', data=data, headers=headers)     predictions = json.loads(json_response.text)['predictions']     return predictions

梦想的传达

本教程仅介绍机器学习工作流程中的单个组件。机器学习管道包括针对组织及其用例的大量训练,推断和监视周期。建立这些管道需要对驾驶员,乘客和车辆路线有更深入的了解。只有这样,才能实现理想的运输工具!

以上就是在TensorFlow 2中实现完全卷积网络(FCN)的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/259537.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月4日 08:36:36
下一篇 2025年11月4日 08:38:25

相关推荐

  • 如何解决本地图片在使用 mask JS 库时出现的跨域错误?

    如何跨越localhost使用本地图片? 问题: 在本地使用mask js库时,引入本地图片会报跨域错误。 解决方案: 要解决此问题,需要使用本地服务器启动文件,以http或https协议访问图片,而不是使用file://协议。例如: python -m http.server 8000 然后,可以…

    2025年12月24日
    200
  • 使用 Mask 导入本地图片时,如何解决跨域问题?

    跨域疑难:如何解决 mask 引入本地图片产生的跨域问题? 在使用 mask 导入本地图片时,你可能会遇到令人沮丧的跨域错误。为什么会出现跨域问题呢?让我们深入了解一下: mask 框架假设你以 http(s) 协议加载你的 html 文件,而当使用 file:// 协议打开本地文件时,就会产生跨域…

    2025年12月24日
    200
  • HTML、CSS 和 JavaScript 中的简单侧边栏菜单

    构建一个简单的侧边栏菜单是一个很好的主意,它可以为您的网站添加有价值的功能和令人惊叹的外观。 侧边栏菜单对于客户找到不同项目的方式很有用,而不会让他们觉得自己有太多选择,从而创造了简单性和秩序。 今天,我将分享一个简单的 HTML、CSS 和 JavaScript 源代码来创建一个简单的侧边栏菜单。…

    2025年12月24日
    200
  • 前端代码辅助工具:如何选择最可靠的AI工具?

    前端代码辅助工具:可靠性探讨 对于前端工程师来说,在HTML、CSS和JavaScript开发中借助AI工具是司空见惯的事情。然而,并非所有工具都能提供同等的可靠性。 个性化需求 关于哪个AI工具最可靠,这个问题没有一刀切的答案。每个人的使用习惯和项目需求各不相同。以下是一些影响选择的重要因素: 立…

    2025年12月24日
    300
  • 带有 HTML、CSS 和 JavaScript 工具提示的响应式侧边导航栏

    响应式侧边导航栏不仅有助于改善网站的导航,还可以解决整齐放置链接的问题,从而增强用户体验。通过使用工具提示,可以让用户了解每个链接的功能,包括设计紧凑的情况。 在本教程中,我将解释使用 html、css、javascript 创建带有工具提示的响应式侧栏导航的完整代码。 对于那些一直想要一个干净、简…

    2025年12月24日
    000
  • 布局 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在这里查看视觉效果: 固定导航 – 布局 – codesandbox两列 – 布局 – codesandbox三列 – 布局 – codesandbox圣杯 &#8…

    2025年12月24日
    000
  • 隐藏元素 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在此处查看隐藏元素的视觉效果 – codesandbox 隐藏元素 hiding elements hiding elements hiding elements hiding elements hiding element…

    2025年12月24日
    400
  • 居中 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在此处查看垂直中心 – codesandbox 和水平中心的视觉效果。 通过 css 居中 垂直居中 centering centering centering centering centering centering立即…

    2025年12月24日 好文分享
    300
  • 如何在 Laravel 框架中轻松集成微信支付和支付宝支付?

    如何用 laravel 框架集成微信支付和支付宝支付 问题:如何在 laravel 框架中集成微信支付和支付宝支付? 回答: 建议使用 easywechat 的 laravel 版,easywechat 是一个由腾讯工程师开发的高质量微信开放平台 sdk,已被广泛地应用于许多 laravel 项目中…

    2025年12月24日
    000
  • 如何在移动端实现子 div 在父 div 内任意滑动查看?

    如何在移动端中实现让子 div 在父 div 内任意滑动查看 在移动端开发中,有时我们需要让子 div 在父 div 内任意滑动查看。然而,使用滚动条无法实现负值移动,因此需要采用其他方法。 解决方案: 使用绝对布局(absolute)或相对布局(relative):将子 div 设置为绝对或相对定…

    2025年12月24日
    000
  • 移动端嵌套 DIV 中子 DIV 如何水平滑动?

    移动端嵌套 DIV 中子 DIV 滑动 在移动端开发中,遇到这样的问题:当子 DIV 的高度小于父 DIV 时,无法在父 DIV 中水平滚动子 DIV。 无限画布 要实现子 DIV 在父 DIV 中任意滑动,需要创建一个无限画布。使用滚动无法达到负值,因此需要使用其他方法。 相对定位 一种方法是将子…

    2025年12月24日
    000
  • 移动端项目中,如何消除rem字体大小计算带来的CSS扭曲?

    移动端项目中消除rem字体大小计算带来的css扭曲 在移动端项目中,使用rem计算根节点字体大小可以实现自适应布局。但是,此方法可能会导致页面打开时出现css扭曲,这是因为页面内容在根节点字体大小赋值后重新渲染造成的。 解决方案: 要避免这种情况,将计算根节点字体大小的js脚本移动到页面的最前面,即…

    2025年12月24日
    000
  • Nuxt 移动端项目中 rem 计算导致 CSS 变形,如何解决?

    Nuxt 移动端项目中解决 rem 计算导致 CSS 变形 在 Nuxt 移动端项目中使用 rem 计算根节点字体大小时,可能会遇到一个问题:页面内容在字体大小发生变化时会重绘,导致 CSS 变形。 解决方案: 可将计算根节点字体大小的 JS 代码块置于页面最前端的 标签内,确保在其他资源加载之前执…

    2025年12月24日
    200
  • Nuxt 移动端项目使用 rem 计算字体大小导致页面变形,如何解决?

    rem 计算导致移动端页面变形的解决方法 在 nuxt 移动端项目中使用 rem 计算根节点字体大小时,页面会发生内容重绘,导致页面打开时出现样式变形。如何避免这种现象? 解决方案: 移动根节点字体大小计算代码到页面顶部,即 head 中。 原理: flexível.js 也遇到了类似问题,它的解决…

    2025年12月24日
    000
  • 形状 – CSS 挑战

    您可以在 github 仓库中找到这篇文章中的所有代码。 您可以在此处查看 codesandbox 的视觉效果。 通过css绘制各种形状 如何在 css 中绘制正方形、梯形、三角形、异形三角形、扇形、圆形、半圆、固定宽高比、0.5px 线? shapes 0.5px line .square { w…

    2025年12月24日
    000
  • 有哪些美观的开源数字大屏驾驶舱框架?

    开源数字大屏驾驶舱框架推荐 问题:有哪些美观的开源数字大屏驾驶舱框架? 答案: 资源包 [弗若恩智能大屏驾驶舱开发资源包](https://www.fanruan.com/resource/152) 软件 [弗若恩报表 – 数字大屏可视化组件](https://www.fanruan.c…

    2025年12月24日
    000
  • 网站底部如何实现飘彩带效果?

    网站底部飘彩带效果的 js 库实现 许多网站都会在特殊节日或活动中添加一些趣味性的视觉效果,例如点击按钮后散发的五彩缤纷的彩带。对于一个特定的网站来说,其飘彩带效果的实现方式可能有以下几个方面: 以 https://dub.sh/ 网站为例,它底部按钮点击后的彩带效果是由 javascript 库实…

    2025年12月24日
    000
  • 网站彩带效果背后是哪个JS库?

    网站彩带效果背后是哪个js库? 当你访问某些网站时,点击按钮后,屏幕上会飘出五颜六色的彩带,营造出庆祝的氛围。这些效果是通过使用javascript库实现的。 问题: 哪个javascript库能够实现网站上点击按钮散发彩带的效果? 答案: 根据给定网站的源代码分析: 可以发现,该网站使用了以下js…

    好文分享 2025年12月24日
    100
  • 产品预览卡项目

    这个项目最初是来自 Frontend Mentor 的挑战,旨在使用 HTML 和 CSS 创建响应式产品预览卡。最初的任务是设计一张具有视觉吸引力和功能性的产品卡,能够无缝适应各种屏幕尺寸。这涉及使用 CSS 媒体查询来确保布局在不同设备上保持一致且用户友好。产品卡包含产品图像、标签、标题、描述和…

    2025年12月24日
    100
  • 如何利用 echarts-gl 绘制带发光的 3D 图表?

    如何绘制带发光的 3d 图表,类似于 echarts 中的示例? 为了实现类似的 3d 图表效果,需要引入 echarts-gl 库:https://github.com/ecomfe/echarts-gl。 echarts-gl 专用于在 webgl 环境中渲染 3d 图形。它提供了各种 3d 图…

    2025年12月24日
    000

发表回复

登录后才能评论
关注微信