-
Notifications
You must be signed in to change notification settings - Fork 514
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
340 additions
and
237 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import absolute_import | ||
import abc | ||
|
||
import lupa | ||
|
||
from splash.render_options import BadOption | ||
from splash.utils import truncated | ||
|
||
|
||
class ImmediateResult(object): | ||
def __init__(self, value): | ||
self.value = value | ||
|
||
|
||
class AsyncCommand(object): | ||
def __init__(self, id, name, kwargs): | ||
self.id = id | ||
self.name = name | ||
self.kwargs = kwargs | ||
|
||
|
||
class ScriptError(BadOption): | ||
|
||
def enrich_from_lua_error(self, e): | ||
if not isinstance(e, lupa.LuaError): | ||
return | ||
|
||
print("enrich_from_lua_error", self, e) | ||
|
||
self_repr = repr(self.args[0]) | ||
if self_repr in e.args[0]: | ||
self.args = (e.args[0],) + self.args[1:] | ||
else: | ||
self.args = (e.args[0] + "; " + self_repr,) + self.args[1:] | ||
|
||
|
||
|
||
class BaseScriptRunner(object): | ||
""" | ||
An utility class for running Lua coroutines. | ||
""" | ||
__metaclass__ = abc.ABCMeta | ||
|
||
default_min_log_level = 2 | ||
result = '' | ||
_START_CMD = '__START__' | ||
_waiting_for_result_id = _START_CMD | ||
|
||
|
||
def __init__(self, lua, log, sandboxed): | ||
""" | ||
:param splash.lua_runtime.SplashLuaRuntime lua: Lua runtime wrapper | ||
:param log: log function | ||
:param bool sandboxed: True if the execution should use sandbox | ||
""" | ||
self.log = log | ||
self.sandboxed = sandboxed | ||
self.lua = lua | ||
|
||
def start(self, coro_func, coro_args): | ||
""" | ||
Run the script. | ||
:param callable coro_func: Lua coroutine to start | ||
:param list coro_args: arguments to pass to coro_func | ||
""" | ||
self.coro = coro_func(*coro_args) | ||
self.dispatch(self._START_CMD) | ||
|
||
@abc.abstractmethod | ||
def on_result(self, result): | ||
""" This method is called when the coroutine exits. """ | ||
pass | ||
|
||
@abc.abstractmethod | ||
def on_async_command(self, cmd): | ||
""" This method is called when AsyncCommand instance is received. """ | ||
pass | ||
|
||
def on_lua_error(self, lua_exception): | ||
""" | ||
This method is called when an exception happens in a Lua script. | ||
It is called with a lupa.LuaError instance and can raise a custom | ||
ScriptError. | ||
""" | ||
pass | ||
|
||
def dispatch(self, cmd_id, *args): | ||
""" Execute the script """ | ||
args_repr = truncated("{!r}".format(args), max_length=400, msg="...[long arguments truncated]") | ||
self.log("[lua] dispatch cmd_id={}, args={}".format(cmd_id, args_repr)) | ||
|
||
self.log( | ||
"[lua] arguments are for command %s, waiting for result of %s" % (cmd_id, self._waiting_for_result_id), | ||
min_level=3, | ||
) | ||
if cmd_id != self._waiting_for_result_id: | ||
self.log("[lua] skipping an out-of-order result {}".format(args_repr), min_level=1) | ||
return | ||
|
||
while True: | ||
try: | ||
args = args or None | ||
|
||
# Got arguments from an async command; send them to coroutine | ||
# and wait for the next async command. | ||
self.log("[lua] send %s" % args_repr) | ||
cmd = self.coro.send(args) # cmd is a next async command | ||
|
||
args = None # don't re-send the same value | ||
cmd_repr = truncated(repr(cmd), max_length=400, msg='...[long result truncated]') | ||
self.log("[lua] got {}".format(cmd_repr)) | ||
self._print_instructions_used() | ||
|
||
except StopIteration: | ||
# "main" coroutine is stopped; | ||
# previous result is a final result returned from "main" | ||
self.log("[lua] returning result") | ||
try: | ||
res = self.lua.lua2python(self.result) | ||
except ValueError as e: | ||
# can't convert result to a Python object | ||
raise ScriptError("'main' returned bad result. {!s}".format(e)) | ||
|
||
self._print_instructions_used() | ||
self.on_result(res) | ||
return | ||
except lupa.LuaError as lua_ex: | ||
# Lua script raised an error | ||
self._print_instructions_used() | ||
self.log("[lua] caught LuaError %r" % lua_ex) | ||
self.on_lua_error(lua_ex) # this can also raise a ScriptError | ||
|
||
# XXX: are Lua errors bad requests? | ||
raise ScriptError("unhandled Lua error: {!s}".format(lua_ex)) | ||
|
||
if isinstance(cmd, AsyncCommand): | ||
self.log("[lua] executing {!r}".format(cmd)) | ||
self._waiting_for_result_id = cmd.id | ||
self.on_async_command(cmd) | ||
return | ||
elif isinstance(cmd, ImmediateResult): | ||
self.log("[lua] got result {!r}".format(cmd)) | ||
args = cmd.value | ||
continue | ||
else: | ||
self.log("[lua] got non-command") | ||
|
||
if isinstance(cmd, tuple): | ||
raise ScriptError("'main' function must return a single result") | ||
|
||
self.result = cmd | ||
|
||
def _print_instructions_used(self): | ||
if self.sandboxed: | ||
self.log("[lua] instructions used: %d" % self.lua.instruction_count()) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import absolute_import | ||
import os | ||
from splash.lua import lua2python, python2lua, get_new_runtime | ||
|
||
|
||
class SplashLuaRuntime(object): | ||
""" | ||
Lua runtime wrapper, optionally with a sandbox. | ||
""" | ||
def __init__(self, sandboxed, lua_package_path, lua_sandbox_allowed_modules): | ||
""" | ||
:param bool sandboxed: whether the runtime should be sandboxed | ||
:param str lua_package_path: paths to add to Lua package.path | ||
:param iterable lua_sandbox_allowed_modules: a list of modules allowed | ||
to be required from a sandbox | ||
""" | ||
self._sandboxed = sandboxed | ||
self._lua = self._create_runtime(lua_package_path) | ||
self._setup_lua_sandbox(lua_sandbox_allowed_modules) | ||
self._allowed_object_attrs = {} | ||
|
||
def table_from(self, *args, **kwargs): | ||
return self._lua.table_from(*args, **kwargs) | ||
|
||
def eval(self, *args, **kwargs): | ||
return self._lua.eval(*args, **kwargs) | ||
|
||
def execute(self, *args, **kwargs): | ||
return self._lua.execute(*args, **kwargs) | ||
|
||
def globals(self, *args, **kwargs): | ||
return self._lua.globals(*args, **kwargs) | ||
|
||
def add_allowed_object(self, obj, attr_whitelist): | ||
""" Add a Python object to a list of objects the runtime can access """ | ||
self._allowed_object_attrs[obj] = attr_whitelist | ||
|
||
def lua2python(self, *args, **kwargs): | ||
kwargs.setdefault("binary", True) | ||
kwargs.setdefault("strict", True) | ||
return lua2python(self._lua, *args, **kwargs) | ||
|
||
def python2lua(self, *args, **kwargs): | ||
return python2lua(self._lua, *args, **kwargs) | ||
|
||
def instruction_count(self): | ||
if not self._sandboxed: | ||
return -1 | ||
try: | ||
return self._sandbox.instruction_count | ||
except Exception as e: | ||
print(e) | ||
return -1 | ||
|
||
def create_coroutine(self, func): | ||
""" | ||
Return a Python object which starts a coroutine when called. | ||
""" | ||
if self._sandboxed: | ||
return self._sandbox.create_coroutine(func) | ||
else: | ||
return func.coroutine | ||
|
||
def _create_runtime(self, lua_package_path): | ||
""" | ||
Return a restricted Lua runtime. | ||
Currently it only allows accessing attributes of this object. | ||
""" | ||
attribute_handlers=(self._attr_getter, self._attr_setter) | ||
runtime = get_new_runtime(attribute_handlers=attribute_handlers) | ||
self._setup_lua_paths(runtime, lua_package_path) | ||
return runtime | ||
|
||
def _setup_lua_paths(self, lua, lua_package_path): | ||
default_path = os.path.abspath( | ||
os.path.join( | ||
os.path.dirname(__file__), | ||
'lua_modules' | ||
) | ||
) + "/?.lua" | ||
if lua_package_path: | ||
packages_path = ";".join([default_path, lua_package_path]) | ||
else: | ||
packages_path = default_path | ||
|
||
lua.execute(""" | ||
package.path = "{packages_path};" .. package.path | ||
""".format(packages_path=packages_path)) | ||
|
||
@property | ||
def _sandbox(self): | ||
return self.eval("require('sandbox')") | ||
|
||
def _setup_lua_sandbox(self, allowed_modules): | ||
self._sandbox["allowed_require_names"] = self.python2lua( | ||
{name: True for name in allowed_modules} | ||
) | ||
|
||
def _attr_getter(self, obj, attr_name): | ||
|
||
if not isinstance(attr_name, basestring): | ||
raise AttributeError("Non-string lookups are not allowed (requested: %r)" % attr_name) | ||
|
||
if isinstance(attr_name, basestring) and attr_name.startswith("_"): | ||
raise AttributeError("Access to private attribute %r is not allowed" % attr_name) | ||
|
||
if obj not in self._allowed_object_attrs: | ||
raise AttributeError("Access to object %r is not allowed" % obj) | ||
|
||
if attr_name not in self._allowed_object_attrs[obj]: | ||
raise AttributeError("Access to private attribute %r is not allowed" % attr_name) | ||
|
||
value = getattr(obj, attr_name) | ||
return value | ||
|
||
def _attr_setter(self, obj, attr_name, value): | ||
raise AttributeError("Direct writing to Python objects is not allowed") |
Oops, something went wrong.