1
0

N進数の開平法を使った多倍長整数の平方根(拡張モジュール)

Last updated at Posted at 2022-03-12

Python/C API の PyLongObject 型用平方根(拡張モジュール)

Python での多倍長整数 int 型(C では PyLongObject 型) の平方根を整数の範囲で求めます。

64ビット版 Python では 230 進数(PyLongObject での単位)の開平法で計算しています。32ビット版 Python では 215 進数となるハズですが動作確認していないので、恐らくコンパイルできなかったりバグったりします。

除算は C 言語の __int128 (コンパイラ拡張) の範囲に絞り、多倍長整数では行っていませんから、それなりの速さで計算してくれるかと思います。

拡張モジュールのメソッドは1つだけ
intsqrt(x: int) -> int:           int(x)
intsqrt(x: int, True) -> tuple:   (int(x), x - int(x)**2)

オプションで余りを取得できます。

ビルドと実行

distutils の setup.py でモジュールを作ります。

拡張モジュールとしては、メソッドが1つだけという簡単なものになってます。

動作確認している環境

OS
  macOS 12.2.1

C コンパイラ (Xcode 13.2.1)

  Apple clang version 13.0.0 (clang-1300.0.29.30)
  Target: x86_64-apple-darwin21.3.0

Python : /usr/bin/python3

  Python 3.8.9 (default, Oct 26 2021, 07:25:54)
  [Clang 13.0.0 (clang-1300.0.29.30)] on darwin

intsqrt モジュール (intsqrt.c と setup.py : 折りたたみ)
intsqrt.c
/* -*- coding: utf-8; -*- */

#ifndef PY_SSIZE_T_CLEAN
#define PY_SSIZE_T_CLEAN
#endif /* PY_SSIZE_T_CLEAN */

#include <Python.h>

/*
 *
 */

typedef digit digit1_t;
typedef sdigit sdigit1_t;
typedef twodigits digit2_t;
typedef stwodigits sdigit2_t;

#if defined(__GNUC__)
#  if PyLong_SHIFT == 15
typedef __int64_t digit4_t;
#  elif PyLong_SHIFT == 30
typedef __int128_t digit4_t;
#  else
#    error "unknown PyLong_SHIFT"
#  endif

typedef struct digit4_div_t {
    digit4_t quot;
    digit4_t rem;
} digit4_div_t;

#else
#  error "unknown compiler"
#endif

/*
 *
 */

inline static PyObject *
IncRef(PyObject *object)
{
    Py_INCREF(object);
    return object;
}

inline static PyLongObject *
NewLong(Py_ssize_t len)
{
    PyLongObject *obj;

    if (!(obj = PyObject_NewVar(PyLongObject, &PyLong_Type, len)))
        return NULL;
    if (!PyObject_InitVar((PyVarObject *) obj, Py_TYPE(obj), len)) {
        Py_XDECREF(obj);
        return NULL;
    }
    memset(obj->ob_digit, 0, len * sizeof(digit1_t));
    return obj;
}

static Py_ssize_t
bit_length(PyLongObject *x)
{
    Py_ssize_t xlen = Py_SIZE(x);
    Py_ssize_t sbits = (xlen - 1) * PyLong_SHIFT;
    digit1_t msd, h;
    int s, y = 0;
    int stdc = 1;

    msd = x->ob_digit[xlen - 1];

#if defined(__GNUC__)
#define INTSQRT_FAST_BIT_LENGTH

    y = 32 - __builtin_clz(msd);
    stdc = 0;

#endif

    if (stdc) {
        s = sizeof(digit1_t) * 4;
        y = 1;
        do {
            if ((h = msd >> s)) {
                msd = h;
                y += s;
            }
            s >>= 1;
        }
        while (s > 0);
    }

    return y + sbits;
}

inline static int
lazy_bit_length(PyLongObject *x)
{
#ifdef INTSQRT_FAST_BIT_LENGTH
    return bit_length(x);
#else
    return Py_SIZE(x) * PyLong_SHIFT;
#endif
}

/*
 * Method: intsqrt
 */

#define INTSQRT_xDIGIT(intsqrt_digit, digit)                \
                                                            \
typedef struct intsqrt_digit##_result {                     \
    digit root;                                             \
    digit rem;                                              \
} intsqrt_digit##_result;                                   \
                                                            \
