import random

import numpy
from algorithm.k_nearest_neighbors.distance_measure.euclidean_distance import euclidean_distance
from algorithm.pla.perceptron import Perceptron
from algorithm.pla.perceptron_learning_algorithm import train, train_pocket
from algorithm.decision_tree.decision_tree import DecisionTree
from algorithm.k_nearest_neighbors.k_nearest_neighbors_algorithm import KNearestNeighborsAlgorithm
from aufgaben.p4.testdata import get_labeled_testdata
from aufgaben.p6.error_rate import ErrorRate
from aufgaben.p6.multiclass_error_rate import Multiclass_ErrorRate


def evaluate_algorithm(training_data, test_data, algorithm, evaluator, args=None):
    if args is None:
        args = {}
    algorithm.train(training_data)

    # Vergleiche alle Ergebnisse mit der erwarteten Klasse
    for features, correct_class in test_data:
        if isinstance(algorithm, KNearestNeighborsAlgorithm):
            result = algorithm.classify(features, args['distance'], args['k'])
        else:
            result = algorithm.classify(features)
        evaluator.evaluate(correct_class, result)
    evaluator.print_table()


def evaluate():
    test_data, training_data = get_labeled_testdata()

    # DecisionTree
    print("\nDecision Tree:")
    evaluate_algorithm(training_data, test_data,
                       DecisionTree(entropy_threshold=0.5, number_segments=10, print_after_train=True), Multiclass_ErrorRate())

    # KNN
    print("\nKNN")
    evaluate_algorithm(training_data, test_data, KNearestNeighborsAlgorithm(), Multiclass_ErrorRate(), {'distance': euclidean_distance, 'k': 10})

    # PLA
    print("\nPLA")
    weights = [random.random()]
    threshold = 0
    perceptron = Perceptron(weights, threshold, numpy.tanh)
    train(perceptron, training_data, 100, 0.1)
    fehlerrate = Multiclass_ErrorRate()
    for features, correct_class in test_data:
        result = perceptron.classify(features)
        fehlerrate.evaluate(correct_class, result)
    fehlerrate.print_table()

    # Pocket
    print("\nPocket")
    weights = [random.random()]
    threshold = 0.5
    perceptron = Perceptron(weights, threshold, numpy.tanh)
    train_pocket(perceptron, training_data, 100, 0.1)
    fehlerrate = Multiclass_ErrorRate()
    for features, correct_class in test_data:
        result = perceptron.classify(features)
        fehlerrate.evaluate(correct_class, result)
    fehlerrate.print_table()


if __name__ == '__main__':
    evaluate()