【PPSIG】基于PaddlePaddle复现STANet

本文介绍基于PaddlePaddle复现STANet的过程。STANet是遥感影像变化检测算法,引入时空注意力机制与多尺度子区域方案。复现忠实原文,与官方实现存在两点差异,含特征提取器、时空注意力模块等结构,在LEVIR-CD数据集上完成训练测试,验证了复现效果。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

【ppsig】基于paddlepaddle复现stanet - 创想鸟

基于PaddlePaddle复现STANet

深度学习变化检测算法STANet的PaddlePaddle复现。已集成至PaddleRS项目。

1 简介


1.1 遥感影像变化检测

我们生活在一个变化的世界。无论是地表生态系统还是人类社会活动,都是不断演进和动态发展的过程。实时,准确地获取地表的变化信息,对于研究自然环境、人类活动以及二者之间的相互作用具有十分重要的意义。得益于卫星传感器技术的发展,人类可以通过遥感手段对地表实施大范围、长时间的监测。而遥感影像的变化检测,正是通过对同一地区不同时间的重复观测来分析区域内地物状态变化的技术。

如下图所示,变化检测任务的输入一般是多时相遥感影像。对于二值变化检测任务而言,输出是二值变化图,标示了变化发生的位置。例如图中用白色标记发生变化的像素,黑色标记未发生变化的像素。

【PPSIG】基于PaddlePaddle复现STANet - 创想鸟        

变化检测在实际生产生活中已经得到了十分广泛的应用,例如用于生态系统监测、城市发展规划、农业用地研究、受灾情况评估等等。

“The only thing that never changes is that everything changes.” – Louis L’Amour

1.2 STANet算法

STANet是Hao Chen和Zhenwei Shi提出的一种基于深度学习的遥感影像变化检测方法,论文于2020年4月发表在《Remote Sensing》期刊。STANet创新性地将时空注意力(spatial-temporal attention)机制引入变化检测领域,通过计算两幅影像的特征图中任意两个像素间的相关关系,构建变化检测任务所需要的特征时空依赖,从而实现对变化区域更加精准的识别。另一方面,考虑到遥感影像中常见的地物多尺度问题,作者提出一种将影像分割为多尺度子区域(multi-scale subregion)、并在每个子区域内进行自注意力计算的方案,以便捕捉不同尺度的时空依赖信息。STANet的整体框架如下图所示:

【PPSIG】基于PaddlePaddle复现STANet - 创想鸟        

Chen H, Shi Z. A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection. Remote Sensing. 2020; 12(10):1662. https://doi.org/10.3390/rs12101662

1.3 本项目内容

本项目遵照原论文,使用PaddlePaddle框架对STANet进行复现,并在LEVIR-CD数据集上对复现算法的效果进行评估。

2 STANet的PaddlePaddle实现


本节将基于PaddlePaddle框架,根据原论文实现STANet模型。复现过程总体上忠实于原文,但需要指出的是,原作者的官方实现和本项目中的实现有两点主要不同:

官方实现与本项目实现对ResNet backbone中最后一个stage的卷积空洞率的设置不同;官方实现的STANet采用Metric Module配合度量学习(metric learning)的方式训练,而本项目实现的STANet在最终层使用分割任务风格的dense classification head,使用交叉熵损失训练。

2.1 模型整体结构

遵循自顶向下的设计范式,首先对STANet模型的整体结构进行规划。在此之前,出于复用性方面考虑,对PaddlePaddle的一些原生模块进行封装:

In [ ]

import paddleimport paddle.nn as nnimport paddle.nn.functional as Fdef get_norm_layer():    # TODO: select appropriate norm layer.    return nn.BatchNorm2Ddef make_norm(*args, **kwargs):    norm_layer = get_norm_layer()    return norm_layer(*args, **kwargs)class BasicConv(nn.Layer):    def __init__(self,                 in_ch,                 out_ch,                 kernel_size,                 pad_mode='constant',                 bias='auto',                 norm=False,                 act=False,                 **kwargs):        super(BasicConv, self).__init__()        seq = []        if kernel_size >= 2:            seq.append(nn.Pad2D(kernel_size // 2, mode=pad_mode))        seq.append(            nn.Conv2D(                in_ch,                out_ch,                kernel_size,                stride=1,                padding=0,                bias_attr=(False if norm else None) if bias == 'auto' else bias,                **kwargs))        if norm:            if norm is True:                norm = make_norm(out_ch)            seq.append(norm)        if act:            if act is True:                act = make_act()            seq.append(act)        self.seq = nn.Sequential(*seq)    def forward(self, x):        return self.seq(x)class Conv1x1(BasicConv):    def __init__(self,                 in_ch,                 out_ch,                 pad_mode='constant',                 bias='auto',                 norm=False,                 act=False,                 **kwargs):        super(Conv1x1, self).__init__(            in_ch,            out_ch,            1,            pad_mode=pad_mode,            bias=bias,            norm=norm,            act=act,            **kwargs)class Conv3x3(BasicConv):    def __init__(self,                 in_ch,                 out_ch,                 pad_mode='constant',                 bias='auto',                 norm=False,                 act=False,                 **kwargs):        super(Conv3x3, self).__init__(            in_ch,            out_ch,            3,            pad_mode=pad_mode,            bias=bias,            norm=norm,            act=act,            **kwargs)class Identity(nn.Layer):    """A placeholder identity operator that accepts exactly one argument."""    def __init__(self, *args, **kwargs):        super(Identity, self).__init__()    def forward(self, x):        return x

   