inline static intsqrt_digit##_result                        \
intsqrt_digit(digit x, int bits)                            \
{                                                           \
    intsqrt_digit##_result res;                             \
    digit s, t, u, v, y;                                    \
    int b;                                                  \
                                                            \
    s = 0;                                                  \
    y = 0;                                                  \
    for (b = bits >> 1, t = 1 << b; b >= 0; b--, t >>= 1) { \
        y <<= 1;                                            \
        u = s + t;                                          \
        v = u << b;                                         \
        if (x < v)                                          \
            continue;                                       \
        x -= v;                                             \
        y |= 1;                                             \
        s = t + u;                                          \
    }                                                       \
    res.root = y;                                           \
    res.rem = x;                                            \
    return res;                                             \
}

INTSQRT_xDIGIT(intsqrt_digit1, digit1_t);
INTSQRT_xDIGIT(intsqrt_digit2, digit2_t);

#define DIGIT1_SHIFT  PyLong_SHIFT
#define DIGIT2_SHIFT  (DIGIT1_SHIFT * 2)

#define DIGIT1_MASK   PyLong_MASK
#define DIGIT2_MASK   (((digit2_t) DIGIT1_MASK << DIGIT1_SHIFT) | DIGIT1_MASK)

#define DIGIT1_MAX    (DIGIT1_MASK + 1)
#define DIGIT2_MAX    (DIGIT2_MASK + 1)

/*
 *
 */

typedef struct intsqrt_large_state {
    digit1_t *w1;
    digit1_t *x1;
    digit1_t *y1;
    digit1_t *z1, *z2;
    unsigned long z_swap;

    Py_ssize_t wsize, wlen;
    Py_ssize_t xsize;
    Py_ssize_t ysize, ylen;
    Py_ssize_t zsize, zlen;
} intsqrt_large_state;

inline static digit2_t
intsqrt_digit2_t(digit1_t h, digit1_t l)
{
    return ((digit2_t) h << DIGIT1_SHIFT) | l;
}

inline static digit2_t
intsqrt_digit2_get(digit1_t *p)
{
    return intsqrt_digit2_t(p[1], p[0]);
}

inline static digit2_t
intsqrt_digit2_get_high(digit1_t *p)
{
    return (digit2_t) p[1] << DIGIT1_SHIFT;
}

inline static void
intsqrt_digit2_set(digit1_t *p, digit2_t d)
{
    p[0] = (digit1_t)(d & DIGIT1_MASK);
    p[1] = (digit1_t)(d >> DIGIT1_SHIFT);
}

inline static digit4_t
intsqrt_digit2to4(digit2_t h, digit2_t l)
{
    return ((digit4_t) h << DIGIT2_SHIFT) | l;
}

inline static digit4_div_t
intsqrt_digit4_div(digit4_t num, digit4_t den)
{
    digit4_div_t res;
    res.quot = num / den;
    res.rem  = num % den;
    return res;
}

inline static int
intsqrt_adjust_after_add1(digit1_t *p, Py_ssize_t len)
{
    digit1_t *s = p;
    digit1_t *e = s + len;
    digit1_t c, d;

    c = 0;
    do {
        d = *p + c;
        c = (d >> DIGIT1_SHIFT);
        *p++ = d & DIGIT1_MASK;
    }
    while ((c != 0) && (p < e));

    return (c != 0) ? 1 : 0;
}

inline static int
intsqrt_adjust_after_sub(digit1_t *p, Py_ssize_t len)
{
    digit1_t *s = p;
    digit1_t *e = s + len;
    digit1_t c, d;

    c = 0;
    do {
        d = *p + c;
        c = (sdigit1_t) d >> DIGIT1_SHIFT;
        *p++ = d & DIGIT1_MASK;
    }
    while (p < e);

    return (c != 0) ? 1 : 0;
}

inline static void
intsqrt_large_next(intsqrt_large_state *s)
{
    s->w1 -= 1, s->wlen += 1;
    s->x1 -= 2;
    s->y1 -= 1; s->ylen += 1;
    s->z1 -= 2; s->zlen += 2;
    s->z2 -= 2;
}

inline static void
intsqrt_large_fix_zlen(intsqrt_large_state *s)
{
    while ((s->zlen > 0) && !s->z1[s->zlen - 1])
        s->zlen--;
}

