『网盘赛』基于自定义训练模板的文档阴影消除

本文基于文档阴影消除网盘赛,提供了一套PaddlePaddle训练模板。模板实现了定制输出、中断续训、保存最优模型等功能,涵盖数据增强、模型训练等全流程,还支持图像分块提升精度。示例用KIUnet和UNet_3Plus模型,提交结果0.59951,方便用户快速修改实现想法。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

『网盘赛』基于自定义训练模板的文档阴影消除 - 创想鸟

「自定义训练模板」基于文档阴影消除网盘赛的示例

在使用Paddle写项目的时候,相信你也有和我一样的困惑:

高级API的封装太好,需要定制一些函数功能的时候不好添加;低级的API功能又很单一,要实现很多功能需要额外写很多代码。

如果能有一个模板,本身就具有比较完善的功能,只需要修改一部分代码就能快速实现自己的想法就好了。

基于这个想法,我完成了一套训练模型的模板,具有定制输出,模型中断后继续训练,保存最优模型等功能。

大家可以基于这个模板,快速修改并实现自己的想法。

0 项目背景

生活中在使用手机进行文档扫描时,由于角度原因,难免会在照片中留下举手拍摄时遗留的恼人阴影。为了帮助人们解决这样的问题,减少阴影带来的干扰,选手需要通过深度学习技术训练模型,对给定的真实场景下采集得到的带有阴影的图片进行处理,还原图片原本的样子,并最终输出处理后的结果图片。

有阴影图片 无阴影图片

@@##@@            @@##@@            

1 数据介绍

本次比赛最新发布的数据集共包含训练集、A榜测试集、B榜测试集三个部分,其中,训练集共1410个样本(图片编号非连续),A榜测试集共300个样本(图片编号连续),B榜测试集共397个样本;

images为带阴影的源图像数据,gts为无阴影的真值数据;(测试数据集的GT不开放)

images与gts中的图片根据图片名称一一对应。

2 模板已实现功能

1 :每个epoch输出定制信息,使用tqdm进度条实时查看每个epoch运行时间(修正windows终端或异常中断tqdm出现的混乱)

2 :可以直接定义并修改损失函数和模型,并给出示例

3 :可以选择传入模型继续训练或从头训练

4 :在训练完成后,以csv形式保存log文件

5 :自定义评价指标,并根据评价指标结果自动保存最优模型

6 :完成数据增强,模型训练,测试,推理全流程代码

7 :实现输入图像的分块,能有效提升图像类任务精度


参考项目:牛尾扫暗影,虎头开盲盒。三卷网盘赛,榜评0.59168

3 代码实现及讲解

In [ ]

# 解压数据集!unzip  -o data/data125945/shadowRemovalOfDoc.zip!unzip delight_testA_dataset.zip -d data/ >>/dev/null!unzip delight_train_dataset.zip -d data/ >>/dev/null

   In [ ]

# 安装需要的库函数!pip install scikit_image==0.15.0

   

3.1 图像分块


该部分代码用于将图像分割成小块后保存起来,包括翻转,旋转等七种不同组合的数据增强方法。


prepare_data( patch_size=256, stride=200, aug_times=1)

patch_size:切分的正方形块的大小

stride:切分步长,每隔切分步长进行切块

aug_times:增强次数,每个图块生成几张增强图像

CODE 4 第32行 scales = [1] # 对数据进行随机放缩

该参数是一个List,用于对图像进行缩放。

[1,1.2] 表示原图和将图像扩大1.2倍后进行切分,这样增加训练数据量为 [1] 的两倍


In [ ]

# 定义离线数据增强方法def data_augmentation(image,label, mode):    out = np.transpose(image, (1,2,0))    out_label = np.transpose(label, (1,2,0))    if mode == 0:        # original        out = out        out_label = out_label    elif mode == 1:        # flip up and down        out = np.flipud(out)        out_label = np.flipud(out_label)    elif mode == 2:        # rotate counterwise 90 degree        out = np.rot90(out)        out_label = np.rot90(out_label)    elif mode == 3:        # rotate 90 degree and flip up and down        out = np.rot90(out)        out = np.flipud(out)        out_label = np.rot90(out_label)        out_label = np.flipud(out_label)    elif mode == 4:        # rotate 180 degree        out = np.rot90(out, k=2)        out_label = np.rot90(out_label, k=2)    elif mode == 5:        # rotate 180 degree and flip        out = np.rot90(out, k=2)        out = np.flipud(out)        out_label = np.rot90(out_label, k=2)        out_label = np.flipud(out_label)    elif mode == 6:        # rotate 270 degree        out = np.rot90(out, k=3)        out_label = np.rot90(out_label, k=3)    elif mode == 7:        # rotate 270 degree and flip        out = np.rot90(out, k=3)        out = np.flipud(out)        out_label = np.rot90(out_label, k=3)        out_label = np.flipud(out_label)    return  out,out_label

   In [ ]

## 制作分块数据集import cv2import numpy as npimport mathimport glob import osdef Im2Patch(img, win, stride=1):    k = 0    endc = img.shape[0]    endw = img.shape[1]    endh = img.shape[2]    patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]    TotalPatNum = patch.shape[1] * patch.shape[2]    Y = np.zeros([endc, win*win,TotalPatNum], np.float32)    for i in range(win):        for j in range(win):            patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]            Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)            k = k + 1    return Y.reshape([endc, win, win, TotalPatNum])def prepare_data(patch_size, stride, aug_times=1):    '''    该函数用于将图像切成方块,并进行数据增强    patch_size: 图像块的大小,本项目200*200        stride: 步长,每个图像块的间隔     aug_times: 数据增强次数,默认从八种增强方式中选择一种    '''    # train    print('process training data')    scales = [1] # 对数据进行随机放缩    files = glob.glob(os.path.join('data/delight_train_dataset/images', '*.jpg'))    files.sort()    img_folder = 'work/img_patch'    if  not os.path.exists(img_folder):        os.mkdir(img_folder)    label_folder = 'work/label_patch'    if  not os.path.exists(label_folder):        os.mkdir(label_folder)    train_num = 0    for i in range(len(files)):        img = cv2.imread(files[i])        label = cv2.imread(files[i].replace('images','gts'))        h, w, c = img.shape        for k in range(len(scales)):            Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)            Label = cv2.resize(label, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)            Img = np.transpose(Img, (2,0,1))            Label = np.transpose(Label, (2,0,1))            Img = np.float32(np.clip(Img,0,255))            Label = np.float32(np.clip(Label,0,255))            patches = Im2Patch(Img, win=patch_size, stride=stride)            label_patches = Im2Patch(Label, win=patch_size, stride=stride)            print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], patches.shape[3]*aug_times))            for n in range(patches.shape[3]):                data = patches[:,:,:,n].copy()                label_data = label_patches[:,:,:,n].copy()                            for m in range(aug_times):                    data_aug,label_aug = data_augmentation(data,label_data, np.random.randint(1,8))                    label_name = os.path.join(label_folder,str(train_num)+"_aug_%d" % (m+1)+'.jpg')                    image_name = os.path.join(img_folder,str(train_num)+"_aug_%d" % (m+1)+'.jpg')                                        cv2.imwrite(image_name, data_aug,[int( cv2.IMWRITE_JPEG_QUALITY), 100])                    cv2.imwrite(label_name, label_aug,[int( cv2.IMWRITE_JPEG_QUALITY), 100])                    train_num += 1    print('training set, # samples %dn' % train_num)   ## 生成数据prepare_data( patch_size=256, stride=200, aug_times=1)

   

