Skip to content

Commit

Permalink
Modernize DECREF in py_alltoall_type (psrc).
Browse files Browse the repository at this point in the history
  • Loading branch information
1uc committed Nov 14, 2024
1 parent bbab9a4 commit 42cd898
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions src/nrnpython/nrnpy_p2h.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,16 +773,16 @@ static PyObject* py_broadcast(PyObject* psrc, int root) {
// size for 3, 4, 5 refer to rootrank.
static Object* py_alltoall_type(int size, int type) {
int np = nrnmpi_numprocs; // of subworld communicator
PyObject* psrc = NULL;
nb::object psrc;

if (type == 1 || type == 5) { // alltoall, scatter
Object* o = *hoc_objgetarg(1);
if (type == 1 || nrnmpi_myid == size) { // if scatter only root must be a list
psrc = nrnpy_hoc2pyobject(o);
if (!PyList_Check(psrc)) {
psrc = nb::borrow(nrnpy_hoc2pyobject(o));
if (!PyList_Check(psrc.ptr())) {
hoc_execerror("Argument must be a Python list", 0);
}
if (PyList_Size(psrc) != np) {
if (PyList_Size(psrc.ptr()) != np) {
if (type == 1) {
hoc_execerror("py_alltoall list size must be nhost", 0);
} else {
Expand All @@ -794,7 +794,7 @@ static Object* py_alltoall_type(int size, int type) {
if (type == 1) {
return o;
} else { // return psrc[0]
auto pdest = nb::borrow(PyList_GetItem(psrc, 0));
auto pdest = nb::borrow(PyList_GetItem(psrc.ptr(), 0));
Object* ho = nrnpy_po2ho(pdest.ptr());

Check warning on line 798 in src/nrnpython/nrnpy_p2h.cpp

View check run for this annotation

Codecov / codecov/patch

src/nrnpython/nrnpy_p2h.cpp#L797-L798

Added lines #L797 - L798 were not covered by tests
if (ho) {
--ho->refcount;
Expand All @@ -804,16 +804,15 @@ static Object* py_alltoall_type(int size, int type) {
}
} else {
// Get the raw PyObject* arg. So things like None, int, bool are preserved.
psrc = hocobj_call_arg(0);
Py_INCREF(psrc);
psrc = nb::borrow(hocobj_call_arg(0));

if (np == 1) {
nb::object pdest;
if (type == 4) { // broadcast is just the PyObject
pdest = nb::steal(psrc);
pdest = psrc;
} else { // allgather and gather must wrap psrc in list
pdest = nb::steal(PyList_New(1));
PyList_SetItem(pdest.ptr(), 0, psrc);
PyList_SetItem(pdest.ptr(), 0, psrc.release().ptr());
}
Object* ho = nrnpy_po2ho(pdest.ptr());
if (ho) {
Expand All @@ -829,19 +828,17 @@ static Object* py_alltoall_type(int size, int type) {
PyObject* pdest = NULL;

if (type == 2) {
pdest = py_allgather(psrc);
Py_DECREF(psrc);
pdest = py_allgather(psrc.ptr());
} else if (type != 1 && type != 5) {
root = size;
if (root < 0 || root >= np) {
hoc_execerror("root rank must be >= 0 and < nhost", 0);
}
if (type == 3) {
pdest = py_gather(psrc, root);
pdest = py_gather(psrc.ptr(), root);
} else if (type == 4) {
pdest = py_broadcast(psrc, root);
pdest = py_broadcast(psrc.ptr(), root);
}
Py_DECREF(psrc);
} else {
if (type == 5) { // scatter
root = size;
Expand Down

0 comments on commit 42cd898

Please sign in to comment.