inline static void
intsqrt_large_first(intsqrt_large_state *s)
{
    intsqrt_digit2_result msr;
    digit2_t msz;
    int hbits;

    if ((s->xsize & 1) == 0) {
        s->x1 -= 2;
        msz = intsqrt_digit2_get(s->x1);
        hbits = DIGIT2_SHIFT;
    }
    else {
        s->x1 -= 1;
        msz = s->x1[0];
        hbits = DIGIT1_SHIFT;
    }

    msr = intsqrt_digit2(msz, hbits);
    intsqrt_digit2_set(s->w1, (msr.root << 1));
    if (s->w1[1] == 0)
        s->wlen--;

    s->y1[0] = msr.root;
    s->z1[0] = msr.rem;
    s->z1[1] = 0;
}

inline static void
intsqrt_large_second(intsqrt_large_state *s)
{
    digit2_t msw, ssw;
    digit2_t msz, ssz;
    digit2_t ws, we, wp;

    digit4_t w4s, w4e;
    digit4_t w4;
    digit4_t z4;

    s->z1[0] = s->x1[0];
    s->z1[1] = s->x1[1];

    msw = s->w1[2];
    ssw = intsqrt_digit2_get_high(s->w1);

    msz = intsqrt_digit2_get(s->z1 + 2);
    ssz = intsqrt_digit2_get(s->z1 + 0);

    w4s = intsqrt_digit2to4(msw, ssw + DIGIT1_MASK);
    w4e = intsqrt_digit2to4(msw, ssw + 1);

    z4 = intsqrt_digit2to4(msz, ssz);
    ws = intsqrt_digit4_div(z4, w4s).quot;
    we = intsqrt_digit4_div(z4, w4e).quot;
    if (we == 0)
        return;

    if (ws >= DIGIT1_MASK) {
        ws = we = DIGIT1_MASK;
        z4 -= w4s * DIGIT1_MASK;
    }
    else {
        while ((we - ws) > 1) {
            wp = (ws + we) >> 1;
            w4 = intsqrt_digit2to4(msw, ssw + wp) * wp;
            if (w4 < z4)
                ws = wp;
            else
                we = wp - 1;
        }

        w4s = intsqrt_digit2to4(msw, ssw + ws) * ws;
        if (ws != we) {
            w4e = intsqrt_digit2to4(msw, ssw + we) * we;
            if (z4 >= w4e) {
                ws = we;
                w4s = w4e;
            }
        }
        z4 -= w4s;
    }

    msz = (digit2_t) (z4 >> DIGIT2_SHIFT);
    ssz = (digit2_t) z4;

    s->w1[0] = (ws * 2) & DIGIT1_MASK;
    s->w1[1] += (ws >> (DIGIT1_SHIFT -1));

    s->y1[0] = ws;
    intsqrt_digit2_set(s->z1 + 0, ssz);
    intsqrt_digit2_set(s->z1 + 2, msz);
}

inline static void
intsqrt_large_step_zswap(intsqrt_large_state *s)
{
    digit1_t *z_ptr;

    z_ptr = s->z1;
    s->z1 = s->z2;
    s->z2 = z_ptr;
    s->z_swap++;
}

inline static int
intsqrt_large_step_zsub1(intsqrt_large_state *s)
{
    Py_ssize_t wlen = s->wlen;
    Py_ssize_t zlen = s->zlen;
    Py_ssize_t dlen;
    Py_ssize_t i;

    digit1_t k;
    digit2_t wh, wl;

    dlen = zlen - wlen;

    k = s->w1[0];
    wl = (digit2_t) s->w1[wlen - 1] * k;
    wh = wl >> DIGIT1_SHIFT;
    wl &= DIGIT1_MASK;

    if (dlen == 0) {
        if (wh != 0)
            return 1;
        s->z2[zlen - 1] = s->z1[zlen - 1] - wl;
    }
    else {
        wh = s->z1[zlen - 1] - wh;
        if ((sdigit1_t) wh < 0)
            return 1;
        wl = s->z1[zlen - 2] - wl;

        s->z2[zlen - 1] = wh;
        s->z2[zlen - 2] = wl;
    }

    wh = 0;
    for (i = 0; i < (wlen - 1); i++) {
        wl = (digit2_t) s->w1[i] * k + wh;
        wh = wl >> DIGIT1_SHIFT;
        wl &= DIGIT1_MASK;
        s->z2[i] = s->z1[i] - wl;
    }
    s->z2[i] -= wh;

    return intsqrt_adjust_after_sub(s->z2, s->zlen);
}

