%% gmres_nr.m  for Multi-level Wavelet Schur (Method I)
%% [xo,res,iters] = gmres_nr(A,x,b, levs,restart, maxit,tol)
% ______________________
% Features of gmres_nr.m
% ----------------------
% (1) Finest level J=levs:  Flexible GMRES preconditioned by equation
%                           T_j x = R_j x + b_j
% (2) Preconditioning equation (Richardson iteration)
%              T_j x = R_j x + b_j reduces a coarse level and recusively
% (3) GMRES preconditioned by Multilevel Wavelet Schur 

function [xo,res,git] = nr_gmres(A,x,b, m,max_it,tol)
  if nargin<1, help gmres_nr, return; end
% Input ---- Mr Right Preconditioner AMy=b 
%            A: Linear system to be solved
%            x: Initial guess
%            b: Right-hand side
%         levs: Number of levels of DWT requested (best = L-5) Band=16/Daub4
%	     m: Restart number for GMRES
%	max_it: Max number of restart
%          tol: Stopping criterion
%Internal MFILES
%gmres_pr.m: Multi-reso Richardson DWT solver for \bar{T}x=y (precond)
%gmres_up: Used by gmres_pr.m to do restrictions (i=1 to i=levs+1) - see nr_1.
%
%           xo: Final solution
%          res: Relative norm of residual in each iteration

%----------------- Setup ------------------------------ | | |
%%%%%%% Declare Global Variables to save on dummy Variables %
        global mu n_o n_e levs order prob reject
        global RM AM LM UM BM CM bM xM zM   
[N,n] = size(A);
  Lev=levs; maxi = floor(log(N)/log(10));  
space='               '; 

    RM=[]; AM=[];LM=[];UM=[];BM=[];CM=[]; 
    bM=[];xM=[];zM=[];  %M-level matrices/vectors
for i=1:levs+2
    id = floor(log(i)/log(10));
  if i<levs+2
    s=sprintf('AM%d%s',i,space(1:maxi-id)); AM=[AM;s];
    s=sprintf('BM%d%s',i,space(1:maxi-id)); BM=[BM;s];
    s=sprintf('CM%d%s',i,space(1:maxi-id)); CM=[CM;s];
  end
  if i==1
    s=sprintf('  A%s',space(1:maxi-id)); RM=[RM;s];    % Input Matrix A !!!!!
     else
    s=sprintf('RM%d%s',i,space(1:maxi-id)); RM=[RM;s];
  end
    s=sprintf('LM%d%s',i,space(1:maxi-id)); LM=[LM;s]; % Inc coarsest lev k+2
    s=sprintf('UM%d%s',i,space(1:maxi-id)); UM=[UM;s]; % Inc coarsest lev k+2
    s=sprintf('bM%d%s',i,space(1:maxi-id)); bM=[bM;s];
    s=sprintf('xM%d%s',i,space(1:maxi-id)); xM=[xM;s];
    s=sprintf('zM%d%s',i,space(1:maxi-id)); zM=[zM;s];
end % No more making arrays

for i=1:levs+2
 eval(['global ' LM(i,:) ' ' UM(i,:) ' ' bM(i,:) ' ' xM(i,:) ' ' zM(i,:)])
 if i<levs+2
 eval([' global ' BM(i,:) ' ' CM(i,:) ' '])
 if i>1  
  eval(['global ' RM(i,:) ])
 end
 end
end
warning off MATLAB:flops:UnavailableFunction

% Matrix one-off transformations ++++++++++++++++++++++

  thr=1.0/N^2; thr=0.02; % relative thresh
  n_e=[];n_o=[]; n_e(1)=N;  n_o(1)=N;  % thr=thresh
for i=2:levs+2
  n_e(i) = floor( n_o(i-1)/2 );  %% n x n is the full matrix at LEV = i
  n_o(i) = n_o(i-1) - n_e(i);
end
order
   x = fwts(x,order,1,0,1,1); b = fwts(b,order,1,0,1,1);  % New DWT space 
for i=1:levs+1
    ni = n_e(i+1); nn = n_o(i);      % nxn the full matrix & ni=nn/2
 eval([RM(i,:) ' = fwts(' RM(i,:) ',order,1,1,1,1);']) 
 if i==1
   clear ba bb
   ba = banded(A(1:ni,1:ni),thr); ba=min(ba,19); ba=max(ba,3);
   bb = banded(A(1:ni,1+ni:nn),thr); bb=min(bb,19); bb=max(bb,5);
   fprintf('   Lev=%d mu=%d Thresh=%4.1e give bandw=ba/bb=%d/%d\n',...
    levs,mu,thr,ba,bb);
 end

 eval([AM(i,:) ' = ba_cut(' RM(i,:) '(1:ni,1:ni),   -ba);'])  % A_bar
 eval([BM(i,:) ' = ba_cut(' RM(i,:) '(1:ni,ni+1:nn),-ba);'])  % B_bar
 eval([CM(i,:) ' = ba_cut(' RM(i,:) '(1+ni:nn,1:ni),-ba);'])  % C_bar

