『网盘赛』基于自定义训练模板的文档阴影消除

本文基于文档阴影消除网盘赛,提供了一套PaddlePaddle训练模板。模板实现了定制输出、中断续训、保存最优模型等功能,涵盖数据增强、模型训练等全流程,还支持图像分块提升精度。示例用KIUnet和UNet_3Plus模型,提交结果0.59951,方便用户快速修改实现想法。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

『网盘赛』基于自定义训练模板的文档阴影消除 - 创想鸟

「自定义训练模板」基于文档阴影消除网盘赛的示例

在使用Paddle写项目的时候,相信你也有和我一样的困惑:

高级API的封装太好,需要定制一些函数功能的时候不好添加;低级的API功能又很单一,要实现很多功能需要额外写很多代码。

如果能有一个模板,本身就具有比较完善的功能,只需要修改一部分代码就能快速实现自己的想法就好了。

基于这个想法,我完成了一套训练模型的模板,具有定制输出,模型中断后继续训练,保存最优模型等功能。

大家可以基于这个模板,快速修改并实现自己的想法。

0 项目背景

生活中在使用手机进行文档扫描时,由于角度原因,难免会在照片中留下举手拍摄时遗留的恼人阴影。为了帮助人们解决这样的问题,减少阴影带来的干扰,选手需要通过深度学习技术训练模型,对给定的真实场景下采集得到的带有阴影的图片进行处理,还原图片原本的样子,并最终输出处理后的结果图片。

有阴影图片 无阴影图片

@@##@@            @@##@@            

1 数据介绍

本次比赛最新发布的数据集共包含训练集、A榜测试集、B榜测试集三个部分,其中,训练集共1410个样本(图片编号非连续),A榜测试集共300个样本(图片编号连续),B榜测试集共397个样本;

images为带阴影的源图像数据,gts为无阴影的真值数据;(测试数据集的GT不开放)

images与gts中的图片根据图片名称一一对应。

2 模板已实现功能

1 :每个epoch输出定制信息,使用tqdm进度条实时查看每个epoch运行时间(修正windows终端或异常中断tqdm出现的混乱)

2 :可以直接定义并修改损失函数和模型,并给出示例

3 :可以选择传入模型继续训练或从头训练

4 :在训练完成后,以csv形式保存log文件

5 :自定义评价指标,并根据评价指标结果自动保存最优模型

6 :完成数据增强,模型训练,测试,推理全流程代码

7 :实现输入图像的分块,能有效提升图像类任务精度


参考项目:牛尾扫暗影,虎头开盲盒。三卷网盘赛,榜评0.59168

3 代码实现及讲解

In [ ]

# 解压数据集!unzip  -o data/data125945/shadowRemovalOfDoc.zip!unzip delight_testA_dataset.zip -d data/ >>/dev/null!unzip delight_train_dataset.zip -d data/ >>/dev/null

   In [ ]

# 安装需要的库函数!pip install scikit_image==0.15.0

   

3.1 图像分块


该部分代码用于将图像分割成小块后保存起来,包括翻转,旋转等七种不同组合的数据增强方法。


prepare_data( patch_size=256, stride=200, aug_times=1)

patch_size:切分的正方形块的大小

stride:切分步长,每隔切分步长进行切块

aug_times:增强次数,每个图块生成几张增强图像

CODE 4 第32行 scales = [1] # 对数据进行随机放缩

该参数是一个List,用于对图像进行缩放。

[1,1.2] 表示原图和将图像扩大1.2倍后进行切分,这样增加训练数据量为 [1] 的两倍


In [ ]

# 定义离线数据增强方法def data_augmentation(image,label, mode):    out = np.transpose(image, (1,2,0))    out_label = np.transpose(label, (1,2,0))    if mode == 0:        # original        out = out        out_label = out_label    elif mode == 1:        # flip up and down        out = np.flipud(out)        out_label = np.flipud(out_label)    elif mode == 2:        # rotate counterwise 90 degree        out = np.rot90(out)        out_label = np.rot90(out_label)    elif mode == 3:        # rotate 90 degree and flip up and down        out = np.rot90(out)        out = np.flipud(out)        out_label = np.rot90(out_label)        out_label = np.flipud(out_label)    elif mode == 4:        # rotate 180 degree        out = np.rot90(out, k=2)        out_label = np.rot90(out_label, k=2)    elif mode == 5:        # rotate 180 degree and flip        out = np.rot90(out, k=2)        out = np.flipud(out)        out_label = np.rot90(out_label, k=2)        out_label = np.flipud(out_label)    elif mode == 6:        # rotate 270 degree        out = np.rot90(out, k=3)        out_label = np.rot90(out_label, k=3)    elif mode == 7:        # rotate 270 degree and flip        out = np.rot90(out, k=3)        out = np.flipud(out)        out_label = np.rot90(out_label, k=3)        out_label = np.flipud(out_label)    return  out,out_label

   In [ ]

## 制作分块数据集import cv2import numpy as npimport mathimport glob import osdef Im2Patch(img, win, stride=1):    k = 0    endc = img.shape[0]    endw = img.shape[1]    endh = img.shape[2]    patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]    TotalPatNum = patch.shape[1] * patch.shape[2]    Y = np.zeros([endc, win*win,TotalPatNum], np.float32)    for i in range(win):        for j in range(win):            patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]            Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)            k = k + 1    return Y.reshape([endc, win, win, TotalPatNum])def prepare_data(patch_size, stride, aug_times=1):    '''    该函数用于将图像切成方块,并进行数据增强    patch_size: 图像块的大小,本项目200*200        stride: 步长,每个图像块的间隔     aug_times: 数据增强次数,默认从八种增强方式中选择一种    '''    # train    print('process training data')    scales = [1] # 对数据进行随机放缩    files = glob.glob(os.path.join('data/delight_train_dataset/images', '*.jpg'))    files.sort()    img_folder = 'work/img_patch'    if  not os.path.exists(img_folder):        os.mkdir(img_folder)    label_folder = 'work/label_patch'    if  not os.path.exists(label_folder):        os.mkdir(label_folder)    train_num = 0    for i in range(len(files)):        img = cv2.imread(files[i])        label = cv2.imread(files[i].replace('images','gts'))        h, w, c = img.shape        for k in range(len(scales)):            Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)            Label = cv2.resize(label, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)            Img = np.transpose(Img, (2,0,1))            Label = np.transpose(Label, (2,0,1))            Img = np.float32(np.clip(Img,0,255))            Label = np.float32(np.clip(Label,0,255))            patches = Im2Patch(Img, win=patch_size, stride=stride)            label_patches = Im2Patch(Label, win=patch_size, stride=stride)            print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], patches.shape[3]*aug_times))            for n in range(patches.shape[3]):                data = patches[:,:,:,n].copy()                label_data = label_patches[:,:,:,n].copy()                            for m in range(aug_times):                    data_aug,label_aug = data_augmentation(data,label_data, np.random.randint(1,8))                    label_name = os.path.join(label_folder,str(train_num)+"_aug_%d" % (m+1)+'.jpg')                    image_name = os.path.join(img_folder,str(train_num)+"_aug_%d" % (m+1)+'.jpg')                                        cv2.imwrite(image_name, data_aug,[int( cv2.IMWRITE_JPEG_QUALITY), 100])                    cv2.imwrite(label_name, label_aug,[int( cv2.IMWRITE_JPEG_QUALITY), 100])                    train_num += 1    print('training set, # samples %dn' % train_num)   ## 生成数据prepare_data( patch_size=256, stride=200, aug_times=1)

   