inline static int
intsqrt_large_step_zsub2(intsqrt_large_state *s)
{
    Py_ssize_t wlen = s->wlen;
    Py_ssize_t zlen = s->zlen;
    Py_ssize_t dlen;
    Py_ssize_t i;

    digit1_t kl, kh;

    dlen = zlen - wlen;
    if (dlen == 0) {
        for (i = wlen - 1; i >= 2; i--) {
            if (s->z1[i] < s->w1[i])
                return 1;
            if (s->z1[i] > s->w1[i])
                break;
        }
    }

    kl = s->w1[0] * 2 + 1;
    kh = kl >> DIGIT1_SHIFT;
    kl &= DIGIT1_MASK;

    s->z2[0] = s->z1[0] - kl;
    s->z2[1] = s->z1[1] - kh - s->w1[1];
    for (i = 2; i < wlen; i++)
        s->z2[i] = s->z1[i] - s->w1[i];
    for (; i < zlen; i++)
        s->z2[i] = s->z1[i];

    return intsqrt_adjust_after_sub(s->z2, zlen);
}

inline static void
intsqrt_large_step(intsqrt_large_state *s)
{
    digit2_t msw, ssw;
    digit2_t msz, ssz;
    digit2_t zdw;
    digit4_t w4;
    digit4_t z4;
    Py_ssize_t dlen;

    s->z1[0] = s->x1[0];
    s->z1[1] = s->x1[1];

    dlen = s->zlen - s->wlen;
    if (dlen < 0)
        return;

    msw = 0;
    ssw = intsqrt_digit2_get(s->w1 + s->wlen - 2);
    zdw = 0;

    if (dlen == 0) {
        ssz = intsqrt_digit2_get(s->z1 + s->zlen - 2);
        msz = 0;
        zdw = ssz / ssw;
    }
    else {
        ssz = intsqrt_digit2_get(s->z1 + s->zlen - 3);
        msz = s->z1[s->zlen - 1];
        z4 = intsqrt_digit2to4(msz, ssz);
        w4 = intsqrt_digit4_div(z4, ssw).quot;
        zdw = (digit2_t) w4;
    }
    if (zdw == 0)
        return;

    if (zdw > DIGIT1_MASK) {
        s->w1[0] = zdw = DIGIT1_MASK;
        intsqrt_large_step_zsub1(s);
        intsqrt_large_step_zswap(s);
    }
    else if (zdw == 1) {
        s->w1[0] = zdw;
        if (intsqrt_large_step_zsub1(s)) {
            s->w1[0] = 0;
            return;
        }
        intsqrt_large_step_zswap(s);
    }
    else {
        s->w1[0] = --zdw;
        intsqrt_large_step_zsub1(s);
        intsqrt_large_step_zswap(s);
        if (!intsqrt_large_step_zsub2(s)) {
            s->w1[0] = ++zdw;
            intsqrt_large_step_zswap(s);
            if (!intsqrt_large_step_zsub2(s)) {
                s->w1[0] = ++zdw;
                intsqrt_large_step_zswap(s);
            }
        }
    }

    s->y1[0] = zdw;
    s->w1[0] += zdw;
    intsqrt_adjust_after_add1(s->w1, s->wlen);
}

static PyObject *
intsqrt_large(PyLongObject *x, int remobj)
{
    PyObject *retval = NULL;
    PyObject *tuple = NULL;
    PyLongObject *w1 = NULL;
    PyLongObject *x1 = x;
    PyLongObject *y1 = NULL;
    PyLongObject *z1 = NULL;
    PyLongObject *z2 = NULL;
    PyLongObject *zp = NULL;
    intsqrt_large_state st, *s = &st;

    s->xsize = Py_SIZE(x);
    s->zsize = (s->xsize + 1) & ~1;
    s->ysize = s->zsize >> 1;
    s->wsize = s->ysize + 1;

    if (!(w1 = NewLong(s->wsize))) goto error;
    if (!(y1 = NewLong(s->ysize))) goto error;
    if (!(z1 = NewLong(s->zsize))) goto error;
    if (!(z2 = NewLong(s->zsize))) goto error;

    s->wlen = 2;
    s->ylen = 1;
    s->zlen = 2;

    s->w1 = w1->ob_digit + s->wsize - s->wlen;
    s->x1 = x1->ob_digit + s->xsize;
    s->y1 = y1->ob_digit + s->ysize - s->ylen;
    s->z1 = z1->ob_digit + s->zsize - s->zlen;
    s->z2 = z2->ob_digit + s->zsize - s->zlen;
    s->z_swap = 0;

    intsqrt_large_first(s);
    intsqrt_large_fix_zlen(s);
    if (s->ylen < s->ysize) {
        intsqrt_large_next(s);
        intsqrt_large_second(s);
        intsqrt_large_fix_zlen(s);
        while (s->ylen < s->ysize) {
            intsqrt_large_next(s);
            intsqrt_large_step(s);
            intsqrt_large_fix_zlen(s);
        }
    }

    if (!remobj) {
        retval = (PyObject *) y1;
        goto success1;
    }

    if ((s->z_swap & 1)) {
        zp = z1;
        z1 = z2;
        z2 = zp;
    }
    Py_SIZE(z1) = s->zlen;

    if (!(tuple = PyTuple_New(2)))
        goto error;
    PyTuple_SET_ITEM(tuple, 0, (PyObject *) y1);
    PyTuple_SET_ITEM(tuple, 1, (PyObject *) z1);
    retval = tuple;
    goto success2;
error:
    Py_XDECREF(tuple);
    Py_XDECREF(y1);
success1:
    Py_XDECREF(z1);
success2:
    Py_XDECREF(w1);
    Py_XDECREF(z2);
    return retval;
}

