From eb076ca2c6c8932fd251419563cf0078c5ee0914 Mon Sep 17 00:00:00 2001 From: Feffernoose Date: Sat, 5 Oct 2013 19:14:53 -0400 Subject: Rewrote weighted random number generator The previous method of picking which token was the next one was flawed in some mysterious way that ended up picking various words that occurred only once in the input corpus as the first word of the generated output (most notably, "hysterically," "Anarchy," "Yorkshire," and "impunity."). --- kgramstats.cpp | 70 +++++++++++++++++++++++++++++++--------------------------- kgramstats.h | 3 ++- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/kgramstats.cpp b/kgramstats.cpp index d196e8f..6c0e4ce 100644 --- a/kgramstats.cpp +++ b/kgramstats.cpp @@ -21,7 +21,7 @@ kgramstats::kgramstats(string corpus, int maxK) start = ((end > (string::npos - 1) ) ? string::npos : end + 1); } - stats = new map* >(); + map* > tstats; for (int k=0; k<=maxK; k++) { for (int i=0; i<(tokens.size() - k); i++) @@ -31,17 +31,18 @@ kgramstats::kgramstats(string corpus, int maxK) string f = tokens[i+k]; string canonical = canonize(f); - if ((*stats)[seq] == NULL) + if (tstats[seq] == NULL) { - (*stats)[seq] = new map(); + tstats[seq] = new map(); } - if ((*(*stats)[seq])[canonical] == NULL) + if ((*tstats[seq])[canonical] == NULL) { - (*(*stats)[seq])[canonical] = (token_data*) calloc(1, sizeof(token_data)); + (*tstats[seq])[canonical] = (token_data*) calloc(1, sizeof(token_data)); } - token_data* td = stats->at(seq)->at(canonical); + token_data* td = tstats[seq]->at(canonical); + td->token = new string(canonical); td->all++; if ((f.length() > 0) && (f[f.length()-1] == '.')) @@ -58,6 +59,24 @@ kgramstats::kgramstats(string corpus, int maxK) } } } + + stats = new map* >(); + for (map* >::iterator it = tstats.begin(); it != tstats.end(); it++) + { + kgram klist = it->first; + map* probtable = it->second; + map* distribution = new map(); + int max = 0; + + for (map::iterator kt = probtable->begin(); kt != probtable->end(); kt++) + { + max += kt->second->all; + + (*distribution)[max] = kt->second; + } + + (*stats)[klist] = distribution; + } } void printKgram(kgram k) @@ -89,38 +108,23 @@ vector kgramstats::randomSentence(int n) } } - map* probtable = (*stats)[cur]; - int max = 0; - for (map::iterator it = probtable->begin(); it != probtable->end(); ++it) - { - max += it->second->all; - } - - int r = rand() % (max+1); - map::iterator next = probtable->begin(); - for (map::iterator it = probtable->begin(); it != probtable->end(); ++it) - { - if (it->second->all > r) - { - break; - } else { - next = it; - r -= it->second->all; - } - } + map distribution = *(*stats)[cur]; + int max = distribution.rbegin()->first; + int r = rand() % max; + token_data* next = distribution.upper_bound(r)->second; - string nextToken(next->first); - int casing = rand() % next->second->all; - int period = rand() % next->second->all; - if (casing < next->second->uppercase) + string nextToken(*(next->token)); + int casing = rand() % next->all; + int period = rand() % next->all; + if (casing < next->uppercase) { transform(nextToken.begin(), nextToken.end(), nextToken.begin(), ::toupper); - } else if ((casing - next->second->uppercase) < next->second->titlecase) + } else if ((casing - next->uppercase) < next->titlecase) { nextToken[0] = toupper(nextToken[0]); } - if (period < next->second->period) + if (period < next->period) { nextToken += "."; } @@ -136,9 +140,9 @@ vector kgramstats::randomSentence(int n) cout << *it << " "; } - cout << "-> \"" << nextToken << "\" (" << next->second->all << "/" << max << ")" << endl; + cout << "-> \"" << nextToken << "\" (" << next->all << "/" << max << ")" << endl; - cur.push_back(next->first); + cur.push_back(*(next->token)); result.push_back(nextToken); } diff --git a/kgramstats.h b/kgramstats.h index 248b193..b40e1ab 100644 --- a/kgramstats.h +++ b/kgramstats.h @@ -23,9 +23,10 @@ private: int titlecase; int uppercase; int period; + string* token; } token_data; int maxK; - map* >* stats; + map* >* stats; }; void printKgram(kgram k); -- cgit 1.4.1