3.2 重写数据读取类


该部分代码使用了paddle.vision.transforms内置的图像增强方法


In [ ]

# 重写数据读取类import paddleimport paddle.vision.transforms as Timport numpy as npimport globimport cv2# 重写数据读取类class DEshadowDataset(paddle.io.Dataset):    def __init__(self,mode = 'train',is_transforms = False):               label_path_ ='work/label_patch/*.jpg'        jpg_path_ ='work/img_patch/*.jpg'        self.label_list_ = glob.glob(label_path_)                self.jpg_list_ = glob.glob(jpg_path_)                self.is_transforms = is_transforms        self.mode = mode        scale_point = 0.95                self.transforms =T.Compose([            T.Normalize(data_format='HWC',),            T.HueTransform(0.4),            T.SaturationTransform(0.4),            T.HueTransform(0.4),            T.ToTensor(),            ])        # 选择前95%训练,后5%验证        if self.mode == 'train':            self.jpg_list_ = self.jpg_list_[:int(scale_point*len(self.jpg_list_))]            self.label_list_ = self.label_list_[:int(scale_point*len(self.label_list_))]        else:            self.jpg_list_ = self.jpg_list_[int(scale_point*len(self.jpg_list_)):]            self.label_list_ = self.label_list_[int(scale_point*len(self.label_list_)):]    def __getitem__(self, index):        jpg_ = self.jpg_list_[index]        label_ =  self.label_list_[index]        data = cv2.imread(jpg_) # 读取和代码处于同一目录下的 lena.png # 转为 0-1        mask = cv2.imread(label_)        data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB) # BGR 2 RGB        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) # BGR 2 RGB                data = np.uint8(data)        mask = np.uint8(mask)        if self.is_transforms:            data = self.transforms(data)            data = data/255            mask = T.functional.to_tensor(mask)         return  data,mask      def __len__(self):        return len(self.jpg_list_)

   In [13]

# 数据读取及增强可视化import paddle.vision.transforms as Timport matplotlib.pyplot as pltfrom PIL import Imagedataset = DEshadowDataset(mode='train',is_transforms = False )print('=============train dataset=============')img_,mask_ = dataset[3] # mask 始终大于 imgimg = Image.fromarray(img_)mask = Image.fromarray(mask_)#当要保存的图片为灰度图像时,灰度图像的 numpy 尺度是 [1, h, w]。需要将 [1, h, w] 改变为 [h, w]plt.figure(figsize=(12, 6))plt.subplot(1,2,1),plt.xticks([]),plt.yticks([]),plt.imshow(img)plt.subplot(1,2,2),plt.xticks([]),plt.yticks([]),plt.imshow(mask)plt.show()

   

3.3 重写模型


该部分代码定义KIUnet 和UNet_3Plus作为示例。

网络介绍详情请点击链接查看。

In [ ]

