Skip to content

Commit

Permalink
Started testing gemm, need to work out some stuff.
Browse files Browse the repository at this point in the history
  • Loading branch information
corbett5 committed May 10, 2023
1 parent 3165719 commit 3725d1a
Show file tree
Hide file tree
Showing 13 changed files with 695 additions and 154 deletions.
6 changes: 1 addition & 5 deletions scripts/uberenv/packages/lvarray/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,10 @@ class Lvarray(CMakePackage, CudaPackage):
variant('addr2line', default=True,
description='Build support for addr2line.')

<<<<<<< HEAD
variant('tpl_build_type', default='none', description='TPL build type',
values=('Debug', 'Release', 'RelWithDebInfo', 'MinSizeRel', 'none'))


# conflicts('~lapack', when='+magma')
=======
conflicts('~lapack', when='+magma')
>>>>>>> cde43f2 (Building and compiling with MAGMA. GPU not yet working, think it's something to do with the new workspaces.)

depends_on('[email protected]:', when='@0.2.0:', type='build')

Expand Down Expand Up @@ -114,6 +109,7 @@ class Lvarray(CMakePackage, CudaPackage):
depends_on('umpire build_type={}'.format(bt))
depends_on('chai build_type={}'.format(bt), when='+chai')
depends_on('caliper build_type={}'.format(bt), when='+caliper')
depends_on('magma build_type={}'.format(bt), when='+magma')

phases = ['hostconfig', 'cmake', 'build', 'install']

Expand Down
10 changes: 5 additions & 5 deletions scripts/uberenv/spack_configs/toss_4_x86_64_ib/packages.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ packages:
target: [default]
compiler: [gcc, clang, intel]
providers:
blas: [intel-mkl]
lapack: [intel-mkl]
blas: [intel-oneapi-mkl]
lapack: [intel-oneapi-mkl]

intel-mkl:
intel-oneapi-mkl:
buildable: False
externals:
- spec: intel-mkl@2020.0.166 threads=openmp
prefix: /usr/tce/packages/mkl/mkl-2020.0/
- spec: intel-oneapi-mkl@2022.1.0
prefix: /usr/tce/backend/installations/linux-rhel8-x86_64/intel-19.0.4/intel-oneapi-mkl-2022.1.0-sksz67twjxftvwchnagedk36gf7plkrp/

cmake:
buildable: False
Expand Down
210 changes: 210 additions & 0 deletions src/dense/BlasLapackInterface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#include "BlasLapackInterface.hpp"
#include "backendHelpers.hpp"

extern "C"
{

////////////////////////////////////////////////////////////////////////////////////////////////////
#define LVARRAY_SGEMM LVARRAY_LAPACK_FORTRAN_MANGLE( sgemm )
void LVARRAY_SGEMM(
char const * TRANSA,
char const * TRANSB,
int const * M,
int const * N,
int const * K,
float const * ALPHA,
float const * A,
int const * LDA,
float const * B,
int const * LDB,
float const * BETA,
float * C,
int const * LDC );

////////////////////////////////////////////////////////////////////////////////////////////////////
#define LVARRAY_DGEMM LVARRAY_LAPACK_FORTRAN_MANGLE( dgemm )
void LVARRAY_DGEMM(
char const * TRANSA,
char const * TRANSB,
int const * M,
int const * N,
int const * K,
double const * ALPHA,
double const * A,
int const * LDA,
double const * B,
int const * LDB,
double const * BETA,
double * C,
int const * LDC );

////////////////////////////////////////////////////////////////////////////////////////////////////
#define LVARRAY_CGEMM LVARRAY_LAPACK_FORTRAN_MANGLE( cgemm )
void LVARRAY_CGEMM(
char const * TRANSA,
char const * TRANSB,
int const * M,
int const * N,
int const * K,
std::complex< float > const * ALPHA,
std::complex< float > const * A,
int const * LDA,
std::complex< float > const * B,
int const * LDB,
std::complex< float > const * BETA,
std::complex< float > * C,
int const * LDC );

////////////////////////////////////////////////////////////////////////////////////////////////////
#define LVARRAY_ZGEMM LVARRAY_LAPACK_FORTRAN_MANGLE( zgemm )
void LVARRAY_ZGEMM(
char const * TRANSA,
char const * TRANSB,
int const * M,
int const * N,
int const * K,
std::complex< double > const * ALPHA,
std::complex< double > const * A,
int const * LDA,
std::complex< double > const * B,
int const * LDB,
std::complex< double > const * BETA,
std::complex< double > * C,
int const * LDC );

////////////////////////////////////////////////////////////////////////////////////////////////////
#define LVARRAY_SGESV LVARRAY_LAPACK_FORTRAN_MANGLE( sgesv )
void LVARRAY_SGESV(
int const * N,
int const * NRHS,
float * A,
int const * LDA,
int * IPIV,
float * B,
int const * LDB,
int * INFO );

////////////////////////////////////////////////////////////////////////////////////////////////////
#define LVARRAY_DGESV LVARRAY_LAPACK_FORTRAN_MANGLE( dgesv )
void LVARRAY_DGESV(
int const * N,
int const * NRHS,
double * A,
int const * LDA,
int * IPIV,
double * B,
int const * LDB,
int * INFO );

////////////////////////////////////////////////////////////////////////////////////////////////////
#define LVARRAY_CGESV LVARRAY_LAPACK_FORTRAN_MANGLE( cgesv )
void LVARRAY_CGESV(
int const * N,
int const * NRHS,
std::complex< float > * A,
int const * LDA,
int * IPIV,
std::complex< float > * B,
int const * LDB,
int * INFO );

////////////////////////////////////////////////////////////////////////////////////////////////////
#define LVARRAY_ZGESV LVARRAY_LAPACK_FORTRAN_MANGLE( zgesv )
void LVARRAY_ZGESV(
int const * N,
int const * NRHS,
std::complex< double > * A,
int const * LDA,
int * IPIV,
std::complex< double > * B,
int const * LDB,
int * INFO );

} // extern "C"

namespace LvArray
{
namespace dense
{

char toLapackChar( Operation const op )
{
if( op == Operation::NO_OP ) return 'N';
if( op == Operation::TRANSPOSE ) return 'T';
if( op == Operation::ADJOINT ) return 'C';

LVARRAY_ERROR( "Unknown operation: " << int( op ) );
return '\0';
}


template< typename T >
void BlasLapackInterface< T >::gemm(
Operation opA,
Operation opB,
T const alpha,
Matrix< T const > const & A,
Matrix< T const > const & B,
T const beta,
Matrix< T > const & C )
{
char const TRANSA = toLapackChar( opA );
char const TRANSB = toLapackChar( opB );
int const M = C.sizes[ 0 ];
int const N = C.sizes[ 1 ];
int const K = opA == Operation::NO_OP ? A.sizes[ 1 ] : A.sizes[ 0 ];
int const LDA = std::max( std::ptrdiff_t{ 1 }, A.strides[ 1 ] );
int const LDB = std::max( std::ptrdiff_t{ 1 }, B.strides[ 1 ] );
int const LDC = std::max( std::ptrdiff_t{ 1 }, C.strides[ 1 ] );

TypeDispatch< T >::dispatch( LVARRAY_SGEMM, LVARRAY_DGEMM, LVARRAY_CGEMM, LVARRAY_ZGEMM,
&TRANSA,
&TRANSB,
&M,
&N,
&K,
&alpha,
A.data,
&LDA,
B.data,
&LDB,
&beta,
C.data,
&LDC );
}


template< typename T >
void BlasLapackInterface< T >::gesv(
Matrix< T > const & A,
Matrix< T > const & B,
Vector< int > const & pivots )
{
int const N = A.sizes[ 0 ];
int const NRHS = B.sizes[ 1 ];
int const LDA = A.strides[ 1 ];
int const LDB = B.strides[ 1 ];
int INFO = 0;

TypeDispatch< T >::dispatch( LVARRAY_SGESV, LVARRAY_DGESV, LVARRAY_CGESV, LVARRAY_ZGESV,
&N,
&NRHS,
A.data,
&LDA,
pivots.data,
B.data,
&LDB,
&INFO );

LVARRAY_ERROR_IF( INFO < 0, "The " << -INFO << "-th argument had an illegal value." );
LVARRAY_ERROR_IF( INFO > 0, "The factorization has been completed but U( " << INFO - 1 << ", " << INFO - 1 <<
" ) is exactly zero so the solution could not be computed." );
}

template class BlasLapackInterface< float >;
template class BlasLapackInterface< double >;
template class BlasLapackInterface< std::complex< float > >;
template class BlasLapackInterface< std::complex< double > >;

} // namespace dense
} // namespace LvArray
31 changes: 31 additions & 0 deletions src/dense/BlasLapackInterface.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include "common.hpp"

namespace LvArray
{
namespace dense
{

template< typename T >
struct BlasLapackInterface
{
static constexpr MemorySpace MEMORY_SPACE = MemorySpace::host;

static void gemm(
Operation opA,
Operation opB,
T const alpha,
Matrix< T const > const & A,
Matrix< T const > const & B,
T const beta,
Matrix< T > const & C );

static void gesv(
Matrix< T > const & A,
Matrix< T > const & B,
Vector< int > const & pivots );
};

} // namespace dense
} // namespace LvArray
6 changes: 2 additions & 4 deletions src/dense/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
set( lvarraydense_headers
common.hpp
backendHelpers.hpp
eigenDecomposition.hpp
linearSolve.hpp
BlasLapackInterface.hpp
)

set( lvarraydense_sources
common.cpp
eigenDecomposition.cpp
linearSolve.cpp
BlasLapackInterface.cpp
)

set( dependencies lvarray ${lvarray_dependencies} blas lapack )
Expand Down
78 changes: 74 additions & 4 deletions src/dense/backendHelpers.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,82 @@
#pragma once

#if defined( LVARRAY_USE_MAGMA )
#include <magma.h>
#endif
#include <complex>

/// This macro provide a flexible interface for Fortran naming convention for compiled objects
// #ifdef FORTRAN_MANGLE_NO_UNDERSCORE
#define LVARRAY_LAPACK_FORTRAN_MANGLE( name ) name
// #else
// #define LVARRAY_LAPACK_FORTRAN_MANGLE( name ) name ## _
// #endif
// #endif

namespace LvArray
{
namespace dense
{

template< typename T >
struct TypeDispatch
{};

template<>
struct TypeDispatch< float >
{
template< typename F_FLOAT, typename F_DOUBLE, typename F_CFLOAT, typename F_CDOUBLE, typename ... ARGS >
static constexpr auto dispatch(
F_FLOAT && fFloat,
F_DOUBLE &&,
F_CFLOAT &&,
F_CDOUBLE &&,
ARGS && ... args )
{
return fFloat( std::forward< ARGS >( args ) ... );
}
};

template<>
struct TypeDispatch< double >
{
template< typename F_FLOAT, typename F_DOUBLE, typename F_CFLOAT, typename F_CDOUBLE, typename ... ARGS >
static constexpr auto dispatch(
F_FLOAT &&,
F_DOUBLE && fDouble,
F_CFLOAT &&,
F_CDOUBLE &&,
ARGS && ... args )
{
return fDouble( std::forward< ARGS >( args ) ... );
}
};

template<>
struct TypeDispatch< std::complex< float > >
{
template< typename F_FLOAT, typename F_DOUBLE, typename F_CFLOAT, typename F_CDOUBLE, typename ... ARGS >
static constexpr auto dispatch(
F_FLOAT &&,
F_DOUBLE &&,
F_CFLOAT && fCFloat,
F_CDOUBLE &&,
ARGS && ... args )
{
return fCFloat( std::forward< ARGS >( args ) ... );
}
};

template<>
struct TypeDispatch< std::complex< double > >
{
template< typename F_FLOAT, typename F_DOUBLE, typename F_CFLOAT, typename F_CDOUBLE, typename ... ARGS >
static constexpr auto dispatch(
F_FLOAT &&,
F_DOUBLE &&,
F_CFLOAT &&,
F_CDOUBLE && fCDouble,
ARGS && ... args )
{
return fCDouble( std::forward< ARGS >( args ) ... );
}
};

} // namespace dense
} // namespace LvArray
Loading

0 comments on commit 3725d1a

Please sign in to comment.