%eval([AM(i,:) ' = ba_cut(' RM(i,:) '(1:ni,1:ni),   bb,1);'])  % A_bar
%eval([BM(i,:) ' = ba_cut(' RM(i,:) '(1:ni,ni+1:nn),bb,1);'])  % B_bar
%eval([CM(i,:) ' = ba_cut(' RM(i,:) '(1+ni:nn,1:ni),bb,1);'])  % C_bar

 eval([RM(i+1,:) '=' RM(i,:) '(1+ni:nn,1+ni:nn);'])  % Mr T_i+1 (full)

%eval(['figure;subplot(221); spyc(' AM(i,:) ')'])
%eval(['subplot(222); spyc(' BM(i,:) ')'])
%eval(['subplot(223); spyc(' CM(i,:) ')'])
%eval(['subplot(224); spyc(' RM(i+1,:) ')'])

 if i>1  %-----! Do not reset RM1 = \tilde{A} on Level 1
 eval([RM(i,:) '= [' AM(i,:) BM(i,:) ';' CM(i,:) RM(i+1,:) ']-' RM(i,:) ';'])
 end

 eval(['[' LM(i,:) UM(i,:) ']=lu(' AM(i,:) ');'])  % A's factorised
 eval(['clear ' AM(i,:) ])  % A_i no longer needed cos of L U's
end % for i -----------------------------------------------------------------

% pause

 i = levs+2;
 eval(['[' LM(i,:) UM(i,:) ']=lu(' RM(i,:) ');'])  % A's factorised
 eval(['clear ' RM(i,:) ])  % A_i no longer needed cos of L U's

%----------------- Setup End ------------------------------ | | |

r = b-A*x;  Res_at_start1 = norm(r)

bnrm2 = norm(r);
if  ( bnrm2 == 0.0 ), bnrm2 = 1.0; end

V=[]; H=[]; cs=[]; sn=[]; res=[]; res(1)=norm(r)/bnrm2; e1=eye(n,1); 
if (res(1) < tol) return, end

p=1;flag=0; % 0=insufficient max_it, 2=NaN occured, 1=stagnation, 3=sucessful
fprintf('Solving Ax=b using gmres_nr/pr/up [GMRES(%d) for %d] ... P%d\n', ... \
         m,max_it, prob);
for iter = 1:max_it, %%%------------------------ % Begin iteration
  x0 = x; r = b-A*x;
  V(:,1) = r / norm( r ); z=[];

  s = norm( r )*e1;
  for i = 1:m,                                   
    z(:,i) = gmres_pr(V(:,i));                   % Position for flexible gmres
    w = A*z(:,i);                                
    for k = 1:i,   
      H(k,i)= w'*V(:,k);
      w = w - H(k,i)*V(:,k);
    end
    H(i+1,i) = norm( w );
    V(:,i+1) = w / H(i+1,i);

    for k = 1:i-1,                      % apply Givens rotation to H for QR
      temp     =  cs(k)*H(k,i) + sn(k)*H(k+1,i);
      H(k+1,i) = -sn(k)*H(k,i) + cs(k)*H(k+1,i);
      H(k,i)   = temp;
    end
    [cs(i),sn(i)] = givens( H(i,i), H(i+1,i) ); 
    temp   = cs(i)*s(i);                        
    s(i+1) = -sn(i)*s(i);
    s(i)   = temp;
    H(i,i) = cs(i)*H(i,i) + sn(i)*H(i+1,i);
    H(i+1,i) = 0.0;

    y = H(1:i,1:i) \ s(1:i);
    x = x0 + z(:,1:i)*y; r = b-A*x; p = p+1;
    res(p) = norm(r) / bnrm2;              % exit if small
fprintf('\n\t\t Residual(%3d)=%8.2e',p-1,res(p));
    if res(p) < tol, flag=3; break; end
    if isnan(res(p)), flag=2;break; end
    if res(p) >= 1.999*res(p-1) & p>5; flag=1;break; end % Probable Failure
  end % i inner loop complete ---------------------------------------------
  
    if flag>0,break;end
