[Mlir-commits] [mlir] fbe6113 - [mlir][sparse] refactored codegen environment into its own file

Aart Bik llvmlistbot at llvm.org
Tue Dec 20 16:59:07 PST 2022


Author: Aart Bik
Date: 2022-12-20T16:58:59-08:00
New Revision: fbe611309edf9ab8d9c26dcc82bc1ef38e99e637

URL: https://github.com/llvm/llvm-project/commit/fbe611309edf9ab8d9c26dcc82bc1ef38e99e637
DIFF: https://github.com/llvm/llvm-project/commit/fbe611309edf9ab8d9c26dcc82bc1ef38e99e637.diff

LOG: [mlir][sparse] refactored codegen environment into its own file

Also, as a proof of concept, all functionality related to reductions
has been refactored into private fields and a clean public API. As a
result, some dead code was found as well. This approach also simplifies
asserting on a proper environment state for each call.

NOTE: making all other fields private and migrating more methods into
      this new class is still TBD in yes another next revision!

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D140443

Added: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index a44d2117f5891..cfa73bf246ecc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRSparseTensorTransforms
   BufferizableOpInterfaceImpl.cpp
+  CodegenEnv.cpp
   CodegenUtils.cpp
   SparseBufferRewriting.cpp
   SparseTensorCodegen.cpp

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
new file mode 100644
index 0000000000000..52d6d5822e577
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -0,0 +1,69 @@
+//===- CodegenEnv.cpp -  Code generation environment class ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenEnv.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+//===----------------------------------------------------------------------===//
+// Code generation environment constructor and setup
+//===----------------------------------------------------------------------===//
+
+CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
+                       unsigned numTensors, unsigned numLoops,
+                       unsigned numFilterLoops)
+    : linalgOp(linop), options(opts), topSort(),
+      merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr),
+      sparseOut(nullptr), redVal(nullptr), redExp(-1u), redCustom(-1u) {}
+
+void CodegenEnv::startEmit(SparseTensorLoopEmitter *le) {
+  assert(!loopEmitter && "must only start emitting once");
+  loopEmitter = le;
+  if (sparseOut) {
+    insChain = sparseOut->get();
+    merger.setHasSparseOut(true);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Code generation environment methods
+//===----------------------------------------------------------------------===//
+
+void CodegenEnv::startReduc(unsigned exp, Value val) {
+  assert(redExp == -1u && exp != -1u);
+  redExp = exp;
+  updateReduc(val);
+}
+
+void CodegenEnv::updateReduc(Value val) {
+  assert(redExp != -1u);
+  redVal = exp(redExp).val = val;
+}
+
+Value CodegenEnv::endReduc() {
+  Value val = redVal;
+  updateReduc(Value());
+  redExp = -1u;
+  return val;
+}
+
+void CodegenEnv::startCustomReduc(unsigned exp) {
+  assert(redCustom == -1u && exp != -1u);
+  redCustom = exp;
+}
+
+Value CodegenEnv::getCustomRedId() {
+  assert(redCustom != -1u);
+  return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
+}
+
+void CodegenEnv::endCustomReduc() {
+  assert(redCustom != -1u);
+  redCustom = -1u;
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
new file mode 100644
index 0000000000000..1b0362203f2fd
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -0,0 +1,136 @@
+//===- CodegenEnv.h - Code generation environment class ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines the code generation environment class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENENV_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENENV_H_
+
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+/// The code generation environment class aggregates a number of data
+/// structures that are needed during the code generation phase of
+/// sparsification. This environment simplifies passing around such
+/// data during sparsification (rather than passing around all the
+/// individual compoments where needed). Furthermore, it provides
+/// a number of delegate and convience methods that keep some of the
+/// implementation details transparent to sparsification.
+class CodegenEnv {
+public:
+  CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
+             unsigned numTensors, unsigned numLoops, unsigned numFilterLoops);
+
+  // Start emitting.
+  void startEmit(SparseTensorLoopEmitter *le);
+
+  // Delegate methods to merger.
+  TensorExp &exp(unsigned e) { return merger.exp(e); }
+  LatPoint &lat(unsigned l) { return merger.lat(l); }
+  SmallVector<unsigned> &set(unsigned s) { return merger.set(s); }
+  DimLevelType dimLevelType(unsigned t, unsigned i) const {
+    return merger.getDimLevelType(t, i);
+  }
+  DimLevelType dimLevelType(unsigned b) const {
+    return merger.getDimLevelType(b);
+  }
+  bool isFilterLoop(unsigned i) const { return merger.isFilterLoop(i); }
+
+  // Delegate methods to loop emitter.
+  Value getLoopIV(unsigned i) const { return loopEmitter->getLoopIV(i); }
+  const std::vector<Value> &getValBuffer() const {
+    return loopEmitter->getValBuffer();
+  }
+
+  // Convenience method to slice topsort.
+  ArrayRef<unsigned> getTopSortSlice(size_t n, size_t m) const {
+    return ArrayRef<unsigned>(topSort).slice(n, m);
+  }
+
+  // Convenience method to get current loop stack.
+  ArrayRef<unsigned> getLoopCurStack() const {
+    return getTopSortSlice(0, loopEmitter->getCurrentDepth());
+  }
+
+  // Convenience method to get the IV of the given loop index.
+  Value getLoopIdxValue(size_t loopIdx) const {
+    for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++)
+      if (topSort[lv] == loopIdx)
+        return getLoopIV(lv);
+    llvm_unreachable("invalid loop index");
+  }
+
+  //
+  // Reductions.
+  //
+
+  void startReduc(unsigned exp, Value val);
+  void updateReduc(Value val);
+  bool isReduc() const { return redExp != -1u; }
+  Value getReduc() const { return redVal; }
+  Value endReduc();
+
+  void startCustomReduc(unsigned exp);
+  bool isCustomReduc() const { return redCustom != -1u; }
+  Value getCustomRedId();
+  void endCustomReduc();
+
+public:
+  //
+  // TODO make this section private too, using similar refactoring as for reduc
+  //
+
+  // Linalg operation.
+  linalg::GenericOp linalgOp;
+
+  // Sparsification options.
+  SparsificationOptions options;
+
+  // Topological sort.
+  std::vector<unsigned> topSort;
+
+  // Merger helper class.
+  Merger merger;
+
+  // Loop emitter helper class (keep reference in scope!).
+  // TODO: move emitter constructor up in time?
+  SparseTensorLoopEmitter *loopEmitter;
+
+  // Sparse tensor as output. Implemented either through direct injective
+  // insertion in lexicographic index order or through access pattern expansion
+  // in the innermost loop nest (`expValues` through `expCount`).
+  OpOperand *sparseOut;
+  unsigned outerParNest;
+  Value insChain; // bookkeeping for insertion chain
+  Value expValues;
+  Value expFilled;
+  Value expAdded;
+  Value expCount;
+
+private:
+  // Bookkeeping for reductions (up-to-date value of the reduction, and indices
+  // into the merger's expression tree. When the indices of a tensor reduction
+  // expression are exhausted, all inner loops can use a scalarized reduction.
+  Value redVal;
+  unsigned redExp;
+  unsigned redCustom;
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENENV_H_

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 3f67f66a93097..a6c031486c7e3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "CodegenEnv.h"
 #include "CodegenUtils.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -36,7 +37,7 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 
 //===----------------------------------------------------------------------===//
-// Declarations of data structures.
+// Declarations
 //===----------------------------------------------------------------------===//
 
 namespace {
@@ -49,100 +50,6 @@ enum SortMask {
   kIncludeAll = 0x3
 };
 
-/// Reduction kinds.
-enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom };
-
-/// Code generation environment. This structure aggregates a number
-/// of data structures needed during code generation. Such an environment
-/// simplifies passing around data during sparsification (rather than
-/// passing around all the individual compoments where needed).
-//
-// TODO: refactor further, move into own file
-//
-struct CodeGenEnv {
-  CodeGenEnv(linalg::GenericOp linop, SparsificationOptions opts,
-             unsigned numTensors, unsigned numLoops, unsigned numFilterLoops)
-      : linalgOp(linop), options(opts), topSort(),
-        merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr),
-        redExp(-1u), redKind(kNoReduc), redCustom(-1u), sparseOut(nullptr) {}
-
-  // Start emitting.
-  void startEmit(SparseTensorLoopEmitter *le) {
-    assert(!loopEmitter && "must only start emitting once");
-    loopEmitter = le;
-    if (sparseOut) {
-      insChain = sparseOut->get();
-      merger.setHasSparseOut(true);
-    }
-  }
-
-  // Delegate methods to merger.
-  TensorExp &exp(unsigned e) { return merger.exp(e); }
-  LatPoint &lat(unsigned l) { return merger.lat(l); }
-  SmallVector<unsigned> &set(unsigned s) { return merger.set(s); }
-  DimLevelType dimLevelType(unsigned t, unsigned i) const {
-    return merger.getDimLevelType(t, i);
-  }
-  DimLevelType dimLevelType(unsigned b) const {
-    return merger.getDimLevelType(b);
-  }
-  bool isFilterLoop(unsigned i) const { return merger.isFilterLoop(i); }
-
-  // Delegate methods to loop emitter.
-  Value getLoopIV(unsigned i) const { return loopEmitter->getLoopIV(i); }
-  const std::vector<Value> &getValBuffer() const {
-    return loopEmitter->getValBuffer();
-  }
-
-  // Convenience method to slice topsort.
-  ArrayRef<unsigned> getTopSortSlice(size_t n, size_t m) const {
-    return ArrayRef<unsigned>(topSort).slice(n, m);
-  }
-
-  // Convenience method to get current loop stack.
-  ArrayRef<unsigned> getLoopCurStack() const {
-    return getTopSortSlice(0, loopEmitter->getCurrentDepth());
-  }
-
-  // Convenience method to get the IV of the given loop index.
-  Value getLoopIdxValue(size_t loopIdx) const {
-    for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++)
-      if (topSort[lv] == loopIdx)
-        return getLoopIV(lv);
-    llvm_unreachable("invalid loop index");
-  }
-
-  // TODO: make private
-
-  /// Linalg operation.
-  linalg::GenericOp linalgOp;
-  /// Sparsification options.
-  SparsificationOptions options;
-  // Topological sort.
-  std::vector<unsigned> topSort;
-  /// Merger helper class.
-  Merger merger;
-  /// Loop emitter helper class (keep reference in scope!).
-  /// TODO: move emitter constructor up in time?
-  SparseTensorLoopEmitter *loopEmitter;
-  /// Current reduction, updated during code generation. When indices of a
-  /// reduction are exhausted, all inner loops can use a scalarized reduction.
-  unsigned redExp;
-  Value redVal;
-  Reduction redKind;
-  unsigned redCustom;
-  /// Sparse tensor as output. Implemented either through direct injective
-  /// insertion in lexicographic index order or through access pattern expansion
-  /// in the innermost loop nest (`expValues` through `expCount`).
-  OpOperand *sparseOut;
-  unsigned outerParNest;
-  Value insChain; // bookkeeping for insertion chain
-  Value expValues;
-  Value expFilled;
-  Value expAdded;
-  Value expCount;
-};
-
 /// A helper class that visits an affine expression and tries to find an
 /// AffineDimExpr to which the corresponding iterator from a GenericOp matches
 /// the desired iterator type.
