Blender V2.61 - r43446

sgstrs.c

Go to the documentation of this file.
00001 
00005 /*
00006  * -- SuperLU routine (version 3.0) --
00007  * Univ. of California Berkeley, Xerox Palo Alto Research Center,
00008  * and Lawrence Berkeley National Lab.
00009  * October 15, 2003
00010  *
00011  */
00012 /*
00013   Copyright (c) 1994 by Xerox Corporation.  All rights reserved.
00014  
00015   THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
00016   EXPRESSED OR IMPLIED.  ANY USE IS AT YOUR OWN RISK.
00017  
00018   Permission is hereby granted to use or copy this program for any
00019   purpose, provided the above notices are retained on all copies.
00020   Permission to modify the code and to distribute modified code is
00021   granted, provided the above notices are retained, and a notice that
00022   the code was modified is included with the above copyright notice.
00023 */
00024 
00025 #include "ssp_defs.h"
00026 
00027 
00028 /* 
00029  * Function prototypes 
00030  */
00031 void susolve(int, int, float*, float*);
00032 void slsolve(int, int, float*, float*);
00033 void smatvec(int, int, int, float*, float*, float*);
00034 void sprint_soln(int , float *);
00035 
00036 void
00037 sgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
00038         int *perm_c, int *perm_r, SuperMatrix *B,
00039         SuperLUStat_t *stat, int *info)
00040 {
00041 /*
00042  * Purpose
00043  * =======
00044  *
00045  * SGSTRS solves a system of linear equations A*X=B or A'*X=B
00046  * with A sparse and B dense, using the LU factorization computed by
00047  * SGSTRF.
00048  *
00049  * See supermatrix.h for the definition of 'SuperMatrix' structure.
00050  *
00051  * Arguments
00052  * =========
00053  *
00054  * trans   (input) trans_t
00055  *          Specifies the form of the system of equations:
00056  *          = NOTRANS: A * X = B  (No transpose)
00057  *          = TRANS:   A'* X = B  (Transpose)
00058  *          = CONJ:    A**H * X = B  (Conjugate transpose)
00059  *
00060  * L       (input) SuperMatrix*
00061  *         The factor L from the factorization Pr*A*Pc=L*U as computed by
00062  *         sgstrf(). Use compressed row subscripts storage for supernodes,
00063  *         i.e., L has types: Stype = SLU_SC, Dtype = SLU_S, Mtype = SLU_TRLU.
00064  *
00065  * U       (input) SuperMatrix*
00066  *         The factor U from the factorization Pr*A*Pc=L*U as computed by
00067  *         sgstrf(). Use column-wise storage scheme, i.e., U has types:
00068  *         Stype = SLU_NC, Dtype = SLU_S, Mtype = SLU_TRU.
00069  *
00070  * perm_c  (input) int*, dimension (L->ncol)
00071  *     Column permutation vector, which defines the 
00072  *         permutation matrix Pc; perm_c[i] = j means column i of A is 
00073  *         in position j in A*Pc.
00074  *
00075  * perm_r  (input) int*, dimension (L->nrow)
00076  *         Row permutation vector, which defines the permutation matrix Pr; 
00077  *         perm_r[i] = j means row i of A is in position j in Pr*A.
00078  *
00079  * B       (input/output) SuperMatrix*
00080  *         B has types: Stype = SLU_DN, Dtype = SLU_S, Mtype = SLU_GE.
00081  *         On entry, the right hand side matrix.
00082  *         On exit, the solution matrix if info = 0;
00083  *
00084  * stat     (output) SuperLUStat_t*
00085  *          Record the statistics on runtime and floating-point operation count.
00086  *          See util.h for the definition of 'SuperLUStat_t'.
00087  *
00088  * info    (output) int*
00089  *     = 0: successful exit
00090  *     < 0: if info = -i, the i-th argument had an illegal value
00091  *
00092  */
00093 #ifdef _CRAY
00094     _fcd ftcs1, ftcs2, ftcs3, ftcs4;
00095 #endif
00096 #ifdef USE_VENDOR_BLAS
00097     float   alpha = 1.0, beta = 1.0;
00098     float   *work_col;
00099 #endif
00100     DNformat *Bstore;
00101     float   *Bmat;
00102     SCformat *Lstore;
00103     NCformat *Ustore;
00104     float   *Lval, *Uval;
00105     int      fsupc, nrow, nsupr, nsupc, luptr, istart, irow;
00106     int      i, j, k, iptr, jcol, n, ldb, nrhs;
00107     float   *work, *rhs_work, *soln;
00108     flops_t  solve_ops;
00109     void sprint_soln();
00110 
00111     /* Test input parameters ... */
00112     *info = 0;
00113     Bstore = B->Store;
00114     ldb = Bstore->lda;
00115     nrhs = B->ncol;
00116     if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1;
00117     else if ( L->nrow != L->ncol || L->nrow < 0 ||
00118           L->Stype != SLU_SC || L->Dtype != SLU_S || L->Mtype != SLU_TRLU )
00119     *info = -2;
00120     else if ( U->nrow != U->ncol || U->nrow < 0 ||
00121           U->Stype != SLU_NC || U->Dtype != SLU_S || U->Mtype != SLU_TRU )
00122     *info = -3;
00123     else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
00124           B->Stype != SLU_DN || B->Dtype != SLU_S || B->Mtype != SLU_GE )
00125     *info = -6;
00126     if ( *info ) {
00127     i = -(*info);
00128     xerbla_("sgstrs", &i);
00129     return;
00130     }
00131 
00132     n = L->nrow;
00133     work = floatCalloc(n * nrhs);
00134     if ( !work ) ABORT("Malloc fails for local work[].");
00135     soln = floatMalloc(n);
00136     if ( !soln ) ABORT("Malloc fails for local soln[].");
00137 
00138     Bmat = Bstore->nzval;
00139     Lstore = L->Store;
00140     Lval = Lstore->nzval;
00141     Ustore = U->Store;
00142     Uval = Ustore->nzval;
00143     solve_ops = 0;
00144     
00145     if ( trans == NOTRANS ) {
00146     /* Permute right hand sides to form Pr*B */
00147     for (i = 0; i < nrhs; i++) {
00148         rhs_work = &Bmat[i*ldb];
00149         for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
00150         for (k = 0; k < n; k++) rhs_work[k] = soln[k];
00151     }
00152     
00153     /* Forward solve PLy=Pb. */
00154     for (k = 0; k <= Lstore->nsuper; k++) {
00155         fsupc = L_FST_SUPC(k);
00156         istart = L_SUB_START(fsupc);
00157         nsupr = L_SUB_START(fsupc+1) - istart;
00158         nsupc = L_FST_SUPC(k+1) - fsupc;
00159         nrow = nsupr - nsupc;
00160 
00161         solve_ops += nsupc * (nsupc - 1) * nrhs;
00162         solve_ops += 2 * nrow * nsupc * nrhs;
00163         
00164         if ( nsupc == 1 ) {
00165         for (j = 0; j < nrhs; j++) {
00166             rhs_work = &Bmat[j*ldb];
00167                 luptr = L_NZ_START(fsupc);
00168             for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
00169             irow = L_SUB(iptr);
00170             ++luptr;
00171             rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr];
00172             }
00173         }
00174         } else {
00175             luptr = L_NZ_START(fsupc);
00176 #ifdef USE_VENDOR_BLAS
00177 #ifdef _CRAY
00178         ftcs1 = _cptofcd("L", strlen("L"));
00179         ftcs2 = _cptofcd("N", strlen("N"));
00180         ftcs3 = _cptofcd("U", strlen("U"));
00181         STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
00182                &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
00183         
00184         SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, 
00185             &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
00186             &beta, &work[0], &n );
00187 #else
00188         strsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
00189                &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
00190         
00191         sgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, 
00192             &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
00193             &beta, &work[0], &n );
00194 #endif
00195         for (j = 0; j < nrhs; j++) {
00196             rhs_work = &Bmat[j*ldb];
00197             work_col = &work[j*n];
00198             iptr = istart + nsupc;
00199             for (i = 0; i < nrow; i++) {
00200             irow = L_SUB(iptr);
00201             rhs_work[irow] -= work_col[i]; /* Scatter */
00202             work_col[i] = 0.0;
00203             iptr++;
00204             }
00205         }
00206 #else       
00207         for (j = 0; j < nrhs; j++) {
00208             rhs_work = &Bmat[j*ldb];
00209             slsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
00210             smatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
00211                 &rhs_work[fsupc], &work[0] );
00212 
00213             iptr = istart + nsupc;
00214             for (i = 0; i < nrow; i++) {
00215             irow = L_SUB(iptr);
00216             rhs_work[irow] -= work[i];
00217             work[i] = 0.0;
00218             iptr++;
00219             }
00220         }
00221 #endif          
00222         } /* else ... */
00223     } /* for L-solve */
00224 
00225 #ifdef DEBUG
00226     printf("After L-solve: y=\n");
00227     sprint_soln(n, Bmat);
00228 #endif
00229 
00230     /*
00231      * Back solve Ux=y.
00232      */
00233     for (k = Lstore->nsuper; k >= 0; k--) {
00234         fsupc = L_FST_SUPC(k);
00235         istart = L_SUB_START(fsupc);
00236         nsupr = L_SUB_START(fsupc+1) - istart;
00237         nsupc = L_FST_SUPC(k+1) - fsupc;
00238         luptr = L_NZ_START(fsupc);
00239 
00240         solve_ops += nsupc * (nsupc + 1) * nrhs;
00241 
00242         if ( nsupc == 1 ) {
00243         rhs_work = &Bmat[0];
00244         for (j = 0; j < nrhs; j++) {
00245             rhs_work[fsupc] /= Lval[luptr];
00246             rhs_work += ldb;
00247         }
00248         } else {
00249 #ifdef USE_VENDOR_BLAS
00250 #ifdef _CRAY
00251         ftcs1 = _cptofcd("L", strlen("L"));
00252         ftcs2 = _cptofcd("U", strlen("U"));
00253         ftcs3 = _cptofcd("N", strlen("N"));
00254         STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
00255                &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
00256 #else
00257         strsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
00258                &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
00259 #endif
00260 #else       
00261         for (j = 0; j < nrhs; j++)
00262             susolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] );
00263 #endif      
00264         }
00265 
00266         for (j = 0; j < nrhs; ++j) {
00267         rhs_work = &Bmat[j*ldb];
00268         for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
00269             solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
00270             for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
00271             irow = U_SUB(i);
00272             rhs_work[irow] -= rhs_work[jcol] * Uval[i];
00273             }
00274         }
00275         }
00276         
00277     } /* for U-solve */
00278 
00279 #ifdef DEBUG
00280     printf("After U-solve: x=\n");
00281     sprint_soln(n, Bmat);
00282 #endif
00283 
00284     /* Compute the final solution X := Pc*X. */
00285     for (i = 0; i < nrhs; i++) {
00286         rhs_work = &Bmat[i*ldb];
00287         for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
00288         for (k = 0; k < n; k++) rhs_work[k] = soln[k];
00289     }
00290     
00291         stat->ops[SOLVE] = solve_ops;
00292 
00293     } else { /* Solve A'*X=B or CONJ(A)*X=B */
00294     /* Permute right hand sides to form Pc'*B. */
00295     for (i = 0; i < nrhs; i++) {
00296         rhs_work = &Bmat[i*ldb];
00297         for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
00298         for (k = 0; k < n; k++) rhs_work[k] = soln[k];
00299     }
00300 
00301     stat->ops[SOLVE] = 0;
00302     for (k = 0; k < nrhs; ++k) {
00303         
00304         /* Multiply by inv(U'). */
00305         sp_strsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info);
00306         
00307         /* Multiply by inv(L'). */
00308         sp_strsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info);
00309         
00310     }
00311     /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
00312     for (i = 0; i < nrhs; i++) {
00313         rhs_work = &Bmat[i*ldb];
00314         for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
00315         for (k = 0; k < n; k++) rhs_work[k] = soln[k];
00316     }
00317 
00318     }
00319 
00320     SUPERLU_FREE(work);
00321     SUPERLU_FREE(soln);
00322 }
00323 
00324 /*
00325  * Diagnostic print of the solution vector 
00326  */
00327 void
00328 sprint_soln(int n, float *soln)
00329 {
00330     int i;
00331 
00332     for (i = 0; i < n; i++) 
00333     printf("\t%d: %.4f\n", i, soln[i]);
00334 }