3.2 重写数据读取类


该部分代码使用了paddle.vision.transforms内置的图像增强方法


In [ ]

# 重写数据读取类import paddleimport paddle.vision.transforms as Timport numpy as npimport globimport cv2# 重写数据读取类class DEshadowDataset(paddle.io.Dataset):    def __init__(self,mode = 'train',is_transforms = False):               label_path_ ='work/label_patch/*.jpg'        jpg_path_ ='work/img_patch/*.jpg'        self.label_list_ = glob.glob(label_path_)                self.jpg_list_ = glob.glob(jpg_path_)                self.is_transforms = is_transforms        self.mode = mode        scale_point = 0.95                self.transforms =T.Compose([            T.Normalize(data_format='HWC',),            T.HueTransform(0.4),            T.SaturationTransform(0.4),            T.HueTransform(0.4),            T.ToTensor(),            ])        # 选择前95%训练,后5%验证        if self.mode == 'train':            self.jpg_list_ = self.jpg_list_[:int(scale_point*len(self.jpg_list_))]            self.label_list_ = self.label_list_[:int(scale_point*len(self.label_list_))]        else:            self.jpg_list_ = self.jpg_list_[int(scale_point*len(self.jpg_list_)):]            self.label_list_ = self.label_list_[int(scale_point*len(self.label_list_)):]    def __getitem__(self, index):        jpg_ = self.jpg_list_[index]        label_ =  self.label_list_[index]        data = cv2.imread(jpg_) # 读取和代码处于同一目录下的 lena.png # 转为 0-1        mask = cv2.imread(label_)        data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB) # BGR 2 RGB        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) # BGR 2 RGB                data = np.uint8(data)        mask = np.uint8(mask)        if self.is_transforms:            data = self.transforms(data)            data = data/255            mask = T.functional.to_tensor(mask)         return  data,mask      def __len__(self):        return len(self.jpg_list_)

   In [13]

# 数据读取及增强可视化import paddle.vision.transforms as Timport matplotlib.pyplot as pltfrom PIL import Imagedataset = DEshadowDataset(mode='train',is_transforms = False )print('=============train dataset=============')img_,mask_ = dataset[3] # mask 始终大于 imgimg = Image.fromarray(img_)mask = Image.fromarray(mask_)#当要保存的图片为灰度图像时,灰度图像的 numpy 尺度是 [1, h, w]。需要将 [1, h, w] 改变为 [h, w]plt.figure(figsize=(12, 6))plt.subplot(1,2,1),plt.xticks([]),plt.yticks([]),plt.imshow(img)plt.subplot(1,2,2),plt.xticks([]),plt.yticks([]),plt.imshow(mask)plt.show()

   

3.3 重写模型


该部分代码定义KIUnet 和UNet_3Plus作为示例。

网络介绍详情请点击链接查看。

In [ ]