接着,利用封装好的组件“搭积木”,定义模型:

In [ ]

class STANet(nn.Layer):    """    The STANet implementation based on PaddlePaddle.    The original article refers to        H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection"        (https://www.mdpi.com/2072-4292/12/10/1662).    Note that this implementation differs from the original work in two aspects:    1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.    2. A classification head is used in place of the original metric learning-based head to stablize the training process.    Args:        in_channels (int): The number of bands of the input images.        num_classes (int): The number of target classes.        att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'.        ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values             greater than 1, the input features will first be processed by an average pooling layer with the kernel size of             `ds_factor`, before being used to calculate the attention scores. Default: 1.    Raises:        ValueError: When `att_type` has an illeagal value (unsupported attention type).    """    def __init__(self, in_channels, num_classes, att_type='BAM', ds_factor=1):        super(STANet, self).__init__()        WIDTH = 64        self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH)        self.attend = build_sta_module(            in_ch=WIDTH, att_type=att_type, ds=ds_factor)        self.conv_out = nn.Sequential(            Conv3x3(                WIDTH, WIDTH, norm=True, act=True),            Conv3x3(WIDTH, num_classes))        self.init_weight()    def forward(self, t1, t2):        f1 = self.extract(t1)        f2 = self.extract(t2)        f1, f2 = self.attend(f1, f2)        y = paddle.abs(f1 - f2)        y = F.interpolate(            y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)        pred = self.conv_out(y)        return [pred]    def init_weight(self):        # Do nothing here as the encoder and decoder weights have already been initialized.        # Note however that currently self.attend and self.conv_out use the default initilization method.        pass

   

其中,build_feat_extractor()函数负责定义特征提取器(Feature Extractor),而build_sta_module()函数则负责定义时空注意力模块(Spatial-Temporal Attention Module)。

2.2 特征提取器

根据原文,将特征提取器分为骨干网络(backbone,一个修改版的ResNet-18)和解码器(decoder,将骨干网络提取到的不同层级特征上采样并融合)。编写build_feat_extractor()如下:

In [ ]

def build_feat_extractor(in_ch, width):    return nn.Sequential(Backbone(in_ch, 'resnet18'), Decoder(width))

   

其中,Backbone是对骨干网络部分的封装(为了使代码更具可扩展性,不止支持ResNet-18一种结构),Decoder则是对解码器部分的封装,用nn.Sequential将二者级联。

骨干网络部分的实现如下:

In [ ]

class Backbone(nn.Layer):    def __init__(self, in_ch, arch, pretrained=True, strides=(2, 1, 2, 2, 2)):        super(Backbone, self).__init__()        if arch == 'resnet18':            self.resnet = resnet18(                pretrained=pretrained,                strides=strides,                norm_layer=get_norm_layer())        elif arch == 'resnet34':            self.resnet = resnet34(                pretrained=pretrained,                strides=strides,                norm_layer=get_norm_layer())        elif arch == 'resnet50':            self.resnet = resnet50(                pretrained=pretrained,                strides=strides,                norm_layer=get_norm_layer())        else:            raise ValueError        self._trim_resnet()        if in_ch != 3:            self.resnet.conv1 = nn.Conv2D(                in_ch,                64,                kernel_size=7,                stride=strides[0],                padding=3,                bias_attr=False)    def forward(self, x):        x = self.resnet.conv1(x)        x = self.resnet.bn1(x)        x = self.resnet.relu(x)        x = self.resnet.maxpool(x)        x1 = self.resnet.layer1(x)        x2 = self.resnet.layer2(x1)        x3 = self.resnet.layer3(x2)        x4 = self.resnet.layer4(x3)        return x1, x2, x3, x4    def _trim_resnet(self):        self.resnet.avgpool = Identity()        self.resnet.fc = Identity()

   

其中,_trim_resnet()方法用于去除没有用到的全局平均池化和全连接层,以避免不必要的开销和可能导致的参数量统计错误。

