【PPSIG】基于PaddlePaddle复现STANet

本文介绍基于PaddlePaddle复现STANet的过程。STANet是遥感影像变化检测算法,引入时空注意力机制与多尺度子区域方案。复现忠实原文,与官方实现存在两点差异,含特征提取器、时空注意力模块等结构,在LEVIR-CD数据集上完成训练测试,验证了复现效果。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

【ppsig】基于paddlepaddle复现stanet - 创想鸟

基于PaddlePaddle复现STANet

深度学习变化检测算法STANet的PaddlePaddle复现。已集成至PaddleRS项目。

1 简介


1.1 遥感影像变化检测

我们生活在一个变化的世界。无论是地表生态系统还是人类社会活动,都是不断演进和动态发展的过程。实时,准确地获取地表的变化信息,对于研究自然环境、人类活动以及二者之间的相互作用具有十分重要的意义。得益于卫星传感器技术的发展,人类可以通过遥感手段对地表实施大范围、长时间的监测。而遥感影像的变化检测,正是通过对同一地区不同时间的重复观测来分析区域内地物状态变化的技术。

如下图所示,变化检测任务的输入一般是多时相遥感影像。对于二值变化检测任务而言,输出是二值变化图,标示了变化发生的位置。例如图中用白色标记发生变化的像素,黑色标记未发生变化的像素。

【PPSIG】基于PaddlePaddle复现STANet - 创想鸟        

变化检测在实际生产生活中已经得到了十分广泛的应用,例如用于生态系统监测、城市发展规划、农业用地研究、受灾情况评估等等。

“The only thing that never changes is that everything changes.” – Louis L’Amour

1.2 STANet算法

STANet是Hao Chen和Zhenwei Shi提出的一种基于深度学习的遥感影像变化检测方法,论文于2020年4月发表在《Remote Sensing》期刊。STANet创新性地将时空注意力(spatial-temporal attention)机制引入变化检测领域,通过计算两幅影像的特征图中任意两个像素间的相关关系,构建变化检测任务所需要的特征时空依赖,从而实现对变化区域更加精准的识别。另一方面,考虑到遥感影像中常见的地物多尺度问题,作者提出一种将影像分割为多尺度子区域(multi-scale subregion)、并在每个子区域内进行自注意力计算的方案,以便捕捉不同尺度的时空依赖信息。STANet的整体框架如下图所示:

【PPSIG】基于PaddlePaddle复现STANet - 创想鸟        

Chen H, Shi Z. A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection. Remote Sensing. 2020; 12(10):1662. https://doi.org/10.3390/rs12101662

1.3 本项目内容

本项目遵照原论文,使用PaddlePaddle框架对STANet进行复现,并在LEVIR-CD数据集上对复现算法的效果进行评估。

2 STANet的PaddlePaddle实现


本节将基于PaddlePaddle框架,根据原论文实现STANet模型。复现过程总体上忠实于原文,但需要指出的是,原作者的官方实现和本项目中的实现有两点主要不同:

官方实现与本项目实现对ResNet backbone中最后一个stage的卷积空洞率的设置不同;官方实现的STANet采用Metric Module配合度量学习(metric learning)的方式训练,而本项目实现的STANet在最终层使用分割任务风格的dense classification head,使用交叉熵损失训练。

2.1 模型整体结构

遵循自顶向下的设计范式,首先对STANet模型的整体结构进行规划。在此之前,出于复用性方面考虑,对PaddlePaddle的一些原生模块进行封装:

In [ ]

import paddleimport paddle.nn as nnimport paddle.nn.functional as Fdef get_norm_layer():    # TODO: select appropriate norm layer.    return nn.BatchNorm2Ddef make_norm(*args, **kwargs):    norm_layer = get_norm_layer()    return norm_layer(*args, **kwargs)class BasicConv(nn.Layer):    def __init__(self,                 in_ch,                 out_ch,                 kernel_size,                 pad_mode='constant',                 bias='auto',                 norm=False,                 act=False,                 **kwargs):        super(BasicConv, self).__init__()        seq = []        if kernel_size >= 2:            seq.append(nn.Pad2D(kernel_size // 2, mode=pad_mode))        seq.append(            nn.Conv2D(                in_ch,                out_ch,                kernel_size,                stride=1,                padding=0,                bias_attr=(False if norm else None) if bias == 'auto' else bias,                **kwargs))        if norm:            if norm is True:                norm = make_norm(out_ch)            seq.append(norm)        if act:            if act is True:                act = make_act()            seq.append(act)        self.seq = nn.Sequential(*seq)    def forward(self, x):        return self.seq(x)class Conv1x1(BasicConv):    def __init__(self,                 in_ch,                 out_ch,                 pad_mode='constant',                 bias='auto',                 norm=False,                 act=False,                 **kwargs):        super(Conv1x1, self).__init__(            in_ch,            out_ch,            1,            pad_mode=pad_mode,            bias=bias,            norm=norm,            act=act,            **kwargs)class Conv3x3(BasicConv):    def __init__(self,                 in_ch,                 out_ch,                 pad_mode='constant',                 bias='auto',                 norm=False,                 act=False,                 **kwargs):        super(Conv3x3, self).__init__(            in_ch,            out_ch,            3,            pad_mode=pad_mode,            bias=bias,            norm=norm,            act=act,            **kwargs)class Identity(nn.Layer):    """A placeholder identity operator that accepts exactly one argument."""    def __init__(self, *args, **kwargs):        super(Identity, self).__init__()    def forward(self, x):        return x

   

