From dae82270f5ac453cd861a08127f4a9c77c789def Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 26 Jun 2023 09:20:20 +0000 Subject: [PATCH 1/2] merge --- .../executor/function_graph.py | 1 + .../executor/opcode_executor.py | 11 ++-- .../executor/pycode_generator.py | 52 ++++++++++++++++++- 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index 08cecb5d4..dec0c8b5c 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -127,6 +127,7 @@ def start_compile(self, *ret_vars: VariableBase): for ret_var in ret_vars for ret_item in ret_var.flatten_items() ] + self.pycode_gen.gen_disable_eval_frame() tensor_items = self._find_tensor_outputs(ret_items) compiled_fn, statment_ir = self.sir_ctx.compile_fn( [Symbol(tensor_var.var_name) for tensor_var in tensor_items] diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 4e95c1506..c8293c201 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -1118,7 +1118,7 @@ def _break_graph_in_jump(self, result, instr): for name in if_inputs: self.get_var(name).reconstruct(self._graph.pycode_gen) self._graph.pycode_gen.gen_call_function( - argc=if_fn.__code__.co_argcount + argc=if_fn.__code__.co_argcount, enable_evalframe=True ) self._graph.pycode_gen.gen_return() else: @@ -1135,7 +1135,7 @@ def _break_graph_in_jump(self, result, instr): for name in else_inputs: self.get_var(name).reconstruct(self._graph.pycode_gen) self._graph.pycode_gen.gen_call_function( - argc=else_fn.__code__.co_argcount + argc=else_fn.__code__.co_argcount, enable_evalframe=True ) self._graph.pycode_gen.gen_return() else: @@ -1197,7 +1197,7 @@ def _break_graph_in_call(self, origin_stack, instr, push_n): for name in resume_input_name: self._locals[name].reconstruct(self._graph.pycode_gen) self._graph.pycode_gen.gen_call_function( - argc=resume_fn.__code__.co_argcount + argc=resume_fn.__code__.co_argcount, enable_evalframe=True ) # gen RETURN_VALUE @@ -1299,7 +1299,8 @@ def update_locals(name, variable): # 5.4 call loop body self._graph.pycode_gen.gen_call_function( - argc=loop_body.__code__.co_argcount + argc=loop_body.__code__.co_argcount, + enable_evalframe=True, ) # 5.5 unpack and store retval, keep break_flag in stack @@ -1328,7 +1329,7 @@ def update_locals(name, variable): self._graph.pycode_gen.gen_load_fast(name) self._graph.pycode_gen.gen_call_function( - argc=after_loop_fn.__code__.co_argcount + argc=after_loop_fn.__code__.co_argcount, enable_evalframe=True ) self._graph.pycode_gen.gen_return() diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index d10c2b0d1..d723878cf 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -11,6 +11,8 @@ import opcode +import paddle + from ...utils import ( ResumeFnNameFactory, list_contain_by_id, @@ -352,6 +354,52 @@ def gen_load_const(self, value): idx = list_find_index_by_id(self._code_options["co_consts"], value) self._add_instr("LOAD_CONST", arg=idx, argval=value) + def gen_disable_eval_frame(self): + self.gen_load_object( + paddle.fluid.core.set_eval_frame, "paddle_set_eval_frame_function" + ) + self.gen_load_const(None) + self.gen_call_function(1) + self.gen_store_fast("paddle_old_eval_frame_fn") + + def gen_enable_eval_frame(self): + self.gen_load_object( + paddle.fluid.core.set_eval_frame, "paddle_set_eval_frame_function" + ) + self.gen_load_fast("paddle_old_eval_frame_fn") + self.gen_call_function(1) + self.gen_pop_top() + + def gen_print_log(self, message): + """print a log :""" + self.gen_disable_eval_frame() + self.gen_load_global("print") + self.gen_load_const(message) + self.gen_call_function(1) + self.gen_enable_eval_frame() + + def gen_dbg_function(self, dbg_fun): + """debug bytecode helper function. + Usage like: + def dbg_func(): + import inspect + import dis + print("dbg here.") + print(locals()) + dis.dis(inspect.currentframe().f_back.f_code) + frame = inspect.currentframe().f_back + code = (inspect.currentframe().f_back.f_code) + breakpoint() + print(inspect.currentframe().f_back.f_locals['y']) + + self.pycode_gen.gen_dbg_function(dbg_func) + """ + self.gen_disable_eval_frame() + self.gen_load_object(dbg_fun, "dbg1") + self.gen_call_function(0) + self.gen_pop_top() + self.gen_enable_eval_frame() + def gen_load_global(self, name): if name not in self._code_options["co_names"]: self._code_options["co_names"].append(name) @@ -424,7 +472,9 @@ def gen_build_map(self, count): def gen_unpack_sequence(self, count): self._add_instr("UNPACK_SEQUENCE", arg=count, argval=count) - def gen_call_function(self, argc=0): + def gen_call_function(self, argc=0, enable_evalframe=False): + if enable_evalframe: + self.gen_enable_eval_frame() self._add_instr("CALL_FUNCTION", arg=argc, argval=argc) def gen_pop_top(self): From 4bf671709330618a5e5d2698b1313f19b36a89b2 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 26 Jun 2023 12:55:39 +0000 Subject: [PATCH 2/2] split all not implement exception to valid semantics --- .../executor/opcode_executor.py | 42 ++++++++----------- .../executor/variable_dispatch.py | 10 ++--- .../executor/variables/base.py | 6 +-- .../executor/variables/basic.py | 4 +- .../executor/variables/container.py | 10 +++-- sot/utils/exceptions.py | 8 ++++ 6 files changed, 41 insertions(+), 39 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index c8293c201..1377d16c7 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -13,6 +13,8 @@ BreakGraphError, InnerError, NotImplementException, + NotImplementFatal, + NotImplementInsignificant, Singleton, is_strict_mode, log, @@ -240,7 +242,7 @@ def inner(*args, **kwargs): try: return fn(*args, **kwargs) except Exception as e: - raise NotImplementException( + raise NotImplementFatal( f'An exception occurred when processing break graph, fallback to dygraph, error message is: \n{type(e)} : {e}\n' ) @@ -339,9 +341,7 @@ def step(self, instr: Instruction): if instr.starts_line is not None: self._current_line = instr.starts_line if not hasattr(self, instr.opname): - raise NotImplementException( - f"opcode: {instr.opname} is not supported." - ) + raise NotImplementFatal(f"opcode: {instr.opname} is not supported.") log( 3, f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self._stack}\n", @@ -679,7 +679,7 @@ def CALL_FUNCTION(self, instr): kwargs = {} fn = self.pop() if not isinstance(fn, CallableVariable): - raise NotImplementException(f"CALL_FUNCTION: {fn} is not callable") + raise NotImplementFatal(f"CALL_FUNCTION: {fn} is not callable") ret = fn(*args, **kwargs) self.push(ret) @@ -703,9 +703,7 @@ def CALL_FUNCTION_KW(self, instr): fn = self.pop() if not isinstance(fn, CallableVariable): - raise NotImplementException( - f"CALL_FUNCTION_KW: {fn} is not callable." - ) + raise NotImplementFatal(f"CALL_FUNCTION_KW: {fn} is not callable.") ret = fn(*args, **kwargs) self.push(ret) @@ -724,9 +722,7 @@ def CALL_FUNCTION_EX(self, instr): fn = self.pop() if not isinstance(fn, CallableVariable): - raise NotImplementException( - f"CALL_FUNCTION_EX: {fn} is not callable." - ) + raise NotImplementFatal(f"CALL_FUNCTION_EX: {fn} is not callable.") ret = fn(*args, **kwargs) self.push(ret) @@ -753,7 +749,7 @@ def COMPARE_OP(self, instr): ) return except Exception as e: - raise NotImplementException( + raise NotImplementFatal( f"{instr} is not support between {left} and {right}. may be not a supported compare op." ) @@ -790,7 +786,7 @@ def MAKE_FUNCTION(self, instr): related_list.append(self.pop()) if flag & MF.MF_HAS_KWDEFAULTS: - raise NotImplementException( + raise NotImplementFatal( "Found need func_kwdefaults when MAKE_FUNCTION." ) @@ -888,7 +884,7 @@ def JUMP_IF_FALSE_OR_POP(self, instr): else: self.pop() return - raise NotImplementException( + raise NotImplementFatal( "Currently don't support predicate a non-const / non-tensor obj." ) @@ -903,7 +899,7 @@ def JUMP_IF_TRUE_OR_POP(self, instr): else: self.pop() return - raise NotImplementException( + raise NotImplementFatal( "Currently don't support predicate a non-const / non-tensor obj." ) @@ -916,7 +912,7 @@ def POP_JUMP_IF_FALSE(self, instr): if is_jump: self._lasti = self.indexof(instr.jump_to) return - raise NotImplementException( + raise NotImplementFatal( "Currently don't support predicate a non-const / non-tensor obj." ) @@ -929,7 +925,7 @@ def POP_JUMP_IF_TRUE(self, instr): if is_jump: self._lasti = self.indexof(instr.jump_to) return - raise NotImplementException( + raise NotImplementFatal( "Currently don't support predicate a non-const / non-tensor obj." ) @@ -945,13 +941,11 @@ def UNPACK_SEQUENCE(self, instr): ''' if isinstance(sequence, TensorVariable): # TODO: If need to unpack a Tensor, should have different logic. - raise NotImplementException("Unpack a iterator is not implemented.") + raise NotImplementFatal("Unpack a tensor is not implemented.") elif isinstance(sequence, (ListVariable, TupleVariable)): seq = sequence.value else: - raise NotImplementException( - f"Unpack {sequence} is not implemented." - ) + raise NotImplementFatal(f"Unpack {sequence} is not implemented.") assert ( len(seq) == instr.arg @@ -1002,9 +996,7 @@ def FORMAT_VALUE(self, instr): ) ) else: - raise NotImplementException( - f"Do not support format {type(value)} now" - ) + raise NotImplementFatal(f"Do not support format {type(value)} now") # NOTE: This operation will generate SideEffects, and the mechanism has not been completed yet def DICT_UPDATE(self, instr): @@ -1360,7 +1352,7 @@ def FOR_ITER(self, instr): end = self.indexof(instr.jump_to) for i in range(start, end): if self._instructions[i].opname == "RETURN_VALUE": - raise NotImplementException( + raise NotImplementInsignificant( "Found RETURN_VALUE in for loop body." ) diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index e65a1c18d..b6a3d2963 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -6,7 +6,7 @@ import paddle -from ...utils import BreakGraphError, NotImplementException +from ...utils import BreakGraphError, InnerError, NotImplementInsignificant from ...utils.magic_methods import ( BINARY_OPS, UNARY_OPS, @@ -250,8 +250,8 @@ def tensor_mod_dispatcher( raise BreakGraphError( "(ConstantVariable % TensorVariable) raise a callback. " ) - raise NotImplementException( - "Tensor doesn't support __rmod__" + raise InnerError( + "TypeError: unsupported operand type(s) for %: 'int' and 'Tensor'" ) else: @@ -275,7 +275,7 @@ def tensor_mod_dispatcher( @Dispatcher.register_decorator(unary_fn) def numpy_unary_dispatcher(var: NumpyVariable): - raise NotImplementException( + raise NotImplementInsignificant( 'Numpy operator need fallback to dygraph' ) @@ -285,6 +285,6 @@ def numpy_unary_dispatcher(var: NumpyVariable): @Dispatcher.register_decorator(binary_fn) def numpy_binary_dispatcher(var: NumpyVariable, other: NumpyVariable): - raise NotImplementException( + raise NotImplementInsignificant( 'Numpy operator need fallback to dygraph' ) diff --git a/sot/opcode_translator/executor/variables/base.py b/sot/opcode_translator/executor/variables/base.py index 710a320c9..a6b0a60cd 100644 --- a/sot/opcode_translator/executor/variables/base.py +++ b/sot/opcode_translator/executor/variables/base.py @@ -7,7 +7,7 @@ import paddle from ....utils import NameGenerator, get_unbound_method, log, log_do -from ....utils.exceptions import InnerError, NotImplementException +from ....utils.exceptions import InnerError, NotImplementFatal from ..guard import StringifyExpression, union_free_vars from ..pycode_generator import PyCodeGen from ..tracker import DummyTracker, GetAttrTracker, GetItemTracker, Tracker @@ -224,7 +224,7 @@ def reconstruct(self, codegen: PyCodeGen): self._reconstruct(codegen) def _reconstruct(self, codegen: PyCodeGen): - raise NotImplementException() + raise NotImplementFatal("Not implement reconstruct.") def flatten_items(self) -> list[VariableBase]: from .container import ContainerVariable @@ -281,7 +281,7 @@ def getattr(self, name: str): ) def __setitem__(self, key, value): - raise NotImplementException(f"{self} is not support setitem.") + raise NotImplementFatal(f"{self} is not support setitem.") def __repr__(self): info = {**self.main_info, **self.debug_info} diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index a38603af4..b055b4317 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -14,7 +14,7 @@ from ....utils import ( BreakGraphError, NameGenerator, - NotImplementException, + NotImplementInsignificant, log_do, paddle_tensor_methods, ) @@ -453,7 +453,7 @@ def format_number(number: np.number): union_free_vars(frame_value_tracer.free_vars, {"np": np}), ) else: - raise NotImplementException( + raise NotImplementInsignificant( "We can not stringify numpy variable when value is np.ndarray" ) diff --git a/sot/opcode_translator/executor/variables/container.py b/sot/opcode_translator/executor/variables/container.py index 89db0fcba..80864ef03 100644 --- a/sot/opcode_translator/executor/variables/container.py +++ b/sot/opcode_translator/executor/variables/container.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any from ....utils import log_do -from ....utils.exceptions import InnerError, NotImplementException +from ....utils.exceptions import InnerError, NotImplementFatal from ..guard import StringifyExpression from ..pycode_generator import PyCodeGen from ..tracker import ( @@ -24,10 +24,12 @@ class ContainerVariable(VariableBase): def get_items(self) -> list[VariableBase]: - raise NotImplementException() + raise NotImplementFatal( + "Not implement get_items for container variable." + ) def __len__(self): - raise NotImplementException() + raise NotImplementFatal("Not implement __len__ for container variable.") def len(self): return VariableFactory.from_value( @@ -401,7 +403,7 @@ def getattr(self, name): builtin_fn, self.graph, DanglingTracker() ).bind(self, name) else: - raise NotImplementException( + raise NotImplementFatal( f"attribute {name} for dict is not implemented" ) diff --git a/sot/utils/exceptions.py b/sot/utils/exceptions.py index f5d55bd6a..aeaf3ff35 100644 --- a/sot/utils/exceptions.py +++ b/sot/utils/exceptions.py @@ -10,6 +10,14 @@ class NotImplementException(FallbackErrorBase): pass +class NotImplementFatal(NotImplementException): + pass + + +class NotImplementInsignificant(NotImplementException): + pass + + # raise in inline function call strategy. class BreakGraphError(FallbackErrorBase): pass