87 lines
2.6 KiB
Python
87 lines
2.6 KiB
Python
import argparse
|
|
import numpy as np
|
|
from train import train
|
|
from LinearRegression import LinearRegression
|
|
from matplotlib import animation
|
|
|
|
|
|
def get_thetas(path):
|
|
try:
|
|
thetas = np.genfromtxt(path, delimiter=',')
|
|
except:
|
|
print ('info: path not found, asking for train')
|
|
return "error"
|
|
try:
|
|
return [thetas[0], thetas[1]]
|
|
except:
|
|
print ('warning: path in wrong format, asking for train')
|
|
return "error"
|
|
|
|
def get_price(mileage, thetas):
|
|
return (thetas[0] + mileage * thetas[1])
|
|
|
|
def predict_subject():
|
|
thetas = get_thetas('thetas.csv')
|
|
if thetas == "error":
|
|
print("")
|
|
try:
|
|
input_thetas = input("model isn't trained, would you like to train it before predicting car price ? y/n\n");
|
|
if (input_thetas == 'y'):
|
|
raw_data = np.genfromtxt('data.csv', delimiter=',', skip_header=1)
|
|
train(raw_data, np.zeros(2), 'thetas.csv', False);
|
|
thetas = np.genfromtxt('thetas.csv', delimiter=',')
|
|
else:
|
|
print ('info: wrong input format, setting thetas to 0')
|
|
thetas = [0, 0]
|
|
except:
|
|
print ('info: wrong input format or fail to train, setting thetas to 0')
|
|
thetas = [0, 0]
|
|
try:
|
|
mileage = int(eval(input("Enter mileage\n")))
|
|
except:
|
|
print ('info: input a number')
|
|
return
|
|
if mileage < 0:
|
|
print ('info: mileage should be superior to 0 ! aborting')
|
|
return
|
|
price = get_price(mileage, thetas)
|
|
if price < 0:
|
|
print('This car belongs in a museum ! (price inferior to 0)')
|
|
else:
|
|
print('Predicted car value is ', price,)
|
|
|
|
def get_y(x, thetas):
|
|
h = 0
|
|
for i in range(len(x)):
|
|
h += x[i] * thetas[i]
|
|
return h
|
|
|
|
def predict(thetas_path):
|
|
try:
|
|
thetas = np.genfromtxt(thetas_path,delimiter=',')
|
|
except:
|
|
print('wrong name or format')
|
|
return
|
|
x = np.empty(len(thetas) - 1)
|
|
x[0] = 1;
|
|
for i in range(1 ,len(thetas) - 1):
|
|
try:
|
|
s = "Enter feature " + str(i) + "\n"
|
|
x[i] = int(eval(input(s)))
|
|
except:
|
|
print('input a number')
|
|
return
|
|
print("y is equal to ", get_y(x, thetas))
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='DSLR is a 2 day project, if you do Linear Regression in two weeks :pepethefrog:')
|
|
parser.add_argument("-p", "--path", type=str, default=False, help="thetas file path")
|
|
args = parser.parse_args()
|
|
if (args.path == False):
|
|
predict_subject()
|
|
else:
|
|
predict(args.path)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|