# 定义KIUnet import paddleimport paddle.nn as nnimport paddle.nn.functional as Ffrom paddle.nn import initializerdef init_weights(init_type='kaiming'):    if init_type == 'normal':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.Normal())    elif init_type == 'xavier':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.XavierNormal())    elif init_type == 'kaiming':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal)    else:        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)class kiunet(nn.Layer):    def __init__(self ,in_channels = 3, n_classes =3):        super(kiunet,self).__init__()        self.in_channels = in_channels        self.n_class = n_classes        self.encoder1 = nn.Conv2D(self.in_channels, 16, 3, stride=1, padding=1)  # First Layer GrayScale Image , change to input channels to 3 in case of RGB         self.en1_bn = nn.BatchNorm(16)        self.encoder2=   nn.Conv2D(16, 32, 3, stride=1, padding=1)          self.en2_bn = nn.BatchNorm(32)        self.encoder3=   nn.Conv2D(32, 64, 3, stride=1, padding=1)        self.en3_bn = nn.BatchNorm(64)        self.decoder1 =   nn.Conv2D(64, 32, 3, stride=1, padding=1)           self.de1_bn = nn.BatchNorm(32)        self.decoder2 =   nn.Conv2D(32,16, 3, stride=1, padding=1)        self.de2_bn = nn.BatchNorm(16)        self.decoder3 =   nn.Conv2D(16, 8, 3, stride=1, padding=1)        self.de3_bn = nn.BatchNorm(8)        self.decoderf1 =   nn.Conv2D(64, 32, 3, stride=1, padding=1)        self.def1_bn = nn.BatchNorm(32)        self.decoderf2=   nn.Conv2D(32, 16, 3, stride=1, padding=1)        self.def2_bn = nn.BatchNorm(16)        self.decoderf3 =   nn.Conv2D(16, 8, 3, stride=1, padding=1)        self.def3_bn = nn.BatchNorm(8)        self.encoderf1 =   nn.Conv2D(in_channels, 16, 3, stride=1, padding=1)  # First Layer GrayScale Image , change to input channels to 3 in case of RGB         self.enf1_bn = nn.BatchNorm(16)        self.encoderf2=   nn.Conv2D(16, 32, 3, stride=1, padding=1)        self.enf2_bn = nn.BatchNorm(32)        self.encoderf3 =   nn.Conv2D(32, 64, 3, stride=1, padding=1)        self.enf3_bn = nn.BatchNorm(64)        self.intere1_1 = nn.Conv2D(16,16,3, stride=1, padding=1)        self.inte1_1bn = nn.BatchNorm(16)        self.intere2_1 = nn.Conv2D(32,32,3, stride=1, padding=1)        self.inte2_1bn = nn.BatchNorm(32)        self.intere3_1 = nn.Conv2D(64,64,3, stride=1, padding=1)        self.inte3_1bn = nn.BatchNorm(64)        self.intere1_2 = nn.Conv2D(16,16,3, stride=1, padding=1)        self.inte1_2bn = nn.BatchNorm(16)        self.intere2_2 = nn.Conv2D(32,32,3, stride=1, padding=1)        self.inte2_2bn = nn.BatchNorm(32)        self.intere3_2 = nn.Conv2D(64,64,3, stride=1, padding=1)        self.inte3_2bn = nn.BatchNorm(64)        self.interd1_1 = nn.Conv2D(32,32,3, stride=1, padding=1)        self.intd1_1bn = nn.BatchNorm(32)        self.interd2_1 = nn.Conv2D(16,16,3, stride=1, padding=1)        self.intd2_1bn = nn.BatchNorm(16)        self.interd3_1 = nn.Conv2D(64,64,3, stride=1, padding=1)        self.intd3_1bn = nn.BatchNorm(64)        self.interd1_2 = nn.Conv2D(32,32,3, stride=1, padding=1)        self.intd1_2bn = nn.BatchNorm(32)        self.interd2_2 = nn.Conv2D(16,16,3, stride=1, padding=1)        self.intd2_2bn = nn.BatchNorm(16)        self.interd3_2 = nn.Conv2D(64,64,3, stride=1, padding=1)        self.intd3_2bn = nn.BatchNorm(64)        self.final = nn.Sequential(            nn.Conv2D(8,self.n_class,1,stride=1,padding=0),            nn.AdaptiveAvgPool2D(output_size=1))        # initialise weights        for m in self.sublayers ():            if isinstance(m, nn.Conv2D):                m.weight_attr = init_weights(init_type='kaiming')                m.bias_attr = init_weights(init_type='kaiming')            elif isinstance(m, nn.BatchNorm):                m.param_attr =init_weights(init_type='kaiming')                m.bias_attr = init_weights(init_type='kaiming')     def forward(self, x):        # input: c * h * w -> 16 * h/2 * w/2        out = F.relu(self.en1_bn(F.max_pool2d(self.encoder1(x),2,2)))  #U-Net branch        # c * h * w -> 16 * 2h * 2w        out1 = F.relu(self.enf1_bn(F.interpolate(self.encoderf1(x),scale_factor=(2,2),mode ='bicubic'))) #Ki-Net branch        # 16 * h/2 * w/2        tmp = out        # 16 * 2h * 2w -> 16 * h/2 * w/2        out = paddle.add(out,F.interpolate(F.relu(self.inte1_1bn(self.intere1_1(out1))),scale_factor=(0.25,0.25),mode ='bicubic')) #CRFB        # 16 * h/2 * w/2 -> 16 * 2h * 2w        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte1_2bn(self.intere1_2(tmp))),scale_factor=(4,4),mode ='bicubic')) #CRFB                # 16 * h/2 * w/2        u1 = out  #skip conn        # 16 * 2h * 2w        o1 = out1  #skip conn        # 16 * h/2 * w/2 -> 32 * h/4 * w/4        out = F.relu(self.en2_bn(F.max_pool2d(self.encoder2(out),2,2)))        # 16 * 2h * 2w -> 32 * 4h * 4w        out1 = F.relu(self.enf2_bn(F.interpolate(self.encoderf2(out1),scale_factor=(2,2),mode ='bicubic')))        #  32 * h/4 * w/4        tmp = out        # 32 * 4h * 4w -> 32 * h/4 *w/4        out = paddle.add(out,F.interpolate(F.relu(self.inte2_1bn(self.intere2_1(out1))),scale_factor=(0.0625,0.0625),mode ='bicubic'))        # 32 * h/4 * w/4 -> 32 *4h *4w        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte2_2bn(self.intere2_2(tmp))),scale_factor=(16,16),mode ='bicubic'))                #  32 * h/4 *w/4        u2 = out        #  32 *4h *4w        o2 = out1                # 32 * h/4 *w/4 -> 64 * h/8 *w/8        out = F.relu(self.en3_bn(F.max_pool2d(self.encoder3(out),2,2)))        # 32 *4h *4w -> 64 * 8h *8w        out1 = F.relu(self.enf3_bn(F.interpolate(self.encoderf3(out1),scale_factor=(2,2),mode ='bicubic')))        #  64 * h/8 *w/8         tmp = out        #  64 * 8h *8w -> 64 * h/8 * w/8        out = paddle.add(out,F.interpolate(F.relu(self.inte3_1bn(self.intere3_1(out1))),scale_factor=(0.015625,0.015625),mode ='bicubic'))        #  64 * h/8 *w/8 -> 64 * 8h * 8w        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte3_2bn(self.intere3_2(tmp))),scale_factor=(64,64),mode ='bicubic'))                ### End of encoder block        ### Start Decoder                # 64 * h/8 * w/8 -> 32 * h/4 * w/4         out = F.relu(self.de1_bn(F.interpolate(self.decoder1(out),scale_factor=(2,2),mode ='bicubic')))  #U-NET        # 64 * 8h * 8w -> 32 * 4h * 4w         out1 = F.relu(self.def1_bn(F.max_pool2d(self.decoderf1(out1),2,2))) #Ki-NET        # 32 * h/4 * w/4         tmp = out        # 32 * 4h * 4w  -> 32 * h/4 * w/4         out = paddle.add(out,F.interpolate(F.relu(self.intd1_1bn(self.interd1_1(out1))),scale_factor=(0.0625,0.0625),mode ='bicubic'))        # 32 * h/4 * w/4  -> 32 * 4h * 4w         out1 = paddle.add(out1,F.interpolate(F.relu(self.intd1_2bn(self.interd1_2(tmp))),scale_factor=(16,16),mode ='bicubic'))                # 32 * h/4 * w/4         out = paddle.add(out,u2)  #skip conn        # 32 * 4h * 4w         out1 = paddle.add(out1,o2)  #skip conn        # 32 * h/4 * w/4 -> 16 * h/2 * w/2         out = F.relu(self.de2_bn(F.interpolate(self.decoder2(out),scale_factor=(2,2),mode ='bicubic')))        # 32 * 4h * 4w  -> 16 * 2h * 2w        out1 = F.relu(self.def2_bn(F.max_pool2d(self.decoderf2(out1),2,2)))        # 16 * h/2 * w/2         tmp = out        # 16 * 2h * 2w -> 16 * h/2 * w/2        out = paddle.add(out,F.interpolate(F.relu(self.intd2_1bn(self.interd2_1(out1))),scale_factor=(0.25,0.25),mode ='bicubic'))        # 16 * h/2 * w/2 -> 16 * 2h * 2w        out1 = paddle.add(out1,F.interpolate(F.relu(self.intd2_2bn(self.interd2_2(tmp))),scale_factor=(4,4),mode ='bicubic'))                # 16 * h/2 * w/2        out = paddle.add(out,u1)        # 16 * 2h * 2w        out1 = paddle.add(out1,o1)        # 16 * h/2 * w/2 -> 8 * h * w        out = F.relu(self.de3_bn(F.interpolate(self.decoder3(out),scale_factor=(2,2),mode ='bicubic')))        # 16 * 2h * 2w -> 8 * h * w        out1 = F.relu(self.def3_bn(F.max_pool2d(self.decoderf3(out1),2,2)))        # 8 * h * w        out = paddle.add(out,out1) # fusion of both branches        # 最后一层用sigmoid激活函数        out = F.sigmoid(self.final(out))  #1*1 conv                return out

   In [ ]

