解决Keras中循环使用K.ctc_decode内存不释放的问题
如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。
data=generator(...) model=init_model(...) foriinrange(NUM): x,y=next(data) _y=model.predict(x) shape=_y.shape input_length=np.ones(shape[0])*shape[1] ctc_decode=K.ctc_decode(_y,input_length)[0][0] out=K.get_value(ctc_decode)
原因
每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。
PS:有资料说是由于get_value导致的,其中也给出了解决方案。
但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?
解决方案
通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)
data=generator(...) model=init_model(...) x=model.output#[batch_sizes,series_length,classes] input_length=KL.Input(batch_shape=[None],dtype='int32') ctc_decode=K.ctc_decode(x,input_length=input_length*K.shape(x)[1]) decode=K.function([model.input,input_length],[ctc_decode[0][0]]) foriinrange(NUM): _x,_y=next(data) out=decode([_x,np.ones(1)])
补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加
该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。
测试方法:在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。
fromkerasimportbackendasK fromkeras.layersimportLambda,Input fromkerasimportModel fromtensorflow.python.opsimportctc_opsasctc importtensorflowastf fromkeras.layersimportLayer classCTC_Batch_Cost(): ''' 用于计算CTCloss ''' defctc_lambda_func(self,args): """RunsCTClossalgorithmoneachbatchelement. #Arguments y_true:tensor`(samples,max_string_length)`真实标签 y_pred:tensor`(samples,time_steps,num_categories)`预测前未经过softmax的向量 input_length:tensor`(samples,1)`每一个y_pred的长度 label_length:tensor`(samples,1)`每一个y_true的长度 #Returns Tensorwithshape(samples,1)包含了每一个样本的ctcloss """ y_true,y_pred,input_length,label_length=args #y_pred=y_pred[:,:,:] #y_pred=y_pred[:,2:,:] returnself.ctc_batch_cost(y_true,y_pred,input_length,label_length) def__call__(self,args): ''' ctc_decode每次创建会生成一个节点,这里参考了上面的内容 将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题 ''' y_true=Input(shape=(None,)) y_pred=Input(shape=(None,None)) input_length=Input(shape=(1,)) label_length=Input(shape=(1,)) lamd=Lambda(self.ctc_lambda_func,output_shape=(1,),name='ctc')([y_true,y_pred,input_length,label_length]) model=Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc") #returnLambda(self.ctc_lambda_func,output_shape=(1,),name='ctc')(args) returnmodel(args) defctc_batch_cost(self,y_true,y_pred,input_length,label_length): """RunsCTClossalgorithmoneachbatchelement. #Arguments y_true:tensor`(samples,max_string_length)` containingthetruthlabels. y_pred:tensor`(samples,time_steps,num_categories)` containingtheprediction,oroutputofthesoftmax. input_length:tensor`(samples,1)`containingthesequencelengthfor eachbatchitemin`y_pred`. label_length:tensor`(samples,1)`containingthesequencelengthfor eachbatchitemin`y_true`. #Returns Tensorwithshape(samples,1)containingthe CTClossofeachelement. """ label_length=tf.to_int32(tf.squeeze(label_length,axis=-1)) input_length=tf.to_int32(tf.squeeze(input_length,axis=-1)) sparse_labels=tf.to_int32(K.ctc_label_dense_to_sparse(y_true,label_length)) y_pred=tf.log(tf.transpose(y_pred,perm=[1,0,2])+1e-7) #注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch returntf.expand_dims(ctc.ctc_loss(inputs=y_pred, labels=sparse_labels, sequence_length=input_length, ignore_longer_outputs_than_inputs=True),1) #使用方法:(注意shape) loss_out=CTC_Batch_Cost()([y_true,y_pred,audio_length,label_length])
fromkerasimportbackendasK fromkeras.layersimportLambda,Input fromkerasimportModel fromtensorflow.python.opsimportctc_opsasctc importtensorflowastf fromkeras.layersimportLayer classCTCDecodeLayer(Layer): def__init__(self,**kwargs): super().__init__(**kwargs) def_ctc_decode(self,args): base_pred,in_len=args in_len=K.squeeze(in_len,axis=-1) r=K.ctc_decode(base_pred,in_len,greedy=True,beam_width=100,top_paths=1) r1=r[0][0] prob=r[1][0] return[r1,prob] defcall(self,inputs,**kwargs): returnself._ctc_decode(inputs) defcompute_output_shape(self,input_shape): return[(None,None),(1,)] classCTCDecode(): '''用与CTC解码,得到真实语音序列 2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生 ''' def__init__(self): base_pred=Input(shape=[None,None],name="pred") feature_len=Input(shape=[1,],name="feature_len") r1,prob=CTCDecodeLayer()([base_pred,feature_len]) self.model=Model([base_pred,feature_len],[r1,prob]) pass defctc_decode(self,base_pred,in_len,return_prob=False): ''' :parambase_pred:[sample,timestamp,vector] :paramin_len:[sample,1] :return: ''' result,prob=self.model.predict([base_pred,in_len]) ifreturn_prob: returnresult,prob returnresult def__call__(self,base_pred,in_len,return_prob=False): returnself.ctc_decode(base_pred,in_len,return_prob) #使用方法:(注意shape,是batch级的输入) ctc_decoder=CTCDecode() ctc_decoder.ctc_decode(result,feature_len)
以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。