Skip to content

Commit c51dd28

Browse files
committed
allow env based injection of env configuraiton for connect, recognise SPARK_REMOTE - allows remote testing ~cut-release
1 parent 75c3cc4 commit c51dd28

4 files changed

Lines changed: 88 additions & 48 deletions

File tree

connect/src/main/4.0.0-scala/org/apache/spark/sql/SparkConnectServerUtils.scala

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package org.apache.spark.sql
1818

1919
import com.sparkutils.testing.ConnectSession
20-
import com.sparkutils.testing.SparkTestUtils.{DEBUG_CONNECT_LOGS_SYS, FLAT_JVM_OPTION, MAIN_CLASSPATH, booleanEnvOrProp, classPathJars, connectServerJars, testClassPaths}
20+
import com.sparkutils.testing.SparkTestUtils.{DEBUG_CONNECT_LOGS_SYS, FLAT_JVM_OPTION, MAIN_CLASSPATH, booleanEnvOrProp, classPathJars, connectServerJars, stringEnvOrProp, testClassPaths}
2121
import com.sparkutils.testing.TestUtilsEnvironment.{onDatabricksFS, onFabricOrSynapse}
2222
import org.apache.spark.{SparkBuildInfo, sql}
2323
import org.apache.spark.sql.connect.SparkSession
@@ -288,56 +288,61 @@ object SparkConnectServerUtils {
288288
// classic non-shared databricks setup, should also be the case for a normal submitted job on OSS / Fabric
289289
val localOnly = booleanEnvOrProp("SPARKUTILS_TESTING_USE_LOCAL_CONNECT")
290290

291+
// despite the function name, we may want to test against remote servers
292+
val connectURL = stringEnvOrProp("SPARK_REMOTE")
293+
291294
Some(
295+
if (connectURL ne null)
296+
ExistingSession(SparkSession.builder.config(clientConfig).getOrCreate())
297+
else
298+
if (spawnConnect) {
299+
// if there is a forced local connect, e.g. running 4.0.0 full shades on a later Fabric 1.4 that doesn't force a
300+
// connect setup, we have a way out
301+
if (localOnly)
302+
ExistingSession(SparkSession.builder.config("spark.api.mode", "connect").getOrCreate())
303+
else
304+
new ConnectSession {
305+
306+
val filter =
307+
serverConfig.get(DEBUG_CONNECT_LOGS_SYS).exists { v =>
308+
System.setProperty(DEBUG_CONNECT_LOGS_SYS, v)
309+
true
310+
}
292311

293-
if (spawnConnect) {
294-
// if there is a forced local connect, e.g. running 4.0.0 full shades on a later Fabric 1.4 that doesn't force a
295-
// connect setup, we have a way out
296-
if (localOnly)
297-
ExistingSession(SparkSession.builder.config("spark.api.mode", "connect").getOrCreate())
298-
else
299-
new ConnectSession {
300-
301-
val filter =
302-
serverConfig.get(DEBUG_CONNECT_LOGS_SYS).exists { v =>
303-
System.setProperty(DEBUG_CONNECT_LOGS_SYS, v)
304-
true
312+
val utils = SparkConnectServerUtils(
313+
if (filter)
314+
serverConfig - DEBUG_CONNECT_LOGS_SYS
315+
else
316+
serverConfig
317+
)
318+
319+
val th = System.getProperty("spark.test.home")
320+
if (th eq null) {
321+
new File("./testing_connect_tmp").mkdirs()
322+
System.setProperty("spark.test.home","./testing_connect_tmp")
305323
}
306324

307-
val utils = SparkConnectServerUtils(
308-
if (filter)
309-
serverConfig - DEBUG_CONNECT_LOGS_SYS
310-
else
311-
serverConfig
312-
)
313-
314-
val th = System.getProperty("spark.test.home")
315-
if (th eq null) {
316-
new File("./testing_connect_tmp").mkdirs()
317-
System.setProperty("spark.test.home","./testing_connect_tmp")
318-
}
319-
320325

321-
utils.start()
326+
utils.start()
322327

323-
private def createSparkSession: SparkSession = SparkConnectServerUtils.createSparkSession(utils.port, clientConfig)
328+
private def createSparkSession: SparkSession = SparkConnectServerUtils.createSparkSession(utils.port, clientConfig)
324329

325-
private var _sparkSession: SparkSession = createSparkSession
330+
private var _sparkSession: SparkSession = createSparkSession
326331

327-
override def sparkSession: sql.SparkSession = _sparkSession
332+
override def sparkSession: sql.SparkSession = _sparkSession
328333

329-
override def stopServer(): Unit = utils.stop()
334+
override def stopServer(): Unit = utils.stop()
330335

331-
override def resetSession(): Unit =
332-
if (!(onFabricOrSynapse(_sparkSession) || onDatabricksFS)) {
333-
if (sparkSession.isUsable) {
334-
sparkSession.stop()
335-
}
336-
_sparkSession = createSparkSession
337-
} // else leave as is, no reset to do
338-
}
339-
} else
340-
ExistingSession(SparkSession.active)
336+
override def resetSession(): Unit =
337+
if (!(onFabricOrSynapse(_sparkSession) || onDatabricksFS)) {
338+
if (sparkSession.isUsable) {
339+
sparkSession.stop()
340+
}
341+
_sparkSession = createSparkSession
342+
} // else leave as is, no reset to do
343+
}
344+
} else
345+
ExistingSession(SparkSession.active)
341346
)
342347
}
343348

runtime/src/main/scala/com/sparkutils/testing/SparkTestUtils.scala

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import org.apache.spark.sql.SparkSession
55
import java.io.File
66
import java.util.concurrent.atomic.AtomicReference
77
import scala.util.Try
8+
import scala.jdk.CollectionConverters._
89

910
object SparkTestUtils {
1011

@@ -21,6 +22,23 @@ object SparkTestUtils {
2122
parseBoolean( System.getProperty(env) )
2223
).getOrElse(default)
2324

25+
def stringEnvOrProp(env: String, default: String = null): String =
26+
Option( System.getenv(env) ).orElse(
27+
Option( System.getProperty(env) )
28+
).getOrElse(default)
29+
30+
def configFromPrefix(prefix: String): Map[String, String] = {
31+
def hasPrefix(p: (String, String)) =
32+
if (p._1.startsWith(prefix))
33+
Some((p._1.drop(prefix.length), p._2))
34+
else
35+
None
36+
37+
(System.getenv().asScala.flatMap(hasPrefix) ++
38+
System.getProperties.asScala.flatMap(hasPrefix)).toMap
39+
}
40+
41+
2442
/**
2543
* If there is a sparkSession already _and_ it's connect - then default to true, otherwise false.
2644
*
@@ -135,7 +153,7 @@ object SparkTestUtils {
135153
*/
136154
def jvmOpt(pair: (String, String)) = FLAT_JVM_OPTION+pair._1 -> pair._2
137155

138-
private val _runtimeConnectClientConfig = new AtomicReference[Map[String,String]](Map.empty)
156+
private val _runtimeConnectClientConfig = new AtomicReference[Map[String,String]](configFromPrefix("SPARKUTILS_CONNECT_CLIENT."))
139157

140158
def setRuntimeConnectClientConfig(config: Map[String, String]): Unit = {
141159
_runtimeConnectClientConfig.set(config)
@@ -146,7 +164,8 @@ object SparkTestUtils {
146164
*/
147165
lazy val runtimeConnectClientConfig: Map[String, String] = _runtimeConnectClientConfig.get()
148166

149-
private val _runtimeClassicConfig = new AtomicReference[Map[String,String]](Map.empty)
167+
168+
private val _runtimeClassicConfig = new AtomicReference[Map[String,String]](configFromPrefix("SPARKUTILS_CONNECT_SERVER."))
150169

151170
def setRuntimeClassicConfig(config: Map[String, String]): Unit = {
152171
_runtimeClassicConfig.set(config)
@@ -157,9 +176,9 @@ object SparkTestUtils {
157176
*/
158177
lazy val runtimeClassicConfig: Map[String, String] = _runtimeClassicConfig.get()
159178

160-
161-
162-
protected val tpath = new AtomicReference[String]("./target/testData")
179+
protected val tpath = new AtomicReference[String](
180+
stringEnvOrProp("SPARKUTILS_TEST_OUTPUTDIR", "./target/testData")
181+
)
163182

164183
def ouputDir = tpath.get
165184

testing/src/main/scala/com/sparkutils/testing/TestRunner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package com.sparkutils.testing
22

33
import com.google.common.reflect.ClassPath
44
import com.sparkutils.testing.SparkTestUtils.disableClassicTesting
5-
import com.sparkutils.testing.markers.ConnectSafe
5+
import com.sparkutils.testing.markers.{ConnectSafe, DontRunOnPureConnect}
66
import org.scalatest.Suite
77

88
import java.io.IOException
@@ -55,7 +55,7 @@ trait TestRunner {
5555
*/
5656
def usableTestSuite(clazz: Class[_]): Boolean =
5757
if (disableClassicTesting)
58-
ConnectSafe.isConnectSafe(clazz)
58+
ConnectSafe.isConnectSafe(clazz) && !DontRunOnPureConnect.shouldNotRunOnPureConnect(clazz)
5959
else
6060
true
6161

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.sparkutils.testing.markers
2+
3+
/**
4+
* Simple marker trait to indicate a test suite should not be run when only pure connect is available
5+
* (e.g. classic tests that you also want to have run with local connect, but never remote or on Shared Compute).
6+
*/
7+
trait DontRunOnPureConnect
8+
9+
object DontRunOnPureConnect {
10+
/**
11+
* Checks if the ConnectSafe trait is present, indicating a test suite is safe for use in a pure connect setup
12+
*/
13+
def shouldNotRunOnPureConnect(clazz: Class[_]): Boolean =
14+
classOf[DontRunOnPureConnect].isAssignableFrom(clazz)
15+
}
16+

0 commit comments

Comments
 (0)