[Mlir-commits] [mlir] 7db18ef - [mlir][SPIR-V] Fold IAddCarry/[SU]MulExtended constants into spirv.Constant struct (#198633)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 20 02:08:44 PDT 2026


Author: Arseniy Obolenskiy
Date: 2026-05-20T11:08:37+02:00
New Revision: 7db18ef27472ff7d26f7029f7bf87511c63fda88

URL: https://github.com/llvm/llvm-project/commit/7db18ef27472ff7d26f7029f7bf87511c63fda88
DIFF: https://github.com/llvm/llvm-project/commit/7db18ef27472ff7d26f7029f7bf87511c63fda88.diff

LOG: [mlir][SPIR-V] Fold IAddCarry/[SU]MulExtended constants into spirv.Constant struct (#198633)

Emit a single spirv.Constant of struct type instead of the spirv.Undef +
two spirv.CompositeInsert workaround

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 2bae0d23c9f00..e4b48ed680d52 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -125,17 +125,13 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
 // spirv.IAddCarry
 //===----------------------------------------------------------------------===//
 
-// We are required to use CompositeConstructOp to create a constant struct as
-// they are not yet implemented as constant, hence we can not do so in a fold.
 struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
   using Base::Base;
 
   LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
                                 PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
     Value lhs = op.getOperand1();
     Value rhs = op.getOperand2();
-    Type constituentType = lhs.getType();
 
     // iaddcarry (x, 0) = <0, x>
     if (matchPattern(rhs, m_Zero())) {
@@ -177,20 +173,8 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
     if (!carrys)
       return failure();
 
-    Value addsVal =
-        spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
-
-    Value carrysVal =
-        spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
-
-    // Create empty struct
-    Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
-    // Fill in adds at id 0
-    Value intermediate =
-        spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
-    // Fill in carrys at id 1
-    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
-                                                          intermediate, 1);
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, op.getType(), rewriter.getArrayAttr({adds, carrys}));
     return success();
   }
 };
@@ -204,8 +188,6 @@ void spirv::IAddCarryOp::getCanonicalizationPatterns(
 // spirv.[S|U]MulExtended
 //===----------------------------------------------------------------------===//
 
-// We are required to use CompositeConstructOp to create a constant struct as
-// they are not yet implemented as constant, hence we can not do so in a fold.
 template <typename MulOp, bool IsSigned>
 struct MulExtendedFold final : OpRewritePattern<MulOp> {
   using OpRewritePattern<MulOp>::OpRewritePattern;
@@ -258,20 +240,8 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
     if (!highBits)
       return failure();
 
-    Value lowBitsVal =
-        spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
-
-    Value highBitsVal =
-        spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
-
-    // Create empty struct
-    Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
-    // Fill in lowBits at id 0
-    Value intermediate =
-        spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
-    // Fill in highBits at id 1
-    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
-                                                          intermediate, 1);
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, op.getType(), rewriter.getArrayAttr({lowBits, highBits}));
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 722c27586aa61..235ca15c08d7b 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -356,20 +356,12 @@ func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.s
   %cn5 = spirv.Constant -5 : i32
   %cn8 = spirv.Constant -8 : i32
 
-  // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
-  // CHECK-DAG: %[[CN3:.*]] = spirv.Constant -3
-  // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
-  // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN3]], %[[UNDEF1]][0 : i32]
-  // CHECK-DAG: %[[CC_CN3_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER1]][1 : i32]
-  // CHECK-DAG: %[[C1:.*]] = spirv.Constant 1
-  // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
-  // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
-  // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[CN13]], %[[UNDEF2]][0 : i32]
-  // CHECK-DAG: %[[CC_CN13_C1:.*]] = spirv.CompositeInsert %[[C1]], %[[INTER2]][1 : i32]
+  // CHECK-DAG: %[[CST_CN3_C0:.*]] = spirv.Constant [-3 : i32, 0 : i32] : !spirv.struct<(i32, i32)>
+  // CHECK-DAG: %[[CST_CN13_C1:.*]] = spirv.Constant [-13 : i32, 1 : i32] : !spirv.struct<(i32, i32)>
   %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
   %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
 
