From 27235df3615da3e0ee83b3a350f6b4ac121def71 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sun, 10 Dec 2023 11:03:32 +0100 Subject: [PATCH] Add 'name' and 'mode' arguments to LuaRuntime.{eval,execute,compile} to allow finer control of input and debug output. Closes https://github.com/scoder/lupa/issues/248 --- lupa/_lupa.pyx | 75 ++++++++++++++++++++++++++++++++++++---------- lupa/luaapi.pxd | 8 +++-- lupa/tests/test.py | 22 ++++++++++++++ 3 files changed, 87 insertions(+), 18 deletions(-) diff --git a/lupa/_lupa.pyx b/lupa/_lupa.pyx index 263b36ae..f355c624 100644 --- a/lupa/_lupa.pyx +++ b/lupa/_lupa.pyx @@ -393,36 +393,74 @@ cdef class LuaRuntime: raise return 0 - def eval(self, lua_code, *args): + @cython.final + cdef bytes _source_encode(self, string): + if isinstance(string, unicode): + return (string).encode(self._source_encoding) + elif isinstance(string, bytes): + return string + elif isinstance(string, bytearray): + return bytes(string) + + raise TypeError(f"Expected string, got {type(string)}") + + def eval(self, lua_code, *args, name=None, mode=None): """Evaluate a Lua expression passed in a string. + + The 'name' argument can be used to override the name printed in error messages. + + The 'mode' argument specifies the input type. By default, both source code and + pre-compiled byte code is allowed (mode='bt'). It can be restricted to source + code with mode='t' and to byte code with mode='b'. This has no effect on Lua 5.1. """ assert self._state is not NULL - if isinstance(lua_code, unicode): - lua_code = (lua_code).encode(self._source_encoding) - return run_lua(self, b'return ' + lua_code, args) + name_b = self._source_encode(name) if name is not None else None + mode_b = _asciiOrNone(mode) + return run_lua(self, b'return ' + self._source_encode(lua_code), name_b, mode_b, args) - def execute(self, lua_code, *args): + def execute(self, lua_code, *args, name=None, mode=None): """Execute a Lua program passed in a string. + + The 'name' argument can be used to override the name printed in error messages. + + The 'mode' argument specifies the input type. By default, both source code and + pre-compiled byte code is allowed (mode='bt'). It can be restricted to source + code with mode='t' and to byte code with mode='b'. This has no effect on Lua 5.1. """ assert self._state is not NULL - if isinstance(lua_code, unicode): - lua_code = (lua_code).encode(self._source_encoding) - return run_lua(self, lua_code, args) + name_b = self._source_encode(name) if name is not None else None + mode_b = _asciiOrNone(mode) + return run_lua(self, self._source_encode(lua_code), name_b, mode_b, args) - def compile(self, lua_code): + def compile(self, lua_code, name=None, mode=None): """Compile a Lua program into a callable Lua function. + + The 'name' argument can be used to override the name printed in error messages. + + The 'mode' argument specifies the input type. By default, both source code and + pre-compiled byte code is allowed (mode='bt'). It can be restricted to source + code with mode='t' and to byte code with mode='b'. This has no effect on Lua 5.1. """ assert self._state is not NULL - cdef const char *err - if isinstance(lua_code, unicode): - lua_code = (lua_code).encode(self._source_encoding) + cdef const char * c_name = b'' + cdef const char * c_mode = NULL + + lua_code_bytes = self._source_encode(lua_code) + if name is not None: + name_b = self._source_encode(name) + c_name = name_b + if mode is not None: + mode_b = _asciiOrNone(mode) + c_mode = mode_b + L = self._state lock_runtime(self) old_top = lua.lua_gettop(L) cdef size_t size + cdef const char *err try: check_lua_stack(L, 1) - status = lua.luaL_loadbuffer(L, lua_code, len(lua_code), b'') + status = lua.luaL_loadbufferx(L, lua_code_bytes, len(lua_code_bytes), c_name, c_mode) if status == 0: return py_from_lua(self, L, -1) else: @@ -1719,14 +1757,21 @@ cdef build_lua_error_message(LuaRuntime runtime, lua_State* L, int stack_index=- # calling into Lua -cdef run_lua(LuaRuntime runtime, bytes lua_code, tuple args): +cdef run_lua(LuaRuntime runtime, bytes lua_code, bytes name, bytes mode, tuple args): """Run Lua code with arguments""" cdef lua_State* L = runtime._state + cdef const char* c_name = b'' + cdef const char* c_mode = NULL + if name is not None: + c_name = name + if mode is not None: + c_mode = mode + lock_runtime(runtime) old_top = lua.lua_gettop(L) try: check_lua_stack(L, 1) - if lua.luaL_loadbuffer(L, lua_code, len(lua_code), ''): + if lua.luaL_loadbufferx(L, lua_code, len(lua_code), c_name, c_mode): error = build_lua_error_message(runtime, L) if error.startswith("not enough memory"): raise LuaMemoryError(error) diff --git a/lupa/luaapi.pxd b/lupa/luaapi.pxd index 35d4b6e1..03aad5a1 100644 --- a/lupa/luaapi.pxd +++ b/lupa/luaapi.pxd @@ -316,9 +316,10 @@ cdef extern from "lauxlib.h" nogil: int luaL_ref (lua_State *L, int t) void luaL_unref (lua_State *L, int t, int ref) - int luaL_loadfile (lua_State *L, char *filename) - int luaL_loadbuffer (lua_State *L, char *buff, size_t sz, char *name) - int luaL_loadstring (lua_State *L, char *s) + int luaL_loadfile (lua_State *L, const char *filename) + int luaL_loadbuffer (lua_State *L, const char *buff, size_t sz, const char *name) + int luaL_loadbufferx (lua_State *L, const char *buff, size_t sz, const char *name, const char *mode) + int luaL_loadstring (lua_State *L, const char *s) lua_State *luaL_newstate () @@ -450,6 +451,7 @@ cdef extern from * nogil: #if LUA_VERSION_NUM < 502 #define lua_tointegerx(L, i, isnum) (*(isnum) = lua_isnumber(L, i), lua_tointeger(L, i)) + #define luaL_loadbufferx(L, buff, sz, name, mode) (((void)mode), luaL_loadbuffer(L, buff, sz, name)) #endif #if LUA_VERSION_NUM >= 504 diff --git a/lupa/tests/test.py b/lupa/tests/test.py index 623f4d35..cbb13580 100644 --- a/lupa/tests/test.py +++ b/lupa/tests/test.py @@ -129,6 +129,14 @@ def test_eval_args(self): def test_eval_args_multi(self): self.assertEqual((1, 2, 3), self.lua.eval('...', 1, 2, 3)) + def test_eval_name_mode(self): + self.assertEqual(2, self.lua.eval('1+1', name='plus', mode='t')) + + def test_eval_mode_error(self): + if self.lupa.LUA_VERSION < (5, 2): + raise unittest.SkipTest("needs lua 5.2+") + self.assertRaises(self.lupa.LuaSyntaxError, self.lua.eval, '1+1', name='plus', mode='b') + def test_eval_error(self): self.assertRaises(self.lupa.LuaError, self.lua.eval, '') @@ -156,6 +164,14 @@ def test_eval_error_message_decoding(self): def test_execute(self): self.assertEqual(2, self.lua.execute('return 1+1')) + def test_execute_mode(self): + self.assertEqual(2, self.lua.execute('return 1+1', name='return_plus', mode='t')) + + def test_execute_mode_error(self): + if self.lupa.LUA_VERSION < (5, 2): + raise unittest.SkipTest("needs lua 5.2+") + self.assertRaises(self.lupa.LuaSyntaxError, self.lua.execute, 'return 1+1', name='plus', mode='b') + def test_execute_function(self): self.assertEqual(3, self.lua.execute('f = function(i) return i+1 end; return f(2)')) @@ -919,6 +935,12 @@ def f(*args, **kwargs): def test_compile(self): lua_func = self.lua.compile('return 1 + 2') self.assertEqual(lua_func(), 3) + lua_func = self.lua.compile('return 3 + 2', mode='t') + self.assertEqual(lua_func(), 5) + lua_func = self.lua.compile('return 1 + 3', name='huhu') + self.assertEqual(lua_func(), 4) + lua_func = self.lua.compile('return 2 + 3', name='huhu', mode='t') + self.assertEqual(lua_func(), 5) self.assertRaises(self.lupa.LuaSyntaxError, self.lua.compile, 'function awd()')