케라스의 구현 방식에는 Sequential API, Functional API 외에도 Subclassing API라는 구현 방식이 존재합니다.
앞서 ( https://wikidocs.net/111472 ) 에서 Sequential API로 구현했던 선형 회귀를 Subclassing API로 구현한다면 다음과 같습니다.
import tensorflowas as tf
class LinearRegression(tf.keras.Model):
def __init__(self):
super(LinearRegression, self).__init__()
self.linear_layer = tf.keras.layers.Dense(1, input_dim=1, activation='linear')
def call(self, x):
y_pred = self.linear_layer(x)
return y_pred
model = LinearRegression()
X = [1, 2, 3, 4, 5, 6, 7, 8, 9] # 공부하는 시간
y = [11, 22, 33, 44, 53, 66, 77, 87, 95] # 각 공부하는 시간에 맵핑되는 성적
sgd = tf.keras.optimizers.SGD(lr=0.01)
model.compile(optimizer=sgd, loss='mse', metrics=['mse'])
model.fit(X, y, epochs=300)
클래스(class) 형태의 모델은 tf.keras.Model을 상속받습니다. 그리고 init()에서 모델의 구조와 동적을 정의하는 생성자를 정의합니다. 이는 파이썬에서 객체가 갖는 속성값을 초기화하는 역할로, 객체가 생성될 때 자동으호 호출됩니다. super() 함수를 부르면 여기서 만든 클래스는 tf.keras.Model 클래스의 속성들을 가지고 초기화 됩니다. call() 함수는 모델이 데이터를 입력받아 예측값을 리턴하는 포워드(forward) 연산을 진행시키는 함수입니다.
H(x) 식에 입력 로부터 예측된 를 얻는 것을 forward 연산이라고 합니다.
x
y
Sequential API는 간단한 모델을 구현하기에 적합합니다. Functional API로는 Sequential API로 구현할 수 없는 복잡한 모델들을 구현가능합니다. 그런데 Subclassing API로는 Functional API가 구현할 수 없는 모델들조차 구현할 수 있는 경우가 있습니다. Functional API는 기본적으로 딥 러닝 모델을 DAG(directed acyclic graph)로 취급합니다. 실제로 대부분의 딥 러닝 모델이 이에 속하기는 하지만, 항상 그렇지는 않습니다. 예를 들어서 재귀 네트워크나 트리 RNN은 이 가정을 따르지 않으며 Functional API에서 구현할 수 없습니다.
이를 반대로 해석하면 대부분의 딥 러닝 모델은 Functional API 수준에서도 전부 구현이 가능하다는 의미이기도 합니다. 그래서 Subclassing API는 밑바닥부터 새로운 수준의 아키텍처를 구현해야 하는 실험적 연구를 하는 연구자들에게 적합합니다.
장점 : 단순하게 층을 쌓는 방식으로 쉽고 사용하기가 간단합니다.단점 : 다수의 입력(multi-input), 다수의 출력(multi-output)을 가진 모델 또는 층 간의 연결(concatenate)이나 덧셈(Add)과 같은 연산을 하는 모델을 구현하기에는 적합하지 않습니다. 이런 모델들의 구현은 Functional API를 사용해야 합니다.
장점 : Sequential API로는 구현하기 어려운 복잡한 모델들을 구현할 수 있습니다.단점 : 입력의 크기(shape)를 명시한 입력층(Input layer)을 모델의 앞단에 정의해주어야 합니다. 가령, 아래의 코드를 봅시다.
# 선형 회귀 구현 코드의 일부 발췌
inputs = Input(shape=(1,)) # <-- 해당 부분
output = Dense(1, activation='linear')(inputs)
linear_model = Model(inputs, output)
sgd = optimizers.SGD(lr=0.01)
linear_model.compile(optimizer=sgd, loss='mse', metrics=['mse'])
linear_model.fit(X, y, epochs=300)
장점 : Functional API로도 구현할 수 없는 모델들조차 구현이 가능합니다.단점 : 객체 지향 프로그래밍(Object-oriented programming)에 익숙해야 하므로 코드 사용이 가장 까다롭습니다.