接着,利用封装好的组件“搭积木”,定义模型:

In [ ]

class STANet(nn.Layer):    """    The STANet implementation based on PaddlePaddle.    The original article refers to        H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection"        (https://www.mdpi.com/2072-4292/12/10/1662).    Note that this implementation differs from the original work in two aspects:    1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.    2. A classification head is used in place of the original metric learning-based head to stablize the training process.    Args:        in_channels (int): The number of bands of the input images.        num_classes (int): The number of target classes.        att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'.        ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values             greater than 1, the input features will first be processed by an average pooling layer with the kernel size of             `ds_factor`, before being used to calculate the attention scores. Default: 1.    Raises:        ValueError: When `att_type` has an illeagal value (unsupported attention type).    """    def __init__(self, in_channels, num_classes, att_type='BAM', ds_factor=1):        super(STANet, self).__init__()        WIDTH = 64        self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH)        self.attend = build_sta_module(            in_ch=WIDTH, att_type=att_type, ds=ds_factor)        self.conv_out = nn.Sequential(            Conv3x3(                WIDTH, WIDTH, norm=True, act=True),            Conv3x3(WIDTH, num_classes))        self.init_weight()    def forward(self, t1, t2):        f1 = self.extract(t1)        f2 = self.extract(t2)        f1, f2 = self.attend(f1, f2)        y = paddle.abs(f1 - f2)        y = F.interpolate(            y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)        pred = self.conv_out(y)        return [pred]    def init_weight(self):        # Do nothing here as the encoder and decoder weights have already been initialized.        # Note however that currently self.attend and self.conv_out use the default initilization method.        pass

   

其中,build_feat_extractor()函数负责定义特征提取器(Feature Extractor),而build_sta_module()函数则负责定义时空注意力模块(Spatial-Temporal Attention Module)。

2.2 特征提取器

根据原文,将特征提取器分为骨干网络(backbone,一个修改版的ResNet-18)和解码器(decoder,将骨干网络提取到的不同层级特征上采样并融合)。编写build_feat_extractor()如下:

In [ ]

def build_feat_extractor(in_ch, width):    return nn.Sequential(Backbone(in_ch, 'resnet18'), Decoder(width))

   

其中,Backbone是对骨干网络部分的封装(为了使代码更具可扩展性,不止支持ResNet-18一种结构),Decoder则是对解码器部分的封装,用nn.Sequential将二者级联。

骨干网络部分的实现如下:

In [ ]

class Backbone(nn.Layer):    def __init__(self, in_ch, arch, pretrained=True, strides=(2, 1, 2, 2, 2)):        super(Backbone, self).__init__()        if arch == 'resnet18':            self.resnet = resnet18(                pretrained=pretrained,                strides=strides,                norm_layer=get_norm_layer())        elif arch == 'resnet34':            self.resnet = resnet34(                pretrained=pretrained,                strides=strides,                norm_layer=get_norm_layer())        elif arch == 'resnet50':            self.resnet = resnet50(                pretrained=pretrained,                strides=strides,                norm_layer=get_norm_layer())        else:            raise ValueError        self._trim_resnet()        if in_ch != 3:            self.resnet.conv1 = nn.Conv2D(                in_ch,                64,                kernel_size=7,                stride=strides[0],                padding=3,                bias_attr=False)    def forward(self, x):        x = self.resnet.conv1(x)        x = self.resnet.bn1(x)        x = self.resnet.relu(x)        x = self.resnet.maxpool(x)        x1 = self.resnet.layer1(x)        x2 = self.resnet.layer2(x1)        x3 = self.resnet.layer3(x2)        x4 = self.resnet.layer4(x3)        return x1, x2, x3, x4    def _trim_resnet(self):        self.resnet.avgpool = Identity()        self.resnet.fc = Identity()

   

其中,_trim_resnet()方法用于去除没有用到的全局平均池化和全连接层,以避免不必要的开销和可能导致的参数量统计错误。

