Skip to content

Commit

Permalink
refactor LuaRender
Browse files Browse the repository at this point in the history
  • Loading branch information
kmike committed Feb 2, 2015
1 parent da6b43a commit e3ab12c
Show file tree
Hide file tree
Showing 3 changed files with 340 additions and 237 deletions.
158 changes: 158 additions & 0 deletions splash/lua_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import abc

import lupa

from splash.render_options import BadOption
from splash.utils import truncated


class ImmediateResult(object):
def __init__(self, value):
self.value = value


class AsyncCommand(object):
def __init__(self, id, name, kwargs):
self.id = id
self.name = name
self.kwargs = kwargs


class ScriptError(BadOption):

def enrich_from_lua_error(self, e):
if not isinstance(e, lupa.LuaError):
return

print("enrich_from_lua_error", self, e)

self_repr = repr(self.args[0])
if self_repr in e.args[0]:
self.args = (e.args[0],) + self.args[1:]
else:
self.args = (e.args[0] + "; " + self_repr,) + self.args[1:]



class BaseScriptRunner(object):
"""
An utility class for running Lua coroutines.
"""
__metaclass__ = abc.ABCMeta

default_min_log_level = 2
result = ''
_START_CMD = '__START__'
_waiting_for_result_id = _START_CMD


def __init__(self, lua, log, sandboxed):
"""
:param splash.lua_runtime.SplashLuaRuntime lua: Lua runtime wrapper
:param log: log function
:param bool sandboxed: True if the execution should use sandbox
"""
self.log = log
self.sandboxed = sandboxed
self.lua = lua

def start(self, coro_func, coro_args):
"""
Run the script.
:param callable coro_func: Lua coroutine to start
:param list coro_args: arguments to pass to coro_func
"""
self.coro = coro_func(*coro_args)
self.dispatch(self._START_CMD)

@abc.abstractmethod
def on_result(self, result):
""" This method is called when the coroutine exits. """
pass

@abc.abstractmethod
def on_async_command(self, cmd):
""" This method is called when AsyncCommand instance is received. """
pass

def on_lua_error(self, lua_exception):
"""
This method is called when an exception happens in a Lua script.
It is called with a lupa.LuaError instance and can raise a custom
ScriptError.
"""
pass

def dispatch(self, cmd_id, *args):
""" Execute the script """
args_repr = truncated("{!r}".format(args), max_length=400, msg="...[long arguments truncated]")
self.log("[lua] dispatch cmd_id={}, args={}".format(cmd_id, args_repr))

self.log(
"[lua] arguments are for command %s, waiting for result of %s" % (cmd_id, self._waiting_for_result_id),
min_level=3,
)
if cmd_id != self._waiting_for_result_id:
self.log("[lua] skipping an out-of-order result {}".format(args_repr), min_level=1)
return

while True:
try:
args = args or None

# Got arguments from an async command; send them to coroutine
# and wait for the next async command.
self.log("[lua] send %s" % args_repr)
cmd = self.coro.send(args) # cmd is a next async command

args = None # don't re-send the same value
cmd_repr = truncated(repr(cmd), max_length=400, msg='...[long result truncated]')
self.log("[lua] got {}".format(cmd_repr))
self._print_instructions_used()

except StopIteration:
# "main" coroutine is stopped;
# previous result is a final result returned from "main"
self.log("[lua] returning result")
try:
res = self.lua.lua2python(self.result)
except ValueError as e:
# can't convert result to a Python object
raise ScriptError("'main' returned bad result. {!s}".format(e))

self._print_instructions_used()
self.on_result(res)
return
except lupa.LuaError as lua_ex:
# Lua script raised an error
self._print_instructions_used()
self.log("[lua] caught LuaError %r" % lua_ex)
self.on_lua_error(lua_ex) # this can also raise a ScriptError

# XXX: are Lua errors bad requests?
raise ScriptError("unhandled Lua error: {!s}".format(lua_ex))

if isinstance(cmd, AsyncCommand):
self.log("[lua] executing {!r}".format(cmd))
self._waiting_for_result_id = cmd.id
self.on_async_command(cmd)
return
elif isinstance(cmd, ImmediateResult):
self.log("[lua] got result {!r}".format(cmd))
args = cmd.value
continue
else:
self.log("[lua] got non-command")

