From ecbaf94b5ec94e893442e3197ed1de526cd6e41b Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <frank@sauerburger.com>
Date: Thu, 9 Nov 2017 16:48:17 +0100
Subject: [PATCH] Clean up the code and add null pointer checks

The commit cleans up the code, such that it compiles without any warnings.
Additionally it is ensured that no segmentation faults occurs. There have
been several different types of changes.

 - Unused variables have been removed.
 - A lot of cleanup and documentation has been added in the course of issue #2.
 - A numpy deprecation flag for version 1.12 has been set.
 - Attention has been paid to the difference between PyArrayObject and PyObject
   pointers.
 - Checks to detect memory allocation errors have been added, which would have
   otherwise resulted in segmentation faults.

Closes #7.
---
 sortednpmodule.c | 74 +++++++++++++++++++++++++++++-------------------
 1 file changed, 45 insertions(+), 29 deletions(-)

diff --git a/sortednpmodule.c b/sortednpmodule.c
index 48988af..a95bb7f 100644
--- a/sortednpmodule.c
+++ b/sortednpmodule.c
@@ -1,6 +1,9 @@
 
 #include <stdbool.h>
 #include <Python.h>
+
+#define NPY_NO_DEPRECATED_API NPY_1_12_API_VERSION
+
 #include <numpy/arrayobject.h>
 
 /*
@@ -47,10 +50,11 @@ static PyObject *sortednp_intersect(PyObject *self, PyObject *args) {
         return NULL;
     }
 
-    int nd_a = PyArray_NDIM(a);
-    int nd_b = PyArray_NDIM(b);
+    // Some methods need a PyObject* other nee a PyArrayObject*.
+    PyArrayObject *a_array = (PyArrayObject*) a;
+    PyArrayObject *b_array = (PyArrayObject*) b;
 
-    if (PyArray_NDIM(a) != 1 || PyArray_NDIM(b) != 1) {
+    if (PyArray_NDIM(a_array) != 1 || PyArray_NDIM(b_array) != 1) {
       PyErr_SetString(PyExc_ValueError, "Arguments can not be multi-dimensional.");
       // Reference counter of input arrays have been fixed. It is safe to exit.
       return NULL;
@@ -58,30 +62,35 @@ static PyObject *sortednp_intersect(PyObject *self, PyObject *args) {
 
     // Since the size of the intersection array can not be known in advance we
     // need to create an array of at least the size of the smaller array.
-    npy_intp len_a = PyArray_DIMS(a)[0];
-    npy_intp len_b = PyArray_DIMS(b)[0];
+    npy_intp len_a = PyArray_DIMS(a_array)[0];
+    npy_intp len_b = PyArray_DIMS(b_array)[0];
     npy_intp new_dim[1] = {len_a < len_b ? len_a : len_b};
 
     // Creating the new array sets the reference counter to 1 and passes the
     // ownership of the returned reference to the caller. The method steals the
     // type descriptor, which is why we have to increment its count before
     // calling the method.
-    PyArray_Descr* type = PyArray_DESCR(a);
+    PyArray_Descr* type = PyArray_DESCR(a_array);
     Py_INCREF(type);
-    PyArrayObject *out;
+    PyObject *out;
     out = PyArray_SimpleNewFromDescr(1, new_dim, type);
+    if (out == NULL) {
+        // Probably a memory error occurred.
+        return NULL;
+    }
+    PyArrayObject* out_array = (PyArrayObject*) out;
 
     npy_intp i_a = 0;
     npy_intp i_b = 0;
     npy_intp i_o = 0;
-    double v_a = *((double*) PyArray_GETPTR1(a, i_a));
-    double v_b = *((double*) PyArray_GETPTR1(b, i_b));
+    double v_a = *((double*) PyArray_GETPTR1(a_array, i_a));
+    double v_b = *((double*) PyArray_GETPTR1(b_array, i_b));
 
     // Actual computation of the intersection.
     while (i_a < len_a && i_b < len_b) {
         bool matched = false;
         if (v_a == v_b) {
-          double *t = (double*) PyArray_GETPTR1(out, i_o);
+          double *t = (double*) PyArray_GETPTR1(out_array, i_o);
           *t = v_a;
 
           i_o++;
@@ -90,12 +99,12 @@ static PyObject *sortednp_intersect(PyObject *self, PyObject *args) {
         
         if (v_a < v_b || matched) {
             i_a++;
-            v_a = *((double*) PyArray_GETPTR1(a, i_a));
+            v_a = *((double*) PyArray_GETPTR1(a_array, i_a));
         }
 
         if (v_b < v_a || matched) {
             i_b++;
-            v_b = *((double*) PyArray_GETPTR1(b, i_b));
+            v_b = *((double*) PyArray_GETPTR1(b_array, i_b));
         }
     }
 
@@ -104,7 +113,7 @@ static PyObject *sortednp_intersect(PyObject *self, PyObject *args) {
     PyArray_Dims dims;
     dims.ptr = new_dim;
     dims.len = 1;
-    PyArray_Resize(out, &dims, 0, NPY_CORDER);
+    PyArray_Resize(out_array, &dims, 0, NPY_CORDER);
 
     // Passes ownership of the returned reference to the  caller.
     return out;
@@ -153,10 +162,12 @@ static PyObject *sortednp_merge(PyObject *self, PyObject *args) {
         return NULL;
     }
 
-    int nd_a = PyArray_NDIM(a);
-    int nd_b = PyArray_NDIM(b);
+    // Some methods need a PyObject* other nee a PyArrayObject*.
+    PyArrayObject *a_array = (PyArrayObject*) a;
+    PyArrayObject *b_array = (PyArrayObject*) b;
 
-    if (PyArray_NDIM(a) != 1 || PyArray_NDIM(b) != 1) {
+
+    if (PyArray_NDIM(a_array) != 1 || PyArray_NDIM(b_array) != 1) {
       PyErr_SetString(PyExc_ValueError, "Arguments can not be multi-dimensional.");
       // Reference counter of input arrays have been fixed. It is safe to exit.
       return NULL;
@@ -165,54 +176,59 @@ static PyObject *sortednp_merge(PyObject *self, PyObject *args) {
     // Since the size of the merged array can not be known in advance we
     // need to create an array of at least the size of the concatenation of both
     // arrays.
-    npy_intp len_a = PyArray_DIMS(a)[0];
-    npy_intp len_b = PyArray_DIMS(b)[0];
+    npy_intp len_a = PyArray_DIMS(a_array)[0];
+    npy_intp len_b = PyArray_DIMS(b_array)[0];
     npy_intp new_dim[1] = {len_a + len_b};
 
     // Creating the new array sets the reference counter to 1 and passes the
     // ownership of the returned reference to the caller. The method steals the
     // type descriptor, which is why we have to increment its count before
     // calling the method.
-    PyArray_Descr* type = PyArray_DESCR(a);
+    PyArray_Descr* type = PyArray_DESCR(a_array);
     Py_INCREF(type);
-    PyArrayObject *out;
+    PyObject *out;
     out = PyArray_SimpleNewFromDescr(1, new_dim, type);
+    if (out == NULL) {
+        // Probably a memory error occurred.
+        return NULL;
+    }
+    PyArrayObject* out_array = (PyArrayObject*) out;
 
     npy_intp i_a = 0;
     npy_intp i_b = 0;
     npy_intp i_o = 0;
-    double v_a = *((double*) PyArray_GETPTR1(a, i_a));
-    double v_b = *((double*) PyArray_GETPTR1(b, i_b));
+    double v_a = *((double*) PyArray_GETPTR1(a_array, i_a));
+    double v_b = *((double*) PyArray_GETPTR1(b_array, i_b));
 
     // Actually merging the arrays.
     while (i_a < len_a && i_b < len_b) {
-        double *t = (double*) PyArray_GETPTR1(out, i_o);
+        double *t = (double*) PyArray_GETPTR1(out_array, i_o);
 
         if (v_a < v_b) {
             *t = v_a;
             i_a++;
             i_o++;
-            v_a = *((double*) PyArray_GETPTR1(a, i_a));
+            v_a = *((double*) PyArray_GETPTR1(a_array, i_a));
         } else {
             *t = v_b;
             i_b++;
             i_o++;
-            v_b = *((double*) PyArray_GETPTR1(b, i_b));
+            v_b = *((double*) PyArray_GETPTR1(b_array, i_b));
         }
     }
 
     // If the end of one of the two arrays has been reached in the above loop,
     // we need to copy all the elements left the array to the output.
     while (i_a < len_a) {
-        double v_a = *((double*) PyArray_GETPTR1(a, i_a));
-        double *t = (double*) PyArray_GETPTR1(out, i_o);
+        double v_a = *((double*) PyArray_GETPTR1(a_array, i_a));
+        double *t = (double*) PyArray_GETPTR1(out_array, i_o);
         *t = v_a;
         i_a++;
         i_o++;
     }
     while (i_b < len_b) {
-        double v_b = *((double*) PyArray_GETPTR1(b, i_b));
-        double *t = (double*) PyArray_GETPTR1(out, i_o);
+        double v_b = *((double*) PyArray_GETPTR1(b_array, i_b));
+        double *t = (double*) PyArray_GETPTR1(out_array, i_o);
         *t = v_b;
         i_b++;
         i_o++;
-- 
GitLab