ft_linear_regression/predict.py

87 lines
2.6 KiB
Python

import argparse
import numpy as np
from train import train
from LinearRegression import LinearRegression
from matplotlib import animation
def get_thetas(path):
try:
thetas = np.genfromtxt(path, delimiter=',')
except:
print ('info: path not found, asking for train')
return "error"
try:
return [thetas[0], thetas[1]]
except:
print ('warning: path in wrong format, asking for train')
return "error"
def get_price(mileage, thetas):
return (thetas[0] + mileage * thetas[1])
def predict_subject():
thetas = get_thetas('thetas.csv')
if thetas == "error":
print("")
try:
input_thetas = input("model isn't trained, would you like to train it before predicting car price ? y/n\n");
if (input_thetas == 'y'):
raw_data = np.genfromtxt('data.csv', delimiter=',', skip_header=1)
train(raw_data, np.zeros(2), 'thetas.csv', False);
thetas = np.genfromtxt('thetas.csv', delimiter=',')
else:
print ('info: wrong input format, setting thetas to 0')
thetas = [0, 0]
except:
print ('info: wrong input format or fail to train, setting thetas to 0')
thetas = [0, 0]
try:
mileage = int(eval(input("Enter mileage\n")))
except:
print ('info: input a number')
return
if mileage < 0:
print ('info: mileage should be superior to 0 ! aborting')
return
price = get_price(mileage, thetas)
if price < 0:
print('This car belongs in a museum ! (price inferior to 0)')
else:
print('Predicted car value is ', price,)
def get_y(x, thetas):
h = 0
for i in range(len(x)):
h += x[i] * thetas[i]
return h
def predict(thetas_path):
try:
thetas = np.genfromtxt(thetas_path,delimiter=',')
except:
print('wrong name or format')
return
x = np.empty(len(thetas) - 1)
x[0] = 1;
for i in range(1 ,len(thetas) - 1):
try:
s = "Enter feature " + str(i) + "\n"
x[i] = int(eval(input(s)))
except:
print('input a number')
return
print("y is equal to ", get_y(x, thetas))
def main():
parser = argparse.ArgumentParser(description='DSLR is a 2 day project, if you do Linear Regression in two weeks :pepethefrog:')
parser.add_argument("-p", "--path", type=str, default=False, help="thetas file path")
args = parser.parse_args()
if (args.path == False):
predict_subject()
else:
predict(args.path)
if __name__ == '__main__':
main()