diff --git a/src/atom/AtomicCenteredNonLocalOperator.t.cc b/src/atom/AtomicCenteredNonLocalOperator.t.cc index a7cc505c1..214c394b3 100644 --- a/src/atom/AtomicCenteredNonLocalOperator.t.cc +++ b/src/atom/AtomicCenteredNonLocalOperator.t.cc @@ -1642,88 +1642,96 @@ namespace dftfe & sphericalFunctionKetTimesVectorParFlattened, const bool flagCopyResultsToMatrix) { - if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace) + if (d_totalNonLocalEntries > 0) { - const std::vector &atomicNumber = - d_atomCenteredSphericalFunctionContainer->getAtomicNumbers(); - const std::vector atomIdsInProc = - d_atomCenteredSphericalFunctionContainer - ->getAtomIdsInCurrentProcess(); - if (couplingtype == CouplingStructure::diagonal) + if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace) { - unsigned int startIndex = 0; - const unsigned int inc = 1; - for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++) + const std::vector &atomicNumber = + d_atomCenteredSphericalFunctionContainer->getAtomicNumbers(); + const std::vector atomIdsInProc = + d_atomCenteredSphericalFunctionContainer + ->getAtomIdsInCurrentProcess(); + if (couplingtype == CouplingStructure::diagonal) { - const unsigned int atomId = atomIdsInProc[iAtom]; - const unsigned int Znum = atomicNumber[atomId]; - const unsigned int numberSphericalFunctions = - d_atomCenteredSphericalFunctionContainer - ->getTotalNumberOfSphericalFunctionsPerAtom(Znum); + unsigned int startIndex = 0; + const unsigned int inc = 1; + for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++) + { + const unsigned int atomId = atomIdsInProc[iAtom]; + const unsigned int Znum = atomicNumber[atomId]; + const unsigned int numberSphericalFunctions = + d_atomCenteredSphericalFunctionContainer + ->getTotalNumberOfSphericalFunctionsPerAtom(Znum); - for (unsigned int alpha = 0; alpha < numberSphericalFunctions; - alpha++) - { - ValueType nonlocalConstantV = couplingMatrix[startIndex++]; - const unsigned int localId = - sphericalFunctionKetTimesVectorParFlattened - .getMPIPatternP2P() - ->globalToLocal( - d_sphericalFunctionIdsNumberingMapCurrentProcess - .find(std::make_pair(atomId, alpha)) - ->second); - if (flagCopyResultsToMatrix) - { - std::transform( - sphericalFunctionKetTimesVectorParFlattened.begin() + - localId * d_numberWaveFunctions, - sphericalFunctionKetTimesVectorParFlattened.begin() + - localId * d_numberWaveFunctions + - d_numberWaveFunctions, - d_sphericalFnTimesWavefunMatrix[atomId].begin() + - d_numberWaveFunctions * alpha, - [&nonlocalConstantV](auto &a) { - return nonlocalConstantV * a; - }); - } - else + for (unsigned int alpha = 0; + alpha < numberSphericalFunctions; + alpha++) { - d_BLASWrapperPtr->xscal( - sphericalFunctionKetTimesVectorParFlattened.begin() + - localId * d_numberWaveFunctions, - nonlocalConstantV, - d_numberWaveFunctions); + ValueType nonlocalConstantV = + couplingMatrix[startIndex++]; + const unsigned int localId = + sphericalFunctionKetTimesVectorParFlattened + .getMPIPatternP2P() + ->globalToLocal( + d_sphericalFunctionIdsNumberingMapCurrentProcess + .find(std::make_pair(atomId, alpha)) + ->second); + if (flagCopyResultsToMatrix) + { + std::transform( + sphericalFunctionKetTimesVectorParFlattened + .begin() + + localId * d_numberWaveFunctions, + sphericalFunctionKetTimesVectorParFlattened + .begin() + + localId * d_numberWaveFunctions + + d_numberWaveFunctions, + d_sphericalFnTimesWavefunMatrix[atomId].begin() + + d_numberWaveFunctions * alpha, + [&nonlocalConstantV](auto &a) { + return nonlocalConstantV * a; + }); + } + else + { + d_BLASWrapperPtr->xscal( + sphericalFunctionKetTimesVectorParFlattened + .begin() + + localId * d_numberWaveFunctions, + nonlocalConstantV, + d_numberWaveFunctions); + } } } } } - } #if defined(DFTFE_WITH_DEVICE) - else - { - if (couplingtype == CouplingStructure::diagonal) + else { - d_BLASWrapperPtr->stridedBlockScale( - d_numberWaveFunctions, - d_totalNonLocalEntries, - ValueType(1.0), - couplingMatrix.begin(), - sphericalFunctionKetTimesVectorParFlattened.begin()); - } + if (couplingtype == CouplingStructure::diagonal) + { + d_BLASWrapperPtr->stridedBlockScale( + d_numberWaveFunctions, + d_totalNonLocalEntries, + ValueType(1.0), + couplingMatrix.begin(), + sphericalFunctionKetTimesVectorParFlattened.begin()); + } - if (flagCopyResultsToMatrix) - dftfe::AtomicCenteredNonLocalOperatorKernelsDevice:: - copyFromParallelNonLocalVecToAllCellsVec( - d_numberWaveFunctions, - d_totalNonlocalElems, - d_maxSingleAtomContribution, - sphericalFunctionKetTimesVectorParFlattened.begin(), - d_sphericalFnTimesVectorAllCellsDevice.begin(), - d_indexMapFromPaddedNonLocalVecToParallelNonLocalVecDevice - .begin()); - } + if (flagCopyResultsToMatrix) + dftfe::AtomicCenteredNonLocalOperatorKernelsDevice:: + copyFromParallelNonLocalVecToAllCellsVec( + d_numberWaveFunctions, + d_totalNonlocalElems, + d_maxSingleAtomContribution, + sphericalFunctionKetTimesVectorParFlattened.begin(), + d_sphericalFnTimesVectorAllCellsDevice.begin(), + d_indexMapFromPaddedNonLocalVecToParallelNonLocalVecDevice + .begin()); + } #endif + } } template @@ -1734,75 +1742,81 @@ namespace dftfe & sphericalFunctionKetTimesVectorParFlattened, const bool skipComm) { - if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace) + if (d_totalNonLocalEntries > 0) { - const std::vector atomIdsInProc = - d_atomCenteredSphericalFunctionContainer - ->getAtomIdsInCurrentProcess(); - const std::vector &atomicNumber = - d_atomCenteredSphericalFunctionContainer->getAtomicNumbers(); - for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++) + if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace) { - const unsigned int atomId = atomIdsInProc[iAtom]; - unsigned int Znum = atomicNumber[atomId]; - const unsigned int numberSphericalFunctions = + const std::vector atomIdsInProc = d_atomCenteredSphericalFunctionContainer - ->getTotalNumberOfSphericalFunctionsPerAtom(Znum); - for (unsigned int alpha = 0; alpha < numberSphericalFunctions; - alpha++) + ->getAtomIdsInCurrentProcess(); + const std::vector &atomicNumber = + d_atomCenteredSphericalFunctionContainer->getAtomicNumbers(); + for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++) { - const unsigned int id = - d_sphericalFunctionIdsNumberingMapCurrentProcess - .find(std::make_pair(atomId, alpha)) - ->second; - std::memcpy(sphericalFunctionKetTimesVectorParFlattened.data() + - sphericalFunctionKetTimesVectorParFlattened - .getMPIPatternP2P() - ->globalToLocal(id) * - d_numberWaveFunctions, - d_sphericalFnTimesWavefunMatrix[atomId].begin() + - d_numberWaveFunctions * alpha, - d_numberWaveFunctions * sizeof(ValueType)); - - - // d_BLASWrapperPtr->xcopy( - // d_numberWaveFunctions, - // &d_sphericalFnTimesWavefunMatrix[atomId] - // [d_numberWaveFunctions * - // alpha], - // inc, - // sphericalFunctionKetTimesVectorParFlattened.data() + - // sphericalFunctionKetTimesVectorParFlattened.getMPIPatternP2P() - // ->globalToLocal(id) *d_numberWaveFunctions, - // inc); + const unsigned int atomId = atomIdsInProc[iAtom]; + unsigned int Znum = atomicNumber[atomId]; + const unsigned int numberSphericalFunctions = + d_atomCenteredSphericalFunctionContainer + ->getTotalNumberOfSphericalFunctionsPerAtom(Znum); + for (unsigned int alpha = 0; alpha < numberSphericalFunctions; + alpha++) + { + const unsigned int id = + d_sphericalFunctionIdsNumberingMapCurrentProcess + .find(std::make_pair(atomId, alpha)) + ->second; + std::memcpy( + sphericalFunctionKetTimesVectorParFlattened.data() + + sphericalFunctionKetTimesVectorParFlattened + .getMPIPatternP2P() + ->globalToLocal(id) * + d_numberWaveFunctions, + d_sphericalFnTimesWavefunMatrix[atomId].begin() + + d_numberWaveFunctions * alpha, + d_numberWaveFunctions * sizeof(ValueType)); + + + // d_BLASWrapperPtr->xcopy( + // d_numberWaveFunctions, + // &d_sphericalFnTimesWavefunMatrix[atomId] + // [d_numberWaveFunctions * + // alpha], + // inc, + // sphericalFunctionKetTimesVectorParFlattened.data() + + // sphericalFunctionKetTimesVectorParFlattened.getMPIPatternP2P() + // ->globalToLocal(id) *d_numberWaveFunctions, + // inc); + } + } + if (!skipComm) + { + sphericalFunctionKetTimesVectorParFlattened + .accumulateAddLocallyOwned(1); + sphericalFunctionKetTimesVectorParFlattened.updateGhostValues( + 1); } } - if (!skipComm) - { - sphericalFunctionKetTimesVectorParFlattened - .accumulateAddLocallyOwned(1); - sphericalFunctionKetTimesVectorParFlattened.updateGhostValues(1); - } - } #if defined(DFTFE_WITH_DEVICE) - else - { - dftfe::AtomicCenteredNonLocalOperatorKernelsDevice:: - copyToDealiiParallelNonLocalVec( - d_numberWaveFunctions, - d_totalNonLocalEntries, - d_sphericalFnTimesWavefunctionMatrix.begin(), - sphericalFunctionKetTimesVectorParFlattened.begin(), - d_sphericalFnIdsParallelNumberingMapDevice.begin()); - - if (!skipComm) + else { - sphericalFunctionKetTimesVectorParFlattened - .accumulateAddLocallyOwned(1); - sphericalFunctionKetTimesVectorParFlattened.updateGhostValues(1); + dftfe::AtomicCenteredNonLocalOperatorKernelsDevice:: + copyToDealiiParallelNonLocalVec( + d_numberWaveFunctions, + d_totalNonLocalEntries, + d_sphericalFnTimesWavefunctionMatrix.begin(), + sphericalFunctionKetTimesVectorParFlattened.begin(), + d_sphericalFnIdsParallelNumberingMapDevice.begin()); + + if (!skipComm) + { + sphericalFunctionKetTimesVectorParFlattened + .accumulateAddLocallyOwned(1); + sphericalFunctionKetTimesVectorParFlattened.updateGhostValues( + 1); + } } - } #endif + } } template void