import paddleimport paddle.nn as nnimport paddle.nn.functional as Ffrom paddle.nn import initializerdef init_weights(init_type='kaiming'):    if init_type == 'normal':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.Normal())    elif init_type == 'xavier':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.XavierNormal())    elif init_type == 'kaiming':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal)    else:        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)class unetConv2(nn.Layer):    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):        super(unetConv2, self).__init__()        self.n = n        self.ks = ks        self.stride = stride        self.padding = padding        s = stride        p = padding        if is_batchnorm:            for i in range(1, n + 1):                conv = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),                                     nn.BatchNorm(out_size),                                     nn.ReLU(), )                setattr(self, 'conv%d' % i, conv)                in_size = out_size        else:            for i in range(1, n + 1):                conv = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),                                     nn.ReLU(), )                setattr(self, 'conv%d' % i, conv)                in_size = out_size        # initialise the blocks        for m in self.children():            m.weight_attr=init_weights(init_type='kaiming')            m.bias_attr=init_weights(init_type='kaiming')    def forward(self, inputs):        x = inputs        for i in range(1, self.n + 1):            conv = getattr(self, 'conv%d' % i)            x = conv(x)        return x'''    UNet 3+'''class UNet_3Plus(nn.Layer):    def __init__(self, in_channels=3, n_classes=1, is_deconv=True, is_batchnorm=True, end_sigmoid=True):        super(UNet_3Plus, self).__init__()        self.is_deconv = is_deconv        self.in_channels = in_channels        self.is_batchnorm = is_batchnorm        self.end_sigmoid = end_sigmoid        filters = [16, 32, 64, 128, 256]        ## -------------Encoder--------------        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)        self.maxpool1 = nn.MaxPool2D(kernel_size=2)        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)        self.maxpool2 = nn.MaxPool2D(kernel_size=2)        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)        self.maxpool3 = nn.MaxPool2D(kernel_size=2)        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)        self.maxpool4 = nn.MaxPool2D(kernel_size=2)        self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)        ## -------------Decoder--------------        self.CatChannels = filters[0]        self.CatBlocks = 5        self.UpChannels = self.CatChannels * self.CatBlocks        '''stage 4d'''        # h2->320*320, hd4->40*40, Pooling 8 times        self.h2_PT_hd4 = nn.MaxPool2D(8, 8, ceil_mode=True)        self.h2_PT_hd4_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)        self.h2_PT_hd4_bn = nn.BatchNorm(self.CatChannels)        self.h2_PT_hd4_relu = nn.ReLU()        # h2->160*160, hd4->40*40, Pooling 4 times        self.h2_PT_hd4 = nn.MaxPool2D(4, 4, ceil_mode=True)        self.h2_PT_hd4_conv = nn.Conv2D(filters[1], self.CatChannels, 3, padding=1)        self.h2_PT_hd4_bn = nn.BatchNorm(self.CatChannels)        self.h2_PT_hd4_relu = nn.ReLU()        # h3->80*80, hd4->40*40, Pooling 2 times        self.h3_PT_hd4 = nn.MaxPool2D(2, 2, ceil_mode=True)        self.h3_PT_hd4_conv = nn.Conv2D(filters[2], self.CatChannels, 3, padding=1)        self.h3_PT_hd4_bn = nn.BatchNorm(self.CatChannels)        self.h3_PT_hd4_relu = nn.ReLU()        # h4->40*40, hd4->40*40, Concatenation        self.h4_Cat_hd4_conv = nn.Conv2D(filters[3], self.CatChannels, 3, padding=1)        self.h4_Cat_hd4_bn = nn.BatchNorm(self.CatChannels)        self.h4_Cat_hd4_relu = nn.ReLU()        # hd5->20*20, hd4->40*40, Upsample 2 times        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14        self.hd5_UT_hd4_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)        self.hd5_UT_hd4_bn = nn.BatchNorm(self.CatChannels)        self.hd5_UT_hd4_relu = nn.ReLU()        # fusion(h2_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)        self.conv4d_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16        self.bn4d_1 = nn.BatchNorm(self.UpChannels)        self.relu4d_1 = nn.ReLU()        '''stage 3d'''        # h2->320*320, hd3->80*80, Pooling 4 times        self.h2_PT_hd3 = nn.MaxPool2D(4, 4, ceil_mode=True)        self.h2_PT_hd3_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)        self.h2_PT_hd3_bn = nn.BatchNorm(self.CatChannels)        self.h2_PT_hd3_relu = nn.ReLU()        # h2->160*160, hd3->80*80, Pooling 2 times        self.h2_PT_hd3 = nn.MaxPool2D(2, 2, ceil_mode=True)        self.h2_PT_hd3_conv = nn.Conv2D(filters[1], self.CatChannels, 3, padding=1)        self.h2_PT_hd3_bn = nn.BatchNorm(self.CatChannels)        self.h2_PT_hd3_relu = nn.ReLU()        # h3->80*80, hd3->80*80, Concatenation        self.h3_Cat_hd3_conv = nn.Conv2D(filters[2], self.CatChannels, 3, padding=1)        self.h3_Cat_hd3_bn = nn.BatchNorm(self.CatChannels)        self.h3_Cat_hd3_relu = nn.ReLU()        # hd4->40*40, hd4->80*80, Upsample 2 times        self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14        self.hd4_UT_hd3_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd4_UT_hd3_bn = nn.BatchNorm(self.CatChannels)        self.hd4_UT_hd3_relu = nn.ReLU()        # hd5->20*20, hd4->80*80, Upsample 4 times        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14        self.hd5_UT_hd3_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)        self.hd5_UT_hd3_bn = nn.BatchNorm(self.CatChannels)        self.hd5_UT_hd3_relu = nn.ReLU()        # fusion(h2_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)        self.conv3d_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16        self.bn3d_1 = nn.BatchNorm(self.UpChannels)        self.relu3d_1 = nn.ReLU()        '''stage 2d '''        # h2->320*320, hd2->160*160, Pooling 2 times        self.h2_PT_hd2 = nn.MaxPool2D(2, 2, ceil_mode=True)        self.h2_PT_hd2_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)        self.h2_PT_hd2_bn = nn.BatchNorm(self.CatChannels)        self.h2_PT_hd2_relu = nn.ReLU()        # h2->160*160, hd2->160*160, Concatenation        self.h2_Cat_hd2_conv = nn.Conv2D(filters[1], self.CatChannels, 3, padding=1)        self.h2_Cat_hd2_bn = nn.BatchNorm(self.CatChannels)        self.h2_Cat_hd2_relu = nn.ReLU()        # hd3->80*80, hd2->160*160, Upsample 2 times        self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14        self.hd3_UT_hd2_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd3_UT_hd2_bn = nn.BatchNorm(self.CatChannels)        self.hd3_UT_hd2_relu = nn.ReLU()        # hd4->40*40, hd2->160*160, Upsample 4 times        self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14        self.hd4_UT_hd2_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd4_UT_hd2_bn = nn.BatchNorm(self.CatChannels)        self.hd4_UT_hd2_relu = nn.ReLU()        # hd5->20*20, hd2->160*160, Upsample 8 times        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14        self.hd5_UT_hd2_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)        self.hd5_UT_hd2_bn = nn.BatchNorm(self.CatChannels)        self.hd5_UT_hd2_relu = nn.ReLU()        # fusion(h2_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)        self.Conv2D_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16        self.bn2d_1 = nn.BatchNorm(self.UpChannels)        self.relu2d_1 = nn.ReLU()        '''stage 1d'''        # h2->320*320, hd1->320*320, Concatenation        self.h2_Cat_hd1_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)        self.h2_Cat_hd1_bn = nn.BatchNorm(self.CatChannels)        self.h2_Cat_hd1_relu = nn.ReLU()        # hd2->160*160, hd1->320*320, Upsample 2 times        self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14        self.hd2_UT_hd1_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd2_UT_hd1_bn = nn.BatchNorm(self.CatChannels)        self.hd2_UT_hd1_relu = nn.ReLU()        # hd3->80*80, hd1->320*320, Upsample 4 times        self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14        self.hd3_UT_hd1_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd3_UT_hd1_bn = nn.BatchNorm(self.CatChannels)        self.hd3_UT_hd1_relu = nn.ReLU()        # hd4->40*40, hd1->320*320, Upsample 8 times        self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14        self.hd4_UT_hd1_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd4_UT_hd1_bn = nn.BatchNorm(self.CatChannels)        self.hd4_UT_hd1_relu = nn.ReLU()        # hd5->20*20, hd1->320*320, Upsample 16 times        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear')  # 14*14        self.hd5_UT_hd1_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)        self.hd5_UT_hd1_bn = nn.BatchNorm(self.CatChannels)        self.hd5_UT_hd1_relu = nn.ReLU()        # fusion(h2_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)        self.conv1d_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16        self.bn1d_1 = nn.BatchNorm(self.UpChannels)        self.relu1d_1 = nn.ReLU()        # output        self.outconv1 = nn.Conv2D(self.UpChannels, n_classes, 3, padding=1)        # initialise weights        for m in self.sublayers ():            if isinstance(m, nn.Conv2D):                m.weight_attr = init_weights(init_type='kaiming')                m.bias_attr = init_weights(init_type='kaiming')            elif isinstance(m, nn.BatchNorm):                m.param_attr =init_weights(init_type='kaiming')                m.bias_attr = init_weights(init_type='kaiming')    def forward(self, inputs):        ## -------------Encoder-------------        h2 = self.conv1(inputs)  # h2->320*320*64        h2 = self.maxpool1(h2)        h2 = self.conv2(h2)  # h2->160*160*128        h3 = self.maxpool2(h2)        h3 = self.conv3(h3)  # h3->80*80*256        h4 = self.maxpool3(h3)        h4 = self.conv4(h4)  # h4->40*40*512        h5 = self.maxpool4(h4)        hd5 = self.conv5(h5)  # h5->20*20*1024        ## -------------Decoder-------------        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))        h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))        h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))        hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))        hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(            paddle.concat([h2_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4], 1)))) # hd4->40*40*UpChannels        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))        h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))        hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))        hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))        hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(            paddle.concat([h2_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3], 1)))) # hd3->80*80*UpChannels        h2_PT_hd2 = self.h2_PT_hd2_relu(self.h2_PT_hd2_bn(self.h2_PT_hd2_conv(self.h2_PT_hd2(h2))))        h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))        hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))        hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))        hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))        hd2 = self.relu2d_1(self.bn2d_1(self.Conv2D_1(            paddle.concat([h2_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2], 1)))) # hd2->160*160*UpChannels        h2_Cat_hd1 = self.h2_Cat_hd1_relu(self.h2_Cat_hd1_bn(self.h2_Cat_hd1_conv(h2)))        hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))        hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))        hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))        hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))        hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(            paddle.concat([h2_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1], 1)))) # hd1->320*320*UpChannels        d1 = self.outconv1(hd1)  # d1->320*320*n_classes        if self.end_sigmoid:            out = F.sigmoid(d1)        else:            out = d1        return out

   In [ ]

