大橙子网站建设,新征程启航
为企业提供网站建设、域名注册、服务器等服务
pytorch怎样实现特征图可视化,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。
站在用户的角度思考问题,与客户深入沟通,找到师宗网站设计与师宗网站推广的解决方案,凭借多年的经验,让设计与互联网技术结合,创造个性化、用户体验好的作品,建站类型包括:成都网站制作、成都做网站、企业官网、英文网站、手机端网站、网站推广、空间域名、雅安服务器托管、企业邮箱。业务覆盖师宗地区。
是不是要这样的效果技术要点 1.选择一层网络,将图片的tensor放进去 2.将网络的输出plt.imshow
代码可直接复制使用,需要改的就是你的图片位置
import torchfrom torchvision import models, transformsfrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as npimport scipy.misc plt.rcParams['font.sans-serif']=['STSong']import torchvision.models as models model = models.alexnet(pretrained=True)#1.模型查看# print(model)#可以看出网络一共有3层,两个Sequential()+avgpool# model_features = list(model.children())# print(model_features[0][3])#取第0层Sequential()中的第四层# for index,layer in enumerate(model_features[0]):# print(layer)#2. 导入数据# 以RGB格式打开图像# Pytorch DataLoader就是使用PIL所读取的图像格式# 建议就用这种方法读取图像,当读入灰度图像时convert('')def get_image_info(image_dir):image_info = Image.open(image_dir).convert('RGB')#是一幅图片# 数据预处理方法image_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])image_info = image_transform(image_info)#torch.Size([3, 224, 224])image_info = image_info.unsqueeze(0)#torch.Size([1, 3, 224, 224])因为model的输入要求是4维,所以变成4维return image_info#变成tensor数据#2. 获取第k层的特征图''' args: k:定义提取第几层的feature map x:图片的tensor model_layer:是一个Sequential()特征层 '''def get_k_layer_feature_map(model_layer, k, x):with torch.no_grad():for index, layer in enumerate(model_layer):#model的第一个Sequential()是有多层,所以遍历x = layer(x)#torch.Size([1, 64, 55, 55])生成了64个通道if k == index:return x# 可视化特征图def show_feature_map(feature_map):#feature_map=torch.Size([1, 64, 55, 55]),feature_map[0].shape=torch.Size([64, 55, 55]) # feature_map[2].shape out of boundsfeature_map = feature_map.squeeze(0)#压缩成torch.Size([64, 55, 55])feature_map_num = feature_map.shape[0]#返回通道数row_num = np.ceil(np.sqrt(feature_map_num))#8plt.figure()for index in range(1, feature_map_num + 1):#通过遍历的方式,将64个通道的tensor拿出plt.subplot(row_num, row_num, index)plt.imshow(feature_map[index - 1], cmap='gray')#feature_map[0].shape=torch.Size([55, 55])plt.axis('off')scipy.misc.imsave( 'feature_map_save//'+str(index) + ".png", feature_map[index - 1])plt.show()if __name__ == '__main__':image_dir = r"car_logol.png"# 定义提取第几层的feature mapk = 0image_info = get_image_info(image_dir)model = models.alexnet(pretrained=True)model_layer= list(model.children())model_layer=model_layer[0]#这里选择model的第一个Sequential()feature_map = get_k_layer_feature_map(model_layer, k, image_info)show_feature_map(feature_map)
彩色图显示
#在show_feature_map函数中加上一句,tensor数据变成Img的操作image_PIL=transforms.ToPILImage()(feature_map[index - 1])
如果对于matplotlib不熟练
matplotlib绘制多个子图(汉字标题,XY轴标签)& PIL.Image 11行读取文件夹中照片
看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注创新互联行业资讯频道,感谢您对创新互联的支持。