In this reading, we'll see how np.dot
(often expressed with the @
operator) and np.linalg.solve
relate to predict
and fit
respectively for sklearn's LinearRegression.
Say we've seen a few houses sell recently, with the following characteristics (features) and prices (label):
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
train = pd.DataFrame([[2,1,1985,196.55],
[3,1,1998,260.56],
[4,3,2005,334.55],
[4,2,2020,349.6]],
columns=["beds", "baths", "year", "price"])
train
Can we fit a LinearRegression model to the above, then use it to predict prices for the following three houses that haven't sold yet?
live = pd.DataFrame([[2,2,1999],
[5,1,1950],
[3,4,2000]],
columns=["beds", "baths", "year"])
live
lr = LinearRegression()
lr.fit(train[["beds", "baths", "year"]], train["price"])
lr.predict(live)
The above tells us that the model thinks the three houses will sell for \$229.93K, \\$265K, and \$293.9K, respectively. Underlying this prediction was a dot product based on some arrays calculated during fitting.
c = lr.coef_
c
b = lr.intercept_
b
Let's pull the features from the first row of the live data into an array too.
house = live.iloc[0].values
house
The array we pulled from lr
is called coef_ because those numbers are meant to be coefficients on the features for each house.
c[0]*house[0] + c[1]*house[1]+ c[2]*house[2] + b
That was the same amount predicted by lr.predict
for the first house! Better, if we put our houses in the right shape, we can simplify the expression to a dot product.
house = house.reshape(1,-1)
c = c.reshape(-1,1)
np.dot(house, c) + b
Same thing as before! Or using the @
operator, which is a shorthand for np.dot
:
house @ c + b
We've seen how to do row @ col
. If we do matrix @ col
, you can think of it as looping over each row in matrix, computing the dot product of each row with col, then stacking the results in the output. This means we can do all the predictions at once!
live.values
live.values @ c + b
Recall that these are the same values that LinearRegression predicted earlier -- it's just using the dot product internally:
lr.predict(live)
Ok, how did fit
determine what values to use in coef_
and intercept_
? Let's think about how we could use X and y values extracted from our training data:
X = train.values[:, :-1]
X
y = train.values[:, -1:]
y
We know that for predictions, LinearAlgebra wants to use this:
X @ coef_ + intercept_ = y
coef_
is a vector and intercept_
is a single number; we can eliminate intercept_
as a separate variable if we add an entry to coef_
and add a column of ones to X.
X = np.concatenate((train.values[:, :-1], np.ones((len(train), 1))), axis=1)
X
This gives us this simple equation:
X @ coef_ = y
We know X
and y
(from the train
DataFrame) -- can we use those to solve for coef_
? If the dot product were a regular multiplication, we would divide both sides by X, but that's not valid for matrices and the dot product. Solving is a little trickier, but fortunately numpy can do it for us:
# Solve's for c in X@c=y, given X and y as arguments
c = np.linalg.solve(X, y)
list(c.reshape(-1))
That contains the same coefficients and intercept that LinearRegression.fit
found earlier:
print(list(lr.coef_), lr.intercept_)
How does np.linalg.solve
work? It is solving a system of N equations and N variables. It turns out it is possible to convert a table to such an algebra problem, converting each row to an equation and each column to a variable.
Of course, this means it only works for square tables (same number of rows and columns), such as the sub-table of train
that contains features, if we were to add a column of ones:
train[["beds", "baths", "year"]]
One implication is that np.linalg.solve
won't work for us if there are M rows and N columns, where M > N. This would be solving a system of M equations with only N variables. It is rarely possible to solve for correct solutions in such cases. In the near future, however, we'll learn how to solve for good solutions ("good" remains to be defined) for systems of M equations and N variables. Of course, most of the tables you've worked with probably have more rows than columns, so this is a very important problem.