keras K.function获取某层的输出操作
如下所示:
fromkerasimportbackendasK fromkeras.modelsimportload_model models=load_model('models.hdf5') image=r'image.png' images=cv2.imread(r'image.png') image_arr=process_image(image,(224,224,3)) image_arr=np.expand_dims(image_arr,axis=0) layer_1=K.function([base_model.get_input_at(0)],[base_model.get_layer('layer_name').output]) f1=layer_1([image_arr])[0]
加载训练好并保存的网络模型
加载数据(图像),并将数据处理成array形式
指定输出层
将处理后的数据输入,然后获取输出
其中,K.function有两种不同的写法:
1.获取名为layer_name的层的输出
layer_1=K.function([base_model.get_input_at(0)],[base_model.get_layer('layer_name').output]) #指定输出层的名称
2.获取第n层的输出
layer_1=K.function([model.get_input_at(0)],[model.layers[5].output]) #指定输出层的序号(层号从0开始)
另外,需要注意的是,书写不规范会导致报错:
报错:
TypeError:inputstoaTensorFlowbackendfunctionshouldbealistortuple
将该句:
f1=layer_1(image_arr)[0]
修改为:
f1=layer_1([image_arr])[0]
补充知识:keras.backend.function()
如下所示:
deffunction(inputs,outputs,updates=None,**kwargs): """InstantiatesaKerasfunction. Arguments: inputs:Listofplaceholdertensors. outputs:Listofoutputtensors. updates:Listofupdateops. **kwargs:Passedto`tf.Session.run`. Returns: OutputvaluesasNumpyarrays. Raises: ValueError:ifinvalidkwargsarepassedin. """ ifkwargs: forkeyinkwargs: if(keynotintf_inspect.getargspec(session_module.Session.run)[0]and keynotintf_inspect.getargspec(Function.__init__)[0]): msg=('Invalidargument"%s"passedtoK.functionwithTensorflow' 'backend')%key raiseValueError(msg) returnFunction(inputs,outputs,updates=updates,**kwargs)
这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。
我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。
classFunction(object): """Runsacomputationgraph. Arguments: inputs:Feedplaceholderstothecomputationgraph. outputs:Outputtensorstofetch. updates:Additionalupdateopstoberunatfunctioncall. name:anametohelpusersidentifywhatthisfunctiondoes. """ def__init__(self,inputs,outputs,updates=None,name=None, **session_kwargs): updates=updatesor[] ifnotisinstance(inputs,(list,tuple)): raiseTypeError('`inputs`toaTensorFlowbackendfunction' 'shouldbealistortuple.') ifnotisinstance(outputs,(list,tuple)): raiseTypeError('`outputs`ofaTensorFlowbackendfunction' 'shouldbealistortuple.') ifnotisinstance(updates,(list,tuple)): raiseTypeError('`updates`inaTensorFlowbackendfunction' 'shouldbealistortuple.') self.inputs=list(inputs) self.outputs=list(outputs) withops.control_dependencies(self.outputs): updates_ops=[] forupdateinupdates: ifisinstance(update,tuple): p,new_p=update updates_ops.append(state_ops.assign(p,new_p)) else: #assumedalreadyanop updates_ops.append(update) self.updates_op=control_flow_ops.group(*updates_ops) self.name=name self.session_kwargs=session_kwargs def__call__(self,inputs): ifnotisinstance(inputs,(list,tuple)): raiseTypeError('`inputs`shouldbealistortuple.') feed_dict={} fortensor,valueinzip(self.inputs,inputs): ifis_sparse(tensor): sparse_coo=value.tocoo() indices=np.concatenate((np.expand_dims(sparse_coo.row,1), np.expand_dims(sparse_coo.col,1)),1) value=(indices,sparse_coo.data,sparse_coo.shape) feed_dict[tensor]=value session=get_session() updated=session.run( self.outputs+[self.updates_op], feed_dict=feed_dict, **self.session_kwargs) returnupdated[:len(self.outputs)]
所以,function函数利用我们之前已经创建好的comuptationgraph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。
以上这篇kerasK.function获取某层的输出操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。