在TensorFlow 2中实现完全卷积网络(FCN)

在TensorFlow 2中实现完全卷积网络(FCN)

作者 | himanshu rawlani

来源 | Medium

编辑 | 代码医生团队

卷积神经网络(CNN)非常适合计算机视觉任务。使用对大型图像集(如ImageNet,COCO等)进行训练的预训练模型,可以快速使这些体系结构专业化,以适合独特数据集。此过程称为迁移学习。但是有一个陷阱!用于图像分类和对象检测任务的预训练模型通常在固定的输入图像尺寸上训练。这些通常从224x224x3到某个范围变化,512x512x3并且大多数具有1的长宽比,即图像的宽度和高度相等。如果它们不相等,则将图像调整为相等的高度和宽度。

较新的体系结构确实能够处理可变的输入图像大小,但是与图像分类任务相比,它在对象检测和分割任务中更为常见。最近遇到了一个有趣的用例,其中有5个不同类别的图像,每个类别都有微小的差异。此外图像的纵横比也比平常高。图像的平均高度约为30像素,宽度约为300像素。这是一个有趣的原因,其原因如下:

调整图像大小容易使重要功能失真预训练的架构非常庞大,并且总是过度拟合数据集任务要求低延迟

需要具有可变输入尺寸的CNN

尝试了MobileNet和EfficientNet的基本模型,但没有任何效果。需要一种对输入图像大小没有任何限制并且可以执行手边的图像分类任务的网络。震惊的第一件事是完全卷积网络(FCN)。FCN是一个不包含任何“密集”层的网络(如在传统的CNN中一样),而是包含1×1卷积,用于执行完全连接的层(密集层)的任务。尽管没有密集层可以输入可变的输入,但是有两种技术可以在保留可变输入尺寸的同时使用密集层。本教程描述了其中一些技术。在本教程中,将执行以下步骤:

使用Keras在TensorFlow中构建完全卷积网络(FCN)下载并拆分样本数据集在Keras中创建生成器以加载和处理内存中的一批数据训练具有可变批次尺寸的网络使用TensorFlow Serving部署模型

获取代码

本文中的代码片段仅突出实际脚本的一部分,有关完整代码,请参阅GitHub存储库。

https://github.com/himanshurawlani/fully_convolutional_network.git

1.设计引擎(model.py)

通过堆叠由2D卷积层(Conv2D)和所需的正则化(Dropout和BatchNormalization)组成的卷积块来构建FCN模型。正则化可防止过度拟合并有助于快速收敛。还添加了一个激活层来合并非线性。在Keras中,输入批次尺寸是自动添加的,不需要在输入层中指定它。由于输入图像的高度和宽度是可变的,因此将输入形状指定为(None, None, 3)。3表示图像中的通道数,该数量对于彩色图像(RGB)是固定的。

代码语言:javascript代码运行次数:0运行复制

import tensorflow as tf def FCN_model(len_classes=5, dropout_rate=0.2):        # Input layer    input = tf.keras.layers.Input(shape=(None, None, 3))     # A convolution block    x = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=1)(input)    x = tf.keras.layers.Dropout(dropout_rate)(x)    x = tf.keras.layers.BatchNormalization()(x)    x = tf.keras.layers.Activation('relu')(x)        # Stack of convolution blocks    .    .    .

最小图像尺寸要求

在输入施加卷积块之后,输入的高度和宽度将降低基于所述值kernel_size和strides。如果输入图像的尺寸太小,那么可能无法达到下一个卷积块所需的最小高度和宽度(应大于或等于内核尺寸)。确定最小输入尺寸的尝试和错误方法如下:

确定要堆叠的卷积块数选择任何输入形状以说出(32, 32, 3)并堆叠数量越来越多的通道的卷积块尝试构建模型并打印model.summary()以查看每个图层的输出形状。确保(1, 1, num_of_filters)从最后一个卷积块获得输出尺寸(这将被输入到完全连接的层)。尝试减小/增大输入形状,内核大小或步幅,以满足步骤4中的条件。满足条件的输入形状以及其他配置是网络所需的最小输入尺寸。

