56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
import argparse
|
|
import numpy as np
|
|
from LinearRegression import LinearRegression
|
|
from matplotlib import animation
|
|
|
|
def get_thetas():
|
|
try:
|
|
thetas = np.genfromtxt('thetas.csv', delimiter=',')
|
|
except:
|
|
print ("info: thetas.csv not found, setting thetas to 0")
|
|
return [0, 0]
|
|
try:
|
|
return [thetas[2], thetas[1]]
|
|
except:
|
|
print ("warning: thetas.csv in wrong format, setting thetas to 0")
|
|
return [0, 0]
|
|
|
|
def train(raw_data, raw_thetas, thetas_path, visu):
|
|
lr = LinearRegression(thetas = raw_thetas, data = raw_data, epochs = 10000, learning_rate = 0.1)
|
|
print(lr.raw_data)
|
|
print("thetas before train")
|
|
print(lr.raw_thetas[:len(lr.raw_thetas) - 1])
|
|
lr.gradient_descent()
|
|
print("thetas after train")
|
|
print(lr.raw_thetas)
|
|
tosave = []
|
|
for i in range(len(lr.raw_thetas)):
|
|
tosave.append(lr.raw_thetas[i])
|
|
tosave.append(lr.thetas[0])
|
|
np.savetxt(thetas_path, tosave, delimiter=',')
|
|
if (visu and len(lr.thetas) == 2):
|
|
lr.show()
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='DSLR is a 2 day project, if you do Linear Regression in two weeks :pepethefrog:')
|
|
parser.add_argument("-p", "--path", type=str, default='data.csv', help="data file path")
|
|
parser.add_argument("-t", "--thetas", type=str, default=False, help="thetas file path")
|
|
parser.add_argument("--visu", default=False, help="plot data on graph", action="store_true")
|
|
args = parser.parse_args()
|
|
if (args.thetas == False):
|
|
args.thetas = 'thetas.csv'
|
|
try:
|
|
raw_data = np.genfromtxt(args.path, delimiter=',', skip_header=1)
|
|
except:
|
|
print('csv file not found or wrong')
|
|
return
|
|
try:
|
|
raw_thetas = np.genfromtxt(args.thetas, delimiter=',')
|
|
except:
|
|
print('thetas file not found or wrong, setting thetas to 0')
|
|
raw_thetas = np.zeros(raw_data.shape[1])
|
|
train(raw_data, raw_thetas, args.thetas, args.visu)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|