Source code for sourced.ml.core.models.id2vec

from modelforge import merge_strings, Model, register_model, split_strings

from sourced.ml.core.models.license import DEFAULT_LICENSE


[docs]@register_model class Id2Vec(Model): """ id2vec model - source code identifier embeddings. """
[docs] NAME = "id2vec"
[docs] VENDOR = "source{d}"
[docs] DESCRIPTION = "Model that contains information on source code as identifier embeddings."
[docs] LICENSE = DEFAULT_LICENSE
[docs] def construct(self, embeddings, tokens): self._embeddings = embeddings self._tokens = tokens self._log.info("Building the token index...") self._token2index = {w: i for i, w in enumerate(self._tokens)} return self
def _load_tree(self, tree): self.construct(embeddings=tree["embeddings"].copy(), tokens=split_strings(tree["tokens"]))
[docs] def dump(self): return """Shape: %s First 10 words: %s""" % ( self.embeddings.shape, self.tokens[:10])
@property
[docs] def embeddings(self): """ :class:`numpy.ndarray` with the embeddings of shape (N tokens x embedding dims). """ return self._embeddings
@property
[docs] def tokens(self): """ List with the processed source code identifiers. """ return self._tokens
[docs] def items(self): """ Returns the tuples belonging to token -> index mapping. """ return self._token2index.items()
def __getitem__(self, item): """ Returns the index of the specified processed source code identifier. """ return self._token2index[item] def __len__(self): """ Returns the number of tokens in the model. """ return len(self._tokens) def _generate_tree(self): return {"embeddings": self.embeddings, "tokens": merge_strings(self.tokens)}