일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 | 31 |
- VGGNet
- NER
- BERT
- cross domain
- tensorflow
- 경사하강법
- textmining
- MLOps
- NMF
- Attention
- Transfer Learning
- stemming
- Clustering
- 군집화
- gaze estimation
- Python
- RNN
- Logistic Regression
- 자기조직화지도
- Ann
- TFX
- ResNet
- Gradient Descent
- Binary classification
- SOMs
- nlp
- LSTM
- AI 윤리
- Support Vector Machine
- Generative model
- Today
- Total
juooo1117
Transfer Learning (CNN in TensorFlow) 본문
Transfer Learning (CNN in TensorFlow)
Hyo__ni 2023. 12. 21. 13:43
Transfer Learning : utilize an already trained network to help you solve a similar problem to the one it was originally trained to solve → 이미 훈련된 네트워크를 활용하여 훈련된 모델과 유사한 문제를 해결하는 데 도움이 되는 기술
horse & human 이미지를 가지고 image classifier를 만들어보자.
Dataset
As expected, the sample image has a resolution of 300x300 and the last dimension is used for each one of the RGB channels to represent color. → 데이터셋의 사이즈가 (300, 300, 3) 인 것을 확인하였다.
# Load the first example of a horse
sample_image = load_img(f"{os.path.join(train_horses_dir, os.listdir(train_horses_dir)[0])}")
# Convert the image into its numpy array representation
sample_array = img_to_array(sample_image)
print(f"Each image has shape: {sample_array.shape}") # (300, 300, 3)
Training and Validation Generators
The images have a resolution of 300x300 but the flow_from_directory method you will use allows you to set a target resolution. In this case, set a target_size of (150, 150). This will heavily lower the number of trainable parameters in your final network, yielding much quicker training times without compromising the accuracy.
→ 이미지의 해상도는 300x300이지만 target_size를 150x150로 설정하기 위해서 flow_from_directory 메서드를 사용한다. 해상도를 줄이면, 최종 네트워크에서 훈련 가능한 매개변수 수가 크게 줄어들어 정확도를 저하시키지 않으면서 훨씬 더 빠른 훈련 시간을 얻을 수 있기 때문이다.
# Instantiate the ImageDataGenerator class → set the rescale argument
# normalize pixel values and set arguments to augment the images
# validation data should not be augmented
def train_val_generators(TRAINING_DIR, VALIDATION_DIR):
train_datagen = ImageDataGenerator(rescale = 1./255.)
train_generator = train_datagen.flow_from_directory(directory=TRAINING_DIR,
batch_size=32,
class_mode='binary',
target_size=(150, 150))
validation_datagen = ImageDataGenerator(rescale = 1./255.)
validation_generator = validation_datagen.flow_from_directory(directory=VALIDATION_DIR,
batch_size=32,
class_mode='binary',
target_size=(150, 150))
return train_generator, validation_generator
train_generator, validation_generator = train_val_generators(train_dir, validation_dir)
Transfer learning - Create the pre-trained model
Download the inception V3 weights and load the InceptionV3 model and save the path to the weights.
Complete the create_pre_trained_model function below.
You should specify the correct input_shape for the model and make all of the layers non-trainable:
→ 모델에 대해 올바른 input_shape를 지정하고 모든 레이어를 학습 불가능하게 만들어야 한다.
Total params: 21,802,784
Trainable params: 0
Non-trainable params: 21,802,784
# Import the inception model
from tensorflow.keras.applications.inception_v3 import InceptionV3
# Create an instance of the inception model from the local pre-trained weights
local_weights_file = '/tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'
def create_pre_trained_model(local_weights_file):
pre_trained_model = InceptionV3(input_shape = (150, 150, 3),
include_top = False,
weights = None)
pre_trained_model.load_weights(local_weights_file)
# Make all the layers in the pre-trained model non-trainable
for layer in pre_trained_model.layers:
layer.trainable = False
return pre_trained_model
pre_trained_model = create_pre_trained_model(local_weights_file)
Creating callbacks for later
Define a Callback class that stops training once accuracy reaches 99.9%
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if(logs.get('accuracy') > 0.999):
print("\nReached 99.9% accuracy so cancelling training!")
self.model.stop_training = True
Pipelining the pre-trained model with your own
Now that the pre-trained model is ready, you need to "glue" it to your own model to solve the task at hand.
For this you will need the last output of the pre-trained model, since this will be the input for your own.
→ 이제 사전 훈련된 모델을 자신의 모델에 붙여야 하는데(glue), 이를 위해서는 사전 훈련된 모델의 마지막 출력이 필요하다. 이 마지막 출력이 모델의 입력으로 들어가기 때문이다.
# Note: For grading purposes use the mixed7 layer as the last layer of the pre-trained model.
def output_of_last_layer(pre_trained_model):
last_desired_layer = pre_trained_model.get_layer('mixed7')
print('last layer output shape: ', last_desired_layer.output_shape)
last_output = last_desired_layer.output
print('last layer output: ', last_output)
return last_output
last_output = output_of_last_layer(pre_trained_model)
연결된 최종 모델을 완성해 보자.
# Flatten the output layer to 1 dimension
# Add a fully connected layer with 1024 hidden units and ReLU activation
# Add a dropout rate of 0.2
# Add a final sigmoid layer for classification
# Create the complete model by using the Model class
# Compile the model
def create_final_model(pre_trained_model, last_output):
x = layers.Flatten()(last_output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(1, activation='sigmoid')(x)
model = Model(inputs=pre_trained_model.input, outputs=x)
model.compile(optimizer = RMSprop(learning_rate=0.0001),
loss = 'binary_crossentropy',
metrics = ['accuracy'])
return model
model = create_final_model(pre_trained_model, last_output)
Now train the model:
callbacks = myCallback()
history = model.fit(train_generator,
validation_data = validation_generator,
epochs = 100,
verbose = 2,
callbacks=callbacks)
The training should have stopped after less than 10 epochs
→ 이미 포함된 사전 훈련된 모델을 사용했기 때문에 매우 빠르게 높은 정확도에 도달하는 것을 확인했다.
'Deep Learning Study > DeepLearning.AI (Coursera)' 카테고리의 다른 글
Deploying a Machine Learning Model (Fast API) (0) | 2023.12.25 |
---|---|
Predicting the next word (Sequence Models and Literature) (0) | 2023.12.24 |