[Mlir-commits] [mlir] [mlir][affine] Support vector types in `affine.apply` (PR #129442)

Ivan Butygin llvmlistbot at llvm.org
Sat Mar 15 08:42:08 PDT 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/129442

>From a33196cd92cb1744ff0ef58e18a8d6302950c17f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 2 Mar 2025 13:26:20 +0100
Subject: [PATCH] [mlir][affine] Support vector types in `affine.apply`

`affine.apply` is generally useful outside of affine to generate various index computations.
Add support for vectors of index to enable vectorized code generation.

All operands and result types must match.
Type is optional in asm format and assumed `index` if missing so it's backward compatible with exisiting text IR, to reduce churn.
---
 .../mlir/Dialect/Affine/IR/AffineOps.td       | 13 +++--
 mlir/include/mlir/Dialect/Affine/Utils.h      | 10 ++--
 mlir/include/mlir/IR/CommonTypeConstraints.td |  4 ++
 .../AffineToStandard/AffineToStandard.cpp     |  2 +-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 57 +++++++++++++++----
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       | 40 ++++++++-----
 .../AffineToStandard/lower-affine.mlir        | 24 +++++---
 mlir/test/Dialect/Affine/invalid.mlir         | 14 ++++-
 8 files changed, 120 insertions(+), 44 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 6cd3408e2b2e9..8be819323fd6f 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -13,12 +13,13 @@
 #ifndef AFFINE_OPS
 #define AFFINE_OPS
 
-include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
+include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/CommonTypeConstraints.td"
 
 def Affine_Dialect : Dialect {
   let name = "affine";
@@ -57,18 +58,22 @@ def AffineApplyOp : Affine_Op<"apply", [Pure]> {
     %2 = affine.apply affine_map<(i)[s0] -> (i+s0)> (%42)[%n]
     ```
   }];
-  let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$mapOperands);
-  let results = (outs Index);
+  let arguments = (ins AffineMapAttr:$map, Variadic<IndexLike>:$mapOperands);
+  let results = (outs IndexLike);
 
   // TODO: The auto-generated builders should check to see if the return type
   // has a constant builder. That way we wouldn't need to explicitly specify the
   // result types here.
   let builders = [
-    OpBuilder<(ins "ArrayRef<AffineExpr> ":$exprList,"ValueRange":$mapOperands),
+    OpBuilder<(ins "ArrayRef<AffineExpr>":$exprList,"ValueRange":$mapOperands),
     [{
       build($_builder, $_state, $_builder.getIndexType(),
             AffineMap::inferFromExprList(exprList, $_builder.getContext())
                                         .front(), mapOperands);
+    }]>,
+    OpBuilder<(ins "AffineMap":$map,"ValueRange":$mapOperands),
+    [{
+      build($_builder, $_state, $_builder.getIndexType(), map, mapOperands);
     }]>
   ];
 
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index ff1900bc8f2eb..ab39239a77312 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -294,14 +294,14 @@ void createAffineComputationSlice(Operation *opInst,
 /// Emit code that computes the given affine expression using standard
 /// arithmetic operations applied to the provided dimension and symbol values.
 Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
-                       ValueRange dimValues, ValueRange symbolValues);
+                       ValueRange dimValues, ValueRange symbolValues,
+                       Type type = {});
 
 /// Create a sequence of operations that implement the `affineMap` applied to
 /// the given `operands` (as it it were an AffineApplyOp).
-std::optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
-                                                     Location loc,
-                                                     AffineMap affineMap,
-                                                     ValueRange operands);
+std::optional<SmallVector<Value, 8>>
+expandAffineMap(OpBuilder &builder, Location loc, AffineMap affineMap,
+                ValueRange operands, Type type = {});
 
 /// Holds the result of (div a, b)  and (mod a, b).
 struct DivModValue {
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..65d2ddb6a5c85 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -892,6 +892,10 @@ class TypeOrValueSemanticsContainer<Type allowedType, string name>
 // bools.
 def BoolLike : TypeOrValueSemanticsContainer<I1, "bool-like">;
 
+// Type constraint for index-like types: index, vectors of index, tensors of
+// index.
+def IndexLike : TypeOrValueSemanticsContainer<Index, "index-like">;
+
 // Type constraint for signless-integer-like types: signless integers or
 // value-semantics containers of signless integers.
 def SignlessIntegerLike : TypeOrValueSemanticsContainer<
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 4fbe6a03f6bad..e9fb745a068a4 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -336,7 +336,7 @@ class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
                                 PatternRewriter &rewriter) const override {
     auto maybeExpandedMap =
         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
-                        llvm::to_vector<8>(op.getOperands()));
+                        llvm::to_vector<8>(op.getOperands()), op.getType());
     if (!maybeExpandedMap)
       return failure();
     rewriter.replaceOp(op, *maybeExpandedMap);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8acb21d5074b4..717d3bbd8e3e9 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -491,20 +491,37 @@ static void printDimAndSymbolList(Operation::operand_iterator begin,
     printer << '[' << operands.drop_front(numDims) << ']';
 }
 
-/// Parses dimension and symbol list and returns true if parsing failed.
-ParseResult mlir::affine::parseDimAndSymbolList(
-    OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
-  SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
+/// Parse dimension and symbol list, but not resolve yet, as we may not know the
+/// operands types.
+static ParseResult parseDimAndSymbolListImpl(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &opInfos,
+    unsigned &numDims) {
   if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
     return failure();
+
   // Store number of dimensions for validation by caller.
   numDims = opInfos.size();
 
   // Parse the optional symbol operands.
+  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::OptionalSquare))
+    return failure();
+
+  return success();
+}
+
+/// Parses dimension and symbol list and returns true if parsing failed.
+ParseResult mlir::affine::parseDimAndSymbolList(
+    OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
+  SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
+  if (parseDimAndSymbolListImpl(parser, opInfos, numDims))
+    return failure();
+
   auto indexTy = parser.getBuilder().getIndexType();
-  return failure(parser.parseOperandList(
-                     opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
-                 parser.resolveOperands(opInfos, indexTy, operands));
+  if (parser.resolveOperands(opInfos, indexTy, operands))
+    return failure();
+
+  return success();
 }
 
 /// Utility function to verify that a set of operands are valid dimension and
@@ -538,14 +555,25 @@ AffineValueMap AffineApplyOp::getAffineValueMap() {
 
 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
   auto &builder = parser.getBuilder();
-  auto indexTy = builder.getIndexType();
 
   AffineMapAttr mapAttr;
   unsigned numDims;
+  SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
   if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
-      parseDimAndSymbolList(parser, result.operands, numDims) ||
+      parseDimAndSymbolListImpl(parser, opInfos, numDims) ||
       parser.parseOptionalAttrDict(result.attributes))
     return failure();
+
+  Type type;
+  if (parser.parseOptionalColon()) {
+    type = builder.getIndexType();
+  } else if (parser.parseType(type)) {
+    return failure();
+  }
+
+  if (parser.resolveOperands(opInfos, type, result.operands))
+    return failure();
+
   auto map = mapAttr.getValue();
 
   if (map.getNumDims() != numDims ||
@@ -554,7 +582,7 @@ ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
                             "dimension or symbol index mismatch");
   }
 
-  result.types.append(map.getNumResults(), indexTy);
+  result.types.append(map.getNumResults(), type);
   return success();
 }
 
@@ -563,9 +591,18 @@ void AffineApplyOp::print(OpAsmPrinter &p) {
   printDimAndSymbolList(operand_begin(), operand_end(),
                         getAffineMap().getNumDims(), p);
   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
+  Type resType = getType();
+  if (!isa<IndexType>(resType))
+    p << ":" << resType;
 }
 
 LogicalResult AffineApplyOp::verify() {
+  // Check all operand and result types are the same.
+  // We cannot use `SameOperandsAndResultType` as it expects at least 1 operand.
+  if (!llvm::all_equal(
+          llvm::concat<Type>(getOperandTypes(), (*this)->getResultTypes())))
+    return emitOpError("requires the same type for all operands and results");
+
   // Check input and output dimensions match.
   AffineMap affineMap = getMap();
 
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 2723cff6900d0..0342ae3ac6908 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -46,9 +46,9 @@ class AffineApplyExpander
   /// This internal class expects arguments to be non-null, checks must be
   /// performed at the call site.
   AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
-                      ValueRange symbolValues, Location loc)
+                      ValueRange symbolValues, Location loc, Type type)
       : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
-        loc(loc) {}
+        loc(loc), type(type) {}
 
   template <typename OpTy>
   Value buildBinaryExpr(AffineBinaryOpExpr expr,
@@ -189,8 +189,16 @@ class AffineApplyExpander
   }
 
   Value visitConstantExpr(AffineConstantExpr expr) {
-    auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
-    return op.getResult();
+    int64_t value = expr.getValue();
+    if (isa<IndexType>(type))
+      return builder.create<arith::ConstantIndexOp>(loc, value);
+
+    if (auto shaped = dyn_cast<ShapedType>(type)) {
+      auto elements = DenseIntElementsAttr::get(shaped, value);
+      return builder.create<arith::ConstantOp>(loc, elements);
+    }
+
+    llvm_unreachable("AffineApplyExpander: Unsupported type");
   }
 
   Value visitDimExpr(AffineDimExpr expr) {
@@ -211,6 +219,7 @@ class AffineApplyExpander
   ValueRange symbolValues;
 
   Location loc;
+  Type type;
 };
 } // namespace
 
@@ -219,23 +228,28 @@ class AffineApplyExpander
 mlir::Value mlir::affine::expandAffineExpr(OpBuilder &builder, Location loc,
                                            AffineExpr expr,
                                            ValueRange dimValues,
-                                           ValueRange symbolValues) {
-  return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
+                                           ValueRange symbolValues, Type type) {
+  if (!type)
+    type = builder.getIndexType();
+
+  return AffineApplyExpander(builder, dimValues, symbolValues, loc, type)
+      .visit(expr);
 }
 
 /// Create a sequence of operations that implement the `affineMap` applied to
 /// the given `operands` (as it it were an AffineApplyOp).
 std::optional<SmallVector<Value, 8>>
 mlir::affine::expandAffineMap(OpBuilder &builder, Location loc,
-                              AffineMap affineMap, ValueRange operands) {
+                              AffineMap affineMap, ValueRange operands,
+                              Type type) {
   auto numDims = affineMap.getNumDims();
   auto expanded = llvm::to_vector<8>(
-      llvm::map_range(affineMap.getResults(),
-                      [numDims, &builder, loc, operands](AffineExpr expr) {
-                        return expandAffineExpr(builder, loc, expr,
-                                                operands.take_front(numDims),
-                                                operands.drop_front(numDims));
-                      }));
+      llvm::map_range(affineMap.getResults(), [numDims, &builder, loc, operands,
+                                               type](AffineExpr expr) {
+        return expandAffineExpr(builder, loc, expr,
+                                operands.take_front(numDims),
+                                operands.drop_front(numDims), type);
+      }));
   if (llvm::all_of(expanded, [](Value v) { return v; }))
     return expanded;
   return std::nullopt;
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 550ea71882e14..4f01f05dfb6b4 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -429,8 +429,9 @@ func.func @min_reduction_tree(%v1 : index, %v2 : index, %v3 : index, %v4 : index
 #map5 = affine_map<(d0,d1,d2) -> (d0,d1,d2)>
 #map6 = affine_map<(d0,d1,d2) -> (d0 + d1 + d2)>
 
-// CHECK-LABEL: func @affine_applies(
-func.func @affine_applies(%arg0 : index) {
+// CHECK-LABEL: func @affine_applies
+//  CHECK-SAME:   (%[[ARG0:.*]]: index, %[[ARG1:.*]]: vector<4xindex>)
+func.func @affine_applies(%arg0 : index, %arg1 : vector<4xindex>) {
 // CHECK: %[[c0:.*]] = arith.constant 0 : index
   %zero = affine.apply #map0()
 
@@ -448,24 +449,29 @@ func.func @affine_applies(%arg0 : index) {
   %one = affine.apply #map3(%symbZero)[%zero]
 
 // CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index
-// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] overflow<nsw> : index
-// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] : index
+// CHECK-NEXT: %[[v2:.*]] = arith.muli %[[ARG0]], %[[c2]] overflow<nsw> : index
+// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[ARG0]], %[[v2]] : index
 // CHECK-NEXT: %[[c3:.*]] = arith.constant 3 : index
-// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] overflow<nsw> : index
+// CHECK-NEXT: %[[v4:.*]] = arith.muli %[[ARG0]], %[[c3]] overflow<nsw> : index
 // CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] : index
 // CHECK-NEXT: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] overflow<nsw> : index
+// CHECK-NEXT: %[[v6:.*]] = arith.muli %[[ARG0]], %[[c4]] overflow<nsw> : index
 // CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] : index
 // CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index
-// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] overflow<nsw> : index
+// CHECK-NEXT: %[[v8:.*]] = arith.muli %[[ARG0]], %[[c5]] overflow<nsw> : index
 // CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] : index
 // CHECK-NEXT: %[[c6:.*]] = arith.constant 6 : index
-// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] overflow<nsw> : index
+// CHECK-NEXT: %[[v10:.*]] = arith.muli %[[ARG0]], %[[c6]] overflow<nsw> : index
 // CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] : index
 // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] overflow<nsw> : index
+// CHECK-NEXT: %[[v12:.*]] = arith.muli %[[ARG0]], %[[c7]] overflow<nsw> : index
 // CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index
   %four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
+
+// CHECK-NEXT: %[[v14:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : vector<4xindex>
+// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1> : vector<4xindex>
+// CHECK-NEXT: %[[v15:.*]] = arith.addi %[[v14]], %[[cst]] : vector<4xindex>
+  %vec = affine.apply #map3(%arg1)[%arg1] : vector<4xindex>
   return
 }
 
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 9bbd19c381163..af948aa56eef5 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -5,7 +5,7 @@
 func.func @affine_apply_operand_non_index(%arg0 : i32) {
   // Custom parser automatically assigns all arguments the `index` so we must
   // use the generic syntax here to exercise the verifier.
-  // expected-error at +1 {{op operand #0 must be variadic of index, but got 'i32'}}
+  // expected-error at +1 {{op operand #0 must be variadic of index-like, but got 'i32'}}
   %0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (i32) -> (index)
   return
 }
@@ -15,11 +15,21 @@ func.func @affine_apply_operand_non_index(%arg0 : i32) {
 func.func @affine_apply_resul_non_index(%arg0 : index) {
   // Custom parser automatically assigns `index` as the result type so we must
   // use the generic syntax here to exercise the verifier.
-  // expected-error at +1 {{op result #0 must be index, but got 'i32'}}
+  // expected-error at +1 {{op result #0 must be index-like, but got 'i32'}}
   %0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (index) -> (i32)
   return
 }
 
+// -----
+
+func.func @affine_apply_types_match(%arg0 : index) {
+  // We are now supporting vectors of index, but all operands and result types
+  // must match.
+  // expected-error at +1 {{op requires the same type for all operands and results}}
+  %0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (index) -> (vector<4xindex>)
+  return
+}
+
 // -----
 func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
   "unknown"() ({



More information about the Mlir-commits mailing list