对于ResNet的具体实现,为了避免重复造轮子,本项目直接在PaddlePaddle内置的ResNet基础上修改。主要修改的地方是为ResNet类的构造函数添加strides和norm_layer两个参数,前者控制stem及每个stage的output stride,后者控制使用的归一化层的类型。尽管本项目中并没有涉及到对归一化层类型的改变(仍然使用nn.BatchNorm2D),但应用这一设计有利于日后进一步的功能扩展。

In [ ]

# Adapted from https://github.com/PaddlePaddle/Paddle/blob/release/2.2/python/paddle/vision/models/resnet.py## Original head information# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.from __future__ import divisionfrom __future__ import print_functionimport paddleimport paddle.nn as nnfrom paddle.utils.download import get_weights_path_from_url__all__ = []model_urls = {    'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',                 'cf548f46534aa3560945be4b95cd11c4'),    'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams',                 '8d2275cf8706028345f78ac0e1d31969'),    'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',                 'ca6f485ee1ab0492d38f323885b0ad80'),    'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams',                  '02f35f034ca3858e1e54d4036443c92d'),    'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams',                  '7ad16a2f1e7333859ff986138630fd7a'),}class BasicBlock(nn.Layer):    expansion = 1    def __init__(self,                 inplanes,                 planes,                 stride=1,                 downsample=None,                 groups=1,                 base_width=64,                 dilation=1,                 norm_layer=None):        super(BasicBlock, self).__init__()        if norm_layer is None:            norm_layer = nn.BatchNorm2D        if dilation > 1:            raise NotImplementedError(                "Dilation > 1 not supported in BasicBlock")        self.conv1 = nn.Conv2D(            inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)        self.bn1 = norm_layer(planes)        self.relu = nn.ReLU()        self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)        self.bn2 = norm_layer(planes)        self.downsample = downsample        self.stride = stride    def forward(self, x):        identity = x        out = self.conv1(x)        out = self.bn1(out)        out = self.relu(out)        out = self.conv2(out)        out = self.bn2(out)        if self.downsample is not None:            identity = self.downsample(x)        out += identity        out = self.relu(out)        return outclass BottleneckBlock(nn.Layer):    expansion = 4    def __init__(self,                 inplanes,                 planes,                 stride=1,                 downsample=None,                 groups=1,                 base_width=64,                 dilation=1,                 norm_layer=None):        super(BottleneckBlock, self).__init__()        if norm_layer is None:            norm_layer = nn.BatchNorm2D        width = int(planes * (base_width / 64.)) * groups        self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)        self.bn1 = norm_layer(width)        self.conv2 = nn.Conv2D(            width,            width,            3,            padding=dilation,            stride=stride,            groups=groups,            dilation=dilation,            bias_attr=False)        self.bn2 = norm_layer(width)        self.conv3 = nn.Conv2D(            width, planes * self.expansion, 1, bias_attr=False)        self.bn3 = norm_layer(planes * self.expansion)        self.relu = nn.ReLU()        self.downsample = downsample        self.stride = stride    def forward(self, x):        identity = x        out = self.conv1(x)        out = self.bn1(out)        out = self.relu(out)        out = self.conv2(out)        out = self.bn2(out)        out = self.relu(out)        out = self.conv3(out)        out = self.bn3(out)        if self.downsample is not None:            identity = self.downsample(x)        out += identity        out = self.relu(out)        return outclass ResNet(nn.Layer):    """ResNet model from    `"Deep Residual Learning for Image Recognition" `_    Args:        Block (BasicBlock|BottleneckBlock): block module of model.        depth (int): layers of resnet, default: 50.        num_classes (int): output dim of last fc layer. If num_classes  0:            self.fc = nn.Linear(512 * block.expansion, num_classes)    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):        norm_layer = self._norm_layer        downsample = None        previous_dilation = self.dilation        if dilate:            self.dilation *= stride            stride = 1        if stride != 1 or self.inplanes != planes * block.expansion:            downsample = nn.Sequential(                nn.Conv2D(                    self.inplanes,                    planes * block.expansion,                    1,                    stride=stride,                    bias_attr=False),                norm_layer(planes * block.expansion), )        layers = []        layers.append(            block(self.inplanes, planes, stride, downsample, 1, 64,                  previous_dilation, norm_layer))        self.inplanes = planes * block.expansion        for _ in range(1, blocks):            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))        return nn.Sequential(*layers)    def forward(self, x):        x = self.conv1(x)        x = self.bn1(x)        x = self.relu(x)        x = self.maxpool(x)        x = self.layer1(x)        x = self.layer2(x)        x = self.layer3(x)        x = self.layer4(x)        if self.with_pool:            x = self.avgpool(x)        if self.num_classes > 0:            x = paddle.flatten(x, 1)            x = self.fc(x)        return xdef _resnet(arch, Block, depth, pretrained, **kwargs):    model = ResNet(Block, depth, **kwargs)    if pretrained:        assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(            arch)        weight_path = get_weights_path_from_url(model_urls[arch][0],                                                model_urls[arch][1])        param = paddle.load(weight_path)        model.set_dict(param)    return modeldef resnet18(pretrained=False, **kwargs):    """ResNet 18-layer model        Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet    Examples:        .. code-block:: python            from paddle.vision.models import resnet18            # build model            model = resnet18()            # build model and load imagenet pretrained weight            # model = resnet18(pretrained=True)    """    return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)def resnet34(pretrained=False, **kwargs):    """ResNet 34-layer model        Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet        Examples:        .. code-block:: python            from paddle.vision.models import resnet34            # build model            model = resnet34()            # build model and load imagenet pretrained weight            # model = resnet34(pretrained=True)    """    return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)def resnet50(pretrained=False, **kwargs):    """ResNet 50-layer model        Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet    Examples:        .. code-block:: python            from paddle.vision.models import resnet50            # build model            model = resnet50()            # build model and load imagenet pretrained weight            # model = resnet50(pretrained=True)    """    return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)def resnet101(pretrained=False, **kwargs):    """ResNet 101-layer model        Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet    Examples:        .. code-block:: python            from paddle.vision.models import resnet101            # build model            model = resnet101()            # build model and load imagenet pretrained weight            # model = resnet101(pretrained=True)    """    return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)def resnet152(pretrained=False, **kwargs):    """ResNet 152-layer model        Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet    Examples:        .. code-block:: python            from paddle.vision.models import resnet152            # build model            model = resnet152()            # build model and load imagenet pretrained weight            # model = resnet152(pretrained=True)    """    return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)

   

