73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
from Ft_logistic_regression import Ft_logistic_regression
|
|
from threading import Thread
|
|
import argparse
|
|
|
|
class multi_train(Thread):
|
|
|
|
def __init__(self, key, hogwarts, hogwarts_thetas):
|
|
Thread.__init__(self)
|
|
self.key = key
|
|
self.hogwarts = hogwarts
|
|
self.hogwarts_thetas = hogwarts_thetas
|
|
|
|
def run(self):
|
|
key = self.key
|
|
hogwarts = self.hogwarts
|
|
hogwarts_thetas = self.hogwarts_thetas
|
|
for elem in hogwarts[key]['Index']:
|
|
hogwarts[key].loc[elem, 'Hogwarts House'] = 1 if hogwarts[key].loc[elem, 'Hogwarts House'] == key else 0
|
|
hogwarts[key] = np.array(hogwarts[key])
|
|
hogwarts[key] = np.array(hogwarts[key][:,1:])
|
|
hogwarts[key] = np.hstack((hogwarts[key][:,1:], hogwarts[key][:,:1]))
|
|
lr = Ft_logistic_regression(epochs = 2000, learning_rate = 20, data = hogwarts[key])
|
|
print('Training ' + key + '...')
|
|
lr.gradient_descent()
|
|
hogwarts_thetas[key] = lr.thetas
|
|
print ('Thetas values for ' + key + ' : ')
|
|
print(hogwarts_thetas[key])
|
|
print('Done training ' + key + ' !\n\n')
|
|
|
|
def train():
|
|
raw_data = pd.read_csv('dataset_train.csv',delimiter=',')
|
|
data = raw_data.drop(columns = ['First Name', 'Last Name', 'Birthday', 'Best Hand'])
|
|
data = data.drop(columns = ['Arithmancy', 'Care of Magical Creatures', 'Astronomy'])
|
|
hogwarts = {
|
|
'Gryffindor':data.copy(),
|
|
'Slytherin':data.copy(),
|
|
'Hufflepuff':data.copy(),
|
|
'Ravenclaw':data.copy()
|
|
}
|
|
|
|
hogwarts_thetas = {
|
|
'Gryffindor':[],
|
|
'Slytherin':[],
|
|
'Hufflepuff':[],
|
|
'Ravenclaw':[]
|
|
}
|
|
thread = {}
|
|
for key in hogwarts:
|
|
thread[key] = multi_train(key, hogwarts, hogwarts_thetas)
|
|
for key in thread:
|
|
thread[key].start()
|
|
for key in thread:
|
|
thread[key].join()
|
|
thetas= np.array([hogwarts_thetas['Gryffindor'],hogwarts_thetas['Slytherin'],hogwarts_thetas['Hufflepuff'],hogwarts_thetas['Ravenclaw']])
|
|
np.savetxt('thetas.csv', thetas, delimiter=',')
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("path", type=str, default="none", help="dataset_train.csv")
|
|
args = parser.parse_args()
|
|
try:
|
|
if (args.path == "dataset_train.csv"):
|
|
train()
|
|
else:
|
|
print("wrong path")
|
|
except:
|
|
print("Error")
|
|
|
|
if __name__ == '__main__':
|
|
main()
|