对于ResNet的具体实现,为了避免重复造轮子,本项目直接在PaddlePaddle内置的ResNet基础上修改。主要修改的地方是为ResNet类的构造函数添加strides和norm_layer两个参数,前者控制stem及每个stage的output stride,后者控制使用的归一化层的类型。尽管本项目中并没有涉及到对归一化层类型的改变(仍然使用nn.BatchNorm2D),但应用这一设计有利于日后进一步的功能扩展。

In [ ]

# Adapted from https://github.com/PaddlePaddle/Paddle/blob/release/2.2/python/paddle/vision/models/resnet.py## Original head information# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.from __future__ import divisionfrom __future__ import print_functionimport paddleimport paddle.nn as nnfrom paddle.utils.download import get_weights_path_from_url__all__ = []model_urls = {    'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',                 'cf548f46534aa3560945be4b95cd11c4'),    'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams',                 '8d2275cf8706028345f78ac0e1d31969'),    'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',                 'ca6f485ee1ab0492d38f323885b0ad80'),    'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams',                  '02f35f034ca3858e1e54d4036443c92d'),    'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams',                  '7ad16a2f1e7333859ff986138630fd7a'),}class BasicBlock(nn.Layer):    expansion = 1    def __init__(self,                 inplanes,                 planes,                 stride=1,                 downsample=None,                 groups=1,                 base_width=64,                 dilation=1,                 norm_layer=None):        super(BasicBlock, self).__init__()        if norm_layer is None:            norm_layer = nn.BatchNorm2D        if dilation > 1:            raise NotImplementedError(                "Dilation > 1 not supported in BasicBlock")        self.conv1 = nn.Conv2D(            inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)        self.bn1 = norm_layer(planes)        self.relu = nn.ReLU()        self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)        self.bn2 = norm_layer(planes)        self.downsample = downsample        self.stride = stride    def forward(self, x):        identity = x        out = self.conv1(x)        out = self.bn1(out)        out = self.relu(out)        out = self.conv2(out)        out = self.bn2(out)        if self.downsample is not None:            identity = self.downsample(x)        out += identity        out = self.relu(out)        return outclass BottleneckBlock(nn.Layer):    expansion = 4    def __init__(self,                 inplanes,                 planes,                 stride=1,                 downsample=None,                 groups=1,                 base_width=64,                 dilation=1,                 norm_layer=None):        super(BottleneckBlock, self).__init__()        if norm_layer is None:            norm_layer = nn.BatchNorm2D        width = int(planes * (base_width / 64.)) * groups        self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)        self.bn1 = norm_layer(width)        self.conv2 = nn.Conv2D(            width,            width,            3,            padding=dilation,            stride=stride,            groups=groups,            dilation=dilation,            bias_attr=False)        self.bn2 = norm_layer(width)        self.conv3 = nn.Conv2D(            width, planes * self.expansion, 1, bias_attr=False)        self.bn3 = norm_layer(planes * self.expansion)        self.relu = nn.ReLU()        self.downsample = downsample        self.stride = stride    def forward(self, x):        identity = x        out = self.conv1(x)        out = self.bn1(out)        out = self.relu(out)        out = self.conv2(out)        out = self.bn2(out)        out = self.relu(out)        out = self.conv3(out)        out = self.bn3(out)        if self.downsample is not None:            identity = self.downsample(x)        out += identity        out = self.relu(out)        return outclass ResNet(nn.Layer):    """ResNet model from    `"Deep Residual Learning for Image Recognition" `_    Args:        Block (BasicBlock|BottleneckBlock): block module of model.        depth (int): layers of resnet, default: 50.        num_classes (int): output dim of last fc layer. If num_classes  0:            self.fc = nn.Linear(512 * block.expansion, num_classes)    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):        norm_layer = self._norm_layer        downsample = None        previous_dilation = self.dilation        if dilate:            self.dilation *= stride            stride = 1        if stride != 1 or self.inplanes != planes * block.expansion:            downsample = nn.Sequential(                nn.Conv2D(                    self.inplanes,                    planes * block.expansion,                    1,                    stride=stride,                    bias_attr=False),                norm_layer(planes * block.expansion), )        layers = []        layers.append(            block(self.inplanes, planes, stride, downsample, 1, 64,                  previous_dilation, norm_layer))        self.inplanes = planes * block.expansion        for _ in range(1, blocks):            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))        return nn.Sequential(*layers)    def forward(self, x):        x = self.conv1(x)        x = self.bn1(x)        x = self.relu(x)        x = self.maxpool(x)        x = self.layer1(x)        x = self.layer2(x)        x = self.layer3(x)        x = self.layer4(x)        if self.with_pool:            x = self.avgpool(x)        if self.num_classes > 0:            x = paddle.flatten(x, 1)            x = self.fc(x)        return xdef _resnet(arch, Block, depth, pretrained, **kwargs):    model = ResNet(Block, depth, **kwargs)    if pretrained:        assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(            arch)        weight_path = get_weights_path_from_url(model_urls[arch][0],                                                model_urls[arch][1])        param = paddle.load(weight_path)        model.set_dict(param)    return modeldef resnet18(pretrained=False, **kwargs):    """ResNet 18-layer model        Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet    Examples:        .. code-block:: python            from paddle.vision.models import resnet18            # build model            model = resnet18()            # build model and load imagenet pretrained weight            # model = resnet18(pretrained=True)    """    return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)def resnet34(pretrained=False, **kwargs):    """ResNet 34-layer model        Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet        Examples:        .. code-block:: python            from paddle.vision.models import resnet34            # build model            model = resnet34()            # build model and load imagenet pretrained weight            # model = resnet34(pretrained=True)    """    return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)def resnet50(pretrained=False, **kwargs):    """ResNet 50-layer model        Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet    Examples:        .. code-block:: python            from paddle.vision.models import resnet50            # build model            model = resnet50()            # build model and load imagenet pretrained weight            # model = resnet50(pretrained=True)    """    return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)def resnet101(pretrained=False, **kwargs):    """ResNet 101-layer model        Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet    Examples:        .. code-block:: python            from paddle.vision.models import resnet101            # build model            model = resnet101()            # build model and load imagenet pretrained weight            # model = resnet101(pretrained=True)    """    return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)def resnet152(pretrained=False, **kwargs):    """ResNet 152-layer model        Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet    Examples:        .. code-block:: python            from paddle.vision.models import resnet152            # build model            model = resnet152()            # build model and load imagenet pretrained weight            # model = resnet152(pretrained=True)    """    return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)

   

