summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'tesseract/src/lstm/stridemap.h')
-rw-r--r--tesseract/src/lstm/stridemap.h135
1 files changed, 135 insertions, 0 deletions
diff --git a/tesseract/src/lstm/stridemap.h b/tesseract/src/lstm/stridemap.h
new file mode 100644
index 00000000..54369aff
--- /dev/null
+++ b/tesseract/src/lstm/stridemap.h
@@ -0,0 +1,135 @@
+///////////////////////////////////////////////////////////////////////
+// File: stridemap.h
+// Description: Indexing into a 4-d tensor held in a 2-d Array.
+// Author: Ray Smith
+//
+// (C) Copyright 2016, Google Inc.
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+///////////////////////////////////////////////////////////////////////
+#ifndef TESSERACT_LSTM_STRIDEMAP_H_
+#define TESSERACT_LSTM_STRIDEMAP_H_
+
+#include <cstring>
+#include <vector>
+
+namespace tesseract {
+
+// Enum describing the dimensions of the 'Tensor' in a NetworkIO.
+// A NetworkIO is analogous to a TF Tensor, except that the number of dimensions
+// is fixed (4), and they always have the same meaning. The underlying
+// representation is a 2-D array, for which the product batch*height*width
+// is always dim1 and depth is always dim2. FlexDimensions is used only for
+// batch, height, width with the StrideMap, and therefore represents the runtime
+// shape. The build-time shape is defined by StaticShape.
+enum FlexDimensions {
+ FD_BATCH, // Index of multiple images.
+ FD_HEIGHT, // y-coordinate in image.
+ FD_WIDTH, // x-coordinate in image.
+ FD_DIMSIZE, // Number of flexible non-depth dimensions.
+};
+
+// Encapsulation of information relating to the mapping from [batch][y][x] to
+// the first index into the 2-d array underlying a NetworkIO.
+class StrideMap {
+ public:
+ // Class holding the non-depth indices.
+ class Index {
+ public:
+ explicit Index(const StrideMap& stride_map) : stride_map_(&stride_map) {
+ InitToFirst();
+ }
+ Index(const StrideMap& stride_map, int batch, int y, int x)
+ : stride_map_(&stride_map) {
+ indices_[FD_BATCH] = batch;
+ indices_[FD_HEIGHT] = y;
+ indices_[FD_WIDTH] = x;
+ SetTFromIndices();
+ }
+ // Accesses the index to the underlying array.
+ int t() const { return t_; }
+ int index(FlexDimensions dimension) const { return indices_[dimension]; }
+ // Initializes the indices to the first valid location.
+ void InitToFirst() {
+ memset(indices_, 0, sizeof(indices_));
+ t_ = 0;
+ }
+ // Initializes the indices to the last valid location.
+ void InitToLast() { InitToLastOfBatch(MaxIndexOfDim(FD_BATCH)); }
+ // Returns true if *this is a valid index.
+ bool IsValid() const;
+ // Returns true if the index of the given dimension is the last.
+ bool IsLast(FlexDimensions dimension) const;
+ // Given that the dimensions up to and including dim-1 are valid, returns
+ // the maximum index for dimension dim.
+ int MaxIndexOfDim(FlexDimensions dim) const;
+ // Adds the given offset to the given dimension. Returns true if the result
+ // makes a valid index.
+ bool AddOffset(int offset, FlexDimensions dimension);
+ // Increments the index in some encapsulated way that guarantees to remain
+ // valid until it returns false, meaning that the iteration is complete.
+ bool Increment();
+ // Decrements the index in some encapsulated way that guarantees to remain
+ // valid until it returns false, meaning that the iteration (that started
+ // with InitToLast()) is complete.
+ bool Decrement();
+
+ private:
+ // Initializes the indices to the last valid location in the given batch
+ // index.
+ void InitToLastOfBatch(int batch);
+ // Computes and sets t_ from the current indices_.
+ void SetTFromIndices();
+
+ // Map into which *this is an index.
+ const StrideMap* stride_map_;
+ // Index to the first dimension of the underlying array.
+ int t_;
+ // Indices into the individual dimensions.
+ int indices_[FD_DIMSIZE];
+ };
+
+ StrideMap() {
+ memset(shape_, 0, sizeof(shape_));
+ memset(t_increments_, 0, sizeof(t_increments_));
+ }
+ // Default copy constructor and operator= are OK to use here!
+
+ // Sets up the stride for the given array of height, width pairs.
+ void SetStride(const std::vector<std::pair<int, int>>& h_w_pairs);
+ // Scales width and height dimensions by the given factors.
+ void ScaleXY(int x_factor, int y_factor);
+ // Reduces width to 1, across the batch, whatever the input size.
+ void ReduceWidthTo1();
+ // Transposes the width and height dimensions.
+ void TransposeXY();
+ // Returns the size of the given dimension.
+ int Size(FlexDimensions dimension) const { return shape_[dimension]; }
+ // Returns the total width required.
+ int Width() const { return t_increments_[FD_BATCH] * shape_[FD_BATCH]; }
+
+ private:
+ // Computes t_increments_ from shape_.
+ void ComputeTIncrements();
+
+ // The size of each non-depth dimension.
+ int shape_[FD_DIMSIZE];
+ // Precomputed 't' increments for each dimension. This is the value of
+ // the given dimension in the packed 3-d array that the shape_ represents.
+ int t_increments_[FD_DIMSIZE];
+ // Vector of size shape_[FD_BATCH] holds the height of each image in a batch.
+ std::vector<int> heights_;
+ // Vector of size shape_[FD_BATCH] holds the width of each image in a batch.
+ std::vector<int> widths_;
+};
+
+} // namespace tesseract
+
+#endif // TESSERACT_LSTM_STRIDEMAP_H_