Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- bump: minor
changes:
added:
- Added select_filing_status_value utility function for cleaner filing status selection with SINGLE as default.
changed:
- Refactored MD, CA, NY, AL, GA income tax calculations to use select_filing_status_value utility.
55 changes: 55 additions & 0 deletions policyengine_us/tests/tools/test_filing_status_utility.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Test that select_filing_status_value works correctly with MD income tax
- name: Test filing status utility with MD tax - single filer
period: 2024
input:
people:
person:
age: 30
tax_units:
tax_unit:
members: [person]
md_taxable_income: 50_000
households:
household:
members: [person]
state_code: MD
output:
md_income_tax_before_credits: 2_322.50

- name: Test filing status utility with MD tax - joint filers
period: 2024
input:
people:
head:
age: 30
spouse:
age: 28
tax_units:
tax_unit:
members: [head, spouse]
md_taxable_income: 100_000
households:
household:
members: [head, spouse]
state_code: MD
output:
md_income_tax_before_credits: 4_697.50

- name: Test filing status utility with MD tax - head of household
period: 2024
input:
people:
head:
age: 35
child:
age: 10
tax_units:
tax_unit:
members: [head, child]
md_taxable_income: 75_000
households:
household:
members: [head, child]
state_code: MD
output:
md_income_tax_before_credits: 3_510.00
78 changes: 78 additions & 0 deletions policyengine_us/tools/general.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the simplification. Question: does this only work with scale parameters (single amounts, marginal rates)? If not, can we have an example of a select function without a scale parameter

Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,81 @@ def get_previous_threshold(
return t[
max_((t <= values.reshape((1, len(values))).T).sum(axis=1) - 1, 0)
]


def select_filing_status_value(
filing_status: ArrayLike,
filing_status_values: dict,
input_value: ArrayLike = None,
**kwargs,
) -> ArrayLike:
"""
Select a value based on filing status, with SINGLE as the default.

This is a common pattern for selecting parameter values based on filing status.
According to IRS SOI data, SINGLE is the most common filing status.

Args:
filing_status: Array of filing status enum values
filing_status_values: Dict mapping filing status to values or functions
input_value: Optional input value to pass to functions (e.g., taxable income)

Returns:
Array of selected values based on filing status

Example:
# For parameter values
result = select_filing_status_value(
filing_status,
parameters.amount
)

# For calculated values (e.g., tax brackets)
result = select_filing_status_value(
filing_status,
parameters.rates,
taxable_income
)
"""
statuses = filing_status.possible_values

# Helper function to get value
def get_value(fs_value):
if input_value is not None and hasattr(fs_value, "calc"):
# It's a rate schedule or similar
return fs_value.calc(input_value, **kwargs)
elif hasattr(fs_value, "__call__"):
# It's a callable
return (
fs_value(input_value, **kwargs)
if input_value is not None
else fs_value(**kwargs)
)
else:
# It's a simple value
return fs_value

# Build conditions and values, excluding SINGLE
conditions = []
values = []

# Check each filing status except SINGLE
for status_name in [
"JOINT",
"SEPARATE",
"HEAD_OF_HOUSEHOLD",
"SURVIVING_SPOUSE",
]:
# Check if this enum value exists in this filing status enum
if hasattr(statuses, status_name):
status_enum = getattr(statuses, status_name)
if status_enum.name.lower() in filing_status_values:
conditions.append(filing_status == status_enum)
values.append(
get_value(filing_status_values[status_enum.name.lower()])
)

# SINGLE is the default
default_value = get_value(filing_status_values["single"])

return select(conditions, values, default=default_value)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from policyengine_us.model_api import *
from policyengine_us.tools.general import select_filing_status_value


class al_income_tax_before_non_refundable_credits(Variable):
Expand All @@ -16,21 +17,4 @@ def formula(tax_unit, period, parameters):
taxable_income = tax_unit("al_taxable_income", period)
p = parameters(period).gov.states.al.tax.income.rates

statuses = filing_status.possible_values

return select(
[
filing_status == statuses.SINGLE,
filing_status == statuses.SEPARATE,
filing_status == statuses.JOINT,
filing_status == statuses.SURVIVING_SPOUSE,
filing_status == statuses.HEAD_OF_HOUSEHOLD,
],
[
p.single.calc(taxable_income),
p.separate.calc(taxable_income),
p.joint.calc(taxable_income),
p.surviving_spouse.calc(taxable_income),
p.head_of_household.calc(taxable_income),
],
)
return select_filing_status_value(filing_status, p, taxable_income)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from policyengine_us.model_api import *
from policyengine_us.tools.general import select_filing_status_value


class ca_income_tax_before_credits(Variable):
Expand All @@ -15,21 +16,4 @@ def formula(tax_unit, period, parameters):
taxable_income = tax_unit("ca_taxable_income", period)
p = parameters(period).gov.states.ca.tax.income.rates

statuses = filing_status.possible_values

return select(
[
filing_status == statuses.SINGLE,
filing_status == statuses.SEPARATE,
filing_status == statuses.JOINT,
filing_status == statuses.SURVIVING_SPOUSE,
filing_status == statuses.HEAD_OF_HOUSEHOLD,
],
[
p.single.calc(taxable_income),
p.separate.calc(taxable_income),
p.joint.calc(taxable_income),
p.surviving_spouse.calc(taxable_income),
p.head_of_household.calc(taxable_income),
],
)
return select_filing_status_value(filing_status, p, taxable_income)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from policyengine_us.model_api import *
from policyengine_us.tools.general import select_filing_status_value


class ga_income_tax_before_non_refundable_credits(Variable):
Expand All @@ -12,21 +13,5 @@ class ga_income_tax_before_non_refundable_credits(Variable):
def formula(tax_unit, period, parameters):
p = parameters(period).gov.states.ga.tax.income.main
filing_status = tax_unit("filing_status", period)
status = filing_status.possible_values
income = tax_unit("ga_taxable_income", period)
return select(
[
filing_status == status.SINGLE,
filing_status == status.SEPARATE,
filing_status == status.JOINT,
filing_status == status.HEAD_OF_HOUSEHOLD,
filing_status == status.SURVIVING_SPOUSE,
],
[
p.single.calc(income),
p.separate.calc(income),
p.joint.calc(income),
p.head_of_household.calc(income),
p.surviving_spouse.calc(income),
],
)
return select_filing_status_value(filing_status, p, income)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from policyengine_us.model_api import *
from policyengine_us.tools.general import select_filing_status_value


class md_income_tax_before_credits(Variable):
Expand All @@ -11,26 +12,12 @@ class md_income_tax_before_credits(Variable):

def formula(tax_unit, period, parameters):
filing_status = tax_unit("filing_status", period)
filing_statuses = filing_status.possible_values
taxable_income = tax_unit("md_taxable_income", period)

# Calculate regular income tax based on filing status
p = parameters(period).gov.states.md.tax.income
regular_income_tax = select(
[
filing_status == filing_statuses.SINGLE,
filing_status == filing_statuses.SEPARATE,
filing_status == filing_statuses.JOINT,
filing_status == filing_statuses.HEAD_OF_HOUSEHOLD,
filing_status == filing_statuses.SURVIVING_SPOUSE,
],
[
p.rates.single.calc(taxable_income),
p.rates.separate.calc(taxable_income),
p.rates.joint.calc(taxable_income),
p.rates.head_of_household.calc(taxable_income),
p.rates.surviving_spouse.calc(taxable_income),
],
regular_income_tax = select_filing_status_value(
filing_status, p.rates, taxable_income
)

# Add capital gains surtax if applicable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from policyengine_us.model_api import *
from policyengine_us.tools.general import select_filing_status_value


class ny_main_income_tax(Variable):
Expand All @@ -12,28 +13,6 @@ class ny_main_income_tax(Variable):
def formula(tax_unit, period, parameters):
taxable_income = tax_unit("ny_taxable_income", period)
filing_status = tax_unit("filing_status", period)
status = filing_status.possible_values

rates = parameters(period).gov.states.ny.tax.income.main
single = rates.single
joint = rates.joint
hoh = rates.head_of_household
surviving_spouse = rates.surviving_spouse
separate = rates.separate

return select(
[
filing_status == status.SINGLE,
filing_status == status.JOINT,
filing_status == status.HEAD_OF_HOUSEHOLD,
filing_status == status.SURVIVING_SPOUSE,
filing_status == status.SEPARATE,
],
[
single.calc(taxable_income),
joint.calc(taxable_income),
hoh.calc(taxable_income),
surviving_spouse.calc(taxable_income),
separate.calc(taxable_income),
],
)
return select_filing_status_value(filing_status, rates, taxable_income)