要将Scala嵌入为“脚本语言”,我需要能够将文本片段编译为简单对象,例如可以从磁盘序列化和反序列化的Function0[Unit]
,并且可以将其加载到当前运行时并执行。
我该怎么做?
比如说,我的文本片段是(纯粹是假设的):
Document.current.elements.headOption.foreach(_.open())
这可能包含在以下完整文本中:
package myapp.userscripts
import myapp.DSL._
object UserFunction1234 extends Function0[Unit] {
def apply(): Unit = {
Document.current.elements.headOption.foreach(_.open())
}
}
接下来会发生什么?我应该使用IMain
来编译此代码吗?我不想使用普通的解释器模式,因为编译应该是“无上下文”而不是累积请求。
我需要从编译中解脱出来的是我猜二进制类文件?在这种情况下,序列化是直接的(字节数组)。然后,我如何将该类加载到运行时并调用apply
方法?
如果代码编译为多个辅助类会发生什么?上面的示例包含一个闭包_.open()
。如何确保将所有辅助内容“打包”到一个对象中以进行序列化和类加载?
注意:鉴于Scala 2.11即将发布且编译器API可能已更改,我很高兴收到有关如何在Scala 2.11上解决此问题的提示
答案 0 :(得分:4)
这是一个想法:使用常规的Scala编译器实例。不幸的是,它似乎需要使用硬盘文件进行输入和输出。所以我们使用临时文件。输出将在JAR中压缩,JAR将存储为字节数组(将进入假设的序列化过程)。我们需要一个特殊的类加载器来从提取的JAR中再次检索类。
以下假设Scala 2.10.3在类路径上带有scala-compiler
库:
import scala.tools.nsc
import java.io._
import scala.annotation.tailrec
在函数类中包装用户提供的代码,其合成名称将针对每个新片段递增:
val packageName = "myapp"
var userCount = 0
def mkFunName(): String = {
val c = userCount
userCount += 1
s"Fun$c"
}
def wrapSource(source: String): (String, String) = {
val fun = mkFunName()
val code = s"""package $packageName
|
|class $fun extends Function0[Unit] {
| def apply(): Unit = {
| $source
| }
|}
|""".stripMargin
(fun, code)
}
编译源片段并返回生成的jar的字节数组的函数:
/** Compiles a source code consisting of a body which is wrapped in a `Function0`
* apply method, and returns the function's class name (without package) and the
* raw jar file produced in the compilation.
*/
def compile(source: String): (String, Array[Byte]) = {
val set = new nsc.Settings
val d = File.createTempFile("temp", ".out")
d.delete(); d.mkdir()
set.d.value = d.getPath
set.usejavacp.value = true
val compiler = new nsc.Global(set)
val f = File.createTempFile("temp", ".scala")
val out = new BufferedOutputStream(new FileOutputStream(f))
val (fun, code) = wrapSource(source)
out.write(code.getBytes("UTF-8"))
out.flush(); out.close()
val run = new compiler.Run()
run.compile(List(f.getPath))
f.delete()
val bytes = packJar(d)
deleteDir(d)
(fun, bytes)
}
def deleteDir(base: File): Unit = {
base.listFiles().foreach { f =>
if (f.isFile) f.delete()
else deleteDir(f)
}
base.delete()
}
注意:尚未处理编译器错误!
packJar
方法使用编译器输出目录并从中生成内存中的jar文件:
// cf. http://stackoverflow.com/questions/1281229
def packJar(base: File): Array[Byte] = {
import java.util.jar._
val mf = new Manifest
mf.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0")
val bs = new java.io.ByteArrayOutputStream
val out = new JarOutputStream(bs, mf)
def add(prefix: String, f: File): Unit = {
val name0 = prefix + f.getName
val name = if (f.isDirectory) name0 + "/" else name0
val entry = new JarEntry(name)
entry.setTime(f.lastModified())
out.putNextEntry(entry)
if (f.isFile) {
val in = new BufferedInputStream(new FileInputStream(f))
try {
val buf = new Array[Byte](1024)
@tailrec def loop(): Unit = {
val count = in.read(buf)
if (count >= 0) {
out.write(buf, 0, count)
loop()
}
}
loop()
} finally {
in.close()
}
}
out.closeEntry()
if (f.isDirectory) f.listFiles.foreach(add(name, _))
}
base.listFiles().foreach(add("", _))
out.close()
bs.toByteArray
}
一个实用程序函数,它接受反序列化中找到的字节数组,并创建一个从类名到类字节代码的映射:
def unpackJar(bytes: Array[Byte]): Map[String, Array[Byte]] = {
import java.util.jar._
import scala.annotation.tailrec
val in = new JarInputStream(new ByteArrayInputStream(bytes))
val b = Map.newBuilder[String, Array[Byte]]
@tailrec def loop(): Unit = {
val entry = in.getNextJarEntry
if (entry != null) {
if (!entry.isDirectory) {
val name = entry.getName
// cf. http://stackoverflow.com/questions/8909743
val bs = new ByteArrayOutputStream
var i = 0
while (i >= 0) {
i = in.read()
if (i >= 0) bs.write(i)
}
val bytes = bs.toByteArray
b += mkClassName(name) -> bytes
}
loop()
}
}
loop()
in.close()
b.result()
}
def mkClassName(path: String): String = {
require(path.endsWith(".class"))
path.substring(0, path.length - 6).replace("/", ".")
}
合适的类加载器:
class MemoryClassLoader(map: Map[String, Array[Byte]]) extends ClassLoader {
override protected def findClass(name: String): Class[_] =
map.get(name).map { bytes =>
println(s"defineClass($name, ...)")
defineClass(name, bytes, 0, bytes.length)
} .getOrElse(super.findClass(name)) // throws exception
}
包含其他类(闭包)的测试用例:
val exampleSource =
"""val xs = List("hello", "world")
|println(xs.map(_.capitalize).mkString(" "))
|""".stripMargin
def test(fun: String, cl: ClassLoader): Unit = {
val clName = s"$packageName.$fun"
println(s"Resolving class '$clName'...")
val clazz = Class.forName(clName, true, cl)
println("Instantiating...")
val x = clazz.newInstance().asInstanceOf[() => Unit]
println("Invoking 'apply':")
x()
}
locally {
println("Compiling...")
val (fun, bytes) = compile(exampleSource)
val map = unpackJar(bytes)
println("Classes found:")
map.keys.foreach(k => println(s" '$k'"))
val cl = new MemoryClassLoader(map)
test(fun, cl) // should call `defineClass`
test(fun, cl) // should find cached class
}