end % iter end  -----------------------------------------------------------
git=p-1;res=res(2:end);
fprintf('\nFinal # of iter: %3d',git)
fprintf(' of GMRES(%d) to get Residual = %9.2e [flag %d]\n',m,res(git),flag)
    xo = iwts(x,order,1,0,1,1);  % WAS xo=x;

function [ c, s ] = givens( a, b ) %++++++++++++++++++++++++++++++++++++
%
% Compute the Givens rotation matrix parameters for a and b.
%
   if ( b == 0.0 ),
      c = 1.0;
      s = 0.0;
   elseif ( abs(b) > abs(a) ),
      temp = a / b;
      s = 1.0 / sqrt( 1.0 + temp^2 );
      c = temp * s;
   else
      temp = b / a;
      c = 1.0 / sqrt( 1.0 + temp^2 );
      s = temp * c;
   end

function x_out = gmres_pr(b1); %+++++++++++++++++++++++++++++++++++++++++

%%%%%%% Declare Global Variables to save on dummy Variables %%%%%%%
        global mu n_o n_e levs order  i k iout check
        global RM AM LM UM BM CM bM xM zM   
iout = 0; N =length(b1); k=levs;  %%% Communicate via i k to gmres_up
for i=1:k+2
 eval(['global ' LM(i,:) ' ' UM(i,:) ' ' bM(i,:) ' ' xM(i,:) ' ' zM(i,:)])
 if i<k+2
 eval([' global ' BM(i,:) ' ' CM(i,:) ' '])
 if i>1  
  eval(['global ' RM(i,:) ])
 end
 end
end 

% Let
  bM1 = b1;  first=norm(b1); if first<eps, first=1; end

%%% ________ %%% _______________________________________________________%
%%% Step (1) %%% Input ++++++++++++++++++++++++++++++++++++++++++++++++++

%disp(['------ N = ' int2str(N) ' = 2^Lev = 2^' int2str(Lev) ... \
%      '------'])

%%% Initiliase xM on coarse levels

for i=1:levs+1
    ni = n_e(i+1); nn = n_o(i);    % nxn the full matrix & ni=nn/2
      eval([xM(i,:) ' = zeros(nn,1);']) % Initialize x_j's
end
    xM1 = b1; %% better than nothing

%%% ________ %%% _______________________________________________________%
%%% Step (4) %%% Main Loops 

      check=zeros(levs+2,1);    

      i = 1;
      gmres_up;

  while i>1  %%% Level Schur Method starts  * * * * * * * * * * * * * * *

   i = i - 1;
    ni = n_e(i+1); nn = n_o(i);
        eval(['oldx =' xM(i,:) '(1+ni:nn);'])

      if i>=levs+1                    % i=levs+1 from no DWT
        eval([xM(i,:) '(1+ni:nn)=' xM(i+1,:) ';'])
      else
%%%%% Save prev iterate on i just before it is over-written %%%%%
        eval([xM(i,:) '(1+ni:nn)=iwts(' xM(i+1,:) ',order,1,0,1,1);'])
      end
      eval(['errx = norm( oldx -' xM(i,:) '(1+ni:nn) );'])
      eval([xM(i,:) '(1:ni)=' zM(i,:) '(1:ni)-' UM(i,:) '\(' ... \
            LM(i,:) '\(' BM(i,:) '*' xM(i,:) '(1+ni:nn)) );'])

   if iout==1 & i<=3
%  fprintf('gmres_pr.m  i k+1 = %d %d [%d / %d]\n', i,levs+1, check(i),mu)
   fprintf('\ngmres_pr.m  i=%2d, mu=%2d, er=%f\n', i, check(i), errx)
   end

   if i==1
      %%check(i) = mu;  % treat i=1 specially (not necessory)
      check(i) = check(i) + 1;
   else
      check(i) = check(i) + 1;
   end

   if check(i)<mu & (errx>0.05 | check(i)<2)  %  *   *   *   *   *   *   
      gmres_up;
   else
      if i>1, 
        check(i) = 0;
      end
   end % check  *   *   *   *   *   *   *   *   *   *   *   *   *   *   *

  end % while i * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

x_out = xM1;
   %% second = norm(x_out); 
   %% if second/first > 1.e+10, x_out=b1; disp('reject preconditioner'); end
%-----------------------------------------------------------------------%
   if iout==1
   Level_dimensions=[1:levs+2; n_o; n_e]
   var=who('RM*')', var=who('AM*')', var=who('BM*')', var=who('CM*')'
   end

