[Mlir-commits] [mlir] be4e9db - [mlir][Linalg] NFC: Clean up for 0-D abstraction.

Hanhan Wang llvmlistbot at llvm.org
Fri Mar 20 13:08:48 PDT 2020


Author: Hanhan Wang
Date: 2020-03-20T13:07:14-07:00
New Revision: be4e9db5799a4d1c2350859b4582bbf25e39fff9

URL: https://github.com/llvm/llvm-project/commit/be4e9db5799a4d1c2350859b4582bbf25e39fff9
DIFF: https://github.com/llvm/llvm-project/commit/be4e9db5799a4d1c2350859b4582bbf25e39fff9.diff

LOG: [mlir][Linalg] NFC: Clean up for 0-D abstraction.

Summary:
After D75831 has been landed, both the generic op and indexed_generic op can
handle 0-D edge case. In the previous patch, only generic op has been updated.
This patch updates the lowering to loops for indexed_generic op. Since they are
almost the sanme, the patch also refactors the common part.

Differential Revision: https://reviews.llvm.org/D76413

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 817f0235ccff..316a5a75617a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -33,6 +33,7 @@ using namespace mlir::linalg;
 
 using edsc::op::operator+;
 using edsc::op::operator==;
+using mlir::edsc::intrinsics::detail::ValueHandleArray;
 
 static SmallVector<ValueHandle, 8>
 makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
@@ -81,6 +82,30 @@ SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
   return res;
 }
 
+template <typename OpType>
+static void inlineRegionAndEmitStdStore(OpType op,
+                                        ArrayRef<Value> indexedValues,
+                                        ArrayRef<ValueHandleArray> indexing,
+                                        ArrayRef<Value> outputBuffers) {
+  auto &b = ScopedContext::getBuilder();
+  auto &block = op.region().front();
+  BlockAndValueMapping map;
+  map.map(block.getArguments(), indexedValues);
+  for (auto &op : block.without_terminator()) {
+    assert(op.getNumRegions() == 0 && "expected a non-nested region");
+    auto *newOp = b.clone(op, map);
+    map.map(op.getResults(), newOp->getResults());
+  }
+
+  Operation &terminator = block.back();
+  assert(isa<YieldOp>(terminator) &&
+         "expected an yield op in the end of the region");
+  for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
+    std_store(map.lookup(terminator.getOperand(i)), outputBuffers[i],
+              indexing[i]);
+  }
+}
+
 namespace {
 template <typename IndexedValueType, typename LinalgOpType>
 class LinalgScopedEmitter {};
@@ -300,6 +325,8 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
     }
 
     // 1.b. Emit std_load from output views.
+    // TODO(mravishankar): Avoid the loads if the corresponding argument of the
+    // region has no uses.
     for (unsigned i = 0; i < nOutputs; ++i) {
       Value output = genericOp.getOutputBuffer(i);
       ValueHandleArray indexing(makeCanonicalAffineApplies(
@@ -324,24 +351,16 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
     }
     // TODO(ntv): When a region inliner exists, use it.
     // 2. Inline region, currently only works for a single basic block.
-    BlockAndValueMapping map;
-    auto &block = genericOp.region().front();
-    map.map(block.getArguments(), indexedValues);
-    for (auto &op : block.without_terminator()) {
-      assert(op.getNumRegions() == 0);
-      auto *newOp = b.clone(op, map);
-      map.map(op.getResults(), newOp->getResults());
-    }
-
     // 3. Emit std_store.
-    auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
-    assert(yieldOp->getNumOperands() == nOutputs);
+    SmallVector<ValueHandleArray, 8> indexing;
+    SmallVector<Value, 8> outputBuffers;
     for (unsigned i = 0; i < nOutputs; ++i) {
-      ValueHandleArray indexing(makeCanonicalAffineApplies(
+      indexing.emplace_back(makeCanonicalAffineApplies(
           b, loc, genericOp.getOutputIndexingMap(i), allIvs));
-      std_store(map.lookup(yieldOp->getOperand(i)),
-                genericOp.getOutputBuffer(i), indexing);
+      outputBuffers.push_back(genericOp.getOutputBuffer(i));
     }
+    inlineRegionAndEmitStdStore(genericOp, indexedValues, indexing,
+                                outputBuffers);
   }
 };
 
@@ -397,25 +416,17 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
     // 1.a. Emit std_load from input views.
     for (unsigned i = 0; i < nInputs; ++i) {
       Value input = indexedGenericOp.getInput(i);
-      if (input.getType().cast<ShapedType>().getRank()) {
-        ValueHandleArray indexing(makeCanonicalAffineApplies(
-            b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
-        indexedValues[nLoops + i] = std_load(input, indexing);
-      } else {
-        indexedValues[nLoops + i] = std_load(input);
-      }
+      ValueHandleArray indexing(makeCanonicalAffineApplies(
+          b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
+      indexedValues[nLoops + i] = std_load(input, indexing);
     }
 
     // 1.b. Emit std_load from output views.
     for (unsigned i = 0; i < nOutputs; ++i) {
       Value output = indexedGenericOp.getOutputBuffer(i);
-      if (output.getType().cast<ShapedType>().getRank()) {
-        ValueHandleArray indexing(makeCanonicalAffineApplies(
-            b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
-        indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
-      } else {
-        indexedValues[nLoops + nInputs + i] = std_load(output);
-      }
+      ValueHandleArray indexing(makeCanonicalAffineApplies(
+          b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+      indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
     }
 
     if (auto funcOp = indexedGenericOp.getFunction()) {
@@ -426,40 +437,24 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
       // 3. Emit std_store.
       for (unsigned i = 0; i < nOutputs; ++i) {
         Value output = indexedGenericOp.getOutputBuffer(i);
-        if (output.getType().cast<ShapedType>().getRank()) {
-          ValueHandleArray indexing(makeCanonicalAffineApplies(
-              b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
-          std_store(callOp->getResult(i), output, indexing);
-        } else {
-          std_store(callOp->getResult(i), output);
-        }
+        ValueHandleArray indexing(makeCanonicalAffineApplies(
+            b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+        std_store(callOp->getResult(i), output, indexing);
       }
       return;
     }
     // TODO(ntv): When a region inliner exists, use it.
     // 2. Inline region, currently only works for a single basic block.
-    BlockAndValueMapping map;
-    auto &block = indexedGenericOp.region().front();
-    map.map(block.getArguments(), indexedValues);
-    for (auto &op : block.without_terminator()) {
-      assert(op.getNumRegions() == 0);
-      auto *newOp = b.clone(op, map);
-      map.map(op.getResults(), newOp->getResults());
-    }
-
     // 3. Emit std_store.
-    auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
-    assert(yieldOp->getNumOperands() == nOutputs);
+    SmallVector<ValueHandleArray, 8> indexing;
+    SmallVector<Value, 8> outputBuffers;
     for (unsigned i = 0; i < nOutputs; ++i) {
-      Value output = indexedGenericOp.getOutputBuffer(i);
-      if (output.getType().cast<ShapedType>().getRank()) {
-        ValueHandleArray indexing(makeCanonicalAffineApplies(
-            b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
-        std_store(map.lookup(yieldOp->getOperand(i)), output, indexing);
-      } else {
-        std_store(map.lookup(yieldOp->getOperand(i)), output);
-      }
+      indexing.emplace_back(makeCanonicalAffineApplies(
+          b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+      outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
     }
+    inlineRegionAndEmitStdStore(indexedGenericOp, indexedValues, indexing,
+                                outputBuffers);
   }
 };
 


        


More information about the Mlir-commits mailing list