# 查看网络结构KIunet = kiunet(in_channels = 3, n_classes =3)Unet3p = UNet_3Plus( in_channels=3, n_classes=3)model = paddle.Model(Unet3p)model.summary((2,3, 256, 256))

   

3.4 训练函数介绍

该部分介绍模板的具体使用:


模型运行

参数 self.is_Train self.PATH

False进行推理从头训练True进行训练读取模型继续训练

注意: self.PATH的True指输入模型路径,例如:当self.is_Train为False时, self.PATH为模型路径,此时不需要训练,直接进行推理生成结果。


损失函数

在work目录下loss.py文件中实现了:

—- TVLoss

AiPPT模板广场 AiPPT模板广场

AiPPT模板广场-PPT模板-word文档模板-excel表格模板

AiPPT模板广场 147 查看详情 AiPPT模板广场

—- SSIMLoss

—- LossVGG19(感知损失)

……

可以根据自己的需求进行修改或添加


最优模型

在Eval方法中定义了score变量,该变量记录验证集上的评价指标。

每个EPOCH会计算一次验证集上的score。

只有当获得更高的score时,才会删除原模型,记录新的最优模型。

注意: 需要自定义评价指标时,只需修改score的计算即可


保存结果

函数会zi’dong在work目录下生成一系列目录:

log

用于保存训练过程的结果

outputs

用于以模型名称为文件夹保存推理的结果

saveModel

用于记录最优的模型,名称为模型名+验证集score分数

注意: 需要自定义修改self.Modelname,用于区分不同模型的生成结果

In [ ]

