> AI Gravity :: Leaner Regression (ft. Gradient descent, BackPropagation)
본문으로 바로가기

선형회귀  ---->  목적= "많은 데이터로 선긋기"

$$y=ax+b\ \ \ <->\ \ \hat{y}=wx+b$$

\[ \begin{align} &\hat{y}는 예측값 \\ &w는 가중치 \\ &y는 타켓 \end{align} \]

 

1. w, b값 선정

w=1,b=1일때,

\( \begin{align} \hat{y}&=wx+b \\&=1\cdot x\left[0\right]+1 \\ &\fallingdotseq 0.0619 \end{align} \)

 

2. w값 수정 (w 0.1 증가)

\( \begin{align} \hat{y}' &=w'x+b \\&=\left(1+0.1\right)x\left[0\right]+1 \\ &\fallingdotseq 0.0675 \end{align} \)

 

3. Gradient descent

\( \begin{align}1case 가중치 증가,예측값 증가->가중치\cdot (+1) \\ 2case 가중치 증가,예측값 감소->가중치\cdot (-1)\end{align} \)

err가 줄어드는 방향으로 1case과 2case중 선택

이 경우는 1case이므로 가중치에 양수를 더한다.

 

\( \begin{align} w_{rate} \end{align} \) 은 가중치가 변할때, 예측값의 변화율이다.

