Skip to content

Commit

Permalink
refactoring + renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei Zhabinski committed Apr 14, 2016
1 parent 93c5953 commit 6e1ed30
Show file tree
Hide file tree
Showing 18 changed files with 401 additions and 283 deletions.
6 changes: 3 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ julia:
notifications:
email: false
# uncomment the following lines to override the default test script
#script:
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
# - julia --check-bounds=yes -e 'Pkg.clone(pwd()); Pkg.build("Sparta"); Pkg.test("Sparta"; coverage=true)'
script:
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
- julia --check-bounds=yes -e 'Pkg.clone(pwd()); Pkg.build("Spark"); Pkg.test("Spark"; coverage=true)'
2 changes: 1 addition & 1 deletion LICENSE.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
The Sparta.jl package is licensed under the MIT "Expat" License:
The Spark.jl package is licensed under the MIT "Expat" License:

> Copyright (c) 2015: dfdx.
>
Expand Down
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Sparta
# Spark.jl

[![Build Status](https://travis-ci.org/dfdx/Sparta.jl.svg?branch=master)](https://travis-ci.org/dfdx/Sparta.jl)
[![Build Status](https://travis-ci.org/dfdx/Spark.jl.svg?branch=master)](https://travis-ci.org/dfdx/Spark.jl)

Julia interface to Apache Spark

See [Roadmap](https://github.com/dfdx/Sparta.jl/issues/1) for the current status.
2 changes: 1 addition & 1 deletion deps/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ catch
error("Cannot find maven. Is it installed")
end

cd("../jvm/sparta")
cd("../jvm/sparkjl")
run(`mvn clean package`)
6 changes: 3 additions & 3 deletions jvm/sparta/pom.xml → jvm/sparkjl/pom.xml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<project>
<groupId>julia-sparta</groupId>
<artifactId>sparta</artifactId>
<groupId>sparkjl</groupId>
<artifactId>sparkjl</artifactId>
<modelVersion>4.0.0</modelVersion>
<name>sparta</name>
<name>sparkjl</name>
<packaging>jar</packaging>
<version>0.1</version>
<repositories>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package org.apache.spark.api.julia

import java.io.{BufferedInputStream, DataInputStream, EOFException}
import java.net.Socket

import org.apache.commons.compress.utils.Charsets
import org.apache.spark._

/**
* Iterator that connects to a Julia process and reads data back to JVM.
* */
class InputIterator(context: TaskContext, worker: Socket, outputThread: OutputThread) extends Iterator[Array[Byte]] with Logging {

val BUFFER_SIZE = 65536

val env = SparkEnv.get
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, BUFFER_SIZE))

override def next(): Array[Byte] = {
val obj = _nextObj
if (hasNext) {
_nextObj = read()
}
obj
}

private def read(): Array[Byte] = {
if (outputThread.exception.isDefined) {
throw outputThread.exception.get
}
try {
stream.readInt() match {
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case 0 => Array.empty[Byte]
case SpecialLengths.JULIA_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in julia
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new Exception(new String(obj, Charsets.UTF_8),
outputThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
null
} else {
throw new RuntimeException("Protocol error")
}

}
} catch {

case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException

case e: Exception if env.isStopped =>
logDebug("Exception thrown after context is stopped", e)
null // exit silently

case e: Exception if outputThread.exception.isDefined =>
logError("Julia worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", outputThread.exception.get)
throw outputThread.exception.get

case eof: EOFException =>
throw new SparkException("Julia worker exited unexpectedly (crashed)", eof)
}
}

var _nextObj = read()

override def hasNext: Boolean = _nextObj != null

}
125 changes: 125 additions & 0 deletions jvm/sparkjl/src/main/scala/org/apache/spark/api/julia/JuliaRDD.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package org.apache.spark.api.julia

import java.io._
import java.net._

import org.apache.commons.compress.utils.Charsets
import org.apache.spark._
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD

import scala.collection.JavaConversions._
import scala.language.existentials


class JuliaRDD(
@transient parent: RDD[_],
command: Array[Byte],
inputType: Array[Byte]
) extends RDD[Array[Byte]](parent) {

val preservePartitioning = true
val reuseWorker = true

override def getPartitions: Array[Partition] = firstParent.partitions

override val partitioner: Option[Partitioner] = {
if (preservePartitioning) firstParent.partitioner else None
}


override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val worker: Socket = JuliaRDD.createWorker()
// Start a thread to feed the process input from our parent's iterator
val outputThread = new OutputThread(context, firstParent.iterator(split, context), worker, command, inputType, split)
outputThread.start()
// Return an iterator that read lines from the process's stdout
val resultIterator = new InputIterator(context, worker, outputThread)
new InterruptibleIterator(context, resultIterator)
}

val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def collect(): Array[Array[Byte]] = {
super.collect()
}


}

private object SpecialLengths {
val END_OF_DATA_SECTION = -1
val JULIA_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
val END_OF_STREAM = -4
val NULL = -5
}

object JuliaRDD extends Logging {

def fromJavaRDD[T](javaRdd: JavaRDD[T], command: Array[Byte], inputType: Array[Byte]): JuliaRDD =
new JuliaRDD(JavaRDD.toRDD(javaRdd), command, inputType)

def createWorker(): Socket = {
var serverSocket: ServerSocket = null
try {
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1).map(_.toByte)))

// Create and start the worker
val pb = new ProcessBuilder(Seq("julia", "-e", "using Spark; using Iterators; Spark.launch_worker()"))
// val workerEnv = pb.environment()
// workerEnv.putAll(envVars)
val worker = pb.start()

// Redirect worker stdout and stderr
StreamUtils.redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream)

// Tell the worker our port
val out = new OutputStreamWriter(worker.getOutputStream)
out.write(serverSocket.getLocalPort + "\n")
out.flush()

// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
val socket = serverSocket.accept()
// workers.put(socket, worker)
return socket
} catch {
case e: Exception =>
throw new SparkException("Julia worker did not connect back in time", e)
}
} finally {
if (serverSocket != null) {
serverSocket.close()
}
}
null
}


