diff options
Diffstat (limited to 'kgramstats.cpp')
-rw-r--r-- | kgramstats.cpp | 442 |
1 files changed, 165 insertions, 277 deletions
diff --git a/kgramstats.cpp b/kgramstats.cpp index b0ec68a..c88d83c 100644 --- a/kgramstats.cpp +++ b/kgramstats.cpp | |||
@@ -5,237 +5,176 @@ | |||
5 | #include <algorithm> | 5 | #include <algorithm> |
6 | #include "malaprop.h" | 6 | #include "malaprop.h" |
7 | 7 | ||
8 | query wildcardQuery(querytype_sentence); | ||
9 | |||
8 | std::string canonize(std::string f); | 10 | std::string canonize(std::string f); |
9 | 11 | ||
10 | // runs in O(t^2) time where t is the number of tokens in the input corpus | 12 | // runs in O(t^2) time where t is the number of tokens in the input corpus |
11 | // We consider maxK to be fairly constant | 13 | // We consider maxK to be fairly constant |
12 | kgramstats::kgramstats(std::string corpus, int maxK) | 14 | kgramstats::kgramstats(std::string corpus, int maxK) |
13 | { | 15 | { |
14 | this->maxK = maxK; | 16 | this->maxK = maxK; |
15 | 17 | ||
16 | std::vector<std::string> tokens; | 18 | std::vector<std::string> tokens; |
17 | size_t start = 0; | 19 | size_t start = 0; |
18 | int end = 0; | 20 | int end = 0; |
19 | 21 | ||
20 | while (end != std::string::npos) | 22 | while (end != std::string::npos) |
21 | { | 23 | { |
22 | end = corpus.find(" ", start); | 24 | end = corpus.find(" ", start); |
23 | 25 | ||
24 | std::string token = corpus.substr(start, (end == std::string::npos) ? std::string::npos : end - start); | 26 | std::string token = corpus.substr(start, (end == std::string::npos) ? std::string::npos : end - start); |
25 | if (token[token.length()-1] == '\n') | 27 | if (token[token.length()-1] == '\n') |
26 | { | 28 | { |
27 | if ((token[token.length()-2] != '.') && (token[token.length()-2] != '!') && (token[token.length()-2] != '?')) | 29 | if ((token[token.length()-2] != '.') && (token[token.length()-2] != '!') && (token[token.length()-2] != '?') && (token[token.length()-2] != ',')) |
28 | { | 30 | { |
29 | token.insert(token.length()-1, "."); | 31 | token.insert(token.length()-1, "."); |
30 | } | 32 | } |
31 | 33 | ||
32 | token.resize(token.length()-1); | 34 | token.resize(token.length()-1); |
33 | } | 35 | } |
34 | 36 | ||
35 | if (token.compare("") && token.compare(".")) | 37 | if (token.compare("") && token.compare(".")) |
36 | { | 38 | { |
37 | mstats.addWord(token); | 39 | mstats.addWord(token); |
38 | tokens.push_back(token); | 40 | tokens.push_back(token); |
39 | } | 41 | } |
40 | 42 | ||
41 | start = ((end > (std::string::npos - 1) ) ? std::string::npos : end + 1); | 43 | start = ((end > (std::string::npos - 1) ) ? std::string::npos : end + 1); |
42 | } | 44 | } |
43 | 45 | ||
44 | std::map<kgram, std::map<std::string, token_data*>* > tstats; | 46 | std::map<kgram, std::map<token, token_data> > tstats; |
45 | bool newSentence = true; | 47 | std::map<token, std::map<termstats, int> > tendings; |
46 | bool newClause = false; | 48 | for (int k=1; k<maxK; k++) |
47 | for (int k=0; k<maxK; k++) | 49 | { |
48 | { | 50 | for (int i=0; i<(tokens.size() - k); i++) |
49 | for (int i=0; i<(tokens.size() - k); i++) | 51 | { |
50 | { | 52 | std::list<std::string> seq(tokens.begin()+i, tokens.begin()+i+k); |
51 | kgram seq(tokens.begin()+i, tokens.begin()+i+k); | 53 | kgram prefix; |
52 | std::transform(seq.begin(), seq.end(), seq.begin(), canonize); | ||
53 | std::string f = tokens[i+k]; | ||
54 | |||
55 | |||
56 | |||
57 | std::string canonical = canonize(f); | ||
58 | |||
59 | if (tstats[seq] == NULL) | ||
60 | { | ||
61 | tstats[seq] = new std::map<std::string, token_data*>(); | ||
62 | } | ||
63 | |||
64 | if ((*tstats[seq])[canonical] == NULL) | ||
65 | { | ||
66 | (*tstats[seq])[canonical] = (token_data*) calloc(1, sizeof(token_data)); | ||
67 | } | ||
68 | |||
69 | token_data* td = tstats[seq]->at(canonical); | ||
70 | td->token = new std::string(canonical); | ||
71 | td->all++; | ||
72 | 54 | ||
73 | /*if (newSentence) | 55 | for (std::list<std::string>::iterator it = seq.begin(); it != seq.end(); it++) |
74 | { | 56 | { |
75 | kgram newKgram(1, "."); | 57 | token word(canonize(*it)); |
76 | if (tstats[newKgram] == NULL) | 58 | |
59 | if (it->find_first_of(".?!,") != std::string::npos) | ||
77 | { | 60 | { |
78 | tstats[newKgram] = new std::map<std::string, token_data*>(); | 61 | word.terminating = true; |
79 | } | 62 | } |
80 | 63 | ||
81 | (*tstats[newKgram])[canonical] = td; | 64 | prefix.push_back(word); |
82 | |||
83 | newSentence = false; | ||
84 | } | 65 | } |
85 | 66 | ||
86 | if (newClause) | 67 | std::string f = tokens[i+k]; |
68 | std::string canonical = canonize(f); | ||
69 | |||
70 | token word(canonical); | ||
71 | if (f.find_first_of(".?!,") != std::string::npos) | ||
87 | { | 72 | { |
88 | kgram commaKgram(1, ","); | 73 | word.terminating = true; |
89 | if (tstats[commaKgram] == NULL) | ||
90 | { | ||
91 | tstats[commaKgram] = new std::map<std::string, token_data*>(); | ||
92 | } | ||
93 | 74 | ||
94 | (*tstats[commaKgram])[canonical] = td; | 75 | char terminator = f[f.find_last_of(".?!,")]; |
76 | int occurrences = std::count(f.begin(), f.end(), terminator); | ||
95 | 77 | ||
96 | newClause = false; | 78 | tendings[word][termstats(terminator, occurrences)]++; |
97 | } | ||
98 | |||
99 | if ((f.length() > 0) && (f[f.length()-1] == '\n')) | ||
100 | { | ||
101 | td->period++; | ||
102 | newSentence = true; | ||
103 | f.resize(f.length()-1); | ||
104 | } | 79 | } |
105 | 80 | ||
106 | if (f.length() > 0) | 81 | token_data& td = tstats[prefix][word]; |
82 | td.word = word; | ||
83 | td.all++; | ||
84 | |||
85 | if (std::find_if(f.begin(), f.end(), ::islower) == f.end()) | ||
107 | { | 86 | { |
108 | if ((f[f.length()-1] == '.') || (f[f.length()-1] == '!') || (f[f.length()-1] == '?')) | 87 | td.uppercase++; |
109 | { | 88 | } else if (isupper(f[0])) |
110 | if (!newSentence) | 89 | { |
111 | { | 90 | td.titlecase++; |
112 | td->period++; | ||
113 | newSentence = true; | ||
114 | } | ||
115 | |||
116 | f.resize(f.length()-1); | ||
117 | } else if (f[f.length()-1] == ',') | ||
118 | { | ||
119 | if (!newSentence) | ||
120 | { | ||
121 | td->comma++; | ||
122 | newClause = true; | ||
123 | } | ||
124 | |||
125 | f.resize(f.length()-1); | ||
126 | } | ||
127 | } | 91 | } |
128 | 92 | ||
129 | if (f.length() > 0) | 93 | if (prefix.front().word.terminating) |
130 | { | 94 | { |
131 | if (f[0] == '"') | 95 | prefix.front() = wildcardQuery; |
132 | { | ||
133 | td->startquote++; | ||
134 | } | ||
135 | 96 | ||
136 | if (f[0] == '(') | 97 | token_data& td2 = tstats[prefix][word]; |
98 | td2.word = word; | ||
99 | td2.all++; | ||
100 | |||
101 | if (std::find_if(f.begin(), f.end(), ::islower) == f.end()) | ||
137 | { | 102 | { |
138 | td->startparen++; | 103 | td2.uppercase++; |
139 | } | 104 | } else if (isupper(f[0])) |
140 | |||
141 | if ((f[f.length()-1] == '"') || (f[f.length()-1] == ')')) | ||
142 | { | 105 | { |
143 | if (f[f.length()-1] == '"') | 106 | td2.titlecase++; |
144 | { | ||
145 | td->endquote++; | ||
146 | } else if (f[f.length()-1] == ')') | ||
147 | { | ||
148 | td->endparen++; | ||
149 | } | ||
150 | |||
151 | f.resize(f.length()-1); | ||
152 | |||
153 | if (f.length() > 0) | ||
154 | { | ||
155 | if ((f[f.length()-1] == '.') || (f[f.length()-1] == '!') || (f[f.length()-1] == '?')) | ||
156 | { | ||
157 | if (!newSentence) | ||
158 | { | ||
159 | td->period++; | ||
160 | newSentence = true; | ||
161 | } | ||
162 | } else if (f[f.length()-1] == ',') | ||
163 | { | ||
164 | if (!newSentence && !newClause) | ||
165 | { | ||
166 | td->comma++; | ||
167 | newClause = true; | ||
168 | } | ||
169 | } | ||
170 | } | ||
171 | } | ||
172 | }*/ | ||
173 | |||
174 | if (std::find_if(f.begin(), f.end(), ::islower) == f.end()) | ||
175 | { | ||
176 | td->uppercase++; | ||
177 | } else if (isupper(f[0])) | ||
178 | { | ||
179 | td->titlecase++; | ||
180 | } | ||
181 | |||
182 | /*if (k != 0) | ||
183 | { | ||
184 | if (newSentence) | ||
185 | { | ||
186 | i += k; | ||
187 | } | 107 | } |
188 | 108 | } | |
189 | newSentence = false; | 109 | } |
190 | newClause = false; | 110 | } |
191 | }*/ | ||
192 | } | ||
193 | } | ||
194 | 111 | ||
195 | stats = new std::map<kgram, std::map<int, token_data*>* >(); | 112 | for (std::map<kgram, std::map<token, token_data> >::iterator it = tstats.begin(); it != tstats.end(); it++) |
196 | for (std::map<kgram, std::map<std::string, token_data*>* >::iterator it = tstats.begin(); it != tstats.end(); it++) | 113 | { |
197 | { | 114 | kgram klist = it->first; |
198 | kgram klist = it->first; | 115 | std::map<token, token_data>& probtable = it->second; |
199 | std::map<std::string, token_data*>* probtable = it->second; | 116 | std::map<int, token_data>& distribution = stats[klist]; |
200 | std::map<int, token_data*>* distribution = new std::map<int, token_data*>(); | 117 | int max = 0; |
201 | int max = 0; | ||
202 | 118 | ||
203 | for (std::map<std::string, token_data*>::iterator kt = probtable->begin(); kt != probtable->end(); kt++) | 119 | for (std::map<token, token_data>::iterator kt = probtable.begin(); kt != probtable.end(); kt++) |
204 | { | 120 | { |
205 | max += kt->second->all; | 121 | max += kt->second.all; |
206 | 122 | ||
207 | (*distribution)[max] = kt->second; | 123 | distribution[max] = kt->second; |
208 | } | 124 | } |
209 | 125 | } | |
210 | (*stats)[klist] = distribution; | 126 | |
211 | } | 127 | for (std::map<token, std::map<termstats, int> >::iterator it = tendings.begin(); it != tendings.end(); it++) |
128 | { | ||
129 | token word = it->first; | ||
130 | std::map<termstats, int>& probtable = it->second; | ||
131 | std::map<int, termstats>& distribution = endings[word]; | ||
132 | int max = 0; | ||
133 | |||
134 | for (std::map<termstats, int>::iterator kt = probtable.begin(); kt != probtable.end(); kt++) | ||
135 | { | ||
136 | max += kt->second; | ||
137 | |||
138 | distribution[max] = kt->first; | ||
139 | } | ||
140 | } | ||
212 | } | 141 | } |
213 | 142 | ||
214 | void printKgram(kgram k) | 143 | void printKgram(kgram k) |
215 | { | 144 | { |
216 | for (kgram::iterator it = k.begin(); it != k.end(); it++) | 145 | for (kgram::iterator it = k.begin(); it != k.end(); it++) |
217 | { | 146 | { |
218 | std::cout << *it << " "; | 147 | query& q = *it; |
219 | } | 148 | if (q.type == querytype_sentence) |
149 | { | ||
150 | std::cout << "#.# "; | ||
151 | } else if (q.type == querytype_literal) | ||
152 | { | ||
153 | if (q.word.terminating) | ||
154 | { | ||
155 | std::cout << q.word.canon << ". "; | ||
156 | } else { | ||
157 | std::cout << q.word.canon << " "; | ||
158 | } | ||
159 | } | ||
160 | } | ||
220 | } | 161 | } |
221 | 162 | ||
222 | // runs in O(n log t) time where n is the input number of sentences and t is the number of tokens in the input corpus | 163 | // runs in O(n log t) time where n is the input number of sentences and t is the number of tokens in the input corpus |
223 | std::vector<std::string> kgramstats::randomSentence(int n) | 164 | std::vector<std::string> kgramstats::randomSentence(int n) |
224 | { | 165 | { |
225 | std::vector<std::string> result; | 166 | std::vector<std::string> result; |
226 | kgram newKgram(1, "."); | 167 | kgram cur(1, wildcardQuery); |
227 | kgram commaKgram(1, ","); | ||
228 | std::list<std::string> cur; | ||
229 | int cuts = 0; | 168 | int cuts = 0; |
230 | 169 | ||
231 | for (int i=0; i<n; i++) | 170 | for (int i=0; i<n; i++) |
232 | { | 171 | { |
233 | if (cur.size() == maxK) | 172 | if (cur.size() == maxK) |
234 | { | 173 | { |
235 | cur.pop_front(); | 174 | cur.pop_front(); |
236 | } | 175 | } |
237 | 176 | ||
238 | if ((cur.size() > 0) && (cur != newKgram)) | 177 | if (cur.size() > 0) |
239 | { | 178 | { |
240 | if (rand() % (maxK - cur.size() + 1) == 0) | 179 | if (rand() % (maxK - cur.size() + 1) == 0) |
241 | { | 180 | { |
@@ -253,20 +192,19 @@ std::vector<std::string> kgramstats::randomSentence(int n) | |||
253 | 192 | ||
254 | cuts++; | 193 | cuts++; |
255 | } | 194 | } |
195 | |||
196 | // Gotta circumvent the last line of the input corpus | ||
197 | // https://twitter.com/starla4444/status/684222271339237376 | ||
198 | if (stats.count(cur) == 0) | ||
199 | { | ||
200 | cur = kgram(1, wildcardQuery); | ||
201 | } | ||
256 | 202 | ||
257 | std::map<int, token_data*> distribution = *(*stats)[cur]; | 203 | std::map<int, token_data>& distribution = stats[cur]; |
258 | int max = distribution.rbegin()->first; | 204 | int max = distribution.rbegin()->first; |
259 | int r = rand() % max; | 205 | int r = rand() % max; |
260 | token_data* next = distribution.upper_bound(r)->second; | 206 | token_data& next = distribution.upper_bound(r)->second; |
261 | 207 | std::string nextToken(next.word.canon); | |
262 | std::string nextToken(*(next->token)); | ||
263 | int casing = rand() % next->all; | ||
264 | /*int period = rand() % next->all; | ||
265 | int startparen = rand() % next->all; | ||
266 | int endparen = rand() % next->all; | ||
267 | int startquote = rand() % next->all; | ||
268 | int endquote = rand() % next->all; | ||
269 | int comma = rand() % next->all;*/ | ||
270 | 208 | ||
271 | bool mess = (rand() % 100) == 0; | 209 | bool mess = (rand() % 100) == 0; |
272 | if (mess) | 210 | if (mess) |
@@ -274,114 +212,64 @@ std::vector<std::string> kgramstats::randomSentence(int n) | |||
274 | nextToken = mstats.alternate(nextToken); | 212 | nextToken = mstats.alternate(nextToken); |
275 | } | 213 | } |
276 | 214 | ||
277 | if (casing < next->uppercase) | 215 | // Determine the casing of the next token. We randomly make the token all |
278 | { | 216 | // caps based on the markov chain. Otherwise, we check if the previous |
279 | std::transform(nextToken.begin(), nextToken.end(), nextToken.begin(), ::toupper); | 217 | // token is the end of a sentence (terminating token or a wildcard query). |
280 | } | 218 | int casing = rand() % next.all; |
281 | 219 | if (casing < next.uppercase) | |
282 | if ((cur == newKgram) && (rand() % 15 > 0)) | 220 | { |
221 | std::transform(nextToken.begin(), nextToken.end(), nextToken.begin(), ::toupper); | ||
222 | } else if ((((cur.rbegin()->type == querytype_sentence) | ||
223 | || ((cur.rbegin()->type == querytype_literal) | ||
224 | && (cur.rbegin()->word.terminating))) | ||
225 | && (rand() % 2 > 0)) | ||
226 | || (casing - next.uppercase < next.titlecase)) | ||
283 | { | 227 | { |
284 | nextToken[0] = toupper(nextToken[0]); | 228 | nextToken[0] = toupper(nextToken[0]); |
285 | } | 229 | } |
286 | 230 | ||
287 | /*if (startquote < next->startquote) | 231 | if (next.word.terminating) |
288 | { | ||
289 | nextToken = "\"" + nextToken; | ||
290 | } else if (startparen < next->startparen) | ||
291 | { | 232 | { |
292 | nextToken = "(" + nextToken; | 233 | std::map<int, termstats>& ending = endings[next.word]; |
234 | int emax = ending.rbegin()->first; | ||
235 | int er = rand() % emax; | ||
236 | termstats& nextend = ending.upper_bound(er)->second; | ||
237 | |||
238 | nextToken.append(std::string(nextend.occurrences, nextend.terminator)); | ||
293 | } | 239 | } |
294 | |||
295 | if (period < next->period) | ||
296 | { | ||
297 | if (endquote < next->endquote) | ||
298 | { | ||
299 | nextToken += "\""; | ||
300 | } else if (endparen < next->endparen) | ||
301 | { | ||
302 | nextToken += ")"; | ||
303 | } | ||
304 | |||
305 | int type = rand() % 6; | ||
306 | |||
307 | if (type < 3) | ||
308 | { | ||
309 | nextToken += "."; | ||
310 | } else if (type < 5) | ||
311 | { | ||
312 | nextToken += "!"; | ||
313 | } else { | ||
314 | nextToken += "?"; | ||
315 | } | ||
316 | } else if (comma < next->comma) | ||
317 | { | ||
318 | if (endquote < next->endquote) | ||
319 | { | ||
320 | nextToken += "\""; | ||
321 | } else if (endparen < next->endparen) | ||
322 | { | ||
323 | nextToken += ")"; | ||
324 | } | ||
325 | |||
326 | nextToken += ","; | ||
327 | }*/ | ||
328 | 240 | ||
329 | /* DEBUG */ | 241 | /* DEBUG */ |
330 | for (kgram::iterator it = cur.begin(); it != cur.end(); it++) | 242 | printKgram(cur); |
331 | { | ||
332 | std::cout << *it << " "; | ||
333 | } | ||
334 | 243 | ||
335 | std::cout << "-> \"" << nextToken << "\" (" << next->all << "/" << max << ")"; | 244 | std::cout << "-> \"" << nextToken << "\" (" << next.all << "/" << max << ")"; |
336 | 245 | ||
337 | if (mess) | 246 | if (mess) |
338 | { | 247 | { |
339 | std::cout << " mala " << *(next->token); | 248 | std::cout << " mala " << next.word.canon; |
340 | } | 249 | } |
341 | 250 | ||
342 | std::cout << std::endl; | 251 | std::cout << std::endl; |
343 | 252 | ||
344 | /*if ((cur == newKgram) || (cur == commaKgram)) | 253 | cur.push_back(next.word); |
345 | { | ||
346 | cur.pop_front(); | ||
347 | } | ||
348 | |||
349 | if (period < next->period)// && ((rand() % 3) != 0)) | ||
350 | { | ||
351 | cur = newKgram; | ||
352 | } else if ((comma < next->comma) && ((rand() % 3) == 0)) | ||
353 | { | ||
354 | cur = commaKgram; | ||
355 | } else {*/ | ||
356 | //if (mess && (rand() % 2 == 0)) | ||
357 | if (false) | ||
358 | { | ||
359 | // This doesn't work because sometimes the alternate token isn't actually present in the original corpus | ||
360 | cur.clear(); | ||
361 | cur.push_back(nextToken); | ||
362 | } else { | ||
363 | cur.push_back(*(next->token)); | ||
364 | } | ||
365 | //} | ||
366 | 254 | ||
367 | result.push_back(nextToken); | 255 | result.push_back(nextToken); |
368 | } | 256 | } |
369 | 257 | ||
370 | return result; | 258 | return result; |
371 | } | 259 | } |
372 | 260 | ||
373 | bool removeIf(char c) | 261 | bool removeIf(char c) |
374 | { | 262 | { |
375 | return !((c != '.') && (c != '?') && (c != '!') && (c != '"') && (c != '(') && (c != ')') && (c != ',') && (c != '\n')); | 263 | return !((c != '.') && (c != '?') && (c != '!') && (c != ',') /*&& (c != '"') && (c != '(') && (c != ')') && (c != '\n')*/); |
376 | } | 264 | } |
377 | 265 | ||
378 | std::string canonize(std::string f) | 266 | std::string canonize(std::string f) |
379 | { | 267 | { |
380 | std::string canonical(f); | 268 | std::string canonical(f); |
381 | std::transform(canonical.begin(), canonical.end(), canonical.begin(), ::tolower); | 269 | std::transform(canonical.begin(), canonical.end(), canonical.begin(), ::tolower); |
382 | 270 | ||
383 | std::string result; | 271 | std::string result; |
384 | std::remove_copy_if(canonical.begin(), canonical.end(), std::back_inserter(result), removeIf); | 272 | std::remove_copy_if(canonical.begin(), canonical.end(), std::back_inserter(result), removeIf); |
385 | 273 | ||
386 | return canonical; | 274 | return result; |
387 | } | 275 | } |