본문 바로가기

Deep Learning/Tensorflow

Tensorflow - callback class를 이용해서, 원하는 조건이 되면 학습을 멈추게 하는 코드

반응형
callback class를 이용해서, 원하는 조건이 되면 학습을 멈추게 하는 코드

 

 

 

텐서플로우의 콜백 클래스를 상속해서 만든다.


함수 on_epoch_end 함수 안에,

 

epoch가 끝날때마다 하고싶은 작업을, 코딩을 해주면 된다.

 

ealry_stop과 함께 사용 가능하다.

 

 

class myCallback(tf.keras.callbacks.Callback) :
  def on_epoch_end(self, epoch, logs={}) :
    if logs.get('val_accuracy') > 0.87 :
      print('\n Validation accuracy가 87% 넘으므로, 학습을 멈추게 합니다.')
      self.model.stop_training = True
my_callback = myCallback()
def build_model() :
  model = tf.keras.models.Sequential()
  model.add( tf.keras.layers.Flatten() )
  model.add( tf.keras.layers.Dense( 128, 'relu'))
  model.add( tf.keras.layers.Dense( 10, 'softmax'))
  model.compile('adam', loss = 'sparse_categorical_crossentropy', metrics=['accuracy'])
  return model
model = build_model()
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
epoch_history = model.fit( training_images, training_labels, epochs=1000, validation_split = 0.2, callbacks = [ my_callback, early_stop ] )

 

 

 

 

 

 

반응형