Loading [MathJax]/extensions/tex2jax.js
PISM, A Parallel Ice Sheet Model 2.2.2-d6b3a29ca committed by Constantine Khrulev on 2025-03-28
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
SSAFD_SNES.cc
Go to the documentation of this file.
1/* Copyright (C) 2024 PISM Authors
2 *
3 * This file is part of PISM.
4 *
5 * PISM is free software; you can redistribute it and/or modify it under the
6 * terms of the GNU General Public License as published by the Free Software
7 * Foundation; either version 3 of the License, or (at your option) any later
8 * version.
9 *
10 * PISM is distributed in the hope that it will be useful, but WITHOUT ANY
11 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
13 * details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with PISM; if not, write to the Free Software
17 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20#include "pism/stressbalance/ssa/SSAFD_SNES.hh"
21#include "pism/stressbalance/StressBalance.hh" // Inputs
22#include "pism/util/petscwrappers/Vec.hh"
23#include <algorithm> // std::max()
24
25namespace pism {
26namespace stressbalance {
27
28PetscErrorCode SSAFDSNESConvergenceTest(SNES snes, PetscInt it, PetscReal xnorm, PetscReal gnorm,
29 PetscReal f, SNESConvergedReason *reason, void *ctx) {
30 PetscErrorCode ierr;
31
32 SSAFD_SNES *solver = reinterpret_cast<SSAFD_SNES *>(ctx);
33 double tolerance = solver->tolerance();
34
35 ierr = SNESConvergedDefault(snes, it, xnorm, gnorm, f, reason, ctx); CHKERRQ(ierr);
36 if (*reason >= 0 and tolerance > 0) {
37 // converged or iterating
38 Vec residual;
39 ierr = SNESGetFunction(snes, &residual, NULL, NULL);
40 CHKERRQ(ierr);
41
42 PetscReal norm;
43 ierr = VecNorm(residual, NORM_INFINITY, &norm);
44 CHKERRQ(ierr);
45
46 if (norm <= tolerance) {
47 *reason = SNES_CONVERGED_FNORM_ABS;
48 }
49 }
50
51 return 0;
52}
53
54double SSAFD_SNES::tolerance() const {
55 return m_config->get_number("stress_balance.ssa.fd.absolute_tolerance");
56}
57
58SSAFD_SNES::SSAFD_SNES(std::shared_ptr<const Grid> grid, bool regional_mode)
59 : SSAFDBase(grid, regional_mode), m_residual(grid, "_ssa_residual") {
60
61 PetscErrorCode ierr;
62
63 int stencil_width=2;
64 m_DA = m_grid->get_dm(2, stencil_width);
65
66 // ierr = DMCreateGlobalVector(*m_DA, m_X.rawptr());
67 // PISM_CHK(ierr, "DMCreateGlobalVector");
68
69 ierr = SNESCreate(m_grid->com, m_snes.rawptr());
70 PISM_CHK(ierr, "SNESCreate");
71
72 // Set the SNES callbacks to call into our compute_local_function and compute_local_jacobian
75 m_callback_data.inputs = nullptr;
76
77 ierr = DMDASNESSetFunctionLocal(*m_DA, INSERT_VALUES,
78#if PETSC_VERSION_LT(3,21,0)
79 (DMDASNESFunction)SSAFD_SNES::function_callback,
80#else
81 (DMDASNESFunctionFn*)SSAFD_SNES::function_callback,
82#endif
84 PISM_CHK(ierr, "DMDASNESSetFunctionLocal");
85
86 ierr = DMDASNESSetJacobianLocal(*m_DA,
87#if PETSC_VERSION_LT(3,21,0)
88 (DMDASNESJacobian)SSAFD_SNES::jacobian_callback,
89#else
90 (DMDASNESJacobianFn*)SSAFD_SNES::jacobian_callback,
91#endif
93 PISM_CHK(ierr, "DMDASNESSetJacobianLocal");
94
95 // ierr = DMSetMatType(*m_DA, "baij");
96 // PISM_CHK(ierr, "DMSetMatType");
97
98 ierr = DMSetApplicationContext(*m_DA, &m_callback_data);
99 PISM_CHK(ierr, "DMSetApplicationContext");
100
101 ierr = SNESSetOptionsPrefix(m_snes, "ssafd_");
102 PISM_CHK(ierr, "SNESSetOptionsPrefix");
103
104 ierr = SNESSetDM(m_snes, *m_DA);
105 PISM_CHK(ierr, "SNESSetDM");
106
107 ierr = SNESSetConvergenceTest(m_snes, SSAFDSNESConvergenceTest, this, NULL);
108 PISM_CHK(ierr, "SNESSetConvergenceTest");
109
110 ierr = SNESSetTolerances(m_snes, 0.0, 0.0, 0.0, 500, -1);
111 PISM_CHK(ierr, "SNESSetTolerances");
112
113 ierr = SNESSetFromOptions(m_snes);
114 PISM_CHK(ierr, "SNESSetFromOptions");
115}
116
117void SSAFD_SNES::solve(const Inputs &inputs) {
118 m_callback_data.inputs = &inputs;
119 initialize_iterations(inputs);
120 {
121 PetscErrorCode ierr;
122
123 // Solve:
124 // ierr = SNESSolve(m_snes, NULL, m_X);
125 ierr = SNESSolve(m_snes, NULL, m_velocity_global.vec());
126 PISM_CHK(ierr, "SNESSolve");
127
128 // See if it worked.
129 SNESConvergedReason reason;
130 ierr = SNESGetConvergedReason(m_snes, &reason);
131 PISM_CHK(ierr, "SNESGetConvergedReason");
132 if (reason < 0) {
134 "SSAFD_SNES solve failed to converge (SNES reason %s)",
135 SNESConvergedReasons[reason]);
136 }
137
138 PetscInt snes_iterations = 0;
139 ierr = SNESGetIterationNumber(m_snes, &snes_iterations);
140 PISM_CHK(ierr, "SNESGetIterationNumber");
141
142 PetscInt ksp_iterations = 0;
143 ierr = SNESGetLinearSolveIterations(m_snes, &ksp_iterations);
144 PISM_CHK(ierr, "SNESGetLinearSolveIterations");
145
146 m_log->message(1, "SSA: %d*%d its, %s\n", (int)snes_iterations,
147 (int)(ksp_iterations / std::max((int)snes_iterations, 1)),
148 SNESConvergedReasons[reason]);
149 }
150 m_callback_data.inputs = nullptr;
151
152 // copy from m_velocity_global to provide m_velocity with ghosts:
154
156}
157
158
159PetscErrorCode SSAFD_SNES::function_callback(DMDALocalInfo * /*unused*/,
160 Vector2d const *const *velocity, Vector2d **result,
161 CallbackData *data) {
162 try {
163 data->solver->compute_residual(*data->inputs, velocity, result);
164 } catch (...) {
165 MPI_Comm com = MPI_COMM_SELF;
166 PetscErrorCode ierr = PetscObjectGetComm((PetscObject)data->da, &com);
167 CHKERRQ(ierr);
169 SETERRQ(com, 1, "A PISM callback failed");
170 }
171 return 0;
172}
173
174void SSAFD_SNES::compute_jacobian(const Inputs &inputs, Vector2d const *const *const velocity,
175 Mat J) {
178}
179
180PetscErrorCode SSAFD_SNES::jacobian_callback(DMDALocalInfo * /*unused*/,
181 Vector2d const *const *const velocity, Mat /* A */,
182 Mat J, CallbackData *data) {
183 try {
184 data->solver->compute_jacobian(*data->inputs, velocity, J);
185 } catch (...) {
186 MPI_Comm com = MPI_COMM_SELF;
187 PetscErrorCode ierr = PetscObjectGetComm((PetscObject)data->da, &com);
188 CHKERRQ(ierr);
190 SETERRQ(com, 1, "A PISM callback failed");
191 }
192 return 0;
193}
194
196 return m_residual;
197}
198
199//! @brief Computes the magnitude of the driving shear stress at the base of
200//! ice (diagnostically).
201class SSAFD_residual_mag : public Diag<SSAFD_SNES> {
202public:
204
205 // set metadata:
206 m_vars = { { m_sys, "ssa_residual_mag" } };
207
208 m_vars[0].long_name("magnitude of the SSAFD solver's residual").units("Pa");
209 }
210
211protected:
212 virtual std::shared_ptr<array::Array> compute_impl() const {
213 auto result = allocate<array::Scalar>("ssa_residual_mag");
214 result->metadata(0) = m_vars[0];
215
216 compute_magnitude(model->residual(), *result);
217
218 return result;
219 }
220};
221
224
225 result["ssa_residual"] = Diagnostic::wrap(m_residual);
226 result["ssa_residual_mag"] = Diagnostic::Ptr(new SSAFD_residual_mag(this));
227
228 return result;
229}
230
231
232} // namespace stressbalance
233} // namespace pism
const Config::ConstPtr m_config
configuration database used by this component
Definition Component.hh:158
const Logger::ConstPtr m_log
logger (for easy access)
Definition Component.hh:162
const std::shared_ptr< const Grid > m_grid
grid used by this component
Definition Component.hh:156
const SSAFD_SNES * model
A template derived from Diagnostic, adding a "Model".
static Ptr wrap(const T &input)
const units::System::Ptr m_sys
the unit system
std::vector< SpatialVariableMetadata > m_vars
metadata corresponding to NetCDF variables
std::shared_ptr< Diagnostic > Ptr
Definition Diagnostic.hh:65
static RuntimeError formatted(const ErrorLocation &location, const char format[],...) __attribute__((format(printf
build a RuntimeError with a formatted message
This class represents a 2D vector field (such as ice velocity) at a certain grid point.
Definition Vector2d.hh:29
T * rawptr()
Definition Wrapper.hh:39
void copy_from(const Array2D< T > &source)
Definition Array2D.hh:73
petsc::Vec & vec() const
Definition Array.cc:310
const array::Scalar * basal_yield_stress
const array::Scalar * bc_mask
void fd_operator(const Geometry &geometry, const array::Scalar *bc_mask, double bc_scaling, const array::Scalar &basal_yield_stress, IceBasalResistancePlasticLaw *basal_sliding_law, const pism::Vector2d *const *velocity, const array::Staggered1 &nuH, const array::CellType1 &cell_type, Mat *A, Vector2d **Ax) const
Assemble the left-hand side matrix for the KSP-based, Picard iteration, and finite difference impleme...
Definition SSAFDBase.cc:549
void initialize_iterations(const Inputs &inputs)
array::Staggered1 m_nuH
viscosity times thickness
Definition SSAFDBase.hh:119
const double m_bc_scaling
scaling used for diagonal matrix elements at Dirichlet BC locations
Definition SSAFDBase.hh:130
DiagnosticList diagnostics_impl() const
void compute_residual(const Inputs &inputs, const array::Vector2 &velocity, array::Vector &result)
const array::Vector & residual() const
DiagnosticList diagnostics_impl() const
void solve(const Inputs &inputs)
SSAFD_SNES(std::shared_ptr< const Grid > grid, bool regional_mode)
Definition SSAFD_SNES.cc:58
static PetscErrorCode jacobian_callback(DMDALocalInfo *info, Vector2d const *const *velocity, Mat A, Mat J, CallbackData *data)
static PetscErrorCode function_callback(DMDALocalInfo *info, Vector2d const *const *velocity, Vector2d **result, CallbackData *)
void compute_jacobian(const Inputs &inputs, Vector2d const *const *velocity, Mat J)
array::Vector m_residual
residual (diagnostic)
Definition SSAFD_SNES.hh:45
virtual std::shared_ptr< array::Array > compute_impl() const
Computes the magnitude of the driving shear stress at the base of ice (diagnostically).
array::Vector m_velocity_global
Definition SSA.hh:133
const array::Vector1 & velocity() const
Get the thickness-advective 2D velocity.
IceBasalResistancePlasticLaw * m_basal_sliding_law
#define PISM_CHK(errcode, name)
#define PISM_ERROR_LOCATION
PetscErrorCode SSAFDSNESConvergenceTest(SNES snes, PetscInt it, PetscReal xnorm, PetscReal gnorm, PetscReal f, SNESConvergedReason *reason, void *ctx)
Definition SSAFD_SNES.cc:28
std::map< std::string, Diagnostic::Ptr > DiagnosticList
void handle_fatal_errors(MPI_Comm com)