Skip to content

Commit 08d1f58

Browse files
Optional output of tensor product computations (#919)
* Change to optional return type of tensor product functionality * Remove not needed wrappers * Fix * Include nanobind/optional.
1 parent beadcef commit 08d1f58

File tree

4 files changed

+92
-119
lines changed

4 files changed

+92
-119
lines changed

cpp/basix/finite-element.cpp

Lines changed: 49 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include <concepts>
2424
#include <limits>
2525
#include <numeric>
26+
#include <optional>
27+
#include <stdexcept>
2628

2729
#define str_macro(X) #X
2830
#define str(X) str_macro(X)
@@ -328,10 +330,15 @@ basix::create_tp_element(element::family family, cell::type cell, int degree,
328330
element::lagrange_variant lvariant,
329331
element::dpc_variant dvariant, bool discontinuous)
330332
{
331-
std::vector<int> dof_ordering = tp_dof_ordering(
333+
std::optional<std::vector<int>> dof_ordering = tp_dof_ordering(
332334
family, cell, degree, lvariant, dvariant, discontinuous);
335+
336+
if (!dof_ordering.has_value())
337+
throw std::runtime_error(
338+
"Element does not have tensor product factorisation.");
339+
333340
return create_element<T>(family, cell, degree, lvariant, dvariant,
334-
discontinuous, dof_ordering);
341+
discontinuous, *dof_ordering);
335342
}
336343
//-----------------------------------------------------------------------------
337344
template basix::FiniteElement<float>
@@ -342,60 +349,54 @@ basix::create_tp_element(element::family, cell::type, int,
342349
element::lagrange_variant, element::dpc_variant, bool);
343350
//-----------------------------------------------------------------------------
344351
template <std::floating_point T>
345-
std::vector<std::vector<FiniteElement<T>>>
352+
std::optional<std::vector<std::vector<FiniteElement<T>>>>
346353
basix::tp_factors(element::family family, cell::type cell, int degree,
347354
element::lagrange_variant lvariant,
348355
element::dpc_variant dvariant, bool discontinuous,
349356
const std::vector<int>& dof_ordering)
350357
{
351-
std::vector<int> tp_dofs = tp_dof_ordering(family, cell, degree, lvariant,
352-
dvariant, discontinuous);
353-
if (!tp_dofs.empty() && tp_dofs == dof_ordering)
358+
std::optional<std::vector<int>> tp_dofs = tp_dof_ordering(
359+
family, cell, degree, lvariant, dvariant, discontinuous);
360+
if (!tp_dofs.has_value() || tp_dofs->empty()
361+
|| tp_dofs.value() != dof_ordering)
362+
return std::nullopt;
363+
364+
switch (family)
354365
{
355-
switch (family)
356-
{
357-
case element::family::P:
366+
case element::family::P:
367+
{
368+
FiniteElement<T> sub_element
369+
= create_element<T>(element::family::P, cell::type::interval, degree,
370+
lvariant, dvariant, true);
371+
switch (cell)
358372
{
359-
FiniteElement<T> sub_element
360-
= create_element<T>(element::family::P, cell::type::interval, degree,
361-
lvariant, dvariant, true);
362-
switch (cell)
363-
{
364-
case cell::type::quadrilateral:
365-
{
366-
return {{sub_element, sub_element}};
367-
}
368-
case cell::type::hexahedron:
369-
{
370-
return {{sub_element, sub_element, sub_element}};
371-
}
372-
default:
373-
{
374-
throw std::runtime_error("Invalid celltype.");
375-
}
376-
}
377-
break;
378-
}
373+
case cell::type::quadrilateral:
374+
return {{{sub_element, sub_element}}};
375+
case cell::type::hexahedron:
376+
return {{{sub_element, sub_element, sub_element}}};
379377
default:
380-
{
381-
throw std::runtime_error("Invalid family.");
382-
}
378+
return std::nullopt;
383379
}
380+
break;
384381
}
385-
throw std::runtime_error(
386-
"Element does not have tensor product factorisation.");
382+
default:
383+
return std::nullopt;
384+
}
385+
// C++ 23:
386+
// std::unreachable()
387+
return std::nullopt;
387388
}
388389
//-----------------------------------------------------------------------------
389-
template std::vector<std::vector<basix::FiniteElement<float>>>
390+
template std::optional<std::vector<std::vector<basix::FiniteElement<float>>>>
390391
basix::tp_factors(element::family, cell::type, int, element::lagrange_variant,
391392
element::dpc_variant, bool, const std::vector<int>&);
392-
template std::vector<std::vector<basix::FiniteElement<double>>>
393+
template std::optional<std::vector<std::vector<basix::FiniteElement<double>>>>
393394
basix::tp_factors(element::family, cell::type, int, element::lagrange_variant,
394395
element::dpc_variant, bool, const std::vector<int>&);
395396
//-----------------------------------------------------------------------------
396-
std::vector<int> basix::tp_dof_ordering(element::family family, cell::type cell,
397-
int degree, element::lagrange_variant,
398-
element::dpc_variant, bool)
397+
std::optional<std::vector<int>>
398+
basix::tp_dof_ordering(element::family family, cell::type cell, int degree,
399+
element::lagrange_variant, element::dpc_variant, bool)
399400
{
400401
std::vector<int> dof_ordering;
401402
std::vector<int> perm;
@@ -488,21 +489,17 @@ std::vector<int> basix::tp_dof_ordering(element::family family, cell::type cell,
488489
break;
489490
}
490491
default:
491-
{
492-
}
492+
return std::nullopt;
493493
}
494494
break;
495495
}
496496
default:
497-
{
498-
}
497+
return std::nullopt;
499498
}
500499