# 定义KIUnet import paddleimport paddle.nn as nnimport paddle.nn.functional as Ffrom paddle.nn import initializerdef init_weights(init_type='kaiming'):    if init_type == 'normal':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.Normal())    elif init_type == 'xavier':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.XavierNormal())    elif init_type == 'kaiming':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal)    else:        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)class kiunet(nn.Layer):    def __init__(self ,in_channels = 3, n_classes =3):        super(kiunet,self).__init__()        self.in_channels = in_channels        self.n_class = n_classes        self.encoder1 = nn.Conv2D(self.in_channels, 16, 3, stride=1, padding=1)  # First Layer GrayScale Image , change to input channels to 3 in case of RGB         self.en1_bn = nn.BatchNorm(16)        self.encoder2=   nn.Conv2D(16, 32, 3, stride=1, padding=1)          self.en2_bn = nn.BatchNorm(32)        self.encoder3=   nn.Conv2D(32, 64, 3, stride=1, padding=1)        self.en3_bn = nn.BatchNorm(64)        self.decoder1 =   nn.Conv2D(64, 32, 3, stride=1, padding=1)           self.de1_bn = nn.BatchNorm(32)        self.decoder2 =   nn.Conv2D(32,16, 3, stride=1, padding=1)        self.de2_bn = nn.BatchNorm(16)        self.decoder3 =   nn.Conv2D(16, 8, 3, stride=1, padding=1)        self.de3_bn = nn.BatchNorm(8)        self.decoderf1 =   nn.Conv2D(64, 32, 3, stride=1, padding=1)        self.def1_bn = nn.BatchNorm(32)        self.decoderf2=   nn.Conv2D(32, 16, 3, stride=1, padding=1)        self.def2_bn = nn.BatchNorm(16)        self.decoderf3 =   nn.Conv2D(16, 8, 3, stride=1, padding=1)        self.def3_bn = nn.BatchNorm(8)        self.encoderf1 =   nn.Conv2D(in_channels, 16, 3, stride=1, padding=1)  # First Layer GrayScale Image , change to input channels to 3 in case of RGB         self.enf1_bn = nn.BatchNorm(16)        self.encoderf2=   nn.Conv2D(16, 32, 3, stride=1, padding=1)        self.enf2_bn = nn.BatchNorm(32)        self.encoderf3 =   nn.Conv2D(32, 64, 3, stride=1, padding=1)        self.enf3_bn = nn.BatchNorm(64)        self.intere1_1 = nn.Conv2D(16,16,3, stride=1, padding=1)        self.inte1_1bn = nn.BatchNorm(16)        self.intere2_1 = nn.Conv2D(32,32,3, stride=1, padding=1)        self.inte2_1bn = nn.BatchNorm(32)        self.intere3_1 = nn.Conv2D(64,64,3, stride=1, padding=1)        self.inte3_1bn = nn.BatchNorm(64)        self.intere1_2 = nn.Conv2D(16,16,3, stride=1, padding=1)        self.inte1_2bn = nn.BatchNorm(16)        self.intere2_2 = nn.Conv2D(32,32,3, stride=1, padding=1)        self.inte2_2bn = nn.BatchNorm(32)        self.intere3_2 = nn.Conv2D(64,64,3, stride=1, padding=1)        self.inte3_2bn = nn.BatchNorm(64)        self.interd1_1 = nn.Conv2D(32,32,3, stride=1, padding=1)        self.intd1_1bn = nn.BatchNorm(32)        self.interd2_1 = nn.Conv2D(16,16,3, stride=1, padding=1)        self.intd2_1bn = nn.BatchNorm(16)        self.interd3_1 = nn.Conv2D(64,64,3, stride=1, padding=1)        self.intd3_1bn = nn.BatchNorm(64)        self.interd1_2 = nn.Conv2D(32,32,3, stride=1, padding=1)        self.intd1_2bn = nn.BatchNorm(32)        self.interd2_2 = nn.Conv2D(16,16,3, stride=1, padding=1)        self.intd2_2bn = nn.BatchNorm(16)        self.interd3_2 = nn.Conv2D(64,64,3, stride=1, padding=1)        self.intd3_2bn = nn.BatchNorm(64)        self.final = nn.Sequential(            nn.Conv2D(8,self.n_class,1,stride=1,padding=0),            nn.AdaptiveAvgPool2D(output_size=1))        # initialise weights        for m in self.sublayers ():            if isinstance(m, nn.Conv2D):                m.weight_attr = init_weights(init_type='kaiming')                m.bias_attr = init_weights(init_type='kaiming')            elif isinstance(m, nn.BatchNorm):                m.param_attr =init_weights(init_type='kaiming')                m.bias_attr = init_weights(init_type='kaiming')     def forward(self, x):        # input: c * h * w -> 16 * h/2 * w/2        out = F.relu(self.en1_bn(F.max_pool2d(self.encoder1(x),2,2)))  #U-Net branch        # c * h * w -> 16 * 2h * 2w        out1 = F.relu(self.enf1_bn(F.interpolate(self.encoderf1(x),scale_factor=(2,2),mode ='bicubic'))) #Ki-Net branch        # 16 * h/2 * w/2        tmp = out        # 16 * 2h * 2w -> 16 * h/2 * w/2        out = paddle.add(out,F.interpolate(F.relu(self.inte1_1bn(self.intere1_1(out1))),scale_factor=(0.25,0.25),mode ='bicubic')) #CRFB        # 16 * h/2 * w/2 -> 16 * 2h * 2w        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte1_2bn(self.intere1_2(tmp))),scale_factor=(4,4),mode ='bicubic')) #CRFB                # 16 * h/2 * w/2        u1 = out  #skip conn        # 16 * 2h * 2w        o1 = out1  #skip conn        # 16 * h/2 * w/2 -> 32 * h/4 * w/4        out = F.relu(self.en2_bn(F.max_pool2d(self.encoder2(out),2,2)))        # 16 * 2h * 2w -> 32 * 4h * 4w        out1 = F.relu(self.enf2_bn(F.interpolate(self.encoderf2(out1),scale_factor=(2,2),mode ='bicubic')))        #  32 * h/4 * w/4        tmp = out        # 32 * 4h * 4w -> 32 * h/4 *w/4        out = paddle.add(out,F.interpolate(F.relu(self.inte2_1bn(self.intere2_1(out1))),scale_factor=(0.0625,0.0625),mode ='bicubic'))        # 32 * h/4 * w/4 -> 32 *4h *4w        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte2_2bn(self.intere2_2(tmp))),scale_factor=(16,16),mode ='bicubic'))                #  32 * h/4 *w/4        u2 = out        #  32 *4h *4w        o2 = out1                # 32 * h/4 *w/4 -> 64 * h/8 *w/8        out = F.relu(self.en3_bn(F.max_pool2d(self.encoder3(out),2,2)))        # 32 *4h *4w -> 64 * 8h *8w        out1 = F.relu(self.enf3_bn(F.interpolate(self.encoderf3(out1),scale_factor=(2,2),mode ='bicubic')))        #  64 * h/8 *w/8         tmp = out        #  64 * 8h *8w -> 64 * h/8 * w/8        out = paddle.add(out,F.interpolate(F.relu(self.inte3_1bn(self.intere3_1(out1))),scale_factor=(0.015625,0.015625),mode ='bicubic'))        #  64 * h/8 *w/8 -> 64 * 8h * 8w        out1 = paddle.add(out1,F.interpolate(F.relu(self.inte3_2bn(self.intere3_2(tmp))),scale_factor=(64,64),mode ='bicubic'))                ### End of encoder block        ### Start Decoder                # 64 * h/8 * w/8 -> 32 * h/4 * w/4         out = F.relu(self.de1_bn(F.interpolate(self.decoder1(out),scale_factor=(2,2),mode ='bicubic')))  #U-NET        # 64 * 8h * 8w -> 32 * 4h * 4w         out1 = F.relu(self.def1_bn(F.max_pool2d(self.decoderf1(out1),2,2))) #Ki-NET        # 32 * h/4 * w/4         tmp = out        # 32 * 4h * 4w  -> 32 * h/4 * w/4         out = paddle.add(out,F.interpolate(F.relu(self.intd1_1bn(self.interd1_1(out1))),scale_factor=(0.0625,0.0625),mode ='bicubic'))        # 32 * h/4 * w/4  -> 32 * 4h * 4w         out1 = paddle.add(out1,F.interpolate(F.relu(self.intd1_2bn(self.interd1_2(tmp))),scale_factor=(16,16),mode ='bicubic'))                # 32 * h/4 * w/4         out = paddle.add(out,u2)  #skip conn        # 32 * 4h * 4w         out1 = paddle.add(out1,o2)  #skip conn        # 32 * h/4 * w/4 -> 16 * h/2 * w/2         out = F.relu(self.de2_bn(F.interpolate(self.decoder2(out),scale_factor=(2,2),mode ='bicubic')))        # 32 * 4h * 4w  -> 16 * 2h * 2w        out1 = F.relu(self.def2_bn(F.max_pool2d(self.decoderf2(out1),2,2)))        # 16 * h/2 * w/2         tmp = out        # 16 * 2h * 2w -> 16 * h/2 * w/2        out = paddle.add(out,F.interpolate(F.relu(self.intd2_1bn(self.interd2_1(out1))),scale_factor=(0.25,0.25),mode ='bicubic'))        # 16 * h/2 * w/2 -> 16 * 2h * 2w        out1 = paddle.add(out1,F.interpolate(F.relu(self.intd2_2bn(self.interd2_2(tmp))),scale_factor=(4,4),mode ='bicubic'))                # 16 * h/2 * w/2        out = paddle.add(out,u1)        # 16 * 2h * 2w        out1 = paddle.add(out1,o1)        # 16 * h/2 * w/2 -> 8 * h * w        out = F.relu(self.de3_bn(F.interpolate(self.decoder3(out),scale_factor=(2,2),mode ='bicubic')))        # 16 * 2h * 2w -> 8 * h * w        out1 = F.relu(self.def3_bn(F.max_pool2d(self.decoderf3(out1),2,2)))        # 8 * h * w        out = paddle.add(out,out1) # fusion of both branches        # 最后一层用sigmoid激活函数        out = F.sigmoid(self.final(out))  #1*1 conv                return out

   In [ ]