if isinstance(cmd, tuple):
raise ScriptError("'main' function must return a single result")

self.result = cmd

def _print_instructions_used(self):
if self.sandboxed:
self.log("[lua] instructions used: %d" % self.lua.instruction_count())

118 changes: 118 additions & 0 deletions splash/lua_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import os
from splash.lua import lua2python, python2lua, get_new_runtime


class SplashLuaRuntime(object):
"""
Lua runtime wrapper, optionally with a sandbox.
"""
def __init__(self, sandboxed, lua_package_path, lua_sandbox_allowed_modules):
"""
:param bool sandboxed: whether the runtime should be sandboxed
:param str lua_package_path: paths to add to Lua package.path
:param iterable lua_sandbox_allowed_modules: a list of modules allowed
to be required from a sandbox
"""
self._sandboxed = sandboxed
self._lua = self._create_runtime(lua_package_path)
self._setup_lua_sandbox(lua_sandbox_allowed_modules)
self._allowed_object_attrs = {}

def table_from(self, *args, **kwargs):
return self._lua.table_from(*args, **kwargs)

def eval(self, *args, **kwargs):
return self._lua.eval(*args, **kwargs)

def execute(self, *args, **kwargs):
return self._lua.execute(*args, **kwargs)

def globals(self, *args, **kwargs):
return self._lua.globals(*args, **kwargs)

def add_allowed_object(self, obj, attr_whitelist):
""" Add a Python object to a list of objects the runtime can access """
self._allowed_object_attrs[obj] = attr_whitelist

def lua2python(self, *args, **kwargs):
kwargs.setdefault("binary", True)
kwargs.setdefault("strict", True)
return lua2python(self._lua, *args, **kwargs)

def python2lua(self, *args, **kwargs):
return python2lua(self._lua, *args, **kwargs)

def instruction_count(self):
if not self._sandboxed:
return -1
try:
return self._sandbox.instruction_count
except Exception as e:
print(e)
return -1

def create_coroutine(self, func):
"""
Return a Python object which starts a coroutine when called.
"""
if self._sandboxed:
return self._sandbox.create_coroutine(func)
else:
return func.coroutine

def _create_runtime(self, lua_package_path):
"""
Return a restricted Lua runtime.
Currently it only allows accessing attributes of this object.
"""
attribute_handlers=(self._attr_getter, self._attr_setter)
runtime = get_new_runtime(attribute_handlers=attribute_handlers)
self._setup_lua_paths(runtime, lua_package_path)
return runtime

def _setup_lua_paths(self, lua, lua_package_path):
default_path = os.path.abspath(
os.path.join(
os.path.dirname(__file__),
'lua_modules'
)
) + "/?.lua"
if lua_package_path:
packages_path = ";".join([default_path, lua_package_path])
else:
packages_path = default_path

lua.execute("""
package.path = "{packages_path};" .. package.path
""".format(packages_path=packages_path))

@property
def _sandbox(self):
return self.eval("require('sandbox')")

def _setup_lua_sandbox(self, allowed_modules):
self._sandbox["allowed_require_names"] = self.python2lua(
{name: True for name in allowed_modules}
)

def _attr_getter(self, obj, attr_name):

if not isinstance(attr_name, basestring):
raise AttributeError("Non-string lookups are not allowed (requested: %r)" % attr_name)

if isinstance(attr_name, basestring) and attr_name.startswith("_"):
raise AttributeError("Access to private attribute %r is not allowed" % attr_name)

if obj not in self._allowed_object_attrs:
raise AttributeError("Access to object %r is not allowed" % obj)

if attr_name not in self._allowed_object_attrs[obj]:
raise AttributeError("Access to private attribute %r is not allowed" % attr_name)

value = getattr(obj, attr_name)
return value

def _attr_setter(self, obj, attr_name, value):
raise AttributeError("Direct writing to Python objects is not allowed")
Loading

0 comments on commit e3ab12c

Please sign in to comment.