[Mlir-commits] [mlir] 7c984be - [mlir] Propagate arith.index_cast past tensor.extract

Rob Suderman llvmlistbot at llvm.org
Tue Jan 25 22:19:14 PST 2022


Author: Rob Suderman
Date: 2022-01-25T22:16:07-08:00
New Revision: 7c984be21a350d2fd227fd95f36a70165a523b99

URL: https://github.com/llvm/llvm-project/commit/7c984be21a350d2fd227fd95f36a70165a523b99
DIFF: https://github.com/llvm/llvm-project/commit/7c984be21a350d2fd227fd95f36a70165a523b99.diff

LOG: [mlir] Propagate arith.index_cast past tensor.extract

If we are extracting it is more useful to push the index_cast past the
extraction. This increases the chance the tensor.extract can evaluated at
compile time.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D118204

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5ae13a613c427..bba7f9770f21a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -425,11 +425,51 @@ struct ExtractElementFromTensorFromElements
   }
 };
 
+// Pushes the index_casts that occur before extractions to after the extract.
+// This minimizes type conversion in some cases and enables the extract
+// canonicalizer. This changes:
+//
+// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
+// %extract = tensor.extract %cast[%index] : tensor<1xindex>
+//
+// to the following:
+//
+// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
+// %cast = arith.index_cast %extract : i32 to index
+//
+// to just %element.
+//
+// Consider expanding this to a template and handle all tensor cast operations.
+struct ExtractElementFromIndexCast
+    : public OpRewritePattern<tensor::ExtractOp> {
+  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+                                PatternRewriter &rewriter) const final {
+    Location loc = extract.getLoc();
+    auto indexCast = extract.tensor().getDefiningOp<arith::IndexCastOp>();
+    if (!indexCast)
+      return failure();
+
+    Type elementTy = getElementTypeOrSelf(indexCast.getIn());
+
+    auto newExtract = rewriter.create<tensor::ExtractOp>(
+        loc, elementTy, indexCast.getIn(), extract.indices());
+
+    rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
+                                                    newExtract);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results.add<ExtractElementFromTensorFromElements>(context);
+  results
+      .add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e0ea5d777acb8..3084a262af7dd 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1200,3 +1200,17 @@ func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
   %1 = tensor.expand_shape %0 [] : tensor<i32> into tensor<1xi32>
   return %1 : tensor<1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func @propogate_index_cast
+func @propogate_index_cast(%arg0: tensor<1xi32>) -> index {
+  // CHECK: %[[IDX:.+]] = arith.constant 0
+  // CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32>
+  // CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]]
+  // CHECK: return %[[CAST]] : index
+  %c0 = arith.constant 0 : index
+  %0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex>
+  %1 = tensor.extract %0[%c0] : tensor<1xindex>
+  return %1 : index
+}


        


More information about the Mlir-commits mailing list