Skip to content

Commit

Permalink
Add support for pretrained tok2vec to ud-train
Browse files Browse the repository at this point in the history
  • Loading branch information
honnibal committed Nov 29, 2018
1 parent 93be3ad commit 681258e
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions spacy/cli/ud_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,28 @@ def initialize_pipeline(nlp, docs, golds, config, device):
nlp.tagger.add_label(tag)
if torch is not None and device != -1:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
return nlp.begin_training(
optimizer = nlp.begin_training(
lambda: golds_to_gold_tuples(docs, golds), device=device,
subword_features=config.subword_features, conv_depth=config.conv_depth,
bilstm_depth=config.bilstm_depth)
if config.pretrained_tok2vec:
_load_pretrained_tok2vec(nlp, config.pretrained_tok2vec)
return optimizer


def _load_pretrained_tok2vec(nlp, loc):
"""Load pre-trained weights for the 'token-to-vector' part of the component
models, which is typically a CNN. See 'spacy pretrain'. Experimental.
"""
with Path(loc).open('rb') as file_:
weights_data = file_.read()
loaded = []
for name, component in nlp.pipeline:
if hasattr(component, 'model') and hasattr(component.model, 'tok2vec'):
component.tok2vec.from_bytes(weights_data)
loaded.append(name)
return loaded



########################
Expand All @@ -318,9 +336,9 @@ def initialize_pipeline(nlp, docs, golds, config, device):
class Config(object):
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=False,
multitask_sent=False, multitask_dep=False, multitask_vectors=None,
bilstm_depth=0, nr_epoch=30, min_batch_size=100, max_batch_size=1000,
batch_by_words=True, dropout=0.2, conv_depth=4, subword_features=True,
vectors_dir=None):
bilstm_depth=0, nr_epoch=30, min_batch_size=750, max_batch_size=750,
batch_by_words=True, dropout=0.1, conv_depth=4, subword_features=True,
vectors_dir=None, pretrained_tok2vec=None):
if vectors_dir is not None:
if vectors is None:
vectors = True
Expand Down

0 comments on commit 681258e

Please sign in to comment.