shmosel shmosel - 3 months ago 15
Java Question

Can I limit Collectors.toMap() entries?

I'm looking for a way to limit the number of entries produced by

with a merge function. Consider the following example:

Map<String, Integer> m = Stream.of("a", "a", "b", "c", "d")
.collect(toMap(Function.identity(), s -> 1, Integer::sum));

The problem with the above is that I'll only have 2 elements in the resulting map (
a=2, b=1
). Is there any convenient way to short-circuit the stream once it's processed 3 distinct keys?


A possible solution for this would be to write your own Spliterator, which would wrap the spliterator of a given Stream. This Spliterator would delegate the advancing calls to the wrapped spliterator and contain the logic of counting of many distinct elements have appeared.

For that, we can subclass AbstractSpliterator and provide our own tryAdvance logic. In the following, all elements encountered are added to a set. When the size of that set becomes greater than our maximum or when the wrapped spliterator has no remaining elements, we return false to indicate that there are no remaining elements to consider. This will stop when the numbers of distinct elements have been reached.

private static <T> Stream<T> distinctLimit(Stream<T> stream, int max) {
    Spliterator<T> spltr = stream.spliterator();
    Spliterator<T> res = new AbstractSpliterator<T>(spltr.estimateSize(), spltr.characteristics()) {

        private Set<T> distincts = new HashSet<>();
        private boolean stillGoing = true;

        public boolean tryAdvance(Consumer<? super T> action) {
            boolean hasRemaining = spltr.tryAdvance(elem -> {
                if (distincts.size() > max) {
                    stillGoing = false;
                } else {
            return hasRemaining && stillGoing;
    return, stream.isParallel()).onClose(stream::close);

With your example code, you would have:

Map<String, Long> m =
    distinctLimit(Stream.of("a", "a", "b", "c", "d"), 3)
        .collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));

and the output would be the expected {a=2, b=1, c=1}, i.e. a map with 3 distinct keys.