머신러닝&딥러닝
모델 학습에 진전이 없을 때 멈추려면? / 딥러닝, TensorFlow, Callback, Early Stopping
LearnerToRunner
2022. 12. 5. 19:41
목표
모델을 학습하는데 진전이 없을 때 멈추도록하기
코드 예제
tensorflow.keras.callbacks import EarlyStopping
# 모델 생성/컴파일 부분은 주석처리하겠음
"""
# 모델 생성
model = Sequential()
model.add(Dense(50, input_dim=8, activation='relu')) # 500개의 노드 생성, input feature가 8개, 활성화 함수는 relu로 사용
model.add(Dense(10, activation= 'relu'))
model.add(Dense(1, activation= 'sigmoid'))
# 컴파일
params = {'optimizer':'adam', 'lr': 0.1,'loss': 'binary_crossentropy', 'metrics':['accuracy'], 'epoch': 1000, 'batch_size':10, 'validation_split': 0.3}
model.compile(loss=params['loss'], optimizer=params['optimizer'], metrics=params['metrics'])
model.summary()
"""
# Early Stopping 셋업
early_stopping = EarlyStopping(monitor='val_loss', patience=50) # 50회 동안 val_loss가 변화가없으면 학습을 멈춰라
# 학습
model.fit(x_prcd, y, epochs=params['epoch'], batch_size=params['batch_size'], validation_split=params['validation_split'], callbacks=[early_stopping])
Early Stopping 더 잘 활용하는 법
EarlyStopping()
관찰하는 지표가 개선되지 않을 때 학습을 멈추게함
학습을 멈추게하는 것이니 model.fit에 넣어줍시다
model.fit(callback = [변수이름])
리스트형태로 넣어줄 것
** 활용할만한 args
monitor =
무슨 지표를 기준으로 멈출까?
ex. loss, val_loss, accuracy, val_accuracy
min_delta =
개선되었다고 볼 기준은?
ex. min_delta = 0.1 >>
이전 에포크 지표와 0.1 이상 차이가 나야 개선으로 인정!
patience =
몇 epoch를 반복해도 개선되지 않았을 때 멈출까?
ex. patience = 20
>> 20 에포크 동안 개선이 없으면 학습중단!
verbose =
학습을 멈출 때 메시지 띄워줄까?
0: 놉! 메시지 필요 없음
1:예압! 멈추면 알려줘
start_from_epoch =
몇 번째 epoch부터 관찰을 시작할까?
(필자가 사용하는 tensorflow-gpu 2.5 버전에는 먹히지 않았다)
확인해보기
"""
모델 생성부분 생략
"""
# EarlyStopping Setup
early_stopping = EarlyStopping(
monitor='val_loss', # val_loss를 관찰해라
min_delta = 0.01, #학습해서 val_loss에 0.01 이상의 차이가 발생하면 개선으로 판단
patience=50, # 50회 동안 개선이 안되면 멈춰라
verbose=1 # 학습을 멈추게되면 메시지로 표시해줘
start_from_spoch= 10 # 에포크 10번 돌고 난 뒤부터 개선되는지 확인 시작
)
# 학습 - 학습을 멈추는 것이니 callback 은 fit에 적용한다
model.fit(x_prcd, y, epochs=params['epoch'], batch_size=params['batch_size'], validation_split=params['validation_split'],
callbacks=[early_stopping])
아래는 실행된 결과입니다.
Epoch 21에서 val_loss 가 최소가 되었고 (0.4239)
이후에는 개선이 되지 않아 Epoch 71에서 멈춘 것을 알 수 있습니다.
Verbose = 1로 설정해두었기 때문에
학습이 멈추었을 때 [Epoch 00071: early stopping] 라는 메시지가 출력되었습니다
54/54 [==============================] - 1s 10ms/step - loss: 0.6612 - accuracy: 0.6015 - val_loss: 0.6108 - val_accuracy: 0.6710
Epoch 2/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.5779 - accuracy: 0.6685 - val_loss: 0.5420 - val_accuracy: 0.7056
Epoch 3/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.5355 - accuracy: 0.7281 - val_loss: 0.5033 - val_accuracy: 0.7403
Epoch 4/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.5125 - accuracy: 0.7561 - val_loss: 0.4791 - val_accuracy: 0.7706
Epoch 5/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.4964 - accuracy: 0.7654 - val_loss: 0.4596 - val_accuracy: 0.7965
Epoch 6/1000
54/54 [==============================] - 1s 9ms/step - loss: 0.4844 - accuracy: 0.7691 - val_loss: 0.4459 - val_accuracy: 0.7965
Epoch 7/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.4746 - accuracy: 0.7877 - val_loss: 0.4412 - val_accuracy: 0.7879
Epoch 8/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4688 - accuracy: 0.7803 - val_loss: 0.4340 - val_accuracy: 0.7922
Epoch 9/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.4619 - accuracy: 0.7709 - val_loss: 0.4304 - val_accuracy: 0.7922
Epoch 10/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4563 - accuracy: 0.7821 - val_loss: 0.4289 - val_accuracy: 0.8009
Epoch 11/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4518 - accuracy: 0.7765 - val_loss: 0.4293 - val_accuracy: 0.7965
Epoch 12/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4505 - accuracy: 0.7784 - val_loss: 0.4289 - val_accuracy: 0.7922
Epoch 13/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4456 - accuracy: 0.7765 - val_loss: 0.4276 - val_accuracy: 0.7879
Epoch 14/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4423 - accuracy: 0.7840 - val_loss: 0.4260 - val_accuracy: 0.7965
Epoch 15/1000
54/54 [==============================] - 1s 9ms/step - loss: 0.4411 - accuracy: 0.7896 - val_loss: 0.4265 - val_accuracy: 0.7879
Epoch 16/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4389 - accuracy: 0.7784 - val_loss: 0.4254 - val_accuracy: 0.7922
Epoch 17/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4336 - accuracy: 0.7821 - val_loss: 0.4258 - val_accuracy: 0.7879
Epoch 18/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4322 - accuracy: 0.7858 - val_loss: 0.4249 - val_accuracy: 0.7922
Epoch 19/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4300 - accuracy: 0.7858 - val_loss: 0.4269 - val_accuracy: 0.7792
Epoch 20/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4271 - accuracy: 0.7877 - val_loss: 0.4263 - val_accuracy: 0.7879
Epoch 21/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4239 - accuracy: 0.7858 - val_loss: 0.4239 - val_accuracy: 0.7922
Epoch 22/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4200 - accuracy: 0.7970 - val_loss: 0.4257 - val_accuracy: 0.7922
Epoch 23/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4183 - accuracy: 0.7933 - val_loss: 0.4252 - val_accuracy: 0.7922
Epoch 24/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4171 - accuracy: 0.7970 - val_loss: 0.4242 - val_accuracy: 0.7922
Epoch 25/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4142 - accuracy: 0.7952 - val_loss: 0.4256 - val_accuracy: 0.8009
Epoch 26/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.4134 - accuracy: 0.7933 - val_loss: 0.4274 - val_accuracy: 0.7922
Epoch 27/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.4111 - accuracy: 0.7952 - val_loss: 0.4300 - val_accuracy: 0.7965
Epoch 28/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.4113 - accuracy: 0.8026 - val_loss: 0.4304 - val_accuracy: 0.7922
Epoch 29/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.4081 - accuracy: 0.8026 - val_loss: 0.4278 - val_accuracy: 0.7922
Epoch 30/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.4026 - accuracy: 0.8101 - val_loss: 0.4254 - val_accuracy: 0.8009
Epoch 31/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.4009 - accuracy: 0.8119 - val_loss: 0.4236 - val_accuracy: 0.7922
Epoch 32/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.4014 - accuracy: 0.8138 - val_loss: 0.4273 - val_accuracy: 0.7922
Epoch 33/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.3996 - accuracy: 0.8101 - val_loss: 0.4285 - val_accuracy: 0.7922
Epoch 34/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.3972 - accuracy: 0.7952 - val_loss: 0.4278 - val_accuracy: 0.7879
Epoch 35/1000
54/54 [==============================] - 1s 9ms/step - loss: 0.3947 - accuracy: 0.8007 - val_loss: 0.4273 - val_accuracy: 0.7965
Epoch 36/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.3922 - accuracy: 0.8119 - val_loss: 0.4285 - val_accuracy: 0.7879
Epoch 37/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.3919 - accuracy: 0.8212 - val_loss: 0.4295 - val_accuracy: 0.7879
Epoch 38/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.3885 - accuracy: 0.8194 - val_loss: 0.4288 - val_accuracy: 0.7922
Epoch 39/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3883 - accuracy: 0.8101 - val_loss: 0.4302 - val_accuracy: 0.7922
Epoch 40/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3841 - accuracy: 0.8119 - val_loss: 0.4281 - val_accuracy: 0.7922
Epoch 41/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3838 - accuracy: 0.8231 - val_loss: 0.4303 - val_accuracy: 0.8009
Epoch 42/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3813 - accuracy: 0.8212 - val_loss: 0.4325 - val_accuracy: 0.8009
Epoch 43/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3800 - accuracy: 0.8194 - val_loss: 0.4367 - val_accuracy: 0.7965
Epoch 44/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3774 - accuracy: 0.8156 - val_loss: 0.4348 - val_accuracy: 0.7922
Epoch 45/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3765 - accuracy: 0.8231 - val_loss: 0.4336 - val_accuracy: 0.8052
Epoch 46/1000
54/54 [==============================] - 0s 7ms/step - loss: 0.3733 - accuracy: 0.8287 - val_loss: 0.4349 - val_accuracy: 0.8009
Epoch 47/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3725 - accuracy: 0.8175 - val_loss: 0.4364 - val_accuracy: 0.8052
Epoch 48/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3697 - accuracy: 0.8175 - val_loss: 0.4360 - val_accuracy: 0.8009
Epoch 49/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3682 - accuracy: 0.8268 - val_loss: 0.4427 - val_accuracy: 0.8095
Epoch 50/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.3673 - accuracy: 0.8175 - val_loss: 0.4389 - val_accuracy: 0.7965
Epoch 51/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3663 - accuracy: 0.8231 - val_loss: 0.4367 - val_accuracy: 0.8052
Epoch 52/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3654 - accuracy: 0.8287 - val_loss: 0.4392 - val_accuracy: 0.7922
Epoch 53/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.3648 - accuracy: 0.8231 - val_loss: 0.4396 - val_accuracy: 0.7879
Epoch 54/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.3579 - accuracy: 0.8250 - val_loss: 0.4443 - val_accuracy: 0.7965
Epoch 55/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.3592 - accuracy: 0.8287 - val_loss: 0.4427 - val_accuracy: 0.7922
Epoch 56/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.3583 - accuracy: 0.8399 - val_loss: 0.4440 - val_accuracy: 0.7879
Epoch 57/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.3549 - accuracy: 0.8417 - val_loss: 0.4452 - val_accuracy: 0.7922
Epoch 58/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.3521 - accuracy: 0.8492 - val_loss: 0.4440 - val_accuracy: 0.7879
Epoch 59/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3510 - accuracy: 0.8399 - val_loss: 0.4447 - val_accuracy: 0.7879
Epoch 60/1000
54/54 [==============================] - 1s 9ms/step - loss: 0.3512 - accuracy: 0.8454 - val_loss: 0.4435 - val_accuracy: 0.7922
Epoch 61/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.3491 - accuracy: 0.8417 - val_loss: 0.4459 - val_accuracy: 0.7835
Epoch 62/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.3459 - accuracy: 0.8473 - val_loss: 0.4482 - val_accuracy: 0.7879
Epoch 63/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.3460 - accuracy: 0.8492 - val_loss: 0.4478 - val_accuracy: 0.7879
Epoch 64/1000
54/54 [==============================] - 1s 10ms/step - loss: 0.3423 - accuracy: 0.8492 - val_loss: 0.4528 - val_accuracy: 0.7835
Epoch 65/1000
54/54 [==============================] - 1s 9ms/step - loss: 0.3426 - accuracy: 0.8492 - val_loss: 0.4525 - val_accuracy: 0.7835
Epoch 66/1000
54/54 [==============================] - 0s 9ms/step - loss: 0.3396 - accuracy: 0.8417 - val_loss: 0.4522 - val_accuracy: 0.7835
Epoch 67/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3403 - accuracy: 0.8436 - val_loss: 0.4539 - val_accuracy: 0.7792
Epoch 68/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3359 - accuracy: 0.8510 - val_loss: 0.4553 - val_accuracy: 0.7792
Epoch 69/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3364 - accuracy: 0.8510 - val_loss: 0.4608 - val_accuracy: 0.7922
Epoch 70/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3352 - accuracy: 0.8473 - val_loss: 0.4601 - val_accuracy: 0.7879
Epoch 71/1000
54/54 [==============================] - 0s 8ms/step - loss: 0.3321 - accuracy: 0.8585 - val_loss: 0.4607 - val_accuracy: 0.7835
Epoch 00071: early stopping
TensorFlow / EarlyStopping Documentation
728x90