[Mlir-commits] [mlir] ce90957 - [mlir][spirv] Fold noop `BitcastsOp`s

Jakub Kuderski llvmlistbot at llvm.org
Fri Nov 4 14:38:11 PDT 2022


Author: Jakub Kuderski
Date: 2022-11-04T17:37:30-04:00
New Revision: ce90957461d5d5e4290a61267b4726d3842483d7

URL: https://github.com/llvm/llvm-project/commit/ce90957461d5d5e4290a61267b4726d3842483d7
DIFF: https://github.com/llvm/llvm-project/commit/ce90957461d5d5e4290a61267b4726d3842483d7.diff

LOG: [mlir][spirv] Fold noop `BitcastsOp`s

This allows for bitcast conversion to roundtrip.

Fixes: https://github.com/llvm/llvm-project/issues/58801

Reviewed By: antiagainst, Hardcode84, mravishankar

Differential Revision: https://reviews.llvm.org/D137459

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
    mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index c985c6e94e19e..8975fa01df403 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -88,7 +88,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
   let assemblyFormat = [{
     $operand attr-dict `:` type($operand) `to` type($result)
   }];
-  let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index b3444d8b210a6..57e6475548642 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -116,9 +116,23 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
 // spirv.BitcastOp
 //===----------------------------------------------------------------------===//
 
-void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                   MLIRContext *context) {
-  results.add<ConvertChainedBitcast>(context);
+OpFoldResult spirv::BitcastOp::fold(ArrayRef<Attribute> /*operands*/) {
+  Value arg = getOperand();
+  if (getType() == arg.getType())
+    return arg;
+
+  // Look through nested bitcasts.
+  if (auto bitcast = arg.getDefiningOp<spirv::BitcastOp>()) {
+    Value nestedArg = bitcast.getOperand();
+    if (nestedArg.getType() == getType())
+      return nestedArg;
+
+    getOperandMutable().assign(nestedArg);
+    return getResult();
+  }
+
+  // TODO(kuhar): Consider constant-folding the operand attribute.
+  return getResult();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
index 12c41fcaf0f00..e8d2274d29aa0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
@@ -13,13 +13,6 @@
 include "mlir/IR/PatternBase.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVOps.td"
 
-//===----------------------------------------------------------------------===//
-// spirv.Bitcast
-//===----------------------------------------------------------------------===//
-
-def ConvertChainedBitcast : Pat<(SPIRV_BitcastOp (SPIRV_BitcastOp $operand)),
-                                (SPIRV_BitcastOp $operand)>;
-
 //===----------------------------------------------------------------------===//
 // spirv.LogicalNot
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index b13d6443850c9..e65f92e66bb47 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -86,6 +86,30 @@ func.func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spirv.ptr<i
 
 // -----
 
+// CHECK-LABEL: @convert_bitcast_roundtip
+// CHECK-SAME:    %[[ARG:.+]]: i64
+func.func @convert_bitcast_roundtip(%arg0 : i64) -> i64 {
+  // CHECK: spirv.ReturnValue %[[ARG]]
+  %0 = spirv.Bitcast %arg0 : i64 to f64
+  %1 = spirv.Bitcast %0 : f64 to i64
+  spirv.ReturnValue %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @convert_bitcast_chained_roundtip
+// CHECK-SAME:    %[[ARG:.+]]: i64
+func.func @convert_bitcast_chained_roundtip(%arg0 : i64) -> i64 {
+  // CHECK: spirv.ReturnValue %[[ARG]]
+  %0 = spirv.Bitcast %arg0 : i64 to f64
+  %1 = spirv.Bitcast %0 : f64 to vector<2xi32>
+  %2 = spirv.Bitcast %1 : vector<2xi32> to vector<2xf32>
+  %3 = spirv.Bitcast %2 : vector<2xf32> to i64
+  spirv.ReturnValue %3 : i64
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.CompositeExtract
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list