Perfect Snake

Ich habe auf diesem Blog schon über eine Reihe von Snake Clonen [1, 2, 3, 4, 5] geschrieben, die zum Teil auch Autopilot-Strategien hatten [6, 7]. Die Autopiloten waren zwar meist interessant anzusehen — vor allem bei hohen Geschwindigkeiten — aber bei weitem nicht perfekt.

Auch wenn der Titel etwas zu viel verspricht, schafft es dieser Autopilot (zumindest manchmal) perfekte Spiele zu spielen.

Eine perfekte Partie Snake

Und falls dieses gif nicht überzeugt, kann man den Autopiloten online — dank TensorFlow.js — direkt im Browser ausprobieren auf snake.schawe.me.

Aber was steckt dahinter?

Neuronale Netze

Wenn man nicht clever genug ist, eine direkte Lösung für ein Problem zu finden, kann man versuchen ein neuronales Netz auf die Lösung des Problems zu trainieren. Vor einigen Jahren hat ein Artikel, in dem ein neuronales Netz trainiert wurde alte Atari-Spiele zu spielen, für mediale Aufmerksamkeit gesorgt. Und die gleiche Idee des Reinforcement Learning werde ich hier (nicht als erster [8, 9]) auf Snake anwenden.

Die grundlegende Idee von Reinforcement Learning ist relativ einsichtig: Wir belohnen das Modell für gute Entscheidungen, sodass es lernt mehr gute Entscheidungen zu treffen. In unserem Fall werden gute Entscheidungen dadurch definiert, dass sie zu einer hohen Punktzahl, also Länge der Schlange am Spielende, führen.

Glücklicherweise können wir auf die Literatur zurückgreifen, wie wir diese grundsätzliche Idee umsetzen können. Das Modell, für das ich mich entschieden habe, ist ein Actor-Critic Ansatz. Dabei nutze ich ein neuronales Netz, das als Input den aktuellen Zustand des Spielfeldes bekommt — wie genau dieser Zustand aussieht, diskutieren wir weiter unten. Dann geht es durch ein paar Schichten und endet in zwei „Köpfen“. Einer ist der Actor, mit drei Output-Neuronen, die für „nach links“, „nach rechts“ und „geradeaus weiter“ stehen. Der andere ist der Critic, der ein Output-Neuron hat, das abschätzt wie lang die Schlange, ausgehend von der aktuellen Situation, noch werden kann — also wie gut die aktuelle Situation ist.

Das Training läuft dann so ab, dass ein ganzes Spiel gespielt wird, folgend den Vorschlägen des Actors mit etwas rauschen, um neue Strategien zu erkunden. Sobald es beendet ist, weil die Schlange sich oder eine Wand gebissen hat, wird der Critic mit allen Zuständen des Spielverlaufs darauf trainiert, Schätzungen abzugeben, die möglichst gut zu der tatsächlich erreichten Länge am Spielende passen. Außerdem wird der Actor darauf trainiert gute Entscheidungen zu treffen, indem zu den Zuständen des Spielverlaufs andere Entscheidungen getroffen werden und die Bewertung des Critic der resultierenden Situationen als Qualität der Entscheidung genutzt wird. Actor und Critic helfen sich also gegenseitig besser zu werden. Der gemeinsame Teil des neuronalen Netzes sollte im Idealfall nach genügend gespielten Spielen dabei ein „Verständnis“ für Snake entwickeln. Genial!

Technische Nebensächlichkeiten

Meine Implementierung benutzt die Python Bibliotheken Keras und Tensorflow zum Training und multiJSnake als Environment. Wir steuern also einen Java-Prozess, um unser neuronales Netz in Python zu trainieren. Diese Entscheidung ist etwas unorthodox, aber bot Potential für einen Blogpost auf dem Blog meines Arbeitgebers.

Wir können das Environment getrost als Black-Box betrachten, die dafür sorgt, dass die Regeln von Snake befolgt werden.

Lokale Informationen

Eine der wichtigsten Entscheidungen ist nun, wie der Input in das Modell aussieht. Die einfachste Variante, die sich auch gut zum Testen eignet, ist die lokale Information rund um den Kopf der Schlange: Drei Neuronen, die jeweils 1 oder 0 sind, wenn das Feld links, rechts und geradeaus vom Kopf belegt sind (und acht weitere für etwas mehr Weitsicht auf die Diagonalen und übernächste Felder vorne, rechts, links und diesmal auch zurück). Damit die Schlange auch das Futter finden kann, fügen wir noch 4 weitere Neuronen hinzu, die per 1 oder 0 anzeigen, ob das Futter in, rechts, links oder entgegengesetzt der Bewegungsrichtung der Schlange ist.

Mit diesem Input füttern wir eine einzelne vollvernetzte Schicht, hinter der wir direkt die Actor und Critic Köpfe anschließen.

Layout des neuronalen Netzes mit lokaler Information (Visualisierung: netron)

Das reicht aus, damit die Schlange nach ein paar tausend Trainingsspielen zielstrebig auf das Futter zusteuert und sich selbst ausweicht. Allerdings reicht es noch nicht, um zu verhindern, dass sie sich selbst in Schlaufen fängt. Da war der Autopilot von rsnake besser.

Ein paar Spiele mit lokaler Information

Globale Informationen

Um der Schlange eine Chance zu geben zu erkennen, dass sie sich gerade selbst fängt, sollte man ihr erlauben das ganze Spielfeld zu sehen — schließlich sehen menschliche Spieler auch das ganze Spielfeld. Bei einem \(10 \times 10\) Spielfeld haben wir also schon mindestens 100 Input-Neuronen, sodass vollvernetzte Schichten zu sehr großen Modellen führen würden. Stattdessen bietet es sich bei solchen zweidimensionalen Daten an convolutional neuronale Netze zu nutzen. Um es unserer Schlange etwas einfacher zu machen, werden wir unser Spielfeld in drei Kanäle aufteilen:

  1. der Kopf: nur an der Position des Kopfes ist eine 1, der Rest ist 0
  2. der Körper: die Positionen an denen sich der Körper befindet zeigen wie viele Zeitschritte der Körper noch an dieser Position sein wird
  3. das Futter: nur an der Position des Futters ist eine 1, der Rest ist 0

