forked from MaartenGr/BERTopic
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpatch.diff
More file actions
201 lines (186 loc) · 10.2 KB
/
patch.diff
File metadata and controls
201 lines (186 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
diff --git a/.gitignore b/.gitignore
index a8e93cb..46b5c1a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -75,4 +75,3 @@ venv.bak/
.idea
.idea/
-.vscode
diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py
index 408b8a6..e75c768 100644
--- a/bertopic/_bertopic.py
+++ b/bertopic/_bertopic.py
@@ -89,6 +89,7 @@ class BERTopic:
hdbscan_model: hdbscan.HDBSCAN = None,
vectorizer_model: CountVectorizer = None,
verbose: bool = False,
+ words_separator: str = ", ",
):
"""BERTopic initialization
@@ -159,6 +160,7 @@ class BERTopic:
self.diversity = diversity
self.verbose = verbose
self.seed_topic_list = seed_topic_list
+ self.words_separator = words_separator
# Embedding model
self.language = language if not embedding_model else None
@@ -535,10 +537,15 @@ class BERTopic:
index=documents_per_topic.Topic).to_dict()
# Fill dataframe with results
+ # KK_EDITED
topics_at_timestamp = [(topic,
- ", ".join([words[0] for words in values][:5]),
+ self.words_separator.join([words[0] for words in values][:10]),
topic_frequency[topic],
timestamp) for topic, values in words_per_topic.items()]
+ # topics_at_timestamp = [(topic,
+ # ", ".join([words[0] for words in values][:5]),
+ # topic_frequency[topic],
+ # timestamp) for topic, values in words_per_topic.items()]
topics_over_time.extend(topics_at_timestamp)
if evolution_tuning:
@@ -2164,9 +2171,13 @@ class BERTopic:
self.c_tf_idf, words = self._c_tf_idf(documents_per_topic)
self.topics = self._extract_words_per_topic(words)
self._create_topic_vectors()
- self.topic_names = {key: f"{key}_" + "_".join([word[0] for word in values[:4]])
+ # KK_EDITED
+ self.topic_names = {key: f"{key}_" + self.words_separator.join([word[0] for word in values[:10]])
for key, values in
self.topics.items()}
+ # self.topic_names = {key: f"{key}_" + "_".join([word[0] for word in values[:4]])
+ # for key, values in
+ # self.topics.items()}
def _save_representative_docs(self, documents: pd.DataFrame):
""" Save the most representative docs (3) per topic
diff --git a/bertopic/plotting/_distribution.py b/bertopic/plotting/_distribution.py
index 8fe4b67..e7cf50f 100644
--- a/bertopic/plotting/_distribution.py
+++ b/bertopic/plotting/_distribution.py
@@ -60,7 +60,9 @@ def visualize_distribution(topic_model,
for idx in labels_idx:
words = topic_model.get_topic(idx)
if words:
- label = [word[0] for word in words[:5]]
+ # KK_EDITED
+ label = [word[0] for word in words[:10]]
+ # label = [word[0] for word in words[:5]]
label = f"<b>Topic {idx}</b>: {'_'.join(label)}"
label = label[:40] + "..." if len(label) > 40 else label
labels.append(label)
diff --git a/bertopic/plotting/_heatmap.py b/bertopic/plotting/_heatmap.py
index fe84555..c4e329a 100644
--- a/bertopic/plotting/_heatmap.py
+++ b/bertopic/plotting/_heatmap.py
@@ -13,7 +13,8 @@ def visualize_heatmap(topic_model,
n_clusters: int = None,
custom_labels: bool = False,
width: int = 800,
- height: int = 800) -> go.Figure:
+ height: int = 800,
+ words_separator: str = ", ") -> go.Figure:
""" Visualize a heatmap of the topic's similarity matrix
Based on the cosine similarity matrix between topic embeddings,
@@ -99,7 +100,9 @@ def visualize_heatmap(topic_model,
else:
new_labels = [[[str(topic), None]] + topic_model.get_topic(topic) for topic in sorted_topics]
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
- new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
+ # KK_EDITED
+ new_labels = [words_separator.join([label[0] for label in labels[:9]]) for labels in new_labels]
+ # new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
fig = px.imshow(distance_matrix,
labels=dict(color="Similarity Score"),
diff --git a/bertopic/plotting/_hierarchy.py b/bertopic/plotting/_hierarchy.py
index b1fef67..10c6cc0 100644
--- a/bertopic/plotting/_hierarchy.py
+++ b/bertopic/plotting/_hierarchy.py
@@ -19,7 +19,8 @@ def visualize_hierarchy(topic_model,
hierarchical_topics: pd.DataFrame = None,
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
distance_function: Callable[[csr_matrix], csr_matrix] = None,
- color_threshold: int = 1) -> go.Figure:
+ color_threshold: int = 1,
+ words_separator: str = ", ") -> go.Figure:
""" Visualize a hierarchical structure of the topics
A ward linkage function is used to perform the
@@ -135,7 +136,9 @@ def visualize_hierarchy(topic_model,
new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)])
for x in fig.layout[axis]["ticktext"]]
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
- new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
+ # KK_EDITED
+ new_labels = [words_separator.join([label[0] for label in labels[:9]]) for labels in new_labels]
+ # new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
# Stylize layout
fig.update_layout(
diff --git a/bertopic/plotting/_term_rank.py b/bertopic/plotting/_term_rank.py
index a66aea3..bd42594 100644
--- a/bertopic/plotting/_term_rank.py
+++ b/bertopic/plotting/_term_rank.py
@@ -78,7 +78,8 @@ def visualize_term_rank(topic_model,
label = topic_model.custom_labels[topic + topic_model._outliers]
else:
label = f"<b>Topic {topic}</b>:" + "_".join([word[0] for word in topic_model.get_topic(topic)])
- label = label[:50]
+ # KK_EDITED
+ # label = label[:50]
# line parameters
color = "red" if topic in topics else "black"
diff --git a/bertopic/plotting/_topics.py b/bertopic/plotting/_topics.py
index 84351ae..dc85e38 100644
--- a/bertopic/plotting/_topics.py
+++ b/bertopic/plotting/_topics.py
@@ -12,7 +12,8 @@ def visualize_topics(topic_model,
topics: List[int] = None,
top_n_topics: int = None,
width: int = 650,
- height: int = 650) -> go.Figure:
+ height: int = 650,
+ words_separator: str = ", ") -> go.Figure:
""" Visualize topics, their sizes, and their corresponding words
This visualization is highly inspired by LDAvis, a great visualization
@@ -55,7 +56,9 @@ def visualize_topics(topic_model,
# Extract topic words and their frequencies
topic_list = sorted(topics)
frequencies = [topic_model.topic_sizes[topic] for topic in topic_list]
- words = [" | ".join([word[0] for word in topic_model.get_topic(topic)[:5]]) for topic in topic_list]
+ # KK_EDITED
+ words = [words_separator.join([word[0] for word in topic_model.get_topic(topic)[:10]]) for topic in topic_list]
+ # words = [" | ".join([word[0] for word in topic_model.get_topic(topic)[:5]]) for topic in topic_list]
# Embed c-TF-IDF into 2D
all_topics = sorted(list(topic_model.get_topics().keys()))
diff --git a/bertopic/plotting/_topics_over_time.py b/bertopic/plotting/_topics_over_time.py
index a13282e..4b5cf1f 100644
--- a/bertopic/plotting/_topics_over_time.py
+++ b/bertopic/plotting/_topics_over_time.py
@@ -61,8 +61,11 @@ def visualize_topics_over_time(topic_model,
if topic_model.custom_labels is not None and custom_labels:
topic_names = {key: topic_model.custom_labels[key + topic_model._outliers] for key, _ in topic_model.topic_names.items()}
else:
- topic_names = {key: value[:40] + "..." if len(value) > 40 else value
- for key, value in topic_model.topic_names.items()}
+ # KK_EDITED
+ topic_names = {key: value
+ for key, value in topic_model.topic_names.items()}
+ # topic_names = {key: value[:40] + "..." if len(value) > 40 else value
+ # for key, value in topic_model.topic_names.items()}
topics_over_time["Name"] = topics_over_time.Topic.map(topic_names)
data = topics_over_time.loc[topics_over_time.Topic.isin(selected_topics), :].sort_values(["Topic", "Timestamp"])
diff --git a/bertopic/plotting/_topics_per_class.py b/bertopic/plotting/_topics_per_class.py
index 0894ea0..cb3d69b 100644
--- a/bertopic/plotting/_topics_per_class.py
+++ b/bertopic/plotting/_topics_per_class.py
@@ -61,8 +61,11 @@ def visualize_topics_per_class(topic_model,
if topic_model.custom_labels is not None and custom_labels:
topic_names = {key: topic_model.custom_labels[key + topic_model._outliers] for key, _ in topic_model.topic_names.items()}
else:
- topic_names = {key: value[:40] + "..." if len(value) > 40 else value
- for key, value in topic_model.topic_names.items()}
+ # KK_EDITED
+ topic_names = {key: value
+ for key, value in topic_model.topic_names.items()}
+ # topic_names = {key: value[:40] + "..." if len(value) > 40 else value
+ # for key, value in topic_model.topic_names.items()}
topics_per_class["Name"] = topics_per_class.Topic.map(topic_names)
data = topics_per_class.loc[topics_per_class.Topic.isin(selected_topics), :]