static PyObject *
intsqrt(PyLongObject *x, int remobj)
{
    PyObject *tuple = NULL;
    PyObject *root = NULL;
    PyObject *rem = NULL;

    intsqrt_digit1_result res1;
    intsqrt_digit2_result res2;

    digit2_t t;
    int bits;

    if (Py_SIZE(x) == 0) {
        if (!(root = PyLong_FromUnsignedLong(0)))
            goto error;
        if (remobj)
            rem = IncRef(root);
    }
    else if (Py_SIZE(x) == 1) {
        bits = lazy_bit_length(x);
        res1 = intsqrt_digit1(x->ob_digit[0], bits);
        if (!(root = PyLong_FromUnsignedLong(res1.root)))
            goto error;
        if (remobj && !(rem = PyLong_FromUnsignedLong(res1.rem)))
            goto error;
    }
    else if (Py_SIZE(x) == 2) {
        bits = lazy_bit_length(x);
        t = intsqrt_digit2_get(x->ob_digit);
        res2 = intsqrt_digit2(t, bits);
        if (!(root = PyLong_FromUnsignedLongLong(res2.root)))
            goto error;
        if (remobj && !(rem = PyLong_FromUnsignedLong(res2.rem)))
            goto error;
    }
    else
        return intsqrt_large(x, remobj);

    if (!remobj)
        return root;
    if (!(tuple = PyTuple_New(2)))
        goto error;
    PyTuple_SET_ITEM(tuple, 0, root);
    PyTuple_SET_ITEM(tuple, 1, rem);
    return tuple;

error:
    Py_XDECREF(tuple);
    Py_XDECREF(root);
    Py_XDECREF(rem);
    return NULL;
}

static PyObject *
method_intsqrt(PyObject *unused, PyObject *const *args, Py_ssize_t size)
{
    PyObject *x, *y, *r;
    PyNumberMethods *as_number;
    inquiry nb_bool;
    unaryfunc nb_int;
    int t, remobj;

    (void)unused;
    remobj = 0;

    if (size < 1) {
        PyErr_SetString(PyExc_TypeError, "intsqrt expected at least 1 argument, got 0");
        return NULL;
    }
    if (size > 2) {
        PyErr_SetString(PyExc_TypeError, "intsqrt expected at most 2 arguments");
        return NULL;
    }
    if (size == 2) {
        r = args[1];
        if (r &&
            (as_number = Py_TYPE(r)->tp_as_number) &&
            (nb_bool = as_number->nb_bool) &&
            nb_bool(r))
            remobj = 1;
    }

    x = IncRef(args[0]);
    for (t = 0;; t++) {
        if (PyLong_CheckExact(x))
            break;
        if (t >= 256)
            goto type_error;
        if (!(as_number = Py_TYPE(x)->tp_as_number))
            goto type_error;
        if (!(nb_int = as_number->nb_int))
            goto type_error;
        if (!(y = nb_int(x)))
            goto type_error;
        Py_DECREF(x);
        x = y;
    }
    if (Py_SIZE(x) < 0) {
        PyErr_SetString(PyExc_ArithmeticError, "(x < 0)");
        Py_DECREF(x);
        return NULL;
    }
    y = intsqrt((PyLongObject *) x, remobj);
    Py_DECREF(x);
    return y;

type_error:
    PyErr_SetObject(PyExc_TypeError, args[0]);
    Py_DECREF(x);
    return NULL;
}

