diff +time/CdiffNonlin.m @ 4:8e14b5a577a6

Attempt att optimization of CdiffNonlin.
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 18 Sep 2015 15:12:44 +0200
parents 5ae4f23d9130
children b18d3d201a71
line wrap: on
line diff
--- a/+time/CdiffNonlin.m	Fri Sep 18 13:30:19 2015 +0200
+++ b/+time/CdiffNonlin.m	Fri Sep 18 15:12:44 2015 +0200
@@ -13,16 +13,9 @@
 
     methods
         function obj = CdiffNonlin(D, E, S, k, t0, v, v_prev)
-            default_arg('S',0);
-            default_arg('E',0);
-
-            if isnumeric(S) && S == 0
-                S = @(v)0;
-            end
-
-            if isnumeric(E) && E == 0
-                E = @(v)0;
-            end
+            m = size(D(v),1);
+            default_arg('E',@(v)sparse(m,m));
+            default_arg('S',@(v,t)sparse(m,1));
 
 
             % m = size(D,1);
@@ -49,7 +42,35 @@
         end
 
         function obj = step(obj)
-            [obj.v, obj.v_prev] = time.cdiff.cdiff(obj.v, obj.v_prev, obj.k, obj.D(obj.v), obj.E(obj.v), obj.S(obj.v));
+            D = obj.D(obj.v);
+            E = obj.E(obj.v);
+            S = obj.S(obj.v);
+
+            m = size(D,1);
+            I = speye(m);
+
+            %% Calculate for which indices we need to solve system of equations
+            [rows,cols] = find(E);
+            j = union(rows,cols);
+            i = setdiff(1:m,j);
+
+
+            %% Calculate matrices need for the timestep
+            % Before optimization:  A =  1/k^2 * I - 1/(2*k)*E;
+            k = obj.k;
+            Aj = 1/k^2 * I(j,j) - 1/(2*k)*E(j,j);
+            B =  2/k^2 * I + D;
+            C = -1/k^2 * I - 1/(2*k)*E;
+
+            %% Take the timestep
+            v = obj.v;
+
+            % Before optimization:  obj.v = A\(B*v + C*v_prev + S);
+            obj.v(i) = k^2*(B(i,i)*v(i)   -1/k^2*obj.v_prev(i) + S(i));
+            obj.v(j) =  Aj\(B(j,j)*v(j) + C(j,j)*obj.v_prev(j) + S(j));
+            obj.v_prev = v;
+
+            %% Update state of the timestepper
             obj.t = obj.t + obj.k;
             obj.n = obj.n + 1;
         end