Newer
Older
import random
import numpy
from algorithm.k_nearest_neighbors.distance_measure.euclidean_distance import euclidean_distance
from algorithm.pla.perceptron import Perceptron
from algorithm.pla.perceptron_learning_algorithm import train, train_pocket
from algorithm.decision_tree.decision_tree import DecisionTree
from algorithm.k_nearest_neighbors.k_nearest_neighbors_algorithm import KNearestNeighborsAlgorithm
from aufgaben.p4.testdata import get_evaluation_data
from aufgaben.p6.multiclass_error_rate import Multiclass_ErrorRate
from features.standard_deviation import standard_deviation
from features.arithmetic_mean import arithmetic_mean
from features.median import median
from features.extremwerte import maximum
def evaluate_algorithm(training_data, test_data, algorithm, evaluator, args=None):
if args is None:
args = {}
if isinstance(algorithm, DecisionTree):
algorithm.train(None, training_data)
else:
algorithm.train(training_data)
# Vergleiche alle Ergebnisse mit der erwarteten Klasse
for features, correct_class in test_data:
if isinstance(algorithm, KNearestNeighborsAlgorithm):
result = algorithm.classify(features, args['distance'], args['k'])
else:
result = algorithm.classify(features)
evaluator.evaluate(correct_class, result)
evaluator.print_table()
def evaluate():
features = [standard_deviation, arithmetic_mean, median, maximum]
test_data, training_data = get_evaluation_data(200,features)
# DecisionTree
print("\nDecision Tree:")
evaluate_algorithm(training_data, test_data,
DecisionTree(entropy_threshold=0.5, number_segments=25, print_after_train=True), Multiclass_ErrorRate(classes))
evaluate_algorithm(training_data, test_data, KNearestNeighborsAlgorithm(), Multiclass_ErrorRate(classes), {'distance': euclidean_distance, 'k': 5})
perceptron = Perceptron(weights, threshold, numpy.tanh)
train(perceptron, training_data, 10000, 0.1)
fehlerrate = Multiclass_ErrorRate(classes)
for features, correct_class in test_data:
result = perceptron.classify(features)
fehlerrate.evaluate(correct_class, result)
fehlerrate.print_table()
if __name__ == '__main__':
evaluate()