diff --git a/src/include/proxy.h b/src/include/proxy.h index eab6930fe..b79aa6600 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -194,7 +194,7 @@ struct ncclProxyProgressState { pthread_t thread; volatile int stop; struct ncclProxyPeer** localPeers; - struct ncclSharedNetComms* netComms[NCCL_MAX_NETDEVS]; + struct ncclSharedNetComms* sharedNetComms[NCCL_MAX_NETDEVS]; struct ncclProxyArgs* active; struct ncclProxyArgs* pool; struct ncclProxyPool* pools; diff --git a/src/transport/net.cc b/src/transport/net.cc index d5a585d42..e99134fda 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -681,10 +681,10 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { // Connect or reuse connection for a netdev/remote rank. - if (progressState->netComms[resources->netDev] == NULL) { - NCCLCHECK(ncclCalloc(progressState->netComms + resources->netDev, proxyState->tpnRanks)); + if (progressState->sharedNetComms[resources->netDev] == NULL) { + NCCLCHECK(ncclCalloc(progressState->sharedNetComms + resources->netDev, proxyState->tpnRanks)); } - struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteRank; + struct ncclSharedNetComms* comms = progressState->sharedNetComms[resources->netDev] + resources->tpRemoteRank; if (comms->sendComm[resources->channelId] == NULL) ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, &resources->netDeviceHandle); resources->netSendComm = comms->sendComm[resources->channelId]; if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++; @@ -831,10 +831,10 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { // Connect or reuse connection for a netdev/remote rank. - if (progressState->netComms[resources->netDev] == NULL) { - NCCLCHECK(ncclCalloc(progressState->netComms + resources->netDev, proxyState->tpnRanks)); + if (progressState->sharedNetComms[resources->netDev] == NULL) { + NCCLCHECK(ncclCalloc(progressState->sharedNetComms + resources->netDev, proxyState->tpnRanks)); } - struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteProxyRank; + struct ncclSharedNetComms* comms = progressState->sharedNetComms[resources->netDev] + resources->tpRemoteProxyRank; if (comms->recvComm[resources->channelId] == NULL) ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, &resources->netDeviceHandle); resources->netRecvComm = comms->recvComm[resources->channelId]; if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++; @@ -981,7 +981,7 @@ static ncclResult_t sendProxyFree(struct ncclProxyConnection* connection, struct if (resources->shared) { NCCLCHECK(sharedNetBuffersDestroy(proxyState, resources->tpLocalRank, 0, connection)); if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { - struct ncclSharedNetComms* comms = proxyState->progressState.netComms[resources->netDev]+resources->tpRemoteRank; + struct ncclSharedNetComms* comms = proxyState->progressState.sharedNetComms[resources->netDev]+resources->tpRemoteRank; comms->sendRefCount[resources->channelId]--; if (comms->sendRefCount[resources->channelId] == 0) NCCLCHECK(proxyState->ncclNet->closeSend(comms->sendComm[resources->channelId])); } else { @@ -1022,7 +1022,7 @@ static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct if (resources->shared) { NCCLCHECK(sharedNetBuffersDestroy(proxyState, resources->tpLocalRank, 1, connection)); if (resources->maxRecvs > 1 && ncclParamNetSharedComms()) { - struct ncclSharedNetComms* comms = proxyState->progressState.netComms[resources->netDev] + resources->tpRemoteProxyRank; + struct ncclSharedNetComms* comms = proxyState->progressState.sharedNetComms[resources->netDev] + resources->tpRemoteProxyRank; comms->recvRefCount[resources->channelId]--; if (comms->recvRefCount[resources->channelId] == 0) NCCLCHECK(proxyState->ncclNet->closeRecv(comms->recvComm[resources->channelId])); } else {