CNN 모델과 혈액 도말 이미지를 활용해서 말라리아 감염 여부를 예측했다.
CNN 관련 용어를 이해하고 네트워크 구조를 코드로 작성했다.
라이브러리 로드
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
수치 연산, 데이터 핸들링, 시각화 라이브러리 로드
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense
from tensorflow.keras.callbacks import EarlyStopping
모델 구성, 학습에 필요한 라이브러리와 함수 로드
이미지 확인
이미지는 https://data.lhncbc.nlm.nih.gov/public/Malaria/cell_images.zip 파일을 다운로드, 압축 해제하여 사용한다.
말라리아 감염 이미지는 Parasitized 폴더에, 비감염 이미지는 Uninfected 폴더에 저장되었다.
import glob
upics = glob.glob('./cell_images/Uninfected/*.png')
apics = glob.glob('./cell_images/Parasitized/*.png')
아래에서 감염/비감염 각각 사진 한 장씩 확인
upics_0 = upics[0]
upics_0_img = plt.imread(upics_0)
plt.imshow(upics_0_img)
감염되지 않은 사진 중 첫 번째 사진
apics_0 = apics[0]
apics_0_img = plt.imread(apics_0)
plt.imshow(apics_0_img)
감염된 사진 중 첫 번째 사진
OpenCV를 사용하여 크기와 함께 여러 장의 사진을 보자.
import cv2
plt.figure(figsize=(8, 8))
labels = "Uninfected"
for i, images in enumerate(upics[:9]):
ax = plt.subplot(3, 3, i + 1)
img = cv2.imread(images)
plt.imshow(img)
plt.title(f'{labels} {img.shape}')
plt.axis("off")
감염되지 않은 사진 9장 확인
plt.figure(figsize=(8, 8))
labels = "Infected"
for i, images in enumerate(apics[:9]):
ax = plt.subplot(3, 3, i + 1)
img = cv2.imread(images)
plt.imshow(img)
plt.title(f'{labels} {img.shape}')
plt.axis("off")
감염된 사진 9장 확인
감염 여부와 상관없이 이미지 크기가 제각각이어서 크기를 통일시킬 필요가 있다.
데이터셋 나누기
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(rescale=1/255.0, validation_split=0.2)
validation_split 값을 통해 학습 데이터와 검증 데이터를 8:2로 split 한다.
train set
trainDatagen = datagen.flow_from_directory(directory = 'cell_images/',
target_size = (64, 64),
class_mode = 'binary',
batch_size = 64,
subset='training')
# 실행 결과
Found 22048 images belonging to 2 classes.
train set을 생성한다. 원본 이미지는 가로, 세로 100~200이지만 64, 64로 resize 했다. class_mode는 이진 분류이기 때문에 binary로 지정했다.
test set
valDatagen = datagen.flow_from_directory(directory = 'cell_images/',
target_size =(64, 64),
class_mode = 'binary',
batch_size = 64,
subset='validation')
# 실행 결과
Found 5510 images belonging to 2 classes.
train set과 마찬가지로 test set도 생성.
모델 생성
CNN 모델은 위 그림과 같이 Conv2D와 MaxPooling2D로 구성된 Convolution Layer와 그 뒤의 Dense Layer로 구성된다.
예시로 위 그림은 Conv(filters=6)-Pool(pool_size=2)-Conv(filters=16)-Pool(pool_size=2)-Dense(120)-Dense(84)-Dense(10)으로 표현할 수 있다.
그리고 (이미지 높이, 이미지 너비, 컬러 채널) 크기의 텐서를 입력으로 받는다. 컬러 이미지는 R, G, B 총 3개의 채널을 가지며 흑백 이미지는 하나의 채널을 가진다.
model = Sequential()
# 입력층
model.add(Conv2D(filters=16, kernel_size=(3,3), activation='relu', input_shape=(64, 64, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(filters=32, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(filters=16, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
# Dense Layer(Fully-connected layer)
model.add(Flatten())
model.add(Dense(32, activation='relu'))
# 출력층 => 말라리아 감염 여부 분류
model.add(Dense(1, activation='sigmoid')) # 이진 분류의 출력층(0~1 값)
위와 같이 모델을 생성했다. 출력층의 활성화함수는 이진 분류이므로 0~1의 확률값을 출력하는 sigmoid를 지정했다.
모델 요약
model.summary()
레이어 시각화
plot_model 함수를 사용했다. model.add 한 층들을 볼 수 있다.
모델 컴파일
optim = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer=optim,
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['accuracy']
)
모델을 훈련하기 전 몇 가지 설정이 컴파일 단계에서 이루어진다.
옵티마이저는 'adam'을 사용하고 이진 분류이므로 binary crossentropy를 손실함수로 지정했다.
모델 fit
early_stop= EarlyStopping(monitor='val_loss', patience=5)
history = model.fit(trainDatagen, epochs=1000, callbacks=early_stop, validation_data=valDatagen)
5회 이상 성능 개선이 이루어지지 않을 경우(val_loss가 낮아지지 않을 경우) 조기 종료하도록 early stop을 지정했다.
epoch 18만에 조기종료되었다. 아래처럼 성능 개선을 데이터프레임으로 만들어 확인할 수 있다.
생성한 CNN 모델은 혈액 도말 이미지를 학습하여 말라리아 감염 여부 분류에서 90% 이상의 정확도를 보이는 것을 확인할 수 있었다. 모델의 구성이나 학습 코드를 개선한다면 더 높은 정확도를 기대할 수 있을 것 같다.
'AI SCHOOL > TIL' 카테고리의 다른 글
[DAY 80] Bidirectional RNN을 통해 삼성전자 주가 예측하기 (0) | 2023.04.19 |
---|---|
[DAY 79] CNN과 날씨 이미지를 활용한 멀티클래스 분류 (1) | 2023.04.18 |
[DAY 77] 코딩테스트 연습 - 정규표현식, 페이지 교체 알고리즘, 카카오 문제 (0) | 2023.04.14 |
[DAY 76] Week 17 Insight Day 미니프로젝트5 시작, Resume & Portfolio 특강 (0) | 2023.04.13 |
[DAY 75] PyTorch를 활용한 자동차 연비 회귀 예측 (0) | 2023.04.12 |
댓글