## 主函数定义 Baseline # 基本假设 干净图像 + 噪声 = 受污染图像# 阴影部分的像素值要低一些,因此,需要对受污染图像的像素值进行增强# 受污染图像 + 阴影 = 干净图像"""@author: xupeng"""from work.loss import *from work.util import batch_psnr_ssimfrom PIL import Imageimport numpy as npimport pandas as pdfrom tqdm import tqdmimport osimport matplotlib.pyplot as pltimport globfrom PIL import Imagefrom skimage.measure.simple_metrics import compare_psnrimport skimageimport paddle.vision.transforms as Timport cv2class Solver(object):    def __init__(self):        self.model = None        self.lr = 1e-4 # 学习率        self.epochs = 10  # 训练的代数        self.batch_size = 4 # 训练批次数量        self.optimizer = None        self.scheduler = None        self.saveSetDir = r'data/delight_testA_dataset/images/' # 测试集地址                self.train_set = None        self.eval_set = None        self.train_loader = None        self.eval_loader = None                self.Modelname = 'Unet3p' # 使用的网络声明(生成的文件会以该声明为区分)                self.evaTOP = 0        self.is_Train = True # #是否对执行模型训练,当为False时,给出self.PATH,直接进行推理        self.PATH = False #False# 用于记录最优模型的名称,需要预训练时,此项不为空    def InitModeAndData(self):        print("---------------trainInit:---------------")        # API文档搜索:vision.transforms 查看数据增强方法        is_transforms = True        self.train_set =  DEshadowDataset(mode='train',is_transforms = is_transforms)        self.eval_set  =  DEshadowDataset(mode='eval',is_transforms = is_transforms)        # 使用paddle.io.DataLoader 定义DataLoader对象用于加载Python生成器产生的数据        self.train_loader = paddle.io.DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)        self.eval_loader = paddle.io.DataLoader(self.eval_set, batch_size=self.batch_size, shuffle=False)        self.model = UNet_3Plus( in_channels=3, n_classes=3)        if self.is_Train and self.PATH: # 已经有模型还要训练时,进行二次训练            params = paddle.load(self.PATH)            self.model.set_state_dict(params)            self.evaTOP = float(self.PATH.split(self.Modelname)[-1].replace('.pdparams',''))        self.scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=self.lr, T_max=30, verbose=False)        self.optimizer = paddle.optimizer.Adam(parameters=self.model.parameters(), learning_rate=self.scheduler,beta1=0.5, beta2=0.999)        # 创建文件夹         for item in ['log','outputs','saveModel']:            make_folder = os.path.join('work',item)            if  not os.path.exists(make_folder):                os.mkdir(make_folder)    def train(self,epoch):        print("=======Epoch:{}/{}=======".format(epoch,self.epochs))        self.model.train()        TV_criterion = nTVLoss()    # 用于规范图像噪声        ssim_criterion = SSIMLoss() # 用于对原图进行        lossnet = LossVGG19() # 感知loss        l1_loss = paddle.nn.L1Loss()        Mse_loss = nn.MSELoss()        try: # 使用这种写法(try-except 和 ascii=True),可以避免windows终端或异常中断tqdm出现的混乱            with tqdm(enumerate(self.train_loader),total=len(self.train_loader), ascii=True) as tqdmData:                mean_loss = []                                for idx, (img_train,mask_train) in tqdmData:                    tqdmData.set_description('train')                    img_train = paddle.to_tensor(img_train,dtype="float32")                    mask_train =  paddle.to_tensor(mask_train,dtype="float32")                    # MODEL                    outputs_noise = self.model(img_train) # mask > img  因此需要增加像素值                    mask_noise = mask_train - img_train # 真实图像与阴影的差值                    # 恢复出来的图像                    restore_trian =img_train + outputs_noise                    '''                    # 去均值                    tensor_c = paddle.to_tensor(np.array([123.6800, 116.7790, 103.9390]).astype(np.float32).reshape((1, 3, 1, 1)))                    # 感知损失                    # preceptual loss                    loss_fake_B = lossnet(restore_trian * 255 - tensor_c)                    loss_real_B = lossnet(mask_train * 255 - tensor_c)                    p0 = l1_loss(restore_trian * 255 - tensor_c, mask_train * 255 - tensor_c) * 2                    p1 = l1_loss(loss_fake_B['relu1'], loss_real_B['relu1']) / 2.6                    p2 = l1_loss(loss_fake_B['relu2'], loss_real_B['relu2']) / 4.8                    loss_p = p0 + p1 + p2                    loss = loss_p  + ssim_criterion(restore_trian,mask_train) + 10*l1_loss(outputs_noise,mask_noise)                    '''                    loss = l1_loss(restore_trian,mask_train)                    self.optimizer.clear_grad()                    loss.backward()                    self.optimizer.step()                                        self.scheduler.step() ### 改优化器记得改这个                    mean_loss.append(loss.item())        except KeyboardInterrupt:            tqdmData.close()            os._exit(0)        tqdmData.close()        # 清除中间变量,释放内存        del loss,img_train,mask_train,outputs_noise,mask_noise,restore_trian        paddle.device.cuda.empty_cache()        return {'Mean_trainLoss':np.mean(mean_loss)}        def Eval(self,modelname):                self.model.eval()        temp_eval_psnr ,temp_eval_ssim= [],[]        with paddle.no_grad():            try:                with tqdm(enumerate(self.eval_loader),total=len(self.eval_loader), ascii=True) as tqdmData:                    for idx, (img_eval,mask_eval) in tqdmData:                        tqdmData.set_description(' eval')                        img_eval=  paddle.to_tensor(img_eval,dtype="float32")                        mask_eval = paddle.to_tensor(mask_eval,dtype="float32")                        outputs_denoise = self.model(img_eval) # 模型输出                        outputs_denoise = img_eval + outputs_denoise # 恢复后的图像                        psnr_test,ssim_test = batch_psnr_ssim(outputs_denoise, mask_eval, 1.)                        temp_eval_psnr.append(psnr_test)                        temp_eval_ssim.append(ssim_test)                                except KeyboardInterrupt:                tqdmData.close()                os._exit(0)            tqdmData.close()            paddle.device.cuda.empty_cache()            # 打印test psnr & ssim        # print('eval_psnr:',np.mean(temp_eval_psnr),'eval_ssim:',np.mean(temp_eval_ssim))        # 实现评价指标        score = 0.05*np.mean(temp_eval_psnr)+0.5*np.mean(temp_eval_ssim)        return {'eval_psnr':np.mean(temp_eval_psnr),'eval_ssim':np.mean(temp_eval_ssim),'SCORE':score}        def saveModel(self,trainloss,modelname):                trainLoss = trainloss['SCORE']        if trainLoss < self.evaTOP and self.evaTOP!=0:             return 0        else:            folder = 'work/saveModel/'            self.PATH = folder+modelname+str(trainLoss)+'.pdparams'            removePATH = folder+modelname+str(self.evaTOP)+'.pdparams'            paddle.save(self.model.state_dict(), self.PATH)            if self.evaTOP!=0:                os.remove(removePATH)                        self.evaTOP = trainLoss            return 1            def saveResult(self):        print("---------------saveResult:---------------")        self.model.set_state_dict(paddle.load(self.PATH))        self.model.eval()        paddle.set_grad_enabled(False)        paddle.device.cuda.empty_cache()        data_dir = glob.glob(self.saveSetDir+'*.jpg')        # 创建保存文件夹        make_save_result = os.path.join('work/outputs',self.Modelname)        if  not os.path.exists(make_save_result):            os.mkdir(make_save_result)        saveSet = pd.DataFrame()        tpsnr,tssim = [],[]                for idx,ori_path in enumerate(data_dir):            print(len(data_dir),'|',idx+1,end = 'r',flush = True)                        ori = cv2.imread(ori_path) # W,H,C            ori = cv2.cvtColor(ori, cv2.COLOR_BGR2RGB) # BGR 2 RGB            ori_w,ori_h,ori_c = ori.shape            # normalize_test = T.Normalize(             # [0.610621, 0.5989216, 0.5876396],             # [0.1835931, 0.18701428, 0.19362564],            # data_format='HWC')            #ori = normalize_test(ori)            ori = T.functional.resize(ori,(1024,1024),interpolation = 'bicubic') ###不切块送进去 1024,1024 上采样            # from HWC to CHW ,[0,255] to [0,1]            ori = np.transpose(ori,(2,0,1))/255            ori_img = paddle.to_tensor(ori,dtype="float32")            ori_img = paddle.unsqueeze(ori_img,0) # N C W H            out_noise = self.model(ori_img) #不切块送进去            out_noise = ori_img + out_noise #加上恢复的像素            img_cpu = out_noise.cpu().numpy()            #保存结果            img_cpu = np.squeeze(img_cpu)            img_cpu = np.transpose(img_cpu,(1,2,0)) # C,W,H to W,H,C                        img_cpu = np.clip(img_cpu, 0., 1.)            savepic = np.uint8(img_cpu*255)            savepic = T.functional.resize(savepic,(ori_w,ori_h),interpolation = 'bicubic') ## 采样回原始大小            # 保存路径            savedir = os.path.join(make_save_result,ori_path.split('/')[-1])            savepic = cv2.cvtColor(savepic, cv2.COLOR_RGB2BGR) # BGR            cv2.imwrite(savedir, savepic, [int( cv2.IMWRITE_JPEG_QUALITY), 100])    def run(self):                self.InitModeAndData()                    if self.is_Train:            modelname = self.Modelname #  使用的网络名称            result = pd.DataFrame()                        for epoch in range(1, self.epochs + 1):                                trainloss = self.train(epoch)                evalloss =  self.Eval(modelname)#                Type = self.saveModel(evalloss,modelname)                                type_ = {'Type':Type}                trainloss.update(evalloss)#                trainloss.update(type_)                result = result.append(trainloss,ignore_index=True)                print('Epoch:',epoch,trainloss)                                #self.scheduler.step()            evalloss =  self.Eval(modelname)#            result.to_csv('work/log/' +modelname+str(evalloss['SCORE'])+'.csv')#            self.saveResult()        else:            self.saveResult()def main():    solver = Solver()    solver.run()    if __name__ == '__main__':    main()

   

