특성의 개수가 많아 훈련 세트에서는 거의 완벽하게 학습이 되지만 과대적합으로 인해 test set에서는 형편없는 점수가 나온다.
규제 : 과대적합을 줄이는 방법
선형회귀 모델의 경우 계수(or 기울기)를 줄여 일반적인 모델로 만들 수 있다.
정규화
from sklearn.preprocessing import StandardScaler
ss = StandardScaler()
ss.fit(train_poly)
train_scaled = ss.transform(train_poly)
test_scaled = ss.transform(test_poly)
train set와 test set는 항상 함께!
Ridge : 계수를 제곱한 값을 기준으로 규제를 적용, alpha값이 클수록 규제가 심해진다.(default=1) solver를 지정할 수 있다.(ㅊ매개변수에 최적의 모델을 찾기 위한 방법 지정 default : solver=auto 추천 : solver=sag, solver=saga (saga는 sag의 개선버전으로 사이킷런 0.19 이후 적용)Lasso : 절댓값을 기준으로 규제를 적용
import matplotlib.pyplot as plt
# 50cm 농어의 이웃을 구합니다
distances, indexes = knr.kneighbors([[50]])
# 훈련 세트의 산점도를 그립니다
plt.scatter(train_input, train_target)
# 훈련 세트 중에서 이웃 샘플만 다시 그립니다
plt.scatter(train_input[indexes], train_target[indexes], marker='D')
# 50cm 농어 데이터
plt.scatter(50, 1033, marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
농어의 길이가 50이나 되는데 최근접 이웃인 45보다 작은 농어의 평균 값을 50cm짜리 농어의 무게로 예측한다.
print(np.mean(train_target[indexes]))
1033.3333333333333
print(knr.predict([[100]]))
[1033.33333333]
# 100cm 농어의 이웃을 구합니다
distances, indexes = knr.kneighbors([[100]])
# 훈련 세트의 산점도를 그립니다
plt.scatter(train_input, train_target)
# 훈련 세트 중에서 이웃 샘플만 다시 그립니다
plt.scatter(train_input[indexes], train_target[indexes], marker='D')
# 100cm 농어 데이터
plt.scatter(100, 1033, marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
선형회귀 : 직선으로 결과값 예측
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
# 선형 회귀 모델 훈련
lr.fit(train_input, train_target)
# 50cm 농어에 대한 예측
print(lr.predict([[50]]))
[1241.83860323]
print(lr.coef_, lr.intercept_)
[39.01714496] -709.0186449535477
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
# 선형 회귀 모델 훈련
lr.fit(train_input, train_target)
# 50cm 농어에 대한 예측
print(lr.predict([[50]]))
[1241.83860323]
print(lr.coef_, lr.intercept_)
[39.01714496] -709.0186449535477
# 훈련 세트의 산점도를 그립니다
plt.scatter(train_input, train_target)
# 15에서 50까지 1차 방정식 그래프를 그립니다
plt.plot([15, 50], [15*lr.coef_+lr.intercept_, 50*lr.coef_+lr.intercept_])
# 50cm 농어 데이터
plt.scatter(50, 1241.8, marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()