Pytorch抽取vgg各层并进行定制化处理的方法-创新互联
工作中有时候需要对vgg进行定制化处理,比如有些时候需要借助于vgg的层结构,但是需要使用的是2 channels输入,等等需求,这时候可以使用vgg的原始结构用class重写一遍,但是这样的方式比较慢,并且容易出错,下面给出一种比较简单的方式
创新互联建站是一家从事企业网站建设、做网站、网站制作、行业门户网站建设、网页设计制作的专业网络公司,拥有经验丰富的网站建设工程师和网页设计人员,具备各种规模与类型网站建设的实力,在网站建设领域树立了自己独特的设计风格。自公司成立以来曾独立设计制作的站点数千家。def define_vgg(vgg,input_channels,endlayer,use_maxpool=False): vgg_ad = copy.deepcopy(vgg) model = nn.Sequential() i = 0 for layer in list(vgg_ad.features): if i > endlayer: break if isinstance(layer, nn.Conv2d) and i is 0: name = "conv_" + str(i) layer = nn.Conv2d(input_channels, layer.out_channels, layer.kernel_size, stride = layer.stride, padding=layer.padding) model.add_module(name, layer) if isinstance(layer, nn.Conv2d): name = "conv_" + str(i) model.add_module(name, layer) if isinstance(layer, nn.ReLU): name = "leakyrelu_" + str(i) layer = nn.LeakyReLU(inplace=True) model.add_module(name, layer) if isinstance(layer, nn.MaxPool2d): name = "pool_" + str(i) if use_maxpool: model.add_module(name, layer) else: avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding) model.add_module(name, avgpool) i += 1 return model
另外有需要云服务器可以了解下创新互联scvps.cn,海内外云服务器15元起步,三天无理由+7*72小时售后在线,公司持有idc许可证,提供“云服务器、裸金属服务器、高防服务器、香港服务器、美国服务器、虚拟主机、免备案服务器”等云主机租用服务以及企业上云的综合解决方案,具有“安全稳定、简单易用、服务可用性高、性价比高”等特点与优势,专为企业上云打造定制,能够满足用户丰富、多元化的应用场景需求。
本文标题:Pytorch抽取vgg各层并进行定制化处理的方法-创新互联
当前URL:http://azwzsj.com/article/iedid.html