Loading [MathJax]/extensions/tex2jax.js
PISM, A Parallel Ice Sheet Model 2.2.1-cd005eec8 committed by Constantine Khrulev on 2025-03-07
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
IP_SSATaucTikhonovGNSolver.cc
Go to the documentation of this file.
1// Copyright (C) 2012--2024 David Maxwell and Constantine Khroulev
2//
3// This file is part of PISM.
4//
5// PISM is free software; you can redistribute it and/or modify it under the
6// terms of the GNU General Public License as published by the Free Software
7// Foundation; either version 3 of the License, or (at your option) any later
8// version.
9//
10// PISM is distributed in the hope that it will be useful, but WITHOUT ANY
11// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12// FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
13// details.
14//
15// You should have received a copy of the GNU General Public License
16// along with PISM; if not, write to the Free Software
17// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18
19#include "pism/inverse/IP_SSATaucTikhonovGNSolver.hh"
20#include "pism/util/TerminationReason.hh"
21#include "pism/util/pism_options.hh"
22#include "pism/util/ConfigInterface.hh"
23#include "pism/util/Grid.hh"
24#include "pism/util/Context.hh"
25#include "pism/util/petscwrappers/Vec.hh"
26
27namespace pism {
28namespace inverse {
29
31 DesignVec &d0, StateVec &u_obs, double eta,
34 : m_design_stencil_width(d0.stencil_width()),
35 m_state_stencil_width(u_obs.stencil_width()),
36 m_ssaforward(ssaforward),
37 m_x(d0.grid(), "x"),
38 m_tmp_D1Global(d0.grid(), "work vector"),
39 m_tmp_D2Global(d0.grid(), "work vector"),
40 m_tmp_D1Local(d0.grid(), "work vector"),
41 m_tmp_D2Local(d0.grid(), "work vector"),
42 m_tmp_S1Global(d0.grid(), "work vector"),
43 m_tmp_S2Global(d0.grid(), "work vector"),
44 m_tmp_S1Local(d0.grid(), "work vector"),
45 m_tmp_S2Local(d0.grid(), "work vector"),
46 m_GN_rhs(d0.grid(), "GN_rhs"),
47 m_d0(d0),
48 m_dGlobal(d0.grid(), "d (sans ghosts)"),
49 m_d_diff(d0.grid(), "d_diff"),
50 m_d_diff_lin(d0.grid(), "d_diff linearized"),
51 m_h(d0.grid(), "h"),
52 m_hGlobal(d0.grid(), "h (sans ghosts)"),
53 m_dalpha_rhs(d0.grid(), "dalpha rhs"),
54 m_dh_dalpha(d0.grid(), "dh_dalpha"),
55 m_dh_dalphaGlobal(d0.grid(), "dh_dalpha"),
56 m_grad_design(d0.grid(), "grad design"),
57 m_grad_state(d0.grid(), "grad design"),
58 m_gradient(d0.grid(), "grad design"),
59 m_u_obs(u_obs),
60 m_u_diff(d0.grid(), "du"),
61 m_eta(eta),
62 m_designFunctional(designFunctional),
63 m_stateFunctional(stateFunctional),
64 m_target_misfit(0.0)
65{
66 PetscErrorCode ierr;
67 std::shared_ptr<const Grid> grid = m_d0.grid();
68 m_comm = grid->com;
69
70 m_d = std::make_shared<DesignVecGhosted>(grid, "d");
71
72 ierr = KSPCreate(grid->com, m_ksp.rawptr());
73 PISM_CHK(ierr, "KSPCreate");
74
75 ierr = KSPSetOptionsPrefix(m_ksp, "inv_gn_");
76 PISM_CHK(ierr, "KSPSetOptionsPrefix");
77
78 double ksp_rtol = 1e-5; // Soft tolerance
79 ierr = KSPSetTolerances(m_ksp, ksp_rtol, PETSC_DEFAULT, PETSC_DEFAULT, PETSC_DEFAULT);
80 PISM_CHK(ierr, "KSPSetTolerances");
81
82 ierr = KSPSetType(m_ksp, KSPCG);
83 PISM_CHK(ierr, "KSPSetType");
84
85 PC pc;
86 ierr = KSPGetPC(m_ksp, &pc);
87 PISM_CHK(ierr, "KSPGetPC");
88
89 ierr = PCSetType(pc, PCNONE);
90 PISM_CHK(ierr, "PCSetType");
91
92 ierr = KSPSetFromOptions(m_ksp);
93 PISM_CHK(ierr, "KSPSetFromOptions");
94
95 int nLocalNodes = grid->xm()*grid->ym();
96 int nGlobalNodes = grid->Mx()*grid->My();
97 ierr = MatCreateShell(grid->com, nLocalNodes, nLocalNodes,
98 nGlobalNodes, nGlobalNodes, this, m_mat_GN.rawptr());
99 PISM_CHK(ierr, "MatCreateShell");
100
103 multCallback::connect(m_mat_GN);
104
105 m_alpha = 1./m_eta;
106 m_logalpha = log(m_alpha);
107
108 m_iter_max = 1000;
109 m_iter_max = options::Integer("-inv_gn_iter_max", "", m_iter_max);
110
111 auto config = grid->ctx()->config();
112
113 m_tikhonov_adaptive = config->get_flag("inverse.tikhonov.adaptive");
114 m_tikhonov_atol = config->get_number("inverse.tikhonov.atol");
115 m_tikhonov_rtol = config->get_number("inverse.tikhonov.rtol");
116 m_tikhonov_ptol = config->get_number("inverse.tikhonov.ptol");
117
118 m_log = d0.grid()->ctx()->log();
119}
120
121std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::init() {
123}
124
128
129//! @note This function has to return PetscErrorCode (it is used as a callback).
131 StateVec &tmp_gS = m_tmp_S1Global;
133 DesignVec &tmp_gD = m_tmp_D1Global;
135
136 PetscErrorCode ierr;
137 // FIXME: Needless copies for now.
138 {
139 ierr = DMGlobalToLocalBegin(*m_x.dm(), x, INSERT_VALUES, m_x.vec());
140 PISM_CHK(ierr, "DMGlobalToLocalBegin");
141
142 ierr = DMGlobalToLocalEnd(*m_x.dm(), x, INSERT_VALUES, m_x.vec());
143 PISM_CHK(ierr, "DMGlobalToLocalEnd");
144 }
145
147 Tx.update_ghosts();
148
149 m_stateFunctional.interior_product(Tx,tmp_gS);
150
152
153 m_designFunctional.interior_product(m_x,tmp_gD);
154 GNx.add(m_alpha,tmp_gD);
155
156 ierr = VecCopy(GNx.vec(), y); PISM_CHK(ierr, "VecCopy");
157}
158
171
172std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::solve_linearized() {
173 PetscErrorCode ierr;
174
176
177 ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
178 PISM_CHK(ierr, "KSPSetOperators");
179
180 ierr = KSPSolve(m_ksp,m_GN_rhs.vec(),m_hGlobal.vec());
181 PISM_CHK(ierr, "KSPSolve");
182
183 KSPConvergedReason ksp_reason;
184 ierr = KSPGetConvergedReason(m_ksp ,&ksp_reason);
185 PISM_CHK(ierr, "KSPGetConvergedReason");
186
188
189 return std::shared_ptr<TerminationReason>(new KSPTerminationReason(ksp_reason));
190}
191
193
197
198 double sValue;
199 m_stateFunctional.valueAt(m_tmp_S1Local,&sValue);
200
202 m_tmp_D1Local.add(1,h);
203
204 double dValue;
205 m_designFunctional.valueAt(m_tmp_D1Local,&dValue);
206
207 *value = m_alpha*dValue + sValue;
208}
209
210
211std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::check_convergence() {
212
213 double designNorm, stateNorm, sumNorm;
214 double dWeight, sWeight;
215 dWeight = m_alpha;
216 sWeight = 1;
217
218 designNorm = m_grad_design.norm(NORM_2)[0];
219 stateNorm = m_grad_state.norm(NORM_2)[0];
220
221 designNorm *= dWeight;
222 stateNorm *= sWeight;
223
224 sumNorm = m_gradient.norm(NORM_2)[0];
225
226 m_log->message(2,
227 "----------------------------------------------------------\n");
228 m_log->message(2,
229 "IP_SSATaucTikhonovGNSolver Iteration %d: misfit %g; functional %g \n",
232 m_log->message(2, "alpha %g; log(alpha) %g\n", m_alpha, m_logalpha);
233 }
234 double relsum = (sumNorm/std::max(designNorm,stateNorm));
235 m_log->message(2,
236 "design norm %g stateNorm %g sum %g; relative difference %g\n",
237 designNorm, stateNorm, sumNorm, relsum);
238
239 // If we have an adaptive tikhonov parameter, check if we have met
240 // this constraint first.
242 double disc_ratio = fabs((sqrt(m_val_state)/m_target_misfit) - 1.);
243 if (disc_ratio > m_tikhonov_ptol) {
245 }
246 }
247
248 if (sumNorm < m_tikhonov_atol) {
249 return std::shared_ptr<TerminationReason>(new GenericTerminationReason(1,"TIKHONOV_ATOL"));
250 }
251
252 if (sumNorm < m_tikhonov_rtol*std::max(designNorm,stateNorm)) {
253 return std::shared_ptr<TerminationReason>(new GenericTerminationReason(1,"TIKHONOV_RTOL"));
254 }
255
256 if (m_iter > m_iter_max) {
258 }
259
261}
262
264
265 std::shared_ptr<TerminationReason> reason = m_ssaforward.linearize_at(*m_d);
266 if (reason->failed()) {
267 return reason;
268 }
269
271 m_d_diff.add(-1,m_d0);
272
274 m_u_diff.add(-1,m_u_obs);
275
277
278 // The following computes the reduced gradient.
279 StateVec &adjointRHS = m_tmp_S1Global;
280 m_stateFunctional.gradientAt(m_u_diff,adjointRHS);
282
286
287 double valDesign, valState;
288 m_designFunctional.valueAt(m_d_diff,&valDesign);
289 m_stateFunctional.valueAt(m_u_diff,&valState);
290
291 m_val_design = valDesign;
292 m_val_state = valState;
293
294 m_value = valDesign * m_alpha + valState;
295
296 return reason;
297}
298
299std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::linesearch() {
300 PetscErrorCode ierr;
301
302 std::shared_ptr<TerminationReason> step_reason;
303
304 double old_value = m_val_design * m_alpha + m_val_state;
305
306 double descent_derivative;
307
309
310 ierr = VecDot(m_gradient.vec(), m_tmp_D1Global.vec(), &descent_derivative);
311 PISM_CHK(ierr, "VecDot");
312
313 if (descent_derivative >=0) {
314 printf("descent derivative: %g\n",descent_derivative);
315 return std::shared_ptr<TerminationReason>(new GenericTerminationReason(-1, "Not descent direction"));
316 }
317
318 double alpha = 1;
320 while(true) {
321 m_d->add(alpha,m_h); // Replace with line search.
322 step_reason = this->evaluate_objective_and_gradient();
323 if (step_reason->succeeded()) {
324 if (m_value <= old_value + 1e-3*alpha*descent_derivative) {
325 break;
326 }
327 }
328 else {
329 printf("forward solve failed in linsearch. Shrinking.\n");
330 }
331 alpha *=.5;
332 if (alpha<1e-20) {
333 printf("alpha= %g; derivative = %g\n",alpha,descent_derivative);
334 return std::shared_ptr<TerminationReason>(new GenericTerminationReason(-1, "Too many step shrinks."));
335 }
336 m_d->copy_from(m_tmp_D1Local);
337 }
338
340}
341
342std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::solve() {
343
344 if (m_target_misfit == 0) {
345 throw RuntimeError::formatted(PISM_ERROR_LOCATION, "Call set target misfit prior to calling"
346 " IP_SSATaucTikhonovGNSolver::solve.");
347 }
348
349 m_iter = 0;
350 m_d->copy_from(m_d0);
351
352 double dlogalpha = 0;
353
354 std::shared_ptr<TerminationReason> step_reason, reason;
355
356 step_reason = this->evaluate_objective_and_gradient();
357 if (step_reason->failed()) {
358 reason.reset(new GenericTerminationReason(-1,"Forward solve"));
359 reason->set_root_cause(step_reason);
360 return reason;
361 }
362
363 while(true) {
364
365 reason = this->check_convergence();
366 if (reason->done()) {
367 return reason;
368 }
369
371 m_logalpha += dlogalpha;
372 m_alpha = exp(m_logalpha);
373 }
374
375 step_reason = this->solve_linearized();
376 if (step_reason->failed()) {
377 reason.reset(new GenericTerminationReason(-1,"Gauss Newton solve"));
378 reason->set_root_cause(step_reason);
379 return reason;
380 }
381
382 step_reason = this->linesearch();
383 if (step_reason->failed()) {
384 std::shared_ptr<TerminationReason> cause = reason;
385 reason.reset(new GenericTerminationReason(-1,"Linesearch"));
386 reason->set_root_cause(step_reason);
387 return reason;
388 }
389
391 step_reason = this->compute_dlogalpha(&dlogalpha);
392 if (step_reason->failed()) {
393 std::shared_ptr<TerminationReason> cause = reason;
394 reason.reset(new GenericTerminationReason(-1,"Tikhonov penalty update"));
395 reason->set_root_cause(step_reason);
396 return reason;
397 }
398 }
399
400 m_iter++;
401 }
402
403 return reason;
404}
405
406std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::compute_dlogalpha(double *dlogalpha) {
407
408 PetscErrorCode ierr;
409
410 // Compute the right-hand side for computing dh/dalpha.
412 m_d_diff_lin.add(1,m_h);
415
416 // Solve linear equation for dh/dalpha.
417 ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
418 PISM_CHK(ierr, "KSPSetOperators");
419
420 ierr = KSPSolve(m_ksp,m_dalpha_rhs.vec(),m_dh_dalphaGlobal.vec());
421 PISM_CHK(ierr, "KSPSolve");
422
424
425 KSPConvergedReason ksp_reason;
426 ierr = KSPGetConvergedReason(m_ksp,&ksp_reason);
427 PISM_CHK(ierr, "KSPGetConvergedReason");
428
429 if (ksp_reason<0) {
430 return std::shared_ptr<TerminationReason>(new KSPTerminationReason(ksp_reason));
431 }
432
433 // S1Local contains T(h) + F(x) - u_obs, i.e. the linearized misfit field.
437
438 // Compute linearized discrepancy.
439 double disc_sq;
441
442 // There are a number of equivalent ways to compute the derivative of the
443 // linearized discrepancy with respect to alpha, some of which are cheaper
444 // than others to compute. This equivalency relies, however, on having an
445 // exact solution in the Gauss-Newton step. Since we only solve this with
446 // a soft tolerance, we lose equivalency. We attempt a cheap computation,
447 // and then do a sanity check (namely that the derivative is positive).
448 // If this fails, we compute by a harder way that inherently yields a
449 // positive number.
450
451 double ddisc_sq_dalpha;
452 m_designFunctional.dot(m_dh_dalpha,m_d_diff_lin,&ddisc_sq_dalpha);
453 ddisc_sq_dalpha *= -2*m_alpha;
454
455 if (ddisc_sq_dalpha <= 0) {
456 // Try harder.
457
458 m_log->message(3,
459 "Adaptive Tikhonov sanity check failed (dh/dalpha= %g <= 0)."
460 " Tighten inv_gn_ksp_rtol?\n",
461 ddisc_sq_dalpha);
462
463 // S2Local contains T(dh/dalpha)
466
467 double ddisc_sq_dalpha_a;
468 m_stateFunctional.dot(m_tmp_S2Local,m_tmp_S2Local,&ddisc_sq_dalpha_a);
469 double ddisc_sq_dalpha_b;
470 m_designFunctional.dot(m_dh_dalpha,m_dh_dalpha,&ddisc_sq_dalpha_b);
471 ddisc_sq_dalpha = 2*m_alpha*(ddisc_sq_dalpha_a+m_alpha*ddisc_sq_dalpha_b);
472
473 m_log->message(3,
474 "Adaptive Tikhonov sanity check recovery attempt: dh/dalpha= %g. \n",
475 ddisc_sq_dalpha);
476
477 // This is yet another alternative formula.
478 // m_stateFunctional.dot(m_tmp_S1Local,m_tmp_S2Local,&ddisc_sq_dalpha);
479 // ddisc_sq_dalpha *= 2;
480 }
481
482 // Newton's method formula.
483 *dlogalpha = (m_target_misfit*m_target_misfit-disc_sq)/(ddisc_sq_dalpha*m_alpha);
484
485 // It's easy to take steps that are too big when we are far from the solution.
486 // So we limit the step size.
487 double stepmax = 3;
488 if (fabs(*dlogalpha)> stepmax) {
489 double sgn = *dlogalpha > 0 ? 1 : -1;
490 *dlogalpha = stepmax*sgn;
491 }
492
493 if (*dlogalpha<0) {
494 *dlogalpha*=.5;
495 }
496
498}
499
500} // end of namespace inverse
501} // end of namespace pism
static std::shared_ptr< TerminationReason > max_iter()
static std::shared_ptr< TerminationReason > keep_iterating()
static std::shared_ptr< TerminationReason > success()
static RuntimeError formatted(const ErrorLocation &location, const char format[],...) __attribute__((format(printf
build a RuntimeError with a formatted message
T * rawptr()
Definition Wrapper.hh:39
void copy_from(const Array2D< T > &source)
Definition Array2D.hh:73
void add(double alpha, const Array2D< T > &x)
Definition Array2D.hh:65
petsc::Vec & vec() const
Definition Array.cc:310
void scale(double alpha)
Result: v <- v * alpha. Calls VecScale.
Definition Array.cc:224
std::shared_ptr< const Grid > grid() const
Definition Array.cc:131
void set(double c)
Result: v[j] <- c for all j.
Definition Array.cc:629
std::shared_ptr< petsc::DM > dm() const
Definition Array.cc:324
std::vector< double > norm(int n) const
Computes the norm of all the components of an Array.
Definition Array.cc:668
void update_ghosts()
Updates ghost points.
Definition Array.cc:615
Abstract base class for IPFunctionals arising from an inner product.
virtual std::shared_ptr< array::Vector > solution()
Returns the last solution of the SSA as computed by linearize_at.
virtual void apply_linearization(array::Scalar &dzeta, array::Vector &du)
Applies the linearization of the forward map (i.e. the reduced gradient described in the class-level...
virtual std::shared_ptr< TerminationReason > linearize_at(array::Scalar &zeta)
Sets the current value of the design variable and solves the SSA to find the associated .
virtual void apply_linearization_transpose(array::Vector &du, array::Scalar &dzeta)
Applies the transpose of the linearization of the forward map (i.e. the transpose of the reduced grad...
Implements the forward problem of the map taking to the corresponding solution of the SSA.
virtual std::shared_ptr< TerminationReason > linesearch()
virtual void apply_GN(array::Scalar &h, array::Scalar &out)
virtual void evaluateGNFunctional(DesignVec &h, double *value)
virtual std::shared_ptr< TerminationReason > init()
virtual std::shared_ptr< TerminationReason > evaluate_objective_and_gradient()
IPInnerProductFunctional< StateVec > & m_stateFunctional
IPInnerProductFunctional< DesignVec > & m_designFunctional
virtual std::shared_ptr< TerminationReason > compute_dlogalpha(double *dalpha)
virtual std::shared_ptr< TerminationReason > solve()
virtual std::shared_ptr< TerminationReason > check_convergence()
virtual std::shared_ptr< TerminationReason > solve_linearized()
IP_SSATaucTikhonovGNSolver(IP_SSATaucForwardProblem &ssaforward, DesignVec &d0, StateVec &u_obs, double eta, IPInnerProductFunctional< DesignVec > &designFunctional, IPInnerProductFunctional< StateVec > &stateFunctional)
#define PISM_CHK(errcode, name)
#define PISM_ERROR_LOCATION
std::string printf(const char *format,...)