summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'tesseract/src/classify/shapeclassifier.cpp')
-rw-r--r--tesseract/src/classify/shapeclassifier.cpp234
1 files changed, 234 insertions, 0 deletions
diff --git a/tesseract/src/classify/shapeclassifier.cpp b/tesseract/src/classify/shapeclassifier.cpp
new file mode 100644
index 00000000..b1091a53
--- /dev/null
+++ b/tesseract/src/classify/shapeclassifier.cpp
@@ -0,0 +1,234 @@
+// Copyright 2011 Google Inc. All Rights Reserved.
+// Author: rays@google.com (Ray Smith)
+///////////////////////////////////////////////////////////////////////
+// File: shapeclassifier.cpp
+// Description: Base interface class for classifiers that return a
+// shape index.
+// Author: Ray Smith
+//
+// (C) Copyright 2011, Google Inc.
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+///////////////////////////////////////////////////////////////////////
+
+#ifdef HAVE_CONFIG_H
+#include "config_auto.h"
+#endif
+
+#include "shapeclassifier.h"
+
+#include "scrollview.h"
+#include "shapetable.h"
+#include "svmnode.h"
+#include "trainingsample.h"
+#include "tprintf.h"
+
+#include "genericvector.h"
+
+namespace tesseract {
+
+// Classifies the given [training] sample, writing to results.
+// See shapeclassifier.h for a full description.
+// Default implementation calls the ShapeRating version.
+int ShapeClassifier::UnicharClassifySample(
+ const TrainingSample& sample, Pix* page_pix, int debug,
+ UNICHAR_ID keep_this, std::vector<UnicharRating>* results) {
+ results->clear();
+ std::vector<ShapeRating> shape_results;
+ int num_shape_results = ClassifySample(sample, page_pix, debug, keep_this,
+ &shape_results);
+ const ShapeTable* shapes = GetShapeTable();
+ GenericVector<int> unichar_map;
+ unichar_map.init_to_size(shapes->unicharset().size(), -1);
+ for (int r = 0; r < num_shape_results; ++r) {
+ shapes->AddShapeToResults(shape_results[r], &unichar_map, results);
+ }
+ return results->size();
+}
+
+// Classifies the given [training] sample, writing to results.
+// See shapeclassifier.h for a full description.
+// Default implementation aborts.
+int ShapeClassifier::ClassifySample(const TrainingSample& sample, Pix* page_pix,
+ int debug, int keep_this,
+ std::vector<ShapeRating>* results) {
+ ASSERT_HOST("Must implement ClassifySample!" == nullptr);
+ return 0;
+}
+
+// Returns the shape that contains unichar_id that has the best result.
+// If result is not nullptr, it is set with the shape_id and rating.
+// Does not need to be overridden if ClassifySample respects the keep_this
+// rule.
+int ShapeClassifier::BestShapeForUnichar(const TrainingSample& sample,
+ Pix* page_pix, UNICHAR_ID unichar_id,
+ ShapeRating* result) {
+ std::vector<ShapeRating> results;
+ const ShapeTable* shapes = GetShapeTable();
+ int num_results = ClassifySample(sample, page_pix, 0, unichar_id, &results);
+ for (int r = 0; r < num_results; ++r) {
+ if (shapes->GetShape(results[r].shape_id).ContainsUnichar(unichar_id)) {
+ if (result != nullptr)
+ *result = results[r];
+ return results[r].shape_id;
+ }
+ }
+ return -1;
+}
+
+// Provides access to the UNICHARSET that this classifier works with.
+// Only needs to be overridden if GetShapeTable() can return nullptr.
+const UNICHARSET& ShapeClassifier::GetUnicharset() const {
+ return GetShapeTable()->unicharset();
+}
+
+#ifndef GRAPHICS_DISABLED
+
+// Visual debugger classifies the given sample, displays the results and
+// solicits user input to display other classifications. Returns when
+// the user has finished with debugging the sample.
+// Probably doesn't need to be overridden if the subclass provides
+// DisplayClassifyAs.
+void ShapeClassifier::DebugDisplay(const TrainingSample& sample,
+ Pix* page_pix,
+ UNICHAR_ID unichar_id) {
+ static ScrollView* terminator = nullptr;
+ if (terminator == nullptr) {
+ terminator = new ScrollView("XIT", 0, 0, 50, 50, 50, 50, true);
+ }
+ ScrollView* debug_win = CreateFeatureSpaceWindow("ClassifierDebug", 0, 0);
+ // Provide a right-click menu to choose the class.
+ auto* popup_menu = new SVMenuNode();
+ popup_menu->AddChild("Choose class to debug", 0, "x", "Class to debug");
+ popup_menu->BuildMenu(debug_win, false);
+ // Display the features in green.
+ const INT_FEATURE_STRUCT* features = sample.features();
+ uint32_t num_features = sample.num_features();
+ for (uint32_t f = 0; f < num_features; ++f) {
+ RenderIntFeature(debug_win, &features[f], ScrollView::GREEN);
+ }
+ debug_win->Update();
+ std::vector<UnicharRating> results;
+ // Debug classification until the user quits.
+ const UNICHARSET& unicharset = GetUnicharset();
+ SVEvent* ev;
+ SVEventType ev_type;
+ do {
+ PointerVector<ScrollView> windows;
+ if (unichar_id >= 0) {
+ tprintf("Debugging class %d = %s\n",
+ unichar_id, unicharset.id_to_unichar(unichar_id));
+ UnicharClassifySample(sample, page_pix, 1, unichar_id, &results);
+ DisplayClassifyAs(sample, page_pix, unichar_id, 1, &windows);
+ } else {
+ tprintf("Invalid unichar_id: %d\n", unichar_id);
+ UnicharClassifySample(sample, page_pix, 1, -1, &results);
+ }
+ if (unichar_id >= 0) {
+ tprintf("Debugged class %d = %s\n",
+ unichar_id, unicharset.id_to_unichar(unichar_id));
+ }
+ tprintf("Right-click in ClassifierDebug window to choose debug class,");
+ tprintf(" Left-click or close window to quit...\n");
+ UNICHAR_ID old_unichar_id;
+ do {
+ old_unichar_id = unichar_id;
+ ev = debug_win->AwaitEvent(SVET_ANY);
+ ev_type = ev->type;
+ if (ev_type == SVET_POPUP) {
+ if (unicharset.contains_unichar(ev->parameter)) {
+ unichar_id = unicharset.unichar_to_id(ev->parameter);
+ } else {
+ tprintf("Char class '%s' not found in unicharset", ev->parameter);
+ }
+ }
+ delete ev;
+ } while (unichar_id == old_unichar_id &&
+ ev_type != SVET_CLICK && ev_type != SVET_DESTROY);
+ } while (ev_type != SVET_CLICK && ev_type != SVET_DESTROY);
+ delete debug_win;
+}
+
+#endif // !GRAPHICS_DISABLED
+
+// Displays classification as the given shape_id. Creates as many windows
+// as it feels fit, using index as a guide for placement. Adds any created
+// windows to the windows output and returns a new index that may be used
+// by any subsequent classifiers. Caller waits for the user to view and
+// then destroys the windows by clearing the vector.
+int ShapeClassifier::DisplayClassifyAs(
+ const TrainingSample& sample, Pix* page_pix,
+ UNICHAR_ID unichar_id, int index,
+ PointerVector<ScrollView>* windows) {
+ // Does nothing in the default implementation.
+ return index;
+}
+
+// Prints debug information on the results.
+void ShapeClassifier::UnicharPrintResults(
+ const char* context, const std::vector<UnicharRating>& results) const {
+ tprintf("%s\n", context);
+ for (int i = 0; i < results.size(); ++i) {
+ tprintf("%g: c_id=%d=%s", results[i].rating, results[i].unichar_id,
+ GetUnicharset().id_to_unichar(results[i].unichar_id));
+ if (!results[i].fonts.empty()) {
+ tprintf(" Font Vector:");
+ for (int f = 0; f < results[i].fonts.size(); ++f) {
+ tprintf(" %d", results[i].fonts[f].fontinfo_id);
+ }
+ }
+ tprintf("\n");
+ }
+}
+void ShapeClassifier::PrintResults(
+ const char* context, const std::vector<ShapeRating>& results) const {
+ tprintf("%s\n", context);
+ for (int i = 0; i < results.size(); ++i) {
+ tprintf("%g:", results[i].rating);
+ if (results[i].joined)
+ tprintf("[J]");
+ if (results[i].broken)
+ tprintf("[B]");
+ tprintf(" %s\n", GetShapeTable()->DebugStr(results[i].shape_id).c_str());
+ }
+}
+
+// Removes any result that has all its unichars covered by a better choice,
+// regardless of font.
+void ShapeClassifier::FilterDuplicateUnichars(
+ std::vector<ShapeRating>* results) const {
+ std::vector<ShapeRating> filtered_results;
+ // Copy results to filtered results and knock out duplicate unichars.
+ const ShapeTable* shapes = GetShapeTable();
+ for (int r = 0; r < results->size(); ++r) {
+ if (r > 0) {
+ const Shape& shape_r = shapes->GetShape((*results)[r].shape_id);
+ int c;
+ for (c = 0; c < shape_r.size(); ++c) {
+ int unichar_id = shape_r[c].unichar_id;
+ int s;
+ for (s = 0; s < r; ++s) {
+ const Shape& shape_s = shapes->GetShape((*results)[s].shape_id);
+ if (shape_s.ContainsUnichar(unichar_id))
+ break; // We found unichar_id.
+ }
+ if (s == r)
+ break; // We didn't find unichar_id.
+ }
+ if (c == shape_r.size())
+ continue; // We found all the unichar ids in previous answers.
+ }
+ filtered_results.push_back((*results)[r]);
+ }
+ *results = filtered_results;
+}
+
+} // namespace tesseract.