/*
 * Module: intsqrt
 */

static PyMethodDef intsqrt_methods[] = {
    {"intsqrt", (PyCFunction) method_intsqrt, METH_FASTCALL,
     "intsqrt(x: int) -> int: int(√x)\n"
     "intsqrt(x: int, True) -> tuple: (int(√x), x - int(√x)**2)\n"},
    {NULL, NULL, 0, NULL}, /* end */
};

static PyModuleDef intsqrt_def = {
    PyModuleDef_HEAD_INIT,
    .m_name = "intsqrt",
    .m_doc = NULL,
    .m_size = -1,
    .m_methods = intsqrt_methods,
};

PyMODINIT_FUNC
PyInit_intsqrt(void)
{
    return PyModule_Create(&intsqrt_def);
}

/*
 * Local Variables:
 * c-file-style: "PEP7"
 * End:
 */
setup.py
#!/usr/bin/env python3

import os
from distutils.core import setup, Extension


def getenv(name, defval=None):
    if name in os.environ:
        return os.environ[name]
    return defval


DEBUG = getenv('DEBUG') in ('true', 'yes')

MAJOR_VERSION = 0
MINOR_VERSION = 1
DEBUG_VERSION = 0
VERSION = '%d.%d.%d' % (MAJOR_VERSION, MINOR_VERSION, DEBUG_VERSION)

DEFINE_MACROS = [
    ('MAJOR_VERSION', MAJOR_VERSION),
    ('MINOR_VERSION', MINOR_VERSION),
    ('DEBUG_VERSION', DEBUG_VERSION),
]
UNDEF_MACROS = []

EXTRA_COMPILE_ARGS = [
    '-W',
    '-Wall',
    '-Wno-invalid-offsetof',
    '-Wno-deprecated-declarations',
]

if DEBUG:
    DEFINE_MACROS.append(('DEBUG', 1))
    UNDEF_MACROS.append('NDEBUG')
    EXTRA_COMPILE_ARGS.append('-O0')

setup(name='intsqrt',
      version=VERSION,
      description='',
      ext_modules=[Extension(
          name='intsqrt',
          define_macros=DEFINE_MACROS,
          undef_macros=UNDEF_MACROS,
          extra_compile_args=EXTRA_COMPILE_ARGS,
          sources=['intsqrt.c'])])
ビルドと実行(ターミナル上)
$ python3 setup.py build
running build
running build_ext
building 'intsqrt' extension
creating build
creating build/temp.macosx-10.14-x86_64-3.8
clang -Wno-unused-result -Wsign-compare -Wunreachable-code -fno-common -dynamic -DNDEBUG -g -fwrapv -O3 -Wall -iwithsysroot/System/Library/Frameworks/System.framework/PrivateHeaders -iwithsysroot/Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.8/Headers -arch arm64 -arch x86_64 -Werror=implicit-function-declaration -DMAJOR_VERSION=0 -DMINOR_VERSION=1 -DDEBUG_VERSION=0 -I/Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.8/include/python3.8 -c intsqrt.c -o build/temp.macosx-10.14-x86_64-3.8/intsqrt.o -W -Wall -Wno-invalid-offsetof -Wno-deprecated-declarations
creating build/lib.macosx-10.14-x86_64-3.8
clang -bundle -undefined dynamic_lookup -arch arm64 -arch x86_64 -Wl,-headerpad,0x1000 build/temp.macosx-10.14-x86_64-3.8/intsqrt.o -o build/lib.macosx-10.14-x86_64-3.8/intsqrt.cpython-38-darwin.so
$
$ (cd build/lib.macosx-10.14-x86_64-3.8 ; python3)
>>> from intsqrt import intsqrt
>>> intsqrt(99)
9
>>> intsqrt(99, True)
(9, 18)
>>> 1<<201
3213876088517980551083924184682325205044405987565585670602752
>>> intsqrt(1<<201)
1792728671193156477399422023278
>>> r = intsqrt(1<<201, True)
>>> r
(1792728671193156477399422023278, 2371767103687091674094496737468)
>>> r[0]**2+r[1]
3213876088517980551083924184682325205044405987565585670602752

検証・確認に用いた Python プログラム