还有,以计算输出体积的空间大小,其所示的输入体积的函数的数学方式这里。找到最小输入尺寸后,现在需要将最后一个卷积块的输出传递到完全连接的层。但是任何尺寸大于最小输入尺寸的输入都需要汇总以满足步骤4中的条件。了解如何使用我们的主要成分来做到这一点。

http://cs231n.github.io/convolutional-networks/#conv

主要成分

全连接层(FC层)将执行分类任务。可以通过两种方式构建FC层:

致密层1x1卷积

如果要使用密集层,则必须固定模型输入尺寸,因为必须预先定义作为密集层输入的参数数量才能创建密集层。具体来说,希望(height, width, num_of_filters)最后一个卷积块的输出中的高度和宽度为常数或1。滤波器的数量始终是固定的,因为这些值是在每个卷积块中定义的。

1x1卷积的输入尺寸可以是(1, 1, num_of_filters)或(height, width, num_of_filters)模仿它们沿num_of_filters尺寸方向FC层的功能。但是,在1x1卷积之后,最后一层(Softmax激活层)的输入必须具有固定的长度(类数)。

主要成分:GlobalMaxPooling2D() / GlobalAveragePooling2D()。Keras中的这些层将尺寸的输入转换(height, width, num_of_filters)为(1, 1, num_of_filters)实质上沿尺寸的每个值的最大值或平均值,用于沿尺寸的每个过滤器num_of_filters。

代码语言:javascript代码运行次数:0运行复制

# Uncomment the below line if you're using dense layers# x = tf.keras.layers.GlobalMaxPooling2D()(x) # Fully connected layer 1# x = tf.keras.layers.Dropout(dropout_rate)(x)# x = tf.keras.layers.BatchNormalization()(x)# x = tf.keras.layers.Dense(units=64)(x)# x = tf.keras.layers.Activation('relu')(x) # Fully connected layer 1x = tf.keras.layers.Conv2D(filters=64, kernel_size=1, strides=1)(x)x = tf.keras.layers.Dropout(dropout_rate)(x)x = tf.keras.layers.BatchNormalization()(x)x = tf.keras.layers.Activation('relu')(x) # Fully connected layer 2# x = tf.keras.layers.Dropout(dropout_rate)(x)# x = tf.keras.layers.BatchNormalization()(x)# x = tf.keras.layers.Dense(units=len_classes)(x)# predictions = tf.keras.layers.Activation('softmax')(x) # Fully connected layer 2x = tf.keras.layers.Conv2D(filters=len_classes, kernel_size=1, strides=1)(x)x = tf.keras.layers.Dropout(dropout_rate)(x)x = tf.keras.layers.BatchNormalization()(x)x = tf.keras.layers.GlobalMaxPooling2D()(x)predictions = tf.keras.layers.Activation('softmax')(x) model = tf.keras.Model(inputs=input, outputs=predictions)print(model.summary())

密集层与1x1卷积

该代码包括密集层(注释掉)和1x1卷积。在使用两种配置构建和训练模型之后,这里是一些观察结果:

两种模型都包含相同数量的可训练参数。类似的训练和推理时间。密集层比1x1卷积的泛化效果更好。

第三点不能一概而论,因为它取决于诸如数据集中的图像数量,使用的数据扩充,模型初始化等因素。但是这些是实验中的观察结果。可以通过执行命令来独立运行脚本,以测试是否已成功构建模型$python model.py。

2.下载fuel(data.py)

采风问卷 采风问卷

采风问卷是一款全新体验的调查问卷、表单、投票、评测的调研平台,新奇的交互形式,漂亮的作品,让客户眼前一亮,让创作者获得更多的回复。

采风问卷 20 查看详情 采风问卷

本教程中使用的flowers数据集主要旨在了解在训练具有可变输入维度的模型时面临的挑战。测试FCN模型的一些有趣的数据集可能来自医学成像领域,其中包含对图像分类至关重要的微观特征,而其他数据集包含的几何图案/形状在调整图像大小后可能会失真。

1.提供的脚本(data.py)需要独立运行($python data.py)。它将执行以下任务:

