Python/Matplotlib

[Matplotlib] 파이썬 회귀선 그리기, 결정계수(R2) 구하고 그래프에 표시하기

jimmy_AI 2022. 3. 29. 20:58
반응형

파이썬 plt 회귀선 그래프 및 결정계수 텍스트 예제

파이썬에서 선형 회귀 추세선의 식을 구하여 그래프를 그려보고,

결정계수 값을 구하여 그래프 내에 텍스트를 표시하는 예시에 대해서 다루어보겠습니다.

 

먼저, 아래와 같은 간단한 데이터로 그려진 산점도가 있다고 가정하겠습니다.

import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5, 6, 7]
y = [2, 6, 5, 8, 9, 13, 12]

plt.scatter(x, y, color = 'r', s = 20)
plt.xlabel('x')
plt.ylabel('y')
plt.show()

 

 

선형 회귀식 구하기

선형 회귀식을 구하기 위하여 넘파이 라이브러리의 np.polyfit 함수를 사용하겠습니다.

x축, y축에 해당하는 데이터와 1차식을 나타내는 1을 순서대로 input으로 넣어주면 됩니다.

import numpy as np

fit_line = np.polyfit(x, y, 1) # input 의미 : x축 데이터, y축 데이터, 1차원

print(fit_line) # [1.71428571, 1.        ] : y = 1.71428571x + 1로 선형 회귀

구한 두 값은 각각 기울기와 y절편을 의미하며, 선형 회귀식은 대략 y = 1.7143x + 1 이었습니다.

 

 

회귀선 그래프 그리기

위에서 구한 식을 이용하여, 회귀선을 그래프 내에 같이 그려보도록 하겠습니다.

 

x축의 최소값과 최대값을 각각 위의 회귀식에 대입하고,

이를 y값으로 잡아 plot을 그려주시면 해당 범위 내에서 회귀선 그래프가 그려지게 됩니다.

 

실제로 회귀선 그래프를 그리는 코드와 그려진 그래프의 결과는 아래와 같았습니다.

x_minmax = np.array([min(x), max(x)]) # x축 최소값, 최대값

fit_y = x_minmax * fit_line[0] + fit_line[1] # x축 최소, 최대값을 회귀식에 대입한 값

plt.scatter(x, y, color = 'r', s = 20)
plt.plot(x_minmax, fit_y, color = 'orange') # 회귀선 그래프 그리기
plt.xlabel('x')
plt.ylabel('y')
plt.show()

반응형

결정계수(R2) 값 구하기

회귀식에 대한 결정계수를 구하는 과정은 사이킷런의 r2_score 함수를 사용해주시면 됩니다.

 

실제 데이터 내의 각 x 값을 회귀식에 대입하여 y 값의 추정치를 구한 뒤,

실제 y 값과 추정치 y 값을 input으로 같이 넣어주시면 됩니다.

from sklearn.metrics import r2_score

est_y = np.array(x) * fit_line[0] + fit_line[1] # x의 실제 값들을 회귀식에 대입한 y 추정치

r2 = r2_score(y, est_y) # 0.9056603773584906

구한 결정계수의 값은 대략 0.9057 정도였습니다.

 

 

결정계수 값과 추세식 텍스트 그래프에 새기기

마지막으로 위에서 구한 R2 값과 선형 회귀식에 대한 텍스트를 그래프의 적당한 위치에

새기는 예시를 보여드리면서 이 글을 마무리하도록 하겠습니다.

 

plt.text(x축 위치, y축 위치, 새길 값 포맷팅 양식) 형태로 input을 지정해주시면 되며,

size 인자로 글자 크기를 조정하는 등의 옵션 설정도 가능합니다.

 

텍스트를 새기는 예제 코드와 그려진 그래프의 결과는 다음과 같았습니다.

plt.scatter(x, y, color = 'r', s = 20)
plt.plot(new_x, fit_y, color = 'orange') # 회귀선 그래프 그리기
plt.text(5, 7, '$R^2$ = %.4f'%r2, size = 12) # (5, 7)의 위치에 크기 12로 R값 새김
plt.text(4.8, 6, 'y = %.4fx + %d'%(fit_line[0], fit_line[1]), size = 12) # (4.8, 6)위치에 추세선 식 표현
plt.xlabel('x')
plt.ylabel('y')
plt.show()