import argparse import numpy as np from LinearRegression import LinearRegression from matplotlib import animation def get_thetas(): try: thetas = np.genfromtxt('thetas.csv', delimiter=',') except: print ("info: thetas.csv not found, setting thetas to 0") return [0, 0] try: return [thetas[2], thetas[1]] except: print ("warning: thetas.csv in wrong format, setting thetas to 0") return [0, 0] def train(raw_data, raw_thetas, thetas_path, visu): lr = LinearRegression(thetas = raw_thetas, data = raw_data, epochs = 10000, learning_rate = 0.1) print(lr.raw_data) print("thetas before train") print(lr.raw_thetas[:len(lr.raw_thetas) - 1]) lr.gradient_descent() print("thetas after train") print(lr.raw_thetas) tosave = [] for i in range(len(lr.raw_thetas)): tosave.append(lr.raw_thetas[i]) tosave.append(lr.thetas[0]) np.savetxt(thetas_path, tosave, delimiter=',') if (visu and len(lr.thetas) == 2): lr.show() def main(): parser = argparse.ArgumentParser(description='DSLR is a 2 day project, if you do Linear Regression in two weeks :pepethefrog:') parser.add_argument("-p", "--path", type=str, default='data.csv', help="data file path") parser.add_argument("-t", "--thetas", type=str, default=False, help="thetas file path") parser.add_argument("--visu", default=False, help="plot data on graph", action="store_true") args = parser.parse_args() if (args.thetas == False): args.thetas = 'thetas.csv' try: raw_data = np.genfromtxt(args.path, delimiter=',', skip_header=1) except: print('csv file not found or wrong') return try: raw_thetas = np.genfromtxt(args.thetas, delimiter=',') except: print('thetas file not found or wrong, setting thetas to 0') raw_thetas = np.zeros(raw_data.shape[1]) train(raw_data, raw_thetas, args.thetas, args.visu) if __name__ == '__main__': main()