2.下载包含5类(“雏菊”,“蒲公英”,“玫瑰”,“向日葵”,“郁金香”)的花卉数据集。有关数据集的更多细节在这里。

https://www.tensorflow.org/datasets/catalog/tf_flowers

3.将数据集分为训练和验证集。可以设置要复制到训练和验证集中的图像数量。

提供有关数据集的统计信息,例如图像的最小,平均和最大高度和宽度。

此脚本使用来下载.tar文件并将其内容提取到当前目录中keras.utils.get_file()。如果想使用TensorFlow数据集(TFDS),可以查看本教程,该教程说明了TFDS以及数据扩充的用法。

3.特殊化carburetor(generator.py)

想在不同的输入维度上训练模型。给定批次和批次之间的每个图像都有不同的尺寸。所以有什么问题?退后一步,回顾一下如何训练传统的图像分类器。在传统的图像分类器中,将图像调整为给定尺寸,通过转换为numpy数组或张量将其打包成批,然后将这批数据通过模型进行正向传播。在整个批次中评估指标(损失,准确性等)。根据这些指标计算要反向传播的梯度。

无法调整图像大小(因为我们将失去微观特征)。现在由于无法调整图像的大小,因此无法将其转换为成批的numpy数组。这是因为如果有一个10张图像的列表,(height, width, 3)它们的height和值不同,width并且尝试将其传递给np.array(),则结果数组的形状将为(10,)and not (10, height, width, 3)!但是模型期望输入尺寸为后一种形状。一种解决方法是编写一个自定义训练循环,该循环执行以下操作:

通过将通过每个图像,在列表中(分批),通过模型(height, width, 3)来(1, height, width, 3)使用np.expand_dims(img, axis=0)。累积python列表(批处理)中每个图像的度量。使用累积的指标计算损耗和梯度。将渐变更新应用到模型。重置指标的值并创建新的图像列表(批次)。

尝试了上述步骤,但建议不要采用上述策略。它很费力,导致代码复杂且不可持续,并且运行速度非常慢!每个人都喜欢优雅的 model.fit()和model.fit_generator()。后者是将在这里使用的!但是首先是化油器。

化油器是一种以合适的空燃比混合用于内燃机的空气和燃料的装置。这就是所需要的,空气!找到批处理中图像的最大高度和宽度,并用零填充每个其他图像,以使批处理中的每个图像都具有相等的尺寸。现在可以轻松地将其转换为numpy数组或张量,并将其传递给fit_generator()。该模型会自动学习忽略零(基本上是黑色像素),并从填充图像的预期部分学习特征。这样就有了一个具有相等图像尺寸的批处理,但是每个批处理具有不同的形状(由于批处理中图像的最大高度和宽度不同)。可以generator.py使用独立运行文件$python generator.py并交叉检查输出。

代码语言:javascript代码运行次数:0运行复制

def construct_image_batch(image_group, BATCH_SIZE):    # get the max image shape    max_shape = tuple(max(image.shape[x] for image in image_group) for x in range(3))     # construct an image batch object    image_batch = np.zeros((BATCH_SIZE,) + max_shape, dtype='float32')     # copy all images to the upper left part of the image batch object    for image_index, image in enumerate(image_group):        image_batch[image_index, :image.shape[0], :image.shape[1], :image.shape[2]] = image     return image_batch

4.点燃认知(train.py)

训练脚本导入并实例化以下类:

生成器:需要指定到创建的路径train和val目录data.py。FCN_model:需要指定最终输出层中所需的类数。

将上述对象传递给train()使用Adam优化器和分类交叉熵损失函数编译模型的函数。创建一个检查点回调,以在训练期间保存最佳模型。最佳模型是根据每个时期结束时的验证集计算出的损失值确定的。fit_generator()函数在很大程度上简化了代码。

代码语言:javascript代码运行次数:0运行复制

