diff options
Diffstat (limited to 'kgramstats.cpp')
-rw-r--r-- | kgramstats.cpp | 43 |
1 files changed, 25 insertions, 18 deletions
diff --git a/kgramstats.cpp b/kgramstats.cpp index b0a83dc..6148dd3 100644 --- a/kgramstats.cpp +++ b/kgramstats.cpp | |||
@@ -590,7 +590,7 @@ void rawr::setMinCorpora(int _arg) | |||
590 | } | 590 | } |
591 | 591 | ||
592 | // runs in O(n log t) time where n is the input number of sentences and t is the number of tokens in the input corpus | 592 | // runs in O(n log t) time where n is the input number of sentences and t is the number of tokens in the input corpus |
593 | std::string rawr::randomSentence(int maxL) const | 593 | std::string rawr::randomSentence(int maxL, std::mt19937& rng) const |
594 | { | 594 | { |
595 | if (!_compiled) | 595 | if (!_compiled) |
596 | { | 596 | { |
@@ -610,16 +610,13 @@ std::string rawr::randomSentence(int maxL) const | |||
610 | cur.pop_front(); | 610 | cur.pop_front(); |
611 | } | 611 | } |
612 | 612 | ||
613 | do | 613 | while (cur.size() > 2 && |
614 | cuts > 0 && | ||
615 | !std::bernoulli_distribution(1.0 / static_cast<double>(cuts))(rng)) | ||
614 | { | 616 | { |
615 | if ((cur.size() > 2) && (cuts > 0) && ((rand() % cuts) > 0)) | ||
616 | { | ||
617 | cur.pop_front(); | 617 | cur.pop_front(); |
618 | cuts--; | 618 | cuts--; |
619 | } else { | 619 | } |
620 | break; | ||
621 | } | ||
622 | } while ((cur.size() > 2) && (cuts > 0) && ((rand() % cuts) > 0)); | ||
623 | 620 | ||
624 | // Gotta circumvent the last line of the input corpus | 621 | // Gotta circumvent the last line of the input corpus |
625 | // https://twitter.com/starla4444/status/684222271339237376 | 622 | // https://twitter.com/starla4444/status/684222271339237376 |
@@ -627,7 +624,8 @@ std::string rawr::randomSentence(int maxL) const | |||
627 | { | 624 | { |
628 | // The end of a corpus should probably be treated like a terminator, so | 625 | // The end of a corpus should probably be treated like a terminator, so |
629 | // maybe we should just end here. | 626 | // maybe we should just end here. |
630 | if ((result.length() > maxL) || (rand() % 4 == 0)) | 627 | if (result.length() > maxL || |
628 | std::bernoulli_distribution(1.0 / 4.0)(rng)) | ||
631 | { | 629 | { |
632 | break; | 630 | break; |
633 | } | 631 | } |
@@ -637,10 +635,11 @@ std::string rawr::randomSentence(int maxL) const | |||
637 | 635 | ||
638 | auto& distribution = _stats.at(cur); | 636 | auto& distribution = _stats.at(cur); |
639 | int max = distribution.rbegin()->first; | 637 | int max = distribution.rbegin()->first; |
640 | int r = rand() % max; | 638 | std::uniform_int_distribution<int> randDist(0, max - 1); |
639 | int r = randDist(rng); | ||
641 | const token_data& next = distribution.upper_bound(r)->second; | 640 | const token_data& next = distribution.upper_bound(r)->second; |
642 | const token& interned = _tokenstore.get(next.tok); | 641 | const token& interned = _tokenstore.get(next.tok); |
643 | std::string nextToken = interned.w.forms.next(); | 642 | std::string nextToken = interned.w.forms.next(rng); |
644 | 643 | ||
645 | // Apply user-specified transforms | 644 | // Apply user-specified transforms |
646 | if (_transform) | 645 | if (_transform) |
@@ -651,10 +650,16 @@ std::string rawr::randomSentence(int maxL) const | |||
651 | // Determine the casing of the next token. We randomly make the token all | 650 | // Determine the casing of the next token. We randomly make the token all |
652 | // caps based on the markov chain. Otherwise, we check if the previous | 651 | // caps based on the markov chain. Otherwise, we check if the previous |
653 | // token is the end of a sentence (terminating token or a wildcard query). | 652 | // token is the end of a sentence (terminating token or a wildcard query). |
654 | int casing = rand() % next.all; | 653 | std::uniform_int_distribution<int> caseDist(0, next.all - 1); |
654 | int casing = caseDist(rng); | ||
655 | |||
655 | if (casing < next.uppercase) | 656 | if (casing < next.uppercase) |
656 | { | 657 | { |
657 | std::transform(nextToken.begin(), nextToken.end(), nextToken.begin(), ::toupper); | 658 | std::transform( |
659 | std::begin(nextToken), | ||
660 | std::end(nextToken), | ||
661 | std::begin(nextToken), | ||
662 | ::toupper); | ||
658 | } else { | 663 | } else { |
659 | bool capitalize = false; | 664 | bool capitalize = false; |
660 | 665 | ||
@@ -663,7 +668,7 @@ std::string rawr::randomSentence(int maxL) const | |||
663 | capitalize = true; | 668 | capitalize = true; |
664 | } else if (cur.rbegin()->type == querytype::sentence) | 669 | } else if (cur.rbegin()->type == querytype::sentence) |
665 | { | 670 | { |
666 | if (rand() % 2 > 0) | 671 | if (std::bernoulli_distribution(1.0 / 2.0)(rng)) |
667 | { | 672 | { |
668 | capitalize = true; | 673 | capitalize = true; |
669 | } | 674 | } |
@@ -671,7 +676,7 @@ std::string rawr::randomSentence(int maxL) const | |||
671 | const token& lastTok = _tokenstore.get(cur.rbegin()->tok); | 676 | const token& lastTok = _tokenstore.get(cur.rbegin()->tok); |
672 | 677 | ||
673 | if (lastTok.suffix == suffixtype::terminating && | 678 | if (lastTok.suffix == suffixtype::terminating && |
674 | rand() % 2 > 0) | 679 | std::bernoulli_distribution(1.0 / 2.0)(rng)) |
675 | { | 680 | { |
676 | capitalize = true; | 681 | capitalize = true; |
677 | } | 682 | } |
@@ -753,7 +758,7 @@ std::string rawr::randomSentence(int maxL) const | |||
753 | // Terminators | 758 | // Terminators |
754 | if (interned.suffix == suffixtype::terminating) | 759 | if (interned.suffix == suffixtype::terminating) |
755 | { | 760 | { |
756 | auto term = interned.w.terms.next(); | 761 | auto term = interned.w.terms.next(rng); |
757 | nextToken.append(term.form); | 762 | nextToken.append(term.form); |
758 | 763 | ||
759 | if (term.newline) | 764 | if (term.newline) |
@@ -794,7 +799,9 @@ std::string rawr::randomSentence(int maxL) const | |||
794 | cur.push_back(next.tok); | 799 | cur.push_back(next.tok); |
795 | result.append(nextToken); | 800 | result.append(nextToken); |
796 | 801 | ||
797 | if ((interned.suffix == suffixtype::terminating) && ((result.length() > maxL) || (rand() % 4 == 0))) | 802 | if (interned.suffix == suffixtype::terminating && |
803 | (result.length() > maxL || | ||
804 | std::bernoulli_distribution(1.0 / 4.0)(rng))) | ||
798 | { | 805 | { |
799 | break; | 806 | break; |
800 | } | 807 | } |
@@ -803,7 +810,7 @@ std::string rawr::randomSentence(int maxL) const | |||
803 | // Ensure that enough corpora are used | 810 | // Ensure that enough corpora are used |
804 | if (used_corpora.size() < _min_corpora) | 811 | if (used_corpora.size() < _min_corpora) |
805 | { | 812 | { |
806 | return randomSentence(maxL); | 813 | return randomSentence(maxL, rng); |
807 | } | 814 | } |
808 | 815 | ||
809 | // Remove the trailing space | 816 | // Remove the trailing space |