接着,定义特征提取器的解码器部分:

In [ ]

class Decoder(nn.Layer):    def __init__(self, f_ch):        super(Decoder, self).__init__()        self.dr1 = Conv1x1(64, 96, norm=True, act=True)        self.dr2 = Conv1x1(128, 96, norm=True, act=True)        self.dr3 = Conv1x1(256, 96, norm=True, act=True)        self.dr4 = Conv1x1(512, 96, norm=True, act=True)        self.conv_out = nn.Sequential(            Conv3x3(                384, 256, norm=True, act=True),            nn.Dropout(0.5),            Conv1x1(                256, f_ch, norm=True, act=True))        self.init_weight()    def forward(self, feats):        f1 = self.dr1(feats[0])        f2 = self.dr2(feats[1])        f3 = self.dr3(feats[2])        f4 = self.dr4(feats[3])        f2 = F.interpolate(            f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)        f3 = F.interpolate(            f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)        f4 = F.interpolate(            f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)        x = paddle.concat([f1, f2, f3, f4], axis=1)        y = self.conv_out(x)        return y

   

需要指出的是,在本项目的实现中,若没有使用预训练权重参数,则整个特征提取器中的所有权重参数均使用PaddlePaddle默认的方式初始化。

2.3 时空注意力模块

原论文提出了基础时空注意力模块(Basic spatial-temporal Attention Module, BAM)和金字塔时空注意力模块(Pyramid spatial-temporal Attention Module, PAM)两种时空注意力模块。提出BAM和PAM是论文的主要贡献点之一,而二者的实现也是STANet复现过程中相对困难的部分。仍然遵照自顶向下的原则,首先编写build_sta_module()函数:

In [ ]

def build_sta_module(in_ch, att_type, ds):    if att_type == 'BAM':        return Attention(BAM(in_ch, ds))    elif att_type == 'PAM':        return Attention(PAM(in_ch, ds))    else:        raise ValueError

   

使用Attention类封装时空注意力模块的相关逻辑,作为适配器(adapter)。

In [ ]

class Attention(nn.Layer):    def __init__(self, att):        super(Attention, self).__init__()        self.att = att    def forward(self, x1, x2):        x = paddle.stack([x1, x2], axis=-1)        y = self.att(x)        return y[..., 0], y[..., 1]

   

因为时空注意力机制涉及到两个时相的影像特征,在这里我使用一种简单的策略,即,将两个时相的输入特征图堆叠(stack)后送入具体的注意力模块实现中,然后以元组的形式返回输出特征图在堆叠维度的拆分(split)结果。

考虑到PAM是BAM的多尺度改进版本,首先实现BAM:

In [ ]

