Implementing Q-Learning in Java

In this project, we’ll implement a basic Q-Learning algorithm and use Swagger and Spring Boot to create a server to solve mazes.

Project Setup

As we’re going to use Swagger to generate a lot of the code, we’ll start here.

Open up a browser and go to the URL https://editor.swagger.io

From the menu, select File then Import URL.

Give the URL for the Maze swagger file (this can be found by looking at the scenario documentation) : https://www.aisandbox.dev/scenarios/maze.yaml

Swagger Editor

Select “Generate Server” and choose “Spring” as the server type.

Your browser will download a zip-file inside which is a skeleton project. Save this on your hard drive and extract them, then use your favourite IDE to open them as a maven project. Depending on your IDE you should see something like this:

Java IDE

Changing the project settings

Although the generated code is functional, the template that it comes from was written for Java 1.7, which is fine, but if you are using a JDK 9 or higher you will get compiler errors. To fix these we need to update the projects pom.xml file. Open up the file and make the following changes:

At the top of the file, locate the properties section

<properties>
    <java.version>1.7</java.version>
    <maven.compiler.source>${java.version}</maven.compiler.source>
    <maven.compiler.target>${java.version}</maven.compiler.target>
    <springfox-version>2.9.2</springfox-version>
</properties>

Change the Java version to 8

Scroll down to the dependencies section, It should look like this:

<dependencies>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-tomcat</artifactId>
    </dependency>

    …
    …
</dependencies>

Add into the following dependencies:

<dependency>
  <groupId>javax.xml.bind</groupId>
  <artifactId>jaxb-api</artifactId>
  <version>2.3.0</version>
</dependency>
<dependency>
  <groupId>com.sun.xml.bind</groupId>
  <artifactId>jaxb-core</artifactId>
  <version>2.3.0</version>
</dependency>
<dependency>
  <groupId>com.sun.xml.bind</groupId>
  <artifactId>jaxb-impl</artifactId>
  <version>2.3.0</version>
</dependency>

Alternatively you can replace your POM file with the updated one here

Now that you have updated the dependencies, you can run the maven target “spring-boot:run”. This will compile and start the server on port 8080.

Random Responses

Although not essential, before spending too much time coding an AI, it’s sometimes useful to start with a simple random responder. This simple AI returns a random answer to each API request, and in doing so proves that you have the server working correctly In your project, the API end point is in the file src/main/java/io/swagger/api/ApiApiController.java, open it in your IDE and locate the method apiMazePost. Replace it with the following code:

Random rand = new Random();

public ResponseEntity<MazeResponse> apiMazePost(@ApiParam(value = "", required = true) @Valid @RequestBody MazeRequest body) {
    List<String> moves = body.getConfig().getValidMoves();
    MazeResponse response = new MazeResponse();
    response.setMove(moves.get(rand.nextInt(moves.size())));
    return new ResponseEntity<MazeResponse>(response,HttpStatus.OK);
}

Restart your server (remember to stop it first or it wont be able to connect to the correct port), then start the AI Sandbox Application. Select the Maze scenario then configure the scenario as follows:

Add a single agent, then edit its URL to http://localhost:8080/api/maze

Sandbox setup

Select next, then start the simulation. You should see an agent randomly wandering the maze.

Implementing the Q-Learning algorithm

Now we have a working server, the next step is to code the Q-Learning algorithm. First lets remind ourselves of the key points:

Lets make a start by defining the data structures we’ll need. We’ll store the main table as a double layered map. The outer layer maps a string (the state) to the inner layer, which is a map of actions to values.

package io.swagger.qlearning;

import java.util.HashMap;
import java.util.Map;
import java.util.OptionalDouble;

public class QLearing {

    private double learningRate;

    private double discount;

    private Map<String, Map<String,Double>> qtable = new HashMap<>();

}

There are more efficient ways to store this (a two dimensional array for instance), but for this implementation we’re prioritising readability over performance.

To initialise the table, we’ll pass a list of all possible states, a list of all actions and the initial value.

    public QLearing(String[] states, String[] actions, double initialValue) {
        for (String state : states) {
            Map<String,Double> row = new HashMap<>();
            for (String action : actions) {
                row.put(action,initialValue);
            }
            qtable.put(state,row);
        }
    }

Next we’ll define some basic getters and setters as well as a utility function for returning the highest value possible from a state.

   public double getLearningRate() {
        return learningRate;
    }

    public double getDiscount() {
        return discount;
    }

    public void setLearningRate(double rate) {
        learningRate = rate;
    }

    public void setDiscount(double value) {
        discount = value;
    }

    private double getHighestValueFromState(String state) {
        Map<String, Double> map = qtable.get(state);
        OptionalDouble max = map.values().stream().mapToDouble(v -> v).max();
        return max.isPresent() ? max.getAsDouble() : Double.MIN_VALUE;
    }

Finally we need to implement the two main public methods to get the next move and to learn from the last one.

    public String getBestMove(String state) {
       String bestMove = "North"; // default value - use if everything else fails
       Map<String, Double> row = qtable.get(state);
       if (row != null) {
           double bestScore = Double.NEGATIVE_INFINITY;
           for (Map.Entry<String, Double> entry : row.entrySet()) {
               if (entry.getValue() > bestScore) {
                   bestMove = entry.getKey();
                   bestScore = entry.getValue();
               }
           }
       }
       return bestMove;
   }

   public void learn(String startState, String action, double reward, String endState) {
       Map<String, Double> oldRow = qtable.get(startState);
       Map<String, Double> newRow = qtable.get(endState);
       double maxNewRow = newRow.values().stream().mapToDouble(value -> value).max().getAsDouble();
       oldRow.put(action, (1 - learningRate) * oldRow.get(action) + learningRate * (reward + discount * maxNewRow));
   }

Implementing a maze solver

Now we have a Q-Learning class, and an API server, the final step is to put the two together. This involves going back to the API controller class and reimplementing the apiMazePost method.

    QLearing q = null;
    String currentBoard = null;

    public ResponseEntity<MazeResponse> apiMazePost(@ApiParam(value = "", required = true) @Valid @RequestBody MazeRequest body) {
        // check we're on the same baord as last time
        if ((currentBoard==null)||(!currentBoard.equals(body.getConfig().getBoardID())) {
            // this is a new board - regenerate the Q-Table
            q = new QLearing(
                    getStateArray(body.getConfig().getWidth(),body.getConfig().getHeight()),
                    body.getConfig().getValidMoves(),
                    0.0
            );
            currentBoard = body.getConfig().getBoardID();
        }
        // if there is any history, learn from it
        if (body.getHistory()!=null) {
            q.learn(
                    getState(body.getHistory().getLastPosition()),
                    body.getHistory().getAction(),
                    body.getHistory().getReward(),
                    getState(body.getHistory().getNewPosition())
            );
        }
        // get the next move
        String move = q.getBestMove(getState(body.getCurrentPosition()));
        MazeResponse response = new MazeResponse();
        response.setMove(move);
        return new ResponseEntity<MazeResponse>(response,HttpStatus.OK);
    }

    private String getState(int x,int y) {
        return x + ":" + y;
    }

    private String getState(Position p) {
        return getState(p.getX(),p.getY());
    }

    private String[] getStateArray(int width,int height) {
        List<String> states = new ArrayList<>();
        for (int x=0;x<width;x++) {
          for (int y=0;y<height;y++) {
              states.add(getState(x,y));
          }
        }
        return states.toArray(new String[0]);
    }