function gmres_up
%>> gmres_up.m = tupward.m (treat i=1 differently as DWTed) >>>>>
%___________|_|_|_|_____Up to i=k+1 from 1+2 (not 1)___|
%%% ________ %%% ___________________________________
%%% Step (1) %%% upward to i=k+1 the coarsest level
%%%%%%% Declare Global Variables to save on dummy Variables %%%%%%%
        global mu n_o n_e levs order  i k iout check
        global RM AM LM UM BM CM bM xM zM   
for j=1:levs+2
 eval(['global ' LM(j,:) ' ' UM(j,:) ' ' bM(j,:) ' ' xM(j,:) ' ' zM(j,:)])
 if j<levs+2
 eval([' global ' BM(j,:) ' ' CM(j,:) ' '])
 if j>1  
  eval(['global ' RM(j,:) ])
 end
 end
end 

while i <= k+1 % *****************************************************
   if iout==1
   fprintf('Upward i k+1 = %d %d [%d / %d]\n', i,k+1, check(i),mu)
   end

   if check(i)==0 & i>1 %% cos i=1 is done in MAIN
      eval([bM(i,:) '=fwts(' bM(i,:) ',order,1,0,1,1);']);
   end
    ni = n_e(i+1); nn = n_o(i);

   if i>1
   eval([zM(i,:) '=' RM(i,:) '*' xM(i,:) '+' bM(i,:) ';']);
    else
   eval([zM(i,:) '=' bM(i,:) ';']);
   end

   eval([zM(i,:) '(1:ni)=' UM(i,:) '\(' LM(i,:) '\' zM(i,:) '(1:ni) );']); 
   eval([zM(i,:) '(1+ni:nn)=' zM(i,:) '(1+ni:nn) -' ... \
                  CM(i,:) '*' zM(i,:) '(1:ni);']);
   eval([bM(i+1,:) '=' UM(i,:) '\(' LM(i,:) '\(' ... \
      BM(i,:) '*' xM(i,:) '(1+ni:nn) )) ;'])
   eval([bM(i+1,:) '=' CM(i,:) '*' bM(i+1,:) '+' zM(i,:) '(1+ni:nn) ;'])
   i = i + 1;
end % while i ********************************************************
%%% ________ %%% __________________________________
%%% Step (2) %%% Direct Solver on Level i=k+2 %%%
   eval([xM(i,:) '=' UM(i,:) '\(' LM(i,:) '\' bM(i,:) ');']);
   if iout==1
   fprintf('DIRECT i k+1 = %d %d [%d / %d @ %d]\n', i,k+1, check(i),mu,iter)
   end
%<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
% banded.m --- find out the total semi-bandwidth of a banded matrix
function b=banded(A,tol);
   if nargin < 2
      tol=1.0e-6;
   end
    ma=max(max(abs(A))); % relative now thresh
    n=max(size(A)); b=0; save=zeros(1,floor(n/2));
for row=2:n/2    % Only search lower half
    flag = 0;
for col=1:row
    if (abs(A(row,col)) >= tol*ma & flag==0),
       b_row = col; flag=1; save(row)=abs(col-row);
    end
end
end 
b=max(save);

% ba-cut.m --- Cut out a semi-band ba out of A plus border ba if id=1
% -------- ||| If ba is -ve, try threshold with 0.01*max(ele)
function B=ba_cut(A,ba,id);
[n m] = size(A); B=speye(n,m);

if ba<0 %% Thresh
B=A;
return
end %%%%%% Thresh

mb=m-ba; nb = n-ba; if mb<0; mb=0; end,  if nb<0; nb=0; end
if n>m
 nb=nb-n+m;
elseif m>n
 mb=mb-m+n;
end
   if nargin < 2, ba=1; end
   if nargin < 3, id=1; end  % border off = 0
for row=1:n
 %_________________________diag bands
 ok1 = row-ba; ok2=row+ba; %% left right - usual  col = row-ba:row+ba;
    if ok1<1, ok1=1; end
    if ok2>m, ok2=m; end
    col = ok1:ok2;
 %_________________________bott rows
 if row>=nb+1
    col = row-ba:m;
    if row-ba<1, col=1:m; end
    if id==1
    colr = 1:m;  B(row,colr)=A(row,colr);
    end
 %_________________________righ bands
 else
    if id==1
    colr = mb+1:m;  B(row,colr)=A(row,colr);
    end
 end
    if ~isempty(row) & ~isempty(col)
    B(row,col)=A(row,col);
    end
end