接着,定义特征提取器的解码器部分:

In [ ]

class Decoder(nn.Layer):    def __init__(self, f_ch):        super(Decoder, self).__init__()        self.dr1 = Conv1x1(64, 96, norm=True, act=True)        self.dr2 = Conv1x1(128, 96, norm=True, act=True)        self.dr3 = Conv1x1(256, 96, norm=True, act=True)        self.dr4 = Conv1x1(512, 96, norm=True, act=True)        self.conv_out = nn.Sequential(            Conv3x3(                384, 256, norm=True, act=True),            nn.Dropout(0.5),            Conv1x1(                256, f_ch, norm=True, act=True))        self.init_weight()    def forward(self, feats):        f1 = self.dr1(feats[0])        f2 = self.dr2(feats[1])        f3 = self.dr3(feats[2])        f4 = self.dr4(feats[3])        f2 = F.interpolate(            f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)        f3 = F.interpolate(            f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)        f4 = F.interpolate(            f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)        x = paddle.concat([f1, f2, f3, f4], axis=1)        y = self.conv_out(x)        return y

   

需要指出的是,在本项目的实现中,若没有使用预训练权重参数,则整个特征提取器中的所有权重参数均使用PaddlePaddle默认的方式初始化。

2.3 时空注意力模块

原论文提出了基础时空注意力模块(Basic spatial-temporal Attention Module, BAM)和金字塔时空注意力模块(Pyramid spatial-temporal Attention Module, PAM)两种时空注意力模块。提出BAM和PAM是论文的主要贡献点之一,而二者的实现也是STANet复现过程中相对困难的部分。仍然遵照自顶向下的原则,首先编写build_sta_module()函数:

In [ ]

def build_sta_module(in_ch, att_type, ds):    if att_type == 'BAM':        return Attention(BAM(in_ch, ds))    elif att_type == 'PAM':        return Attention(PAM(in_ch, ds))    else:        raise ValueError

   

使用Attention类封装时空注意力模块的相关逻辑,作为适配器(adapter)。

In [ ]

class Attention(nn.Layer):    def __init__(self, att):        super(Attention, self).__init__()        self.att = att    def forward(self, x1, x2):        x = paddle.stack([x1, x2], axis=-1)        y = self.att(x)        return y[..., 0], y[..., 1]

   

因为时空注意力机制涉及到两个时相的影像特征,在这里我使用一种简单的策略,即,将两个时相的输入特征图堆叠(stack)后送入具体的注意力模块实现中,然后以元组的形式返回输出特征图在堆叠维度的拆分(split)结果。

考虑到PAM是BAM的多尺度改进版本,首先实现BAM:

In [ ]

