본문 바로가기

Deep Learning/Tensorflow

Tensorflow - CNN작업 전체 코드예시

반응형
import tensorflow as tf
from keras.layers import Conv2D, MaxPooling2D

def train_mnist_conv():
    # YOUR CODE STARTS HERE
    class myCallback(tf.keras.callbacks.Callback) :
      def on_epoch_end(self, epoch, logs={}) :
        if logs.get('val_accuracy') > 0.98 :
          print('\n Reached 98% accuracy so cancelling training!')
          self.model.stop_training = True

    my_callback = myCallback()
    # YOUR CODE ENDS HERE

    mnist = tf.keras.datasets.mnist
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    # YOUR CODE STARTS HERE
    X_train = X_train / 255.0
    X_test = X_test / 255.0
    # YOUR CODE ENDS HERE
    X_train = X_train.reshape(60000, 28, 28, 1)
    X_test = X_test.reshape(10000, 28, 28, 1)

    model = tf.keras.models.Sequential([
            # YOUR CODE STARTS HERE
            tf.keras.layers.Conv2D( filters = 64, kernel_size = (3,3), activation = 'relu', input_shape = (28,28,1) ),  # 컨볼루션 레이어, 액티베이션레이어를 설정한 코드
            tf.keras.layers.MaxPooling2D(pool_size = (2), strides = 2 ),  # 피처맵을 다운사이징 한 코드
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(units = 128, activation = 'relu'),
            tf.keras.layers.Dense(units = 10, activation = 'softmax')
  
            # YOUR CODE ENDS HERE
    ])

    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    # model fitting
    history = model.fit(
        # YOUR CODE STARTS HERE
         X_train,y_train, epochs = 20, validation_data = (X_test, y_test), callbacks = my_callback
        # YOUR CODE ENDS HERE
    )
    # model fitting
    return history.epoch, history.history['accuracy'][-1]
_,_ = train_mnist_conv() # 변수로 _,_는 변수를 지정하지 않겠다는 의미

 

위 코드를 보면 알 수 있듯이

 

CNN은 ANN에 Convolution과 pooling(Downsampling)을 조합한 것임을 알 수 있다.

 
반응형