import paddleimport paddle.nn as nnimport paddle.nn.functional as Ffrom paddle.nn import initializerdef init_weights(init_type='kaiming'):    if init_type == 'normal':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.Normal())    elif init_type == 'xavier':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.XavierNormal())    elif init_type == 'kaiming':        return paddle.framework.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal)    else:        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)class unetConv2(nn.Layer):    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):        super(unetConv2, self).__init__()        self.n = n        self.ks = ks        self.stride = stride        self.padding = padding        s = stride        p = padding        if is_batchnorm:            for i in range(1, n + 1):                conv = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),                                     nn.BatchNorm(out_size),                                     nn.ReLU(), )                setattr(self, 'conv%d' % i, conv)                in_size = out_size        else:            for i in range(1, n + 1):                conv = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),                                     nn.ReLU(), )                setattr(self, 'conv%d' % i, conv)                in_size = out_size        # initialise the blocks        for m in self.children():            m.weight_attr=init_weights(init_type='kaiming')            m.bias_attr=init_weights(init_type='kaiming')    def forward(self, inputs):        x = inputs        for i in range(1, self.n + 1):            conv = getattr(self, 'conv%d' % i)            x = conv(x)        return x'''    UNet 3+'''class UNet_3Plus(nn.Layer):    def __init__(self, in_channels=3, n_classes=1, is_deconv=True, is_batchnorm=True, end_sigmoid=True):        super(UNet_3Plus, self).__init__()        self.is_deconv = is_deconv        self.in_channels = in_channels        self.is_batchnorm = is_batchnorm        self.end_sigmoid = end_sigmoid        filters = [16, 32, 64, 128, 256]        ## -------------Encoder--------------        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)        self.maxpool1 = nn.MaxPool2D(kernel_size=2)        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)        self.maxpool2 = nn.MaxPool2D(kernel_size=2)        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)        self.maxpool3 = nn.MaxPool2D(kernel_size=2)        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)        self.maxpool4 = nn.MaxPool2D(kernel_size=2)        self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)        ## -------------Decoder--------------        self.CatChannels = filters[0]        self.CatBlocks = 5        self.UpChannels = self.CatChannels * self.CatBlocks        '''stage 4d'''        # h2->320*320, hd4->40*40, Pooling 8 times        self.h2_PT_hd4 = nn.MaxPool2D(8, 8, ceil_mode=True)        self.h2_PT_hd4_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)        self.h2_PT_hd4_bn = nn.BatchNorm(self.CatChannels)        self.h2_PT_hd4_relu = nn.ReLU()        # h2->160*160, hd4->40*40, Pooling 4 times        self.h2_PT_hd4 = nn.MaxPool2D(4, 4, ceil_mode=True)        self.h2_PT_hd4_conv = nn.Conv2D(filters[1], self.CatChannels, 3, padding=1)        self.h2_PT_hd4_bn = nn.BatchNorm(self.CatChannels)        self.h2_PT_hd4_relu = nn.ReLU()        # h3->80*80, hd4->40*40, Pooling 2 times        self.h3_PT_hd4 = nn.MaxPool2D(2, 2, ceil_mode=True)        self.h3_PT_hd4_conv = nn.Conv2D(filters[2], self.CatChannels, 3, padding=1)        self.h3_PT_hd4_bn = nn.BatchNorm(self.CatChannels)        self.h3_PT_hd4_relu = nn.ReLU()        # h4->40*40, hd4->40*40, Concatenation        self.h4_Cat_hd4_conv = nn.Conv2D(filters[3], self.CatChannels, 3, padding=1)        self.h4_Cat_hd4_bn = nn.BatchNorm(self.CatChannels)        self.h4_Cat_hd4_relu = nn.ReLU()        # hd5->20*20, hd4->40*40, Upsample 2 times        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14        self.hd5_UT_hd4_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)        self.hd5_UT_hd4_bn = nn.BatchNorm(self.CatChannels)        self.hd5_UT_hd4_relu = nn.ReLU()        # fusion(h2_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)        self.conv4d_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16        self.bn4d_1 = nn.BatchNorm(self.UpChannels)        self.relu4d_1 = nn.ReLU()        '''stage 3d'''        # h2->320*320, hd3->80*80, Pooling 4 times        self.h2_PT_hd3 = nn.MaxPool2D(4, 4, ceil_mode=True)        self.h2_PT_hd3_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)        self.h2_PT_hd3_bn = nn.BatchNorm(self.CatChannels)        self.h2_PT_hd3_relu = nn.ReLU()        # h2->160*160, hd3->80*80, Pooling 2 times        self.h2_PT_hd3 = nn.MaxPool2D(2, 2, ceil_mode=True)        self.h2_PT_hd3_conv = nn.Conv2D(filters[1], self.CatChannels, 3, padding=1)        self.h2_PT_hd3_bn = nn.BatchNorm(self.CatChannels)        self.h2_PT_hd3_relu = nn.ReLU()        # h3->80*80, hd3->80*80, Concatenation        self.h3_Cat_hd3_conv = nn.Conv2D(filters[2], self.CatChannels, 3, padding=1)        self.h3_Cat_hd3_bn = nn.BatchNorm(self.CatChannels)        self.h3_Cat_hd3_relu = nn.ReLU()        # hd4->40*40, hd4->80*80, Upsample 2 times        self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14        self.hd4_UT_hd3_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd4_UT_hd3_bn = nn.BatchNorm(self.CatChannels)        self.hd4_UT_hd3_relu = nn.ReLU()        # hd5->20*20, hd4->80*80, Upsample 4 times        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14        self.hd5_UT_hd3_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)        self.hd5_UT_hd3_bn = nn.BatchNorm(self.CatChannels)        self.hd5_UT_hd3_relu = nn.ReLU()        # fusion(h2_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)        self.conv3d_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16        self.bn3d_1 = nn.BatchNorm(self.UpChannels)        self.relu3d_1 = nn.ReLU()        '''stage 2d '''        # h2->320*320, hd2->160*160, Pooling 2 times        self.h2_PT_hd2 = nn.MaxPool2D(2, 2, ceil_mode=True)        self.h2_PT_hd2_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)        self.h2_PT_hd2_bn = nn.BatchNorm(self.CatChannels)        self.h2_PT_hd2_relu = nn.ReLU()        # h2->160*160, hd2->160*160, Concatenation        self.h2_Cat_hd2_conv = nn.Conv2D(filters[1], self.CatChannels, 3, padding=1)        self.h2_Cat_hd2_bn = nn.BatchNorm(self.CatChannels)        self.h2_Cat_hd2_relu = nn.ReLU()        # hd3->80*80, hd2->160*160, Upsample 2 times        self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14        self.hd3_UT_hd2_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd3_UT_hd2_bn = nn.BatchNorm(self.CatChannels)        self.hd3_UT_hd2_relu = nn.ReLU()        # hd4->40*40, hd2->160*160, Upsample 4 times        self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14        self.hd4_UT_hd2_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd4_UT_hd2_bn = nn.BatchNorm(self.CatChannels)        self.hd4_UT_hd2_relu = nn.ReLU()        # hd5->20*20, hd2->160*160, Upsample 8 times        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14        self.hd5_UT_hd2_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)        self.hd5_UT_hd2_bn = nn.BatchNorm(self.CatChannels)        self.hd5_UT_hd2_relu = nn.ReLU()        # fusion(h2_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)        self.Conv2D_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16        self.bn2d_1 = nn.BatchNorm(self.UpChannels)        self.relu2d_1 = nn.ReLU()        '''stage 1d'''        # h2->320*320, hd1->320*320, Concatenation        self.h2_Cat_hd1_conv = nn.Conv2D(filters[0], self.CatChannels, 3, padding=1)        self.h2_Cat_hd1_bn = nn.BatchNorm(self.CatChannels)        self.h2_Cat_hd1_relu = nn.ReLU()        # hd2->160*160, hd1->320*320, Upsample 2 times        self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14        self.hd2_UT_hd1_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd2_UT_hd1_bn = nn.BatchNorm(self.CatChannels)        self.hd2_UT_hd1_relu = nn.ReLU()        # hd3->80*80, hd1->320*320, Upsample 4 times        self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14        self.hd3_UT_hd1_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd3_UT_hd1_bn = nn.BatchNorm(self.CatChannels)        self.hd3_UT_hd1_relu = nn.ReLU()        # hd4->40*40, hd1->320*320, Upsample 8 times        self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14        self.hd4_UT_hd1_conv = nn.Conv2D(self.UpChannels, self.CatChannels, 3, padding=1)        self.hd4_UT_hd1_bn = nn.BatchNorm(self.CatChannels)        self.hd4_UT_hd1_relu = nn.ReLU()        # hd5->20*20, hd1->320*320, Upsample 16 times        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear')  # 14*14        self.hd5_UT_hd1_conv = nn.Conv2D(filters[4], self.CatChannels, 3, padding=1)        self.hd5_UT_hd1_bn = nn.BatchNorm(self.CatChannels)        self.hd5_UT_hd1_relu = nn.ReLU()        # fusion(h2_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)        self.conv1d_1 = nn.Conv2D(self.UpChannels, self.UpChannels, 3, padding=1)  # 16        self.bn1d_1 = nn.BatchNorm(self.UpChannels)        self.relu1d_1 = nn.ReLU()        # output        self.outconv1 = nn.Conv2D(self.UpChannels, n_classes, 3, padding=1)        # initialise weights        for m in self.sublayers ():            if isinstance(m, nn.Conv2D):                m.weight_attr = init_weights(init_type='kaiming')                m.bias_attr = init_weights(init_type='kaiming')            elif isinstance(m, nn.BatchNorm):                m.param_attr =init_weights(init_type='kaiming')                m.bias_attr = init_weights(init_type='kaiming')    def forward(self, inputs):        ## -------------Encoder-------------        h2 = self.conv1(inputs)  # h2->320*320*64        h2 = self.maxpool1(h2)        h2 = self.conv2(h2)  # h2->160*160*128        h3 = self.maxpool2(h2)        h3 = self.conv3(h3)  # h3->80*80*256        h4 = self.maxpool3(h3)        h4 = self.conv4(h4)  # h4->40*40*512        h5 = self.maxpool4(h4)        hd5 = self.conv5(h5)  # h5->20*20*1024        ## -------------Decoder-------------        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))        h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))        h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))        hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))        hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(            paddle.concat([h2_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4], 1)))) # hd4->40*40*UpChannels        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))        h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))        hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))        hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))        hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(            paddle.concat([h2_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3], 1)))) # hd3->80*80*UpChannels        h2_PT_hd2 = self.h2_PT_hd2_relu(self.h2_PT_hd2_bn(self.h2_PT_hd2_conv(self.h2_PT_hd2(h2))))        h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))        hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))        hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))        hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))        hd2 = self.relu2d_1(self.bn2d_1(self.Conv2D_1(            paddle.concat([h2_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2], 1)))) # hd2->160*160*UpChannels        h2_Cat_hd1 = self.h2_Cat_hd1_relu(self.h2_Cat_hd1_bn(self.h2_Cat_hd1_conv(h2)))        hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))        hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))        hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))        hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))        hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(            paddle.concat([h2_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1], 1)))) # hd1->320*320*UpChannels        d1 = self.outconv1(hd1)  # d1->320*320*n_classes        if self.end_sigmoid:            out = F.sigmoid(d1)        else:            out = d1        return out

   In [ ]

