Apache Spark flatMap Example

Spark flatMap example is mostly similar operation with RDD map operation. It is also defined in RDD abstract class of spark core library and same as map it also is a transformation kind of operation hence it is lazily evaluated.

Spark RDD flatMap function returns a new RDD by first applying a function to all elements of this RDD, and then flattening the results.

Spark flatMap is a transformation operation of RDD which accepts a function as an argument. Same as flatMap, this function will be applied to the source RDD and eventually each elements of the source RDD and will create a new RDD as a resulting values. One step more than RDD map operation, it accepts the argument function which returns array, list or sequence of elements instead of a single element. As a final result it flattens all the elements of the resulting RDD in case individual elements are in form of list, array, sequence or any such collection. Let’s check it’s behavior from following image.

Apache Spark flatMap Example
Apache Spark flatMap Example

As you can see in above image RDD X is the source RDD and RDD Y is a resulting RDD. As per our typical word count example in Spark, RDD X is made up of individual lines/sentences which is distributed in various partitions, with the flatMap transformation we are extracting separate array of words from sentence. But instead of array flatMap function will return the RDD with individual words rather than RDD with array of words.

Important points to note are,

  • flatMap is a transformation operation in Spark hence it is lazily evaluated
  • It is a narrow operation as it is not shuffling data from one partition to multiple partitions
  • Output of flatMap is flatten
  • flatMap parameter function should return array, list or sequence (any subtype of scala.TraversableOnce)

Let’s take some examples,

Spark flatMap Example Using Scala
scala> val x = sc.parallelize(List("spark rdd example",  "sample example"), 2)

// map operation will return Array of Arrays in following case : check type of res0
scala> val y = x.map(x => x.split(" ")) // split(" ") returns an array of words
scala> y.collect
res0: Array[Array[String]] = Array(Array(spark, rdd, example), Array(sample, example))

// flatMap operation will return Array of words in following case : Check type of res1
scala> val y = x.flatMap(x => x.split(" "))
scala> y.collect
res1: Array[String] = Array(spark, rdd, example, sample, example)

// rdd y can be re written with shorter syntax in scala as 
scala> val y = x.flatMap(_.split(" "))
scala> y.collect
res2: Array[String] = Array(spark, rdd, example, sample, example)
Spark flatMap Example Using Java 8
// Basic map example in Java 8
package com.backtobazics.sparkexamples;

import java.util.Arrays;
import java.util.List;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

public class FlatMapExample {
    public static void main(String[] args) throws Exception {
        JavaSparkContext sc = new JavaSparkContext();
        
        // Parallelized with 2 partitions
        JavaRDD<String> rddX = sc.parallelize(
                Arrays.asList("spark rdd example", "sample example"),
                2);
        
        // map operation will return List of Array in following case
        JavaRDD<String[]> rddY = rddX.map(e -> e.split(" "));
        List<String[]> listUsingMap = rddY.collect();
        
        // flatMap operation will return list of String in following case
        JavaRDD<String> rddY2 = rddX.flatMap(e -> Arrays.asList(e.split(" ")));
        List<String> listUsingFlatMap = rddY2.collect();
    }
}
Spark flatMap Example Using Python
# Bazic map example in python
>>> x = sc.parallelize(["spark rdd example", "sample example"], 2)

# map operation will return Array of Arrays in following case (check the result)
>>> y = x.map(lambda x: x.split(' '))
>>> y.collect()
[['spark', 'rdd', 'example'], ['sample', 'example']]

# flatMap operation will return Array of words in following case (check the result)
>>> y = x.flatMap(lambda x: x.split(' '))
>>> y.collect()
['spark', 'rdd', 'example', 'sample', 'example']

References:

Leave a Reply

Your email address will not be published. Required fields are marked *

You may use these HTML tags and attributes: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>