開平法と計算を可視化するプログラム(折りたたみ)
# 開平法 2進数
def bsqrt(x):
    if x < 0:
        raise ArithmeticError
    l, r, s = x.bit_length() >> 1, 0, 0
    for b in range(l, -1, -1):
        r <<= 1
        t = 1 << b
        u = s + t
        v = u << b
        if x < v:
            continue
        x -= v
        r |= 1
        s = u + t
    return (r, x)


# 開平法 N進数
def esqrt(x, N=10):
    def to_dig(x):
        dig = [x % N]
        c = x // N
        while c > 0:
            dig.append(c % N)
            c //= N
        return dig

    def to_rdig(x):
        return list(reversed(to_dig(x)))

    if N < 2:
        raise ArithmeticError

    NN = N * N
    NL = N - 1

    xdig = to_dig(x)
    if len(xdig) & 1:
        xdig.append(0)
    xdig = list(reversed(xdig))

    step = []

    # 最初の桁

    d = xdig[:2]
    dig = xdig[2:]
    s = d[0] * N + d[1]

    root, rrem = bsqrt(s)  # 2進数版を使う
    rsub = root * 2

    rstep = [to_rdig(root), to_rdig(rrem), to_rdig(rsub), to_rdig(s), to_rdig(root**2)]
    step.append(rstep)

    if not len(dig):
        return [[root, to_rdig(root)], [rrem, to_rdig(rrem)], [xdig, step]]

    # 残りの桁

    col = 1
    while len(dig):
        col += 1

        root *= N
        rrem *= NN
        rsub *= N

        d = dig[:2]
        dig = dig[2:]
        rrem += d[0] * N + d[1]
        prem = rrem

        ns = rrem // (rsub + NL)
        ne = rrem // (rsub + 1)

        def find(ns, ne):
            if ne == 0:
                return (0, 0)
            if ns >= NL:
                return (NL, (rsub + NL) * NL)
            while (ne - ns) > 1:
                np = (ns + ne) >> 1
                pp = (rsub + np) * np
                if pp < rrem:
                    ns = np
                else:
                    ne = np - 1
            ts = (rsub + ns) * ns
            if ns != ne:
                te = (rsub + ne) * ne
                if rrem >= te:
                    ns = ne
                    ts = te
            return (ns, ts)
        ns, ts = find(ns, ne)

        if ns:
            root += ns
            rrem -= ts
            rsub += ns * 2

        rstep = [to_rdig(root), to_rdig(rrem), to_rdig(rsub), to_rdig(prem), to_rdig(ts)]
        step.append(rstep)

    res = [[root, to_rdig(root)], [rrem, to_rdig(rrem)], [xdig, step]]
    if root != bsqrt(x)[0]:
        raise Exception('Bug!!: esqrt(%d,%d)' % (x, N))
    return res


