From d7f7bed055b4fa7aeb3b4b376490198c6599a7b0 Mon Sep 17 00:00:00 2001 From: Jun Jia Date: Fri, 31 Jul 2020 19:13:46 -0700 Subject: [PATCH] Update smart-arg to 0.1.1 * Update smart-arg to 0.1.1 (relevant changes: 1. now the argpaser override fields start with a double-underscore: private to discourage such usage; avoid conflict with built-in attr, such as _asdict; 2. use __post_init__ instead of __late_init__, to align with dataclasses, as well as supporting in-place mutation within __post_init__ at initialization time. 3. removed namespace for simplicity); * Move `if hparams.ftr_ext != 'cnn': hparams.filter_window_sizes = [0]` to `__post_init__` --- setup.py | 2 +- src/detext/run_detext.py | 63 +++++++++++++++++----------------- src/detext/utils/misc_utils.py | 4 --- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/setup.py b/setup.py index a2945cb..dbff6b7 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ include_package_data=True, install_requires=[ 'numpy<1.17', - 'smart-arg==0.0.5', + 'smart-arg==0.1.1', 'tensorflow==1.14.0', 'tensorflow_ranking==0.1.4', 'gast==0.2.2' diff --git a/src/detext/run_detext.py b/src/detext/run_detext.py index c836ac8..df691fa 100644 --- a/src/detext/run_detext.py +++ b/src/detext/run_detext.py @@ -14,7 +14,7 @@ from detext.train.data_fn import input_fn from detext.utils import misc_utils, logger from detext.utils.executor_utils import EVALUATOR, LOCAL_MODE, CHIEF, get_executor_task_type -from linkedin.smart_arg import arg_suite +from smart_arg import arg_suite from tensorflow.contrib.training import HParams @@ -29,13 +29,14 @@ class DetextArg(NamedTuple): """ # kwargs utility function for split a str to a List, provided for backward compatibilities # Please use built-in List parsing support when adding new arguments - __comma_split_list = lambda t: {'type': lambda s: [t(item) for item in s.split(',')] if ',' in s else t(s), 'nargs': None} # noqa: E731 + _comma_split_list = lambda t: {'type': lambda s: [t(item) for item in s.split(',')] if ',' in s else t(s), 'nargs': None} # noqa: E731 feature_names: List[str] # Feature names. + __feature_names = _comma_split_list(str) # network ftr_ext: str # NLP feature extraction module. - _ftr_ext = {'choices': ['cnn', 'bert', 'lstm_lm', 'lstm']} + __ftr_ext = {'choices': ['cnn', 'bert', 'lstm_lm', 'lstm']} num_units: int = 128 # word embedding size. num_units_for_id_ftr: int = 128 # id feature embedding size. sp_emb_size: int = 1 # Embedding size of sparse features @@ -43,7 +44,7 @@ class DetextArg(NamedTuple): num_wide: int = 0 # number of wide features per doc. num_wide_sp: Optional[int] = None # number of sparse wide features per doc use_deep: bool = True # Whether to use deep features. - _use_deep = {'type': lambda s: True if s.lower() == 'true' else False if s.lower() == 'false' else s} + __use_deep = {'type': lambda s: True if s.lower() == 'true' else False if s.lower() == 'false' else s} elem_rescale: bool = True # Whether to perform elementwise rescaling. use_doc_projection: bool = False # whether to project multiple doc features to 1 vector. use_usr_projection: bool = False # whether to project multiple usr features to 1 vector. @@ -51,7 +52,7 @@ class DetextArg(NamedTuple): # Ranking specific ltr_loss_fn: str = 'pairwise' # learning-to-rank method. emb_sim_func: List[str] = ['inner'] # Approach to computing query/doc similarity scores - _emb_sim_func = {"choices": ('inner', 'hadamard', 'concat', 'diff')} + __emb_sim_func = {"choices": ('inner', 'hadamard', 'concat', 'diff')} # Classification specific num_classes: int = 1 # Number of classes for multi-class classification tasks. @@ -69,7 +70,7 @@ class DetextArg(NamedTuple): # LSTM related unit_type: str = 'lstm' # RNN cell unit type. Currently only supports lstm. Will support other cell types in the future - _unit_type = {'choices': ['lstm']} + __unit_type = {'choices': ['lstm']} num_layers: int = 1 # RNN layers num_residual_layers: int = 0 # Number of residual layers from top to bottom. For example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN cells in the returned list will be wrapped with `ResidualWrapper`. # noqa: E501 forget_bias: float = 1. # Forget bias of RNN cell @@ -79,7 +80,7 @@ class DetextArg(NamedTuple): # Optimizer optimizer: str = 'sgd' # Type of optimizer to use. bert_adam is similar to the optimizer implementation in bert. - _optimizer = {'choices': ['sgd', 'adam', 'bert_adam', 'bert_lamb']} + __optimizer = {'choices': ['sgd', 'adam', 'bert_adam', 'bert_lamb']} max_gradient_norm: float = 1.0 # Clip gradients to this norm. learning_rate: float = 1.0 # Learning rate. Adam: 0.001 | 0.0001 num_train_steps: int = 1 # Num steps to train. @@ -122,7 +123,6 @@ class DetextArg(NamedTuple): num_eval_rounds: Optional[int] = None # number of evaluation round, this param will override steps_per_eval as max(1, num_train_steps / num_eval_rounds) steps_per_eval: int = 1000 # training steps to evaluate datasets. keep_checkpoint_max: int = 5 # The maximum number of recent checkpoint files to keep. If 0, all checkpoint files are kept. Defaults to 5 - _feature_names = __comma_split_list(str) lambda_metric: Optional[str] = None # only support ndcg. init_weight: float = 0.1 # weight initialization value. pmetric: Optional[str] = None # Primary metric. @@ -133,7 +133,7 @@ class DetextArg(NamedTuple): add_first_dim_for_usr_placeholder: bool = False # Whether to add a batch dimension for query and usr_* placeholders. This shall be set to True if usr fields are used document feature in model serving.# noqa: E501 tokenization: str = 'punct' # The tokenzation performed for data preprocessing. Currently support: punct/plain(no split). Note that this should be set correctly to ensure consistency for savedmodel.# noqa: E501 - _tokenization = {'choices': ['plain', 'punct']} + __tokenization = {'choices': ['plain', 'punct']} resume_training: bool = False # Whether to resume training from checkpoint in out_dir. metadata_path: Optional[str] = None # The metadata_path for converted avro2tf avro data. @@ -141,45 +141,46 @@ class DetextArg(NamedTuple): # tf-ranking related use_tfr_loss: bool = False # whether to use tf-ranking loss. tfr_loss_fn: str = tfr.losses.RankingLossKey.SOFTMAX_LOSS # tf-ranking loss - _tfr_loss_fn = {'choices': [tfr.losses.RankingLossKey.SOFTMAX_LOSS, tfr.losses.RankingLossKey.PAIRWISE_LOGISTIC_LOSS]} + __tfr_loss_fn = {'choices': [tfr.losses.RankingLossKey.SOFTMAX_LOSS, tfr.losses.RankingLossKey.PAIRWISE_LOGISTIC_LOSS]} tfr_lambda_weights: Optional[str] = None # use_horovod: bool = False # whether to use horovod for sync distributed training hvd_info: Optional[Dict[str, int]] = None # multitask training related task_ids: Optional[List[int]] = None # All types of task IDs for multitask training. E.g., 1,2,3 - _task_ids = __comma_split_list(int) + __task_ids = _comma_split_list(int) task_weights: Optional[List[float]] = None # Weights for each task specified in task_ids. E.g., 0.5,0.3,0.2 - _task_weights = __comma_split_list(float) + __task_weights = _comma_split_list(float) # This method is automatically called by smart-arg once the argument is created by parsing cli or the constructor # It's used to late-initialize some fields after other fields are created. - def __late_init__(self): - arg = self + def __post_init__(self): + logging.info(f"Start __post_init__ the argument now: {self}") + + # if not using cnn models, then disable cnn parameters + if self.ftr_ext != 'cnn': + self.filter_window_sizes = [0] # if epoch is set, overwrite training steps - if arg.num_epochs is not None: - arg = arg._replace(num_train_steps=misc_utils.estimate_train_steps( - arg.train_file, - arg.num_epochs, - arg.train_batch_size, - arg.metadata_path is None)) + if self.num_epochs is not None: + self.num_train_steps = misc_utils.estimate_train_steps( + self.train_file, + self.num_epochs, + self.train_batch_size, + self.metadata_path is None) # if num_eval_rounds is set, override steps_per_eval - if arg.num_eval_rounds is not None: - arg = arg._replace(steps_per_eval=max(1, int(arg.num_train_steps / arg.num_eval_rounds))) + if self.num_eval_rounds is not None: + self.steps_per_eval = max(1, int(self.num_train_steps / self.num_eval_rounds)) # For data sharding when using horovod - if arg.use_horovod: + if self.use_horovod: import horovod.tensorflow as hvd hvd.init() - arg = arg._replace( - num_train_steps=arg.num_train_steps // hvd.size(), - num_warmup_steps=arg.num_warmup_steps // hvd.size(), - steps_per_eval=arg.steps_per_eval // hvd.size(), - steps_per_stats=arg.steps_per_stats // hvd.size(), - hvd_info={'rank': hvd.rank(), 'size': hvd.size()}) - - return arg + self.num_train_steps = self.num_train_steps // hvd.size() + self.num_warmup_steps = self.num_warmup_steps // hvd.size() + self.steps_per_eval = self.steps_per_eval // hvd.size() + self.steps_per_stats = self.steps_per_stats // hvd.size() + self.hvd_info = {'rank': hvd.rank(), 'size': hvd.size()} def get_hparams(argument: DetextArg = None): diff --git a/src/detext/utils/misc_utils.py b/src/detext/utils/misc_utils.py index 0b914d0..b4f30e7 100644 --- a/src/detext/utils/misc_utils.py +++ b/src/detext/utils/misc_utils.py @@ -99,10 +99,6 @@ def extend_hparams(hparams): tok2regex_pattern = {'plain': None, 'punct': r'(\pP)'} hparams.regex_replace_pattern = tok2regex_pattern[hparams.tokenization] - # if not using cnn models, then disable cnn parameters - if hparams.ftr_ext != 'cnn': - hparams.filter_window_sizes = [0] - assert hparams.pmetric is not None, "Please set your primary evaluation metric using --pmetric option" assert hparams.pmetric != 'confusion_matrix', 'confusion_matrix cannot be used as primary evaluation metric.'