[Mlir-commits] [mlir] ce90957 - [mlir][spirv] Fold noop `BitcastsOp`s
Jakub Kuderski
llvmlistbot at llvm.org
Fri Nov 4 14:38:11 PDT 2022
Author: Jakub Kuderski
Date: 2022-11-04T17:37:30-04:00
New Revision: ce90957461d5d5e4290a61267b4726d3842483d7
URL: https://github.com/llvm/llvm-project/commit/ce90957461d5d5e4290a61267b4726d3842483d7
DIFF: https://github.com/llvm/llvm-project/commit/ce90957461d5d5e4290a61267b4726d3842483d7.diff
LOG: [mlir][spirv] Fold noop `BitcastsOp`s
This allows for bitcast conversion to roundtrip.
Fixes: https://github.com/llvm/llvm-project/issues/58801
Reviewed By: antiagainst, Hardcode84, mravishankar
Differential Revision: https://reviews.llvm.org/D137459
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index c985c6e94e19e..8975fa01df403 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -88,7 +88,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
- let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index b3444d8b210a6..57e6475548642 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -116,9 +116,23 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
// spirv.BitcastOp
//===----------------------------------------------------------------------===//
-void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<ConvertChainedBitcast>(context);
+OpFoldResult spirv::BitcastOp::fold(ArrayRef<Attribute> /*operands*/) {
+ Value arg = getOperand();
+ if (getType() == arg.getType())
+ return arg;
+
+ // Look through nested bitcasts.
+ if (auto bitcast = arg.getDefiningOp<spirv::BitcastOp>()) {
+ Value nestedArg = bitcast.getOperand();
+ if (nestedArg.getType() == getType())
+ return nestedArg;
+
+ getOperandMutable().assign(nestedArg);
+ return getResult();
+ }
+
+ // TODO(kuhar): Consider constant-folding the operand attribute.
+ return getResult();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
index 12c41fcaf0f00..e8d2274d29aa0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
@@ -13,13 +13,6 @@
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/SPIRV/IR/SPIRVOps.td"
-//===----------------------------------------------------------------------===//
-// spirv.Bitcast
-//===----------------------------------------------------------------------===//
-
-def ConvertChainedBitcast : Pat<(SPIRV_BitcastOp (SPIRV_BitcastOp $operand)),
- (SPIRV_BitcastOp $operand)>;
-
//===----------------------------------------------------------------------===//
// spirv.LogicalNot
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index b13d6443850c9..e65f92e66bb47 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -86,6 +86,30 @@ func.func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spirv.ptr<i
// -----
+// CHECK-LABEL: @convert_bitcast_roundtip
+// CHECK-SAME: %[[ARG:.+]]: i64
+func.func @convert_bitcast_roundtip(%arg0 : i64) -> i64 {
+ // CHECK: spirv.ReturnValue %[[ARG]]
+ %0 = spirv.Bitcast %arg0 : i64 to f64
+ %1 = spirv.Bitcast %0 : f64 to i64
+ spirv.ReturnValue %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @convert_bitcast_chained_roundtip
+// CHECK-SAME: %[[ARG:.+]]: i64
+func.func @convert_bitcast_chained_roundtip(%arg0 : i64) -> i64 {
+ // CHECK: spirv.ReturnValue %[[ARG]]
+ %0 = spirv.Bitcast %arg0 : i64 to f64
+ %1 = spirv.Bitcast %0 : f64 to vector<2xi32>
+ %2 = spirv.Bitcast %1 : vector<2xi32> to vector<2xf32>
+ %3 = spirv.Bitcast %2 : vector<2xf32> to i64
+ spirv.ReturnValue %3 : i64
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.CompositeExtract
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list