成都网站建设设计

将想法与焦点和您一起共享

如何对比pytorch的ReLU和自定义的classGuidedBackpropReLU

这篇文章将为大家详细讲解有关如何对比pytorch的ReLU和自定义的class GuidedBackpropReLU,文章内容质量较高,因此小编分享给大家做个参考,希望大家阅读完这篇文章后对相关知识有一定的了解。

牟平网站建设公司成都创新互联公司,牟平网站设计制作,有大型网站制作公司丰富经验。已为牟平成百上千家提供企业网站建设服务。企业网站搭建\外贸网站制作要多少钱,请找那个售后服务好的牟平做网站的公司定做!

总结说明:GuidedBackpropReLU和ReLU的区别很明显,在反向传播时候,仅传播从上一层接受到的正数梯度,将负数梯度直接置零.而ReLU则全部接受上一层的梯度,不论该梯度值是正数还是负数.

实验代码展示(实验中在第58和59行,将coefficient设置成-1和+1会出现不同的效果):

import torchfrom torch.autograd import Functionclass GuidedBackpropReLU(Function):'''特殊的ReLU,区别在于反向传播时候只考虑大于零的输入和大于零的梯度'''
    # @staticmethod# def forward(ctx, input_img):  # torch.Size([1, 64, 112, 112])#     positive_mask = (input_img > 0).type_as(input_img)  # torch.Size([1, 64, 112, 112])#     # output = torch.addcmul(torch.zeros(input_img.size()).type_as(input_img), input_img, positive_mask)#     output = input_img * positive_mask  # 这行代码和上一行的功能相同#     ctx.save_for_backward(input_img, output)#     return output  # torch.Size([1, 64, 112, 112])# 上部分定义的函数功能和以下定义的函数一致@staticmethoddef forward(ctx, input_img):  # torch.Size([1, 64, 112, 112])output = torch.clamp(input_img, min=0.0)# print('函数中的输入张量requires_grad',input_img.requires_grad)ctx.save_for_backward(input_img, output)return output  # torch.Size([1, 64, 112, 112])@staticmethoddef backward(ctx, grad_output):  # torch.Size([1, 2048, 7, 7])input_img, output = ctx.saved_tensors  # torch.Size([1, 2048, 7, 7]) torch.Size([1, 2048, 7, 7])# grad_input = None  # 这行代码没作用positive_mask_1 = (input_img > 0).type_as(grad_output)  # torch.Size([1, 2048, 7, 7])  输入的特征大于零positive_mask_2 = (grad_output > 0).type_as(grad_output)  # torch.Size([1, 2048, 7, 7])  梯度大于零# grad_input = torch.addcmul(#                             torch.zeros(input_img.size()).type_as(input_img),#                             torch.addcmul(#                                             torch.zeros(input_img.size()).type_as(input_img), #                                             grad_output,#                                             positive_mask_1#                             ), #                             positive_mask_2# )grad_input = grad_output * positive_mask_1 * positive_mask_2  # 这行代码的作用和上一行代码相同return grad_input


torch.manual_seed(seed=20200910)size = (3,5)input_data_1 = input = torch.randn(*size, requires_grad=True)torch.manual_seed(seed=20200910)input_data_2 = input = torch.randn(*size, requires_grad=True)torch.manual_seed(seed=20200910)input_data_3 = input = torch.randn(*size, requires_grad=True)print('这三个输入数据的维度分别是:', input_data_1.shape, input_data_2.shape, input_data_3.shape)# print(input_data_1)# print(input_data_2)# print(input_data_3)coefficient = -1.0# coefficient = 1.0loss_1 = coefficient * torch.sum(torch.nn.ReLU()(input_data_1))loss_2 = coefficient * torch.sum(torch.nn.functional.relu(input_data_2))loss_3 = coefficient * torch.sum(GuidedBackpropReLU.apply(input_data_3))loss_1.backward()loss_2.backward()loss_3.backward()print(loss_1, loss_2, loss_3)print(loss_1.item(), loss_2.item(), loss_3.item())print('三个损失值是否相等', loss_1.item() == loss_2.item() == loss_3.item())print('简略打印三个梯度信息...')print(input_data_1.grad)print(input_data_2.grad)print(input_data_3.grad)print('这三个梯度的维度分别是:', input_data_1.grad.shape, input_data_2.grad.shape, input_data_3.grad.shape)print('检查这三个梯度是否两两相等...')print(torch.equal(input_data_1.grad, input_data_2.grad))print(torch.equal(input_data_1.grad, input_data_3.grad))print(torch.equal(input_data_2.grad, input_data_3.grad))

控制台输出(#58 coefficient = -1.0):

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 842 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch2_2_0
(ssd4pytorch2_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch2_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2021.1.502429796\pythonFiles\lib\python\debugpy\launcher' '62123' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\testReLU.py'
这三个输入数据的维度分别是: torch.Size([3, 5]) torch.Size([3, 5]) torch.Size([3, 5])
tensor(-7.1553, grad_fn=) tensor(-7.1553, grad_fn=) tensor(-7.1553, grad_fn=)      
-7.155285835266113 -7.155285835266113 -7.155285835266113
三个损失值是否相等 True
简略打印三个梯度信息...
tensor([[-1.,  0., -1.,  0.,  0.],
        [-1., -1.,  0., -1., -1.],
        [-1., -1.,  0.,  0.,  0.]])
tensor([[-1.,  0., -1.,  0.,  0.],
        [-1., -1.,  0., -1., -1.],
        [-1., -1.,  0.,  0.,  0.]])
tensor([[-0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0.]])
这三个梯度的维度分别是: torch.Size([3, 5]) torch.Size([3, 5]) torch.Size([3, 5])
检查这三个梯度是否两两相等...
True
False
False
(ssd4pytorch2_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>

控制台输出(#59 coefficient = 1.0):

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 846 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch2_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2021.1.502429796\pythonFiles\lib\python\debugpy\launcher' '62135' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\testReLU.py'
这三个输入数据的维度分别是: torch.Size([3, 5]) torch.Size([3, 5]) torch.Size([3, 5])
tensor(7.1553, grad_fn=) tensor(7.1553, grad_fn=) tensor(7.1553, grad_fn=)
7.155285835266113 7.155285835266113 7.155285835266113
三个损失值是否相等 True
简略打印三个梯度信息...
tensor([[1., 0., 1., 0., 0.],
        [1., 1., 0., 1., 1.],
        [1., 1., 0., 0., 0.]])
tensor([[1., 0., 1., 0., 0.],
        [1., 1., 0., 1., 1.],
        [1., 1., 0., 0., 0.]])
tensor([[1., 0., 1., 0., 0.],
        [1., 1., 0., 1., 1.],
        [1., 1., 0., 0., 0.]])
这三个梯度的维度分别是: torch.Size([3, 5]) torch.Size([3, 5]) torch.Size([3, 5])
检查这三个梯度是否两两相等...
True
True
True
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch2_2_0
(ssd4pytorch2_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>

关于如何对比pytorch的ReLU和自定义的class GuidedBackpropReLU就分享到这里了,希望以上内容可以对大家有一定的帮助,可以学到更多知识。如果觉得文章不错,可以把它分享出去让更多的人看到。


新闻标题:如何对比pytorch的ReLU和自定义的classGuidedBackpropReLU
标题链接:http://chengdu.cdxwcx.cn/article/gsojdo.html