Skip to content

Commit

Permalink
Merged in nonLocalOperatorBugFix (pull request #593)
Browse files Browse the repository at this point in the history
Bug fix in nonlocalOpearotr when nonLocalEntries =0

Approved-by: Sambit Das
  • Loading branch information
phanimotamarri authored and dsambit committed May 3, 2024
2 parents 538a294 + 5c74937 commit 1bfaec9
Showing 1 changed file with 141 additions and 127 deletions.
268 changes: 141 additions & 127 deletions src/atom/AtomicCenteredNonLocalOperator.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1642,88 +1642,96 @@ namespace dftfe
& sphericalFunctionKetTimesVectorParFlattened,
const bool flagCopyResultsToMatrix)
{
if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace)
if (d_totalNonLocalEntries > 0)
{
const std::vector<unsigned int> &atomicNumber =
d_atomCenteredSphericalFunctionContainer->getAtomicNumbers();
const std::vector<unsigned int> atomIdsInProc =
d_atomCenteredSphericalFunctionContainer
->getAtomIdsInCurrentProcess();
if (couplingtype == CouplingStructure::diagonal)
if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace)
{
unsigned int startIndex = 0;
const unsigned int inc = 1;
for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++)
const std::vector<unsigned int> &atomicNumber =
d_atomCenteredSphericalFunctionContainer->getAtomicNumbers();
const std::vector<unsigned int> atomIdsInProc =
d_atomCenteredSphericalFunctionContainer
->getAtomIdsInCurrentProcess();
if (couplingtype == CouplingStructure::diagonal)
{
const unsigned int atomId = atomIdsInProc[iAtom];
const unsigned int Znum = atomicNumber[atomId];
const unsigned int numberSphericalFunctions =
d_atomCenteredSphericalFunctionContainer
->getTotalNumberOfSphericalFunctionsPerAtom(Znum);
unsigned int startIndex = 0;
const unsigned int inc = 1;
for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++)
{
const unsigned int atomId = atomIdsInProc[iAtom];
const unsigned int Znum = atomicNumber[atomId];
const unsigned int numberSphericalFunctions =
d_atomCenteredSphericalFunctionContainer
->getTotalNumberOfSphericalFunctionsPerAtom(Znum);


for (unsigned int alpha = 0; alpha < numberSphericalFunctions;
alpha++)
{
ValueType nonlocalConstantV = couplingMatrix[startIndex++];
const unsigned int localId =
sphericalFunctionKetTimesVectorParFlattened
.getMPIPatternP2P()
->globalToLocal(
d_sphericalFunctionIdsNumberingMapCurrentProcess
.find(std::make_pair(atomId, alpha))
->second);
if (flagCopyResultsToMatrix)
{
std::transform(
sphericalFunctionKetTimesVectorParFlattened.begin() +
localId * d_numberWaveFunctions,
sphericalFunctionKetTimesVectorParFlattened.begin() +
localId * d_numberWaveFunctions +
d_numberWaveFunctions,
d_sphericalFnTimesWavefunMatrix[atomId].begin() +
d_numberWaveFunctions * alpha,
[&nonlocalConstantV](auto &a) {
return nonlocalConstantV * a;
});
}
else
for (unsigned int alpha = 0;
alpha < numberSphericalFunctions;
alpha++)
{
d_BLASWrapperPtr->xscal(
sphericalFunctionKetTimesVectorParFlattened.begin() +
localId * d_numberWaveFunctions,
nonlocalConstantV,
d_numberWaveFunctions);
ValueType nonlocalConstantV =
couplingMatrix[startIndex++];
const unsigned int localId =
sphericalFunctionKetTimesVectorParFlattened
.getMPIPatternP2P()
->globalToLocal(
d_sphericalFunctionIdsNumberingMapCurrentProcess
.find(std::make_pair(atomId, alpha))
->second);
if (flagCopyResultsToMatrix)
{
std::transform(
sphericalFunctionKetTimesVectorParFlattened
.begin() +
localId * d_numberWaveFunctions,
sphericalFunctionKetTimesVectorParFlattened
.begin() +
localId * d_numberWaveFunctions +
d_numberWaveFunctions,
d_sphericalFnTimesWavefunMatrix[atomId].begin() +
d_numberWaveFunctions * alpha,
[&nonlocalConstantV](auto &a) {
return nonlocalConstantV * a;
});
}
else
{
d_BLASWrapperPtr->xscal(
sphericalFunctionKetTimesVectorParFlattened
.begin() +
localId * d_numberWaveFunctions,
nonlocalConstantV,
d_numberWaveFunctions);
}
}
}
}
}
}
#if defined(DFTFE_WITH_DEVICE)
else
{
if (couplingtype == CouplingStructure::diagonal)
else
{
d_BLASWrapperPtr->stridedBlockScale(
d_numberWaveFunctions,
d_totalNonLocalEntries,
ValueType(1.0),
couplingMatrix.begin(),
sphericalFunctionKetTimesVectorParFlattened.begin());
}
if (couplingtype == CouplingStructure::diagonal)
{
d_BLASWrapperPtr->stridedBlockScale(
d_numberWaveFunctions,
d_totalNonLocalEntries,
ValueType(1.0),
couplingMatrix.begin(),
sphericalFunctionKetTimesVectorParFlattened.begin());
}

if (flagCopyResultsToMatrix)
dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
copyFromParallelNonLocalVecToAllCellsVec(
d_numberWaveFunctions,
d_totalNonlocalElems,
d_maxSingleAtomContribution,
sphericalFunctionKetTimesVectorParFlattened.begin(),
d_sphericalFnTimesVectorAllCellsDevice.begin(),
d_indexMapFromPaddedNonLocalVecToParallelNonLocalVecDevice
.begin());
}
if (flagCopyResultsToMatrix)
dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
copyFromParallelNonLocalVecToAllCellsVec(
d_numberWaveFunctions,
d_totalNonlocalElems,
d_maxSingleAtomContribution,
sphericalFunctionKetTimesVectorParFlattened.begin(),
d_sphericalFnTimesVectorAllCellsDevice.begin(),
d_indexMapFromPaddedNonLocalVecToParallelNonLocalVecDevice
.begin());
}
#endif
}
}