class BAM(nn.Layer):    def __init__(self, in_ch, ds):        super(BAM, self).__init__()        self.ds = ds        self.pool = nn.AvgPool2D(self.ds)        self.val_ch = in_ch        self.key_ch = in_ch // 8        self.conv_q = Conv1x1(in_ch, self.key_ch)        self.conv_k = Conv1x1(in_ch, self.key_ch)        self.conv_v = Conv1x1(in_ch, self.val_ch)        self.softmax = nn.Softmax(axis=-1)    def forward(self, x):        x = x.flatten(-2)        x_rs = self.pool(x)        b, c, h, w = paddle.shape(x_rs)        query = self.conv_q(x_rs).reshape((b, -1, h * w)).transpose((0, 2, 1))        key = self.conv_k(x_rs).reshape((b, -1, h * w))        energy = paddle.bmm(query, key)        energy = (self.key_ch**(-0.5)) * energy        attention = self.softmax(energy)        value = self.conv_v(x_rs).reshape((b, -1, w * h))        out = paddle.bmm(value, attention.transpose((0, 2, 1)))        out = out.reshape((b, c, h, w))        out = F.interpolate(out, scale_factor=self.ds)        out = out + x        return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))

   

由于BAM需要计算两个时相的特征图中所有像素的密集连接,因此其计算复杂度较高。通过设置大于1的ds参数可以对输入的特征图进行池化以减小其空间尺寸,从而达到降低计算量的目的。

相较BAM,PAM更为复杂。但通过抽象可以将其拆解为不同尺度上“子模块”结果的融合,定义如下所示:

In [ ]

class PAM(nn.Layer):    def __init__(self, in_ch, ds, scales=(1, 2, 4, 8)):        super(PAM, self).__init__()        self.stages = nn.LayerList(            [PAMBlock(                in_ch, scale=s, ds=ds) for s in scales])        self.conv_out = Conv1x1(in_ch * len(scales), in_ch, bias=False)    def forward(self, x):        x = x.flatten(-2)        res = [stage(x) for stage in self.stages]        out = self.conv_out(paddle.concat(res, axis=1))        return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))

   

对于每种尺度,“子模块”PAMBlock的定义为:

In [ ]

class PAMBlock(nn.Layer):    def __init__(self, in_ch, scale=1, ds=1):        super(PAMBlock, self).__init__()        self.scale = scale        self.ds = ds        self.pool = nn.AvgPool2D(self.ds)        self.val_ch = in_ch        self.key_ch = in_ch // 8        self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True)        self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True)        self.conv_v = Conv1x1(in_ch, self.val_ch)    def forward(self, x):        x_rs = self.pool(x)        # Get query, key, and value.        query = self.conv_q(x_rs)        key = self.conv_k(x_rs)        value = self.conv_v(x_rs)        # Split the whole image into subregions.        b, c, h, w = x_rs.shape        query = self._split_subregions(query)        key = self._split_subregions(key)        value = self._split_subregions(value)        # Perform subregion-wise attention.        out = self._attend(query, key, value)        # Stack subregions to reconstruct the whole image.        out = self._recons_whole(out, b, c, h, w)        out = F.interpolate(out, scale_factor=self.ds)        return out    def _attend(self, query, key, value):        energy = paddle.bmm(query.transpose((0, 2, 1)),                            key)  # batch matrix multiplication        energy = (self.key_ch**(-0.5)) * energy        attention = F.softmax(energy, axis=-1)        out = paddle.bmm(value, attention.transpose((0, 2, 1)))        return out    def _split_subregions(self, x):        b, c, h, w = x.shape        assert h % self.scale == 0 and w % self.scale == 0        x = x.reshape(            (b, c, self.scale, h // self.scale, self.scale, w // self.scale))        x = x.transpose((0, 2, 4, 1, 3, 5))        x = x.reshape((b * self.scale * self.scale, c, -1))        return x    def _recons_whole(self, x, b, c, h, w):        x = x.reshape(            (b, self.scale, self.scale, c, h // self.scale, w // self.scale))        x = x.transpose((0, 3, 1, 4, 2, 5)).reshape((b, c, h, w))        return x

   

PAMBlock首先将输入特征在空间维度上划分为一系列子区域,然后在每个子区域执行和BAM相同的操作,最后将得到的结果重建为和输入具有相同形状的张量。在具体代码编写方面,本项目使用向量化的方式实现上述操作,具有较高的并行度。

3 模型训练与测试


本节将对上节中实现的STANet模型进行训练与测试,以验证复现效果。

3.1 准备工作

In [ ]

# 下载安装PaddleRS(个人创建的Gitee镜像版本)!git clone https://gitee.com/bobholamovic/PaddleRS!pip install -r PaddleRS/requirements.txt!pip install -e PaddleRS/import syssys.path.append("/home/aistudio/PaddleRS")

   In [ ]

