Java Programming Tutorials

Java programming tutorials with many code examples!

Java 8 flatMap practical example

In this post we’ll show a practical example of processing tree-like structure using streams and Java 8 flatMap.

As a tree-like structure let’s take… flights. A flight may go through a number of airports and each connection between two adjacent airports is a leg. So, we’ve got a collection of flights. Each of them consists of one or more legs, where each leg has origin and destination airports. The problem is how to get all the airports? Before Java 8 we had to write nested for-each loops with bunch of nested if-else that reduced readability even more. Since Java 8 this can be conveniently solved using Stream.flatMap().

Flight structure

A flight with its legs – list of connections between adjacent airports:

static class Flight {
    private final List<Leg> legs;

    public Flight(List<Leg> legs) {
        this.legs = legs;
        System.out.println("Flight with legs: " + legs);
    }

    public List<Leg> getLegs() {
        return legs;
    }
}

Leg connects adjacent airports – origin and destination:

static class Leg {
    String origin, destination;

    public Leg(String origin, String destination) {
        this.origin = origin;
        this.destination = destination;
    }

    @Override
    public String toString() {
        return "Leg(" + origin + "->" + destination + ')';
    }
}

Java 8 flatMap in practice

Having a collection of flights we can traverse it as a stream and collect all airports from all legs:

public Set<String> collectAirports(Collection<Flight> flights) {
    return flights.stream()
            // flights -> stream of list of legs
            .map(Flight::getLegs)
            // each list of legs -> stream of legs
            .flatMap(Collection::stream)
            // each leg -> stream of airports
            .flatMap(this::getLegAirports)
            // collect results
            .collect(toSet());
}

// Extract leg airports as stream:
private Stream<String> getLegAirports(Leg leg) {
    return Stream.of(leg.origin, leg.destination);
}

How it all works

The idea behind this is that we want to convert everything into one, giant stream of airports. To do that we have to turn each tree level into a stream.

The first map(Flight::getLegs) just extracts list of legs from each flight, therefore we got rid of flights. It’s better, but now have stream of lists of legs, whereas we want to have only stream of legs. How to do that? This is where flatMap() steps in! It takes a function that will convert given thing into a stream and process all streams as one. So, we can use Collection::stream to turn each list legs into a streams of legs – Stream<Collection<Leg>> into just Stream<Leg>!

After the first flatMap() we ended up with Stream<Leg>. Now, we need to extract airports from each leg as stream. The problem is that the airports are simple fields there and we cannot just use Collection::stream to turn them into a stream. This can be solved by simple helper method, that will create a stream of them using Stream.of(…). This way we’ve turned a stream of legs into a stream of airports. Now we just need to collect them into a set to have unique airports!

Unit test

To see all flatMap in action let’s write a unit test and actually collect the airports:

package com.farenda.java.util.stream;

import com.farenda.java.util.stream.PracticalFlatMap.Flight;
import com.farenda.java.util.stream.PracticalFlatMap.Leg;
import org.junit.Test;

import java.util.*;

import static java.util.Arrays.asList;
import static java.util.stream.Collectors.toList;
import static org.junit.Assert.*;

public class PracticalFlatMapTest {

    private PracticalFlatMap practical = new PracticalFlatMap();

    @Test
    public void shouldCollectAirports() {
        //given:
        List<Flight> flights = createFlights(
                createLegs("a", "b", "c"),
                createLegs("a", "d"),
                createLegs("b", "d", "e"));

        Set<String> expectedAirports
                = new HashSet<>(asList("a", "b", "c", "d", "e"));

        //when:
        Set<String> collected = practical.collectAirports(flights);

        //then:
        assertEquals(expectedAirports, collected);
    }

    private List<Leg> createLegs(String... airports) {
        List<Leg> legs = new ArrayList<>(airports.length-1);
        for (int i = 1; i < airports.length; ++i) {
            legs.add(new Leg(airports[i-1], airports[i]));
        }
        return legs;
    }

    private List<Flight> createFlights(List<Leg>... legs) {
        return Arrays.stream(legs)
                .map(Flight::new)
                .collect(toList());
    }
}

The above code produces the following output:

Flight with legs: [Leg(a->b), Leg(b->c)]
Flight with legs: [Leg(a->d)]
Flight with legs: [Leg(b->d), Leg(d->e)]
Share with the World!