diff options
author | Feffernoose <fefferburbia@gmail.com> | 2013-10-05 19:14:53 -0400 |
---|---|---|
committer | Feffernoose <fefferburbia@gmail.com> | 2013-10-05 19:14:53 -0400 |
commit | eb076ca2c6c8932fd251419563cf0078c5ee0914 (patch) | |
tree | bcd96acd0613fafa27b847cc5937420755b3d748 | |
parent | 92a4a0e7db8336f8ccc11c053dc29847a303ad88 (diff) | |
download | rawr-ebooks-eb076ca2c6c8932fd251419563cf0078c5ee0914.tar.gz rawr-ebooks-eb076ca2c6c8932fd251419563cf0078c5ee0914.tar.bz2 rawr-ebooks-eb076ca2c6c8932fd251419563cf0078c5ee0914.zip |
Rewrote weighted random number generator
The previous method of picking which token was the next one was flawed in some mysterious way that ended up picking various words that occurred only once in the input corpus as the first word of the generated output (most notably, "hysterically," "Anarchy," "Yorkshire," and "impunity.").
-rw-r--r-- | kgramstats.cpp | 70 | ||||
-rw-r--r-- | kgramstats.h | 3 |
2 files changed, 39 insertions, 34 deletions
diff --git a/kgramstats.cpp b/kgramstats.cpp index d196e8f..6c0e4ce 100644 --- a/kgramstats.cpp +++ b/kgramstats.cpp | |||
@@ -21,7 +21,7 @@ kgramstats::kgramstats(string corpus, int maxK) | |||
21 | start = ((end > (string::npos - 1) ) ? string::npos : end + 1); | 21 | start = ((end > (string::npos - 1) ) ? string::npos : end + 1); |
22 | } | 22 | } |
23 | 23 | ||
24 | stats = new map<kgram, map<string, token_data*>* >(); | 24 | map<kgram, map<string, token_data*>* > tstats; |
25 | for (int k=0; k<=maxK; k++) | 25 | for (int k=0; k<=maxK; k++) |
26 | { | 26 | { |
27 | for (int i=0; i<(tokens.size() - k); i++) | 27 | for (int i=0; i<(tokens.size() - k); i++) |
@@ -31,17 +31,18 @@ kgramstats::kgramstats(string corpus, int maxK) | |||
31 | string f = tokens[i+k]; | 31 | string f = tokens[i+k]; |
32 | string canonical = canonize(f); | 32 | string canonical = canonize(f); |
33 | 33 | ||
34 | if ((*stats)[seq] == NULL) | 34 | if (tstats[seq] == NULL) |
35 | { | 35 | { |
36 | (*stats)[seq] = new map<string, token_data*>(); | 36 | tstats[seq] = new map<string, token_data*>(); |
37 | } | 37 | } |
38 | 38 | ||
39 | if ((*(*stats)[seq])[canonical] == NULL) | 39 | if ((*tstats[seq])[canonical] == NULL) |
40 | { | 40 | { |
41 | (*(*stats)[seq])[canonical] = (token_data*) calloc(1, sizeof(token_data)); | 41 | (*tstats[seq])[canonical] = (token_data*) calloc(1, sizeof(token_data)); |
42 | } | 42 | } |
43 | 43 | ||
44 | token_data* td = stats->at(seq)->at(canonical); | 44 | token_data* td = tstats[seq]->at(canonical); |
45 | td->token = new string(canonical); | ||
45 | td->all++; | 46 | td->all++; |
46 | 47 | ||
47 | if ((f.length() > 0) && (f[f.length()-1] == '.')) | 48 | if ((f.length() > 0) && (f[f.length()-1] == '.')) |
@@ -58,6 +59,24 @@ kgramstats::kgramstats(string corpus, int maxK) | |||
58 | } | 59 | } |
59 | } | 60 | } |
60 | } | 61 | } |
62 | |||
63 | stats = new map<kgram, map<int, token_data*>* >(); | ||
64 | for (map<kgram, map<string, token_data*>* >::iterator it = tstats.begin(); it != tstats.end(); it++) | ||
65 | { | ||
66 | kgram klist = it->first; | ||
67 | map<string, token_data*>* probtable = it->second; | ||
68 | map<int, token_data*>* distribution = new map<int, token_data*>(); | ||
69 | int max = 0; | ||
70 | |||
71 | for (map<string, token_data*>::iterator kt = probtable->begin(); kt != probtable->end(); kt++) | ||
72 | { | ||
73 | max += kt->second->all; | ||
74 | |||
75 | (*distribution)[max] = kt->second; | ||
76 | } | ||
77 | |||
78 | (*stats)[klist] = distribution; | ||
79 | } | ||
61 | } | 80 | } |
62 | 81 | ||
63 | void printKgram(kgram k) | 82 | void printKgram(kgram k) |
@@ -89,38 +108,23 @@ vector<string> kgramstats::randomSentence(int n) | |||
89 | } | 108 | } |
90 | } | 109 | } |
91 | 110 | ||
92 | map<string, token_data*>* probtable = (*stats)[cur]; | 111 | map<int, token_data*> distribution = *(*stats)[cur]; |
93 | int max = 0; | 112 | int max = distribution.rbegin()->first; |
94 | for (map<string, token_data*>::iterator it = probtable->begin(); it != probtable->end(); ++it) | 113 | int r = rand() % max; |
95 | { | 114 | token_data* next = distribution.upper_bound(r)->second; |
96 | max += it->second->all; | ||
97 | } | ||
98 | |||
99 | int r = rand() % (max+1); | ||
100 | map<string, token_data*>::iterator next = probtable->begin(); | ||
101 | for (map<string, token_data*>::iterator it = probtable->begin(); it != probtable->end(); ++it) | ||
102 | { | ||
103 | if (it->second->all > r) | ||
104 | { | ||
105 | break; | ||
106 | } else { | ||
107 | next = it; | ||
108 | r -= it->second->all; | ||
109 | } | ||
110 | } | ||
111 | 115 | ||
112 | string nextToken(next->first); | 116 | string nextToken(*(next->token)); |
113 | int casing = rand() % next->second->all; | 117 | int casing = rand() % next->all; |
114 | int period = rand() % next->second->all; | 118 | int period = rand() % next->all; |
115 | if (casing < next->second->uppercase) | 119 | if (casing < next->uppercase) |
116 | { | 120 | { |
117 | transform(nextToken.begin(), nextToken.end(), nextToken.begin(), ::toupper); | 121 | transform(nextToken.begin(), nextToken.end(), nextToken.begin(), ::toupper); |
118 | } else if ((casing - next->second->uppercase) < next->second->titlecase) | 122 | } else if ((casing - next->uppercase) < next->titlecase) |
119 | { | 123 | { |
120 | nextToken[0] = toupper(nextToken[0]); | 124 | nextToken[0] = toupper(nextToken[0]); |
121 | } | 125 | } |
122 | 126 | ||
123 | if (period < next->second->period) | 127 | if (period < next->period) |
124 | { | 128 | { |
125 | nextToken += "."; | 129 | nextToken += "."; |
126 | } | 130 | } |
@@ -136,9 +140,9 @@ vector<string> kgramstats::randomSentence(int n) | |||
136 | cout << *it << " "; | 140 | cout << *it << " "; |
137 | } | 141 | } |
138 | 142 | ||
139 | cout << "-> \"" << nextToken << "\" (" << next->second->all << "/" << max << ")" << endl; | 143 | cout << "-> \"" << nextToken << "\" (" << next->all << "/" << max << ")" << endl; |
140 | 144 | ||
141 | cur.push_back(next->first); | 145 | cur.push_back(*(next->token)); |
142 | result.push_back(nextToken); | 146 | result.push_back(nextToken); |
143 | } | 147 | } |
144 | 148 | ||
diff --git a/kgramstats.h b/kgramstats.h index 248b193..b40e1ab 100644 --- a/kgramstats.h +++ b/kgramstats.h | |||
@@ -23,9 +23,10 @@ private: | |||
23 | int titlecase; | 23 | int titlecase; |
24 | int uppercase; | 24 | int uppercase; |
25 | int period; | 25 | int period; |
26 | string* token; | ||
26 | } token_data; | 27 | } token_data; |
27 | int maxK; | 28 | int maxK; |
28 | map<kgram, map<string, token_data*>* >* stats; | 29 | map<kgram, map<int, token_data*>* >* stats; |
29 | }; | 30 | }; |
30 | 31 | ||
31 | void printKgram(kgram k); | 32 | void printKgram(kgram k); |