# 解压数据集!unzip -oq data/data136610/LEVIR-CD.zip -d data/data136610/

   In [ ]

# 将每幅影像切成256x256的不重叠小块,这一做法与STANet原论文一致# 关于该步骤的具体说明可以参看crop_images.py脚本中的注释!python crop_images.py 'data/data136610/LEVIR-CD/' 'data/data136610/LEVIR-CD_crop/' 256 256

   In [ ]

# 生成PaddleRS训练所需的file_list!IN_DIR='data/data136610/LEVIR-CD_crop/train' && paste <(ls ${IN_DIR}/A/*/*.png) <(ls ${IN_DIR}/B/*/*.png)  data/data136610/train_list.txt!IN_DIR='data/data136610/LEVIR-CD_crop/val' && paste <(ls ${IN_DIR}/A/*/*.png) <(ls ${IN_DIR}/B/*/*.png)  data/data136610/val_list.txt!IN_DIR='data/data136610/LEVIR-CD_crop/test' && paste <(ls ${IN_DIR}/A/*/*.png) <(ls ${IN_DIR}/B/*/*.png)  data/data136610/test_list.txt

   

3.2 模型训练

In [ ]

import paddlers as pdrsimport paddlers.transforms as T# 定义训练和验证时的transformstrain_transforms = T.Compose([    T.Resize(target_size=256),    T.RandomHorizontalFlip(),    T.Normalize(        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),])eval_transforms = T.Compose([    T.Resize(target_size=256),    T.Normalize(        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),])# 定义训练和验证所用的数据集train_dataset = pdrs.datasets.CDDataset(    data_dir='', # 由于file_list中存储的是完整的绝对路径,此处只需要给定空字符串即可    file_list='/home/aistudio/data/data136610/train_list.txt',    transforms=train_transforms,    num_workers=2,    binarize_labels=True,    shuffle=True)val_dataset = pdrs.datasets.CDDataset(    data_dir='',    file_list='/home/aistudio/data/data136610/val_list.txt',    transforms=eval_transforms,    num_workers=2,    binarize_labels=True,    shuffle=False)

   In [ ]

# 构造STANet模型,使用PAM作为时空注意力模块model = pdrs.tasks.STANet(in_channels=3, num_classes=2, att_type='PAM')# 制定定步长学习率衰减策略lr_scheduler = paddle.optimizer.lr.StepDecay(    0.001,    step_size=5000,    # 学习率衰减系数,这里指定每次减半    gamma=0.5)# 构造AdamW优化器optimizer = paddle.optimizer.AdamW(    learning_rate=lr_scheduler,    parameters=model.net.parameters())

   In [ ]

# 执行模型训练model.train(    num_epochs=100,    train_dataset=train_dataset,    train_batch_size=8,    eval_dataset=val_dataset,    optimizer=optimizer,    save_interval_epochs=10,    # 每多少次迭代记录一次日志    log_interval_steps=200,    save_dir='exp/',    # 是否使用early stopping策略,当精度不再改善时提前终止训练    early_stop=False,    # 是否启用VisualDL日志功能    use_vdl=True,    # 指定从某个检查点继续训练    resume_checkpoint=None)

   

3.3 模型测试

In [ ]

# 计算测试集上的定量指标test_dataset = pdrs.datasets.CDDataset(    data_dir='',    file_list='/home/aistudio/data/data136610/test_list.txt',    transforms=eval_transforms,    num_workers=2,    binarize_labels=True,    shuffle=False)model = pdrs.tasks.load_model('exp/best_model')res = model.evaluate(test_dataset)print(f"IoU={res['iou']:.4f}, F1-score={res['f1']:.4f}, OA={res['oacc']:.4f}, Kappa={res['kappa']:.4f}")

   In [21]

# 可视化预测结果%matplotlib inlineimport numpy as npfrom PIL import Imagefrom matplotlib import pyplot as pltT1_PATH = "data/data136610/LEVIR-CD_crop/test/A/test_1/test_1_3.png"T2_PATH = "data/data136610/LEVIR-CD_crop/test/B/test_1/test_1_3.png"GT_PATH = "data/data136610/LEVIR-CD_crop/test/label/test_1/test_1_3.png"pred = model.predict(    (T1_PATH, T2_PATH),     transforms=eval_transforms)# 从左到右依次显示:第一时相影像、第二时相影像、STANet预测结果以及真值标签plt.figure(constrained_layout=True)plt.subplot(141)plt.imshow(Image.open(T1_PATH))plt.gca().set_axis_off()plt.title("Image1")plt.subplot(142)plt.imshow(Image.open(T2_PATH))plt.gca().set_axis_off()plt.title("Image2")plt.subplot(143)plt.imshow((pred['label_map']*255).astype('uint8'))plt.gca().set_axis_off()plt.title("Pred")plt.subplot(144)plt.imshow((np.asarray(Image.open(GT_PATH))*255).astype('uint8'))plt.gca().set_axis_off()plt.title("GT")plt.show()

       

               