@@ -212,14 +119,14 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<unsigned> loopStack,
 }
 
 /// Determines if affine expression is invariant.
-static bool isInvariantAffine(CodeGenEnv &env, AffineExpr a, unsigned ldx,
+static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, unsigned ldx,
                               bool &atLevel) {
   return isInvariantAffine(a, env.getLoopCurStack(), ldx, atLevel);
 }
 
 /// Helper method to construct a permuted dimension ordering
 /// that adheres to the given topological sort.
-static AffineMap permute(CodeGenEnv &env, AffineMap m) {
+static AffineMap permute(CodegenEnv &env, AffineMap m) {
   assert(m.getNumDims() + env.merger.getNumFilterLoops() ==
              env.topSort.size() &&
          "size mismatch");
@@ -346,7 +253,7 @@ static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) {
 /// Returns true if the sparse annotations and affine subscript
 /// expressions of all tensors are admissible. Returns false if
 /// no annotations are found or inadmissible constructs occur.
-static bool findSparseAnnotations(CodeGenEnv &env) {
+static bool findSparseAnnotations(CodegenEnv &env) {
   bool annotated = false;
   unsigned filterLdx = env.merger.getFilterLoopStartingIdx();
   for (OpOperand &t : env.linalgOp->getOpOperands()) {
@@ -371,7 +278,7 @@ static bool findSparseAnnotations(CodeGenEnv &env) {
 /// as we use adj matrix for the graph.
 /// The sorted result will put the first Reduction iterator to the
 /// latest possible index.
-static bool topSortOptimal(CodeGenEnv &env, unsigned n,
+static bool topSortOptimal(CodegenEnv &env, unsigned n,
                            ArrayRef<utils::IteratorType> iteratorTypes,
                            std::vector<unsigned> &inDegree,
                            std::vector<std::vector<bool>> &adjM) {
@@ -517,7 +424,7 @@ static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
 /// along fixed dimensions. Even for dense storage formats, however, the
 /// natural index order yields innermost unit-stride access with better
 /// spatial locality.
-static bool computeIterationGraph(CodeGenEnv &env, unsigned mask,
+static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
                                   OpOperand *skip = nullptr) {
   // Set up an n x n from/to adjacency matrix of the iteration graph
   // for the implicit loop indices i_0 .. i_n-1.
@@ -614,12 +521,12 @@ static bool isMaterializing(Value val) {
          val.getDefiningOp<bufferization::AllocTensorOp>();
 }
 
-/// Returns true when the tensor expression is admissible for env.
+/// Returns true when the tensor expression is admissible for codegen.
 /// Since all sparse input tensors are admissible, we just need to check
 /// whether the out tensor in the tensor expression codegen is admissible.
 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
 /// nesting depth when a "truly dynamic" sparse tensor output occurs.
-static bool isAdmissibleTensorExp(CodeGenEnv &env, unsigned exp) {
+static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp) {
   // We reject any expression that makes a reduction from `-outTensor`, as those
   // expressions create a dependency between the current iteration (i) and the
   // previous iteration (i-1). It would require iterating over the whole
@@ -693,48 +600,6 @@ static bool isAdmissibleTensorExp(CodeGenEnv &env, unsigned exp) {
   return false;
 }
 
-//===----------------------------------------------------------------------===//
-// Sparse compiler synthesis methods (reductions).
-//===----------------------------------------------------------------------===//
-
-/// Maps operation to reduction.
-static Reduction getReduction(Kind kind) {
-  switch (kind) {
-  case Kind::kAddF:
-  case Kind::kAddC:
-  case Kind::kAddI:
-  case Kind::kSubF:
-  case Kind::kSubC:
-  case Kind::kSubI:
-    return kSum;
-  case Kind::kMulF:
-  case Kind::kMulC:
-  case Kind::kMulI:
-    return kProduct;
-  case Kind::kAndI:
-    return kAnd;
-  case Kind::kOrI:
-    return kOr;
-  case Kind::kXorI:
-    return kXor;
-  case Kind::kReduce:
-    return kCustom;
-  default:
-    llvm_unreachable("unexpected reduction operator");
-  }
-}
-
-/// Updates scalarized reduction value.
-static void updateReduc(CodeGenEnv &env, Value reduc) {
-  assert(env.redKind != kNoReduc);
-  env.redVal = env.exp(env.redExp).val = reduc;
-}
-
-/// Extracts identity from custom reduce.
-static Value getCustomRedId(Operation *op) {
-  return dyn_cast<sparse_tensor::ReduceOp>(op).getIdentity();
-}
-
 //===----------------------------------------------------------------------===//
 // Sparse compiler synthesis methods (statements and expressions).
 //===----------------------------------------------------------------------===//
@@ -742,12 +607,12 @@ static Value getCustomRedId(Operation *op) {
 /// Generates loop boundary statements (entering/exiting loops). The function
 /// passes and updates the reduction value.
 static Optional<Operation *> genLoopBoundary(
-    CodeGenEnv &env,
+    CodegenEnv &env,
     function_ref<Optional<Operation *>(MutableArrayRef<Value> reduc)>
         callback) {
   SmallVector<Value> reduc;
-  if (env.redVal)
-    reduc.push_back(env.redVal);
+  if (env.isReduc())
+    reduc.push_back(env.getReduc());
   if (env.expValues)
     reduc.push_back(env.expCount);
   if (env.insChain)
@@ -757,8 +622,8 @@ static Optional<Operation *> genLoopBoundary(
 
   // Callback should do in-place update on reduction value vector.
   unsigned i = 0;
-  if (env.redVal)
-    updateReduc(env, reduc[i++]);
+  if (env.isReduc())
+    env.updateReduc(reduc[i++]);
   if (env.expValues)
     env.expCount = reduc[i++];
   if (env.insChain)
@@ -768,7 +633,7 @@ static Optional<Operation *> genLoopBoundary(
 }
 
 /// Local bufferization of all dense and sparse data structures.
-static void genBuffers(CodeGenEnv &env, OpBuilder &builder) {
+static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
   linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
   assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
@@ -810,7 +675,7 @@ static void genBuffers(CodeGenEnv &env, OpBuilder &builder) {
 }
 
 /// Generates index for load/store on sparse tensor.
-static Value genIndex(CodeGenEnv &env, OpOperand *t) {
+static Value genIndex(CodegenEnv &env, OpOperand *t) {
   auto map = env.linalgOp.getMatchingIndexingMap(t);
   auto enc = getSparseTensorEncoding(t->get().getType());
   AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1));
@@ -820,7 +685,7 @@ static Value genIndex(CodeGenEnv &env, OpOperand *t) {
 }
 
 /// Generates subscript for load/store on a dense or sparse tensor.
-static Value genSubscript(CodeGenEnv &env, OpBuilder &builder, OpOperand *t,
+static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
                           SmallVectorImpl<Value> &args) {
   linalg::GenericOp op = env.linalgOp;
   unsigned tensor = t->getOperandNumber();
@@ -841,7 +706,7 @@ static Value genSubscript(CodeGenEnv &env, OpBuilder &builder, OpOperand *t,
 }
 
 /// Generates insertion code to implement dynamic tensor load.
-static Value genInsertionLoad(CodeGenEnv &env, OpBuilder &builder,
+static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
                               OpOperand *t) {
   linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
@@ -856,15 +721,14 @@ static Value genInsertionLoad(CodeGenEnv &env, OpBuilder &builder,
 }
 
 /// Generates insertion code to implement dynamic tensor load for reduction.
-static Value genInsertionLoadReduce(CodeGenEnv &env, OpBuilder &builder,
+static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
                                     OpOperand *t) {
   linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
-  Value identity = getCustomRedId(env.exp(env.redCustom).op);
+  Value identity = env.getCustomRedId();
   // Direct lexicographic index order, tensor loads as identity.
-  if (!env.expValues) {
+  if (!env.expValues)
     return identity;
-  }
   // Load from expanded access pattern if filled, identity otherwise.
   Value index = genIndex(env, t);
   Value isFilled = builder.create<memref::LoadOp>(loc, env.expFilled, index);
@@ -873,7 +737,7 @@ static Value genInsertionLoadReduce(CodeGenEnv &env, OpBuilder &builder,
 }
 
 /// Generates insertion code to implement dynamic tensor store.
-static void genInsertionStore(CodeGenEnv &env, OpBuilder &builder, OpOperand *t,
+static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
                               Value rhs) {
   linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
@@ -920,7 +784,7 @@ static void genInsertionStore(CodeGenEnv &env, OpBuilder &builder, OpOperand *t,
 }
 
 /// Generates a load on a dense or sparse tensor.
-static Value genTensorLoad(CodeGenEnv &env, OpBuilder &builder, unsigned exp) {
+static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, unsigned exp) {
   // Test if the load was hoisted to a higher loop nest.
   Value val = env.exp(exp).val;
   if (val)
@@ -930,7 +794,7 @@ static Value genTensorLoad(CodeGenEnv &env, OpBuilder &builder, unsigned exp) {
   linalg::GenericOp op = env.linalgOp;
   OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
   if (&t == env.sparseOut) {
-    if (env.redCustom != -1u)
+    if (env.isCustomReduc())
       return genInsertionLoadReduce(env, builder, &t);
     return genInsertionLoad(env, builder, &t);
   }
@@ -941,13 +805,13 @@ static Value genTensorLoad(CodeGenEnv &env, OpBuilder &builder, unsigned exp) {
 }
 
 /// Generates a store on a dense or sparse tensor.
-static void genTensorStore(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp,
                            Value rhs) {
   linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
   // Test if this is a scalarized reduction.
-  if (env.redVal) {
-    updateReduc(env, rhs);
+  if (env.isReduc()) {
+    env.updateReduc(rhs);
     return;
   }
   // Store during insertion.
@@ -989,12 +853,12 @@ static void genTensorStore(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
 }
 
 /// Generates an invariant value.
-inline static Value genInvariantValue(CodeGenEnv &env, unsigned exp) {
+inline static Value genInvariantValue(CodegenEnv &env, unsigned exp) {
   return env.exp(exp).val;
 }
 
 /// Generates an index value.
-inline static Value genIndexValue(CodeGenEnv &env, unsigned idx) {
+inline static Value genIndexValue(CodegenEnv &env, unsigned idx) {
   return env.getLoopIdxValue(idx);
 }
 
@@ -1003,7 +867,7 @@ inline static Value genIndexValue(CodeGenEnv &env, unsigned idx) {
 /// branch or otherwise invariantly defined outside the loop nest, with the
 /// exception of index computations, which need to be relinked to actual
 /// inlined cloned code.
-static Value relinkBranch(CodeGenEnv &env, RewriterBase &rewriter, Block *block,
+static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
                           Value e, unsigned ldx) {
   if (Operation *def = e.getDefiningOp()) {
     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
@@ -1018,7 +882,7 @@ static Value relinkBranch(CodeGenEnv &env, RewriterBase &rewriter, Block *block,
 }
 
 /// Recursively generates tensor expression.
-static Value genExp(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp,
+static Value genExp(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
                     unsigned ldx) {
   linalg::GenericOp op = env.linalgOp;
   Location loc = op.getLoc();
@@ -1032,11 +896,8 @@ static Value genExp(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp,
   if (env.exp(exp).kind == Kind::kIndex)
     return genIndexValue(env, env.exp(exp).index);
 
-  // Make custom reduction identity accessible for expanded access pattern.
-  if (env.exp(exp).kind == Kind::kReduce) {
-    assert(env.redCustom == -1u);
-    env.redCustom = exp;
-  }
+  if (env.exp(exp).kind == Kind::kReduce)
+    env.startCustomReduc(exp); // enter custom
 
   Value v0 = genExp(env, rewriter, env.exp(exp).children.e0, ldx);
   Value v1 = genExp(env, rewriter, env.exp(exp).children.e1, ldx);
@@ -1048,20 +909,20 @@ static Value genExp(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp,
              env.exp(exp).kind == Kind::kSelect))
     ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
 
+  if (env.exp(exp).kind == Kind::kReduce)
+    env.endCustomReduc(); // exit custom
+
   if (env.exp(exp).kind == kSelect) {
     assert(!env.exp(exp).val);
     env.exp(exp).val = v0; // Preserve value for later use.
-  } else if (env.exp(exp).kind == Kind::kReduce) {
-    assert(env.redCustom != -1u);
-    env.redCustom = -1u;
   }
 
   return ee;
 }
 
 /// Hoists loop invariant tensor loads for which indices have been exhausted.
-static void genInvariants(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
-                          unsigned ldx, bool atStart, unsigned last = -1u) {
+static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
+                          unsigned ldx, bool atStart) {
   if (exp == -1u)
     return;
   if (env.exp(exp).kind == Kind::kTensor) {
@@ -1090,18 +951,11 @@ static void genInvariants(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
     if (lhs == &t) {
       // Start or end a scalarized reduction
       if (atStart) {
-        Kind kind = env.exp(last).kind;
-        Value load = kind == Kind::kReduce ? getCustomRedId(env.exp(last).op)
-                                           : genTensorLoad(env, builder, exp);
-        env.redKind = getReduction(kind);
-        env.redExp = exp;
-        updateReduc(env, load);
+        Value load = env.isCustomReduc() ? env.getCustomRedId()
+                                         : genTensorLoad(env, builder, exp);
+        env.startReduc(exp, load);
       } else {
-        Value redVal = env.redVal;
-        updateReduc(env, Value());
-        env.redExp = -1u;
-        env.redKind = kNoReduc;
-        genTensorStore(env, builder, exp, redVal);
+        genTensorStore(env, builder, exp, env.endReduc());
       }
     } else {
       // Start or end loop invariant hoisting of a tensor load.
@@ -1112,21 +966,25 @@ static void genInvariants(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
     // Traverse into the binary operations. Note that we only hoist
     // tensor loads, since subsequent MLIR/LLVM passes know how to
     // deal with all other kinds of derived loop invariants.
+    if (env.exp(exp).kind == Kind::kReduce)
+      env.startCustomReduc(exp); // enter custom
     unsigned e0 = env.exp(exp).children.e0;
     unsigned e1 = env.exp(exp).children.e1;
-    genInvariants(env, builder, e0, ldx, atStart, exp);
-    genInvariants(env, builder, e1, ldx, atStart, exp);
+    genInvariants(env, builder, e0, ldx, atStart);
+    genInvariants(env, builder, e1, ldx, atStart);
+    if (env.exp(exp).kind == Kind::kReduce)
+      env.endCustomReduc(); // exit custom
   }
 }
 
 /// Generates an expanded access pattern in innermost dimension.
-static void genExpansion(CodeGenEnv &env, OpBuilder &builder, unsigned at,
+static void genExpansion(CodegenEnv &env, OpBuilder &builder, unsigned at,
                          bool atStart) {
   linalg::GenericOp op = env.linalgOp;
   OpOperand *lhs = env.sparseOut;
   if (!lhs || env.outerParNest != op.getRank(lhs) - 1 || at != env.outerParNest)
     return; // not needed at this level
-  assert(env.redVal == nullptr);
+  assert(!env.isReduc());
   // Generate start or end of an expanded access pattern. Note that because
   // an expension does not rely on the ongoing contents of the sparse storage
   // scheme, we can use the original tensor as incoming SSA value (which
@@ -1166,7 +1024,7 @@ static void genExpansion(CodeGenEnv &env, OpBuilder &builder, unsigned at,
 /// Returns parallelization strategy. Any implicit loop in the Linalg
 /// operation that is marked "parallel" is a candidate. Whether it is actually
 /// converted to a parallel operation depends on the requested strategy.
-static bool isParallelFor(CodeGenEnv &env, bool isOuter, bool isSparse) {
+static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
   // Reject parallelization of sparse output.
   if (env.sparseOut)
     return false;
@@ -1190,7 +1048,7 @@ static bool isParallelFor(CodeGenEnv &env, bool isOuter, bool isSparse) {
 }
 
 /// Generates a for-loop on a single index.
-static Operation *genFor(CodeGenEnv &env, OpBuilder &builder, bool isOuter,
+static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
                          bool isInner, unsigned idx, size_t tid, size_t dim,
                          ArrayRef<size_t> extraTids,
                          ArrayRef<size_t> extraDims) {
@@ -1222,13 +1080,14 @@ static Operation *genFor(CodeGenEnv &env, OpBuilder &builder, bool isOuter,
 }
 
 /// Emit a while-loop for co-iteration over multiple indices.
-static Operation *genWhile(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
+static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx,
                            bool needsUniv, ArrayRef<size_t> condTids,
                            ArrayRef<size_t> condDims,
                            ArrayRef<size_t> extraTids,
                            ArrayRef<size_t> extraDims) {
   Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
-    // Construct the while-loop with a parameter for each index.
+    // Construct the while-loop with a parameter for each
+    // index.
     return env.loopEmitter->enterCoIterationOverTensorsAtDims(
         builder, env.linalgOp.getLoc(), condTids, condDims, needsUniv, reduc,
         extraTids, extraDims);
@@ -1239,7 +1098,7 @@ static Operation *genWhile(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
 
 /// Generates a for-loop or a while-loop, depending on whether it implements
 /// singleton iteration or co-iteration over the given conjunction.
-static Operation *genLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at,
+static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
                           bool needsUniv, ArrayRef<size_t> condTids,
                           ArrayRef<size_t> condDims, ArrayRef<size_t> extraTids,
                           ArrayRef<size_t> extraDims) {
@@ -1257,19 +1116,19 @@ static Operation *genLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at,
 }
 
 /// Generates the induction structure for a while-loop.
-static void finalizeWhileOp(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
+static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
                             bool needsUniv, BitVector &induction,
                             scf::WhileOp whileOp) {
   Location loc = env.linalgOp.getLoc();
   // Finalize each else branch of all if statements.
-  if (env.redVal || env.expValues || env.insChain) {
+  if (env.isReduc() || env.expValues || env.insChain) {
     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
                builder.getInsertionBlock()->getParentOp())) {
       unsigned y = 0;
       SmallVector<Value> yields;
-      if (env.redVal) {
-        yields.push_back(env.redVal);
-        updateReduc(env, ifOp.getResult(y++));
+      if (env.isReduc()) {
+        yields.push_back(env.getReduc());
+        env.updateReduc(ifOp.getResult(y++));
       }
       if (env.expValues) {
         yields.push_back(env.expCount);
@@ -1288,7 +1147,7 @@ static void finalizeWhileOp(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
 }
 
 /// Generates a single if-statement within a while-loop.
-static scf::IfOp genIf(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
+static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, unsigned idx,
                        BitVector &conditions) {
   Location loc = env.linalgOp.getLoc();
   SmallVector<Type> types;
@@ -1313,8 +1172,8 @@ static scf::IfOp genIf(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
     }
     cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
   }
-  if (env.redVal)
-    types.push_back(env.redVal.getType());
+  if (env.isReduc())
+    types.push_back(env.getReduc().getType());
   if (env.expValues)
     types.push_back(builder.getIndexType());
   if (env.insChain)
@@ -1325,13 +1184,13 @@ static scf::IfOp genIf(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
 }
 
 /// Generates end of true branch of if-statement within a while-loop.
-static void endIf(CodeGenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
+static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
                   Operation *loop, Value redInput, Value cntInput,
                   Value insInput) {
   SmallVector<Value> operands;
-  if (env.redVal) {
-    operands.push_back(env.redVal);
-    updateReduc(env, redInput);
+  if (env.isReduc()) {
+    operands.push_back(env.getReduc());
+    env.updateReduc(redInput);
   }
   if (env.expValues) {
     operands.push_back(env.expCount);
@@ -1352,7 +1211,7 @@ static void endIf(CodeGenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
 
 /// Starts a loop sequence at given level. Returns true if
 /// the universal loop index must be maintained at this level.
-static bool startLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
                          unsigned at, unsigned idx, unsigned ldx,
                          unsigned lts) {
   assert(!env.getLoopIdxValue(idx));
@@ -1394,7 +1253,7 @@ static bool startLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
   return false;
 }
 
-static void genConstantDenseAddressFromLevel(CodeGenEnv &env,
+static void genConstantDenseAddressFromLevel(CodegenEnv &env,
                                              OpBuilder &builder, unsigned tid,
                                              unsigned lvl) {
   // TODO: Handle affine expression on output tensor.
@@ -1416,7 +1275,7 @@ static void genConstantDenseAddressFromLevel(CodeGenEnv &env,
   }
 }
 
-static void genInitConstantDenseAddress(CodeGenEnv &env,
+static void genInitConstantDenseAddress(CodegenEnv &env,
                                         RewriterBase &rewriter) {
   // We can generate address for constant affine expression before any loops
   // starting from the first level as they do not depend on any thing.
@@ -1427,7 +1286,7 @@ static void genInitConstantDenseAddress(CodeGenEnv &env,
 }
 
 static void translateBitsToTidDimPairs(
-    CodeGenEnv &env, unsigned li, unsigned idx,
+    CodegenEnv &env, unsigned li, unsigned idx,
     SmallVectorImpl<size_t> &condTids, SmallVectorImpl<size_t> &condDims,
     SmallVectorImpl<size_t> &extraTids, SmallVectorImpl<size_t> &extraDims,
     SmallVectorImpl<size_t> &affineTids, SmallVectorImpl<size_t> &affineDims,
@@ -1513,7 +1372,7 @@ static void translateBitsToTidDimPairs(
 }
 
 /// Starts a single loop in current sequence.
-static Operation *startLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at,
+static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
                             unsigned li, bool needsUniv) {
   // The set of tensors + dims to generate loops on
   SmallVector<size_t> condTids, condDims;
@@ -1551,7 +1410,7 @@ static Operation *startLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at,
 }
 
 /// Ends a single loop in current sequence. Returns new values for needsUniv.
-static bool endLoop(CodeGenEnv &env, RewriterBase &rewriter, Operation *loop,
+static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
                     unsigned idx, unsigned li, bool needsUniv) {
   // End a while-loop.
   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
@@ -1569,7 +1428,7 @@ static bool endLoop(CodeGenEnv &env, RewriterBase &rewriter, Operation *loop,
 }
 
 /// Ends a loop sequence at given level.
-static void endLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
                        unsigned at, unsigned idx, unsigned ldx) {
   assert(env.getLoopIdxValue(idx) == nullptr);
   env.loopEmitter->exitCurrentLoopSeq();
@@ -1582,7 +1441,7 @@ static void endLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
 /// Recursively generates code while computing iteration lattices in order
 /// to manage the complexity of implementing co-iteration over unions
 /// and intersections of sparse iterations spaces.
-static void genStmt(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp,
+static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
                     unsigned at) {
   // At each leaf, assign remaining tensor (sub)expression to output tensor.
   if (at == env.topSort.size()) {
@@ -1612,7 +1471,7 @@ static void genStmt(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp,
 
     // Visit all lattices points with Li >= Lj to generate the
     // loop-body, possibly with if statements for coiteration.
-    Value redInput = env.redVal;
+    Value redInput = env.getReduc();
     Value cntInput = env.expCount;
     Value insInput = env.insChain;
     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
@@ -1640,7 +1499,7 @@ static void genStmt(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp,
 }
 
 /// Converts the result computed by the sparse kernel into the required form.
-static void genResult(CodeGenEnv &env, RewriterBase &rewriter) {
+static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
   linalg::GenericOp op = env.linalgOp;
   OpOperand *lhs = op.getDpsInitOperand(0);
   Value tensor = lhs->get();
@@ -1683,7 +1542,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     unsigned numTensors = op->getNumOperands();
     unsigned numLoops = op.getNumLoops();
     unsigned numFilterLoops = getNumCompoundAffineOnSparseDims(op);
-    CodeGenEnv env(op, options, numTensors, numLoops, numFilterLoops);
+    CodegenEnv env(op, options, numTensors, numLoops, numFilterLoops);
 
     // Detects sparse annotations and translates the per-dimension sparsity
     // information for all tensors to loop indices in the kernel.
@@ -1744,7 +1603,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
 
 private:
   // Last resort cycle resolution.
-  LogicalResult resolveCycle(CodeGenEnv &env, PatternRewriter &rewriter) const {
+  LogicalResult resolveCycle(CodegenEnv &env, PatternRewriter &rewriter) const {
     // Compute topological sort while leaving out every
     // sparse input tensor in succession until an acylic
     // iteration graph results.


        


More information about the Mlir-commits mailing list