[Mlir-commits] [mlir] 8544523 - [mlir][tensor] Promote extract(from_elements(...)) to folding pattern

Matthias Springer llvmlistbot at llvm.org
Wed Apr 20 07:48:12 PDT 2022


Author: Matthias Springer
Date: 2022-04-20T23:47:42+09:00
New Revision: 8544523dcb6249bf3055c3a6ab0cb48586999a30

URL: https://github.com/llvm/llvm-project/commit/8544523dcb6249bf3055c3a6ab0cb48586999a30
DIFF: https://github.com/llvm/llvm-project/commit/8544523dcb6249bf3055c3a6ab0cb48586999a30.diff

LOG: [mlir][tensor] Promote extract(from_elements(...)) to folding pattern

Differential Revision: https://reviews.llvm.org/D123617

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f1181e63ec8f2..1f9f97769646b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -361,16 +361,13 @@ LogicalResult ExtractOp::verify() {
 }
 
 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
-  // The tensor operand must be a known constant.
-  Attribute tensor = operands.front();
-  if (!tensor)
-    return {};
   // If this is a splat elements attribute, simply return the value. All of the
   // elements of a splat attribute are the same.
-  if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
-    return splatTensor.getSplatValue<Attribute>();
+  if (Attribute tensor = operands.front())
+    if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
+      return splatTensor.getSplatValue<Attribute>();
 
-  // Otherwise, collect the constant indices into the tensor.
+  // Collect the constant indices into the tensor.
   SmallVector<uint64_t, 8> indices;
   for (Attribute indice : llvm::drop_begin(operands, 1)) {
     if (!indice || !indice.isa<IntegerAttr>())
@@ -378,10 +375,34 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
     indices.push_back(indice.cast<IntegerAttr>().getInt());
   }
 
+  // Fold extract(from_elements(...)).
+  if (auto fromElementsOp = tensor().getDefiningOp<FromElementsOp>()) {
+    auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
+    auto rank = tensorType.getRank();
+    assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
+           "rank mismatch");
+    int flatIndex = 0;
+    int stride = 1;
+    for (int i = rank - 1; i >= 0; --i) {
+      if (i < rank - 1)
+        stride *= tensorType.getDimSize(i);
+      flatIndex += indices[i] * stride;
+    }
+    // Prevent out of bounds accesses. This can happen in invalid code that will
+    // never execute.
+    if (static_cast<int>(fromElementsOp.elements().size()) <= flatIndex ||
+        flatIndex < 0)
+      return {};
+    return fromElementsOp.elements()[flatIndex];
+  }
+
   // If this is an elements attribute, query the value at the given indices.
-  auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
-  if (elementsAttr && elementsAttr.isValidIndex(indices))
-    return elementsAttr.getValues<Attribute>()[indices];
+  if (Attribute tensor = operands.front()) {
+    auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
+    if (elementsAttr && elementsAttr.isValidIndex(indices))
+      return elementsAttr.getValues<Attribute>()[indices];
+  }
+
   return {};
 }
 
@@ -411,47 +432,6 @@ OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
 
 namespace {
 
-// Canonicalizes the pattern of the form
-//
-// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
-// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
-//
-// to just %element.
-struct ExtractElementFromTensorFromElements
-    : public OpRewritePattern<tensor::ExtractOp> {
-  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
-                                PatternRewriter &rewriter) const final {
-    auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
-    if (!tensorFromElements)
-      return failure();
-    auto tensorType = tensorFromElements.getType().cast<RankedTensorType>();
-    auto rank = tensorType.getRank();
-    if (rank == 0) {
-      rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
-      return success();
-    }
-    SmallVector<APInt, 3> indices(rank);
-    int64_t flatIndex = 0;
-    int64_t stride = 1;
-    for (int i = rank - 1; i >= 0; --i) {
-      APInt index;
-      if (!matchPattern(extract.indices()[i], m_ConstantInt(&index)))
-        return failure();
-      if (i < rank - 1)
-        stride *= tensorType.getDimSize(i);
-      flatIndex += index.getSExtValue() * stride;
-    }
-    // Prevent out of bounds accesses. This can happen in invalid code that will
-    // never execute.
-    if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0)
-      return failure();
-    rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex));
-    return success();
-  }
-};
-
 // Pushes the index_casts that occur before extractions to after the extract.
 // This minimizes type conversion in some cases and enables the extract
 // canonicalizer. This changes:
@@ -494,9 +474,7 @@ struct ExtractElementFromIndexCast
 
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results
-      .add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
-          context);
+  results.add<ExtractElementFromIndexCast>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index ced1ca525546c..a3bd3d0d26796 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -22,12 +22,8 @@ func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
 //   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
 //   CHECK-DAG:   %[[C5:.+]] = arith.constant 5 : index
 //   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
-//   CHECK-DAG:   %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
-//   CHECK-DAG:   %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
 //   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
-//   CHECK-DAG:   %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
-//   CHECK-DAG:   %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
-//       CHECK:   return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
+//       CHECK:   return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
 
 // -----
 


        


More information about the Mlir-commits mailing list