class BAM(nn.Layer):    def __init__(self, in_ch, ds):        super(BAM, self).__init__()        self.ds = ds        self.pool = nn.AvgPool2D(self.ds)        self.val_ch = in_ch        self.key_ch = in_ch // 8        self.conv_q = Conv1x1(in_ch, self.key_ch)        self.conv_k = Conv1x1(in_ch, self.key_ch)        self.conv_v = Conv1x1(in_ch, self.val_ch)        self.softmax = nn.Softmax(axis=-1)    def forward(self, x):        x = x.flatten(-2)        x_rs = self.pool(x)        b, c, h, w = paddle.shape(x_rs)        query = self.conv_q(x_rs).reshape((b, -1, h * w)).transpose((0, 2, 1))        key = self.conv_k(x_rs).reshape((b, -1, h * w))        energy = paddle.bmm(query, key)        energy = (self.key_ch**(-0.5)) * energy        attention = self.softmax(energy)        value = self.conv_v(x_rs).reshape((b, -1, w * h))        out = paddle.bmm(value, attention.transpose((0, 2, 1)))        out = out.reshape((b, c, h, w))        out = F.interpolate(out, scale_factor=self.ds)        out = out + x        return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))

   

由于BAM需要计算两个时相的特征图中所有像素的密集连接,因此其计算复杂度较高。通过设置大于1的ds参数可以对输入的特征图进行池化以减小其空间尺寸,从而达到降低计算量的目的。

相较BAM,PAM更为复杂。但通过抽象可以将其拆解为不同尺度上“子模块”结果的融合,定义如下所示:

In [ ]

class PAM(nn.Layer):    def __init__(self, in_ch, ds, scales=(1, 2, 4, 8)):        super(PAM, self).__init__()        self.stages = nn.LayerList(            [PAMBlock(                in_ch, scale=s, ds=ds) for s in scales])        self.conv_out = Conv1x1(in_ch * len(scales), in_ch, bias=False)    def forward(self, x):        x = x.flatten(-2)        res = [stage(x) for stage in self.stages]        out = self.conv_out(paddle.concat(res, axis=1))        return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))

   

对于每种尺度,“子模块”PAMBlock的定义为:

In [ ]

class PAMBlock(nn.Layer):    def __init__(self, in_ch, scale=1, ds=1):        super(PAMBlock, self).__init__()        self.scale = scale        self.ds = ds        self.pool = nn.AvgPool2D(self.ds)        self.val_ch = in_ch        self.key_ch = in_ch // 8        self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True)        self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True)        self.conv_v = Conv1x1(in_ch, self.val_ch)    def forward(self, x):        x_rs = self.pool(x)        # Get query, key, and value.        query = self.conv_q(x_rs)        key = self.conv_k(x_rs)        value = self.conv_v(x_rs)        # Split the whole image into subregions.        b, c, h, w = x_rs.shape        query = self._split_subregions(query)        key = self._split_subregions(key)        value = self._split_subregions(value)        # Perform subregion-wise attention.        out = self._attend(query, key, value)        # Stack subregions to reconstruct the whole image.        out = self._recons_whole(out, b, c, h, w)        out = F.interpolate(out, scale_factor=self.ds)        return out    def _attend(self, query, key, value):        energy = paddle.bmm(query.transpose((0, 2, 1)),                            key)  # batch matrix multiplication        energy = (self.key_ch**(-0.5)) * energy        attention = F.softmax(energy, axis=-1)        out = paddle.bmm(value, attention.transpose((0, 2, 1)))        return out    def _split_subregions(self, x):        b, c, h, w = x.shape        assert h % self.scale == 0 and w % self.scale == 0        x = x.reshape(            (b, c, self.scale, h // self.scale, self.scale, w // self.scale))        x = x.transpose((0, 2, 4, 1, 3, 5))        x = x.reshape((b * self.scale * self.scale, c, -1))        return x    def _recons_whole(self, x, b, c, h, w):        x = x.reshape(            (b, self.scale, self.scale, c, h // self.scale, w // self.scale))        x = x.transpose((0, 3, 1, 4, 2, 5)).reshape((b, c, h, w))        return x

   

PAMBlock首先将输入特征在空间维度上划分为一系列子区域,然后在每个子区域执行和BAM相同的操作,最后将得到的结果重建为和输入具有相同形状的张量。在具体代码编写方面,本项目使用向量化的方式实现上述操作,具有较高的并行度。

3 模型训练与测试


本节将对上节中实现的STANet模型进行训练与测试,以验证复现效果。

3.1 准备工作

In [ ]

# 下载安装PaddleRS(个人创建的Gitee镜像版本)!git clone https://gitee.com/bobholamovic/PaddleRS!pip install -r PaddleRS/requirements.txt!pip install -e PaddleRS/import syssys.path.append("/home/aistudio/PaddleRS")

   In [ ]

# 解压数据集!unzip -oq data/data136610/LEVIR-CD.zip -d data/data136610/

   In [ ]