# 查看网络结构KIunet = kiunet(in_channels = 3, n_classes =3)Unet3p = UNet_3Plus( in_channels=3, n_classes=3)model = paddle.Model(Unet3p)model.summary((2,3, 256, 256))

   

3.4 训练函数介绍

该部分介绍模板的具体使用:


模型运行

参数 self.is_Train self.PATH

False进行推理从头训练True进行训练读取模型继续训练

注意: self.PATH的True指输入模型路径,例如:当self.is_Train为False时, self.PATH为模型路径,此时不需要训练,直接进行推理生成结果。


损失函数

在work目录下loss.py文件中实现了:

—- TVLoss

AiPPT模板广场 AiPPT模板广场

AiPPT模板广场-PPT模板-word文档模板-excel表格模板

AiPPT模板广场 147 查看详情 AiPPT模板广场

—- SSIMLoss

—- LossVGG19(感知损失)

……

可以根据自己的需求进行修改或添加


最优模型

在Eval方法中定义了score变量,该变量记录验证集上的评价指标。

每个EPOCH会计算一次验证集上的score。

只有当获得更高的score时,才会删除原模型,记录新的最优模型。

注意: 需要自定义评价指标时,只需修改score的计算即可


保存结果

函数会zi’dong在work目录下生成一系列目录:

log

用于保存训练过程的结果

outputs

用于以模型名称为文件夹保存推理的结果

saveModel

用于记录最优的模型,名称为模型名+验证集score分数

注意: 需要自定义修改self.Modelname,用于区分不同模型的生成结果

In [ ]

