[Mlir-commits] [mlir] 79f6f92 - [mlir][spirv] Enhance folding capability of spirv::CompositeExtractOp::fold

Jakub Kuderski llvmlistbot at llvm.org
Fri May 26 16:23:19 PDT 2023


Author: Nishant Patel
Date: 2023-05-26T19:23:02-04:00
New Revision: 79f6f92e10b7cebe2d73f3bbe493ee1392779e26

URL: https://github.com/llvm/llvm-project/commit/79f6f92e10b7cebe2d73f3bbe493ee1392779e26
DIFF: https://github.com/llvm/llvm-project/commit/79f6f92e10b7cebe2d73f3bbe493ee1392779e26.diff

LOG: [mlir][spirv] Enhance folding capability of spirv::CompositeExtractOp::fold

This PR improves the `spirv::CompositeExtractOp::fold` function by adding a backtracking mechanism.
The updated function can now traverse a chain of `CompositeInsertOp`s to find a match.

Patch By: nbpatel
Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D151536

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index cb4ae4efe6062..9219e31f11692 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -141,19 +141,24 @@ OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
-  if (auto insertOp =
-          getComposite().getDefiningOp<spirv::CompositeInsertOp>()) {
+  Value compositeOp = getComposite();
+
+  while (auto insertOp =
+             compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
     if (getIndices() == insertOp.getIndices())
       return insertOp.getObject();
+    compositeOp = insertOp.getComposite();
   }
 
   if (auto constructOp =
-          getComposite().getDefiningOp<spirv::CompositeConstructOp>()) {
+          compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
     auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
     if (getIndices().size() == 1 &&
         constructOp.getConstituents().size() == type.getNumElements()) {
-      auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
-      return constructOp.getConstituents()[i.getValue().getSExtValue()];
+      auto i = getIndices().begin()->cast<IntegerAttr>();
+      if (static_cast<size_t>(i.getValue().getSExtValue()) <
+          constructOp.getConstituents().size())
+        return constructOp.getConstituents()[i.getValue().getSExtValue()];
     }
   }
 

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index f543ed44f649d..d5dc97a6245b1 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -185,6 +185,31 @@ func.func @extract_construct(%val1: vector<2xf32>, %val2: vector<2xf32>) -> (vec
   return %1, %2 : vector<2xf32>, vector<2xf32>
 }
 
+// -----
+
+ // CHECK-LABEL: fold_composite_op
+ //  CHECK-SAME: (%[[COMP:.+]]: !spirv.struct<(f32, f32)>, %[[VAL1:.+]]: f32, %[[VAL2:.+]]: f32)
+  func.func @fold_composite_op(%composite: !spirv.struct<(f32, f32)>, %val1: f32, %val2: f32) -> f32 {
+    %insert = spirv.CompositeInsert %val1, %composite[0 : i32] : f32 into !spirv.struct<(f32, f32)>
+    %1 = spirv.CompositeInsert %val2, %insert[1 : i32] : f32 into !spirv.struct<(f32, f32)>
+    %2 = spirv.CompositeExtract %1[0 : i32] : !spirv.struct<(f32, f32)>
+    // CHECK-NEXT: return  %[[VAL1]]
+    return %2 : f32
+  }
+
+// -----
+
+ // CHECK-LABEL: fold_composite_op
+ //  CHECK-SAME: (%[[VAL1:.+]]: f32, %[[VAL2:.+]]: f32, %[[VAL3:.+]]: f32)
+  func.func @fold_composite_op(%val1: f32, %val2: f32, %val3: f32) -> f32 {
+    %composite = spirv.CompositeConstruct %val1, %val1, %val1 : (f32, f32, f32) -> !spirv.struct<(f32, f32, f32)>
+    %insert = spirv.CompositeInsert %val2, %composite[1 : i32] : f32 into !spirv.struct<(f32, f32, f32)>
+    %1 = spirv.CompositeInsert %val3, %insert[2 : i32] : f32 into !spirv.struct<(f32, f32, f32)>
+    %2 = spirv.CompositeExtract %1[0 : i32] : !spirv.struct<(f32, f32, f32)>
+    // CHECK-NEXT: return  %[[VAL1]]
+    return %2 : f32
+  }
+
 // -----
 
 // Not yet implemented case


        


More information about the Mlir-commits mailing list