# 将每幅影像切成256x256的不重叠小块,这一做法与STANet原论文一致# 关于该步骤的具体说明可以参看crop_images.py脚本中的注释!python crop_images.py 'data/data136610/LEVIR-CD/' 'data/data136610/LEVIR-CD_crop/' 256 256

   In [ ]

# 生成PaddleRS训练所需的file_list!IN_DIR='data/data136610/LEVIR-CD_crop/train' && paste <(ls ${IN_DIR}/A/*/*.png) <(ls ${IN_DIR}/B/*/*.png)  data/data136610/train_list.txt!IN_DIR='data/data136610/LEVIR-CD_crop/val' && paste <(ls ${IN_DIR}/A/*/*.png) <(ls ${IN_DIR}/B/*/*.png)  data/data136610/val_list.txt!IN_DIR='data/data136610/LEVIR-CD_crop/test' && paste <(ls ${IN_DIR}/A/*/*.png) <(ls ${IN_DIR}/B/*/*.png)  data/data136610/test_list.txt

   

3.2 模型训练

In [ ]

import paddlers as pdrsimport paddlers.transforms as T# 定义训练和验证时的transformstrain_transforms = T.Compose([    T.Resize(target_size=256),    T.RandomHorizontalFlip(),    T.Normalize(        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),])eval_transforms = T.Compose([    T.Resize(target_size=256),    T.Normalize(        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),])# 定义训练和验证所用的数据集train_dataset = pdrs.datasets.CDDataset(    data_dir='', # 由于file_list中存储的是完整的绝对路径,此处只需要给定空字符串即可    file_list='/home/aistudio/data/data136610/train_list.txt',    transforms=train_transforms,    num_workers=2,    binarize_labels=True,    shuffle=True)val_dataset = pdrs.datasets.CDDataset(    data_dir='',    file_list='/home/aistudio/data/data136610/val_list.txt',    transforms=eval_transforms,    num_workers=2,    binarize_labels=True,    shuffle=False)

   In [ ]

# 构造STANet模型,使用PAM作为时空注意力模块model = pdrs.tasks.STANet(in_channels=3, num_classes=2, att_type='PAM')# 制定定步长学习率衰减策略lr_scheduler = paddle.optimizer.lr.StepDecay(    0.001,    step_size=5000,    # 学习率衰减系数,这里指定每次减半    gamma=0.5)# 构造AdamW优化器optimizer = paddle.optimizer.AdamW(    learning_rate=lr_scheduler,    parameters=model.net.parameters())

   In [ ]

# 执行模型训练model.train(    num_epochs=100,    train_dataset=train_dataset,    train_batch_size=8,    eval_dataset=val_dataset,    optimizer=optimizer,    save_interval_epochs=10,    # 每多少次迭代记录一次日志    log_interval_steps=200,    save_dir='exp/',    # 是否使用early stopping策略,当精度不再改善时提前终止训练    early_stop=False,    # 是否启用VisualDL日志功能    use_vdl=True,    # 指定从某个检查点继续训练    resume_checkpoint=None)

   

3.3 模型测试

In [ ]

# 计算测试集上的定量指标test_dataset = pdrs.datasets.CDDataset(    data_dir='',    file_list='/home/aistudio/data/data136610/test_list.txt',    transforms=eval_transforms,    num_workers=2,    binarize_labels=True,    shuffle=False)model = pdrs.tasks.load_model('exp/best_model')res = model.evaluate(test_dataset)print(f"IoU={res['iou']:.4f}, F1-score={res['f1']:.4f}, OA={res['oacc']:.4f}, Kappa={res['kappa']:.4f}")

   In [21]

# 可视化预测结果%matplotlib inlineimport numpy as npfrom PIL import Imagefrom matplotlib import pyplot as pltT1_PATH = "data/data136610/LEVIR-CD_crop/test/A/test_1/test_1_3.png"T2_PATH = "data/data136610/LEVIR-CD_crop/test/B/test_1/test_1_3.png"GT_PATH = "data/data136610/LEVIR-CD_crop/test/label/test_1/test_1_3.png"pred = model.predict(    (T1_PATH, T2_PATH),     transforms=eval_transforms)# 从左到右依次显示:第一时相影像、第二时相影像、STANet预测结果以及真值标签plt.figure(constrained_layout=True)plt.subplot(141)plt.imshow(Image.open(T1_PATH))plt.gca().set_axis_off()plt.title("Image1")plt.subplot(142)plt.imshow(Image.open(T2_PATH))plt.gca().set_axis_off()plt.title("Image2")plt.subplot(143)plt.imshow((pred['label_map']*255).astype('uint8'))plt.gca().set_axis_off()plt.title("Pred")plt.subplot(144)plt.imshow((np.asarray(Image.open(GT_PATH))*255).astype('uint8'))plt.gca().set_axis_off()plt.title("GT")plt.show()

       

               

