Skip to content

Commit

Permalink
[KYUUBI apache#6265]Resource isolation in Spark Scala mode
Browse files Browse the repository at this point in the history
  • Loading branch information
wangjunbo committed Apr 22, 2024
1 parent 1591157 commit e3c652e
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 10 deletions.
6 changes: 6 additions & 0 deletions externals/kyuubi-spark-sql-engine/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_${scala.binary.version}</artifactId>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.apache.kyuubi</groupId>
<artifactId>${hive.jdbc.artifact}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import scala.tools.nsc.interpreter.Results.{Error, Incomplete, Success}

import org.apache.hadoop.fs.Path
import org.apache.spark.SparkFiles
import org.apache.spark.kyuubi.SparkJobArtifactHelper
import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveClientHelper
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.KyuubiSQLException
Expand Down Expand Up @@ -90,13 +92,14 @@ class ExecuteScala(
warn(s"Clearing legacy output from last interpreting:\n $legacyOutput")
}
val replUrls = repl.classLoader.getParent.asInstanceOf[URLClassLoader].getURLs
spark.sharedState.jarClassLoader.getURLs.filterNot(replUrls.contains).foreach { jar =>
val root = new File(SparkFiles.getRootDirectory(), session.handle.identifier.toString)
HiveClientHelper.getLoadedClasses(spark).filterNot(replUrls.contains).foreach { jar =>
try {
if ("file".equals(jar.toURI.getScheme)) {
repl.addUrlsToClassPath(jar)
} else {
spark.sparkContext.addFile(jar.toString)
val localJarFile = new File(SparkFiles.get(new Path(jar.toURI.getPath).getName))
val localJarFile = new File(root, new Path(jar.toURI.getPath).getName)
val localJarUrl = localJarFile.toURI.toURL
if (!replUrls.contains(localJarUrl)) {
repl.addUrlsToClassPath(localJarUrl)
Expand Down Expand Up @@ -140,7 +143,9 @@ class ExecuteScala(
val asyncOperation = new Runnable {
override def run(): Unit = {
OperationLog.setCurrentOperationLog(operationLog)
executeScala()
withSessionArtifactState {
executeScala()
}
}
}

Expand All @@ -157,7 +162,15 @@ class ExecuteScala(
throw ke
}
} else {
executeScala()
withSessionArtifactState {
executeScala()
}
}
}

private def withSessionArtifactState(f: => Unit): Unit = {
SparkJobArtifactHelper.withActiveJobArtifactState(session.handle) {
f
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.kyuubi

import org.apache.spark.{JobArtifactSet, JobArtifactState}

import org.apache.kyuubi.session.SessionHandle

object SparkJobArtifactHelper {

def withActiveJobArtifactState(handler: SessionHandle)(f: => Unit): Unit = {
val state = JobArtifactState(handler.identifier.toString, None)
JobArtifactSet.withActiveJobArtifactState(state) {
f
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive

import java.net.URL

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SessionResourceLoader
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION

object HiveClientHelper {

type HiveClientImpl = org.apache.spark.sql.hive.client.HiveClientImpl

def getLoadedClasses(spark: SparkSession): Array[URL] = {
if (spark.conf.get(CATALOG_IMPLEMENTATION).equals("hive")) {
val loader = spark.sessionState.resourceLoader
getHiveLoadedClasses(loader)
} else {
spark.sharedState.jarClassLoader.getURLs
}
}

private def getHiveLoadedClasses(loader: SessionResourceLoader): Array[URL] = {
if (loader != null) {
val field = classOf[HiveSessionResourceLoader].getDeclaredField("client")
field.setAccessible(true)
val client = field.get(loader).asInstanceOf[HiveClientImpl]
if (client != null) {
client.clientLoader.classLoader.getURLs
} else {
Array.empty
}
} else {
Array.empty
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,17 @@ class KyuubiOperationPerUserSuite
}

test("scala NPE issue with hdfs jar") {
val dfsJarPath = prepareHdfsJar
withJdbcStatement() { statement =>
val kyuubiStatement = statement.asInstanceOf[KyuubiStatement]
statement.executeQuery(s"add jar $dfsJarPath")
val rs = kyuubiStatement.executeScala("println(test.utils.Math.add(1,2))")
rs.next()
assert(rs.getString(1) === "3")
}
}

private def prepareHdfsJar: Path = {
val jarDir = Utils.createTempDir().toFile
val udfCode =
"""
Expand All @@ -225,12 +236,44 @@ class KyuubiOperationPerUserSuite
val localPath = new Path(jarFile.getAbsolutePath)
val dfsJarPath = new Path(dfsJarDir, "test-function.jar")
FileUtil.copy(localFs, localPath, dfs, dfsJarPath, false, false, hadoopConf)
withJdbcStatement() { statement =>
val kyuubiStatement = statement.asInstanceOf[KyuubiStatement]
statement.executeQuery(s"add jar $dfsJarPath")
val rs = kyuubiStatement.executeScala("println(test.utils.Math.add(1,2))")
rs.next()
assert(rs.getString(1) === "3")
dfsJarPath
}

test("support scala mode resource isolation") {
val dfsJarPath = prepareHdfsJar
withSessionConf()(
Map(
KyuubiConf.ENGINE_SHARE_LEVEL_SUBDOMAIN.key -> "resource_isolation",
"spark.sql.catalogImplementation" -> "hive"))(
Map.empty) {
var r1: String = null
var exception: Exception = null

new Thread {
override def run(): Unit = withJdbcStatement() { statement =>
val kyuubiStatement = statement.asInstanceOf[KyuubiStatement]
kyuubiStatement.executeQuery(s"add jar $dfsJarPath")
val rs = kyuubiStatement.executeScala("println(test.utils.Math.add(1,2))")
rs.next()
r1 = rs.getString(1)
}
}.start()

new Thread {
override def run(): Unit = withJdbcStatement() { statement =>
val kyuubiStatement = statement.asInstanceOf[KyuubiStatement]
exception = intercept[Exception] {
kyuubiStatement.executeScala("println(test.utils.Math.add(1,2))")
}
}
}.start()

eventually(timeout(120.seconds), interval(100.milliseconds)) {
assert(r1 != null && exception != null)
}

assert(r1 === "3")
assert(exception.getMessage.contains("not found: value test"))
}
}

Expand Down

0 comments on commit e3c652e

Please sign in to comment.