ft_linear_regression/train.py

56 lines
2.0 KiB
Python

import argparse
import numpy as np
from LinearRegression import LinearRegression
from matplotlib import animation
def get_thetas():
try:
thetas = np.genfromtxt('thetas.csv', delimiter=',')
except:
print ("info: thetas.csv not found, setting thetas to 0")
return [0, 0]
try:
return [thetas[2], thetas[1]]
except:
print ("warning: thetas.csv in wrong format, setting thetas to 0")
return [0, 0]
def train(raw_data, raw_thetas, thetas_path, visu):
lr = LinearRegression(thetas = raw_thetas, data = raw_data, epochs = 10000, learning_rate = 0.1)
print(lr.raw_data)
print("thetas before train")
print(lr.raw_thetas[:len(lr.raw_thetas) - 1])
lr.gradient_descent()
print("thetas after train")
print(lr.raw_thetas)
tosave = []
for i in range(len(lr.raw_thetas)):
tosave.append(lr.raw_thetas[i])
tosave.append(lr.thetas[0])
np.savetxt(thetas_path, tosave, delimiter=',')
if (visu and len(lr.thetas) == 2):
lr.show()
def main():
parser = argparse.ArgumentParser(description='DSLR is a 2 day project, if you do Linear Regression in two weeks :pepethefrog:')
parser.add_argument("-p", "--path", type=str, default='data.csv', help="data file path")
parser.add_argument("-t", "--thetas", type=str, default=False, help="thetas file path")
parser.add_argument("--visu", default=False, help="plot data on graph", action="store_true")
args = parser.parse_args()
if (args.thetas == False):
args.thetas = 'thetas.csv'
try:
raw_data = np.genfromtxt(args.path, delimiter=',', skip_header=1)
except:
print('csv file not found or wrong')
return
try:
raw_thetas = np.genfromtxt(args.thetas, delimiter=',')
except:
print('thetas file not found or wrong, setting thetas to 0')
raw_thetas = np.zeros(raw_data.shape[1])
train(raw_data, raw_thetas, args.thetas, args.visu)
if __name__ == '__main__':
main()