进入输出路径创建readme.txt文件,输入要求的内容:

训练框架:PaddlePaddle

代码运行环境:V100

是否使用GPU:是

单张图片耗时/s:1

模型大小:45

其他说明:算法参考UNet+++

In [ ]

# %cd /home/aistudio/work/outputs/Unet3p# !zip result.zip *.jpg *.txt

   

4 项目总结

项目实现了一套简单的训练模板。

项目基于文档阴影消除任务创建,提交结果为0.59951

欢迎大家在该基础上修改自己的算法!

『网盘赛』基于自定义训练模板的文档阴影消除 - 创想鸟『网盘赛』基于自定义训练模板的文档阴影消除 - 创想鸟

以上就是『网盘赛』基于自定义训练模板的文档阴影消除的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/315368.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
mysql中分组命令是
上一篇 2025年11月5日 07:26:39
《纪念碑谷3》免费DLC“生命花园”12月3日更新
下一篇 2025年11月5日 07:26:49

相关推荐

  • Matplotlib 地图中多类型图例的创建与优化

    Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化

    本教程旨在解决matplotlib地图可视化中,如何在一个图例中同时展示颜色块(如区域分类)和自定义标记(如特定兴趣点)的问题。文章详细介绍了当传统`patch`对象无法正确显示标记时,如何利用`matplotlib.lines.line2d`创建标记图例句柄,并将其与颜色块图例句柄合并,从而生成一…

    2026年5月10日 用户投稿
    100
  • Golang JSON序列化:控制敏感字段暴露的最佳实践

    本教程探讨golang中如何高效控制结构体字段在json序列化时的可见性。当需要将包含敏感信息的结构体数组转换为json响应时,通过利用`encoding/json`包提供的结构体标签,特别是`json:”-“`,可以轻松实现对特定字段的忽略,从而避免敏感数据泄露,确保api…

    2026年5月10日
    000
  • 利用海象运算符简化条件赋值:Python教程与最佳实践

    本文旨在探讨Python中海象运算符(:=)在条件赋值场景下的应用。通过对比传统if/else语句与海象运算符,以及条件表达式,分析海象运算符在简化代码、提高可读性方面的优势与局限性。并通过具体示例,展示如何在列表推导式等场景下合理使用海象运算符,同时强调其潜在的复杂性及替代方案,帮助开发者更好地掌…

    2026年5月10日
    100
  • 怎么在PHP代码中实现图片上传功能_PHP图片上传功能实现与安全处理教程

    首先创建含enctype的HTML表单,再用PHP接收文件,检查目录、移动临时文件,验证类型与大小,生成唯一文件名,并调整php.ini限制以确保上传成功。 如果您尝试在PHP项目中添加图片上传功能,但服务器无法正确接收或保存文件,则可能是由于表单配置、文件处理逻辑或安全限制的问题。以下是实现该功能…

    2026年5月10日
    100
  • 比特币新手教程 比特币交易平台有哪些

    比特币是一种去中心化的数字货币,基于区块链技术实现点对点交易,具有匿名性、有限发行和不可篡改等特点;新手可通过交易所购买,P2P交易获得比特币,常用平台包括Binance、OKX和Huobi;交易流程包括注册账户、实名认证、绑定支付方式、充值法币并下单购买,可选择市价单或限价单;比特币存储方式有交易…

    2026年5月10日
    000
  • c++中的SFINAE技术是什么_c++模板编程中的SFINAE原理与应用

    SFINAE 是“替换失败不是错误”的原则,指模板实例化时若参数替换导致错误,只要存在其他合法候选,编译器不报错而是继续重载决议。它用于条件启用模板、类型检测等场景,如通过 decltype 或 enable_if 控制函数重载,实现类型特征判断。尽管 C++20 引入 Concepts 简化了部分…

    2026年5月10日
    000
  • Go语言mgo查询构建:深入理解bson.M与日期范围查询的正确实践

    本文旨在解决go语言mgo库中构建复杂查询时,特别是涉及嵌套`bson.m`和日期范围筛选的常见错误。我们将深入剖析`bson.m`的类型特性,解释为何直接索引`interface{}`会导致“invalid operation”错误,并提供一种推荐的、结构清晰的代码重构方案,以确保查询条件能够正确…

    2026年5月10日
    100
  • RichHandler与Rich Progress集成:解决显示冲突的教程

    在使用rich库的`richhandler`进行日志输出并同时使用`progress`组件时,可能会遇到显示错乱或溢出问题。这通常是由于为`richhandler`和`progress`分别创建了独立的`console`实例导致的。解决方案是确保日志处理器和进度条组件共享同一个`console`实例…

    2026年5月10日
    000
  • Golang goroutine与channel调试技巧

    使用go run -race检测数据竞争,结合runtime.NumGoroutine监控协程数量,通过pprof分析阻塞调用栈,利用select超时避免永久阻塞,有效排查goroutine泄漏、死锁和数据竞争问题。 Go语言的goroutine和channel是并发编程的核心,但它们也带来了调试上…

    2026年5月10日
    000
  • 《魔兽世界》将于6月11日开启国服回归技术测试

    《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试

    《%ign%ignore_a_1%re_a_1%》官方宣布,将于6月11日开启国服回归技术测试,时间为7天,并称可以在6月内正式开服,玩家们可以访问官网下载战网客户端并预下载“巫妖王之怒”客户端,技术测试详情见下图。 WordAi WordAI是一个AI驱动的内容重写平台 53 查看详情 以上就是《…

    2026年5月10日 用户投稿
    200
  • 使用 Jupyter Notebook 进行探索性数据分析

    Jupyter Notebook通过单元格实现代码与Markdown结合,支持数据导入(pandas)、清洗(fillna)、探索(matplotlib/seaborn可视化)、统计分析(describe/corr)和特征工程,便于记录与分享分析过程。 Jupyter Notebook 是进行探索性…

    2026年5月10日
    000
  • 如何在HTML中插入表单元素_HTML表单控件与输入类型使用指南

    HTML表单通过标签构建,包含action和method属性定义数据提交目标与方式,常用input类型如text、password、email等适配不同输入需求,配合label、required、placeholder提升可用性,结合textarea、select、button等控件实现完整交互,是…

    2026年5月10日
    000
  • 创建指定大小并填充特定数据的Golang文件教程

    本文将介绍如何使用Golang创建一个指定大小的文件,并用特定数据填充它。我们将使用 `os` 包提供的函数来创建和截断文件,从而实现快速生成大文件的目的。示例代码展示了如何创建一个10MB的文件,并将其填充为全零数据。掌握这些方法,可以方便地在例如日志系统或磁盘队列等场景中,预先创建测试文件或初始…

    2026年5月10日
    000
  • Python命令怎样使用profile分析脚本性能 Python命令性能分析的基础教程

    使用Python的cProfile模块分析脚本性能最直接的方式是通过命令行执行python -m cProfile your_script.py,它会输出每个函数的调用次数、总耗时、累积耗时等关键指标,帮助定位性能瓶颈;为进一步分析,可将结果保存为文件python -m cProfile -o ou…

    2026年5月10日
    000
  • 如何插入查询结果数据_SQL插入Select查询结果方法

    如何插入查询结果数据_SQL插入Select查询结果方法如何插入查询结果数据_SQL插入Select查询结果方法如何插入查询结果数据_SQL插入Select查询结果方法如何插入查询结果数据_SQL插入Select查询结果方法

    使用INSERT INTO…SELECT语句可高效插入数据,通过NOT EXISTS、LEFT JOIN、MERGE语句或唯一约束避免重复;表结构不一致时可通过别名、类型转换、默认值或计算字段处理;结合存储过程可提升可维护性,支持参数化与动态SQL。 将查询结果数据插入到另一个表中,可以…

    2026年5月10日 用户投稿
    000
  • 使用 WebCodecs VideoDecoder 实现精确逐帧回退

    本文档旨在解决在使用 WebCodecs VideoDecoder 进行视频解码时,实现精确逐帧回退的问题。通过比较帧的时间戳与目标帧的时间戳,可以避免渲染中间帧,从而提高用户体验。本文将提供详细的解决方案和示例代码,帮助开发者实现精确的视频帧控制。 在使用 WebCodecs VideoDecod…

    2026年5月10日
    000
  • Discord.py 交互按钮超时与持久化解决方案

    本教程旨在解决Discord.py中交互按钮在一段时间后出现“This Interaction Failed”错误的问题。我们将深入探讨视图(View)的超时机制,并提供通过正确设置timeout参数以及利用bot.add_view()方法实现按钮持久化的具体方案,确保您的机器人交互功能稳定可靠,即…

    2026年5月10日
    000
  • Debian Copilot的社区活跃度如何

    debian copilot是codeberg社区维护的ai助手,旨在为debian用户提供服务。尽管搜索结果中没有直接提供关于debian copilot社区支持活跃度的具体数据,但我们可以通过debian社区的整体活跃度和特点来推断其活跃性。 Debian社区的一般情况: Debian拥有详尽的…

    2026年5月10日
    000
  • Python递归函数追踪与性能考量:以序列打印为例

    本文深入探讨了Python中一种递归打印序列元素的方法,并着重演示了如何通过引入缩进参数来有效追踪递归函数的执行流程和参数变化。通过实际代码示例,文章揭示了递归调用可能带来的潜在性能开销,特别是对调用栈空间的需求,以及Python默认递归深度限制可能导致的错误,为读者提供了理解和优化递归算法的实用见…

    2026年5月10日
    000
  • python中zip函数详解 python多序列压缩zip函数应用场景

    zip函数的应用场景包括:1) 同时遍历多个序列,2) 合并多个列表的数据,3) 数据分析和科学计算中的元素运算,4) 处理csv文件,5) 性能优化。zip函数是一个强大的工具,能够简化代码并提高处理多个序列时的效率。 在Python中,zip函数是一个非常有用的工具,它能够将多个可迭代对象打包成…

    2026年5月10日
    000

发表回复

登录后才能评论
关注微信