template <typename ValueType, dftfe::utils::MemorySpace memorySpace>
Expand All @@ -1734,75 +1742,81 @@ namespace dftfe
& sphericalFunctionKetTimesVectorParFlattened,
const bool skipComm)
{
if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace)
if (d_totalNonLocalEntries > 0)
{
const std::vector<unsigned int> atomIdsInProc =
d_atomCenteredSphericalFunctionContainer
->getAtomIdsInCurrentProcess();
const std::vector<unsigned int> &atomicNumber =
d_atomCenteredSphericalFunctionContainer->getAtomicNumbers();
for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++)
if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace)
{
const unsigned int atomId = atomIdsInProc[iAtom];
unsigned int Znum = atomicNumber[atomId];
const unsigned int numberSphericalFunctions =
const std::vector<unsigned int> atomIdsInProc =
d_atomCenteredSphericalFunctionContainer
->getTotalNumberOfSphericalFunctionsPerAtom(Znum);
for (unsigned int alpha = 0; alpha < numberSphericalFunctions;
alpha++)
->getAtomIdsInCurrentProcess();
const std::vector<unsigned int> &atomicNumber =
d_atomCenteredSphericalFunctionContainer->getAtomicNumbers();
for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++)
{
const unsigned int id =
d_sphericalFunctionIdsNumberingMapCurrentProcess
.find(std::make_pair(atomId, alpha))
->second;
std::memcpy(sphericalFunctionKetTimesVectorParFlattened.data() +
sphericalFunctionKetTimesVectorParFlattened
.getMPIPatternP2P()
->globalToLocal(id) *
d_numberWaveFunctions,
d_sphericalFnTimesWavefunMatrix[atomId].begin() +
d_numberWaveFunctions * alpha,
d_numberWaveFunctions * sizeof(ValueType));


// d_BLASWrapperPtr->xcopy(
// d_numberWaveFunctions,
// &d_sphericalFnTimesWavefunMatrix[atomId]
// [d_numberWaveFunctions *
// alpha],
// inc,
// sphericalFunctionKetTimesVectorParFlattened.data() +
// sphericalFunctionKetTimesVectorParFlattened.getMPIPatternP2P()
// ->globalToLocal(id) *d_numberWaveFunctions,
// inc);
const unsigned int atomId = atomIdsInProc[iAtom];
unsigned int Znum = atomicNumber[atomId];
const unsigned int numberSphericalFunctions =
d_atomCenteredSphericalFunctionContainer
->getTotalNumberOfSphericalFunctionsPerAtom(Znum);
for (unsigned int alpha = 0; alpha < numberSphericalFunctions;
alpha++)
{
const unsigned int id =
d_sphericalFunctionIdsNumberingMapCurrentProcess
.find(std::make_pair(atomId, alpha))
->second;
std::memcpy(
sphericalFunctionKetTimesVectorParFlattened.data() +
sphericalFunctionKetTimesVectorParFlattened
.getMPIPatternP2P()
->globalToLocal(id) *
d_numberWaveFunctions,
d_sphericalFnTimesWavefunMatrix[atomId].begin() +
d_numberWaveFunctions * alpha,
d_numberWaveFunctions * sizeof(ValueType));


// d_BLASWrapperPtr->xcopy(
// d_numberWaveFunctions,
// &d_sphericalFnTimesWavefunMatrix[atomId]
// [d_numberWaveFunctions *
// alpha],
// inc,
// sphericalFunctionKetTimesVectorParFlattened.data() +
// sphericalFunctionKetTimesVectorParFlattened.getMPIPatternP2P()
// ->globalToLocal(id) *d_numberWaveFunctions,
// inc);
}
}
if (!skipComm)
{
sphericalFunctionKetTimesVectorParFlattened
.accumulateAddLocallyOwned(1);
sphericalFunctionKetTimesVectorParFlattened.updateGhostValues(
1);
}
}
if (!skipComm)
{
sphericalFunctionKetTimesVectorParFlattened
.accumulateAddLocallyOwned(1);
sphericalFunctionKetTimesVectorParFlattened.updateGhostValues(1);
}
}
#if defined(DFTFE_WITH_DEVICE)
else
{
dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
copyToDealiiParallelNonLocalVec(
d_numberWaveFunctions,
d_totalNonLocalEntries,
d_sphericalFnTimesWavefunctionMatrix.begin(),
sphericalFunctionKetTimesVectorParFlattened.begin(),
d_sphericalFnIdsParallelNumberingMapDevice.begin());

if (!skipComm)
else
{
sphericalFunctionKetTimesVectorParFlattened
.accumulateAddLocallyOwned(1);
sphericalFunctionKetTimesVectorParFlattened.updateGhostValues(1);
dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
copyToDealiiParallelNonLocalVec(
d_numberWaveFunctions,
d_totalNonLocalEntries,
d_sphericalFnTimesWavefunctionMatrix.begin(),
sphericalFunctionKetTimesVectorParFlattened.begin(),
d_sphericalFnIdsParallelNumberingMapDevice.begin());

if (!skipComm)
{
sphericalFunctionKetTimesVectorParFlattened
.accumulateAddLocallyOwned(1);
sphericalFunctionKetTimesVectorParFlattened.updateGhostValues(
1);
}
}
}
#endif
}
}
template <typename ValueType, dftfe::utils::MemorySpace memorySpace>
void
Expand Down

0 comments on commit 1bfaec9

Please sign in to comment.