def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {

def write(obj: Any): Unit = {
obj match {
case arr: Array[Byte] =>
dataOut.writeInt(arr.length)
dataOut.write(arr)
case str: String =>
writeUTF(str, dataOut)
case other =>
throw new SparkException("Unexpected element type " + other.getClass)
}
}

iter.foreach(write)
}

def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes(Charsets.UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.apache.spark.api.julia

import scala.collection.JavaConversions._

/**
* Class for execution of Julia scripts on a cluster.
* WARNING: this class isn't used currently, will be utilized later
*/
object JuliaRunner {

def main(args: Array[String]): Unit = {
val juliaScript = args(0)
val scriptArgs = args.slice(1, args.length)
val pb = new ProcessBuilder(Seq("julia", juliaScript) ++ scriptArgs)
val process = pb.start()
StreamUtils.redirectStreamsToStderr(process.getInputStream, process.getErrorStream)
val errorCode = process.waitFor()
if (errorCode != 0) {
throw new RuntimeException("Julia script exited with an error")
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.apache.spark.api.julia

import java.io.{DataOutputStream, BufferedOutputStream}
import java.net.Socket

import org.apache.spark.util.Utils
import org.apache.spark.{TaskContext, Partition, SparkEnv}

/**
* The thread responsible for writing the data from the JuliaRDD's parent iterator to the
* Julia process.
*/
class OutputThread(context: TaskContext, it: Iterator[Array[Byte]], worker: Socket, command: Array[Byte], inputType: Array[Byte], split: Partition)
extends Thread(s"stdout writer for julia") {

val BUFFER_SIZE = 65536

val env = SparkEnv.get

@volatile private var _exception: Exception = null

/** Contains the exception thrown while writing the parent iterator to the Julia process. */
def exception: Option[Exception] = Option(_exception)

/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
def shutdownOnTaskCompletion() {
assert(context.isCompleted)
this.interrupt()
}

override def run(): Unit = Utils.logUncaughtExceptions {
try {
val stream = new BufferedOutputStream(worker.getOutputStream, BUFFER_SIZE)
val dataOut = new DataOutputStream(stream)
// partition index
dataOut.writeInt(split.index)
dataOut.flush()
// input type
dataOut.write(inputType)
dataOut.flush()
// serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)
dataOut.flush()
// data values
JuliaRDD.writeIteratorToStream(it, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
// FIXME: logDebug("Exception thrown after task completion (likely due to cleanup)", e)
println("Exception thrown after task completion (likely due to cleanup)", e)
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}

case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}
} finally {
// Release memory used by this thread for shuffles
// env.shuffleMemoryManager.releaseMemoryForThisThread()
env.shuffleMemoryManager.releaseMemoryForThisTask()
// Release memory used by this thread for unrolling blocks
// env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
env.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.apache.spark.api.julia

import java.io.InputStream

import org.apache.spark.Logging
import org.apache.spark.util.RedirectThread


object StreamUtils extends Logging {

/**
* Redirect the given streams to our stderr in separate threads.
*/
def redirectStreamsToStderr(stdout: InputStream, stderr: InputStream) {
try {
new RedirectThread(stdout, System.err, "stdout reader for julia").start()
new RedirectThread(stderr, System.err, "stderr reader for julia").start()
} catch {
case e: Exception =>
logError("Exception in redirecting streams", e)
}
}

}
Loading

0 comments on commit 6e1ed30

Please sign in to comment.