[Mlir-commits] [mlir] 7c984be - [mlir] Propagate arith.index_cast past tensor.extract
Rob Suderman
llvmlistbot at llvm.org
Tue Jan 25 22:19:14 PST 2022
Author: Rob Suderman
Date: 2022-01-25T22:16:07-08:00
New Revision: 7c984be21a350d2fd227fd95f36a70165a523b99
URL: https://github.com/llvm/llvm-project/commit/7c984be21a350d2fd227fd95f36a70165a523b99
DIFF: https://github.com/llvm/llvm-project/commit/7c984be21a350d2fd227fd95f36a70165a523b99.diff
LOG: [mlir] Propagate arith.index_cast past tensor.extract
If we are extracting it is more useful to push the index_cast past the
extraction. This increases the chance the tensor.extract can evaluated at
compile time.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D118204
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5ae13a613c427..bba7f9770f21a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -425,11 +425,51 @@ struct ExtractElementFromTensorFromElements
}
};
+// Pushes the index_casts that occur before extractions to after the extract.
+// This minimizes type conversion in some cases and enables the extract
+// canonicalizer. This changes:
+//
+// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
+// %extract = tensor.extract %cast[%index] : tensor<1xindex>
+//
+// to the following:
+//
+// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
+// %cast = arith.index_cast %extract : i32 to index
+//
+// to just %element.
+//
+// Consider expanding this to a template and handle all tensor cast operations.
+struct ExtractElementFromIndexCast
+ : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+ PatternRewriter &rewriter) const final {
+ Location loc = extract.getLoc();
+ auto indexCast = extract.tensor().getDefiningOp<arith::IndexCastOp>();
+ if (!indexCast)
+ return failure();
+
+ Type elementTy = getElementTypeOrSelf(indexCast.getIn());
+
+ auto newExtract = rewriter.create<tensor::ExtractOp>(
+ loc, elementTy, indexCast.getIn(), extract.indices());
+
+ rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
+ newExtract);
+
+ return success();
+ }
+};
+
} // namespace
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractElementFromTensorFromElements>(context);
+ results
+ .add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e0ea5d777acb8..3084a262af7dd 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1200,3 +1200,17 @@ func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
%1 = tensor.expand_shape %0 [] : tensor<i32> into tensor<1xi32>
return %1 : tensor<1xi32>
}
+
+// -----
+
+// CHECK-LABEL: func @propogate_index_cast
+func @propogate_index_cast(%arg0: tensor<1xi32>) -> index {
+ // CHECK: %[[IDX:.+]] = arith.constant 0
+ // CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32>
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]]
+ // CHECK: return %[[CAST]] : index
+ %c0 = arith.constant 0 : index
+ %0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex>
+ %1 = tensor.extract %0[%c0] : tensor<1xindex>
+ return %1 : index
+}
More information about the Mlir-commits
mailing list