diff options
Diffstat (limited to 'tesseract/src/lstm/stridemap.h')
-rw-r--r-- | tesseract/src/lstm/stridemap.h | 135 |
1 files changed, 135 insertions, 0 deletions
diff --git a/tesseract/src/lstm/stridemap.h b/tesseract/src/lstm/stridemap.h new file mode 100644 index 00000000..54369aff --- /dev/null +++ b/tesseract/src/lstm/stridemap.h @@ -0,0 +1,135 @@ +/////////////////////////////////////////////////////////////////////// +// File: stridemap.h +// Description: Indexing into a 4-d tensor held in a 2-d Array. +// Author: Ray Smith +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// +#ifndef TESSERACT_LSTM_STRIDEMAP_H_ +#define TESSERACT_LSTM_STRIDEMAP_H_ + +#include <cstring> +#include <vector> + +namespace tesseract { + +// Enum describing the dimensions of the 'Tensor' in a NetworkIO. +// A NetworkIO is analogous to a TF Tensor, except that the number of dimensions +// is fixed (4), and they always have the same meaning. The underlying +// representation is a 2-D array, for which the product batch*height*width +// is always dim1 and depth is always dim2. FlexDimensions is used only for +// batch, height, width with the StrideMap, and therefore represents the runtime +// shape. The build-time shape is defined by StaticShape. +enum FlexDimensions { + FD_BATCH, // Index of multiple images. + FD_HEIGHT, // y-coordinate in image. + FD_WIDTH, // x-coordinate in image. + FD_DIMSIZE, // Number of flexible non-depth dimensions. +}; + +// Encapsulation of information relating to the mapping from [batch][y][x] to +// the first index into the 2-d array underlying a NetworkIO. +class StrideMap { + public: + // Class holding the non-depth indices. + class Index { + public: + explicit Index(const StrideMap& stride_map) : stride_map_(&stride_map) { + InitToFirst(); + } + Index(const StrideMap& stride_map, int batch, int y, int x) + : stride_map_(&stride_map) { + indices_[FD_BATCH] = batch; + indices_[FD_HEIGHT] = y; + indices_[FD_WIDTH] = x; + SetTFromIndices(); + } + // Accesses the index to the underlying array. + int t() const { return t_; } + int index(FlexDimensions dimension) const { return indices_[dimension]; } + // Initializes the indices to the first valid location. + void InitToFirst() { + memset(indices_, 0, sizeof(indices_)); + t_ = 0; + } + // Initializes the indices to the last valid location. + void InitToLast() { InitToLastOfBatch(MaxIndexOfDim(FD_BATCH)); } + // Returns true if *this is a valid index. + bool IsValid() const; + // Returns true if the index of the given dimension is the last. + bool IsLast(FlexDimensions dimension) const; + // Given that the dimensions up to and including dim-1 are valid, returns + // the maximum index for dimension dim. + int MaxIndexOfDim(FlexDimensions dim) const; + // Adds the given offset to the given dimension. Returns true if the result + // makes a valid index. + bool AddOffset(int offset, FlexDimensions dimension); + // Increments the index in some encapsulated way that guarantees to remain + // valid until it returns false, meaning that the iteration is complete. + bool Increment(); + // Decrements the index in some encapsulated way that guarantees to remain + // valid until it returns false, meaning that the iteration (that started + // with InitToLast()) is complete. + bool Decrement(); + + private: + // Initializes the indices to the last valid location in the given batch + // index. + void InitToLastOfBatch(int batch); + // Computes and sets t_ from the current indices_. + void SetTFromIndices(); + + // Map into which *this is an index. + const StrideMap* stride_map_; + // Index to the first dimension of the underlying array. + int t_; + // Indices into the individual dimensions. + int indices_[FD_DIMSIZE]; + }; + + StrideMap() { + memset(shape_, 0, sizeof(shape_)); + memset(t_increments_, 0, sizeof(t_increments_)); + } + // Default copy constructor and operator= are OK to use here! + + // Sets up the stride for the given array of height, width pairs. + void SetStride(const std::vector<std::pair<int, int>>& h_w_pairs); + // Scales width and height dimensions by the given factors. + void ScaleXY(int x_factor, int y_factor); + // Reduces width to 1, across the batch, whatever the input size. + void ReduceWidthTo1(); + // Transposes the width and height dimensions. + void TransposeXY(); + // Returns the size of the given dimension. + int Size(FlexDimensions dimension) const { return shape_[dimension]; } + // Returns the total width required. + int Width() const { return t_increments_[FD_BATCH] * shape_[FD_BATCH]; } + + private: + // Computes t_increments_ from shape_. + void ComputeTIncrements(); + + // The size of each non-depth dimension. + int shape_[FD_DIMSIZE]; + // Precomputed 't' increments for each dimension. This is the value of + // the given dimension in the packed 3-d array that the shape_ represents. + int t_increments_[FD_DIMSIZE]; + // Vector of size shape_[FD_BATCH] holds the height of each image in a batch. + std::vector<int> heights_; + // Vector of size shape_[FD_BATCH] holds the width of each image in a batch. + std::vector<int> widths_; +}; + +} // namespace tesseract + +#endif // TESSERACT_LSTM_STRIDEMAP_H_ |