以上就是【PPSIG】基于PaddlePaddle复现STANet的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/46118.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月7日 16:45:36
下一篇 2025年11月7日 16:46:16

相关推荐

  • Uniapp 中如何不拉伸不裁剪地展示图片?

    灵活展示图片:如何不拉伸不裁剪 在界面设计中,常常需要以原尺寸展示用户上传的图片。本文将介绍一种在 uniapp 框架中实现该功能的简单方法。 对于不同尺寸的图片,可以采用以下处理方式: 极端宽高比:撑满屏幕宽度或高度,再等比缩放居中。非极端宽高比:居中显示,若能撑满则撑满。 然而,如果需要不拉伸不…

    2025年12月24日
    400
  • 如何让小说网站控制台显示乱码,同时网页内容正常显示?

    如何在不影响用户界面的情况下实现控制台乱码? 当在小说网站上下载小说时,大家可能会遇到一个问题:网站上的文本在网页内正常显示,但是在控制台中却是乱码。如何实现此类操作,从而在不影响用户界面(UI)的情况下保持控制台乱码呢? 答案在于使用自定义字体。网站可以通过在服务器端配置自定义字体,并通过在客户端…

    2025年12月24日
    800
  • 如何在地图上轻松创建气泡信息框?

    地图上气泡信息框的巧妙生成 地图上气泡信息框是一种常用的交互功能,它简便易用,能够为用户提供额外信息。本文将探讨如何借助地图库的功能轻松创建这一功能。 利用地图库的原生功能 大多数地图库,如高德地图,都提供了现成的信息窗体和右键菜单功能。这些功能可以通过以下途径实现: 高德地图 JS API 参考文…

    2025年12月24日
    400
  • 如何使用 scroll-behavior 属性实现元素scrollLeft变化时的平滑动画?

    如何实现元素scrollleft变化时的平滑动画效果? 在许多网页应用中,滚动容器的水平滚动条(scrollleft)需要频繁使用。为了让滚动动作更加自然,你希望给scrollleft的变化添加动画效果。 解决方案:scroll-behavior 属性 要实现scrollleft变化时的平滑动画效果…

    2025年12月24日
    000
  • 如何为滚动元素添加平滑过渡,使滚动条滑动时更自然流畅?

    给滚动元素平滑过渡 如何在滚动条属性(scrollleft)发生改变时为元素添加平滑的过渡效果? 解决方案:scroll-behavior 属性 为滚动容器设置 scroll-behavior 属性可以实现平滑滚动。 html 代码: click the button to slide right!…

    2025年12月24日
    500
  • 如何选择元素个数不固定的指定类名子元素?

    灵活选择元素个数不固定的指定类名子元素 在网页布局中,有时需要选择特定类名的子元素,但这些元素的数量并不固定。例如,下面这段 html 代码中,activebar 和 item 元素的数量均不固定: *n *n 如果需要选择第一个 item元素,可以使用 css 选择器 :nth-child()。该…

    2025年12月24日
    200
  • 使用 SVG 如何实现自定义宽度、间距和半径的虚线边框?

    使用 svg 实现自定义虚线边框 如何实现一个具有自定义宽度、间距和半径的虚线边框是一个常见的前端开发问题。传统的解决方案通常涉及使用 border-image 引入切片图片,但是这种方法存在引入外部资源、性能低下的缺点。 为了避免上述问题,可以使用 svg(可缩放矢量图形)来创建纯代码实现。一种方…

    2025年12月24日
    100
  • 如何解决本地图片在使用 mask JS 库时出现的跨域错误?

    如何跨越localhost使用本地图片? 问题: 在本地使用mask js库时,引入本地图片会报跨域错误。 解决方案: 要解决此问题,需要使用本地服务器启动文件,以http或https协议访问图片,而不是使用file://协议。例如: python -m http.server 8000 然后,可以…

    2025年12月24日
    200
  • 如何让“元素跟随文本高度,而不是撑高父容器?

    如何让 元素跟随文本高度,而不是撑高父容器 在页面布局中,经常遇到父容器高度被子元素撑开的问题。在图例所示的案例中,父容器被较高的图片撑开,而文本的高度没有被考虑。本问答将提供纯css解决方案,让图片跟随文本高度,确保父容器的高度不会被图片影响。 解决方法 为了解决这个问题,需要将图片从文档流中脱离…

    2025年12月24日
    000
  • 为什么 CSS mask 属性未请求指定图片?

    解决 css mask 属性未请求图片的问题 在使用 css mask 属性时,指定了图片地址,但网络面板显示未请求获取该图片,这可能是由于浏览器兼容性问题造成的。 问题 如下代码所示: 立即学习“前端免费学习笔记(深入)”; icon [data-icon=”cloud”] { –icon-cl…

    2025年12月24日
    200
  • 如何利用 CSS 选中激活标签并影响相邻元素的样式?

    如何利用 css 选中激活标签并影响相邻元素? 为了实现激活标签影响相邻元素的样式需求,可以通过 :has 选择器来实现。以下是如何具体操作: 对于激活标签相邻后的元素,可以在 css 中使用以下代码进行设置: li:has(+li.active) { border-radius: 0 0 10px…

    2025年12月24日
    100
  • 如何模拟Windows 10 设置界面中的鼠标悬浮放大效果?

    win10设置界面的鼠标移动显示周边的样式(探照灯效果)的实现方式 在windows设置界面的鼠标悬浮效果中,光标周围会显示一个放大区域。在前端开发中,可以通过多种方式实现类似的效果。 使用css 使用css的transform和box-shadow属性。通过将transform: scale(1.…

    2025年12月24日
    200
  • 为什么我的 Safari 自定义样式表在百度页面上失效了?

    为什么在 Safari 中自定义样式表未能正常工作? 在 Safari 的偏好设置中设置自定义样式表后,您对其进行测试却发现效果不同。在您自己的网页中,样式有效,而在百度页面中却失效。 造成这种情况的原因是,第一个访问的项目使用了文件协议,可以访问本地目录中的图片文件。而第二个访问的百度使用了 ht…

    2025年12月24日
    000
  • 如何用前端实现 Windows 10 设置界面的鼠标移动探照灯效果?

    如何在前端实现 Windows 10 设置界面中的鼠标移动探照灯效果 想要在前端开发中实现 Windows 10 设置界面中类似的鼠标移动探照灯效果,可以通过以下途径: CSS 解决方案 DEMO 1: Windows 10 网格悬停效果:https://codepen.io/tr4553r7/pe…

    2025年12月24日
    000
  • 使用CSS mask属性指定图片URL时,为什么浏览器无法加载图片?

    css mask属性未能加载图片的解决方法 使用css mask属性指定图片url时,如示例中所示: mask: url(“https://api.iconify.design/mdi:apple-icloud.svg”) center / contain no-repeat; 但是,在网络面板中却…

    2025年12月24日
    000
  • 如何用CSS Paint API为网页元素添加时尚的斑马线边框?

    为元素添加时尚的斑马线边框 在网页设计中,有时我们需要添加时尚的边框来提升元素的视觉效果。其中,斑马线边框是一种既醒目又别致的设计元素。 实现斜向斑马线边框 要实现斜向斑马线间隔圆环,我们可以使用css paint api。该api提供了强大的功能,可以让我们在元素上绘制复杂的图形。 立即学习“前端…

    2025年12月24日
    000
  • 图片如何不撑高父容器?

    如何让图片不撑高父容器? 当父容器包含不同高度的子元素时,父容器的高度通常会被最高元素撑开。如果你希望父容器的高度由文本内容撑开,避免图片对其产生影响,可以通过以下 css 解决方法: 绝对定位元素: .child-image { position: absolute; top: 0; left: …

    2025年12月24日
    000
  • 使用 Mask 导入本地图片时,如何解决跨域问题?

    跨域疑难:如何解决 mask 引入本地图片产生的跨域问题? 在使用 mask 导入本地图片时,你可能会遇到令人沮丧的跨域错误。为什么会出现跨域问题呢?让我们深入了解一下: mask 框架假设你以 http(s) 协议加载你的 html 文件,而当使用 file:// 协议打开本地文件时,就会产生跨域…

    2025年12月24日
    200
  • CSS 帮助

    我正在尝试将文本附加到棕色框的左侧。我不能。我不知道代码有什么问题。请帮助我。 css .hero { position: relative; bottom: 80px; display: flex; justify-content: left; align-items: start; color:…

    2025年12月24日 好文分享
    200
  • HTML、CSS 和 JavaScript 中的简单侧边栏菜单

    构建一个简单的侧边栏菜单是一个很好的主意,它可以为您的网站添加有价值的功能和令人惊叹的外观。 侧边栏菜单对于客户找到不同项目的方式很有用,而不会让他们觉得自己有太多选择,从而创造了简单性和秩序。 今天,我将分享一个简单的 HTML、CSS 和 JavaScript 源代码来创建一个简单的侧边栏菜单。…

    2025年12月24日
    200

发表回复

登录后才能评论
关注微信