szha commented on a change in pull request #10074: Add vocabulary and embedding
URL: https://github.com/apache/incubator-mxnet/pull/10074#discussion_r174643108
 
 

 ##########
 File path: python/mxnet/gluon/text/embedding.py
 ##########
 @@ -0,0 +1,582 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=consider-iterating-dictionary
+
+"""Text token embedding."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+import io
+import logging
+import os
+import tarfile
+import warnings
+import zipfile
+
+from . import _constants as C
+from ... import nd
+from ... import registry
+from ..utils import check_sha1, download, _get_repo_file_url
+
+
+def register(embedding_cls):
+    """Registers a new token embedding.
+
+
+    Once an embedding is registered, we can create an instance of this 
embedding with
+    :func:`~mxnet.gluon.text.embedding.create`.
+
+
+    Examples
+    --------
+    >>> @mxnet.gluon.text.embedding.register
+    ... class MyTextEmbed(mxnet.gluon.text.embedding.TokenEmbedding):
+    ...     def __init__(self, file_name='my_pretrain_file'):
+    ...         pass
+    >>> embed = mxnet.gluon.text.embedding.create('MyTokenEmbed')
+    >>> print(type(embed))
+    <class '__main__.MyTokenEmbed'>
+    """
+
+    register_text_embedding = registry.get_register_func(TokenEmbedding, 
'token embedding')
+    return register_text_embedding(embedding_cls)
+
+
+def create(embedding_name, **kwargs):
+    """Creates an instance of token embedding.
+
+
+    Creates a token embedding instance by loading embedding vectors from an 
externally hosted
+    pre-trained token embedding file, such as those of GloVe and FastText. To 
get all the valid
+    `embedding_name` and `file_name`, use 
`mxnet.gluon.text.embedding.get_file_names()`.
+
+
+    Parameters
+    ----------
+    embedding_name : str
+        The token embedding name (case-insensitive).
+
+
+    Returns
+    -------
+    An instance of `mxnet.gluon.text.embedding.TokenEmbedding`:
+        A token embedding instance that loads embedding vectors from an 
externally hosted
+        pre-trained token embedding file.
+    """
+
+    create_text_embedding = registry.get_create_func(TokenEmbedding, 'token 
embedding')
+    return create_text_embedding(embedding_name, **kwargs)
+
+
+def get_file_names(embedding_name=None):
+    """Get valid token embedding names and their pre-trained file names.
+
+
+    To load token embedding vectors from an externally hosted pre-trained 
token embedding file,
+    such as those of GloVe and FastText, one should use
+    `mxnet.gluon.text.embedding.create(embedding_name, file_name)`. This 
method returns all the
+    valid names of `file_name` for the specified `embedding_name`. If 
`embedding_name` is set to
+    None, this method returns all the valid names of `embedding_name` with 
their associated
+    `file_name`.
+
+
+    Parameters
+    ----------
+    embedding_name : str or None, default None
+        The pre-trained token embedding name.
+
+
+    Returns
+    -------
+    dict or list:
+        A list of all the valid pre-trained token embedding file names 
(`file_name`) for the
+        specified token embedding name (`embedding_name`). If the text 
embeding name is set to None,
+        returns a dict mapping each valid token embedding name to a list of 
valid pre-trained files
+        (`file_name`). They can be plugged into
+        `mxnet.gluon.text.embedding.create(embedding_name, file_name)`.
+    """
+
+    text_embedding_reg = registry.get_registry(TokenEmbedding)
+
+    if embedding_name is not None:
+        if embedding_name not in text_embedding_reg:
+            raise KeyError('Cannot find `embedding_name` %s. Use '
+                           '`get_file_names(embedding_name=None).keys()` to 
get all the valid'
+                           'embedding names.' % embedding_name)
+        return 
list(text_embedding_reg[embedding_name].pretrained_file_name_sha1.keys())
+    else:
+        return {embedding_name: 
list(embedding_cls.pretrained_file_name_sha1.keys())
+                for embedding_name, embedding_cls in 
registry.get_registry(TokenEmbedding).items()}
+
+
+class TokenEmbedding(object):
+    """Token embedding base class.
+
+
+    To load token embedding from an externally hosted pre-trained token 
embedding file, such as
+    those of GloVe and FastText, use 
:func:`~mxnet.gluon.text.embedding.create(embedding_name,
+    file_name)`. To get all the available `embedding_name` and `file_name`, use
+    :func:`~mxnet.gluon.text.embedding.get_file_names()`.
+
+    Alternatively, to load embedding vectors from a custom pre-trained token 
embedding file, use
+    :func:`~mxnet.gluon.text.embedding.from_file()`.
+
+    For every unknown token, if its representation `self.unknown_token` is 
encountered in the
+    pre-trained token embedding file, index 0 of `self.idx_to_vec` maps to the 
pre-trained token
+    embedding vector loaded from the file; otherwise, index 0 of 
`self.idx_to_vec` maps to the
+    token embedding vector initialized by `init_unknown_vec`.
+
+    If a token is encountered multiple times in the pre-trained token 
embedding file, only the
+    first-encountered token embedding vector will be loaded and the rest will 
be skipped.
+
+
+    Parameters
+    ----------
+    unknown_token : hashable object, default '<unk>'
+        The representation for any unknown token. In other words, any unknown 
token will be indexed
+        as the same representation.
+
+
+    Properties
+    ----------
+    idx_to_vec : mxnet.ndarray.NDArray
+        For all the indexed tokens in this embedding, this NDArray maps each 
token's index to an
+        embedding vector.
+    unknown_token : hashable object
+        The representation for any unknown token. In other words, any unknown 
token will be indexed
+        as the same representation.
+    """
+
+    def __init__(self, unknown_token='<unk>'):
+        self._unknown_token = unknown_token
+        self._idx_to_token = [unknown_token]
+        self._token_to_idx = {token: idx for idx, token in 
enumerate(self._idx_to_token)}
+        self._idx_to_vec = None
+
+    @classmethod
+    def _get_download_file_name(cls, file_name):
+        return file_name
+
+    @classmethod
+    def _get_pretrained_file_url(cls, pretrained_file_name):
+        cls_name = cls.__name__.lower()
+
+        namespace = 'gluon/embeddings/{}'.format(cls_name)
+        return _get_repo_file_url(namespace, 
cls._get_download_file_name(pretrained_file_name))
+
+    @classmethod
+    def _get_pretrained_file(cls, embedding_root, pretrained_file_name):
+        cls_name = cls.__name__.lower()
+        embedding_root = os.path.expanduser(embedding_root)
+        url = cls._get_pretrained_file_url(pretrained_file_name)
+
+        embedding_dir = os.path.join(embedding_root, cls_name)
+        pretrained_file_path = os.path.join(embedding_dir, 
pretrained_file_name)
+        downloaded_file = os.path.basename(url)
+        downloaded_file_path = os.path.join(embedding_dir, downloaded_file)
+
+        expected_file_hash = 
cls.pretrained_file_name_sha1[pretrained_file_name]
+
+        if hasattr(cls, 'pretrained_archive_name_sha1'):
+            expected_downloaded_hash = \
+                cls.pretrained_archive_name_sha1[downloaded_file]
+        else:
+            expected_downloaded_hash = expected_file_hash
+
+        if not os.path.exists(pretrained_file_path) \
+           or not check_sha1(pretrained_file_path, expected_file_hash):
+            download(url, downloaded_file_path, 
sha1_hash=expected_downloaded_hash)
+
+            ext = os.path.splitext(downloaded_file)[1]
+            if ext == '.zip':
+                with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
+                    zf.extractall(embedding_dir)
+            elif ext == '.gz':
+                with tarfile.open(downloaded_file_path, 'r:gz') as tar:
+                    tar.extractall(path=embedding_dir)
+        return pretrained_file_path
+
+    def _load_embedding(self, pretrained_file_path, elem_delim, 
init_unknown_vec, encoding='utf8'):
+        """Load embedding vectors from a pre-trained token embedding file.
+
+
+        For every unknown token, if its representation `self.unknown_token` is 
encountered in the
+        pre-trained token embedding file, index 0 of `self.idx_to_vec` maps to 
the pre-trained token
+        embedding vector loaded from the file; otherwise, index 0 of 
`self.idx_to_vec` maps to the
+        text embedding vector initialized by `init_unknown_vec`.
+
+        If a token is encountered multiple times in the pre-trained text 
embedding file, only the
+        first-encountered token embedding vector will be loaded and the rest 
will be skipped.
+        """
+
+        pretrained_file_path = os.path.expanduser(pretrained_file_path)
+
+        if not os.path.isfile(pretrained_file_path):
+            raise ValueError('`pretrained_file_path` must be a valid path to 
the pre-trained '
+                             'token embedding file.')
+
+        logging.info('Loading pre-trained token embedding vectors from %s', 
pretrained_file_path)
+        vec_len = None
+        all_elems = []
+        tokens = set()
+        loaded_unknown_vec = None
+        line_num = 0
+        with io.open(pretrained_file_path, 'r', encoding=encoding) as f:
+            for line in f:
+                line_num += 1
+                elems = line.rstrip().split(elem_delim)
+
+                assert len(elems) > 1, 'At line %d of the pre-trained token 
embedding file: the ' \
+                                       'data format of the pre-trained token 
embedding file %s ' \
+                                       'is unexpected.' % (line_num, 
pretrained_file_path)
+
+                token, elems = elems[0], [float(i) for i in elems[1:]]
+
+                if token == self.unknown_token and loaded_unknown_vec is None:
+                    loaded_unknown_vec = elems
+                    tokens.add(self.unknown_token)
+                elif token in tokens:
+                    warnings.warn('At line %d of the pre-trained token 
embedding file: the '
+                                  'embedding vector for token %s has been 
loaded and a duplicate '
+                                  'embedding for the  same token is seen and 
skipped.' %
+                                  (line_num, token))
+                elif len(elems) == 1:
+                    warnings.warn('At line %d of the pre-trained token 
embedding file: token %s '
+                                  'with 1-dimensional vector %s is likely a 
header and is '
+                                  'skipped.' % (line_num, token, elems))
+                else:
+                    if vec_len is None:
+                        vec_len = len(elems)
+                        # Reserve a vector slot for the unknown token at the 
very beggining because
+                        # the unknown token index is 0.
+                        all_elems.extend([0] * vec_len)
+                    else:
+                        assert len(elems) == vec_len, \
+                            'At line %d of the pre-trained token embedding 
file: the dimension ' \
+                            'of token %s is %d but the dimension of previous 
tokens is %d. ' \
+                            'Dimensions of all the tokens must be the same.' \
+                            % (line_num, token, len(elems), vec_len)
+                    all_elems.extend(elems)
+                    self._idx_to_token.append(token)
+                    self._token_to_idx[token] = len(self._idx_to_token) - 1
+                    tokens.add(token)
+
+        self._idx_to_vec = nd.array(all_elems).reshape((-1, vec_len))
+
+        if loaded_unknown_vec is None:
+            self._idx_to_vec[C.UNKNOWN_IDX] = init_unknown_vec(shape=vec_len)
+        else:
+            self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec)
+
+    @property
+    def idx_to_vec(self):
+        return self._idx_to_vec
+
+    @property
+    def unknown_token(self):
+        return self._unknown_token
+
+    def __contains__(self, x):
+        return x in self._token_to_idx
+
+    def __getitem__(self, tokens):
+        """Looks up embedding vectors of text tokens.
+
+
+        Parameters
+        ----------
+        tokens : str or list of strs
+            A token or a list of tokens.
+
+
+        Returns
+        -------
+        mxnet.ndarray.NDArray:
+            The embedding vector(s) of the token(s). According to numpy 
conventions, if `tokens` is
+            a string, returns a 1-D NDArray (vector); if `tokens` is a list of
+            strings, returns a 2-D NDArray (matrix) of shape=(len(tokens), 
vec_len).
+        """
+
+        to_reduce = not isinstance(tokens, (list, tuple))
+        if to_reduce:
+            tokens = [tokens]
+
+        indices = [self._token_to_idx.get(token, C.UNKNOWN_IDX) for token in 
tokens]
+
+        vecs = nd.Embedding(nd.array(indices), self.idx_to_vec, 
self.idx_to_vec.shape[0],
+                            self.idx_to_vec.shape[1])
+
+        return vecs[0] if to_reduce else vecs
+
+    def __setitem__(self, tokens, new_vectors):
+        """Updates embedding vectors for tokens.
+
+
+        Parameters
+        ----------
+        tokens : str or a list of strs
+            A token or a list of tokens whose embedding vector are to be 
updated.
+        new_vectors : mxnet.ndarray.NDArray
+            An NDArray to be assigned to the embedding vectors of `tokens`. 
Its length must be equal
+            to the number of `tokens` and its width must be equal to the 
dimension of embedding of
+            the glossary. If `tokens` is a singleton, it must be 1-D or 2-D. 
If `tokens` is a list
+            of multiple strings, it must be 2-D.
+        """
+
+        assert self._idx_to_vec is not None, '`idx_to_vec` has not been 
initialized.'
+
+        if not isinstance(tokens, list) or len(tokens) == 1:
+            assert isinstance(new_vectors, nd.NDArray) and 
len(new_vectors.shape) in [1, 2], \
+                '`new_vectors` must be a 1-D or 2-D NDArray if `tokens` is a 
singleton.'
+            if not isinstance(tokens, list):
+                tokens = [tokens]
+            if len(new_vectors.shape) == 1:
+                new_vectors = new_vectors.expand_dims(0)
+
+        else:
+            assert isinstance(new_vectors, nd.NDArray) and 
len(new_vectors.shape) == 2, \
+                '`new_vectors` must be a 2-D NDArray if `tokens` is a list of 
multiple strings.'
+        assert new_vectors.shape == (len(tokens), self._idx_to_vec.shape[1]), \
+            'The length of new_vectors must be equal to the number of tokens 
and the width of' \
+            'new_vectors must be equal to the dimension of embedding of the 
glossary.'
+
+        indices = []
+        for token in tokens:
+            if token in self._token_to_idx:
+                indices.append(self._token_to_idx[token])
+            else:
+                raise ValueError('Token %s is unknown. To update the embedding 
vector for an '
+                                 'unknown token, please specify it explicitly 
as the '
+                                 '`unknown_token` %s in `tokens`. This is to 
avoid unintended '
+                                 'updates.' % (token, 
self._idx_to_token[C.UNKNOWN_IDX]))
+
+        self._idx_to_vec[nd.array(indices)] = new_vectors
+
+    @classmethod
+    def _check_pretrained_file_names(cls, file_name):
+        """Checks if a pre-trained token embedding file name is valid.
+
+
+        Parameters
+        ----------
+        file_name : str
+            The pre-trained token embedding file.
+        """
+
+        embedding_name = cls.__name__.lower()
+        if file_name not in cls.pretrained_file_name_sha1:
+            raise KeyError('Cannot find pre-trained file %s for token 
embedding %s. Valid '
+                           'pre-trained file names for embedding %s: %s' %
+                           (file_name, embedding_name, embedding_name,
+                            ', '.join(cls.pretrained_file_name_sha1.keys())))
+
+    @staticmethod
+    def from_file(file_path, elem_delim=' ', encoding='utf8', 
init_unknown_vec=nd.zeros, **kwargs):
+        """Creates a user-defined token embedding from a pre-trained embedding 
file.
+
+
+        This is to load embedding vectors from a user-defined pre-trained 
token embedding file.
+        Denote by '(ed)' the argument `elem_delim`. Denote by (v_ij) the j-th 
element of the token
+        embedding vector for (token_i), the expected format of a custom 
pre-trained token embedding
+        file is:
+
+        '(token_1)(ed))v_11)(ed)(v_12)(ed)...(ed)(v_1k)\\\\n
+        (token_2)(ed)(v_21)(ed)(v_22)(ed)...(ed)(v_2k)\\\\n...'
 
 Review comment:
   Use an example for the file format inside a code block so that it's easier 
to understand the file format. Currently it looks confusing. 
http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-10074/10/api/python/gluon/text.html#mxnet.gluon.text.embedding.TokenEmbedding.from_file

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to