## 主函数定义 Baseline # 基本假设 干净图像 + 噪声 = 受污染图像# 阴影部分的像素值要低一些,因此,需要对受污染图像的像素值进行增强# 受污染图像 + 阴影 = 干净图像"""@author: xupeng"""from work.loss import *from work.util import batch_psnr_ssimfrom PIL import Imageimport numpy as npimport pandas as pdfrom tqdm import tqdmimport osimport matplotlib.pyplot as pltimport globfrom PIL import Imagefrom skimage.measure.simple_metrics import compare_psnrimport skimageimport paddle.vision.transforms as Timport cv2class Solver(object):    def __init__(self):        self.model = None        self.lr = 1e-4 # 学习率        self.epochs = 10  # 训练的代数        self.batch_size = 4 # 训练批次数量        self.optimizer = None        self.scheduler = None        self.saveSetDir = r'data/delight_testA_dataset/images/' # 测试集地址                self.train_set = None        self.eval_set = None        self.train_loader = None        self.eval_loader = None                self.Modelname = 'Unet3p' # 使用的网络声明(生成的文件会以该声明为区分)                self.evaTOP = 0        self.is_Train = True # #是否对执行模型训练,当为False时,给出self.PATH,直接进行推理        self.PATH = False #False# 用于记录最优模型的名称,需要预训练时,此项不为空    def InitModeAndData(self):        print("---------------trainInit:---------------")        # API文档搜索:vision.transforms 查看数据增强方法        is_transforms = True        self.train_set =  DEshadowDataset(mode='train',is_transforms = is_transforms)        self.eval_set  =  DEshadowDataset(mode='eval',is_transforms = is_transforms)        # 使用paddle.io.DataLoader 定义DataLoader对象用于加载Python生成器产生的数据        self.train_loader = paddle.io.DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)        self.eval_loader = paddle.io.DataLoader(self.eval_set, batch_size=self.batch_size, shuffle=False)        self.model = UNet_3Plus( in_channels=3, n_classes=3)        if self.is_Train and self.PATH: # 已经有模型还要训练时,进行二次训练            params = paddle.load(self.PATH)            self.model.set_state_dict(params)            self.evaTOP = float(self.PATH.split(self.Modelname)[-1].replace('.pdparams',''))        self.scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=self.lr, T_max=30, verbose=False)        self.optimizer = paddle.optimizer.Adam(parameters=self.model.parameters(), learning_rate=self.scheduler,beta1=0.5, beta2=0.999)        # 创建文件夹         for item in ['log','outputs','saveModel']:            make_folder = os.path.join('work',item)            if  not os.path.exists(make_folder):                os.mkdir(make_folder)    def train(self,epoch):        print("=======Epoch:{}/{}=======".format(epoch,self.epochs))        self.model.train()        TV_criterion = nTVLoss()    # 用于规范图像噪声        ssim_criterion = SSIMLoss() # 用于对原图进行        lossnet = LossVGG19() # 感知loss        l1_loss = paddle.nn.L1Loss()        Mse_loss = nn.MSELoss()        try: # 使用这种写法(try-except 和 ascii=True),可以避免windows终端或异常中断tqdm出现的混乱            with tqdm(enumerate(self.train_loader),total=len(self.train_loader), ascii=True) as tqdmData:                mean_loss = []                                for idx, (img_train,mask_train) in tqdmData:                    tqdmData.set_description('train')                    img_train = paddle.to_tensor(img_train,dtype="float32")                    mask_train =  paddle.to_tensor(mask_train,dtype="float32")                    # MODEL                    outputs_noise = self.model(img_train) # mask > img  因此需要增加像素值                    mask_noise = mask_train - img_train # 真实图像与阴影的差值                    # 恢复出来的图像                    restore_trian =img_train + outputs_noise                    '''                    # 去均值                    tensor_c = paddle.to_tensor(np.array([123.6800, 116.7790, 103.9390]).astype(np.float32).reshape((1, 3, 1, 1)))                    # 感知损失                    # preceptual loss                    loss_fake_B = lossnet(restore_trian * 255 - tensor_c)                    loss_real_B = lossnet(mask_train * 255 - tensor_c)                    p0 = l1_loss(restore_trian * 255 - tensor_c, mask_train * 255 - tensor_c) * 2                    p1 = l1_loss(loss_fake_B['relu1'], loss_real_B['relu1']) / 2.6                    p2 = l1_loss(loss_fake_B['relu2'], loss_real_B['relu2']) / 4.8                    loss_p = p0 + p1 + p2                    loss = loss_p  + ssim_criterion(restore_trian,mask_train) + 10*l1_loss(outputs_noise,mask_noise)                    '''                    loss = l1_loss(restore_trian,mask_train)                    self.optimizer.clear_grad()                    loss.backward()                    self.optimizer.step()                                        self.scheduler.step() ### 改优化器记得改这个                    mean_loss.append(loss.item())        except KeyboardInterrupt:            tqdmData.close()            os._exit(0)        tqdmData.close()        # 清除中间变量,释放内存        del loss,img_train,mask_train,outputs_noise,mask_noise,restore_trian        paddle.device.cuda.empty_cache()        return {'Mean_trainLoss':np.mean(mean_loss)}        def Eval(self,modelname):                self.model.eval()        temp_eval_psnr ,temp_eval_ssim= [],[]        with paddle.no_grad():            try:                with tqdm(enumerate(self.eval_loader),total=len(self.eval_loader), ascii=True) as tqdmData:                    for idx, (img_eval,mask_eval) in tqdmData:                        tqdmData.set_description(' eval')                        img_eval=  paddle.to_tensor(img_eval,dtype="float32")                        mask_eval = paddle.to_tensor(mask_eval,dtype="float32")                        outputs_denoise = self.model(img_eval) # 模型输出                        outputs_denoise = img_eval + outputs_denoise # 恢复后的图像                        psnr_test,ssim_test = batch_psnr_ssim(outputs_denoise, mask_eval, 1.)                        temp_eval_psnr.append(psnr_test)                        temp_eval_ssim.append(ssim_test)                                except KeyboardInterrupt:                tqdmData.close()                os._exit(0)            tqdmData.close()            paddle.device.cuda.empty_cache()            # 打印test psnr & ssim        # print('eval_psnr:',np.mean(temp_eval_psnr),'eval_ssim:',np.mean(temp_eval_ssim))        # 实现评价指标        score = 0.05*np.mean(temp_eval_psnr)+0.5*np.mean(temp_eval_ssim)        return {'eval_psnr':np.mean(temp_eval_psnr),'eval_ssim':np.mean(temp_eval_ssim),'SCORE':score}        def saveModel(self,trainloss,modelname):                trainLoss = trainloss['SCORE']        if trainLoss < self.evaTOP and self.evaTOP!=0:             return 0        else:            folder = 'work/saveModel/'            self.PATH = folder+modelname+str(trainLoss)+'.pdparams'            removePATH = folder+modelname+str(self.evaTOP)+'.pdparams'            paddle.save(self.model.state_dict(), self.PATH)            if self.evaTOP!=0:                os.remove(removePATH)                        self.evaTOP = trainLoss            return 1            def saveResult(self):        print("---------------saveResult:---------------")        self.model.set_state_dict(paddle.load(self.PATH))        self.model.eval()        paddle.set_grad_enabled(False)        paddle.device.cuda.empty_cache()        data_dir = glob.glob(self.saveSetDir+'*.jpg')        # 创建保存文件夹        make_save_result = os.path.join('work/outputs',self.Modelname)        if  not os.path.exists(make_save_result):            os.mkdir(make_save_result)        saveSet = pd.DataFrame()        tpsnr,tssim = [],[]                for idx,ori_path in enumerate(data_dir):            print(len(data_dir),'|',idx+1,end = 'r',flush = True)                        ori = cv2.imread(ori_path) # W,H,C            ori = cv2.cvtColor(ori, cv2.COLOR_BGR2RGB) # BGR 2 RGB            ori_w,ori_h,ori_c = ori.shape            # normalize_test = T.Normalize(             # [0.610621, 0.5989216, 0.5876396],             # [0.1835931, 0.18701428, 0.19362564],            # data_format='HWC')            #ori = normalize_test(ori)            ori = T.functional.resize(ori,(1024,1024),interpolation = 'bicubic') ###不切块送进去 1024,1024 上采样            # from HWC to CHW ,[0,255] to [0,1]            ori = np.transpose(ori,(2,0,1))/255            ori_img = paddle.to_tensor(ori,dtype="float32")            ori_img = paddle.unsqueeze(ori_img,0) # N C W H            out_noise = self.model(ori_img) #不切块送进去            out_noise = ori_img + out_noise #加上恢复的像素            img_cpu = out_noise.cpu().numpy()            #保存结果            img_cpu = np.squeeze(img_cpu)            img_cpu = np.transpose(img_cpu,(1,2,0)) # C,W,H to W,H,C                        img_cpu = np.clip(img_cpu, 0., 1.)            savepic = np.uint8(img_cpu*255)            savepic = T.functional.resize(savepic,(ori_w,ori_h),interpolation = 'bicubic') ## 采样回原始大小            # 保存路径            savedir = os.path.join(make_save_result,ori_path.split('/')[-1])            savepic = cv2.cvtColor(savepic, cv2.COLOR_RGB2BGR) # BGR            cv2.imwrite(savedir, savepic, [int( cv2.IMWRITE_JPEG_QUALITY), 100])    def run(self):                self.InitModeAndData()                    if self.is_Train:            modelname = self.Modelname #  使用的网络名称            result = pd.DataFrame()                        for epoch in range(1, self.epochs + 1):                                trainloss = self.train(epoch)                evalloss =  self.Eval(modelname)#                Type = self.saveModel(evalloss,modelname)                                type_ = {'Type':Type}                trainloss.update(evalloss)#                trainloss.update(type_)                result = result.append(trainloss,ignore_index=True)                print('Epoch:',epoch,trainloss)                                #self.scheduler.step()            evalloss =  self.Eval(modelname)#            result.to_csv('work/log/' +modelname+str(evalloss['SCORE'])+'.csv')#            self.saveResult()        else:            self.saveResult()def main():    solver = Solver()    solver.run()    if __name__ == '__main__':    main()

   

