Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 75 additions & 4 deletions quaddtype/numpy_quaddtype/src/scalar.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "scalar.h"
#include "scalar_ops.h"
#include "dragon4.h"
#include "dtype.h"

// For IEEE 754 binary128 (quad precision), we need 36 decimal digits
// to guarantee round-trip conversion (string -> parse -> equals original value)
Expand Down Expand Up @@ -42,7 +43,77 @@ QuadPrecision_raw_new(QuadBackendType backend)

QuadPrecisionObject *
QuadPrecision_from_object(PyObject *value, QuadBackendType backend)
{
{
// Handle numpy scalars (np.int32, np.float32, etc.) before arrays
// We need to check this before PySequence_Check because some numpy scalars are sequences
if (PyArray_CheckScalar(value)) {
QuadPrecisionObject *self = QuadPrecision_raw_new(backend);
if (!self)
return NULL;

// Try as floating point first
if (PyArray_IsScalar(value, Floating)) {
PyObject *py_float = PyNumber_Float(value);
if (py_float == NULL) {
Py_DECREF(self);
return NULL;
}
double dval = PyFloat_AsDouble(py_float);
Py_DECREF(py_float);

if (backend == BACKEND_SLEEF) {
self->value.sleef_value = Sleef_cast_from_doubleq1(dval);
}
else {
self->value.longdouble_value = (long double)dval;
}
return self;
}
// Try as integer
else if (PyArray_IsScalar(value, Integer)) {
PyObject *py_int = PyNumber_Long(value);
if (py_int == NULL) {
Py_DECREF(self);
return NULL;
}
long long lval = PyLong_AsLongLong(py_int);
Py_DECREF(py_int);

if (backend == BACKEND_SLEEF) {
self->value.sleef_value = Sleef_cast_from_int64q1(lval);
}
else {
self->value.longdouble_value = (long double)lval;
}
return self;
}
// For other scalar types, fall through to error handling
Py_DECREF(self);
}

// this checks arrays and sequences (array, tuple)
// rejects strings; they're parsed below
if (PyArray_Check(value) || (PySequence_Check(value) && !PyUnicode_Check(value) && !PyBytes_Check(value)))
{
QuadPrecDTypeObject *dtype_descr = new_quaddtype_instance(backend);
if (dtype_descr == NULL) {
return NULL;
}

// steals reference to the descriptor
PyObject *result = PyArray_FromAny(
value,
(PyArray_Descr *)dtype_descr,
0,
0,
NPY_ARRAY_ENSUREARRAY, // this should handle the casting if possible
NULL
);

// PyArray_FromAny steals the reference to dtype_descr, so no need to DECREF
return (QuadPrecisionObject *)result;
}

QuadPrecisionObject *self = QuadPrecision_raw_new(backend);
if (!self)
return NULL;
Expand Down Expand Up @@ -105,21 +176,21 @@ QuadPrecision_from_object(PyObject *value, QuadBackendType backend)
const char *type_cstr = PyUnicode_AsUTF8(type_str);
if (type_cstr != NULL) {
PyErr_Format(PyExc_TypeError,
"QuadPrecision value must be a quad, float, int or string, but got %s "
"QuadPrecision value must be a quad, float, int, string, array or sequence, but got %s "
"instead",
type_cstr);
}
else {
PyErr_SetString(
PyExc_TypeError,
"QuadPrecision value must be a quad, float, int or string, but got an "
"QuadPrecision value must be a quad, float, int, string, array or sequence, but got an "
"unknown type instead");
}
Py_DECREF(type_str);
}
else {
PyErr_SetString(PyExc_TypeError,
"QuadPrecision value must be a quad, float, int or string, but got an "
"QuadPrecision value must be a quad, float, int, string, array or sequence, but got an "
"unknown type instead");
}
Py_DECREF(self);
Expand Down
185 changes: 185 additions & 0 deletions quaddtype/tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,191 @@ def test_create_scalar_simple():
assert isinstance(QuadPrecision(1), QuadPrecision)


class TestQuadPrecisionArrayCreation:
"""Test suite for QuadPrecision array creation from sequences and arrays."""

def test_create_array_from_list(self):
"""Test that QuadPrecision can create arrays from lists."""
# Test with simple list
result = QuadPrecision([3, 4, 5])
assert isinstance(result, np.ndarray)
assert result.dtype.name == "QuadPrecDType128"
assert result.shape == (3,)
np.testing.assert_array_equal(result, np.array([3, 4, 5], dtype=QuadPrecDType(backend='sleef')))

# Test with float list
result = QuadPrecision([1.5, 2.5, 3.5])
assert isinstance(result, np.ndarray)
assert result.dtype.name == "QuadPrecDType128"
assert result.shape == (3,)
np.testing.assert_array_equal(result, np.array([1.5, 2.5, 3.5], dtype=QuadPrecDType(backend='sleef')))

def test_create_array_from_tuple(self):
"""Test that QuadPrecision can create arrays from tuples."""
result = QuadPrecision((10, 20, 30))
assert isinstance(result, np.ndarray)
assert result.dtype.name == "QuadPrecDType128"
assert result.shape == (3,)
np.testing.assert_array_equal(result, np.array([10, 20, 30], dtype=QuadPrecDType(backend='sleef')))

def test_create_array_from_ndarray(self):
"""Test that QuadPrecision can create arrays from numpy arrays."""
arr = np.array([1, 2, 3, 4])
result = QuadPrecision(arr)
assert isinstance(result, np.ndarray)
assert result.dtype.name == "QuadPrecDType128"
assert result.shape == (4,)
np.testing.assert_array_equal(result, arr.astype(QuadPrecDType(backend='sleef')))

def test_create_2d_array_from_nested_list(self):
"""Test that QuadPrecision can create 2D arrays from nested lists."""
result = QuadPrecision([[1, 2], [3, 4]])
assert isinstance(result, np.ndarray)
assert result.dtype.name == "QuadPrecDType128"
assert result.shape == (2, 2)
expected = np.array([[1, 2], [3, 4]], dtype=QuadPrecDType(backend='sleef'))
np.testing.assert_array_equal(result, expected)

def test_create_array_with_backend(self):
"""Test that QuadPrecision respects backend parameter for arrays."""
# Test with sleef backend (default)
result_sleef = QuadPrecision([1, 2, 3], backend='sleef')
assert isinstance(result_sleef, np.ndarray)
assert result_sleef.dtype == QuadPrecDType(backend='sleef')

# Test with longdouble backend
result_ld = QuadPrecision([1, 2, 3], backend='longdouble')
assert isinstance(result_ld, np.ndarray)
assert result_ld.dtype == QuadPrecDType(backend='longdouble')

def test_quad_precision_array_vs_astype_equivalence(self):
"""Test that QuadPrecision(array) is equivalent to array.astype(QuadPrecDType)."""
test_arrays = [
[1, 2, 3],
[1.5, 2.5, 3.5],
[[1, 2], [3, 4]],
np.array([10, 20, 30]),
]

for arr in test_arrays:
result_quad = QuadPrecision(arr)
result_astype = np.array(arr).astype(QuadPrecDType(backend='sleef'))
np.testing.assert_array_equal(result_quad, result_astype)
assert result_quad.dtype == result_astype.dtype

def test_create_empty_array(self):
"""Test that QuadPrecision can create arrays from empty sequences."""
result = QuadPrecision([])
assert isinstance(result, np.ndarray)
assert result.dtype.name == "QuadPrecDType128"
assert result.shape == (0,)
expected = np.array([], dtype=QuadPrecDType(backend='sleef'))
np.testing.assert_array_equal(result, expected)

def test_create_from_numpy_int_scalars(self):
"""Test that QuadPrecision can create scalars from numpy integer types."""
# Test np.int32
result = QuadPrecision(np.int32(42))
assert isinstance(result, QuadPrecision)
assert float(result) == 42.0

# Test np.int64
result = QuadPrecision(np.int64(100))
assert isinstance(result, QuadPrecision)
assert float(result) == 100.0

# Test np.uint32
result = QuadPrecision(np.uint32(255))
assert isinstance(result, QuadPrecision)
assert float(result) == 255.0

# Test np.int8
result = QuadPrecision(np.int8(-128))
assert isinstance(result, QuadPrecision)
assert float(result) == -128.0

def test_create_from_numpy_float_scalars(self):
"""Test that QuadPrecision can create scalars from numpy floating types."""
# Test np.float64
result = QuadPrecision(np.float64(3.14))
assert isinstance(result, QuadPrecision)
assert abs(float(result) - 3.14) < 1e-10

# Test np.float32
result = QuadPrecision(np.float32(2.71))
assert isinstance(result, QuadPrecision)
# Note: float32 has limited precision, so we use a looser tolerance
assert abs(float(result) - 2.71) < 1e-5

# Test np.float16
result = QuadPrecision(np.float16(1.5))
assert isinstance(result, QuadPrecision)
assert abs(float(result) - 1.5) < 1e-3

def test_create_from_zero_dimensional_array(self):
"""Test that QuadPrecision can create from 0-d numpy arrays."""
# 0-d array from scalar
arr_0d = np.array(5.5)
result = QuadPrecision(arr_0d)
assert isinstance(result, np.ndarray)
assert result.shape == () # 0-d array
assert result.dtype.name == "QuadPrecDType128"
expected = np.array(5.5, dtype=QuadPrecDType(backend='sleef'))
np.testing.assert_array_equal(result, expected)

# Another test with integer
arr_0d = np.array(42)
result = QuadPrecision(arr_0d)
assert isinstance(result, np.ndarray)
assert result.shape == ()
expected = np.array(42, dtype=QuadPrecDType(backend='sleef'))
np.testing.assert_array_equal(result, expected)

def test_numpy_scalar_with_backend(self):
"""Test that numpy scalars respect the backend parameter."""
# Test with sleef backend
result = QuadPrecision(np.int32(10), backend='sleef')
assert isinstance(result, QuadPrecision)
assert "backend='sleef'" in repr(result)

# Test with longdouble backend
result = QuadPrecision(np.float64(3.14), backend='longdouble')
assert isinstance(result, QuadPrecision)
assert "backend='longdouble'" in repr(result)

def test_numpy_scalar_types_coverage(self):
"""Test a comprehensive set of numpy scalar types."""
# Integer types
int_types = [
(np.int8, 10),
(np.int16, 1000),
(np.int32, 100000),
(np.int64, 10000000),
(np.uint8, 200),
(np.uint16, 50000),
(np.uint32, 4000000000),
]

for dtype, value in int_types:
result = QuadPrecision(dtype(value))
assert isinstance(result, QuadPrecision), f"Failed for {dtype.__name__}"
assert float(result) == float(value), f"Value mismatch for {dtype.__name__}"

# Float types
float_types = [
(np.float16, 1.5),
(np.float32, 2.5),
(np.float64, 3.5),
]

for dtype, value in float_types:
result = QuadPrecision(dtype(value))
assert isinstance(result, QuadPrecision), f"Failed for {dtype.__name__}"
# Use appropriate tolerance based on dtype precision
expected = float(dtype(value))
assert abs(float(result) - expected) < 1e-5, f"Value mismatch for {dtype.__name__}"


def test_string_roundtrip():
# Test with various values that require full quad precision
test_values = [
Expand Down