Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
112 views
in Technique[技术] by (71.8m points)

python - How to run a function on all Spark workers before processing data in PySpark?

I'm running a Spark Streaming task in a cluster using YARN. Each node in the cluster runs multiple spark workers. Before the streaming starts I want to execute a "setup" function on all workers on all nodes in the cluster.

The streaming task classifies incoming messages as spam or not spam, but before it can do that it needs to download the latest pre-trained models from HDFS to local disk, like this pseudo code example:

def fetch_models():
    if hadoop.version > local.version:
        hadoop.download()

I've seen the following examples here on SO:

sc.parallelize().map(fetch_models)

But in Spark 1.6 parallelize() requires some data to be used, like this shitty work-around I'm doing now:

sc.parallelize(range(1, 1000)).map(fetch_models)

Just to be fairly sure that the function is run on ALL workers I set the range to 1000. I also don't exactly know how many workers are in the cluster when running.

I've read the programming documentation and googled relentlessly but I can't seem to find any way to actually just distribute anything to all workers without any data.

After this initialization phase is done, the streaming task is as usual, operating on incoming data from Kafka.

The way I'm using the models is by running a function similar to this:

spark_partitions = config.get(ConfigKeys.SPARK_PARTITIONS)
stream.union(*create_kafka_streams())
    .repartition(spark_partitions)
    .foreachRDD(lambda rdd: rdd.foreachPartition(lambda partition: spam.on_partition(config, partition)))

Theoretically I could check whether or not the models are up to date in the on_partition function, though it would be really wasteful to do this on each batch. I'd like to do it before Spark starts retrieving batches from Kafka, since the downloading from HDFS can take a couple of minutes...

UPDATE:

To be clear: it's not an issue on how to distribute the files or how to load them, it's about how to run an arbitrary method on all workers without operating on any data.

To clarify what actually loading models means currently:

def on_partition(config, partition):
    if not MyClassifier.is_loaded():
        MyClassifier.load_models(config)

    handle_partition(config, partition)

While MyClassifier is something like this:

class MyClassifier:
    clf = None

    @staticmethod
    def is_loaded():
        return MyClassifier.clf is not None

    @staticmethod
    def load_models(config):
        MyClassifier.clf = load_from_file(config)

Static methods since PySpark doesn't seem to be able to serialize classes with non-static methods (the state of the class is irrelevant with relation to another worker). Here we only have to call load_models() once, and on all future batches MyClassifier.clf will be set. This is something that should really not be done for each batch, it's a one time thing. Same with downloading the files from HDFS using fetch_models().

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

If all you want is to distribute a file between worker machines the simplest approach is to use SparkFiles mechanism:

some_path = ...  # local file, a file in DFS, an HTTP, HTTPS or FTP URI.
sc.addFile(some_path)

and retrieve it on the workers using SparkFiles.get and standard IO tools:

from pyspark import SparkFiles

with open(SparkFiles.get(some_path)) as fw:
    ... # Do something

If you want to make sure that model is actually loaded the simplest approach is to load on module import. Assuming config can be used to retrieve model path:

  • model.py:

    from pyspark import SparkFiles
    
    config = ...
    class MyClassifier:
        clf = None
    
        @staticmethod
        def is_loaded():
            return MyClassifier.clf is not None
    
        @staticmethod
        def load_models(config):
            path = SparkFiles.get(config.get("model_file"))
            MyClassifier.clf = load_from_file(path)
    
    # Executed once per interpreter 
    MyClassifier.load_models(config)  
    
  • main.py:

    from pyspark import SparkContext
    
    config = ...
    
    sc = SparkContext("local", "foo")
    
    # Executed before StreamingContext starts
    sc.addFile(config.get("model_file"))
    sc.addPyFile("model.py")
    
    import model
    
    ssc = ...
    stream = ...
    stream.map(model.MyClassifier.do_something).pprint()
    
    ssc.start()
    ssc.awaitTermination()
    

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...