以上就是【PPSIG】基于PaddlePaddle复现STANet的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/46118.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月7日 16:45:49
怎样调整VSCode的字体和字号?
下一篇 2025年11月7日 16:45:52

相关推荐

  • 开源免费PHP工具 PHP开发效率提升利器

    推荐开源免费PHP开发工具以提升效率:VS Code、Sublime Text轻量高效,PhpStorm专业强大;调试用Xdebug、Kint、Ray;依赖管理选Composer;代码质量工具包括PHPStan、Psalm、PHP_CodeSniffer;数据库管理可用%ignore_a_1%MyA…

    2026年5月10日
    000
  • Matplotlib 地图中多类型图例的创建与优化

    Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化

    本教程旨在解决matplotlib地图可视化中,如何在一个图例中同时展示颜色块(如区域分类)和自定义标记(如特定兴趣点)的问题。文章详细介绍了当传统`patch`对象无法正确显示标记时,如何利用`matplotlib.lines.line2d`创建标记图例句柄,并将其与颜色块图例句柄合并,从而生成一…

    2026年5月10日 用户投稿
    100
  • Golang JSON序列化:控制敏感字段暴露的最佳实践

    本教程探讨golang中如何高效控制结构体字段在json序列化时的可见性。当需要将包含敏感信息的结构体数组转换为json响应时,通过利用`encoding/json`包提供的结构体标签,特别是`json:”-“`,可以轻松实现对特定字段的忽略,从而避免敏感数据泄露,确保api…

    2026年5月10日
    000
  • 利用海象运算符简化条件赋值:Python教程与最佳实践

    本文旨在探讨Python中海象运算符(:=)在条件赋值场景下的应用。通过对比传统if/else语句与海象运算符,以及条件表达式,分析海象运算符在简化代码、提高可读性方面的优势与局限性。并通过具体示例,展示如何在列表推导式等场景下合理使用海象运算符,同时强调其潜在的复杂性及替代方案,帮助开发者更好地掌…

    2026年5月10日
    100
  • 怎么在PHP代码中实现图片上传功能_PHP图片上传功能实现与安全处理教程

    首先创建含enctype的HTML表单,再用PHP接收文件,检查目录、移动临时文件,验证类型与大小,生成唯一文件名,并调整php.ini限制以确保上传成功。 如果您尝试在PHP项目中添加图片上传功能,但服务器无法正确接收或保存文件,则可能是由于表单配置、文件处理逻辑或安全限制的问题。以下是实现该功能…

    2026年5月10日
    100
  • 比特币新手教程 比特币交易平台有哪些

    比特币是一种去中心化的数字货币,基于区块链技术实现点对点交易,具有匿名性、有限发行和不可篡改等特点;新手可通过交易所购买,P2P交易获得比特币,常用平台包括Binance、OKX和Huobi;交易流程包括注册账户、实名认证、绑定支付方式、充值法币并下单购买,可选择市价单或限价单;比特币存储方式有交易…

    2026年5月10日
    000
  • c++中的SFINAE技术是什么_c++模板编程中的SFINAE原理与应用

    SFINAE 是“替换失败不是错误”的原则,指模板实例化时若参数替换导致错误,只要存在其他合法候选,编译器不报错而是继续重载决议。它用于条件启用模板、类型检测等场景,如通过 decltype 或 enable_if 控制函数重载,实现类型特征判断。尽管 C++20 引入 Concepts 简化了部分…

    2026年5月10日
    000
  • Go语言mgo查询构建:深入理解bson.M与日期范围查询的正确实践

    本文旨在解决go语言mgo库中构建复杂查询时,特别是涉及嵌套`bson.m`和日期范围筛选的常见错误。我们将深入剖析`bson.m`的类型特性,解释为何直接索引`interface{}`会导致“invalid operation”错误,并提供一种推荐的、结构清晰的代码重构方案,以确保查询条件能够正确…

    2026年5月10日
    100
  • RichHandler与Rich Progress集成:解决显示冲突的教程

    在使用rich库的`richhandler`进行日志输出并同时使用`progress`组件时,可能会遇到显示错乱或溢出问题。这通常是由于为`richhandler`和`progress`分别创建了独立的`console`实例导致的。解决方案是确保日志处理器和进度条组件共享同一个`console`实例…

    2026年5月10日
    000
  • Golang goroutine与channel调试技巧

    使用go run -race检测数据竞争,结合runtime.NumGoroutine监控协程数量,通过pprof分析阻塞调用栈,利用select超时避免永久阻塞,有效排查goroutine泄漏、死锁和数据竞争问题。 Go语言的goroutine和channel是并发编程的核心,但它们也带来了调试上…

    2026年5月10日
    000
  • 使用 Jupyter Notebook 进行探索性数据分析

    Jupyter Notebook通过单元格实现代码与Markdown结合,支持数据导入(pandas)、清洗(fillna)、探索(matplotlib/seaborn可视化)、统计分析(describe/corr)和特征工程,便于记录与分享分析过程。 Jupyter Notebook 是进行探索性…

    2026年5月10日
    000
  • 《魔兽世界》将于6月11日开启国服回归技术测试

    《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试

    《%ign%ignore_a_1%re_a_1%》官方宣布,将于6月11日开启国服回归技术测试,时间为7天,并称可以在6月内正式开服,玩家们可以访问官网下载战网客户端并预下载“巫妖王之怒”客户端,技术测试详情见下图。 WordAi WordAI是一个AI驱动的内容重写平台 53 查看详情 以上就是《…

    2026年5月10日 用户投稿
    200
  • 如何在HTML中插入表单元素_HTML表单控件与输入类型使用指南

    HTML表单通过标签构建,包含action和method属性定义数据提交目标与方式,常用input类型如text、password、email等适配不同输入需求,配合label、required、placeholder提升可用性,结合textarea、select、button等控件实现完整交互,是…

    2026年5月10日
    100
  • 创建指定大小并填充特定数据的Golang文件教程

    本文将介绍如何使用Golang创建一个指定大小的文件,并用特定数据填充它。我们将使用 `os` 包提供的函数来创建和截断文件,从而实现快速生成大文件的目的。示例代码展示了如何创建一个10MB的文件,并将其填充为全零数据。掌握这些方法,可以方便地在例如日志系统或磁盘队列等场景中,预先创建测试文件或初始…

    2026年5月10日
    000
  • Python命令怎样使用profile分析脚本性能 Python命令性能分析的基础教程

    使用Python的cProfile模块分析脚本性能最直接的方式是通过命令行执行python -m cProfile your_script.py,它会输出每个函数的调用次数、总耗时、累积耗时等关键指标,帮助定位性能瓶颈;为进一步分析,可将结果保存为文件python -m cProfile -o ou…

    2026年5月10日
    000
  • 使用 WebCodecs VideoDecoder 实现精确逐帧回退

    本文档旨在解决在使用 WebCodecs VideoDecoder 进行视频解码时,实现精确逐帧回退的问题。通过比较帧的时间戳与目标帧的时间戳,可以避免渲染中间帧,从而提高用户体验。本文将提供详细的解决方案和示例代码,帮助开发者实现精确的视频帧控制。 在使用 WebCodecs VideoDecod…

    2026年5月10日
    000
  • 如何插入查询结果数据_SQL插入Select查询结果方法

    如何插入查询结果数据_SQL插入Select查询结果方法如何插入查询结果数据_SQL插入Select查询结果方法如何插入查询结果数据_SQL插入Select查询结果方法如何插入查询结果数据_SQL插入Select查询结果方法

    使用INSERT INTO…SELECT语句可高效插入数据,通过NOT EXISTS、LEFT JOIN、MERGE语句或唯一约束避免重复;表结构不一致时可通过别名、类型转换、默认值或计算字段处理;结合存储过程可提升可维护性,支持参数化与动态SQL。 将查询结果数据插入到另一个表中,可以…

    2026年5月10日 用户投稿
    000
  • Debian Copilot的社区活跃度如何

    debian copilot是codeberg社区维护的ai助手,旨在为debian用户提供服务。尽管搜索结果中没有直接提供关于debian copilot社区支持活跃度的具体数据,但我们可以通过debian社区的整体活跃度和特点来推断其活跃性。 Debian社区的一般情况: Debian拥有详尽的…

    2026年5月10日
    000
  • Discord.py 交互按钮超时与持久化解决方案

    本教程旨在解决Discord.py中交互按钮在一段时间后出现“This Interaction Failed”错误的问题。我们将深入探讨视图(View)的超时机制,并提供通过正确设置timeout参数以及利用bot.add_view()方法实现按钮持久化的具体方案,确保您的机器人交互功能稳定可靠,即…

    2026年5月10日
    000
  • Python递归函数追踪与性能考量:以序列打印为例

    本文深入探讨了Python中一种递归打印序列元素的方法,并着重演示了如何通过引入缩进参数来有效追踪递归函数的执行流程和参数变化。通过实际代码示例,文章揭示了递归调用可能带来的潜在性能开销,特别是对调用栈空间的需求,以及Python默认递归深度限制可能导致的错误,为读者提供了理解和优化递归算法的实用见…

    2026年5月10日
    000

发表回复

登录后才能评论
关注微信