-  // CHECK: return %[[CC_CN3_C0]], %[[CC_CN13_C1]]
+  // CHECK: return %[[CST_CN3_C0]], %[[CST_CN13_C1]]
   return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
 }
 
@@ -378,14 +370,10 @@ func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector
   %v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
   %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
 
-  // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[-3, -11, 0]>
-  // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[0, 1, 1]>
-  // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
-  // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32]
-  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32]
+  // CHECK: %[[CST:.*]] = spirv.Constant [dense<[-3, -11, 0]> : vector<3xi32>, dense<[0, 1, 1]> : vector<3xi32>] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
   %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 
-  // CHECK: return %[[CC_CV1_CV2]]
+  // CHECK: return %[[CST]]
   return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 }
 
@@ -476,20 +464,12 @@ func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spir
   %cn5 = spirv.Constant -5 : i32
   %cn8 = spirv.Constant -8 : i32
 
-  // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
-  // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1
-  // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
-  // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32]
-  // CHECK-DAG: %[[CC_CN40_CN1:.*]] = spirv.CompositeInsert %[[CN1]], %[[INTER1]]
-  // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
-  // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
-  // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
-  // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32]
-  // CHECK-DAG: %[[CC_C40_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER2]][1 : i32]
+  // CHECK-DAG: %[[CST_CN40_CN1:.*]] = spirv.Constant [-40 : i32, -1 : i32] : !spirv.struct<(i32, i32)>
+  // CHECK-DAG: %[[CST_C40_C0:.*]] = spirv.Constant [40 : i32, 0 : i32] : !spirv.struct<(i32, i32)>
   %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
   %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
 
-  // CHECK: return %[[CC_CN40_CN1]], %[[CC_C40_C0]]
+  // CHECK: return %[[CST_CN40_CN1]], %[[CST_C40_C0]]
   return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
 }
 
@@ -498,14 +478,10 @@ func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vec
   %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
   %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
 
-  // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
-  // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, 0, -1]>
-  // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
-  // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32]
-  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32]
+  // CHECK: %[[CST:.*]] = spirv.Constant [dense<[2147483643, 40, -1]> : vector<3xi32>, dense<[2, 0, -1]> : vector<3xi32>] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
   %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 
-  // CHECK: return %[[CC_CV1_CV2]]
+  // CHECK: return %[[CST]]
   return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 
 }
@@ -545,21 +521,12 @@ func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spir
   %cn5 = spirv.Constant -5 : i32
   %cn8 = spirv.Constant -8 : i32
 
-
-  // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
-  // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
-  // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
-  // CHECK-DAG: %[[C4:.*]] = spirv.Constant 4
-  // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
-  // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32]
-  // CHECK-DAG: %[[CC_CN40_C4:.*]] = spirv.CompositeInsert %[[C4]], %[[INTER1]][1 : i32]
-  // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
-  // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32]
-  // CHECK-DAG: %[[CC_C40_CN13:.*]] = spirv.CompositeInsert %[[CN13]], %[[INTER2]][1 : i32]
+  // CHECK-DAG: %[[CST_CN40_C4:.*]] = spirv.Constant [-40 : i32, 4 : i32] : !spirv.struct<(i32, i32)>
+  // CHECK-DAG: %[[CST_C40_CN13:.*]] = spirv.Constant [40 : i32, -13 : i32] : !spirv.struct<(i32, i32)>
   %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
   %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
 
-  // CHECK: return %[[CC_CN40_C4]], %[[CC_C40_CN13]]
+  // CHECK: return %[[CST_CN40_C4]], %[[CST_C40_CN13]]
   return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
 }
 
@@ -568,14 +535,10 @@ func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vec
   %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
   %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
 
-  // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
-  // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, -13, 0]>
-  // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
-  // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]]
-  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]]
+  // CHECK: %[[CST:.*]] = spirv.Constant [dense<[2147483643, 40, -1]> : vector<3xi32>, dense<[2, -13, 0]> : vector<3xi32>] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
   %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 
-  // CHECK: return %[[CC_CV1_CV2]]
+  // CHECK: return %[[CST]]
   return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 }
 


        


More information about the Mlir-commits mailing list