DSLR/logreg_predict.py

60 lines
2.2 KiB
Python

import pandas as pd
import numpy as np
import csv
from Ft_logistic_regression import Ft_logistic_regression
import argparse
def get_value(thetas, elem):
test = [1] + elem[2:]
for i, v in enumerate(test):
if np.isnan(v):
test[i] = 0
return np.dot(thetas, test)
def predict():
raw_data = pd.read_csv('dataset_test.csv',delimiter=',')
data = raw_data.drop(columns = ['First Name', 'Last Name', 'Birthday', 'Best Hand'])
data = data.drop(columns = ['Arithmancy', 'Care of Magical Creatures', 'Astronomy'])
hogwarts = data.copy()
for key in hogwarts:
if key != 'Index' and key != 'Hogwarts House':
hogwarts = hogwarts.drop(columns = [key])
houses = [None] * 4
thetas = np.genfromtxt('thetas.csv', delimiter=',')
for elem in hogwarts['Index']:
test = data.loc[elem].tolist()
i = 2
for key in data:
if key != 'Index' and key != 'Hogwarts House':
test[i] = (test[i] - min(data[key])) / (max(data[key]) - min(data[key]))
i += 1
for i in range(4):
houses[i] = get_value(thetas[i], test)
if (houses[0] > houses[1] and houses[0] > houses[2] and houses[0] > houses[3]):
hogwarts.loc[elem, 'Hogwarts House'] = 'Gryffindor'
elif (houses[1] > houses[2] and houses[1] > houses[3]):
hogwarts.loc[elem, 'Hogwarts House'] = 'Slytherin'
elif (houses[2] > houses[3]):
hogwarts.loc[elem, 'Hogwarts House'] = 'Hufflepuff'
else:
hogwarts.loc[elem, 'Hogwarts House'] = 'Ravenclaw'
file_content = hogwarts.to_csv(index=False)
f = open('houses.csv', 'w+')
f.write(file_content)
f.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("path", type=str, default="none", help="dataset_train.csv")
parser.add_argument("path2", type=str, default="none", help="thetas.csv")
args = parser.parse_args()
try:
if (args.path == "dataset_test.csv" and args.path2 == "thetas.csv"):
predict()
else:
print("wrong path")
except:
print("Error")
if __name__ == '__main__':
main()