[Mlir-commits] [mlir] 6af81ea - [mlir][std] Fold load(tensor_to_memref) into extract_element

Stephan Herhut llvmlistbot at llvm.org
Fri Nov 20 04:42:23 PST 2020


Author: Stephan Herhut
Date: 2020-11-20T13:42:11+01:00
New Revision: 6af81ea1d6d36c7151a61f65e21b5c4ad9cf859d

URL: https://github.com/llvm/llvm-project/commit/6af81ea1d6d36c7151a61f65e21b5c4ad9cf859d
DIFF: https://github.com/llvm/llvm-project/commit/6af81ea1d6d36c7151a61f65e21b5c4ad9cf859d.diff

LOG: [mlir][std] Fold load(tensor_to_memref) into extract_element

This canonicalization is useful to resolve loads into scalar values when
doing partial bufferization.

Differential Revision: https://reviews.llvm.org/D91855

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 8512c933e424..1ad3df63c1c9 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2234,6 +2234,7 @@ def LoadOp : Std_Op<"load",
     operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
   }];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 
   let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 6e755daa2669..04efc25a92ee 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2293,6 +2293,30 @@ OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
   return OpFoldResult();
 }
 
+namespace {
+/// Fold a load on a tensor_to_memref operation into an extract_element on the
+/// corresponding tensor.
+struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
+  using OpRewritePattern<LoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(LoadOp load,
+                                PatternRewriter &rewriter) const override {
+    auto tensorToMemref = load.memref().getDefiningOp<TensorToMemrefOp>();
+    if (!tensorToMemref)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ExtractElementOp>(load, tensorToMemref.tensor(),
+                                                  load.indices());
+    return success();
+  }
+};
+} // end anonymous namespace.
+
+void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                         MLIRContext *context) {
+  results.insert<LoadOfTensorToMemref>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 51475371244b..ebc59c8dbeac 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -45,6 +45,20 @@ func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
   return %1 : index
 }
 
+// Test case: Folding of load(tensor_to_memref(%v, %idxs))
+//            -> extract_element(%v, %idx)
+// CHECK-LABEL: func @load_from_tensor_to_memref(
+//  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+//  CHECK-SAME:     %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+//       CHECK:   %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
+//   CHECK-NOT:   load
+//       CHECK:   return %[[RES]] : f32
+func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
+  %0 = tensor_to_memref %arg2 : memref<?x?xf32>
+  %1 = load %0[%arg0, %arg1] : memref<?x?xf32>
+  return %1 : f32
+}
+
 // Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx
 // CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements(
 //  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index


        


More information about the Mlir-commits mailing list