Skip to content

Commit

Permalink
Update smart-arg to 0.1.1
Browse files Browse the repository at this point in the history
* Update smart-arg to 0.1.1 (relevant changes: 1. now the argpaser override fields start with a double-underscore: private to discourage such usage; avoid conflict with built-in attr, such as _asdict; 2. use __post_init__ instead of __late_init__, to align with dataclasses, as well as supporting in-place mutation within __post_init__ at initialization time. 3. removed namespace for simplicity);
* Move `if hparams.ftr_ext != 'cnn': hparams.filter_window_sizes = [0]` to `__post_init__`
  • Loading branch information
jakiejj authored and StarWang committed Aug 22, 2020
1 parent ec44761 commit d7f7bed
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 36 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
include_package_data=True,
install_requires=[
'numpy<1.17',
'smart-arg==0.0.5',
'smart-arg==0.1.1',
'tensorflow==1.14.0',
'tensorflow_ranking==0.1.4',
'gast==0.2.2'
Expand Down
63 changes: 32 additions & 31 deletions src/detext/run_detext.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from detext.train.data_fn import input_fn
from detext.utils import misc_utils, logger
from detext.utils.executor_utils import EVALUATOR, LOCAL_MODE, CHIEF, get_executor_task_type
from linkedin.smart_arg import arg_suite
from smart_arg import arg_suite
from tensorflow.contrib.training import HParams


Expand All @@ -29,29 +29,30 @@ class DetextArg(NamedTuple):
"""
# kwargs utility function for split a str to a List, provided for backward compatibilities
# Please use built-in List parsing support when adding new arguments
__comma_split_list = lambda t: {'type': lambda s: [t(item) for item in s.split(',')] if ',' in s else t(s), 'nargs': None} # noqa: E731
_comma_split_list = lambda t: {'type': lambda s: [t(item) for item in s.split(',')] if ',' in s else t(s), 'nargs': None} # noqa: E731

feature_names: List[str] # Feature names.
__feature_names = _comma_split_list(str)

# network
ftr_ext: str # NLP feature extraction module.
_ftr_ext = {'choices': ['cnn', 'bert', 'lstm_lm', 'lstm']}
__ftr_ext = {'choices': ['cnn', 'bert', 'lstm_lm', 'lstm']}
num_units: int = 128 # word embedding size.
num_units_for_id_ftr: int = 128 # id feature embedding size.
sp_emb_size: int = 1 # Embedding size of sparse features
num_hidden: List[int] = [0] # hidden size.
num_wide: int = 0 # number of wide features per doc.
num_wide_sp: Optional[int] = None # number of sparse wide features per doc
use_deep: bool = True # Whether to use deep features.
_use_deep = {'type': lambda s: True if s.lower() == 'true' else False if s.lower() == 'false' else s}
__use_deep = {'type': lambda s: True if s.lower() == 'true' else False if s.lower() == 'false' else s}
elem_rescale: bool = True # Whether to perform elementwise rescaling.
use_doc_projection: bool = False # whether to project multiple doc features to 1 vector.
use_usr_projection: bool = False # whether to project multiple usr features to 1 vector.

# Ranking specific
ltr_loss_fn: str = 'pairwise' # learning-to-rank method.
emb_sim_func: List[str] = ['inner'] # Approach to computing query/doc similarity scores
_emb_sim_func = {"choices": ('inner', 'hadamard', 'concat', 'diff')}
__emb_sim_func = {"choices": ('inner', 'hadamard', 'concat', 'diff')}

# Classification specific
num_classes: int = 1 # Number of classes for multi-class classification tasks.
Expand All @@ -69,7 +70,7 @@ class DetextArg(NamedTuple):

# LSTM related
unit_type: str = 'lstm' # RNN cell unit type. Currently only supports lstm. Will support other cell types in the future
_unit_type = {'choices': ['lstm']}
__unit_type = {'choices': ['lstm']}
num_layers: int = 1 # RNN layers
num_residual_layers: int = 0 # Number of residual layers from top to bottom. For example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN cells in the returned list will be wrapped with `ResidualWrapper`. # noqa: E501
forget_bias: float = 1. # Forget bias of RNN cell
Expand All @@ -79,7 +80,7 @@ class DetextArg(NamedTuple):

# Optimizer
optimizer: str = 'sgd' # Type of optimizer to use. bert_adam is similar to the optimizer implementation in bert.
_optimizer = {'choices': ['sgd', 'adam', 'bert_adam', 'bert_lamb']}
__optimizer = {'choices': ['sgd', 'adam', 'bert_adam', 'bert_lamb']}
max_gradient_norm: float = 1.0 # Clip gradients to this norm.
learning_rate: float = 1.0 # Learning rate. Adam: 0.001 | 0.0001
num_train_steps: int = 1 # Num steps to train.
Expand Down Expand Up @@ -122,7 +123,6 @@ class DetextArg(NamedTuple):
num_eval_rounds: Optional[int] = None # number of evaluation round, this param will override steps_per_eval as max(1, num_train_steps / num_eval_rounds)
steps_per_eval: int = 1000 # training steps to evaluate datasets.
keep_checkpoint_max: int = 5 # The maximum number of recent checkpoint files to keep. If 0, all checkpoint files are kept. Defaults to 5
_feature_names = __comma_split_list(str)
lambda_metric: Optional[str] = None # only support ndcg.
init_weight: float = 0.1 # weight initialization value.
pmetric: Optional[str] = None # Primary metric.
Expand All @@ -133,53 +133,54 @@ class DetextArg(NamedTuple):
add_first_dim_for_usr_placeholder: bool = False # Whether to add a batch dimension for query and usr_* placeholders. This shall be set to True if usr fields are used document feature in model serving.# noqa: E501

tokenization: str = 'punct' # The tokenzation performed for data preprocessing. Currently support: punct/plain(no split). Note that this should be set correctly to ensure consistency for savedmodel.# noqa: E501
_tokenization = {'choices': ['plain', 'punct']}
__tokenization = {'choices': ['plain', 'punct']}

resume_training: bool = False # Whether to resume training from checkpoint in out_dir.
metadata_path: Optional[str] = None # The metadata_path for converted avro2tf avro data.

# tf-ranking related
use_tfr_loss: bool = False # whether to use tf-ranking loss.
tfr_loss_fn: str = tfr.losses.RankingLossKey.SOFTMAX_LOSS # tf-ranking loss
_tfr_loss_fn = {'choices': [tfr.losses.RankingLossKey.SOFTMAX_LOSS, tfr.losses.RankingLossKey.PAIRWISE_LOGISTIC_LOSS]}
__tfr_loss_fn = {'choices': [tfr.losses.RankingLossKey.SOFTMAX_LOSS, tfr.losses.RankingLossKey.PAIRWISE_LOGISTIC_LOSS]}
tfr_lambda_weights: Optional[str] = None #

use_horovod: bool = False # whether to use horovod for sync distributed training
hvd_info: Optional[Dict[str, int]] = None
# multitask training related
task_ids: Optional[List[int]] = None # All types of task IDs for multitask training. E.g., 1,2,3
_task_ids = __comma_split_list(int)
__task_ids = _comma_split_list(int)
task_weights: Optional[List[float]] = None # Weights for each task specified in task_ids. E.g., 0.5,0.3,0.2
_task_weights = __comma_split_list(float)
__task_weights = _comma_split_list(float)

# This method is automatically called by smart-arg once the argument is created by parsing cli or the constructor
# It's used to late-initialize some fields after other fields are created.
def __late_init__(self):
arg = self
def __post_init__(self):
logging.info(f"Start __post_init__ the argument now: {self}")

# if not using cnn models, then disable cnn parameters
if self.ftr_ext != 'cnn':
self.filter_window_sizes = [0]
# if epoch is set, overwrite training steps
if arg.num_epochs is not None:
arg = arg._replace(num_train_steps=misc_utils.estimate_train_steps(
arg.train_file,
arg.num_epochs,
arg.train_batch_size,
arg.metadata_path is None))
if self.num_epochs is not None:
self.num_train_steps = misc_utils.estimate_train_steps(
self.train_file,
self.num_epochs,
self.train_batch_size,
self.metadata_path is None)

# if num_eval_rounds is set, override steps_per_eval
if arg.num_eval_rounds is not None:
arg = arg._replace(steps_per_eval=max(1, int(arg.num_train_steps / arg.num_eval_rounds)))
if self.num_eval_rounds is not None:
self.steps_per_eval = max(1, int(self.num_train_steps / self.num_eval_rounds))

# For data sharding when using horovod
if arg.use_horovod:
if self.use_horovod:
import horovod.tensorflow as hvd
hvd.init()
arg = arg._replace(
num_train_steps=arg.num_train_steps // hvd.size(),
num_warmup_steps=arg.num_warmup_steps // hvd.size(),
steps_per_eval=arg.steps_per_eval // hvd.size(),
steps_per_stats=arg.steps_per_stats // hvd.size(),
hvd_info={'rank': hvd.rank(), 'size': hvd.size()})

return arg
self.num_train_steps = self.num_train_steps // hvd.size()
self.num_warmup_steps = self.num_warmup_steps // hvd.size()
self.steps_per_eval = self.steps_per_eval // hvd.size()
self.steps_per_stats = self.steps_per_stats // hvd.size()
self.hvd_info = {'rank': hvd.rank(), 'size': hvd.size()}


def get_hparams(argument: DetextArg = None):
Expand Down
4 changes: 0 additions & 4 deletions src/detext/utils/misc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ def extend_hparams(hparams):
tok2regex_pattern = {'plain': None, 'punct': r'(\pP)'}
hparams.regex_replace_pattern = tok2regex_pattern[hparams.tokenization]

# if not using cnn models, then disable cnn parameters
if hparams.ftr_ext != 'cnn':
hparams.filter_window_sizes = [0]

assert hparams.pmetric is not None, "Please set your primary evaluation metric using --pmetric option"
assert hparams.pmetric != 'confusion_matrix', 'confusion_matrix cannot be used as primary evaluation metric.'

Expand Down

0 comments on commit d7f7bed

Please sign in to comment.