diff --git a/BLAS/SRC/CMakeLists.txt b/BLAS/SRC/CMakeLists.txt index c92bd2bad9..fa726c153e 100644 --- a/BLAS/SRC/CMakeLists.txt +++ b/BLAS/SRC/CMakeLists.txt @@ -58,7 +58,7 @@ set(ZB1AUX #--------------------------------------------------------------------- # Auxiliary routines needed by both the Level 2 and Level 3 BLAS #--------------------------------------------------------------------- -set(ALLBLAS lsame.f xerbla.f xerbla_array.f) +set(ALLBLAS lsame.f xerbla.f xerbla_array.f xerblai.f xerblai_array.f) #--------------------------------------------------------- # Level 2 BLAS @@ -82,15 +82,17 @@ set(ZBLAS2 zgemv.f zgbmv.f zhemv.f zhbmv.f zhpmv.f #--------------------------------------------------------- # Level 3 BLAS #--------------------------------------------------------- -set(SBLAS3 sgemm.f ssymm.f ssyrk.f ssyr2k.f strmm.f strsm.f) +set(SBLAS3 sgemm.f ssymm.f ssyrk.f ssyr2k.f strmm.f strsm.f + sgemm_batch.f90) set(CBLAS3 cgemm.f csymm.f csyrk.f csyr2k.f ctrmm.f ctrsm.f - chemm.f cherk.f cher2k.f) + chemm.f cherk.f cher2k.f cgemm_batch.f90) -set(DBLAS3 dgemm.f dsymm.f dsyrk.f dsyr2k.f dtrmm.f dtrsm.f) +set(DBLAS3 dgemm.f dsymm.f dsyrk.f dsyr2k.f dtrmm.f dtrsm.f + dgemm_batch.f90) set(ZBLAS3 zgemm.f zsymm.f zsyrk.f zsyr2k.f ztrmm.f ztrsm.f - zhemm.f zherk.f zher2k.f) + zhemm.f zherk.f zher2k.f zgemm_batch.f90) set(SOURCES) diff --git a/BLAS/SRC/Makefile b/BLAS/SRC/Makefile index 70534c8358..f4d96e1072 100644 --- a/BLAS/SRC/Makefile +++ b/BLAS/SRC/Makefile @@ -96,7 +96,7 @@ $(ZB1AUX): $(FRC) # Level 2 and Level 3 BLAS. Comment it out only if you already have # both the Level 2 and 3 BLAS. #--------------------------------------------------------------------- -ALLBLAS = lsame.o xerbla.o xerbla_array.o +ALLBLAS = lsame.o xerbla.o xerbla_array.o xerblai.o xerblai_array.o $(ALLBLAS): $(FRC) #--------------------------------------------------------- @@ -127,18 +127,20 @@ $(ZBLAS2): $(FRC) # Comment out the next 4 definitions if you already have # the Level 3 BLAS. #--------------------------------------------------------- -SBLAS3 = sgemm.o ssymm.o ssyrk.o ssyr2k.o strmm.o strsm.o +SBLAS3 = sgemm.o ssymm.o ssyrk.o ssyr2k.o strmm.o strsm.o \ + sgemm_batch.o $(SBLAS3): $(FRC) CBLAS3 = cgemm.o csymm.o csyrk.o csyr2k.o ctrmm.o ctrsm.o \ - chemm.o cherk.o cher2k.o + chemm.o cherk.o cher2k.o cgemm_batch.o $(CBLAS3): $(FRC) -DBLAS3 = dgemm.o dsymm.o dsyrk.o dsyr2k.o dtrmm.o dtrsm.o +DBLAS3 = dgemm.o dsymm.o dsyrk.o dsyr2k.o dtrmm.o dtrsm.o \ + dgemm_batch.o $(DBLAS3): $(FRC) ZBLAS3 = zgemm.o zsymm.o zsyrk.o zsyr2k.o ztrmm.o ztrsm.o \ - zhemm.o zherk.o zher2k.o + zhemm.o zherk.o zher2k.o zgemm_batch.o $(ZBLAS3): $(FRC) ALLOBJ = $(SBLAS1) $(SBLAS2) $(SBLAS3) $(DBLAS1) $(DBLAS2) $(DBLAS3) \ diff --git a/BLAS/SRC/cgemm_batch.f90 b/BLAS/SRC/cgemm_batch.f90 new file mode 100644 index 0000000000..f47242706d --- /dev/null +++ b/BLAS/SRC/cgemm_batch.f90 @@ -0,0 +1,372 @@ +!> \brief \b CGEMM_BATCH +! +! =========== DOCUMENTATION =========== +! +! Online html documentation available at +! http://www.netlib.org/lapack/explore-html/ +! +! Definition: +! =========== +! +! SUBROUTINE CGEMM_BATCH(TRANSA_ARRAY, TRANSB_ARRAY, +! M_ARRAY, N_ARRAY, K_ARRAY, +! ALPHA_ARRAY, +! A_ARRAY, LDA_ARRAY, +! B_ARRAY, LDB_ARRAY, +! BETA_ARRAY, +! C_ARRAY, LDC_ARRAY, +! GROUP_COUNT, GROUP_SIZE) +! +! .. Scalar Arguments .. +! INTEGER GROUP_COUNT +! .. +! .. Array Arguments .. +! CHARACTER TRANSA_ARRAY(GROUP_COUNT), TRANSB_ARRAY(GROUP_COUNT) +! INTEGER M_ARRAY(GROUP_COUNT), N_ARRAY(GROUP_COUNT), K_ARRAY(GROUP_COUNT) +! COMPLEX ALPHA_ARRAY(GROUP_COUNT),BETA_ARRAY(GROUP_COUNT) +! INTEGER LDA_ARRAY(GROUP_COUNT), LDB_ARRAY(GROUP_COUNT), LDC_ARRAY(GROUP_COUNT) +! INTEGER GROUP_SIZE(GROUP_COUNT) +! .. +! .. Pointer Arguments .. +! TYPE(C_PTR) A_ARRAY(*), B_ARRAY(*), C_ARRAY(*) +! .. +! +! +!> \par Purpose: +! ============= +!> +!> \verbatim +!> +!> CGEMM_BATCH performs a series of the matrix-matrix operations with each ji'th matrix: +!> +!> C_ji := alpha_i*op_i( A_ji )*op( B_ji ) + beta_i*C_ji, +!> +!> where op_i( X ) is one of +!> +!> op_i( X_ji ) = X_ji or op_i( X_ji ) = X_ji**T, +!> +!> alpha_i and beta_i are scalars, and A_ji, B_ji and C_ji are matrices, with op_i( A_ji ) +!> an m_i by k_i matrix, op_i( B_ji ) a k_i by n_i matrix and C_ji an m_i by n_i matrix. +!> Group count defines i and group_size(i) defines j. +!> +!> More generally, +!> +!> idx = 1 +!> for i in 1..group_count +!> alpha, beta = alpha(i), beta(i) +!> for j in 1..group_size(i) +!> A, B, C = A_ARRAY(idx), B_ARRAY(idx), C_ARRAY(idx) +!> C := alpha*op(A)*op(B) + beta*C +!> idx = idx + 1 +!> +!> +!> \endverbatim +! +! Arguments: +! ========== +! +!> \param[in] TRANSA_ARRAY +!> \verbatim +!> TRANSA_ARRAY is CHARACTER*1 array +!> On entry, TRANSA_ARRAY(i) specifies the form of op_i( A_ji ) to be used in +!> the matrix multiplication as follows: +!> +!> TRANSA_ARRAY(i) = 'N' or 'n', op_i( A_ji ) = A_ji. +!> +!> TRANSA_ARRAY(i) = 'T' or 't', op_i( A_ji ) = A_ji**T. +!> +!> TRANSA_ARRAY(i) = 'C' or 'c', op_i( A_ji ) = A_ji**H. +!> \endverbatim +!> +!> \param[in] TRANSB_ARRAY +!> \verbatim +!> TRANSB_ARRAY is CHARACTER*1 array +!> On entry, TRANSB_ARRAY(i) specifies the form of op_i( B_ji ) to be used in +!> the matrix multiplication as follows: +!> +!> TRANSB_ARRAY(i) = 'N' or 'n', op_i( B_ji ) = B_ji. +!> +!> TRANSB_ARRAY(i) = 'T' or 't', op_i( B_ji ) = B_ji**T. +!> +!> TRANSB_ARRAY(i) = 'C' or 'c', op_i( B_ji ) = B_ji**H. +!> \endverbatim +!> +!> \param[in] M_ARRAY +!> \verbatim +!> M_ARRAY is INTEGER array +!> On entry, M_ARRAY(i) specifies the number of rows of the matrixes +!> op_i( A_ji ) and the number of rows of the matrixes C_ji. +!> Each M_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] N_ARRAY +!> \verbatim +!> N_ARRAY is INTEGER array +!> On entry, N_ARRAY(i) specifies the number of columns of the matrixes +!> op_i( B_ji ) and the number of columns of the matrixes C_ji. +!> Each N_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] K_ARRAY +!> \verbatim +!> K_ARRAY is INTEGER array +!> On entry, K_ARRAY(i) specifies the number of columns of the matrixes +!> op_i( A_ji ) and the number of rows of the matrixes op_i( B_ji ). +!> Each K_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] ALPHA_ARRAY +!> \verbatim +!> ALPHA_ARRAY is COMPLEX array. +!> On entry, ALPHA_ARRAY(i) specifies the scalar alpha_i. +!> \endverbatim +!> +!> \param[in] A_ARRAY +!> \verbatim +!> A_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to COMPLEX arrays, dimension ( LDA_i, ka_i ), +!> where ka_i is k_i when TRANSA(i) = 'N' or 'n', and is m_i otherwise. +!> Before entry with TRANSA = 'N' or 'n', the leading m_i by k_i elements +!> at address A(ji) must contain the matrix A_ji, otherwise +!> the leading k_i by m_i elements at address A(ji) must contain the +!> matrix A_ji. +!> \endverbatim +!> +!> \param[in] LDA_ARRAY +!> \verbatim +!> LDA_ARRAY is INTEGER array +!> On entry, LDA_ARRAY(i) specifies the first dimension of A_ji as declared +!> in the calling (sub) program. When TRANSA = 'N' or 'n' then +!> LDA_ARRAY(i) must be at least max( 1, m_i ), otherwise LDA must be at +!> least max( 1, k_i ). +!> \endverbatim +!> +!> \param[in] B_ARRAY +!> \verbatim +!> B_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to COMPLEX arrays, dimension ( LDB_i, kb_i ), +!> where kb_i is n_i when TRANSB(i) = 'N' or 'n', and is k_i otherwise. +!> Before entry with TRANSB = 'N' or 'n', the leading k_i by n_i elements +!> at address B(ji) must contain the matrix B_ji, otherwise +!> the leading n_i by k_i elements at address B(ji) must contain the +!> matrix B_ji. +!> \endverbatim +!> +!> \param[in] LDB_ARRAY +!> \verbatim +!> LDB_ARRAY is INTEGER array +!> On entry, LDB_ARRAY(i) specifies the first dimension of B_ji as declared +!> in the calling (sub) program. When TRANSB = 'N' or 'n' then +!> LDB must be at least max( 1, k_i ), otherwise LDB must be at +!> least max( 1, n_i ). +!> \endverbatim +!> +!> \param[in] BETA_ARRAY +!> \verbatim +!> BETA_ARRAY is COMPLEX array. +!> On entry, BETA_ARRAY(i) specifies the scalar beta. When BETA_ARRAY(i) is +!> supplied as zero then C_ji need not be set on input. +!> \endverbatim +!> +!> \param[in,out] C_ARRAY +!> \verbatim +!> C_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to COMPLEX arrays, dimension ( LDC_i, n_i ). +!> Before entry, the leading m_i by n_i elements +!> at address C(ji) must contain the matrix C_ji, except when BETA_ARRAY(i) +!> is zero, in which case C_ji need not be set on entry. +!> On exit, the array C_ji is overwritten by the m_i by n_i matrix +!> ( alpha_i*op_i( A_ji )*op_i( B_ji ) + beta_i*C_ji ). +!> \endverbatim +!> +!> \param[in] LDC_ARRAY +!> \verbatim +!> LDC_ARRAY is INTEGER array +!> On entry, LDC_ARRAY(i) specifies the first dimension of C_ji as declared +!> in the calling (sub) program. LDC_ARRAY(i) must be at least +!> max( 1, m_i ). +!> \endverbatim +!> +!> \param[in] GROUP_COUNT +!> \verbatim +!> GROUP_COUNT is INTEGER +!> On entry, GROUP_COUNT specifies the number of groups that determines index i. +!> \endverbatim +!> +!> \param[in] GROUP_SIZE +!> \verbatim +!> GROUP_SIZE is INTEGER array +!> On entry, GROUP_SIZE specifies the number of elements in each groups that determines index j. +!> \endverbatim +! +! Authors: +! ======== +! +!> \author Igor S. Gerasimov +! +!> \ingroup gemm_batch +! +!> \par Further Details: +! ===================== +!> +!> \verbatim +!> +!> Level 3 Blas routine. +!> +!> Original API is taken from: +!> https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-fortran/2023-2/gemm-batch.html +!> +!> -- Written on 23-October-2023. +!> +!> \endverbatim +!> +! ===================================================================== + SUBROUTINE CGEMM_BATCH(TRANSA_ARRAY, TRANSB_ARRAY, & + M_ARRAY, N_ARRAY, K_ARRAY, & + ALPHA_ARRAY, & + A_ARRAY, LDA_ARRAY, & + B_ARRAY, LDB_ARRAY, & + BETA_ARRAY, & + C_ARRAY, LDC_ARRAY, & + GROUP_COUNT, GROUP_SIZE) + USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_PTR, C_F_POINTER, C_ASSOCIATED +! +! -- Reference BLAS level3 routine -- +! -- Reference BLAS is a software package provided by Univ. of Tennessee, -- +! -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..-- +! +! .. Scalar Arguments .. + INTEGER GROUP_COUNT +! .. +! .. Array Arguments .. + CHARACTER TRANSA_ARRAY(GROUP_COUNT), TRANSB_ARRAY(GROUP_COUNT) + INTEGER M_ARRAY(GROUP_COUNT), N_ARRAY(GROUP_COUNT), K_ARRAY(GROUP_COUNT) + COMPLEX ALPHA_ARRAY(GROUP_COUNT), BETA_ARRAY(GROUP_COUNT) + INTEGER LDA_ARRAY(GROUP_COUNT), LDB_ARRAY(GROUP_COUNT), LDC_ARRAY(GROUP_COUNT) + INTEGER GROUP_SIZE(GROUP_COUNT) +! .. +! .. Pointer Arguments .. + TYPE(C_PTR) A_ARRAY(*), B_ARRAY(*), C_ARRAY(*) +! .. +! +! ===================================================================== +! +! .. External Functions .. + LOGICAL LSAME + EXTERNAL LSAME +! .. +! .. External Subroutines .. + EXTERNAL XERBLA + EXTERNAL XERBLAI +! .. +! .. Intrinsic Functions .. + INTRINSIC MAX +! .. +! .. Local Scalars .. + INTEGER I, J, IDX, INFO + LOGICAL NOTA, NOTB + INTEGER NROWA, NROWB +! .. +! .. Local Addresses .. + COMPLEX, POINTER :: A, B, C +! .. +! +! Test the input parameters. +! + INFO = 0 + IF (GROUP_COUNT.LT.0) THEN + INFO = 15 + END IF + IF (INFO.NE.0) THEN + CALL XERBLA('CGEMM_BATCH ', INFO) + RETURN + END IF + DO I = 1, GROUP_COUNT + INFO = 0 +! +! Set NOTA and NOTB as true if A and B respectively are not +! transposed and set NROWA and NROWB as the number of rows of A +! and B respectively. +! + NOTA = LSAME(TRANSA_ARRAY(I),'N') + NOTB = LSAME(TRANSB_ARRAY(I),'N') + IF (NOTA) THEN + NROWA = M_ARRAY(I) + ELSE + NROWA = K_ARRAY(I) + END IF + IF (NOTB) THEN + NROWB = K_ARRAY(I) + ELSE + NROWB = N_ARRAY(I) + END IF + IF ((.NOT.NOTA) .AND. & + (.NOT.LSAME(TRANSA_ARRAY(I),'C')) .AND. & + (.NOT.LSAME(TRANSA_ARRAY(I),'T'))) THEN + INFO = 1 + ELSE IF ((.NOT.NOTB) .AND. & + (.NOT.LSAME(TRANSB_ARRAY(I),'C')) .AND. & + (.NOT.LSAME(TRANSB_ARRAY(I),'T'))) THEN + INFO = 2 + ELSE IF (M_ARRAY(I).LT.0) THEN + INFO = 3 + ELSE IF (N_ARRAY(I).LT.0) THEN + INFO = 4 + ELSE IF (K_ARRAY(I).LT.0) THEN + INFO = 5 + ELSE IF (LDA_ARRAY(I).LT.MAX(1,NROWA)) THEN + INFO = 8 + ELSE IF (LDB_ARRAY(I).LT.MAX(1,NROWB)) THEN + INFO = 10 + ELSE IF (LDC_ARRAY(I).LT.MAX(1,M_ARRAY(I))) THEN + INFO = 13 + ELSE IF (GROUP_SIZE(I).LT.0) THEN + INFO = 15 + END IF + IF (INFO.NE.0) THEN + CALL XERBLA('CGEMM_BATCH ',INFO,I) + RETURN + END IF + END DO + IDX = 1 + DO I = 1, GROUP_COUNT + DO J = 1, GROUP_SIZE(I) + INFO = 0 + IF (.NOT.C_ASSOCIATED(A_ARRAY(IDX))) THEN + INFO = 7 + ELSE IF (.NOT.C_ASSOCIATED(B_ARRAY(IDX))) THEN + INFO = 9 + ELSE IF (.NOT.C_ASSOCIATED(C_ARRAY(IDX))) THEN + INFO = 12 + END IF + IF (INFO.NE.0) THEN + CALL XERBLAI('CGEMM_BATCH ',INFO,IDX) + RETURN + END IF + IDX = IDX + 1 + END DO + END DO +! +! Do computations. +! + IDX = 1 + DO I = 1, GROUP_COUNT + DO J = 1, GROUP_SIZE(I) + CALL C_F_POINTER(A_ARRAY(IDX), A) + CALL C_F_POINTER(B_ARRAY(IDX), B) + CALL C_F_POINTER(C_ARRAY(IDX), C) + CALL CGEMM(TRANSA_ARRAY(I), TRANSB_ARRAY(I), & + M_ARRAY(I), N_ARRAY(I), K_ARRAY(I), & + ALPHA_ARRAY(I), & + A, LDA_ARRAY(I), & + B, LDB_ARRAY(I), & + BETA_ARRAY(I), & + C, LDC_ARRAY(I)) + IDX = IDX + 1 + END DO + END DO + RETURN +! +! End of CGEMM_BATCH. +! + END diff --git a/BLAS/SRC/dgemm_batch.f90 b/BLAS/SRC/dgemm_batch.f90 new file mode 100644 index 0000000000..6d783dd3e3 --- /dev/null +++ b/BLAS/SRC/dgemm_batch.f90 @@ -0,0 +1,372 @@ +!> \brief \b DGEMM_BATCH +! +! =========== DOCUMENTATION =========== +! +! Online html documentation available at +! http://www.netlib.org/lapack/explore-html/ +! +! Definition: +! =========== +! +! SUBROUTINE DGEMM_BATCH(TRANSA_ARRAY, TRANSB_ARRAY, +! M_ARRAY, N_ARRAY, K_ARRAY, +! ALPHA_ARRAY, +! A_ARRAY, LDA_ARRAY, +! B_ARRAY, LDB_ARRAY, +! BETA_ARRAY, +! C_ARRAY, LDC_ARRAY, +! GROUP_COUNT, GROUP_SIZE) +! +! .. Scalar Arguments .. +! INTEGER GROUP_COUNT +! .. +! .. Array Arguments .. +! CHARACTER TRANSA_ARRAY(GROUP_COUNT), TRANSB_ARRAY(GROUP_COUNT) +! INTEGER M_ARRAY(GROUP_COUNT), N_ARRAY(GROUP_COUNT), K_ARRAY(GROUP_COUNT) +! DOUBLE PRECISION ALPHA_ARRAY(GROUP_COUNT),BETA_ARRAY(GROUP_COUNT) +! INTEGER LDA_ARRAY(GROUP_COUNT), LDB_ARRAY(GROUP_COUNT), LDC_ARRAY(GROUP_COUNT) +! INTEGER GROUP_SIZE(GROUP_COUNT) +! .. +! .. Pointer Arguments .. +! TYPE(C_PTR) A_ARRAY(*), B_ARRAY(*), C_ARRAY(*) +! .. +! +! +!> \par Purpose: +! ============= +!> +!> \verbatim +!> +!> DGEMM_BATCH performs a series of the matrix-matrix operations with each ji'th matrix: +!> +!> C_ji := alpha_i*op_i( A_ji )*op( B_ji ) + beta_i*C_ji, +!> +!> where op_i( X ) is one of +!> +!> op_i( X_ji ) = X_ji or op_i( X_ji ) = X_ji**T, +!> +!> alpha_i and beta_i are scalars, and A_ji, B_ji and C_ji are matrices, with op_i( A_ji ) +!> an m_i by k_i matrix, op_i( B_ji ) a k_i by n_i matrix and C_ji an m_i by n_i matrix. +!> Group count defines i and group_size(i) defines j. +!> +!> More generally, +!> +!> idx = 1 +!> for i in 1..group_count +!> alpha, beta = alpha(i), beta(i) +!> for j in 1..group_size(i) +!> A, B, C = A_ARRAY(idx), B_ARRAY(idx), C_ARRAY(idx) +!> C := alpha*op(A)*op(B) + beta*C +!> idx = idx + 1 +!> +!> +!> \endverbatim +! +! Arguments: +! ========== +! +!> \param[in] TRANSA_ARRAY +!> \verbatim +!> TRANSA_ARRAY is CHARACTER*1 array +!> On entry, TRANSA_ARRAY(i) specifies the form of op_i( A_ji ) to be used in +!> the matrix multiplication as follows: +!> +!> TRANSA_ARRAY(i) = 'N' or 'n', op_i( A_ji ) = A_ji. +!> +!> TRANSA_ARRAY(i) = 'T' or 't', op_i( A_ji ) = A_ji**T. +!> +!> TRANSA_ARRAY(i) = 'C' or 'c', op_i( A_ji ) = A_ji**T. +!> \endverbatim +!> +!> \param[in] TRANSB_ARRAY +!> \verbatim +!> TRANSB_ARRAY is CHARACTER*1 array +!> On entry, TRANSB_ARRAY(i) specifies the form of op_i( B_ji ) to be used in +!> the matrix multiplication as follows: +!> +!> TRANSB_ARRAY(i) = 'N' or 'n', op_i( B_ji ) = B_ji. +!> +!> TRANSB_ARRAY(i) = 'T' or 't', op_i( B_ji ) = B_ji**T. +!> +!> TRANSB_ARRAY(i) = 'C' or 'c', op_i( B_ji ) = B_ji**T. +!> \endverbatim +!> +!> \param[in] M_ARRAY +!> \verbatim +!> M_ARRAY is INTEGER array +!> On entry, M_ARRAY(i) specifies the number of rows of the matrixes +!> op_i( A_ji ) and the number of rows of the matrixes C_ji. +!> Each M_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] N_ARRAY +!> \verbatim +!> N_ARRAY is INTEGER array +!> On entry, N_ARRAY(i) specifies the number of columns of the matrixes +!> op_i( B_ji ) and the number of columns of the matrixes C_ji. +!> Each N_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] K_ARRAY +!> \verbatim +!> K_ARRAY is INTEGER array +!> On entry, K_ARRAY(i) specifies the number of columns of the matrixes +!> op_i( A_ji ) and the number of rows of the matrixes op_i( B_ji ). +!> Each K_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] ALPHA_ARRAY +!> \verbatim +!> ALPHA_ARRAY is DOUBLE PRECISION array. +!> On entry, ALPHA_ARRAY(i) specifies the scalar alpha_i. +!> \endverbatim +!> +!> \param[in] A_ARRAY +!> \verbatim +!> A_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to DOUBLE PRECISION arrays, dimension ( LDA_i, ka_i ), +!> where ka_i is k_i when TRANSA(i) = 'N' or 'n', and is m_i otherwise. +!> Before entry with TRANSA = 'N' or 'n', the leading m_i by k_i elements +!> at address A(ji) must contain the matrix A_ji, otherwise +!> the leading k_i by m_i elements at address A(ji) must contain the +!> matrix A_ji. +!> \endverbatim +!> +!> \param[in] LDA_ARRAY +!> \verbatim +!> LDA_ARRAY is INTEGER array +!> On entry, LDA_ARRAY(i) specifies the first dimension of A_ji as declared +!> in the calling (sub) program. When TRANSA = 'N' or 'n' then +!> LDA_ARRAY(i) must be at least max( 1, m_i ), otherwise LDA must be at +!> least max( 1, k_i ). +!> \endverbatim +!> +!> \param[in] B_ARRAY +!> \verbatim +!> B_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to DOUBLE PRECISION arrays, dimension ( LDB_i, kb_i ), +!> where kb_i is n_i when TRANSB(i) = 'N' or 'n', and is k_i otherwise. +!> Before entry with TRANSB = 'N' or 'n', the leading k_i by n_i elements +!> at address B(ji) must contain the matrix B_ji, otherwise +!> the leading n_i by k_i elements at address B(ji) must contain the +!> matrix B_ji. +!> \endverbatim +!> +!> \param[in] LDB_ARRAY +!> \verbatim +!> LDB_ARRAY is INTEGER array +!> On entry, LDB_ARRAY(i) specifies the first dimension of B_ji as declared +!> in the calling (sub) program. When TRANSB = 'N' or 'n' then +!> LDB must be at least max( 1, k_i ), otherwise LDB must be at +!> least max( 1, n_i ). +!> \endverbatim +!> +!> \param[in] BETA_ARRAY +!> \verbatim +!> BETA_ARRAY is DOUBLE PRECISION array. +!> On entry, BETA_ARRAY(i) specifies the scalar beta. When BETA_ARRAY(i) is +!> supplied as zero then C_ji need not be set on input. +!> \endverbatim +!> +!> \param[in,out] C_ARRAY +!> \verbatim +!> C_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to DOUBLE PRECISION arrays, dimension ( LDC_i, n_i ). +!> Before entry, the leading m_i by n_i elements +!> at address C(ji) must contain the matrix C_ji, except when BETA_ARRAY(i) +!> is zero, in which case C_ji need not be set on entry. +!> On exit, the array C_ji is overwritten by the m_i by n_i matrix +!> ( alpha_i*op_i( A_ji )*op_i( B_ji ) + beta_i*C_ji ). +!> \endverbatim +!> +!> \param[in] LDC_ARRAY +!> \verbatim +!> LDC_ARRAY is INTEGER array +!> On entry, LDC_ARRAY(i) specifies the first dimension of C_ji as declared +!> in the calling (sub) program. LDC_ARRAY(i) must be at least +!> max( 1, m_i ). +!> \endverbatim +!> +!> \param[in] GROUP_COUNT +!> \verbatim +!> GROUP_COUNT is INTEGER +!> On entry, GROUP_COUNT specifies the number of groups that determines index i. +!> \endverbatim +!> +!> \param[in] GROUP_SIZE +!> \verbatim +!> GROUP_SIZE is INTEGER array +!> On entry, GROUP_SIZE specifies the number of elements in each groups that determines index j. +!> \endverbatim +! +! Authors: +! ======== +! +!> \author Igor S. Gerasimov +! +!> \ingroup gemm_batch +! +!> \par Further Details: +! ===================== +!> +!> \verbatim +!> +!> Level 3 Blas routine. +!> +!> Original API is taken from: +!> https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-fortran/2023-2/gemm-batch.html +!> +!> -- Written on 23-October-2023. +!> +!> \endverbatim +!> +! ===================================================================== + SUBROUTINE DGEMM_BATCH(TRANSA_ARRAY, TRANSB_ARRAY, & + M_ARRAY, N_ARRAY, K_ARRAY, & + ALPHA_ARRAY, & + A_ARRAY, LDA_ARRAY, & + B_ARRAY, LDB_ARRAY, & + BETA_ARRAY, & + C_ARRAY, LDC_ARRAY, & + GROUP_COUNT, GROUP_SIZE) + USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_PTR, C_F_POINTER, C_ASSOCIATED +! +! -- Reference BLAS level3 routine -- +! -- Reference BLAS is a software package provided by Univ. of Tennessee, -- +! -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..-- +! +! .. Scalar Arguments .. + INTEGER GROUP_COUNT +! .. +! .. Array Arguments .. + CHARACTER TRANSA_ARRAY(GROUP_COUNT), TRANSB_ARRAY(GROUP_COUNT) + INTEGER M_ARRAY(GROUP_COUNT), N_ARRAY(GROUP_COUNT), K_ARRAY(GROUP_COUNT) + DOUBLE PRECISION ALPHA_ARRAY(GROUP_COUNT), BETA_ARRAY(GROUP_COUNT) + INTEGER LDA_ARRAY(GROUP_COUNT), LDB_ARRAY(GROUP_COUNT), LDC_ARRAY(GROUP_COUNT) + INTEGER GROUP_SIZE(GROUP_COUNT) +! .. +! .. Pointer Arguments .. + TYPE(C_PTR) A_ARRAY(*), B_ARRAY(*), C_ARRAY(*) +! .. +! +! ===================================================================== +! +! .. External Functions .. + LOGICAL LSAME + EXTERNAL LSAME +! .. +! .. External Subroutines .. + EXTERNAL XERBLA + EXTERNAL XERBLAI +! .. +! .. Intrinsic Functions .. + INTRINSIC MAX +! .. +! .. Local Scalars .. + INTEGER I, J, IDX, INFO + LOGICAL NOTA, NOTB + INTEGER NROWA, NROWB +! .. +! .. Local Addresses .. + DOUBLE PRECISION, POINTER :: A, B, C +! .. +! +! Test the input parameters. +! + INFO = 0 + IF (GROUP_COUNT.LT.0) THEN + INFO = 15 + END IF + IF (INFO.NE.0) THEN + CALL XERBLA('DGEMM_BATCH ', INFO) + RETURN + END IF + DO I = 1, GROUP_COUNT + INFO = 0 +! +! Set NOTA and NOTB as true if A and B respectively are not +! transposed and set NROWA and NROWB as the number of rows of A +! and B respectively. +! + NOTA = LSAME(TRANSA_ARRAY(I),'N') + NOTB = LSAME(TRANSB_ARRAY(I),'N') + IF (NOTA) THEN + NROWA = M_ARRAY(I) + ELSE + NROWA = K_ARRAY(I) + END IF + IF (NOTB) THEN + NROWB = K_ARRAY(I) + ELSE + NROWB = N_ARRAY(I) + END IF + IF ((.NOT.NOTA) .AND. & + (.NOT.LSAME(TRANSA_ARRAY(I),'C')) .AND. & + (.NOT.LSAME(TRANSA_ARRAY(I),'T'))) THEN + INFO = 1 + ELSE IF ((.NOT.NOTB) .AND. & + (.NOT.LSAME(TRANSB_ARRAY(I),'C')) .AND. & + (.NOT.LSAME(TRANSB_ARRAY(I),'T'))) THEN + INFO = 2 + ELSE IF (M_ARRAY(I).LT.0) THEN + INFO = 3 + ELSE IF (N_ARRAY(I).LT.0) THEN + INFO = 4 + ELSE IF (K_ARRAY(I).LT.0) THEN + INFO = 5 + ELSE IF (LDA_ARRAY(I).LT.MAX(1,NROWA)) THEN + INFO = 8 + ELSE IF (LDB_ARRAY(I).LT.MAX(1,NROWB)) THEN + INFO = 10 + ELSE IF (LDC_ARRAY(I).LT.MAX(1,M_ARRAY(I))) THEN + INFO = 13 + ELSE IF (GROUP_SIZE(I).LT.0) THEN + INFO = 15 + END IF + IF (INFO.NE.0) THEN + CALL XERBLAI('DGEMM_BATCH ',INFO,I) + RETURN + END IF + END DO + IDX = 1 + DO I = 1, GROUP_COUNT + DO J = 1, GROUP_SIZE(I) + INFO = 0 + IF (.NOT.C_ASSOCIATED(A_ARRAY(IDX))) THEN + INFO = 7 + ELSE IF (.NOT.C_ASSOCIATED(B_ARRAY(IDX))) THEN + INFO = 9 + ELSE IF (.NOT.C_ASSOCIATED(C_ARRAY(IDX))) THEN + INFO = 12 + END IF + IF (INFO.NE.0) THEN + CALL XERBLAI('DGEMM_BATCH ',INFO,IDX) + RETURN + END IF + IDX = IDX + 1 + END DO + END DO +! +! Do computations. +! + IDX = 1 + DO I = 1, GROUP_COUNT + DO J = 1, GROUP_SIZE(I) + CALL C_F_POINTER(A_ARRAY(IDX), A) + CALL C_F_POINTER(B_ARRAY(IDX), B) + CALL C_F_POINTER(C_ARRAY(IDX), C) + CALL DGEMM(TRANSA_ARRAY(I), TRANSB_ARRAY(I), & + M_ARRAY(I), N_ARRAY(I), K_ARRAY(I), & + ALPHA_ARRAY(I), & + A, LDA_ARRAY(I), & + B, LDB_ARRAY(I), & + BETA_ARRAY(I), & + C, LDC_ARRAY(I)) + IDX = IDX + 1 + END DO + END DO + RETURN +! +! End of DGEMM_BATCH. +! + END diff --git a/BLAS/SRC/sgemm_batch.f90 b/BLAS/SRC/sgemm_batch.f90 new file mode 100644 index 0000000000..748d3d22ba --- /dev/null +++ b/BLAS/SRC/sgemm_batch.f90 @@ -0,0 +1,372 @@ +!> \brief \b SGEMM_BATCH +! +! =========== DOCUMENTATION =========== +! +! Online html documentation available at +! http://www.netlib.org/lapack/explore-html/ +! +! Definition: +! =========== +! +! SUBROUTINE SGEMM_BATCH(TRANSA_ARRAY, TRANSB_ARRAY, +! M_ARRAY, N_ARRAY, K_ARRAY, +! ALPHA_ARRAY, +! A_ARRAY, LDA_ARRAY, +! B_ARRAY, LDB_ARRAY, +! BETA_ARRAY, +! C_ARRAY, LDC_ARRAY, +! GROUP_COUNT, GROUP_SIZE) +! +! .. Scalar Arguments .. +! INTEGER GROUP_COUNT +! .. +! .. Array Arguments .. +! CHARACTER TRANSA_ARRAY(GROUP_COUNT), TRANSB_ARRAY(GROUP_COUNT) +! INTEGER M_ARRAY(GROUP_COUNT), N_ARRAY(GROUP_COUNT), K_ARRAY(GROUP_COUNT) +! REAL ALPHA_ARRAY(GROUP_COUNT),BETA_ARRAY(GROUP_COUNT) +! INTEGER LDA_ARRAY(GROUP_COUNT), LDB_ARRAY(GROUP_COUNT), LDC_ARRAY(GROUP_COUNT) +! INTEGER GROUP_SIZE(GROUP_COUNT) +! .. +! .. Pointer Arguments .. +! TYPE(C_PTR) A_ARRAY(*), B_ARRAY(*), C_ARRAY(*) +! .. +! +! +!> \par Purpose: +! ============= +!> +!> \verbatim +!> +!> SGEMM_BATCH performs a series of the matrix-matrix operations with each ji'th matrix: +!> +!> C_ji := alpha_i*op_i( A_ji )*op( B_ji ) + beta_i*C_ji, +!> +!> where op_i( X ) is one of +!> +!> op_i( X_ji ) = X_ji or op_i( X_ji ) = X_ji**T, +!> +!> alpha_i and beta_i are scalars, and A_ji, B_ji and C_ji are matrices, with op_i( A_ji ) +!> an m_i by k_i matrix, op_i( B_ji ) a k_i by n_i matrix and C_ji an m_i by n_i matrix. +!> Group count defines i and group_size(i) defines j. +!> +!> More generally, +!> +!> idx = 1 +!> for i in 1..group_count +!> alpha, beta = alpha(i), beta(i) +!> for j in 1..group_size(i) +!> A, B, C = A_ARRAY(idx), B_ARRAY(idx), C_ARRAY(idx) +!> C := alpha*op(A)*op(B) + beta*C +!> idx = idx + 1 +!> +!> +!> \endverbatim +! +! Arguments: +! ========== +! +!> \param[in] TRANSA_ARRAY +!> \verbatim +!> TRANSA_ARRAY is CHARACTER*1 array +!> On entry, TRANSA_ARRAY(i) specifies the form of op_i( A_ji ) to be used in +!> the matrix multiplication as follows: +!> +!> TRANSA_ARRAY(i) = 'N' or 'n', op_i( A_ji ) = A_ji. +!> +!> TRANSA_ARRAY(i) = 'T' or 't', op_i( A_ji ) = A_ji**T. +!> +!> TRANSA_ARRAY(i) = 'C' or 'c', op_i( A_ji ) = A_ji**T. +!> \endverbatim +!> +!> \param[in] TRANSB_ARRAY +!> \verbatim +!> TRANSB_ARRAY is CHARACTER*1 array +!> On entry, TRANSB_ARRAY(i) specifies the form of op_i( B_ji ) to be used in +!> the matrix multiplication as follows: +!> +!> TRANSB_ARRAY(i) = 'N' or 'n', op_i( B_ji ) = B_ji. +!> +!> TRANSB_ARRAY(i) = 'T' or 't', op_i( B_ji ) = B_ji**T. +!> +!> TRANSB_ARRAY(i) = 'C' or 'c', op_i( B_ji ) = B_ji**T. +!> \endverbatim +!> +!> \param[in] M_ARRAY +!> \verbatim +!> M_ARRAY is INTEGER array +!> On entry, M_ARRAY(i) specifies the number of rows of the matrixes +!> op_i( A_ji ) and the number of rows of the matrixes C_ji. +!> Each M_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] N_ARRAY +!> \verbatim +!> N_ARRAY is INTEGER array +!> On entry, N_ARRAY(i) specifies the number of columns of the matrixes +!> op_i( B_ji ) and the number of columns of the matrixes C_ji. +!> Each N_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] K_ARRAY +!> \verbatim +!> K_ARRAY is INTEGER array +!> On entry, K_ARRAY(i) specifies the number of columns of the matrixes +!> op_i( A_ji ) and the number of rows of the matrixes op_i( B_ji ). +!> Each K_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] ALPHA_ARRAY +!> \verbatim +!> ALPHA_ARRAY is REAL array. +!> On entry, ALPHA_ARRAY(i) specifies the scalar alpha_i. +!> \endverbatim +!> +!> \param[in] A_ARRAY +!> \verbatim +!> A_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to REAL arrays, dimension ( LDA_i, ka_i ), +!> where ka_i is k_i when TRANSA(i) = 'N' or 'n', and is m_i otherwise. +!> Before entry with TRANSA = 'N' or 'n', the leading m_i by k_i elements +!> at address A(ji) must contain the matrix A_ji, otherwise +!> the leading k_i by m_i elements at address A(ji) must contain the +!> matrix A_ji. +!> \endverbatim +!> +!> \param[in] LDA_ARRAY +!> \verbatim +!> LDA_ARRAY is INTEGER array +!> On entry, LDA_ARRAY(i) specifies the first dimension of A_ji as declared +!> in the calling (sub) program. When TRANSA = 'N' or 'n' then +!> LDA_ARRAY(i) must be at least max( 1, m_i ), otherwise LDA must be at +!> least max( 1, k_i ). +!> \endverbatim +!> +!> \param[in] B_ARRAY +!> \verbatim +!> B_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to REAL arrays, dimension ( LDB_i, kb_i ), +!> where kb_i is n_i when TRANSB(i) = 'N' or 'n', and is k_i otherwise. +!> Before entry with TRANSB = 'N' or 'n', the leading k_i by n_i elements +!> at address B(ji) must contain the matrix B_ji, otherwise +!> the leading n_i by k_i elements at address B(ji) must contain the +!> matrix B_ji. +!> \endverbatim +!> +!> \param[in] LDB_ARRAY +!> \verbatim +!> LDB_ARRAY is INTEGER array +!> On entry, LDB_ARRAY(i) specifies the first dimension of B_ji as declared +!> in the calling (sub) program. When TRANSB = 'N' or 'n' then +!> LDB must be at least max( 1, k_i ), otherwise LDB must be at +!> least max( 1, n_i ). +!> \endverbatim +!> +!> \param[in] BETA_ARRAY +!> \verbatim +!> BETA_ARRAY is REAL array. +!> On entry, BETA_ARRAY(i) specifies the scalar beta. When BETA_ARRAY(i) is +!> supplied as zero then C_ji need not be set on input. +!> \endverbatim +!> +!> \param[in,out] C_ARRAY +!> \verbatim +!> C_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to REAL arrays, dimension ( LDC_i, n_i ). +!> Before entry, the leading m_i by n_i elements +!> at address C(ji) must contain the matrix C_ji, except when BETA_ARRAY(i) +!> is zero, in which case C_ji need not be set on entry. +!> On exit, the array C_ji is overwritten by the m_i by n_i matrix +!> ( alpha_i*op_i( A_ji )*op_i( B_ji ) + beta_i*C_ji ). +!> \endverbatim +!> +!> \param[in] LDC_ARRAY +!> \verbatim +!> LDC_ARRAY is INTEGER array +!> On entry, LDC_ARRAY(i) specifies the first dimension of C_ji as declared +!> in the calling (sub) program. LDC_ARRAY(i) must be at least +!> max( 1, m_i ). +!> \endverbatim +!> +!> \param[in] GROUP_COUNT +!> \verbatim +!> GROUP_COUNT is INTEGER +!> On entry, GROUP_COUNT specifies the number of groups that determines index i. +!> \endverbatim +!> +!> \param[in] GROUP_SIZE +!> \verbatim +!> GROUP_SIZE is INTEGER array +!> On entry, GROUP_SIZE specifies the number of elements in each groups that determines index j. +!> \endverbatim +! +! Authors: +! ======== +! +!> \author Igor S. Gerasimov +! +!> \ingroup gemm_batch +! +!> \par Further Details: +! ===================== +!> +!> \verbatim +!> +!> Level 3 Blas routine. +!> +!> Original API is taken from: +!> https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-fortran/2023-2/gemm-batch.html +!> +!> -- Written on 23-October-2023. +!> +!> \endverbatim +!> +! ===================================================================== + SUBROUTINE SGEMM_BATCH(TRANSA_ARRAY, TRANSB_ARRAY, & + M_ARRAY, N_ARRAY, K_ARRAY, & + ALPHA_ARRAY, & + A_ARRAY, LDA_ARRAY, & + B_ARRAY, LDB_ARRAY, & + BETA_ARRAY, & + C_ARRAY, LDC_ARRAY, & + GROUP_COUNT, GROUP_SIZE) + USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_PTR, C_F_POINTER, C_ASSOCIATED +! +! -- Reference BLAS level3 routine -- +! -- Reference BLAS is a software package provided by Univ. of Tennessee, -- +! -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..-- +! +! .. Scalar Arguments .. + INTEGER GROUP_COUNT +! .. +! .. Array Arguments .. + CHARACTER TRANSA_ARRAY(GROUP_COUNT), TRANSB_ARRAY(GROUP_COUNT) + INTEGER M_ARRAY(GROUP_COUNT), N_ARRAY(GROUP_COUNT), K_ARRAY(GROUP_COUNT) + REAL ALPHA_ARRAY(GROUP_COUNT), BETA_ARRAY(GROUP_COUNT) + INTEGER LDA_ARRAY(GROUP_COUNT), LDB_ARRAY(GROUP_COUNT), LDC_ARRAY(GROUP_COUNT) + INTEGER GROUP_SIZE(GROUP_COUNT) +! .. +! .. Pointer Arguments .. + TYPE(C_PTR) A_ARRAY(*), B_ARRAY(*), C_ARRAY(*) +! .. +! +! ===================================================================== +! +! .. External Functions .. + LOGICAL LSAME + EXTERNAL LSAME +! .. +! .. External Subroutines .. + EXTERNAL XERBLA + EXTERNAL XERBLAI +! .. +! .. Intrinsic Functions .. + INTRINSIC MAX +! .. +! .. Local Scalars .. + INTEGER I, J, IDX, INFO + LOGICAL NOTA, NOTB + INTEGER NROWA, NROWB +! .. +! .. Local Addresses .. + REAL, POINTER :: A, B, C +! .. +! +! Test the input parameters. +! + INFO = 0 + IF (GROUP_COUNT.LT.0) THEN + INFO = 15 + END IF + IF (INFO.NE.0) THEN + CALL XERBLA('SGEMM_BATCH ', INFO) + RETURN + END IF + DO I = 1, GROUP_COUNT + INFO = 0 +! +! Set NOTA and NOTB as true if A and B respectively are not +! transposed and set NROWA and NROWB as the number of rows of A +! and B respectively. +! + NOTA = LSAME(TRANSA_ARRAY(I),'N') + NOTB = LSAME(TRANSB_ARRAY(I),'N') + IF (NOTA) THEN + NROWA = M_ARRAY(I) + ELSE + NROWA = K_ARRAY(I) + END IF + IF (NOTB) THEN + NROWB = K_ARRAY(I) + ELSE + NROWB = N_ARRAY(I) + END IF + IF ((.NOT.NOTA) .AND. & + (.NOT.LSAME(TRANSA_ARRAY(I),'C')) .AND. & + (.NOT.LSAME(TRANSA_ARRAY(I),'T'))) THEN + INFO = 1 + ELSE IF ((.NOT.NOTB) .AND. & + (.NOT.LSAME(TRANSB_ARRAY(I),'C')) .AND. & + (.NOT.LSAME(TRANSB_ARRAY(I),'T'))) THEN + INFO = 2 + ELSE IF (M_ARRAY(I).LT.0) THEN + INFO = 3 + ELSE IF (N_ARRAY(I).LT.0) THEN + INFO = 4 + ELSE IF (K_ARRAY(I).LT.0) THEN + INFO = 5 + ELSE IF (LDA_ARRAY(I).LT.MAX(1,NROWA)) THEN + INFO = 8 + ELSE IF (LDB_ARRAY(I).LT.MAX(1,NROWB)) THEN + INFO = 10 + ELSE IF (LDC_ARRAY(I).LT.MAX(1,M_ARRAY(I))) THEN + INFO = 13 + ELSE IF (GROUP_SIZE(I).LT.0) THEN + INFO = 15 + END IF + IF (INFO.NE.0) THEN + CALL XERBLAI('SGEMM_BATCH ',INFO,I) + RETURN + END IF + END DO + IDX = 1 + DO I = 1, GROUP_COUNT + DO J = 1, GROUP_SIZE(I) + INFO = 0 + IF (.NOT.C_ASSOCIATED(A_ARRAY(IDX))) THEN + INFO = 7 + ELSE IF (.NOT.C_ASSOCIATED(B_ARRAY(IDX))) THEN + INFO = 9 + ELSE IF (.NOT.C_ASSOCIATED(C_ARRAY(IDX))) THEN + INFO = 12 + END IF + IF (INFO.NE.0) THEN + CALL XERBLAI('SGEMM_BATCH ',INFO,IDX) + RETURN + END IF + IDX = IDX + 1 + END DO + END DO +! +! Do computations. +! + IDX = 1 + DO I = 1, GROUP_COUNT + DO J = 1, GROUP_SIZE(I) + CALL C_F_POINTER(A_ARRAY(IDX), A) + CALL C_F_POINTER(B_ARRAY(IDX), B) + CALL C_F_POINTER(C_ARRAY(IDX), C) + CALL SGEMM(TRANSA_ARRAY(I), TRANSB_ARRAY(I), & + M_ARRAY(I), N_ARRAY(I), K_ARRAY(I), & + ALPHA_ARRAY(I), & + A, LDA_ARRAY(I), & + B, LDB_ARRAY(I), & + BETA_ARRAY(I), & + C, LDC_ARRAY(I)) + IDX = IDX + 1 + END DO + END DO + RETURN +! +! End of SGEMM_BATCH. +! + END diff --git a/BLAS/SRC/xerblai.f b/BLAS/SRC/xerblai.f new file mode 100644 index 0000000000..e3078bf77a --- /dev/null +++ b/BLAS/SRC/xerblai.f @@ -0,0 +1,94 @@ +*> \brief \b XERBLAI +* +* =========== DOCUMENTATION =========== +* +* Online html documentation available at +* http://www.netlib.org/lapack/explore-html/ +* +* Definition: +* =========== +* +* SUBROUTINE XERBLAI( SRNAME, INFO, INDX ) +* +* .. Scalar Arguments .. +* CHARACTER*(*) SRNAME +* INTEGER INFO, INDX +* .. +* +* +*> \par Purpose: +* ============= +*> +*> \verbatim +*> +*> XERBLAI is an error handler for the BLAS/LAPACK routines. +*> It is called by an BLAS/LAPACK routine if an input parameter has an +*> invalid value. A message is printed and execution stops. +*> +*> Installers may consider modifying the STOP statement in order to +*> call system-specific exception-handling facilities. +*> \endverbatim +* +* Arguments: +* ========== +* +*> \param[in] SRNAME +*> \verbatim +*> SRNAME is CHARACTER*(*) +*> The name of the routine which called XERBLAI. +*> \endverbatim +*> +*> \param[in] INFO +*> \verbatim +*> INFO is INTEGER +*> The position of the invalid parameter in the parameter list +*> of the calling routine. +*> \endverbatim +*> +*> \param[in] INDX +*> \verbatim +*> INDX is INTEGER +*> The position at the invalid parameter in the parameter list +*> of the calling routine. +*> \endverbatim +* +* Authors: +* ======== +* +*> \author Univ. of Tennessee +*> \author Univ. of California Berkeley +*> \author Univ. of Colorado Denver +*> \author NAG Ltd. +*> \author Igor S. Gerasimov +* +*> \ingroup xerblai +* +* ===================================================================== + SUBROUTINE XERBLAI( SRNAME, INFO, INDX ) +* +* -- Reference BLAS level1 routine -- +* -- Reference BLAS is a software package provided by Univ. of Tennessee, -- +* -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..-- +* +* .. Scalar Arguments .. + CHARACTER*(*) SRNAME + INTEGER INFO, INDX +* .. +* +* ===================================================================== +* +* .. Intrinsic Functions .. + INTRINSIC LEN_TRIM +* .. +* .. Executable Statements .. +* + WRITE( *, FMT = 9999 )SRNAME( 1:LEN_TRIM( SRNAME ) ), INFO, INDX +* + STOP +* + 9999 FORMAT( ' ** On entry to ', A, ' parameter number ', I2, ' had ', + $ 'an illegal value at position ', I0, '.' ) +* +* End of XERBLAI +* + END diff --git a/BLAS/SRC/xerblai_array.f b/BLAS/SRC/xerblai_array.f new file mode 100644 index 0000000000..d5dffb27f7 --- /dev/null +++ b/BLAS/SRC/xerblai_array.f @@ -0,0 +1,127 @@ +*> \brief \b XERBLAI_ARRAY +* +* =========== DOCUMENTATION =========== +* +* Online html documentation available at +* http://www.netlib.org/lapack/explore-html/ +* +* Definition: +* =========== +* +* SUBROUTINE XERBLAI_ARRAY(SRNAME_ARRAY, SRNAME_LEN, INFO, INDX) +* +* .. Scalar Arguments .. +* INTEGER SRNAME_LEN, INFO, INDX +* .. +* .. Array Arguments .. +* CHARACTER(1) SRNAME_ARRAY(SRNAME_LEN) +* .. +* +* +*> \par Purpose: +* ============= +*> +*> \verbatim +*> +*> XERBLAI_ARRAY assists other languages in calling XERBLAI, the LAPACK +*> and BLAS error handler. Rather than taking a Fortran string argument +*> as the function's name, XERBLAI_ARRAY takes an array of single +*> characters along with the array's length. XERBLAI_ARRAY then copies +*> up to 32 characters of that array into a Fortran string and passes +*> that to XERBLAI. If called with a non-positive SRNAME_LEN, +*> XERBLAI_ARRAY will call XERBLAI with a string of all blank characters. +*> +*> Say some macro or other device makes XERBLAI_ARRAY available to C99 +*> by a name lapack_xerbla and with a common Fortran calling convention. +*> Then a C99 program could invoke XERBLAI via: +*> { +*> int flen = strlen(__func__); +*> lapack_xerblai(__func__, &flen, &info, &indx); +*> } +*> +*> Providing XERBLAI_ARRAY is not necessary for intercepting LAPACK +*> errors. XERBLAI_ARRAY calls XERBLAI. +*> \endverbatim +* +* Arguments: +* ========== +* +*> \param[in] SRNAME_ARRAY +*> \verbatim +*> SRNAME_ARRAY is CHARACTER(1) array, dimension (SRNAME_LEN) +*> The name of the routine which called XERBLAI_ARRAY. +*> \endverbatim +*> +*> \param[in] SRNAME_LEN +*> \verbatim +*> SRNAME_LEN is INTEGER +*> The length of the name in SRNAME_ARRAY. +*> \endverbatim +*> +*> \param[in] INFO +*> \verbatim +*> INFO is INTEGER +*> The position of the invalid parameter in the parameter list +*> of the calling routine. +*> \endverbatim +*> +*> \param[in] INDX +*> \verbatim +*> INDX is INTEGER +*> The position at the invalid parameter in the parameter list +*> of the calling routine. +*> \endverbatim +* +* Authors: +* ======== +* +*> \author Univ. of Tennessee +*> \author Univ. of California Berkeley +*> \author Univ. of Colorado Denver +*> \author NAG Ltd. +*> \author Igor S. Gerasimov +* +*> \ingroup xerbla_array +* +* ===================================================================== + SUBROUTINE XERBLAI_ARRAY(SRNAME_ARRAY, SRNAME_LEN, INFO, INDX) +* +* -- Reference BLAS level1 routine -- +* -- Reference BLAS is a software package provided by Univ. of Tennessee, -- +* -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..-- +* +* .. Scalar Arguments .. + INTEGER SRNAME_LEN, INFO, INDX +* .. +* .. Array Arguments .. + CHARACTER(1) SRNAME_ARRAY(SRNAME_LEN) +* .. +* +* ===================================================================== +* +* .. +* .. Local Scalars .. + INTEGER I +* .. +* .. Local Arrays .. + CHARACTER*32 SRNAME +* .. +* .. Intrinsic Functions .. + INTRINSIC MIN, LEN +* .. +* .. External Functions .. + EXTERNAL XERBLAI +* .. +* .. Executable Statements .. + SRNAME = ' ' + DO I = 1, MIN( SRNAME_LEN, LEN( SRNAME ) ) + SRNAME( I:I ) = SRNAME_ARRAY( I ) + END DO + + CALL XERBLAI( SRNAME, INFO, INDX ) + + RETURN +* +* End of XERBLAI_ARRAY +* + END diff --git a/BLAS/SRC/zgemm_batch.f90 b/BLAS/SRC/zgemm_batch.f90 new file mode 100644 index 0000000000..c89757b6d4 --- /dev/null +++ b/BLAS/SRC/zgemm_batch.f90 @@ -0,0 +1,372 @@ +!> \brief \b ZGEMM_BATCH +! +! =========== DOCUMENTATION =========== +! +! Online html documentation available at +! http://www.netlib.org/lapack/explore-html/ +! +! Definition: +! =========== +! +! SUBROUTINE ZGEMM_BATCH(TRANSA_ARRAY, TRANSB_ARRAY, +! M_ARRAY, N_ARRAY, K_ARRAY, +! ALPHA_ARRAY, +! A_ARRAY, LDA_ARRAY, +! B_ARRAY, LDB_ARRAY, +! BETA_ARRAY, +! C_ARRAY, LDC_ARRAY, +! GROUP_COUNT, GROUP_SIZE) +! +! .. Scalar Arguments .. +! INTEGER GROUP_COUNT +! .. +! .. Array Arguments .. +! CHARACTER TRANSA_ARRAY(GROUP_COUNT), TRANSB_ARRAY(GROUP_COUNT) +! INTEGER M_ARRAY(GROUP_COUNT), N_ARRAY(GROUP_COUNT), K_ARRAY(GROUP_COUNT) +! COMPLEX*16 ALPHA_ARRAY(GROUP_COUNT),BETA_ARRAY(GROUP_COUNT) +! INTEGER LDA_ARRAY(GROUP_COUNT), LDB_ARRAY(GROUP_COUNT), LDC_ARRAY(GROUP_COUNT) +! INTEGER GROUP_SIZE(GROUP_COUNT) +! .. +! .. Pointer Arguments .. +! TYPE(C_PTR) A_ARRAY(*), B_ARRAY(*), C_ARRAY(*) +! .. +! +! +!> \par Purpose: +! ============= +!> +!> \verbatim +!> +!> ZGEMM_BATCH performs a series of the matrix-matrix operations with each ji'th matrix: +!> +!> C_ji := alpha_i*op_i( A_ji )*op( B_ji ) + beta_i*C_ji, +!> +!> where op_i( X ) is one of +!> +!> op_i( X_ji ) = X_ji or op_i( X_ji ) = X_ji**T, +!> +!> alpha_i and beta_i are scalars, and A_ji, B_ji and C_ji are matrices, with op_i( A_ji ) +!> an m_i by k_i matrix, op_i( B_ji ) a k_i by n_i matrix and C_ji an m_i by n_i matrix. +!> Group count defines i and group_size(i) defines j. +!> +!> More generally, +!> +!> idx = 1 +!> for i in 1..group_count +!> alpha, beta = alpha(i), beta(i) +!> for j in 1..group_size(i) +!> A, B, C = A_ARRAY(idx), B_ARRAY(idx), C_ARRAY(idx) +!> C := alpha*op(A)*op(B) + beta*C +!> idx = idx + 1 +!> +!> +!> \endverbatim +! +! Arguments: +! ========== +! +!> \param[in] TRANSA_ARRAY +!> \verbatim +!> TRANSA_ARRAY is CHARACTER*1 array +!> On entry, TRANSA_ARRAY(i) specifies the form of op_i( A_ji ) to be used in +!> the matrix multiplication as follows: +!> +!> TRANSA_ARRAY(i) = 'N' or 'n', op_i( A_ji ) = A_ji. +!> +!> TRANSA_ARRAY(i) = 'T' or 't', op_i( A_ji ) = A_ji**T. +!> +!> TRANSA_ARRAY(i) = 'C' or 'c', op_i( A_ji ) = A_ji**H. +!> \endverbatim +!> +!> \param[in] TRANSB_ARRAY +!> \verbatim +!> TRANSB_ARRAY is CHARACTER*1 array +!> On entry, TRANSB_ARRAY(i) specifies the form of op_i( B_ji ) to be used in +!> the matrix multiplication as follows: +!> +!> TRANSB_ARRAY(i) = 'N' or 'n', op_i( B_ji ) = B_ji. +!> +!> TRANSB_ARRAY(i) = 'T' or 't', op_i( B_ji ) = B_ji**T. +!> +!> TRANSB_ARRAY(i) = 'C' or 'c', op_i( B_ji ) = B_ji**H. +!> \endverbatim +!> +!> \param[in] M_ARRAY +!> \verbatim +!> M_ARRAY is INTEGER array +!> On entry, M_ARRAY(i) specifies the number of rows of the matrixes +!> op_i( A_ji ) and the number of rows of the matrixes C_ji. +!> Each M_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] N_ARRAY +!> \verbatim +!> N_ARRAY is INTEGER array +!> On entry, N_ARRAY(i) specifies the number of columns of the matrixes +!> op_i( B_ji ) and the number of columns of the matrixes C_ji. +!> Each N_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] K_ARRAY +!> \verbatim +!> K_ARRAY is INTEGER array +!> On entry, K_ARRAY(i) specifies the number of columns of the matrixes +!> op_i( A_ji ) and the number of rows of the matrixes op_i( B_ji ). +!> Each K_ARRAY(i) must be at least zero. +!> \endverbatim +!> +!> \param[in] ALPHA_ARRAY +!> \verbatim +!> ALPHA_ARRAY is COMPLEX*16 array. +!> On entry, ALPHA_ARRAY(i) specifies the scalar alpha_i. +!> \endverbatim +!> +!> \param[in] A_ARRAY +!> \verbatim +!> A_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to COMPLEX*16 arrays, dimension ( LDA_i, ka_i ), +!> where ka_i is k_i when TRANSA(i) = 'N' or 'n', and is m_i otherwise. +!> Before entry with TRANSA = 'N' or 'n', the leading m_i by k_i elements +!> at address A(ji) must contain the matrix A_ji, otherwise +!> the leading k_i by m_i elements at address A(ji) must contain the +!> matrix A_ji. +!> \endverbatim +!> +!> \param[in] LDA_ARRAY +!> \verbatim +!> LDA_ARRAY is INTEGER array +!> On entry, LDA_ARRAY(i) specifies the first dimension of A_ji as declared +!> in the calling (sub) program. When TRANSA = 'N' or 'n' then +!> LDA_ARRAY(i) must be at least max( 1, m_i ), otherwise LDA must be at +!> least max( 1, k_i ). +!> \endverbatim +!> +!> \param[in] B_ARRAY +!> \verbatim +!> B_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to COMPLEX*16 arrays, dimension ( LDB_i, kb_i ), +!> where kb_i is n_i when TRANSB(i) = 'N' or 'n', and is k_i otherwise. +!> Before entry with TRANSB = 'N' or 'n', the leading k_i by n_i elements +!> at address B(ji) must contain the matrix B_ji, otherwise +!> the leading n_i by k_i elements at address B(ji) must contain the +!> matrix B_ji. +!> \endverbatim +!> +!> \param[in] LDB_ARRAY +!> \verbatim +!> LDB_ARRAY is INTEGER array +!> On entry, LDB_ARRAY(i) specifies the first dimension of B_ji as declared +!> in the calling (sub) program. When TRANSB = 'N' or 'n' then +!> LDB must be at least max( 1, k_i ), otherwise LDB must be at +!> least max( 1, n_i ). +!> \endverbatim +!> +!> \param[in] BETA_ARRAY +!> \verbatim +!> BETA_ARRAY is COMPLEX*16 array. +!> On entry, BETA_ARRAY(i) specifies the scalar beta. When BETA_ARRAY(i) is +!> supplied as zero then C_ji need not be set on input. +!> \endverbatim +!> +!> \param[in,out] C_ARRAY +!> \verbatim +!> C_ARRAY is POINTER array, dimension ( sum( GROUP_SIZE ) ), +!> to COMPLEX*16 arrays, dimension ( LDC_i, n_i ). +!> Before entry, the leading m_i by n_i elements +!> at address C(ji) must contain the matrix C_ji, except when BETA_ARRAY(i) +!> is zero, in which case C_ji need not be set on entry. +!> On exit, the array C_ji is overwritten by the m_i by n_i matrix +!> ( alpha_i*op_i( A_ji )*op_i( B_ji ) + beta_i*C_ji ). +!> \endverbatim +!> +!> \param[in] LDC_ARRAY +!> \verbatim +!> LDC_ARRAY is INTEGER array +!> On entry, LDC_ARRAY(i) specifies the first dimension of C_ji as declared +!> in the calling (sub) program. LDC_ARRAY(i) must be at least +!> max( 1, m_i ). +!> \endverbatim +!> +!> \param[in] GROUP_COUNT +!> \verbatim +!> GROUP_COUNT is INTEGER +!> On entry, GROUP_COUNT specifies the number of groups that determines index i. +!> \endverbatim +!> +!> \param[in] GROUP_SIZE +!> \verbatim +!> GROUP_SIZE is INTEGER array +!> On entry, GROUP_SIZE specifies the number of elements in each groups that determines index j. +!> \endverbatim +! +! Authors: +! ======== +! +!> \author Igor S. Gerasimov +! +!> \ingroup gemm_batch +! +!> \par Further Details: +! ===================== +!> +!> \verbatim +!> +!> Level 3 Blas routine. +!> +!> Original API is taken from: +!> https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-fortran/2023-2/gemm-batch.html +!> +!> -- Written on 23-October-2023. +!> +!> \endverbatim +!> +! ===================================================================== + SUBROUTINE ZGEMM_BATCH(TRANSA_ARRAY, TRANSB_ARRAY, & + M_ARRAY, N_ARRAY, K_ARRAY, & + ALPHA_ARRAY, & + A_ARRAY, LDA_ARRAY, & + B_ARRAY, LDB_ARRAY, & + BETA_ARRAY, & + C_ARRAY, LDC_ARRAY, & + GROUP_COUNT, GROUP_SIZE) + USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_PTR, C_F_POINTER, C_ASSOCIATED +! +! -- Reference BLAS level3 routine -- +! -- Reference BLAS is a software package provided by Univ. of Tennessee, -- +! -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..-- +! +! .. Scalar Arguments .. + INTEGER GROUP_COUNT +! .. +! .. Array Arguments .. + CHARACTER TRANSA_ARRAY(GROUP_COUNT), TRANSB_ARRAY(GROUP_COUNT) + INTEGER M_ARRAY(GROUP_COUNT), N_ARRAY(GROUP_COUNT), K_ARRAY(GROUP_COUNT) + COMPLEX*16 ALPHA_ARRAY(GROUP_COUNT), BETA_ARRAY(GROUP_COUNT) + INTEGER LDA_ARRAY(GROUP_COUNT), LDB_ARRAY(GROUP_COUNT), LDC_ARRAY(GROUP_COUNT) + INTEGER GROUP_SIZE(GROUP_COUNT) +! .. +! .. Pointer Arguments .. + TYPE(C_PTR) A_ARRAY(*), B_ARRAY(*), C_ARRAY(*) +! .. +! +! ===================================================================== +! +! .. External Functions .. + LOGICAL LSAME + EXTERNAL LSAME +! .. +! .. External Subroutines .. + EXTERNAL XERBLA + EXTERNAL XERBLAI +! .. +! .. Intrinsic Functions .. + INTRINSIC MAX +! .. +! .. Local Scalars .. + INTEGER I, J, IDX, INFO + LOGICAL NOTA, NOTB + INTEGER NROWA, NROWB +! .. +! .. Local Addresses .. + COMPLEX*16, POINTER :: A, B, C +! .. +! +! Test the input parameters. +! + INFO = 0 + IF (GROUP_COUNT.LT.0) THEN + INFO = 15 + END IF + IF (INFO.NE.0) THEN + CALL XERBLA('ZGEMM_BATCH ', INFO) + RETURN + END IF + DO I = 1, GROUP_COUNT + INFO = 0 +! +! Set NOTA and NOTB as true if A and B respectively are not +! transposed and set NROWA and NROWB as the number of rows of A +! and B respectively. +! + NOTA = LSAME(TRANSA_ARRAY(I),'N') + NOTB = LSAME(TRANSB_ARRAY(I),'N') + IF (NOTA) THEN + NROWA = M_ARRAY(I) + ELSE + NROWA = K_ARRAY(I) + END IF + IF (NOTB) THEN + NROWB = K_ARRAY(I) + ELSE + NROWB = N_ARRAY(I) + END IF + IF ((.NOT.NOTA) .AND. & + (.NOT.LSAME(TRANSA_ARRAY(I),'C')) .AND. & + (.NOT.LSAME(TRANSA_ARRAY(I),'T'))) THEN + INFO = 1 + ELSE IF ((.NOT.NOTB) .AND. & + (.NOT.LSAME(TRANSB_ARRAY(I),'C')) .AND. & + (.NOT.LSAME(TRANSB_ARRAY(I),'T'))) THEN + INFO = 2 + ELSE IF (M_ARRAY(I).LT.0) THEN + INFO = 3 + ELSE IF (N_ARRAY(I).LT.0) THEN + INFO = 4 + ELSE IF (K_ARRAY(I).LT.0) THEN + INFO = 5 + ELSE IF (LDA_ARRAY(I).LT.MAX(1,NROWA)) THEN + INFO = 8 + ELSE IF (LDB_ARRAY(I).LT.MAX(1,NROWB)) THEN + INFO = 10 + ELSE IF (LDC_ARRAY(I).LT.MAX(1,M_ARRAY(I))) THEN + INFO = 13 + ELSE IF (GROUP_SIZE(I).LT.0) THEN + INFO = 15 + END IF + IF (INFO.NE.0) THEN + CALL XERBLAI('ZGEMM_BATCH ',INFO,I) + RETURN + END IF + END DO + IDX = 1 + DO I = 1, GROUP_COUNT + DO J = 1, GROUP_SIZE(I) + INFO = 0 + IF (.NOT.C_ASSOCIATED(A_ARRAY(IDX))) THEN + INFO = 7 + ELSE IF (.NOT.C_ASSOCIATED(B_ARRAY(IDX))) THEN + INFO = 9 + ELSE IF (.NOT.C_ASSOCIATED(C_ARRAY(IDX))) THEN + INFO = 12 + END IF + IF (INFO.NE.0) THEN + CALL XERBLAI('ZGEMM_BATCH ',INFO,IDX) + RETURN + END IF + IDX = IDX + 1 + END DO + END DO +! +! Do computations. +! + IDX = 1 + DO I = 1, GROUP_COUNT + DO J = 1, GROUP_SIZE(I) + CALL C_F_POINTER(A_ARRAY(IDX), A) + CALL C_F_POINTER(B_ARRAY(IDX), B) + CALL C_F_POINTER(C_ARRAY(IDX), C) + CALL ZGEMM(TRANSA_ARRAY(I), TRANSB_ARRAY(I), & + M_ARRAY(I), N_ARRAY(I), K_ARRAY(I), & + ALPHA_ARRAY(I), & + A, LDA_ARRAY(I), & + B, LDB_ARRAY(I), & + BETA_ARRAY(I), & + C, LDC_ARRAY(I)) + IDX = IDX + 1 + END DO + END DO + RETURN +! +! End of ZGEMM_BATCH. +! + END