人工智能实践 | VGG-16迁移模型

人工智能实践 | VGG-16迁移模型
文章图片
传统的机器学习训练模型需要大量的标签数据 , 而且每一个模型是为了解决特定任务设计的 , 所以当面对全新领域问题就显得无能为力 , 因此采用迁移学习来解决不同领域之间知识迁移问题 , 能达到“举一反三”的作用 , 使学习性能显著提高 。
1
VGG-16结构
VGG-16共包括13个卷积层、3个全连接层、5个池化层 , 卷积层与全连接层具有权重系数 , 而池化层不涉及权重 , 因此这就是VGG-16的来源 , 如图6.36所示 。
人工智能实践 | VGG-16迁移模型
文章图片
■图6.36VGG-16模型结构图
2
迁移学习过程
首先 , 获取想要进行训练的数据集 , 本实验采用1000个分类中的猫和老虎的数据 。 然后 , 自定义设置猫和老虎的体长参数 , 如图6.37所示 。
人工智能实践 | VGG-16迁移模型
文章图片
■图6.37猫和虎的体长数据分布
利用VGG-16训练好的modelparameters,然后保留Convolution和pooling层 , 修改fullyconnected层 , 使其变为可以被训练的两层结构 , 最终输出数字代表猫和老虎的体长 。
self.flatten之前的layers都是不能被训练的.而tf.layers.dense建立的layers是可以被训练的.训练成功之后,再定义一个Saver来保存由tf.layers.dense建立的parameters 。
训练好后的VGG-16的Convolution相当于一个featureextractor , 提取或压缩图片的特征 , 这些特征用作训练regressor , 即softmax 。
至此 , 迁移学习已经完成 , 进行测试 。
3
迁移学习结果
通过传入两张分别为猫和虎的图片 , 应用迁移学习给出各自体长结果 , 如图6.38所示 。
人工智能实践 | VGG-16迁移模型
文章图片
■图6.38迁移学习模型输出结果
实践示例程序参见附录E 。
附录EVGG-16迁移学习
importos
importnumpyasnp
importtensorflowastf
importskimage.io
importskimage.transform
importmatplotlib.pyplotasplt
defload_img(path):
img=skimage.io.imread(path)
img=img/255.0
#print"OriginalImageShape:",img.shape
#wecropimagefromcenter
short_edge=min(img.shape[:2])
yy=int((img.shape[0]-short_edge)/2)
xx=int((img.shape[1]-short_edge)/2)
crop_img=img[yy:yy+short_edge,xx:xx+short_edge]
#resizeto224,224
resized_img=skimage.transform.resize(crop_img,(224,224))[None,:,:,:]#shape[1,224,224,3]
returnresized_img
defload_data:
imgs={'tiger':[],'kittycat':[]}
forkinimgs.keys:
dir='./for_transfer_learning/data/'+k
forfileinos.listdir(dir):
ifnotfile.lower.endswith('.jpg'):
continue
try:
resized_img=load_img(os.path.join(dir,file))
exceptOSError:
continue
imgs[k].append(resized_img)#[1,height,width,depth]*n
iflen(imgs[k])==400:#onlyuse400imgstoreducemymemoryload
break
#fakelengthdatafortigerandcat
tigers_y=np.maximum(20,np.random.randn(len(imgs['tiger']),1)*30+100)
cat_y=np.maximum(10,np.random.randn(len(imgs['kittycat']),1)*8+40)
returnimgs['tiger'],imgs['kittycat'],tigers_y,cat_y
classVgg16:
vgg_mean=[103.939,116.779,123.68]
def__init__(self,vgg16_npy_path=None,restore_from=None):
#pre-trainedparameters
try:
self.data_dict=np.load(vgg16_npy_path,allow_pickle=True,encoding='latin1').item
exceptFileNotFoundError:
print('请下载')
self.tfx=tf.placeholder(tf.float32,[None,224,224,3])
self.tfy=tf.placeholder(tf.float32,[None,1])
#ConvertRGBtoBGR
red,green,blue=tf.split(axis=3,num_or_size_splits=3,value=https://pcff.toutiao.jxnews.com.cn/p/20220718/self.tfx*255.0)