Source code for sourced.ml.core.models.topics

from typing import Union

from modelforge import assemble_sparse_matrix, disassemble_sparse_matrix, merge_strings, \
    Model, register_model, split_strings

from sourced.ml.core.models.license import DEFAULT_LICENSE


[docs]@register_model class Topics(Model):
[docs] NAME = "topics"
[docs] VENDOR = "source{d}"
[docs] DESCRIPTION = "Model that is used to identify topics of source code repositories."
[docs] LICENSE = DEFAULT_LICENSE
@property
[docs] def tokens(self): return self._tokens
@property
[docs] def topics(self): """ May be None if no topics are labeled. """ return self._topics
@property
[docs] def matrix(self): """ Rows: tokens Columns: topics """ return self._matrix
[docs] def construct(self, tokens: list, topics: Union[list, None], matrix): if len(tokens) != matrix.shape[1]: raise ValueError("Tokens and matrix do not match.") self._tokens = tokens self._topics = topics self._matrix = matrix return self
def _load_tree(self, tree: dict) -> None: self.construct(split_strings(tree["tokens"]), split_strings(tree["topics"]) if tree["topics"] else None, assemble_sparse_matrix(tree["matrix"]))
[docs] def dump(self) -> str: res = "%d topics, %d tokens\nFirst 10 tokens: %s\nTopics: " % ( self.matrix.shape + (self.tokens[:10],)) if self.topics is not None: res += "labeled, first 10: %s\n" % self.topics[:10] else: res += "unlabeled\n" nnz = self.matrix.getnnz() res += "non-zero elements: %d (%f)" % ( nnz, nnz / (self.matrix.shape[0] * self.matrix.shape[1])) return res
def _generate_tree(self): return {"tokens": merge_strings(self.tokens), "topics": merge_strings(self.topics) if self.topics is not None else False, "matrix": disassemble_sparse_matrix(self.matrix)} def __len__(self): """ Returns the number of topics. """ return self.matrix.shape[0] def __getitem__(self, item): """ Returns the keywords sorted by significance from topic index. """ row = self.matrix[item] nnz = row.nonzero()[1] pairs = [(-row[0, i], i) for i in nnz] pairs.sort() return [(self.tokens[pair[1]], -pair[0]) for pair in pairs]
[docs] def label_topics(self, labels): if len(labels) != len(self): raise ValueError("Sizes do not match: %d != %d" % (len(labels), len(self))) if not isinstance(labels[0], str): raise TypeError("Labels must be strings") self._topics = list(labels)