\( \begin{align} w_{rate}=\frac{\hat{y'}-\hat{y}}{w'-w}=\frac{\left(w'x+b\right)-\left(wx+b\right)}{\left(w+0.1\right)-w}=\frac{0.1}{0.1}\cdot x\left[0\right]=x\left[0\right] \end{align} \)

같은 방법으로 

\( \begin{align} b_{rate}=1 \end{align} \)

 

경사하강법에 따라 가중치와 절편은 다음과 같다.

\( \begin{align} &w'=w+w_{rate}\\ &b'=b+1 \end{align} \)

 

4. BackPropagation

최적화를 위해 오차err를 곱하여 계산에 반영한다.

\( \begin{align} &w'=w+w_{rate}\cdot err\\ &b'=b+1\cdot err\\ &err=\hat{y}-y \end{align} \)

 

\( \begin{align}  &x[0]\ \ \hat{y}=w'x+b',\left(10.25,150.0\right)\\ &x[1]\ \ \hat{y}=w'x+b'\ ,\left(14.13,75.5\right)\\   \end{align} \)

...

x[n]까지 반복 -> 에포크(epoch)

에포크 -> 100번 반복 

 최종 결과: 

w=913, b=123

\[ \begin{align} y=913x+123 \end{align} \]

 


from sklearn.datasets import load_diabetes
diabetes = load_diabetes()
print(diabetes.data.shape, diabetes.target.shape) #(442, 10)(442,)

diabetes.data=
array([[ 0.03807591,  0.05068012,  0.06169621, ..., -0.00259226,
         0.01990842, -0.01764613],
       [-0.00188202, -0.04464164, -0.05147406, ..., -0.03949338,
        -0.06832974, -0.09220405],
       [ 0.08529891,  0.05068012,  0.04445121, ..., -0.00259226,
         0.00286377, -0.02593034],
       ...,
       [ 0.04170844,  0.05068012, -0.01590626, ..., -0.01107952,
        -0.04687948,  0.01549073],
       [-0.04547248, -0.04464164,  0.03906215, ...,  0.02655962,
         0.04452837, -0.02593034],
       [-0.04547248, -0.04464164, -0.0730303 , ..., -0.03949338,
        -0.00421986,  0.00306441]])
diabetes.target=
array([151.,  75., 141., 206., 135.,  97., 138.,  63., 110., 310., 101.,
        69., 179., 185., 118., 171., 166., 144.,  97., 168.,  68.,  49.,
        68., 245., 184., 202., 137.,  85., 131., 283., 129.,  59., 341.,
        87.,  65., 102., 265., 276., 252.,  90., 100.,  55.,  61.,  92.,
       259.,  53., 190., 142.,  75., 142., 155., 225.,  59., 104., 182.,
       128.,  52.,  37., 170., 170.,  61., 144.,  52., 128.,  71., 163.,
       150.,  97., 160., 178.,  48., 270., 202., 111.,  85.,  42., 170.,
       200., 252., 113., 143.,  51.,  52., 210.,  65., 141.,  55., 134.,
        42., 111.,  98., 164.,  48.,  96.,  90., 162., 150., 279.,  92.,
        83., 128., 102., 302., 198.,  95.,  53., 134., 144., 232.,  81.,
       104.,  59., 246., 297., 258., 229., 275., 281., 179., 200., 200.,
       173., 180.,  84., 121., 161.,  99., 109., 115., 268., 274., 158.,
       107.,  83., 103., 272.,  85., 280., 336., 281., 118., 317., 235.,
        60., 174., 259., 178., 128.,  96., 126., 288.,  88., 292.,  71.,
       197., 186.,  25.,  84.,  96., 195.,  53., 217., 172., 131., 214.,
        59.,  70., 220., 268., 152.,  47.,  74., 295., 101., 151., 127.,
       237., 225.,  81., 151., 107.,  64., 138., 185., 265., 101., 137.,
       143., 141.,  79., 292., 178.,  91., 116.,  86., 122.,  72., 129.,
       142.,  90., 158.,  39., 196., 222., 277.,  99., 196., 202., 155.,
        77., 191.,  70.,  73.,  49.,  65., 263., 248., 296., 214., 185.,
        78.,  93., 252., 150.,  77., 208.,  77., 108., 160.,  53., 220.,
       154., 259.,  90., 246., 124.,  67.,  72., 257., 262., 275., 177.,
        71.,  47., 187., 125.,  78.,  51., 258., 215., 303., 243.,  91.,
       150., 310., 153., 346.,  63.,  89.,  50.,  39., 103., 308., 116.,
       145.,  74.,  45., 115., 264.,  87., 202., 127., 182., 241.,  66.,
        94., 283.,  64., 102., 200., 265.,  94., 230., 181., 156., 233.,
        60., 219.,  80.,  68., 332., 248.,  84., 200.,  55.,  85.,  89.,
        31., 129.,  83., 275.,  65., 198., 236., 253., 124.,  44., 172.,
       114., 142., 109., 180., 144., 163., 147.,  97., 220., 190., 109.,
       191., 122., 230., 242., 248., 249., 192., 131., 237.,  78., 135.,
       244., 199., 270., 164.,  72.,  96., 306.,  91., 214.,  95., 216.,
       263., 178., 113., 200., 139., 139.,  88., 148.,  88., 243.,  71.,
        77., 109., 272.,  60.,  54., 221.,  90., 311., 281., 182., 321.,
        58., 262., 206., 233., 242., 123., 167.,  63., 197.,  71., 168.,
       140., 217., 121., 235., 245.,  40.,  52., 104., 132.,  88.,  69.,
       219.,  72., 201., 110.,  51., 277.,  63., 118.,  69., 273., 258.,
        43., 198., 242., 232., 175.,  93., 168., 275., 293., 281.,  72.,
       140., 189., 181., 209., 136., 261., 113., 131., 174., 257.,  55.,
        84.,  42., 146., 212., 233.,  91., 111., 152., 120.,  67., 310.,
        94., 183.,  66., 173.,  72.,  49.,  64.,  48., 178., 104., 132.,
       220.,  57.])
import matplotlib.pyplot as plt

x = diabetes.data[:, 2]
y = diabetes.target
w = 1.0
b = 1.0

for i in range(1, 100):
    for x_i, y_i in zip(x, y):
        y_hat = x_i * w + b
        err = y_i - y_hat
        w_rate = x_i
        w = w + w_rate * err
        b = b + 1 * err
print(w, b)	#913.5973364345905 123.39414383177204

plt.scatter(x, y)
pt1 = (-0.1, -0.1 * w + b)
pt2 = (0.15, 0.15 * w + b)
plt.plot([pt1[0], pt2[0]], [pt1[1], pt2[1]])
plt.xlabel('x')
plt.ylabel('y')
plt.show()