Was ein Mensch sieht und was wir unserem neuronalen Netz zeigen

Dies ist auch kein unfairer Vorteil, schließlich sehen menschliche Spieler das Bild auch mit drei Farbkanälen.

Und damit die Schlange nicht auch noch lernen muss was rechts und links bedeutet, geben wir dem Actor 4 Outputs, die für Norden, Osten, Süden und Westen stehen.

Layout des Convolutional-Neural-Networks (Visualisierung: netron)

Dieses Modell-Layout verdient es dann schon eher als Deep Learning bezeichnet zu werden. Weitere Modell-Parameter, können auf github.com/surt91/multiJSnake nachgeschlagen werden.

Nach einigen zehntausend Trainingsspielen funktioniert dieses Modell dann tatsächlich gut genug, um regelmäßig perfekte Spiele auf einem \(10 \times 10\) Spielfeld zu erreichen. Aber da ich es nur auf \(10 \times 10\) Feldern trainiert habe, versagt es leider auf jeder anderen Größe.

Twitter Profilhintergrundfarben

Für ein Projekt habe ich Tweets von >8‘000‘000 Twitter-Usern eingesammelt. Dabei fallen noch eine Reihe weiterer Daten an, wie die Profilhintergrundfarbe. Es wäre eine Schande diese Daten einfach verkommen zu lassen, also habe ich nach einer Möglichkeit gesucht diese Information ansprechend darzustellen, was sich als weniger trivial herausgestellt hat, als ich ursprünglich angenommen hatte: Im Idealfall sollten ähnliche Farben nahe beieinander liegen, allerdings ist der RGB Farbraum ein dreidimensionaler Kubus, ein Bild aber nur zweidimensional, sodass es keine „richtige“ Art und Weise gibt, ähnliche Farben nebeneinander anzuordnen.

Ich habe mich hier dafür entschieden eine 2D Hilbert-Kurve durch mein Bild zu legen und die Farben in der Reihenfolge zu zeichnen, in der eine 3D Hilbert-Kurve ihnen im RGB-Kubus begegnet. Wenn man dann noch die beiden Standardhintergrundfarben #F5F8FA und #C0DEED ignoriert, sieht das Ergebnis so aus.

Twitter-Profil-Hintergrundfarbe

Und dank der Python Pakete hilbertcurve und pypng ist der Code sogar ziemlich harmlos:

from math import ceil, sqrt, log2

from hilbertcurve.hilbertcurve import HilbertCurve
import png


"""
    turn an RGB string like `#C0DEED` into a tuple of integers,
    i.e., coordinates of the RGB cube
"""
def str2rgb(s):
    s = s.strip("#")
    return (int(s[0:2], 16), int(s[2:4], 16), int(s[4:6], 16))


"""
    `color_histogram` is a dict mapping an rgb string like `#F5F8FA`
    to the number of usages of this color
"""
def plot_background_colors(color_histogram, filename="colors.png"):
    defaults = {"F5F8FA", "C0DEED"}

    data = {str2rgb(rgb): d for rgb, d in color_histogram if rgb not in defaults}

    # calculate the size of the resulting image
    # for a 2D Hilbert curve, it mus be square with a width, which is a power of 2
    num_pixels = sum(data.values())
    min_width = ceil(sqrt(num_pixels))
    exponent = ceil(log2(min_width))
    width = 2**exponent

    # output buffer for a width x width png, with 4 color values per pixel
    buf = [[0 for _ in range(4 * width)] for _ in range(width)]

    hc2 = HilbertCurve(exponent, 2)
    # there are 256 = 2^8 values in each direction of the RGB cube
    hc3 = HilbertCurve(8, 3)

    sorted_rgbs = sorted(data.keys(), key=lambda x: hc3.distance_from_point(x))

    idx = 0
    for rgb in sorted_rgbs:
        for _ in range(data[rgb]):
            # get the coordinate of the next pixel
            x, y = hc2.point_from_distance(idx)
            # assign the RGBA values to the pixel
            buf[x][4 * y] = rgb[0]
            buf[x][4 * y + 1] = rgb[1]
            buf[x][4 * y + 2] = rgb[2]
            buf[x][4 * y + 3] = 255

            idx += 1

    png.from_array(buf, 'RGBA').save(filename)

Das Histogram, das als Input benötigt wird war in meinem Fall nur eine SQL Query entfernt:

SELECT profile_background_color, COUNT(profile_background_color) FROM users
    GROUP BY profile_background_color;

inline-python

Für jeden Zweck das passende Werkzeug: In meinem Alltag bedeutet das, dass ich Simulationen in Rust schreibe und in Python visualisiere. Dank inline-python geht das sogar sehr reibungslos.

use inline_python::python;

fn main() {
    let x: Vec<f32> = (0..628).map(|i| i as f32 / 100.).collect();
    let y: Vec<f32> = x.iter().map(|x| x.sin()).collect();

    python! {
        import numpy as np
        from matplotlib import pyplot as plt

        plt.plot('x, 'y)
        plt.show()
    }
}

Dieses Minimalbeispiel ist natürlich nicht nützlich, aber ich habe es bereits produktiv genutzt, um Dynamik auf petgraph Graphen zu simulieren und ihren Zustand per graph-tool zu visualisieren.

Graph state visualized with graph-tool