def train(model, train_generator, val_generator, epochs = 50):    model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0001),                    loss='categorical_crossentropy',                    metrics=['accuracy'])     checkpoint_path = './snapshots'    os.makedirs(checkpoint_path, exist_ok=True)    model_path = os.path.join(checkpoint_path, 'model_epoch_{epoch:02d}_loss_{loss:.2f}_acc_{acc:.2f}_val_loss_{val_loss:.2f}_val_acc_{val_acc:.2f}.h5')        history = model.fit_generator(generator=train_generator,                                    steps_per_epoch=len(train_generator),                                    epochs=epochs,                                    callbacks=[tf.keras.callbacks.ModelCheckpoint(model_path, monitor='val_loss', save_best_only=True, verbose=1)],                                    validation_data=val_generator,                                    validation_steps=len(val_generator))     return history

建议在Google Colab上进行训练,除非本地计算机上有GPU。GitHub存储库包含一个Colab笔记本,该笔记本将训练所需的所有内容组合在一起。可以在Colab本身中修改python脚本,并在选择的数据集上训练不同的模型配置。完成训练后,可以从Colab中的“文件”选项卡将最佳快照下载到本地计算机。

5.使用TensorFlow Serving(inference.py)部署模型

下载模型后,需要使用将其导出为SavedModel格式export_savedmodel.py。.h5在主要功能中指定下载模型(文件)的路径,然后使用命令执行脚本$python export_savedmodel.py。该脚本使用TensorFlow 2.0中的新功能,该功能从.h5文件中加载Keras模型并将其保存为TensorFlow SavedModel格式。SavedModel将导出到export_path脚本中指定的位置。TensorFlow服务docker映像需要此SavedModel。

代码语言:javascript代码运行次数:0运行复制

def export(input_h5_file, export_path):    # The export path contains the name and the version of the model    tf.keras.backend.set_learning_phase(0)  # Ignore dropout at inference    model = tf.keras.models.load_model(input_h5_file)    model.save(export_path, save_format='tf')    print(f"SavedModel created at {export_path}")

要启动TensorFlow Serving服务器,请转到导出SavedModel的目录(./flower_classifier在这种情况下)并运行以下命令(注意:计算机上必须安装了Docker):

代码语言:javascript代码运行次数:0运行复制

$ docker run --rm -t -p 8501:8501 -v "$(pwd):/models/flower_classifier" -e MODEL_NAME=flower_classifier --name flower_classifier tensorflow/serving

可以使用$ docker ps命令验证容器在后台运行。还可以使用查看容器日志$ docker logs your_container_id。该inference.py脚本包含用于构建具有统一图像尺寸的批次的代码,并将这些批次作为POST请求发送到TensorFlow服务服务器。从服务器接收的输出被解码并在终端中打印。

代码语言:javascript代码运行次数:0运行复制

def make_serving_request(image_batch):    data = json.dumps({"signature_name": "serving_default",                       "instances": image_batch.tolist()})     headers = {"content-type": "application/json"}     os.environ['NO_PROXY'] = 'localhost'    json_response = requests.post(        'http://localhost:8501/v1/models/flower_classifier:predict', data=data, headers=headers)     predictions = json.loads(json_response.text)['predictions']     return predictions

梦想的传达

本教程仅介绍机器学习工作流程中的单个组件。机器学习管道包括针对组织及其用例的大量训练,推断和监视周期。建立这些管道需要对驾驶员,乘客和车辆路线有更深入的了解。只有这样,才能实现理想的运输工具!

以上就是在TensorFlow 2中实现完全卷积网络(FCN)的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/259537.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
Mysql长事务的影响是什么
上一篇 2025年11月4日 08:38:07
linux rm是什么意思
下一篇 2025年11月4日 08:38:10

