summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'tesseract/unittest/mastertrainer_test.cc')
-rw-r--r--tesseract/unittest/mastertrainer_test.cc298
1 files changed, 298 insertions, 0 deletions
diff --git a/tesseract/unittest/mastertrainer_test.cc b/tesseract/unittest/mastertrainer_test.cc
new file mode 100644
index 00000000..0f93e221
--- /dev/null
+++ b/tesseract/unittest/mastertrainer_test.cc
@@ -0,0 +1,298 @@
+// (C) Copyright 2017, Google Inc.
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Although this is a trivial-looking test, it exercises a lot of code:
+// SampleIterator has to correctly iterate over the correct characters, or
+// it will fail.
+// The canonical and cloud features computed by TrainingSampleSet need to
+// be correct, along with the distance caches, organizing samples by font
+// and class, indexing of features, distance calculations.
+// IntFeatureDist has to work, or the canonical samples won't work.
+// Mastertrainer has ability to read tr files and set itself up tested.
+// Finally the serialize/deserialize test ensures that MasterTrainer,
+// TrainingSampleSet, TrainingSample can all serialize/deserialize correctly
+// enough to reproduce the same results.
+
+#include "include_gunit.h"
+
+#include "log.h" // for LOG
+#include "unicharset.h"
+#include "errorcounter.h"
+#include "mastertrainer.h"
+#include "shapeclassifier.h"
+#include "shapetable.h"
+#include "trainingsample.h"
+#include "commontraining.h"
+
+#include "absl/strings/numbers.h" // for safe_strto32
+#include "absl/strings/str_split.h" // for absl::StrSplit
+
+#include <string>
+#include <utility>
+#include <vector>
+
+using namespace tesseract;
+
+// Specs of the MockClassifier.
+static const int kNumTopNErrs = 10;
+static const int kNumTop2Errs = kNumTopNErrs + 20;
+static const int kNumTop1Errs = kNumTop2Errs + 30;
+static const int kNumTopTopErrs = kNumTop1Errs + 25;
+static const int kNumNonReject = 1000;
+static const int kNumCorrect = kNumNonReject - kNumTop1Errs;
+// The total number of answers is given by the number of non-rejects plus
+// all the multiple answers.
+static const int kNumAnswers = kNumNonReject + 2 * (kNumTop2Errs - kNumTopNErrs) +
+ (kNumTop1Errs - kNumTop2Errs) +
+ (kNumTopTopErrs - kNumTop1Errs);
+
+#ifndef DISABLED_LEGACY_ENGINE
+static bool safe_strto32(const std::string& str, int* pResult)
+{
+ long n = strtol(str.c_str(), nullptr, 0);
+ *pResult = n;
+ return true;
+}
+#endif
+
+// Mock ShapeClassifier that cheats by looking at the correct answer, and
+// creates a specific pattern of errors that can be tested.
+class MockClassifier : public ShapeClassifier {
+ public:
+ explicit MockClassifier(ShapeTable* shape_table)
+ : shape_table_(shape_table), num_done_(0), done_bad_font_(false) {
+ // Add a false font answer to the shape table. We pick a random unichar_id,
+ // add a new shape for it with a false font. Font must actually exist in
+ // the font table, but not match anything in the first 1000 samples.
+ false_unichar_id_ = 67;
+ false_shape_ = shape_table_->AddShape(false_unichar_id_, 25);
+ }
+ virtual ~MockClassifier() {}
+
+ // Classifies the given [training] sample, writing to results.
+ // If debug is non-zero, then various degrees of classifier dependent debug
+ // information is provided.
+ // If keep_this (a shape index) is >= 0, then the results should always
+ // contain keep_this, and (if possible) anything of intermediate confidence.
+ // The return value is the number of classes saved in results.
+ int ClassifySample(const TrainingSample& sample, Pix* page_pix,
+ int debug, UNICHAR_ID keep_this,
+ std::vector<ShapeRating>* results) override {
+ results->clear();
+ // Everything except the first kNumNonReject is a reject.
+ if (++num_done_ > kNumNonReject) return 0;
+
+ int class_id = sample.class_id();
+ int font_id = sample.font_id();
+ int shape_id = shape_table_->FindShape(class_id, font_id);
+ // Get ids of some wrong answers.
+ int wrong_id1 = shape_id > 10 ? shape_id - 1 : shape_id + 1;
+ int wrong_id2 = shape_id > 10 ? shape_id - 2 : shape_id + 2;
+ if (num_done_ <= kNumTopNErrs) {
+ // The first kNumTopNErrs are top-n errors.
+ results->push_back(ShapeRating(wrong_id1, 1.0f));
+ } else if (num_done_ <= kNumTop2Errs) {
+ // The next kNumTop2Errs - kNumTopNErrs are top-2 errors.
+ results->push_back(ShapeRating(wrong_id1, 1.0f));
+ results->push_back(ShapeRating(wrong_id2, 0.875f));
+ results->push_back(ShapeRating(shape_id, 0.75f));
+ } else if (num_done_ <= kNumTop1Errs) {
+ // The next kNumTop1Errs - kNumTop2Errs are top-1 errors.
+ results->push_back(ShapeRating(wrong_id1, 1.0f));
+ results->push_back(ShapeRating(shape_id, 0.8f));
+ } else if (num_done_ <= kNumTopTopErrs) {
+ // The next kNumTopTopErrs - kNumTop1Errs are cases where the actual top
+ // is not correct, but do not count as a top-1 error because the rating
+ // is close enough to the top answer.
+ results->push_back(ShapeRating(wrong_id1, 1.0f));
+ results->push_back(ShapeRating(shape_id, 0.99f));
+ } else if (!done_bad_font_ && class_id == false_unichar_id_) {
+ // There is a single character with a bad font.
+ results->push_back(ShapeRating(false_shape_, 1.0f));
+ done_bad_font_ = true;
+ } else {
+ // Everything else is correct.
+ results->push_back(ShapeRating(shape_id, 1.0f));
+ }
+ return results->size();
+ }
+ // Provides access to the ShapeTable that this classifier works with.
+ const ShapeTable* GetShapeTable() const override { return shape_table_; }
+
+ private:
+ // Borrowed pointer to the ShapeTable.
+ ShapeTable* shape_table_;
+ // Unichar_id of a random character that occurs after the first 60 samples.
+ int false_unichar_id_;
+ // Shape index of prepared false answer for false_unichar_id.
+ int false_shape_;
+ // The number of classifications we have processed.
+ int num_done_;
+ // True after the false font has been emitted.
+ bool done_bad_font_;
+};
+
+const double kMin1lDistance = 0.25;
+
+// The fixture for testing Tesseract.
+class MasterTrainerTest : public testing::Test {
+#ifndef DISABLED_LEGACY_ENGINE
+ protected:
+ void SetUp() {
+ std::locale::global(std::locale(""));
+ file::MakeTmpdir();
+ }
+
+ std::string TestDataNameToPath(const std::string& name) {
+ return file::JoinPath(TESTING_DIR, name);
+ }
+ std::string TmpNameToPath(const std::string& name) {
+ return file::JoinPath(FLAGS_test_tmpdir, name);
+ }
+
+ MasterTrainerTest() {
+ shape_table_ = nullptr;
+ master_trainer_ = nullptr;
+ }
+ ~MasterTrainerTest() {
+ delete shape_table_;
+ }
+
+ // Initializes the master_trainer_ and shape_table_.
+ // if load_from_tmp, then reloads a master trainer that was saved by a
+ // previous call in which it was false.
+ void LoadMasterTrainer() {
+ FLAGS_output_trainer = TmpNameToPath("tmp_trainer").c_str();
+ FLAGS_F = file::JoinPath(LANGDATA_DIR, "font_properties").c_str();
+ FLAGS_X = TestDataNameToPath("eng.xheights").c_str();
+ FLAGS_U = TestDataNameToPath("eng.unicharset").c_str();
+ std::string tr_file_name(TestDataNameToPath("eng.Arial.exp0.tr"));
+ const char* argv[] = {tr_file_name.c_str()};
+ int argc = 1;
+ STRING file_prefix;
+ delete shape_table_;
+ shape_table_ = nullptr;
+ master_trainer_ =
+ LoadTrainingData(argc, argv, false, &shape_table_, &file_prefix);
+ EXPECT_TRUE(master_trainer_ != nullptr);
+ EXPECT_TRUE(shape_table_ != nullptr);
+ }
+
+ // EXPECTs that the distance between I and l in Arial is 0 and that the
+ // distance to 1 is significantly not 0.
+ void VerifyIl1() {
+ // Find the font id for Arial.
+ int font_id = master_trainer_->GetFontInfoId("Arial");
+ EXPECT_GE(font_id, 0);
+ // Track down the characters we are interested in.
+ int unichar_I = master_trainer_->unicharset().unichar_to_id("I");
+ EXPECT_GT(unichar_I, 0);
+ int unichar_l = master_trainer_->unicharset().unichar_to_id("l");
+ EXPECT_GT(unichar_l, 0);
+ int unichar_1 = master_trainer_->unicharset().unichar_to_id("1");
+ EXPECT_GT(unichar_1, 0);
+ // Now get the shape ids.
+ int shape_I = shape_table_->FindShape(unichar_I, font_id);
+ EXPECT_GE(shape_I, 0);
+ int shape_l = shape_table_->FindShape(unichar_l, font_id);
+ EXPECT_GE(shape_l, 0);
+ int shape_1 = shape_table_->FindShape(unichar_1, font_id);
+ EXPECT_GE(shape_1, 0);
+
+ float dist_I_l =
+ master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_l);
+ // No tolerance here. We expect that I and l should match exactly.
+ EXPECT_EQ(0.0f, dist_I_l);
+ float dist_l_I =
+ master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_I);
+ // BOTH ways.
+ EXPECT_EQ(0.0f, dist_l_I);
+
+ // l/1 on the other hand should be distinct.
+ float dist_l_1 =
+ master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_1);
+ EXPECT_GT(dist_l_1, kMin1lDistance);
+ float dist_1_l =
+ master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_l);
+ EXPECT_GT(dist_1_l, kMin1lDistance);
+
+ // So should I/1.
+ float dist_I_1 =
+ master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_1);
+ EXPECT_GT(dist_I_1, kMin1lDistance);
+ float dist_1_I =
+ master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_I);
+ EXPECT_GT(dist_1_I, kMin1lDistance);
+ }
+
+ // Objects declared here can be used by all tests in the test case for Foo.
+ ShapeTable* shape_table_;
+ std::unique_ptr<MasterTrainer> master_trainer_;
+#endif
+};
+
+// Tests that the MasterTrainer correctly loads its data and reaches the correct
+// conclusion over the distance between Arial I l and 1.
+TEST_F(MasterTrainerTest, Il1Test) {
+#ifdef DISABLED_LEGACY_ENGINE
+ // Skip test because LoadTrainingData is missing.
+ GTEST_SKIP();
+#else
+ // Initialize the master_trainer_ and load the Arial tr file.
+ LoadMasterTrainer();
+ VerifyIl1();
+#endif
+}
+
+// Tests the ErrorCounter using a MockClassifier to check that it counts
+// error categories correctly.
+TEST_F(MasterTrainerTest, ErrorCounterTest) {
+#ifdef DISABLED_LEGACY_ENGINE
+ // Skip test because LoadTrainingData is missing.
+ GTEST_SKIP();
+#else
+ // Initialize the master_trainer_ from the saved tmp file.
+ LoadMasterTrainer();
+ // Add the space character to the shape_table_ if not already present to
+ // count junk.
+ if (shape_table_->FindShape(0, -1) < 0) shape_table_->AddShape(0, 0);
+ // Make a mock classifier.
+ auto shape_classifier = std::make_unique<MockClassifier>(shape_table_);
+ // Get the accuracy report.
+ STRING accuracy_report;
+ master_trainer_->TestClassifierOnSamples(tesseract::CT_UNICHAR_TOP1_ERR, 0,
+ false, shape_classifier.get(),
+ &accuracy_report);
+ LOG(INFO) << accuracy_report.c_str();
+ std::string result_string = accuracy_report.c_str();
+ std::vector<std::string> results =
+ absl::StrSplit(result_string, '\t', absl::SkipEmpty());
+ EXPECT_EQ(tesseract::CT_SIZE + 1, results.size());
+ int result_values[tesseract::CT_SIZE];
+ for (int i = 0; i < tesseract::CT_SIZE; ++i) {
+ EXPECT_TRUE(safe_strto32(results[i + 1], &result_values[i]));
+ }
+ // These tests are more-or-less immune to additions to the number of
+ // categories or changes in the training data.
+ int num_samples = master_trainer_->GetSamples()->num_raw_samples();
+ EXPECT_EQ(kNumCorrect, result_values[tesseract::CT_UNICHAR_TOP_OK]);
+ EXPECT_EQ(1, result_values[tesseract::CT_FONT_ATTR_ERR]);
+ EXPECT_EQ(kNumTopTopErrs, result_values[tesseract::CT_UNICHAR_TOPTOP_ERR]);
+ EXPECT_EQ(kNumTop1Errs, result_values[tesseract::CT_UNICHAR_TOP1_ERR]);
+ EXPECT_EQ(kNumTop2Errs, result_values[tesseract::CT_UNICHAR_TOP2_ERR]);
+ EXPECT_EQ(kNumTopNErrs, result_values[tesseract::CT_UNICHAR_TOPN_ERR]);
+ // Each of the TOPTOP errs also counts as a multi-unichar.
+ EXPECT_EQ(kNumTopTopErrs - kNumTop1Errs,
+ result_values[tesseract::CT_OK_MULTI_UNICHAR]);
+ EXPECT_EQ(num_samples - kNumNonReject, result_values[tesseract::CT_REJECT]);
+ EXPECT_EQ(kNumAnswers, result_values[tesseract::CT_NUM_RESULTS]);
+#endif
+}