diff --git a/numpower.c b/numpower.c index fcc4989..db93b55 100644 --- a/numpower.c +++ b/numpower.c @@ -2894,6 +2894,48 @@ PHP_METHOD(NDArray, variance) { RETURN_NDARRAY(rtn, return_value); } +/** + * NDArray::cov + * + * @param execute_data + * @param return_value + */ +ZEND_BEGIN_ARG_INFO_EX(arginfo_ndarray_cov, 0, 0, 1) +ZEND_ARG_INFO(0, array) +ZEND_ARG_INFO(0, rowvar) +ZEND_END_ARG_INFO() +PHP_METHOD(NDArray, cov) { + NDArray *rtn = NULL; + zval *array; + bool rowvar = true; + ZEND_PARSE_PARAMETERS_START(1, 2) + Z_PARAM_ZVAL(array) + Z_PARAM_OPTIONAL + Z_PARAM_BOOL(rowvar) + ZEND_PARSE_PARAMETERS_END(); + NDArray *nda = ZVAL_TO_NDARRAY(array); + if (nda == NULL) { + return; + } + + if (NDArray_DEVICE(nda) == NDARRAY_DEVICE_CPU) { + rtn = NDArray_cov(nda, rowvar); + } else { +#ifdef HAVE_CUBLAS + rtn = NDArray_cov(nda, rowvar); +#else + zend_throw_error(NULL, "GPU operations unavailable. CUBLAS not detected."); +#endif + } + if (rtn == NULL) { + return; + } + if (Z_TYPE_P(array) == IS_ARRAY) { + NDArray_FREE(nda); + } + RETURN_NDARRAY(rtn, return_value); +} + /** * NDArray::ceil * @@ -5180,6 +5222,7 @@ static const zend_function_entry class_NDArray_methods[] = { ZEND_ME(NDArray, average, arginfo_ndarray_average, ZEND_ACC_PUBLIC | ZEND_ACC_STATIC) ZEND_ME(NDArray, std, arginfo_ndarray_std, ZEND_ACC_PUBLIC | ZEND_ACC_STATIC) ZEND_ME(NDArray, quantile, arginfo_ndarray_quantile, ZEND_ACC_PUBLIC | ZEND_ACC_STATIC) + ZEND_ME(NDArray, cov, arginfo_ndarray_cov, ZEND_ACC_PUBLIC | ZEND_ACC_STATIC) // ARITHMETICS ZEND_ME(NDArray, add, arginfo_ndarray_add, ZEND_ACC_PUBLIC | ZEND_ACC_STATIC) diff --git a/src/ndmath/statistics.c b/src/ndmath/statistics.c index 5bd8aff..a4ab217 100644 --- a/src/ndmath/statistics.c +++ b/src/ndmath/statistics.c @@ -3,6 +3,8 @@ #include "string.h" #include "../initializers.h" #include "arithmetics.h" +#include "../manipulation.h" +#include "linalg.h" // Comparison function for sorting int compare_quantile(const void* a, const void* b) { @@ -151,4 +153,68 @@ NDArray_Average(NDArray *a, NDArray *weights) { NDArray_FREE(m_weights); } return rtn; +} + +/** + * NDArray::cov + * + * @param a + * @return + */ +NDArray *NDArray_cov(NDArray *a, bool rowvar) +{ + if (!rowvar) { + a = NDArray_Transpose(a, NULL); + } + + if (a == NULL || NDArray_NUMELEMENTS(a) == 0) + { + zend_throw_error(NULL, "Input cannot be null or empty."); + return NULL; + } + if (NDArray_NDIM(a) != 2 || NDArray_SHAPE(a)[1] == 1) + { + zend_throw_error(NULL, "Input must be a 2D NDArray."); + return NULL; + } + + int cols = NDArray_SHAPE(a)[0]; + int rows = NDArray_SHAPE(a)[1]; + + int *indices_shape = emalloc(sizeof(int) * 2); + indices_shape[0] = 2; + indices_shape[1] = 1; + + NDArray** indices_axis = emalloc(sizeof(NDArray*) * 2); + indices_axis[0] = NDArray_Zeros(indices_shape, 1, NDArray_TYPE(a), NDArray_DEVICE(a)); + indices_axis[1] = NDArray_Zeros(indices_shape, 1, NDArray_TYPE(a), NDArray_DEVICE(a)); + + NDArray_FDATA(indices_axis[1])[0] = 0; + NDArray_FDATA(indices_axis[1])[1] = rows; + + NDArray **centered_vectors = emalloc(sizeof(NDArray *) * cols); + for (int i = 0; i < cols; i++) + { + NDArray_FDATA(indices_axis[0])[0] = i; + NDArray_FDATA(indices_axis[0])[1] = i + 1; + NDArray *col_vector = NDArray_Slice(a, indices_axis, 2); + NDArray *centered = NDArray_Subtract_Float(col_vector, NDArray_CreateFromFloatScalar(NDArray_Sum_Float(col_vector) / NDArray_NUMELEMENTS(col_vector))); + NDArray_FREE(col_vector); + centered_vectors[i] = centered; + } + efree(indices_shape); + efree(indices_axis[0]); + efree(indices_axis[1]); + efree(indices_axis); + NDArray *centered_a = NDArray_Reshape(NDArray_ConcatenateFlat(centered_vectors, cols), NDArray_SHAPE(a), NDArray_NDIM(a)); + for (int i = 0; i < cols; i++) + { + NDArray_FREE(centered_vectors[i]); + } + efree(centered_vectors); + NDArray *multiplied = NDArray_Dot(centered_a, NDArray_Transpose(centered_a, NULL)); + NDArray_FREE(centered_a); + NDArray *rtn = NDArray_Divide_Float(multiplied, NDArray_CreateFromFloatScalar((float)rows - 1)); + NDArray_FREE(multiplied); + return rtn; } \ No newline at end of file diff --git a/src/ndmath/statistics.h b/src/ndmath/statistics.h index bb4c643..d9c645a 100644 --- a/src/ndmath/statistics.h +++ b/src/ndmath/statistics.h @@ -7,5 +7,6 @@ NDArray* NDArray_Quantile(NDArray *target, NDArray *q); NDArray* NDArray_Std(NDArray *a); NDArray* NDArray_Variance(NDArray *a); NDArray* NDArray_Average(NDArray *a, NDArray *weights); +NDArray* NDArray_cov(NDArray *a, bool rowvar); #endif //NUMPOWER_STATISTICS_H diff --git a/tests/math/048-ndarray-cov.phpt b/tests/math/048-ndarray-cov.phpt new file mode 100644 index 0000000..0c501f5 --- /dev/null +++ b/tests/math/048-ndarray-cov.phpt @@ -0,0 +1,127 @@ +--TEST-- +NDArray::cov +--FILE-- +toArray()); +$b = \NDArray::array([[1, 2, 3, 4], [5, 4, 3, 2]]); +print_r(\NDArray::cov($b)->toArray()); +$c = \NDArray::array([[1, 2, 3, 4], [5, 6, 7, 8]]); +print_r(\NDArray::cov($c)->toArray()); +$d = \NDArray::array([[1, 2, 3, 4], [1, 2, 3, 4]]); +print_r(\NDArray::cov($d)->toArray()); +$e = \NDArray::array([[1, 2, 3, 4]]); +print_r(\NDArray::cov($e)->toArray()); +$f = \NDArray::array([[0, 0, 0, 0], [0, 0, 0, 0]]); +print_r(\NDArray::cov($f)->toArray()); +$g = \NDArray::array([[3, 7, 8], [2, 4, 3]]); +print_r(\NDArray::cov($g, False)->toArray()); +?> +--EXPECT-- +Array +( + [0] => Array + ( + [0] => 7 + [1] => 2 + ) + + [1] => Array + ( + [0] => 2 + [1] => 1 + ) + +) +Array +( + [0] => Array + ( + [0] => 1.6666666269302 + [1] => -1.6666666269302 + ) + + [1] => Array + ( + [0] => -1.6666666269302 + [1] => 1.6666666269302 + ) + +) +Array +( + [0] => Array + ( + [0] => 1.6666666269302 + [1] => 1.6666666269302 + ) + + [1] => Array + ( + [0] => 1.6666666269302 + [1] => 1.6666666269302 + ) + +) +Array +( + [0] => Array + ( + [0] => 1.6666666269302 + [1] => 1.6666666269302 + ) + + [1] => Array + ( + [0] => 1.6666666269302 + [1] => 1.6666666269302 + ) + +) +Array +( + [0] => Array + ( + [0] => 1.6666666269302 + ) + +) +Array +( + [0] => Array + ( + [0] => 0 + [1] => 0 + ) + + [1] => Array + ( + [0] => 0 + [1] => 0 + ) + +) +Array +( + [0] => Array + ( + [0] => 0.5 + [1] => 1.5 + [2] => 2.5 + ) + + [1] => Array + ( + [0] => 1.5 + [1] => 4.5 + [2] => 7.5 + ) + + [2] => Array + ( + [0] => 2.5 + [1] => 7.5 + [2] => 12.5 + ) + +) \ No newline at end of file