import argparse import numpy as np from train import train from LinearRegression import LinearRegression from matplotlib import animation def get_thetas(path): try: thetas = np.genfromtxt(path, delimiter=',') except: print ('info: path not found, asking for train') return "error" try: return [thetas[0], thetas[1]] except: print ('warning: path in wrong format, asking for train') return "error" def get_price(mileage, thetas): return (thetas[0] + mileage * thetas[1]) def predict_subject(): thetas = get_thetas('thetas.csv') if thetas == "error": print("") try: input_thetas = input("model isn't trained, would you like to train it before predicting car price ? y/n\n"); if (input_thetas == 'y'): raw_data = np.genfromtxt('data.csv', delimiter=',', skip_header=1) train(raw_data, np.zeros(2), 'thetas.csv', False); thetas = np.genfromtxt('thetas.csv', delimiter=',') else: print ('info: wrong input format, setting thetas to 0') thetas = [0, 0] except: print ('info: wrong input format or fail to train, setting thetas to 0') thetas = [0, 0] try: mileage = int(eval(input("Enter mileage\n"))) except: print ('info: input a number') return if mileage < 0: print ('info: mileage should be superior to 0 ! aborting') return price = get_price(mileage, thetas) if price < 0: print('This car belongs in a museum ! (price inferior to 0)') else: print('Predicted car value is ', price,) def get_y(x, thetas): h = 0 for i in range(len(x)): h += x[i] * thetas[i] return h def predict(thetas_path): try: thetas = np.genfromtxt(thetas_path,delimiter=',') except: print('wrong name or format') return x = np.empty(len(thetas) - 1) x[0] = 1; for i in range(1 ,len(thetas) - 1): try: s = "Enter feature " + str(i) + "\n" x[i] = int(eval(input(s))) except: print('input a number') return print("y is equal to ", get_y(x, thetas)) def main(): parser = argparse.ArgumentParser(description='DSLR is a 2 day project, if you do Linear Regression in two weeks :pepethefrog:') parser.add_argument("-p", "--path", type=str, default=False, help="thetas file path") args = parser.parse_args() if (args.path == False): predict_subject() else: predict(args.path) if __name__ == '__main__': main()