# N進数開平法(筆算型)を可視化
def esqrt_format(x, N=10, xdec=False):
    if N < 10:
        NUM_FMT = '%%%dd' % (N - 1)
    elif N <= 16:
        NUM_FMT = '%%0%dx' % (N - 1)
    else:
        XDEC = 'dx'[int(xdec)]
        NUM_FMT = ('%%0%d' + XDEC) % len(('%' + XDEC) % (N - 1))

    def nfmt(n):
        if N <= 16:
            return ''.join('0123456789abcdef'[v] for v in n)
        return '_'.join(NUM_FMT % v for v in n)

    root, rrem, (x, step) = esqrt(x, N)

    mrem = [s[3] for s in step]
    msub = [s[4] for s in step]

    sans = [s[0][-1] for s in step]
    srem = [s[1] for s in step]
    ssub = [[0]] + [s[2] for s in step]
    sdiff = [s[3] for s in step]
    nans = len(sans)

    wn1 = len(nfmt([0]))
    wn2 = len(nfmt([0, 0])) - wn1

    def nwidth(column):
        return wn1 + wn2 * (column - 1)

    def del0(n):
        while len(n) > 1 and n[0] == 0:
            n = n[1:]
        return n

    def dfmt(n):
        nn = len(n)
        s = []
        if nn & 1:
            s = [(' ' * wn2) + nfmt(n[:1])]
            n = n[1:]
            nn -= 1
        s += [nfmt(n[i:i+2]) for i in range(0, nn, 2)]
        return ' '.join(s)

    mr = ' '.join(' ' * wn2 + nfmt([n]) for n in step[-1][0])
    mx = dfmt(del0(x))
    ml = '-' * len(mx)

    cmain = [mr, ml, mx]
    for i in range(nans):
        wc = i + nwidth(2) * (i + 1)
        mfmt = '%%%ds' % wc

        dr = dfmt(mrem[i])
        ds = dfmt(msub[i])
        dc = len(ds)
        if i:
            cmain.append(mfmt % dr)
            dc = max(dc, len(dr))
        cmain.append(mfmt % ds)

        if (i + 1) != nans:
            o = nwidth(2) + 1
            dc += o
            wc += o
        mfmt = '%%%ds' % wc
        cmain.append(mfmt % ('-' * dc))
    cmain.append(mfmt % dfmt(rrem[1]))
    mfmt = '%%-%ds' % len(cmain[0])
    for n in range(nwidth(2)):
        if sum(int(cmain[i][n] != ' ') for i in (0, 2, 3)):
            break
    cmain = ['  ' + (mfmt % s)[n:] for s in cmain]
    cmain[1] = '--' + cmain[1][2:]
    cmain[2] = ')' + cmain[2][1:]

    csub = []
    wrs = nwidth(len(ssub[-1]) + 1)
    for i in range(nans):
        ws = nwidth(i + 2)
        sfmt = ('%%%ds' % ws)
        da = [sans[i]]
        sa = sfmt % nfmt(da)
        ss = sfmt % nfmt(del0(ssub[i] + da))
        sl = '-' * nwidth(i + 2 + int((i + 1) < nans))
        csub.append(ss + ' ' * (wrs - len(ss)))
        csub.append(sa + ' ' * (wrs - len(sa)))
        csub.append(sl + ' ' * (wrs - len(sl)))
    csub.append(sfmt % nfmt(del0(ssub[-1])))

    for n in range(nwidth(2)):
        if sum(int(csub[i][n] != ' ') for i in (0, 1, 3)):
            break
    csub = [l[n:] for l in csub]

    mspc = ' ' * len(cmain[0])
    sspc = ' ' * len(csub[0])

    csub = [sspc, sspc] + csub
    for i in range(max(len(cmain), len(csub))):
        m = mspc if i >= len(cmain) else cmain[i]
        s = sspc if i >= len(csub) else csub[i]
        print(m, ' ', s)
実行結果
>>> esqrt_format(11223344)  # デフォルトは10進数の開平法
   3  3  5  0       
-------------       
) 11 22 33 44   3   
   9            3   
  -----         --  
   2 22         63  
   1 89          3  
  --------      --- 
     33 33      665 
     33 25        5 
     --------   ----
         8 44   6700
            0      0
        -----   ----
         8 44   6700
>>> esqrt_format(11223344, 2)  # 2進数の開平法
   1  1  0  1  0  0  0  1  0  1  1  0                 
-------------------------------------                 
) 10 10 10 11 01 00 00 01 00 11 00 00    1            
   1                                     1            
  -----                                 ---           
   1 10                                 101           
   1 01                                   1           
  --------                              ----          
      1 10                              1100          
         0                                 0          
     --------                           -----         
      1 10 11                           11001         
      1 10 01                               1         
     -----------                        ------        
           10 01                        110100        
               0                             0        
           --------                     -------       
           10 01 00                     1101000       
                  0                           0       
           -----------                  --------      
           10 01 00 00                  11010000      
                     0                         0      
           --------------               ---------     
           10 01 00 00 01               110100001     
            1 10 10 00 01                       1     
           -----------------            ----------    
              10 10 00 00 00            1101000100    
                           0                     0    
              -----------------         -----------   
              10 10 00 00 00 11         11010001001   
               1 10 10 00 10 01                   1   
              --------------------      ------------  
                 11 01 11 10 10 00      110100010101  
                 11 01 00 01 01 01                 1  
                 --------------------   ------------- 
                       11 01 00 11 00   1101000101100 
                                    0               0 
                       --------------   ------------- 
                       11 01 00 11 00   1101000101100
>>> esqrt_format(11223344, 16)  # 16進数の開平法
   d  1  6        
----------        
) ab 41 30    d   
  a9          d   
  -----      ---  
   2 41      1a1  
   1 a1        1  
  --------   ---- 
     a0 30   1a26 
     9c e4      6 
     -----   ---- 
      3 4c   1a2c

intsqrt(N) != bsqrt(N) のとき、esqrt_format(N,1<<30) で出てくる計算手順と違うところを探しました。


計算速度ならば gmpy2.isqrt を使った方が良いです。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0