501500
if (perm.size() == 0)
502-
{
503-
throw std::runtime_error(
504-
"Element does not have tensor product factorisation.");
505-
}
501+
return std::nullopt;
502+
506503
dof_ordering.resize(perm.size());
507504
for (std::size_t i = 0; i < perm.size(); ++i)
508505
dof_ordering[perm[i]] = i;
@@ -1037,7 +1034,7 @@ FiniteElement<F>::FiniteElement(
10371034
_embedded_superdegree(embedded_superdegree),
10381035
_embedded_subdegree(embedded_subdegree), _value_shape(value_shape),
10391036
_map_type(map_type), _sobolev_space(sobolev_space),
1040-
_discontinuous(discontinuous), _dof_ordering(dof_ordering)
1037+
_discontinuous(discontinuous), _dof_ordering(std::move(dof_ordering))
10411038
{
10421039
// Check that discontinuous elements only have DOFs on interior
10431040
if (discontinuous)
@@ -1055,14 +1052,10 @@ FiniteElement<F>::FiniteElement(
10551052
}
10561053
}
10571054

1058-
try
1059-
{
1060-
_tensor_factors = tp_factors<F>(family, cell_type, degree, lvariant,
1061-
dvariant, discontinuous, dof_ordering);
1062-
}
1063-
catch (...)
1064-
{
1065-
}
1055+
auto factors = tp_factors<F>(family, cell_type, degree, lvariant, dvariant,
1056+
discontinuous, _dof_ordering);
1057+
if (factors.has_value())
1058+
_tensor_factors = factors.value();
10661059

10671060
std::vector<F> wcoeffs_b(wcoeffs.extent(0) * wcoeffs.extent(1));
10681061
std::copy(wcoeffs.data_handle(), wcoeffs.data_handle() + wcoeffs.size(),

cpp/basix/finite-element.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <functional>
1818
#include <map>
1919
#include <numeric>
20+
#include <optional>
2021
#include <span>
2122
#include <string>
2223
#include <tuple>
@@ -1623,11 +1624,12 @@ FiniteElement<T> create_element(element::family family, cell::type cell,
16231624
/// @param[in] discontinuous Indicates whether the element is discontinuous
16241625
/// between cells points of the element. The discontinuous element will have the
16251626
/// same DOFs, but they will all be associated with the interior of the cell.
1626-
/// @return A vector containing the dof ordering
1627-
std::vector<int> tp_dof_ordering(element::family family, cell::type cell,
1628-
int degree, element::lagrange_variant lvariant,
1629-
element::dpc_variant dvariant,
1630-
bool discontinuous);
1627+
/// @return An optional vector containing the dof ordering, if has tensor
1628+
/// structure
1629+
std::optional<std::vector<int>>
1630+
tp_dof_ordering(element::family family, cell::type cell, int degree,
1631+
element::lagrange_variant lvariant,
1632+
element::dpc_variant dvariant, bool discontinuous);
16311633

16321634
/// Get the lexicographic DOF ordering for an element
16331635
/// @param[in] family The element family
@@ -1655,9 +1657,10 @@ std::vector<int> lex_dof_ordering(element::family family, cell::type cell,
16551657
/// between cells points of the element. The discontinuous element will have the
16561658
/// same DOFs, but they will all be associated with the interior of the cell.
16571659
/// @param[in] dof_ordering The ordering of the DOFs
1658-
/// @return A list of lists of finite element factors
1660+
/// @return An optioanl list of lists of finite element factors if family has
1661+
/// tensor structure
16591662
template <std::floating_point T>
1660-
std::vector<std::vector<FiniteElement<T>>>
1663+
std::optional<std::vector<std::vector<FiniteElement<T>>>>
16611664
tp_factors(element::family family, cell::type cell, int degree,
16621665
element::lagrange_variant lvariant, element::dpc_variant dvariant,
16631666
bool discontinuous, const std::vector<int>& dof_ordering);

python/basix/finite_element.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def tp_factors(
778778
discontinuous: bool = False,
779779
dof_ordering: typing.Optional[list[int]] = None,
780780
dtype: npt.DTypeLike = np.float64,
781-
) -> list[list[FiniteElement]]:
781+
) -> typing.Optional[list[list[FiniteElement]]]:
782782
"""Elements in the tensor product factorisation of an element.
783783
784784
If the element has no factorisation, raises a RuntimeError.
@@ -799,19 +799,21 @@ def tp_factors(
799799
Returns:
800800
A list of finite elements.
801801
"""
802-
return [
803-
[FiniteElement(e) for e in elements]
804-
for elements in _tp_factors(
805-
family,
806-
celltype,
807-
degree,
808-
lagrange_variant,
809-
dpc_variant,
810-
discontinuous,
811-
dof_ordering if dof_ordering is not None else [],
812-
np.dtype(dtype).char,
813-
)
814-
]
802+
factors = _tp_factors(
803+
family,
804+
celltype,
805+
degree,
806+
lagrange_variant,
807+
dpc_variant,
808+
discontinuous,
809+
dof_ordering if dof_ordering is not None else [],
810+
np.dtype(dtype).char,
811+
)
812+
813+
if factors is None:
814+
return None
815+
816+
return [[FiniteElement(e) for e in elements] for elements in factors]
815817

816818

817819
def tp_dof_ordering(
@@ -821,7 +823,7 @@ def tp_dof_ordering(
821823
lagrange_variant: LagrangeVariant = LagrangeVariant.unset,
822824
dpc_variant: DPCVariant = DPCVariant.unset,
823825
discontinuous: bool = False,
824-
) -> list[int]:
826+
) -> typing.Optional[list[int]]:
825827
"""Tensor product DOF ordering for an element.
826828
827829
This DOF ordering can be passed into create_element to create the

python/wrapper.cpp

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <memory>
1818
#include <nanobind/nanobind.h>
1919
#include <nanobind/ndarray.h>
20+
#include <nanobind/stl/optional.h>
2021
#include <nanobind/stl/pair.h>
2122
#include <nanobind/stl/string.h>
2223
#include <nanobind/stl/tuple.h>
@@ -522,8 +523,7 @@ NB_MODULE(_basixcpp, m)
522523
.value("prism", cell::type::prism)
523524
.value("pyramid", cell::type::pyramid);
524525

525-
m.def("cell_volume", [](cell::type cell_type) -> double
526-
{ return cell::volume<double>(cell_type); });
526+
m.def("cell_volume", &cell::volume<double>);
527527
m.def("cell_facet_normals", [](cell::type cell_type)
528528
{ return as_nbarrayp(cell::facet_normals<double>(cell_type)); });
529529
m.def(
@@ -641,8 +641,9 @@ NB_MODULE(_basixcpp, m)
641641
element::lagrange_variant lagrange_variant,
642642
element::dpc_variant dpc_variant, bool discontinuous,
643643
const std::vector<int>& dof_ordering, char dtype)
644-
-> std::variant<std::vector<std::vector<FiniteElement<float>>>,
645-
std::vector<std::vector<FiniteElement<double>>>>
644+
-> std::optional<
645+
std::variant<std::vector<std::vector<FiniteElement<float>>>,
646+
std::vector<std::vector<FiniteElement<double>>>>>
646647
{
647648
if (dtype == 'd')
648649
{
@@ -660,40 +661,16 @@ NB_MODULE(_basixcpp, m)
660661
throw std::runtime_error("Unsupported finite element dtype.");
661662
});
662663

663-
m.def("tp_dof_ordering",
664-
[](element::family family_name, cell::type cell, int degree,
665-
element::lagrange_variant lagrange_variant,
666-
element::dpc_variant dpc_variant,
667-
bool discontinuous) -> std::vector<int>
668-
{
669-
return basix::tp_dof_ordering(family_name, cell, degree,
670-
lagrange_variant, dpc_variant,
671-
discontinuous);
672-
});
673-
674-
m.def("lex_dof_ordering",
675-
[](element::family family_name, cell::type cell, int degree,
676-
element::lagrange_variant lagrange_variant,
677-
element::dpc_variant dpc_variant,
678-
bool discontinuous) -> std::vector<int>
679-
{
680-
return basix::lex_dof_ordering(family_name, cell, degree,
681-
lagrange_variant, dpc_variant,
682-
discontinuous);
683-
});
664+
m.def("tp_dof_ordering", &basix::tp_dof_ordering);
665+
m.def("lex_dof_ordering", &basix::lex_dof_ordering);
684666

685667
nb::enum_<polyset::type>(m, "PolysetType", nb::is_arithmetic(),
686668
"Polyset type.")
687669
.value("standard", polyset::type::standard)
688670
.value("macroedge", polyset::type::macroedge);
689671

690-
m.def("superset",
691-
[](cell::type cell, polyset::type type1, polyset::type type2)
692-
{ return polyset::superset(cell, type1, type2); });
693-
694-
m.def("restriction",
695-
[](polyset::type ptype, cell::type cell, cell::type restriction_cell)
696-
{ return polyset::restriction(ptype, cell, restriction_cell); });
672+
m.def("superset", &polyset::superset);
673+
m.def("restriction", &polyset::restriction);
697674

698675
m.def(
699676
"make_quadrature",
@@ -707,15 +684,13 @@ NB_MODULE(_basixcpp, m)
707684
as_nbarray(std::move(w)));
708685
});
709686

710-
m.def(
711-
"gauss_jacobi_rule",
712-
[](double a, int m)
713-
{
714-
auto [pts, w]
715-
= quadrature::gauss_jacobi_rule<double>(a, m);
716-
return std::pair(as_nbarray(std::move(pts)),
717-
as_nbarray(std::move(w)));
718-
});
687+
m.def("gauss_jacobi_rule",
688+
[](double a, int m)
689+
{
690+
auto [pts, w] = quadrature::gauss_jacobi_rule<double>(a, m);
691+
return std::pair(as_nbarray(std::move(pts)),
692+
as_nbarray(std::move(w)));
693+
});
719694

720695
m.def("index", nb::overload_cast<int>(&basix::indexing::idx));
721696
m.def("index", nb::overload_cast<int, int>(&basix::indexing::idx));

0 commit comments

Comments
 (0)