Source code for sourced.ml.core.models.topics
from typing import Union
from modelforge import assemble_sparse_matrix, disassemble_sparse_matrix, merge_strings, \
Model, register_model, split_strings
from sourced.ml.core.models.license import DEFAULT_LICENSE
[docs]@register_model
class Topics(Model):
[docs] DESCRIPTION = "Model that is used to identify topics of source code repositories."
[docs] LICENSE = DEFAULT_LICENSE
@property
[docs] def tokens(self):
return self._tokens
@property
[docs] def topics(self):
"""
May be None if no topics are labeled.
"""
return self._topics
@property
[docs] def matrix(self):
"""
Rows: tokens
Columns: topics
"""
return self._matrix
[docs] def construct(self, tokens: list, topics: Union[list, None], matrix):
if len(tokens) != matrix.shape[1]:
raise ValueError("Tokens and matrix do not match.")
self._tokens = tokens
self._topics = topics
self._matrix = matrix
return self
def _load_tree(self, tree: dict) -> None:
self.construct(split_strings(tree["tokens"]),
split_strings(tree["topics"]) if tree["topics"] else None,
assemble_sparse_matrix(tree["matrix"]))
[docs] def dump(self) -> str:
res = "%d topics, %d tokens\nFirst 10 tokens: %s\nTopics: " % (
self.matrix.shape + (self.tokens[:10],))
if self.topics is not None:
res += "labeled, first 10: %s\n" % self.topics[:10]
else:
res += "unlabeled\n"
nnz = self.matrix.getnnz()
res += "non-zero elements: %d (%f)" % (
nnz, nnz / (self.matrix.shape[0] * self.matrix.shape[1]))
return res
def _generate_tree(self):
return {"tokens": merge_strings(self.tokens),
"topics": merge_strings(self.topics) if self.topics is not None else False,
"matrix": disassemble_sparse_matrix(self.matrix)}
def __len__(self):
"""
Returns the number of topics.
"""
return self.matrix.shape[0]
def __getitem__(self, item):
"""
Returns the keywords sorted by significance from topic index.
"""
row = self.matrix[item]
nnz = row.nonzero()[1]
pairs = [(-row[0, i], i) for i in nnz]
pairs.sort()
return [(self.tokens[pair[1]], -pair[0]) for pair in pairs]
[docs] def label_topics(self, labels):
if len(labels) != len(self):
raise ValueError("Sizes do not match: %d != %d" % (len(labels), len(self)))
if not isinstance(labels[0], str):
raise TypeError("Labels must be strings")
self._topics = list(labels)