[Mlir-commits] [mlir] 217b5c5 - [mlir][spirv] Add some folders for spv.CompositeExtract

Lei Zhang llvmlistbot at llvm.org
Fri Sep 2 14:21:09 PDT 2022


Author: Lei Zhang
Date: 2022-09-02T17:20:58-04:00
New Revision: 217b5c50b9037da13ee83a76c68f0f20a86eadb0

URL: https://github.com/llvm/llvm-project/commit/217b5c50b9037da13ee83a76c68f0f20a86eadb0
DIFF: https://github.com/llvm/llvm-project/commit/217b5c50b9037da13ee83a76c68f0f20a86eadb0.diff

LOG: [mlir][spirv] Add some folders for spv.CompositeExtract

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D133167

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 4396c96766880..8db35960dc668 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -126,7 +126,21 @@ void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
+  if (auto insertOp = composite().getDefiningOp<spirv::CompositeInsertOp>()) {
+    if (indices() == insertOp.indices())
+      return insertOp.object();
+  }
+
+  if (auto constructOp =
+          composite().getDefiningOp<spirv::CompositeConstructOp>()) {
+    auto type = constructOp.getType().cast<spirv::CompositeType>();
+    if (indices().size() == 1 &&
+        constructOp.constituents().size() == type.getNumElements()) {
+      auto i = indices().begin()->cast<IntegerAttr>();
+      return constructOp.constituents()[i.getValue().getSExtValue()];
+    }
+  }
+
   auto indexVector =
       llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
         return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 6d1c0dceed59b..41fa122f99124 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -137,6 +137,47 @@ func.func @extract_from_not_constant() -> i32 {
 
 // -----
 
+// CHECK-LABEL: extract_insert
+//  CHECK-SAME: (%[[COMP:.+]]: !spv.array<1 x vector<2xf32>>, %[[VAL:.+]]: f32)
+func.func @extract_insert(%composite: !spv.array<1xvector<2xf32>>, %val: f32) -> (f32, f32) {
+  // CHECK: %[[INSERT:.+]] = spv.CompositeInsert %[[VAL]], %[[COMP]]
+  %insert = spv.CompositeInsert %val, %composite[0 : i32, 1 : i32] : f32 into !spv.array<1xvector<2xf32>>
+  %1 = spv.CompositeExtract %insert[0 : i32, 0 : i32] : !spv.array<1xvector<2xf32>>
+  // CHECK: %[[S:.+]] = spv.CompositeExtract %[[INSERT]][0 : i32, 0 : i32]
+  %2 = spv.CompositeExtract %insert[0 : i32, 1 : i32] : !spv.array<1xvector<2xf32>>
+  // CHECK: return %[[S]], %[[VAL]]
+  return %1, %2 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: extract_construct
+//  CHECK-SAME: (%[[VAL1:.+]]: vector<2xf32>, %[[VAL2:.+]]: vector<2xf32>)
+func.func @extract_construct(%val1: vector<2xf32>, %val2: vector<2xf32>) -> (vector<2xf32>, vector<2xf32>) {
+  %construct = spv.CompositeConstruct %val1, %val2 : (vector<2xf32>, vector<2xf32>) -> !spv.array<2xvector<2xf32>>
+  %1 = spv.CompositeExtract %construct[0 : i32] : !spv.array<2xvector<2xf32>>
+  %2 = spv.CompositeExtract %construct[1 : i32] : !spv.array<2xvector<2xf32>>
+  // CHECK: return %[[VAL1]], %[[VAL2]]
+  return %1, %2 : vector<2xf32>, vector<2xf32>
+}
+
+// -----
+
+// Not yet implemented case
+
+// CHECK-LABEL: extract_construct
+func.func @extract_construct(%val1: vector<3xf32>, %val2: f32) -> (f32, f32) {
+  // CHECK: spv.CompositeConstruct
+  %construct = spv.CompositeConstruct %val1, %val2 : (vector<3xf32>, f32) -> vector<4xf32>
+  // CHECK: spv.CompositeExtract
+  %1 = spv.CompositeExtract %construct[0 : i32] : vector<4xf32>
+  // CHECK: spv.CompositeExtract
+  %2 = spv.CompositeExtract %construct[1 : i32] : vector<4xf32>
+  return %1, %2 : f32, f32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.Constant
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list