Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
770 views
in Technique[技术] by (71.8m points)

scala - How do I create a shared SparkSession fixture in scalatest that persists between test suites?

I am very new to Scala and Scalatest, but have some experience with Pyspark, and I'm trying to learn Scala from a Spark perspective.

I'm currently trying to get my head around exactly the correct way to set up and use fixtures within Scalatest.

The way I imagine this working, and this may not be how this is done in Scala, would be that I would setup a SparkSession as a global fixture shared amongst suites of test, then with potentially several sample datasets hooking into that SparkSession that may be used for an individual test suite with several tests etc.

At the moment, I have some code that is working for running several tests in the same suite using the shared fixture using the BeforeAndAfterAll trait; however, if I run several suites at the same time, the suite that completes first appears to terminate the SparkSession and any further tests fail with java.lang.IllegalStateException: Cannot call methods on a stopped SparkContext.

So, I was wondering if there is a way to create the SparkSession so that it will only be stopped when all running suites have completed; or if I'm barking up the wrong tree and there is a better approach altogether - as I say, I'm very new to Scala so this might just not be how you do this, in which case alternative suggestions are very welcome.

First I have a package testSetup and I'm creating a trait for the SparkSession:

package com.example.test

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

import org.scalatest._
import org.scalatest.FixtureSuite
import org.scalatest.funsuite.FixtureAnyFunSuite

package testSetup {

 trait SparkSetup  {
    val spark = SparkSession
      .builder
      .master("local")
      .appName(getClass.getSimpleName.replace("$", ""))
      .getOrCreate()
      
    spark.sparkContext.setLogLevel("ERROR")
  }
    

And then using this in a trait to set up some sample data:

 trait TestData extends SparkSetup {

    def data(): DataFrame = {

      val testDataStruct = StructType(List(
                              StructField("date", StringType, true),
                              StructField("period", StringType, true),
                              StructField("ID", IntegerType, true),
                              StructField("SomeText", StringType, true)))

      val testData = Seq(Row("01012020", "10:00", 20, "Some Text"))

      spark.createDataFrame(spark.sparkContext.parallelize(testData), testDataStruct)
      
    }
  }

I'm then putting these together to run the test via withFixture and use afterAll to close the SparkSession; this is clearly where something isn't quite right, but I'm not really sure what:

  trait DataFixture extends funsuite.FixtureAnyFunSuite with TestData with BeforeAndAfterAll { this: FixtureSuite =>

    type FixtureParam = DataFrame

    def withFixture(test: OneArgTest) = {

      super.withFixture(test.toNoArgTest(data())) // "loan" the fixture to the test
    }

    override def afterAll() {
      spark.close()
    }        
  }
}

I'm currently testing a basic function to hash the columns in a DataFrame dynamically, with the option to exclude some; here's the code:

package com.example.utilities

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions._

object GeneralTransforms {
    def addHashColumn(inputDataFrame: DataFrame, exclusionCols: List[String]): DataFrame = {
        
        val columnsToHash = inputDataFrame.columns.filterNot(exclusionCols.contains(_)).toList
        
        inputDataFrame.withColumn("RowHash", sha2(concat_ws("|", columnsToHash.map(col) : _*), 256))
    }
}

And the current test cases:-

import testSetup._
import com.example.utilities.GeneralTransforms._

import org.apache.spark.sql.DataFrame

class TestData extends funsuite.FixtureAnyFunSuite with DataFixture {
  test("Test data has correct columns") { inputData => 
    val cols = inputData.columns.toSeq
    val expectedCols = Array("date", "period", "ID", "SomeText").toSeq
    
    assert(cols == expectedCols)
  }
}

class TestAddHashColumn extends funsuite.FixtureAnyFunSuite with DataFixture {
  
  test("Test new hash column added") { inputData =>
    val hashedDf = addHashColumn(inputData, List())
    val initialCols = inputData.columns.toSeq
    val cols = hashedDf.columns.toSeq
    
    assert(initialCols.contains("RowHash") == false)
    assert(cols.contains("RowHash") == true)
  }

  test("Test all columns hashed - no exclusion") { inputData =>
    val hashedDf = addHashColumn(inputData, List())
    val rowHashColumn = hashedDf.select("RowHash").first().getString(0)
    val checkString = "01012020|10:00|20|Some Text"
    val expectedHash = String.format("%064x", new java.math.BigInteger(1, java.security.MessageDigest.getInstance("SHA-256").digest(checkString.getBytes("UTF-8"))))

    assert(rowHashColumn == expectedHash)
  }

  test("Test all columns hashed - with exclusion") { inputData =>

    val excludedColumns = List("ID", "SomeText")
    val hashedDf = addHashColumn(inputData,excludedColumns)
    val rowHashColumn = hashedDf.select("RowHash").first().getString(0)
    val checkString = "01012020|10:00"
    val expectedHash = String.format("%064x", new java.math.BigInteger(1, java.security.MessageDigest.getInstance("SHA-256").digest(checkString.getBytes("UTF-8"))))

    assert(rowHashColumn == expectedHash)

  }
}

Both test suites work absolutely fine in isolation; it's only when running both together that I have an issue. This can also be resolved by adding parallelExecution in Test := false to my build.sbt, but it would be nice to be able to allow this to happen in parallel as I add more and more tests.

I also wondered if it's something that can just be resolved by running something in the BeforeAll/AfterAll that checks for other instances of SparkSession against the context, but I'm not sure how to do that and wanted to exhaust this avenue first before I go down another rabbit hole!

Edit

Since posting, I've spent some more time on this and have switched things around a bit, with a helper class for dealing with the Spark setup. In this I've set it to create a pseudo master spark session using the SparkSession.builder.getOrCreate method, but then created a new spark session for the actual tests - this will allow me to have different configs, and do stuff like have different temp table registrations and so on. However, I've still not been able to solve the shut down of spark - obviously if I run spark.stop() against any of the running sessions on the SparkContext it will stop the context for all sessions.

And it would appear that until sbt is exited the context is not stopped?


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
等待大神答复

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...