Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
#include "nanobind/typing.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

Expand Down Expand Up @@ -1482,7 +1483,11 @@ class PyConcreteValue : public PyValue {

/// Binds the Python module objects to functions of this class.
static void bind(nb::module_ &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName);
auto cls = ClassTy(
m, DerivedTy::pyClassName, nb::is_generic(),
nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
.str()
.c_str()));
cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
cls.def_static(
"isinstance",
Expand Down Expand Up @@ -4605,7 +4610,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of Value.
//----------------------------------------------------------------------------
nb::class_<PyValue>(m, "Value")
m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));

nb::class_<PyValue>(m, "Value", nb::is_generic(),
nb::sig("class Value(Generic[_T])"))
.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
"Creates a Value reference from another `Value`.")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
Expand Down Expand Up @@ -4737,7 +4745,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyValue &self, const PyType &type) {
mlirValueSetType(self.get(), type);
},
nb::arg("type"), "Sets the type of the value.")
nb::arg("type"), "Sets the type of the value.",
nb::sig("def set_type(self, type: _T)"))
.def(
"replace_all_uses_with",
[](PyValue &self, PyValue &with) {
Expand Down
14 changes: 7 additions & 7 deletions mlir/test/mlir-tblgen/op-python-bindings.td
Original file line number Diff line number Diff line change
Expand Up @@ -350,16 +350,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)

// CHECK: @builtins.property
// CHECK: def f32(self) -> _ods_ir.Value:
// CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]:
// CHECK: return self.operation.operands[1]
let arguments = (ins I32, F32:$f32, I64);

// CHECK: @builtins.property
// CHECK: def i32(self) -> _ods_ir.OpResult:
// CHECK: def i32(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
// CHECK: return self.operation.results[0]
//
// CHECK: @builtins.property
// CHECK: def i64(self) -> _ods_ir.OpResult:
// CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
// CHECK: return self.operation.results[2]
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
Expand Down Expand Up @@ -590,20 +590,20 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)

// CHECK: @builtins.property
// CHECK: def i32(self) -> _ods_ir.Value:
// CHECK: def i32(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
// CHECK: return self.operation.operands[0]
//
// CHECK: @builtins.property
// CHECK: def f32(self) -> _ods_ir.Value:
// CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]:
// CHECK: return self.operation.operands[1]
let arguments = (ins I32:$i32, F32:$f32);

// CHECK: @builtins.property
// CHECK: def i64(self) -> _ods_ir.OpResult:
// CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
// CHECK: return self.operation.results[0]
//
// CHECK: @builtins.property
// CHECK: def f64(self) -> _ods_ir.OpResult:
// CHECK: def f64(self) -> _ods_ir.OpResult[_ods_ir.FloatType]:
// CHECK: return self.operation.results[1]
let results = (outs I64:$i64, AnyFloat:$f64);
}
Expand Down
9 changes: 8 additions & 1 deletion mlir/test/python/dialects/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def testOptionalOperandOp():
)
assert (
typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
is OpResult
== OpResult[IntegerType]
)
assert type(op1.result) is OpResult

Expand Down Expand Up @@ -662,6 +662,13 @@ def testCustomType():
raise


@run
# CHECK-LABEL: TEST: testValue
def testValue():
# Check that Value is a generic class at runtime.
assert hasattr(Value, "__class_getitem__")


@run
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
Expand Down
31 changes: 30 additions & 1 deletion mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,22 @@ static std::string attrSizedTraitForKind(const char *kind) {
StringRef(kind).drop_front());
}

static StringRef getPythonType(StringRef cppType) {
return llvm::StringSwitch<StringRef>(cppType)
.Case("::mlir::MemRefType", "_ods_ir.MemRefType")
.Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType")
.Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType")
.Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType")
.Case("::mlir::VectorType", "_ods_ir.VectorType")
.Case("::mlir::IntegerType", "_ods_ir.IntegerType")
.Case("::mlir::FloatType", "_ods_ir.FloatType")
.Case("::mlir::IndexType", "_ods_ir.IndexType")
.Case("::mlir::ComplexType", "_ods_ir.ComplexType")
.Case("::mlir::TupleType", "_ods_ir.TupleType")
.Case("::mlir::NoneType", "_ods_ir.NoneType")
.Default(StringRef());
}

/// Emits accessors to "elements" of an Op definition. Currently, the supported
/// elements are operands and results, indicated by `kind`, which must be either
/// `operand` or `result` and is used verbatim in the emitted code.
Expand Down Expand Up @@ -370,8 +386,11 @@ static void emitElementAccessors(
seenVariableLength = true;
if (element.name.empty())
continue;
const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
if (StringRef pythonType = getPythonType(element.constraint.getCppType());
!pythonType.empty())
type = llvm::formatv("{0}[{1}]", type, pythonType);
if (element.isVariableLength()) {
if (element.isOptional()) {
os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
Expand Down Expand Up @@ -418,6 +437,11 @@ static void emitElementAccessors(
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
}
if (std::strcmp(kind, "operand") == 0) {
StringRef pythonType = getPythonType(element.constraint.getCppType());
if (!pythonType.empty())
type += "[" + pythonType.str() + "]";
}
Comment on lines +440 to +444
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is an error here. type may be "_ods_ir.OpOperandList" here which is not Generic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed here #167930

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
kind, numSimpleLength, numVariadicGroups,
numPrecedingSimple, numPrecedingVariadic, type);
Expand Down Expand Up @@ -449,6 +473,11 @@ static void emitElementAccessors(
if (!element.isVariableLength() || element.isOptional()) {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
if (std::strcmp(kind, "operand") == 0) {
StringRef pythonType = getPythonType(element.constraint.getCppType());
if (!pythonType.empty())
type += "[" + pythonType.str() + "]";
}
if (!element.isVariableLength()) {
trailing = "[0]";
} else if (element.isOptional()) {
Expand Down
Loading