diff --git a/experiments/ClimaEarth/Manifest-v1.11.toml b/experiments/ClimaEarth/Manifest-v1.11.toml index b9c553721a..7023d019ae 100644 --- a/experiments/ClimaEarth/Manifest-v1.11.toml +++ b/experiments/ClimaEarth/Manifest-v1.11.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.11.5" +julia_version = "1.11.6" manifest_format = "2.0" -project_hash = "f84997fb003cca7320c93d5448b91701d73e02c0" +project_hash = "b5f4dacdb351c0782e8ab711936dbbcf80091b48" [[deps.ADTypes]] git-tree-sha1 = "8b2b045b22740e4be20654175cc38291d48539db" @@ -449,9 +449,11 @@ weakdeps = ["CUDA", "MPI"] [[deps.ClimaCore]] deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LazyBroadcast", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "UnrolledUtilities"] -git-tree-sha1 = "ca717446978d2815b4fa23a62a2131861e44d1e8" +git-tree-sha1 = "ecf3fd72077213622731dc76d45897b66f6aa55b" +repo-rev = "js/kp/conservative" +repo-url = "https://github.com/CliMA/ClimaCore.jl.git" uuid = "d414da3d-4745-48bb-8d80-42e94e092884" -version = "0.14.42" +version = "0.14.43" weakdeps = ["CUDA", "Krylov"] [deps.ClimaCore.extensions] @@ -689,6 +691,14 @@ git-tree-sha1 = "d9d26935a0bcffc87d2613ce14c527c99fc543fd" uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" version = "2.5.0" +[[deps.ConservativeRegridding]] +deps = ["DocStringExtensions", "Extents", "GeoInterface", "GeometryOps", "GeometryOpsCore", "LinearAlgebra", "ProgressMeter", "SortTileRecursiveTree", "SparseArrays"] +git-tree-sha1 = "bb51e642ddfd67cf6cac0344fabd9c3fe36cd9a3" +repo-rev = "main" +repo-url = "https://github.com/JuliaGeo/ConservativeRegridding.jl" +uuid = "8e50ac2c-eb48-49bc-a402-07c87b949343" +version = "0.1.0" + [[deps.ConstructionBase]] git-tree-sha1 = "b4b092499347b18a015186eae3042f72267106cb" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" diff --git a/experiments/ClimaEarth/Project.toml b/experiments/ClimaEarth/Project.toml index eafdc92955..c20da36f31 100644 --- a/experiments/ClimaEarth/Project.toml +++ b/experiments/ClimaEarth/Project.toml @@ -16,12 +16,14 @@ ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c" ClimaSeaIce = "6ba0ff68-24e6-4315-936c-2e99227c95a4" ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79" ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513" +ConservativeRegridding = "8e50ac2c-eb48-49bc-a402-07c87b949343" EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d" GeoMakie = "db073c08-6b98-4ee5-b6a4-5efafb3259c6" Insolation = "e98cc03f-d57e-4e3c-b70c-8d51efe9e0d8" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab" Oceananigans = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09" @@ -50,6 +52,7 @@ ClimaUtilities = "0.1" EnsembleKalmanProcesses = "2" Insolation = "0.10.2" Interpolations = "0.14, 0.15" +LinearAlgebra = "1.11" Oceananigans = "0.100" StaticArrays = "1" YAML = "0.4" diff --git a/experiments/ClimaEarth/components/ocean/climaocean_helpers.jl b/experiments/ClimaEarth/components/ocean/climaocean_helpers.jl index 5927ab5327..8178e18fed 100644 --- a/experiments/ClimaEarth/components/ocean/climaocean_helpers.jl +++ b/experiments/ClimaEarth/components/ocean/climaocean_helpers.jl @@ -1,44 +1,3 @@ -""" - to_node(pt::CC.Geometry.LatLongPoint) - -Transform `LatLongPoint` into a tuple (long, lat, 0), where the 0 is needed because we only -care about the surface. -""" -@inline to_node(pt::CC.Geometry.LatLongPoint) = pt.long, pt.lat, zero(pt.lat) -# This next one is needed if we have "LevelGrid" -@inline to_node(pt::CC.Geometry.LatLongZPoint) = pt.long, pt.lat, zero(pt.lat) - -""" - map_interpolate(points, oc_field::OC.Field) - -Interpolate the given 3D field onto the target points. - -If the underlying grid does not contain a given point, return 0 instead. - -Note: `map_interpolate` does not support interpolation from `Field`s defined on -`OrthogononalSphericalShellGrids` such as the `TripolarGrid`. - -TODO: Use a non-allocating version of this function (simply replace `map` with `map!`) -""" -function map_interpolate(points, oc_field::OC.Field) - loc = map(L -> L(), OC.Fields.location(oc_field)) - grid = oc_field.grid - data = oc_field.data - - # TODO: There has to be a better way - min_lat, max_lat = extrema(OC.φnodes(grid, OC.Center(), OC.Center(), OC.Center())) - - map(points) do pt - FT = eltype(pt) - - # The oceananigans grid does not cover the entire globe, so we should not - # interpolate outside of its latitude bounds. Instead we return 0 - min_lat < pt.lat < max_lat || return FT(0) - - fᵢ = OC.Fields.interpolate(to_node(pt), data, loc, grid) - convert(FT, fᵢ)::FT - end -end """ surface_flux(f::OC.AbstractField) @@ -54,16 +13,6 @@ function surface_flux(f::OC.AbstractField) end end -function Interfacer.remap(field::OC.Field, target_space) - return map_interpolate(CC.Fields.coordinate_field(target_space), field) -end - -function Interfacer.remap(operation::OC.AbstractOperations.AbstractOperation, target_space) - evaluated_field = OC.Field(operation) - OC.compute!(evaluated_field) - return Interfacer.remap(evaluated_field, target_space) -end - """ set_from_extrinsic_vector!(vector, grid, u_cc, v_cc) diff --git a/experiments/ClimaEarth/components/ocean/oceananigans.jl b/experiments/ClimaEarth/components/ocean/oceananigans.jl index c90d477a4d..964c493e28 100644 --- a/experiments/ClimaEarth/components/ocean/oceananigans.jl +++ b/experiments/ClimaEarth/components/ocean/oceananigans.jl @@ -7,8 +7,11 @@ import Thermodynamics as TD import ClimaParams as CP import ClimaOcean.EN4: download_dataset using KernelAbstractions: @kernel, @index, @inbounds +import ConservativeRegridding as CR +import LinearAlgebra as LA include("climaocean_helpers.jl") +include("remapping.jl") """ OceananigansSimulation{SIM, A, OPROP, REMAP, SIC} @@ -51,7 +54,7 @@ function OceananigansSimulation( ice_model, Δt = nothing, comms_ctx = ClimaComms.context(), - coupled_param_dict = CP.create_toml_dict(eltype(area_fraction)), + coupled_param_dict = CP.create_toml_dict(CC.Spaces.undertype(boundary_space)), ) arch = comms_ctx.device isa ClimaComms.CUDADevice ? OC.GPU() : OC.CPU() @@ -82,7 +85,7 @@ function OceananigansSimulation( bottom_height = CO.regrid_bathymetry( underlying_grid; minimum_depth = 30, - interpolation_passes = 20, + interpolation_passes = 2, # TODO revert major_basins = 1, ) @@ -137,33 +140,8 @@ function OceananigansSimulation( # Set initial condition to EN4 state estimate at start_date OC.set!(ocean.model, T = en4_temperature[1], S = en4_salinity[1]) - long_cc = OC.λnodes(grid, OC.Center(), OC.Center(), OC.Center()) - lat_cc = OC.φnodes(grid, OC.Center(), OC.Center(), OC.Center()) - - # TODO: Go from 0 to Nx+1, Ny+1 (for halos) (for LatLongGrid) - - # Construct a remapper from the exchange grid to `Center, Center` fields - long_cc = reshape(long_cc, length(long_cc), 1) - lat_cc = reshape(lat_cc, 1, length(lat_cc)) - target_points_cc = @. CC.Geometry.LatLongPoint(lat_cc, long_cc) - # TODO: We can remove the `nothing` after CC > 0.14.33 - remapper_cc = CC.Remapping.Remapper(boundary_space, target_points_cc, nothing) - - # Construct two 2D Center/Center fields to use as scratch space while remapping - scratch_cc1 = OC.Field{OC.Center, OC.Center, Nothing}(grid) - scratch_cc2 = OC.Field{OC.Center, OC.Center, Nothing}(grid) - - # Construct two scratch arrays to use while remapping - # We get the array type, float type, and dimensions from the remapper object to maintain consistency - ArrayType = ClimaComms.array_type(remapper_cc.space) - FT = CC.Spaces.undertype(remapper_cc.space) - interpolated_values_dim..., _buffer_length = size(remapper_cc._interpolated_values) - scratch_arr1 = ArrayType(zeros(FT, interpolated_values_dim...)) - scratch_arr2 = ArrayType(zeros(FT, interpolated_values_dim...)) - scratch_arr3 = ArrayType(zeros(FT, interpolated_values_dim...)) - - remapping = - (; remapper_cc, scratch_cc1, scratch_cc2, scratch_arr1, scratch_arr2, scratch_arr3) + # Construct the remappers for remapping between the Oceananigans grid and the ClimaCore boundary space + remapping = construct_remappers(grid, boundary_space) # Get some ocean properties and parameters ocean_properties = (; @@ -281,6 +259,19 @@ Interfacer.get_field(sim::OceananigansSimulation, ::Val{:surface_diffuse_albedo} Interfacer.get_field(sim::OceananigansSimulation, ::Val{:surface_temperature}) = sim.ocean.model.tracers.T + sim.ocean_properties.C_to_K # convert from Celsius to Kelvin +# Extend Interfacer.get_field to make sure we provide the remapping object to the remap functions +function Interfacer.get_field(sim::OceananigansSimulation, quantity, target_space) + return Interfacer.remap( + Interfacer.get_field(sim, quantity), + sim.remapping, + target_space, + ) +end +function Interfacer.get_field!(target_field, sim::OceananigansSimulation, quantity) + Interfacer.remap!(target_field, Interfacer.get_field(sim, quantity), sim.remapping) + return nothing +end + """ FluxCalculator.update_turbulent_fluxes!(sim::OceananigansSimulation, fields) @@ -305,29 +296,19 @@ function FluxCalculator.update_turbulent_fluxes!(sim::OceananigansSimulation, fi grid = sim.ocean.model.grid ice_concentration = sim.ice_concentration - # Remap momentum fluxes onto reduced 2D Center, Center fields using scratch arrays and fields - CC.Remapping.interpolate!( - sim.remapping.scratch_arr1, - sim.remapping.remapper_cc, - F_turb_ρτxz, - ) - OC.set!(sim.remapping.scratch_cc1, sim.remapping.scratch_arr1) # zonal momentum flux - CC.Remapping.interpolate!( - sim.remapping.scratch_arr2, - sim.remapping.remapper_cc, - F_turb_ρτyz, - ) - OC.set!(sim.remapping.scratch_cc2, sim.remapping.scratch_arr2) # meridional momentum flux + # Remap momentum fluxes onto scratch 2D Center, Center fields + Interfacer.remap!(sim.remapping.scratch_field_oc1, F_turb_ρτxz, sim.remapping) # zonal momentum flux + Interfacer.remap!(sim.remapping.scratch_field_oc2, F_turb_ρτyz, sim.remapping) # meridional momentum flux # Rename for clarity; these are now Center, Center Oceananigans fields - F_turb_ρτxz_cc = sim.remapping.scratch_cc1 - F_turb_ρτyz_cc = sim.remapping.scratch_cc2 + oc_F_turb_ρτxz = sim.remapping.scratch_field_oc1 + oc_F_turb_ρτyz = sim.remapping.scratch_field_oc2 # Weight by (1 - sea ice concentration) - OC.interior(F_turb_ρτxz_cc, :, :, 1) .= - OC.interior(F_turb_ρτxz_cc, :, :, 1) .* (1.0 .- ice_concentration) - OC.interior(F_turb_ρτyz_cc, :, :, 1) .= - OC.interior(F_turb_ρτyz_cc, :, :, 1) .* (1.0 .- ice_concentration) + OC.interior(oc_F_turb_ρτxz, :, :, 1) .= + OC.interior(oc_F_turb_ρτxz, :, :, 1) .* (1.0 .- ice_concentration) + OC.interior(oc_F_turb_ρτyz, :, :, 1) .= + OC.interior(oc_F_turb_ρτyz, :, :, 1) .* (1.0 .- ice_concentration) # Set the momentum flux BCs at the correct locations using the remapped scratch fields oc_flux_u = surface_flux(sim.ocean.model.velocities.u) @@ -335,19 +316,19 @@ function FluxCalculator.update_turbulent_fluxes!(sim::OceananigansSimulation, fi set_from_extrinsic_vector!( (; u = oc_flux_u, v = oc_flux_v), grid, - F_turb_ρτxz_cc, - F_turb_ρτyz_cc, + oc_F_turb_ρτxz, + oc_F_turb_ρτyz, ) (; reference_density, heat_capacity, fresh_water_density) = sim.ocean_properties # Remap the latent and sensible heat fluxes using scratch arrays - CC.Remapping.interpolate!(sim.remapping.scratch_arr1, sim.remapping.remapper_cc, F_lh) # latent heat flux - CC.Remapping.interpolate!(sim.remapping.scratch_arr2, sim.remapping.remapper_cc, F_sh) # sensible heat flux + Interfacer.remap!(sim.remapping.scratch_field_oc1, F_lh, sim.remapping) # latent heat flux + Interfacer.remap!(sim.remapping.scratch_field_oc2, F_sh, sim.remapping) # sensible heat flux # Rename for clarity; recall F_turb_energy = F_lh + F_sh - remapped_F_lh = sim.remapping.scratch_arr1 - remapped_F_sh = sim.remapping.scratch_arr2 + oc_F_lh = OC.interior(sim.remapping.scratch_field_oc1, :, :, 1) + oc_F_sh = OC.interior(sim.remapping.scratch_field_oc2, :, :, 1) # TODO: Note, SW radiation penetrates the surface. Right now, we just put # everything on the surface, but later we will need to account for this. @@ -355,19 +336,16 @@ function FluxCalculator.update_turbulent_fluxes!(sim::OceananigansSimulation, fi oc_flux_T = surface_flux(sim.ocean.model.tracers.T) OC.interior(oc_flux_T, :, :, 1) .= OC.interior(oc_flux_T, :, :, 1) .+ - (1.0 .- ice_concentration) .* (remapped_F_lh .+ remapped_F_sh) ./ + (1.0 .- ice_concentration) .* (oc_F_lh .+ oc_F_sh) ./ (reference_density * heat_capacity) # Add the part of the salinity flux that comes from the moisture flux, we also need to # add the component due to precipitation (that was done with the radiative fluxes) - CC.Remapping.interpolate!( - sim.remapping.scratch_arr1, - sim.remapping.remapper_cc, - F_turb_moisture, - ) - moisture_fresh_water_flux = sim.remapping.scratch_arr1 ./ fresh_water_density + Interfacer.remap!(sim.remapping.scratch_field_oc1, F_turb_moisture, sim.remapping) # moisture flux + moisture_fresh_water_flux = + OC.interior(sim.remapping.scratch_field_oc1, :, :, 1) ./ fresh_water_density oc_flux_S = surface_flux(sim.ocean.model.tracers.S) - surface_salinity = OC.interior(sim.ocean.model.tracers.S, :, :, 1) + surface_salinity = OC.interior(sim.ocean.model.tracers.S, :, :, grid.Nz) OC.interior(oc_flux_S, :, :, 1) .= OC.interior(oc_flux_S, :, :, 1) .- (1.0 .- ice_concentration) .* surface_salinity .* moisture_fresh_water_flux @@ -401,21 +379,14 @@ so a sign change is needed when we convert from precipitation to salinity flux. function FieldExchanger.update_sim!(sim::OceananigansSimulation, csf) (; reference_density, heat_capacity, fresh_water_density) = sim.ocean_properties ice_concentration = sim.ice_concentration + grid = sim.ocean.model.grid - # Remap radiative flux onto scratch array; rename for clarity - CC.Remapping.interpolate!( - sim.remapping.scratch_arr1, - sim.remapping.remapper_cc, - csf.SW_d, - ) - remapped_SW_d = sim.remapping.scratch_arr1 + # Remap radiative flux onto scratch fields; rename for clarity + Interfacer.remap!(sim.remapping.scratch_field_oc1, csf.SW_d, sim.remapping) # shortwave radiation + oc_SW_d = OC.interior(sim.remapping.scratch_field_oc1, :, :, 1) - CC.Remapping.interpolate!( - sim.remapping.scratch_arr2, - sim.remapping.remapper_cc, - csf.LW_d, - ) - remapped_LW_d = sim.remapping.scratch_arr2 + Interfacer.remap!(sim.remapping.scratch_field_oc2, csf.LW_d, sim.remapping) # longwave radiation + oc_LW_d = OC.interior(sim.remapping.scratch_field_oc2, :, :, 1) # Update only the part due to radiative fluxes. For the full update, the component due # to latent and sensible heat is missing and will be updated in update_turbulent_fluxes. @@ -425,33 +396,26 @@ function FieldExchanger.update_sim!(sim::OceananigansSimulation, csf) ϵ = Interfacer.get_field(sim, Val(:emissivity)) # scalar OC.interior(oc_flux_T, :, :, 1) .= (1.0 .- ice_concentration) .* ( - -(1 - α) .* remapped_SW_d .- + -(1 - α) .* oc_SW_d .- ϵ * ( - remapped_LW_d .- - σ .* (C_to_K .+ OC.interior(sim.ocean.model.tracers.T, :, :, 1)) .^ 4 + oc_LW_d .- + σ .* (C_to_K .+ OC.interior(sim.ocean.model.tracers.T, :, :, grid.Nz)) .^ 4 ) ) ./ (reference_density * heat_capacity) # Remap precipitation fields onto scratch arrays; rename for clarity - CC.Remapping.interpolate!( - sim.remapping.scratch_arr1, - sim.remapping.remapper_cc, - csf.P_liq, - ) - CC.Remapping.interpolate!( - sim.remapping.scratch_arr2, - sim.remapping.remapper_cc, - csf.P_snow, - ) - remapped_P_liq = sim.remapping.scratch_arr1 - remapped_P_snow = sim.remapping.scratch_arr2 + Interfacer.remap!(sim.remapping.scratch_field_oc1, csf.P_liq, sim.remapping) # liquid precipitation + oc_P_liq = OC.interior(sim.remapping.scratch_field_oc1, :, :, 1) + + Interfacer.remap!(sim.remapping.scratch_field_oc2, csf.P_snow, sim.remapping) # snow precipitation + oc_P_snow = OC.interior(sim.remapping.scratch_field_oc2, :, :, 1) # Virtual salt flux oc_flux_S = surface_flux(sim.ocean.model.tracers.S) OC.interior(oc_flux_S, :, :, 1) .= OC.interior(oc_flux_S, :, :, 1) .- - OC.interior(sim.ocean.model.tracers.S, :, :, 1) .* (1.0 .- ice_concentration) .* - (remapped_P_liq .+ remapped_P_snow) ./ fresh_water_density + OC.interior(sim.ocean.model.tracers.S, :, :, grid.Nz) .* + (1.0 .- ice_concentration) .* (oc_P_liq .+ oc_P_snow) ./ fresh_water_density return nothing end diff --git a/experiments/ClimaEarth/components/ocean/remapping.jl b/experiments/ClimaEarth/components/ocean/remapping.jl new file mode 100644 index 0000000000..7f7bb2b2f5 --- /dev/null +++ b/experiments/ClimaEarth/components/ocean/remapping.jl @@ -0,0 +1,214 @@ +### Helper functions to use ConservativeRemapping.jl with Oceananigans.jl +""" + compute_cell_matrix(grid::Union{OC.OrthogonalSphericalShellGrid, OC.LatitudeLongitudeGrid}) + +Get a vector of vector of coordinate tuples, of the format expected by the +ConservativeRemapping.jl regridder. +""" +function compute_cell_matrix( + grid::Union{OC.OrthogonalSphericalShellGrid, OC.LatitudeLongitudeGrid}, +) + Fx, Fy, _ = size(grid) + # TODO is it ok to hardcode Center? Regridder is specifically for Center, Center fields so I think it's ok + ℓx, ℓy = OC.Center(), OC.Center() + + if isnothing(ℓx) || isnothing(ℓy) + error( + "cell_matrix can only be computed for fields with non-nothing horizontal location.", + ) + end + + arch = grid.architecture + FT = eltype(grid) + + vertices_per_cell = 5 # convention: [sw, nw, ne, se, sw] + ArrayType = OC.Architectures.array_type(arch) + cell_matrix = ArrayType{Tuple{FT, FT}}(undef, vertices_per_cell, Fx * Fy) + + OC.Utils.launch!( + arch, + grid, + (Fx, Fy), + _compute_cell_matrix!, + cell_matrix, + Fx, + ℓx, + ℓy, + grid, + ) + + return cell_matrix +end + +flip(::OC.Face) = OC.Center() +flip(::OC.Center) = OC.Face() + +left_index(i, ::OC.Center) = i +left_index(i, ::OC.Face) = i - 1 +right_index(i, ::OC.Center) = i + 1 +right_index(i, ::OC.Face) = i + +@kernel function _compute_cell_matrix!(cell_matrix, Fx, ℓx, ℓy, grid) + i, j = @index(Global, NTuple) + + vx = flip(ℓx) + vy = flip(ℓy) + + isw = left_index(i, ℓx) + jsw = left_index(j, ℓy) + + inw = left_index(i, ℓx) + jnw = right_index(j, ℓy) + + ine = right_index(i, ℓx) + jne = right_index(j, ℓy) + + ise = right_index(i, ℓx) + jse = left_index(j, ℓy) + + xsw = OC.ξnode(isw, jsw, 1, grid, vx, vy, nothing) + ysw = OC.ηnode(isw, jsw, 1, grid, vx, vy, nothing) + + xnw = OC.ξnode(inw, jnw, 1, grid, vx, vy, nothing) + ynw = OC.ηnode(inw, jnw, 1, grid, vx, vy, nothing) + + xne = OC.ξnode(ine, jne, 1, grid, vx, vy, nothing) + yne = OC.ηnode(ine, jne, 1, grid, vx, vy, nothing) + + xse = OC.ξnode(ise, jse, 1, grid, vx, vy, nothing) + yse = OC.ηnode(ise, jse, 1, grid, vx, vy, nothing) + + linear_idx = i + (j - 1) * Fx + @inbounds begin + cell_matrix[1, linear_idx] = (xsw, ysw) + cell_matrix[2, linear_idx] = (xnw, ynw) + cell_matrix[3, linear_idx] = (xne, yne) + cell_matrix[4, linear_idx] = (xse, yse) + cell_matrix[5, linear_idx] = (xsw, ysw) + end +end + +### Extensions of Interfacer.jl functions for Oceananigans fields/grids +# Non-allocating ClimaCore -> Oceananigans remap +function Interfacer.remap!(dst_field::OC.Field, src_field::CC.Fields.Field, remapping) + CC.Remapping.get_value_per_element!( + remapping.value_per_element_cc, + src_field, + remapping.field_ones_cc, + ) + + # Get the index of the top level (surface); 1 for 2D fields, Nz for 3D fields + z = size(dst_field, 3) + dst = vec(OC.interior(dst_field, :, :, z)) + src = remapping.value_per_element_cc + + # Multiply by transpose of the matrix of intersection areas + LA.mul!(dst, transpose(remapping.remapper_oc_to_cc.intersections), src) + + # Normalize by the destination (Oceananigans) element areas + dst ./= remapping.remapper_oc_to_cc.src_areas # Oceananigans areas are source areas + return nothing +end +# Allocating ClimaCore -> Oceananigans remap +function Interfacer.remap( + src_field::CC.Fields.Field, + remapping, + dst_space::Union{OC.OrthogonalSphericalShellGrid, OC.LatitudeLongitudeGrid}, +) + dst_field = OC.Field{OC.Center, OC.Center, Nothing}(dst_space) + remap!(dst_field, src_field, remapping) + return dst_field +end + +# Non-allocating Oceananigans -> ClimaCore remap +function Interfacer.remap!(dst_field::CC.Fields.Field, src_field::OC.Field, remapping) + # Get the index of the top level (surface); 1 for 2D fields, Nz for 3D fields + z = size(src_field, 3) + # Store the remapped FV values in a vector of length equal to the number of elements in the target space + dst = remapping.value_per_element_cc + src = vec(OC.interior(src_field, :, :, z)) + LA.mul!(dst, remapping.remapper_oc_to_cc.intersections, src) + + # Normalize by the destination (ClimaCore) element areas + dst ./= remapping.remapper_oc_to_cc.dst_areas # ClimaCore areas are destination areas + + # Convert the vector of remapped values to a ClimaCore Field with one value per element + CC.Remapping.set_value_per_element!(dst_field, dst) + return nothing +end +# Handle the case of remapping the area fraction field, which is a ClimaCore Field +Interfacer.remap!(dst_field::CC.Fields.Field, src_field::CC.Fields.Field, remapping) = + Interfacer.remap!(dst_field, src_field) +Interfacer.remap!(dst_field::CC.Fields.Field, src_field::Number, remapping) = + Interfacer.remap!(dst_field, src_field) +# Allocating Oceananigans -> ClimaCore remap +function Interfacer.remap( + src_field::OC.Field, + remapping, + dst_space::CC.Spaces.AbstractSpace, +) + dst_field = CC.Fields.zeros(dst_space) + Interfacer.remap!(dst_field, src_field, remapping) + return dst_field +end + +# Handle the case of remapping a scalar number to a ClimaCore space +Interfacer.remap(num::Number, remapping, target_space::CC.Spaces.AbstractSpace) = + Interfacer.remap(num, target_space) + + +function Interfacer.remap( + operation::OC.AbstractOperations.AbstractOperation, + remapping, + target_space, +) + evaluated_field = OC.Field(operation) + OC.compute!(evaluated_field) + return Interfacer.remap(evaluated_field, remapping, target_space) +end + +function Interfacer.remap!( + target_field, + operation::OC.AbstractOperations.AbstractOperation, + remapping, +) + evaluated_field = OC.Field(operation) + OC.compute!(evaluated_field) + return Interfacer.remap!(target_field, evaluated_field, remapping) +end + +""" + construct_remappers(oc_grid, space_cc) + +Given an Oceananigans LatitudeLongitudeGrid and a ClimaCore space, construct the +remappers needed to remap between the two grids in both directions. + +Returns a remapper from the Oceananigans grid to the ClimaCore boundary space. +To regrid from Oceananigans to ClimaCore, use `LA.mul!(dest_vector, remapper_oc_to_cc, src_vector)`. +To regrid from ClimaCore to Oceananigans, use `LA.mul!(dest_vector, transpose(remapper_oc_to_cc), src_vector)`. +""" +function construct_remappers(grid_oc, space_cc) + # Get the vector of polygons for Oceananigans and ClimaCore spaces + vertices_oc = compute_cell_matrix(grid_oc.underlying_grid) + vertices_cc = CC.Remapping.get_element_vertices(space_cc) + + remapper_oc_to_cc = CR.Regridder(vertices_cc, vertices_oc; normalize = false) + + # Create a field of ones on the boundary space so we can compute element areas + field_ones_cc = CC.Fields.ones(space_cc) + + # Allocate a vector with length equal to the number of elements in the target space + # To be used as a temp field for remapping + value_per_element_cc = zeros(Float64, CC.Meshes.nelements(space_cc.grid.topology.mesh)) + + # Construct two 2D Oceananigans Center/Center fields to use as scratch space while remapping + scratch_field_oc1 = OC.Field{OC.Center, OC.Center, Nothing}(grid_oc) + scratch_field_oc2 = OC.Field{OC.Center, OC.Center, Nothing}(grid_oc) + return (; + remapper_oc_to_cc, + field_ones_cc, + value_per_element_cc, + scratch_field_oc1, + scratch_field_oc2, + ) +end