进入输出路径创建readme.txt文件,输入要求的内容:

训练框架:PaddlePaddle

代码运行环境:V100

是否使用GPU:是

单张图片耗时/s:1

模型大小:45

其他说明:算法参考UNet+++

In [ ]

# %cd /home/aistudio/work/outputs/Unet3p# !zip result.zip *.jpg *.txt

   

4 项目总结

项目实现了一套简单的训练模板。

项目基于文档阴影消除任务创建,提交结果为0.59951

欢迎大家在该基础上修改自己的算法!

『网盘赛』基于自定义训练模板的文档阴影消除 - 创想鸟『网盘赛』基于自定义训练模板的文档阴影消除 - 创想鸟

以上就是『网盘赛』基于自定义训练模板的文档阴影消除的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/315368.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2025年11月5日 07:26:32
下一篇 2025年11月5日 07:27:04

相关推荐

  • Uniapp 中如何不拉伸不裁剪地展示图片?

    灵活展示图片:如何不拉伸不裁剪 在界面设计中,常常需要以原尺寸展示用户上传的图片。本文将介绍一种在 uniapp 框架中实现该功能的简单方法。 对于不同尺寸的图片,可以采用以下处理方式: 极端宽高比:撑满屏幕宽度或高度,再等比缩放居中。非极端宽高比:居中显示,若能撑满则撑满。 然而,如果需要不拉伸不…

    2025年12月24日
    400
  • 如何让小说网站控制台显示乱码,同时网页内容正常显示?

    如何在不影响用户界面的情况下实现控制台乱码? 当在小说网站上下载小说时,大家可能会遇到一个问题:网站上的文本在网页内正常显示,但是在控制台中却是乱码。如何实现此类操作,从而在不影响用户界面(UI)的情况下保持控制台乱码呢? 答案在于使用自定义字体。网站可以通过在服务器端配置自定义字体,并通过在客户端…

    2025年12月24日
    800
  • 如何在地图上轻松创建气泡信息框?

    地图上气泡信息框的巧妙生成 地图上气泡信息框是一种常用的交互功能,它简便易用,能够为用户提供额外信息。本文将探讨如何借助地图库的功能轻松创建这一功能。 利用地图库的原生功能 大多数地图库,如高德地图,都提供了现成的信息窗体和右键菜单功能。这些功能可以通过以下途径实现: 高德地图 JS API 参考文…

    2025年12月24日
    400
  • 如何使用 scroll-behavior 属性实现元素scrollLeft变化时的平滑动画?

    如何实现元素scrollleft变化时的平滑动画效果? 在许多网页应用中,滚动容器的水平滚动条(scrollleft)需要频繁使用。为了让滚动动作更加自然,你希望给scrollleft的变化添加动画效果。 解决方案:scroll-behavior 属性 要实现scrollleft变化时的平滑动画效果…

    2025年12月24日
    000
  • 如何为滚动元素添加平滑过渡,使滚动条滑动时更自然流畅?

    给滚动元素平滑过渡 如何在滚动条属性(scrollleft)发生改变时为元素添加平滑的过渡效果? 解决方案:scroll-behavior 属性 为滚动容器设置 scroll-behavior 属性可以实现平滑滚动。 html 代码: click the button to slide right!…

    2025年12月24日
    500
  • 如何选择元素个数不固定的指定类名子元素?

    灵活选择元素个数不固定的指定类名子元素 在网页布局中,有时需要选择特定类名的子元素,但这些元素的数量并不固定。例如,下面这段 html 代码中,activebar 和 item 元素的数量均不固定: *n *n 如果需要选择第一个 item元素,可以使用 css 选择器 :nth-child()。该…

    2025年12月24日
    200
  • 使用 SVG 如何实现自定义宽度、间距和半径的虚线边框?

    使用 svg 实现自定义虚线边框 如何实现一个具有自定义宽度、间距和半径的虚线边框是一个常见的前端开发问题。传统的解决方案通常涉及使用 border-image 引入切片图片,但是这种方法存在引入外部资源、性能低下的缺点。 为了避免上述问题,可以使用 svg(可缩放矢量图形)来创建纯代码实现。一种方…

    2025年12月24日
    100
  • 如何解决本地图片在使用 mask JS 库时出现的跨域错误?

    如何跨越localhost使用本地图片? 问题: 在本地使用mask js库时,引入本地图片会报跨域错误。 解决方案: 要解决此问题,需要使用本地服务器启动文件,以http或https协议访问图片,而不是使用file://协议。例如: python -m http.server 8000 然后,可以…

    2025年12月24日
    200
  • 旋转长方形后,如何计算其相对于画布左上角的轴距?

    绘制长方形并旋转,计算旋转后轴距 在拥有 1920×1080 画布中,放置一个宽高为 200×20 的长方形,其坐标位于 (100, 100)。当以任意角度旋转长方形时,如何计算它相对于画布左上角的 x、y 轴距? 以下代码提供了一个计算旋转后长方形轴距的解决方案: const x = 200;co…

    2025年12月24日
    000
  • 旋转长方形后,如何计算它与画布左上角的xy轴距?

    旋转后长方形在画布上的xy轴距计算 在画布中添加一个长方形,并将其旋转任意角度,如何计算旋转后的长方形与画布左上角之间的xy轴距? 问题分解: 要计算旋转后长方形的xy轴距,需要考虑旋转对长方形宽高和位置的影响。首先,旋转会改变长方形的长和宽,其次,旋转会改变长方形的中心点位置。 求解方法: 计算旋…

    2025年12月24日
    000
  • 旋转长方形后如何计算其在画布上的轴距?

    旋转长方形后计算轴距 假设长方形的宽、高分别为 200 和 20,初始坐标为 (100, 100),我们将它旋转一个任意角度。根据旋转矩阵公式,旋转后的新坐标 (x’, y’) 可以通过以下公式计算: x’ = x * cos(θ) – y * sin(θ)y’ = x * …

    2025年12月24日
    000
  • 如何让“元素跟随文本高度,而不是撑高父容器?

    如何让 元素跟随文本高度,而不是撑高父容器 在页面布局中,经常遇到父容器高度被子元素撑开的问题。在图例所示的案例中,父容器被较高的图片撑开,而文本的高度没有被考虑。本问答将提供纯css解决方案,让图片跟随文本高度,确保父容器的高度不会被图片影响。 解决方法 为了解决这个问题,需要将图片从文档流中脱离…

    2025年12月24日
    000
  • 如何计算旋转后长方形在画布上的轴距?

    旋转后长方形与画布轴距计算 在给定的画布中,有一个长方形,在随机旋转一定角度后,如何计算其在画布上的轴距,即距离左上角的距离? 以下提供一种计算长方形相对于画布左上角的新轴距的方法: const x = 200; // 初始 x 坐标const y = 90; // 初始 y 坐标const w =…

    2025年12月24日
    200
  • CSS元素设置em和transition后,为何载入页面无放大效果?

    css元素设置em和transition后,为何载入无放大效果 很多开发者在设置了em和transition后,却发现元素载入页面时无放大效果。本文将解答这一问题。 原问题:在视频演示中,将元素设置如下,载入页面会有放大效果。然而,在个人尝试中,并未出现该效果。这是由于macos和windows系统…

    2025年12月24日
    200
  • 为什么 CSS mask 属性未请求指定图片?

    解决 css mask 属性未请求图片的问题 在使用 css mask 属性时,指定了图片地址,但网络面板显示未请求获取该图片,这可能是由于浏览器兼容性问题造成的。 问题 如下代码所示: 立即学习“前端免费学习笔记(深入)”; icon [data-icon=”cloud”] { –icon-cl…

    2025年12月24日
    200
  • 如何利用 CSS 选中激活标签并影响相邻元素的样式?

    如何利用 css 选中激活标签并影响相邻元素? 为了实现激活标签影响相邻元素的样式需求,可以通过 :has 选择器来实现。以下是如何具体操作: 对于激活标签相邻后的元素,可以在 css 中使用以下代码进行设置: li:has(+li.active) { border-radius: 0 0 10px…

    2025年12月24日
    100
  • 如何模拟Windows 10 设置界面中的鼠标悬浮放大效果?

    win10设置界面的鼠标移动显示周边的样式(探照灯效果)的实现方式 在windows设置界面的鼠标悬浮效果中,光标周围会显示一个放大区域。在前端开发中,可以通过多种方式实现类似的效果。 使用css 使用css的transform和box-shadow属性。通过将transform: scale(1.…

    2025年12月24日
    200
  • 如何用HTML/JS实现Windows 10设置界面鼠标移动探照灯效果?

    Win10设置界面中的鼠标移动探照灯效果实现指南 想要在前端开发中实现类似于Windows 10设置界面的鼠标移动探照灯效果,有两种解决方案:CSS 和 HTML/JS 组合。 CSS 实现 不幸的是,仅使用CSS无法完全实现该效果。 立即学习“前端免费学习笔记(深入)”; HTML/JS 实现 要…

    2025年12月24日
    000
  • 如何计算旋转后的长方形在画布上的 XY 轴距?

    旋转长方形后计算其画布xy轴距 在创建的画布上添加了一个长方形,并提供其宽、高和初始坐标。为了视觉化旋转效果,还提供了一些旋转特定角度后的图片。 问题是如何计算任意角度旋转后,这个长方形的xy轴距。这涉及到使用三角学来计算旋转后的坐标。 以下是一个 javascript 代码示例,用于计算旋转后长方…

    2025年12月24日
    000
  • 为什么我的 Safari 自定义样式表在百度页面上失效了?

    为什么在 Safari 中自定义样式表未能正常工作? 在 Safari 的偏好设置中设置自定义样式表后,您对其进行测试却发现效果不同。在您自己的网页中,样式有效,而在百度页面中却失效。 造成这种情况的原因是,第一个访问的项目使用了文件协议,可以访问本地目录中的图片文件。而第二个访问的百度使用了 ht…

    2025年12月24日
    000

发表回复

登录后才能评论
关注微信