[Mlir-commits] [mlir] 6af81ea - [mlir][std] Fold load(tensor_to_memref) into extract_element
Stephan Herhut
llvmlistbot at llvm.org
Fri Nov 20 04:42:23 PST 2020
Author: Stephan Herhut
Date: 2020-11-20T13:42:11+01:00
New Revision: 6af81ea1d6d36c7151a61f65e21b5c4ad9cf859d
URL: https://github.com/llvm/llvm-project/commit/6af81ea1d6d36c7151a61f65e21b5c4ad9cf859d
DIFF: https://github.com/llvm/llvm-project/commit/6af81ea1d6d36c7151a61f65e21b5c4ad9cf859d.diff
LOG: [mlir][std] Fold load(tensor_to_memref) into extract_element
This canonicalization is useful to resolve loads into scalar values when
doing partial bufferization.
Differential Revision: https://reviews.llvm.org/D91855
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 8512c933e424..1ad3df63c1c9 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2234,6 +2234,7 @@ def LoadOp : Std_Op<"load",
operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 6e755daa2669..04efc25a92ee 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2293,6 +2293,30 @@ OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
return OpFoldResult();
}
+namespace {
+/// Fold a load on a tensor_to_memref operation into an extract_element on the
+/// corresponding tensor.
+struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
+ using OpRewritePattern<LoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(LoadOp load,
+ PatternRewriter &rewriter) const override {
+ auto tensorToMemref = load.memref().getDefiningOp<TensorToMemrefOp>();
+ if (!tensorToMemref)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ExtractElementOp>(load, tensorToMemref.tensor(),
+ load.indices());
+ return success();
+ }
+};
+} // end anonymous namespace.
+
+void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<LoadOfTensorToMemref>(context);
+}
+
//===----------------------------------------------------------------------===//
// MemRefCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 51475371244b..ebc59c8dbeac 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -45,6 +45,20 @@ func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
return %1 : index
}
+// Test case: Folding of load(tensor_to_memref(%v, %idxs))
+// -> extract_element(%v, %idx)
+// CHECK-LABEL: func @load_from_tensor_to_memref(
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK: %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
+// CHECK-NOT: load
+// CHECK: return %[[RES]] : f32
+func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
+ %0 = tensor_to_memref %arg2 : memref<?x?xf32>
+ %1 = load %0[%arg0, %arg1] : memref<?x?xf32>
+ return %1 : f32
+}
+
// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx
// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements(
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
More information about the Mlir-commits
mailing list