Skip to content
Snippets Groups Projects
test_pla.py 1.25 KiB
Newer Older
from algorithm.pla.perceptron import Perceptron
from algorithm.pla.transfer_functions import normalized_tanh
from algorithm.pla.perceptron_learning_algorithm import train_pocket, update_weight
from server.testserver import get_testdata, send_result


def test_pla():
    response = get_testdata("perceptron_training")

    learning_rate = response['learning-rate']
    initial_weights = response['initial-weights']

    # Der Threshold ist der erste Wert in der Liste.
    # Der Rest sind die anderen Gewichte
    threshold = initial_weights[0] * -1
    weights = initial_weights[1: len(initial_weights)]

    perceptron = Perceptron(weights, threshold, normalized_tanh)

    results = []
    for data in response['training-data']:
        # Hole die einzelnen Traningsdaten
        test_id = data['id']
        features = data['input']
        correct_class = data['class']

        classification = perceptron.classify(features)
        update_weight(perceptron, features, correct_class, classification, learning_rate)

        # Hänge das Ergebnis in der Liste an
        results.append({'id': test_id, 'weights': perceptron.weights})

    send_result('perceptron_training', {'session': response['session'], 'results': results})


if __name__ == '__main__':
    test_pla()