相关推荐

  • 修复Django电商项目中AJAX过滤产品列表图片不显示问题

    在Django电商项目中,当使用AJAX动态加载过滤后的产品列表时,常遇到图片无法正常显示的问题。这通常是由于前端模板中图片加载方式(如data-setbg属性结合JavaScript库)与AJAX动态内容更新机制不兼容所致。解决方案是直接在AJAX返回的HTML中使用标准的标签来渲染图片,确保浏览…

    2026年5月10日
    000
  • 开源免费PHP工具 PHP开发效率提升利器

    推荐开源免费PHP开发工具以提升效率:VS Code、Sublime Text轻量高效,PhpStorm专业强大;调试用Xdebug、Kint、Ray;依赖管理选Composer;代码质量工具包括PHPStan、Psalm、PHP_CodeSniffer;数据库管理可用%ignore_a_1%MyA…

    2026年5月10日
    000
  • Matplotlib 地图中多类型图例的创建与优化

    Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化

    本教程旨在解决matplotlib地图可视化中,如何在一个图例中同时展示颜色块(如区域分类)和自定义标记(如特定兴趣点)的问题。文章详细介绍了当传统`patch`对象无法正确显示标记时,如何利用`matplotlib.lines.line2d`创建标记图例句柄,并将其与颜色块图例句柄合并,从而生成一…

    2026年5月10日 用户投稿
    100
  • Golang JSON序列化:控制敏感字段暴露的最佳实践

    本教程探讨golang中如何高效控制结构体字段在json序列化时的可见性。当需要将包含敏感信息的结构体数组转换为json响应时,通过利用`encoding/json`包提供的结构体标签,特别是`json:”-“`,可以轻松实现对特定字段的忽略,从而避免敏感数据泄露,确保api…

    2026年5月10日
    000
  • 利用海象运算符简化条件赋值:Python教程与最佳实践

    本文旨在探讨Python中海象运算符(:=)在条件赋值场景下的应用。通过对比传统if/else语句与海象运算符,以及条件表达式,分析海象运算符在简化代码、提高可读性方面的优势与局限性。并通过具体示例,展示如何在列表推导式等场景下合理使用海象运算符,同时强调其潜在的复杂性及替代方案,帮助开发者更好地掌…

    2026年5月10日
    100
  • 比特币新手教程 比特币交易平台有哪些

    比特币是一种去中心化的数字货币,基于区块链技术实现点对点交易,具有匿名性、有限发行和不可篡改等特点;新手可通过交易所购买,P2P交易获得比特币,常用平台包括Binance、OKX和Huobi;交易流程包括注册账户、实名认证、绑定支付方式、充值法币并下单购买,可选择市价单或限价单;比特币存储方式有交易…

    2026年5月10日
    000
  • Golang gRPC流式请求异常处理

    在Golang的gRPC流式通信中,必须通过context.Context处理异常。应监听上下文取消或超时,及时释放资源,设置合理超时,避免连接长时间挂起,并在goroutine中通过context控制生命周期。 在使用 Golang 和 gRPC 实现流式通信时,异常处理是确保服务健壮性的关键部分…

    2026年5月10日
    000
  • Go语言mgo查询构建:深入理解bson.M与日期范围查询的正确实践

    本文旨在解决go语言mgo库中构建复杂查询时,特别是涉及嵌套`bson.m`和日期范围筛选的常见错误。我们将深入剖析`bson.m`的类型特性,解释为何直接索引`interface{}`会导致“invalid operation”错误,并提供一种推荐的、结构清晰的代码重构方案,以确保查询条件能够正确…

    2026年5月10日
    100
  • vscode上怎么运行html_vscode上运行html步骤【指南】

    首先保存文件为.html格式,再通过浏览器或Live Server插件打开预览;推荐安装Live Server实现本地服务器运行与实时刷新,提升开发体验。 在 VS Code 上运行 HTML 文件并不需要复杂的配置,只需几个简单步骤即可预览页面效果。VS Code 本身是一个代码编辑器,不直接运行…

    2026年5月10日
    100
  • RichHandler与Rich Progress集成:解决显示冲突的教程

    在使用rich库的`richhandler`进行日志输出并同时使用`progress`组件时,可能会遇到显示错乱或溢出问题。这通常是由于为`richhandler`和`progress`分别创建了独立的`console`实例导致的。解决方案是确保日志处理器和进度条组件共享同一个`console`实例…

    2026年5月10日
    000
  • 修复点击时按钮抖动:CSS垂直对齐实践

    本文探讨了在Web开发中,交互式按钮(如播放/暂停按钮)在点击时发生意外垂直位移的问题。通过分析CSS样式变化对元素布局的影响,我们发现这是由于按钮不同状态下的边框样式和内边距改变,以及默认的垂直对齐行为共同作用所致。核心解决方案是利用CSS的vertical-align属性,将其设置为middle…

    2026年5月10日
    100
  • Golang goroutine与channel调试技巧

    使用go run -race检测数据竞争,结合runtime.NumGoroutine监控协程数量,通过pprof分析阻塞调用栈,利用select超时避免永久阻塞,有效排查goroutine泄漏、死锁和数据竞争问题。 Go语言的goroutine和channel是并发编程的核心,但它们也带来了调试上…

    2026年5月10日
    000
  • 使用 Jupyter Notebook 进行探索性数据分析

    Jupyter Notebook通过单元格实现代码与Markdown结合,支持数据导入(pandas)、清洗(fillna)、探索(matplotlib/seaborn可视化)、统计分析(describe/corr)和特征工程,便于记录与分享分析过程。 Jupyter Notebook 是进行探索性…

    2026年5月10日
    000
  • 如何在HTML中插入表单元素_HTML表单控件与输入类型使用指南

    HTML表单通过标签构建,包含action和method属性定义数据提交目标与方式,常用input类型如text、password、email等适配不同输入需求,配合label、required、placeholder提升可用性,结合textarea、select、button等控件实现完整交互,是…

    2026年5月10日
    100
  • 前端缓存策略与JavaScript存储管理

    根据数据特性选择合适的存储方式并制定清晰的读写与清理逻辑,能显著提升前端性能;合理运用Cookie、localStorage、sessionStorage、IndexedDB及Cache API,结合缓存策略与定期清理机制,可在保证用户体验的同时避免安全与性能隐患。 前端缓存和JavaScript存…

    2026年5月10日
    200
  • HTML5网页如何实现手势操作 HTML5网页移动端交互的处理技巧

    首先利用原生touch事件实现滑动判断,再通过preventDefault解决滚动冲突,接着引入Hammer.js处理复杂手势,最后通过优化点击区域、避免事件冲突和增加视觉反馈提升体验。 在移动端浏览器中,HTML5网页可以通过触摸事件实现手势操作,提升用户体验。虽然原生JavaScript提供了基…

    2026年5月10日
    000
  • 深入理解 Express.js 中 next() 参数的作用与中间件机制

    本文深入探讨 express.js 中间件函数中的 `next()` 参数。它负责将控制权传递给请求-响应周期中的下一个中间件或路由处理程序。文章将详细解释 `next()` 的工作原理、中间件的注册与执行顺序,以及不正确使用 `next()` 可能导致请求挂起的风险,并通过代码示例和实际应用场景,…

    2026年5月10日
    000
  • 创建指定大小并填充特定数据的Golang文件教程

    本文将介绍如何使用Golang创建一个指定大小的文件,并用特定数据填充它。我们将使用 `os` 包提供的函数来创建和截断文件,从而实现快速生成大文件的目的。示例代码展示了如何创建一个10MB的文件,并将其填充为全零数据。掌握这些方法,可以方便地在例如日志系统或磁盘队列等场景中,预先创建测试文件或初始…

    2026年5月10日
    000
  • Python命令怎样使用profile分析脚本性能 Python命令性能分析的基础教程

    使用Python的cProfile模块分析脚本性能最直接的方式是通过命令行执行python -m cProfile your_script.py,它会输出每个函数的调用次数、总耗时、累积耗时等关键指标,帮助定位性能瓶颈;为进一步分析,可将结果保存为文件python -m cProfile -o ou…

    2026年5月10日
    000
  • Python递归函数追踪与性能考量:以序列打印为例

    本文深入探讨了Python中一种递归打印序列元素的方法,并着重演示了如何通过引入缩进参数来有效追踪递归函数的执行流程和参数变化。通过实际代码示例,文章揭示了递归调用可能带来的潜在性能开销,特别是对调用栈空间的需求,以及Python默认递归深度限制可能导致的错误,为读者提供了理解和优化递归算法的实用见…

    2026年5月10日
    000

发表回复

登录后才能评论
关注微信