diff --git a/lupa/_lupa.pyx b/lupa/_lupa.pyx index 6e085afe..46b3c103 100644 --- a/lupa/_lupa.pyx +++ b/lupa/_lupa.pyx @@ -157,6 +157,10 @@ def lua_type(obj): lua.lua_settop(L, old_top) unlock_runtime(lua_object._runtime) +cdef int _len_as_int(Py_ssize_t obj) except -1: + if obj > LONG_MAX: + raise OverflowError + return obj @cython.no_gc_clear cdef class LuaRuntime: @@ -430,9 +434,9 @@ cdef class LuaRuntime: # FIXME: how to check for failure? and nested dict for obj in args: if isinstance(obj, dict): - for key, value in obj.iteritems(): # in python3, this is called items - py_to_lua(self, L, key, False, recursive) - py_to_lua(self, L, value, False, recursive) + for key, value in obj.iteritems(): + py_to_lua(self, L, key, wrap_none=False, recursive=recursive) + py_to_lua(self, L, value, wrap_none=False, recursive=recursive) lua.lua_rawset(L, -3) elif isinstance(obj, _LuaTable): @@ -1136,6 +1140,8 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args): # already terminated raise StopIteration if args: + if len(args) > LONG_MAX: + raise OverflowError nargs = len(args) push_lua_arguments(thread._runtime, co, args) with nogil: @@ -1463,6 +1469,10 @@ cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=Fa elif isinstance(o, float): lua.lua_pushnumber(L, o) pushed_values_count = 1 + elif isinstance(o, _PyProtocolWrapper): + type_flags = (<_PyProtocolWrapper> o)._type_flags + o = (<_PyProtocolWrapper> o)._obj + pushed_values_count = py_to_lua_custom(runtime, L, o, type_flags) elif recursive and isinstance(o, Sequence): lua.lua_createtable(L, len(o), 0) # create a table at the top of stack, with narr already known for i, v in enumerate(o): @@ -1471,18 +1481,14 @@ cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=Fa pushed_values_count = 1 elif recursive and isinstance(o, Mapping): lua.lua_createtable(L, 0, len(o)) # create a table at the top of stack, with nrec already known - for key, value in o.iteritems(): # to compatible with py2 + for key, value in o.items(): py_to_lua(runtime, L, key, wrap_none, recursive) py_to_lua(runtime, L, value, wrap_none, recursive) lua.lua_rawset(L, -3) pushed_values_count = 1 else: - if isinstance(o, _PyProtocolWrapper): - type_flags = (<_PyProtocolWrapper>o)._type_flags - o = (<_PyProtocolWrapper>o)._obj - else: - # prefer __getitem__ over __getattr__ by default - type_flags = OBJ_AS_INDEX if hasattr(o, '__getitem__') else 0 + # prefer __getitem__ over __getattr__ by default + type_flags = OBJ_AS_INDEX if hasattr(o, '__getitem__') else 0 pushed_values_count = py_to_lua_custom(runtime, L, o, type_flags) return pushed_values_count diff --git a/lupa/tests/test.py b/lupa/tests/test.py index e95d3780..1945b4e9 100644 --- a/lupa/tests/test.py +++ b/lupa/tests/test.py @@ -24,7 +24,6 @@ def _next(o): if IS_PYTHON2: unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp - class SetupLuaRuntimeMixin(object): lua_runtime_kwargs = {} @@ -98,7 +97,7 @@ def test_eval(self): self.assertEqual(2, self.lua.eval('1+1')) def test_eval_multi(self): - self.assertEqual((1, 2, 3), self.lua.eval('1,2,3')) + self.assertEqual((1,2,3), self.lua.eval('1,2,3')) def test_eval_args(self): self.assertEqual(2, self.lua.eval('...', 2)) @@ -171,7 +170,7 @@ def test_recursive_function(self): return fac ''') self.assertNotEqual(None, fac) - self.assertEqual(6, fac(3)) + self.assertEqual(6, fac(3)) self.assertEqual(3628800, fac(10)) def test_double_recursive_function(self): @@ -186,8 +185,8 @@ def test_double_recursive_function(self): ''' calc = self.lua.execute(func_code) self.assertNotEqual(None, calc) - self.assertEqual(3, calc(3)) - self.assertEqual(109, calc(10)) + self.assertEqual(3, calc(3)) + self.assertEqual(109, calc(10)) self.assertEqual(13529, calc(20)) def test_double_recursive_function_pycallback(self): @@ -200,15 +199,14 @@ def test_double_recursive_function_pycallback(self): end return calc ''' - def pycallback(i): - return i ** 2 + return i**2 calc = self.lua.execute(func_code) self.assertNotEqual(None, calc) - self.assertEqual(12, calc(pycallback, 3)) - self.assertEqual(1342, calc(pycallback, 10)) + self.assertEqual(12, calc(pycallback, 3)) + self.assertEqual(1342, calc(pycallback, 10)) self.assertEqual(185925, calc(pycallback, 20)) def test_none(self): @@ -245,7 +243,6 @@ def test_call_str_py(self): def test_call_str_class(self): called = [False] - class test(object): def __str__(self): called[0] = True @@ -270,7 +267,7 @@ def test_len_table_array(self): def test_len_table_dict(self): table = self.lua.eval('{a=1, b=2, c=3}') - self.assertEqual(0, len(table)) # as returned by Lua's "#" operator + self.assertEqual(0, len(table)) # as returned by Lua's "#" operator def test_table_delattr(self): table = self.lua.eval('{a=1, b=2, c=3}') @@ -296,27 +293,27 @@ def test_len_table(self): def test_iter_table(self): table = self.lua.eval('{2,3,4,5,6}') - self.assertEqual([1, 2, 3, 4, 5], list(table)) + self.assertEqual([1,2,3,4,5], list(table)) def test_iter_table_list_repeat(self): table = self.lua.eval('{2,3,4,5,6}') - self.assertEqual([1, 2, 3, 4, 5], list(table)) # 1 - self.assertEqual([1, 2, 3, 4, 5], list(table)) # 2 - self.assertEqual([1, 2, 3, 4, 5], list(table)) # 3 + self.assertEqual([1,2,3,4,5], list(table)) # 1 + self.assertEqual([1,2,3,4,5], list(table)) # 2 + self.assertEqual([1,2,3,4,5], list(table)) # 3 def test_iter_array_table_values(self): table = self.lua.eval('{2,3,4,5,6}') - self.assertEqual([2, 3, 4, 5, 6], list(table.values())) + self.assertEqual([2,3,4,5,6], list(table.values())) def test_iter_array_table_repeat(self): table = self.lua.eval('{2,3,4,5,6}') - self.assertEqual([2, 3, 4, 5, 6], list(table.values())) # 1 - self.assertEqual([2, 3, 4, 5, 6], list(table.values())) # 2 - self.assertEqual([2, 3, 4, 5, 6], list(table.values())) # 3 + self.assertEqual([2,3,4,5,6], list(table.values())) # 1 + self.assertEqual([2,3,4,5,6], list(table.values())) # 2 + self.assertEqual([2,3,4,5,6], list(table.values())) # 3 def test_iter_multiple_tables(self): count = 10 - table_values = [self.lua.eval('{%s}' % ','.join(map(str, range(2, count + 2)))).values() + table_values = [self.lua.eval('{%s}' % ','.join(map(str, range(2, count+2)))).values() for _ in range(4)] # round robin @@ -325,11 +322,11 @@ def test_iter_multiple_tables(self): for table in table_values: sublist.append(_next(table)) - self.assertEqual([[i] * len(table_values) for i in range(2, count + 2)], l) + self.assertEqual([[i]*len(table_values) for i in range(2, count+2)], l) def test_iter_table_repeat(self): count = 10 - table_values = [self.lua.eval('{%s}' % ','.join(map(str, range(2, count + 2)))).values() + table_values = [self.lua.eval('{%s}' % ','.join(map(str, range(2, count+2)))).values() for _ in range(4)] # one table after the other @@ -338,7 +335,7 @@ def test_iter_table_repeat(self): for sublist in l: sublist.append(_next(table)) - self.assertEqual([[i] * len(table_values) for i in range(2, count + 2)], l) + self.assertEqual([[i]*len(table_values) for i in range(2,count+2)], l) def test_iter_table_refcounting(self): lua_func = self.lua.eval(''' @@ -390,20 +387,20 @@ def test_iter_table_values_int_keys(self): table = self.lua.eval('{%s}' % ','.join('[%d]=%d' % (i, -i) for i in range(10))) l = list(table.values()) l.sort() - self.assertEqual(list(range(-9, 1)), l) + self.assertEqual(list(range(-9,1)), l) def test_iter_table_items(self): keys = list('abcdefg') table = self.lua.eval('{%s}' % ','.join('%s=%d' % (c, i) for i, c in enumerate(keys))) l = list(table.items()) l.sort() - self.assertEqual(list(zip(keys, range(len(keys)))), l) + self.assertEqual(list(zip(keys,range(len(keys)))), l) def test_iter_table_items_int_keys(self): table = self.lua.eval('{%s}' % ','.join('[%d]=%d' % (i, -i) for i in range(10))) l = list(table.items()) l.sort() - self.assertEqual(list(zip(range(10), range(0, -10, -1))), l) + self.assertEqual(list(zip(range(10), range(0,-10,-1))), l) def test_iter_table_values_mixed(self): keys = list('abcdefg') @@ -436,7 +433,7 @@ def test_string_values(self): def test_int_values(self): function = self.lua.eval('function(i) return i + 5 end') - self.assertEqual(3 + 5, function(3)) + self.assertEqual(3+5, function(3)) def test_long_values(self): try: @@ -444,11 +441,11 @@ def test_long_values(self): except NameError: _long = int function = self.lua.eval('function(i) return i + 5 end') - self.assertEqual(3 + 5, function(_long(3))) + self.assertEqual(3+5, function(_long(3))) def test_float_values(self): function = self.lua.eval('function(i) return i + 5 end') - self.assertEqual(float(3) + 5, function(float(3))) + self.assertEqual(float(3)+5, function(float(3))) def test_str_function(self): func = self.lua.eval('function() return 1 end') @@ -459,7 +456,7 @@ def test_str_table(self): self.assertEqual('