Python/C API の PyLongObject 型用平方根(拡張モジュール)
Python での多倍長整数 int 型(C では PyLongObject 型) の平方根を整数の範囲で求めます。
64ビット版 Python では 230 進数(PyLongObject での単位)の開平法で計算しています。32ビット版 Python では 215 進数となるハズですが動作確認していないので、恐らくコンパイルできなかったりバグったりします。
除算は C 言語の __int128 (コンパイラ拡張) の範囲に絞り、多倍長整数では行っていませんから、それなりの速さで計算してくれるかと思います。
intsqrt(x: int) -> int: int(√x)
intsqrt(x: int, True) -> tuple: (int(√x), x - int(√x)**2)
オプションで余りを取得できます。
ビルドと実行
distutils の setup.py でモジュールを作ります。
拡張モジュールとしては、メソッドが1つだけという簡単なものになってます。
動作確認している環境
OS
macOS 12.2.1
C コンパイラ (Xcode 13.2.1)
Apple clang version 13.0.0 (clang-1300.0.29.30)
Target: x86_64-apple-darwin21.3.0
Python : /usr/bin/python3
Python 3.8.9 (default, Oct 26 2021, 07:25:54)
[Clang 13.0.0 (clang-1300.0.29.30)] on darwin
intsqrt モジュール (intsqrt.c と setup.py : 折りたたみ)
/* -*- coding: utf-8; -*- */
#ifndef PY_SSIZE_T_CLEAN
#define PY_SSIZE_T_CLEAN
#endif /* PY_SSIZE_T_CLEAN */
#include <Python.h>
/*
*
*/
typedef digit digit1_t;
typedef sdigit sdigit1_t;
typedef twodigits digit2_t;
typedef stwodigits sdigit2_t;
#if defined(__GNUC__)
# if PyLong_SHIFT == 15
typedef __int64_t digit4_t;
# elif PyLong_SHIFT == 30
typedef __int128_t digit4_t;
# else
# error "unknown PyLong_SHIFT"
# endif
typedef struct digit4_div_t {
digit4_t quot;
digit4_t rem;
} digit4_div_t;
#else
# error "unknown compiler"
#endif
/*
*
*/
inline static PyObject *
IncRef(PyObject *object)
{
Py_INCREF(object);
return object;
}
inline static PyLongObject *
NewLong(Py_ssize_t len)
{
PyLongObject *obj;
if (!(obj = PyObject_NewVar(PyLongObject, &PyLong_Type, len)))
return NULL;
if (!PyObject_InitVar((PyVarObject *) obj, Py_TYPE(obj), len)) {
Py_XDECREF(obj);
return NULL;
}
memset(obj->ob_digit, 0, len * sizeof(digit1_t));
return obj;
}
static Py_ssize_t
bit_length(PyLongObject *x)
{
Py_ssize_t xlen = Py_SIZE(x);
Py_ssize_t sbits = (xlen - 1) * PyLong_SHIFT;
digit1_t msd, h;
int s, y = 0;
int stdc = 1;
msd = x->ob_digit[xlen - 1];
#if defined(__GNUC__)
#define INTSQRT_FAST_BIT_LENGTH
y = 32 - __builtin_clz(msd);
stdc = 0;
#endif
if (stdc) {
s = sizeof(digit1_t) * 4;
y = 1;
do {
if ((h = msd >> s)) {
msd = h;
y += s;
}
s >>= 1;
}
while (s > 0);
}
return y + sbits;
}
inline static int
lazy_bit_length(PyLongObject *x)
{
#ifdef INTSQRT_FAST_BIT_LENGTH
return bit_length(x);
#else
return Py_SIZE(x) * PyLong_SHIFT;
#endif
}
/*
* Method: intsqrt
*/
#define INTSQRT_xDIGIT(intsqrt_digit, digit) \
\
typedef struct intsqrt_digit##_result { \
digit root; \
digit rem; \
} intsqrt_digit##_result; \
\
inline static intsqrt_digit##_result \
intsqrt_digit(digit x, int bits) \
{ \
intsqrt_digit##_result res; \
digit s, t, u, v, y; \
int b; \
\
s = 0; \
y = 0; \
for (b = bits >> 1, t = 1 << b; b >= 0; b--, t >>= 1) { \
y <<= 1; \
u = s + t; \
v = u << b; \
if (x < v) \
continue; \
x -= v; \
y |= 1; \
s = t + u; \
} \
res.root = y; \
res.rem = x; \
return res; \
}
INTSQRT_xDIGIT(intsqrt_digit1, digit1_t);
INTSQRT_xDIGIT(intsqrt_digit2, digit2_t);
#define DIGIT1_SHIFT PyLong_SHIFT
#define DIGIT2_SHIFT (DIGIT1_SHIFT * 2)
#define DIGIT1_MASK PyLong_MASK
#define DIGIT2_MASK (((digit2_t) DIGIT1_MASK << DIGIT1_SHIFT) | DIGIT1_MASK)
#define DIGIT1_MAX (DIGIT1_MASK + 1)
#define DIGIT2_MAX (DIGIT2_MASK + 1)
/*
*
*/
typedef struct intsqrt_large_state {
digit1_t *w1;
digit1_t *x1;
digit1_t *y1;
digit1_t *z1, *z2;
unsigned long z_swap;
Py_ssize_t wsize, wlen;
Py_ssize_t xsize;
Py_ssize_t ysize, ylen;
Py_ssize_t zsize, zlen;
} intsqrt_large_state;
inline static digit2_t
intsqrt_digit2_t(digit1_t h, digit1_t l)
{
return ((digit2_t) h << DIGIT1_SHIFT) | l;
}
inline static digit2_t
intsqrt_digit2_get(digit1_t *p)
{
return intsqrt_digit2_t(p[1], p[0]);
}
inline static digit2_t
intsqrt_digit2_get_high(digit1_t *p)
{
return (digit2_t) p[1] << DIGIT1_SHIFT;
}
inline static void
intsqrt_digit2_set(digit1_t *p, digit2_t d)
{
p[0] = (digit1_t)(d & DIGIT1_MASK);
p[1] = (digit1_t)(d >> DIGIT1_SHIFT);
}
inline static digit4_t
intsqrt_digit2to4(digit2_t h, digit2_t l)
{
return ((digit4_t) h << DIGIT2_SHIFT) | l;
}
inline static digit4_div_t
intsqrt_digit4_div(digit4_t num, digit4_t den)
{
digit4_div_t res;
res.quot = num / den;
res.rem = num % den;
return res;
}
inline static int
intsqrt_adjust_after_add1(digit1_t *p, Py_ssize_t len)
{
digit1_t *s = p;
digit1_t *e = s + len;
digit1_t c, d;
c = 0;
do {
d = *p + c;
c = (d >> DIGIT1_SHIFT);
*p++ = d & DIGIT1_MASK;
}
while ((c != 0) && (p < e));
return (c != 0) ? 1 : 0;
}
inline static int
intsqrt_adjust_after_sub(digit1_t *p, Py_ssize_t len)
{
digit1_t *s = p;
digit1_t *e = s + len;
digit1_t c, d;
c = 0;
do {
d = *p + c;
c = (sdigit1_t) d >> DIGIT1_SHIFT;
*p++ = d & DIGIT1_MASK;
}
while (p < e);
return (c != 0) ? 1 : 0;
}
inline static void
intsqrt_large_next(intsqrt_large_state *s)
{
s->w1 -= 1, s->wlen += 1;
s->x1 -= 2;
s->y1 -= 1; s->ylen += 1;
s->z1 -= 2; s->zlen += 2;
s->z2 -= 2;
}
inline static void
intsqrt_large_fix_zlen(intsqrt_large_state *s)
{
while ((s->zlen > 0) && !s->z1[s->zlen - 1])
s->zlen--;
}
inline static void
intsqrt_large_first(intsqrt_large_state *s)
{
intsqrt_digit2_result msr;
digit2_t msz;
int hbits;
if ((s->xsize & 1) == 0) {
s->x1 -= 2;
msz = intsqrt_digit2_get(s->x1);
hbits = DIGIT2_SHIFT;
}
else {
s->x1 -= 1;
msz = s->x1[0];
hbits = DIGIT1_SHIFT;
}
msr = intsqrt_digit2(msz, hbits);
intsqrt_digit2_set(s->w1, (msr.root << 1));
if (s->w1[1] == 0)
s->wlen--;
s->y1[0] = msr.root;
s->z1[0] = msr.rem;
s->z1[1] = 0;
}
inline static void
intsqrt_large_second(intsqrt_large_state *s)
{
digit2_t msw, ssw;
digit2_t msz, ssz;
digit2_t ws, we, wp;
digit4_t w4s, w4e;
digit4_t w4;
digit4_t z4;
s->z1[0] = s->x1[0];
s->z1[1] = s->x1[1];
msw = s->w1[2];
ssw = intsqrt_digit2_get_high(s->w1);
msz = intsqrt_digit2_get(s->z1 + 2);
ssz = intsqrt_digit2_get(s->z1 + 0);
w4s = intsqrt_digit2to4(msw, ssw + DIGIT1_MASK);
w4e = intsqrt_digit2to4(msw, ssw + 1);
z4 = intsqrt_digit2to4(msz, ssz);
ws = intsqrt_digit4_div(z4, w4s).quot;
we = intsqrt_digit4_div(z4, w4e).quot;
if (we == 0)
return;
if (ws >= DIGIT1_MASK) {
ws = we = DIGIT1_MASK;
z4 -= w4s * DIGIT1_MASK;
}
else {
while ((we - ws) > 1) {
wp = (ws + we) >> 1;
w4 = intsqrt_digit2to4(msw, ssw + wp) * wp;
if (w4 < z4)
ws = wp;
else
we = wp - 1;
}
w4s = intsqrt_digit2to4(msw, ssw + ws) * ws;
if (ws != we) {
w4e = intsqrt_digit2to4(msw, ssw + we) * we;
if (z4 >= w4e) {
ws = we;
w4s = w4e;
}
}
z4 -= w4s;
}
msz = (digit2_t) (z4 >> DIGIT2_SHIFT);
ssz = (digit2_t) z4;
s->w1[0] = (ws * 2) & DIGIT1_MASK;
s->w1[1] += (ws >> (DIGIT1_SHIFT -1));
s->y1[0] = ws;
intsqrt_digit2_set(s->z1 + 0, ssz);
intsqrt_digit2_set(s->z1 + 2, msz);
}
inline static void
intsqrt_large_step_zswap(intsqrt_large_state *s)
{
digit1_t *z_ptr;
z_ptr = s->z1;
s->z1 = s->z2;
s->z2 = z_ptr;
s->z_swap++;
}
inline static int
intsqrt_large_step_zsub1(intsqrt_large_state *s)
{
Py_ssize_t wlen = s->wlen;
Py_ssize_t zlen = s->zlen;
Py_ssize_t dlen;
Py_ssize_t i;
digit1_t k;
digit2_t wh, wl;
dlen = zlen - wlen;
k = s->w1[0];
wl = (digit2_t) s->w1[wlen - 1] * k;
wh = wl >> DIGIT1_SHIFT;
wl &= DIGIT1_MASK;
if (dlen == 0) {
if (wh != 0)
return 1;
s->z2[zlen - 1] = s->z1[zlen - 1] - wl;
}
else {
wh = s->z1[zlen - 1] - wh;
if ((sdigit1_t) wh < 0)
return 1;
wl = s->z1[zlen - 2] - wl;
s->z2[zlen - 1] = wh;
s->z2[zlen - 2] = wl;
}
wh = 0;
for (i = 0; i < (wlen - 1); i++) {
wl = (digit2_t) s->w1[i] * k + wh;
wh = wl >> DIGIT1_SHIFT;
wl &= DIGIT1_MASK;
s->z2[i] = s->z1[i] - wl;
}
s->z2[i] -= wh;
return intsqrt_adjust_after_sub(s->z2, s->zlen);
}
inline static int
intsqrt_large_step_zsub2(intsqrt_large_state *s)
{
Py_ssize_t wlen = s->wlen;
Py_ssize_t zlen = s->zlen;
Py_ssize_t dlen;
Py_ssize_t i;
digit1_t kl, kh;
dlen = zlen - wlen;
if (dlen == 0) {
for (i = wlen - 1; i >= 2; i--) {
if (s->z1[i] < s->w1[i])
return 1;
if (s->z1[i] > s->w1[i])
break;
}
}
kl = s->w1[0] * 2 + 1;
kh = kl >> DIGIT1_SHIFT;
kl &= DIGIT1_MASK;
s->z2[0] = s->z1[0] - kl;
s->z2[1] = s->z1[1] - kh - s->w1[1];
for (i = 2; i < wlen; i++)
s->z2[i] = s->z1[i] - s->w1[i];
for (; i < zlen; i++)
s->z2[i] = s->z1[i];
return intsqrt_adjust_after_sub(s->z2, zlen);
}
inline static void
intsqrt_large_step(intsqrt_large_state *s)
{
digit2_t msw, ssw;
digit2_t msz, ssz;
digit2_t zdw;
digit4_t w4;
digit4_t z4;
Py_ssize_t dlen;
s->z1[0] = s->x1[0];
s->z1[1] = s->x1[1];
dlen = s->zlen - s->wlen;
if (dlen < 0)
return;
msw = 0;
ssw = intsqrt_digit2_get(s->w1 + s->wlen - 2);
zdw = 0;
if (dlen == 0) {
ssz = intsqrt_digit2_get(s->z1 + s->zlen - 2);
msz = 0;
zdw = ssz / ssw;
}
else {
ssz = intsqrt_digit2_get(s->z1 + s->zlen - 3);
msz = s->z1[s->zlen - 1];
z4 = intsqrt_digit2to4(msz, ssz);
w4 = intsqrt_digit4_div(z4, ssw).quot;
zdw = (digit2_t) w4;
}
if (zdw == 0)
return;
if (zdw > DIGIT1_MASK) {
s->w1[0] = zdw = DIGIT1_MASK;
intsqrt_large_step_zsub1(s);
intsqrt_large_step_zswap(s);
}
else if (zdw == 1) {
s->w1[0] = zdw;
if (intsqrt_large_step_zsub1(s)) {
s->w1[0] = 0;
return;
}
intsqrt_large_step_zswap(s);
}
else {
s->w1[0] = --zdw;
intsqrt_large_step_zsub1(s);
intsqrt_large_step_zswap(s);
if (!intsqrt_large_step_zsub2(s)) {
s->w1[0] = ++zdw;
intsqrt_large_step_zswap(s);
if (!intsqrt_large_step_zsub2(s)) {
s->w1[0] = ++zdw;
intsqrt_large_step_zswap(s);
}
}
}
s->y1[0] = zdw;
s->w1[0] += zdw;
intsqrt_adjust_after_add1(s->w1, s->wlen);
}
static PyObject *
intsqrt_large(PyLongObject *x, int remobj)
{
PyObject *retval = NULL;
PyObject *tuple = NULL;
PyLongObject *w1 = NULL;
PyLongObject *x1 = x;
PyLongObject *y1 = NULL;
PyLongObject *z1 = NULL;
PyLongObject *z2 = NULL;
PyLongObject *zp = NULL;
intsqrt_large_state st, *s = &st;
s->xsize = Py_SIZE(x);
s->zsize = (s->xsize + 1) & ~1;
s->ysize = s->zsize >> 1;
s->wsize = s->ysize + 1;
if (!(w1 = NewLong(s->wsize))) goto error;
if (!(y1 = NewLong(s->ysize))) goto error;
if (!(z1 = NewLong(s->zsize))) goto error;
if (!(z2 = NewLong(s->zsize))) goto error;
s->wlen = 2;
s->ylen = 1;
s->zlen = 2;
s->w1 = w1->ob_digit + s->wsize - s->wlen;
s->x1 = x1->ob_digit + s->xsize;
s->y1 = y1->ob_digit + s->ysize - s->ylen;
s->z1 = z1->ob_digit + s->zsize - s->zlen;
s->z2 = z2->ob_digit + s->zsize - s->zlen;
s->z_swap = 0;
intsqrt_large_first(s);
intsqrt_large_fix_zlen(s);
if (s->ylen < s->ysize) {
intsqrt_large_next(s);
intsqrt_large_second(s);
intsqrt_large_fix_zlen(s);
while (s->ylen < s->ysize) {
intsqrt_large_next(s);
intsqrt_large_step(s);
intsqrt_large_fix_zlen(s);
}
}
if (!remobj) {
retval = (PyObject *) y1;
goto success1;
}
if ((s->z_swap & 1)) {
zp = z1;
z1 = z2;
z2 = zp;
}
Py_SIZE(z1) = s->zlen;
if (!(tuple = PyTuple_New(2)))
goto error;
PyTuple_SET_ITEM(tuple, 0, (PyObject *) y1);
PyTuple_SET_ITEM(tuple, 1, (PyObject *) z1);
retval = tuple;
goto success2;
error:
Py_XDECREF(tuple);
Py_XDECREF(y1);
success1:
Py_XDECREF(z1);
success2:
Py_XDECREF(w1);
Py_XDECREF(z2);
return retval;
}
static PyObject *
intsqrt(PyLongObject *x, int remobj)
{
PyObject *tuple = NULL;
PyObject *root = NULL;
PyObject *rem = NULL;
intsqrt_digit1_result res1;
intsqrt_digit2_result res2;
digit2_t t;
int bits;
if (Py_SIZE(x) == 0) {
if (!(root = PyLong_FromUnsignedLong(0)))
goto error;
if (remobj)
rem = IncRef(root);
}
else if (Py_SIZE(x) == 1) {
bits = lazy_bit_length(x);
res1 = intsqrt_digit1(x->ob_digit[0], bits);
if (!(root = PyLong_FromUnsignedLong(res1.root)))
goto error;
if (remobj && !(rem = PyLong_FromUnsignedLong(res1.rem)))
goto error;
}
else if (Py_SIZE(x) == 2) {
bits = lazy_bit_length(x);
t = intsqrt_digit2_get(x->ob_digit);
res2 = intsqrt_digit2(t, bits);
if (!(root = PyLong_FromUnsignedLongLong(res2.root)))
goto error;
if (remobj && !(rem = PyLong_FromUnsignedLong(res2.rem)))
goto error;
}
else
return intsqrt_large(x, remobj);
if (!remobj)
return root;
if (!(tuple = PyTuple_New(2)))
goto error;
PyTuple_SET_ITEM(tuple, 0, root);
PyTuple_SET_ITEM(tuple, 1, rem);
return tuple;
error:
Py_XDECREF(tuple);
Py_XDECREF(root);
Py_XDECREF(rem);
return NULL;
}
static PyObject *
method_intsqrt(PyObject *unused, PyObject *const *args, Py_ssize_t size)
{
PyObject *x, *y, *r;
PyNumberMethods *as_number;
inquiry nb_bool;
unaryfunc nb_int;
int t, remobj;
(void)unused;
remobj = 0;
if (size < 1) {
PyErr_SetString(PyExc_TypeError, "intsqrt expected at least 1 argument, got 0");
return NULL;
}
if (size > 2) {
PyErr_SetString(PyExc_TypeError, "intsqrt expected at most 2 arguments");
return NULL;
}
if (size == 2) {
r = args[1];
if (r &&
(as_number = Py_TYPE(r)->tp_as_number) &&
(nb_bool = as_number->nb_bool) &&
nb_bool(r))
remobj = 1;
}
x = IncRef(args[0]);
for (t = 0;; t++) {
if (PyLong_CheckExact(x))
break;
if (t >= 256)
goto type_error;
if (!(as_number = Py_TYPE(x)->tp_as_number))
goto type_error;
if (!(nb_int = as_number->nb_int))
goto type_error;
if (!(y = nb_int(x)))
goto type_error;
Py_DECREF(x);
x = y;
}
if (Py_SIZE(x) < 0) {
PyErr_SetString(PyExc_ArithmeticError, "(x < 0)");
Py_DECREF(x);
return NULL;
}
y = intsqrt((PyLongObject *) x, remobj);
Py_DECREF(x);
return y;
type_error:
PyErr_SetObject(PyExc_TypeError, args[0]);
Py_DECREF(x);
return NULL;
}
/*
* Module: intsqrt
*/
static PyMethodDef intsqrt_methods[] = {
{"intsqrt", (PyCFunction) method_intsqrt, METH_FASTCALL,
"intsqrt(x: int) -> int: int(√x)\n"
"intsqrt(x: int, True) -> tuple: (int(√x), x - int(√x)**2)\n"},
{NULL, NULL, 0, NULL}, /* end */
};
static PyModuleDef intsqrt_def = {
PyModuleDef_HEAD_INIT,
.m_name = "intsqrt",
.m_doc = NULL,
.m_size = -1,
.m_methods = intsqrt_methods,
};
PyMODINIT_FUNC
PyInit_intsqrt(void)
{
return PyModule_Create(&intsqrt_def);
}
/*
* Local Variables:
* c-file-style: "PEP7"
* End:
*/
#!/usr/bin/env python3
import os
from distutils.core import setup, Extension
def getenv(name, defval=None):
if name in os.environ:
return os.environ[name]
return defval
DEBUG = getenv('DEBUG') in ('true', 'yes')
MAJOR_VERSION = 0
MINOR_VERSION = 1
DEBUG_VERSION = 0
VERSION = '%d.%d.%d' % (MAJOR_VERSION, MINOR_VERSION, DEBUG_VERSION)
DEFINE_MACROS = [
('MAJOR_VERSION', MAJOR_VERSION),
('MINOR_VERSION', MINOR_VERSION),
('DEBUG_VERSION', DEBUG_VERSION),
]
UNDEF_MACROS = []
EXTRA_COMPILE_ARGS = [
'-W',
'-Wall',
'-Wno-invalid-offsetof',
'-Wno-deprecated-declarations',
]
if DEBUG:
DEFINE_MACROS.append(('DEBUG', 1))
UNDEF_MACROS.append('NDEBUG')
EXTRA_COMPILE_ARGS.append('-O0')
setup(name='intsqrt',
version=VERSION,
description='',
ext_modules=[Extension(
name='intsqrt',
define_macros=DEFINE_MACROS,
undef_macros=UNDEF_MACROS,
extra_compile_args=EXTRA_COMPILE_ARGS,
sources=['intsqrt.c'])])
$ python3 setup.py build
running build
running build_ext
building 'intsqrt' extension
creating build
creating build/temp.macosx-10.14-x86_64-3.8
clang -Wno-unused-result -Wsign-compare -Wunreachable-code -fno-common -dynamic -DNDEBUG -g -fwrapv -O3 -Wall -iwithsysroot/System/Library/Frameworks/System.framework/PrivateHeaders -iwithsysroot/Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.8/Headers -arch arm64 -arch x86_64 -Werror=implicit-function-declaration -DMAJOR_VERSION=0 -DMINOR_VERSION=1 -DDEBUG_VERSION=0 -I/Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.8/include/python3.8 -c intsqrt.c -o build/temp.macosx-10.14-x86_64-3.8/intsqrt.o -W -Wall -Wno-invalid-offsetof -Wno-deprecated-declarations
creating build/lib.macosx-10.14-x86_64-3.8
clang -bundle -undefined dynamic_lookup -arch arm64 -arch x86_64 -Wl,-headerpad,0x1000 build/temp.macosx-10.14-x86_64-3.8/intsqrt.o -o build/lib.macosx-10.14-x86_64-3.8/intsqrt.cpython-38-darwin.so
$
$ (cd build/lib.macosx-10.14-x86_64-3.8 ; python3)
>>> from intsqrt import intsqrt
>>> intsqrt(99)
9
>>> intsqrt(99, True)
(9, 18)
>>> 1<<201
3213876088517980551083924184682325205044405987565585670602752
>>> intsqrt(1<<201)
1792728671193156477399422023278
>>> r = intsqrt(1<<201, True)
>>> r
(1792728671193156477399422023278, 2371767103687091674094496737468)
>>> r[0]**2+r[1]
3213876088517980551083924184682325205044405987565585670602752
検証・確認に用いた Python プログラム
開平法と計算を可視化するプログラム(折りたたみ)
# 開平法 2進数
def bsqrt(x):
if x < 0:
raise ArithmeticError
l, r, s = x.bit_length() >> 1, 0, 0
for b in range(l, -1, -1):
r <<= 1
t = 1 << b
u = s + t
v = u << b
if x < v:
continue
x -= v
r |= 1
s = u + t
return (r, x)
# 開平法 N進数
def esqrt(x, N=10):
def to_dig(x):
dig = [x % N]
c = x // N
while c > 0:
dig.append(c % N)
c //= N
return dig
def to_rdig(x):
return list(reversed(to_dig(x)))
if N < 2:
raise ArithmeticError
NN = N * N
NL = N - 1
xdig = to_dig(x)
if len(xdig) & 1:
xdig.append(0)
xdig = list(reversed(xdig))
step = []
# 最初の桁
d = xdig[:2]
dig = xdig[2:]
s = d[0] * N + d[1]
root, rrem = bsqrt(s) # 2進数版を使う
rsub = root * 2
rstep = [to_rdig(root), to_rdig(rrem), to_rdig(rsub), to_rdig(s), to_rdig(root**2)]
step.append(rstep)
if not len(dig):
return [[root, to_rdig(root)], [rrem, to_rdig(rrem)], [xdig, step]]
# 残りの桁
col = 1
while len(dig):
col += 1
root *= N
rrem *= NN
rsub *= N
d = dig[:2]
dig = dig[2:]
rrem += d[0] * N + d[1]
prem = rrem
ns = rrem // (rsub + NL)
ne = rrem // (rsub + 1)
def find(ns, ne):
if ne == 0:
return (0, 0)
if ns >= NL:
return (NL, (rsub + NL) * NL)
while (ne - ns) > 1:
np = (ns + ne) >> 1
pp = (rsub + np) * np
if pp < rrem:
ns = np
else:
ne = np - 1
ts = (rsub + ns) * ns
if ns != ne:
te = (rsub + ne) * ne
if rrem >= te:
ns = ne
ts = te
return (ns, ts)
ns, ts = find(ns, ne)
if ns:
root += ns
rrem -= ts
rsub += ns * 2
rstep = [to_rdig(root), to_rdig(rrem), to_rdig(rsub), to_rdig(prem), to_rdig(ts)]
step.append(rstep)
res = [[root, to_rdig(root)], [rrem, to_rdig(rrem)], [xdig, step]]
if root != bsqrt(x)[0]:
raise Exception('Bug!!: esqrt(%d,%d)' % (x, N))
return res
# N進数開平法(筆算型)を可視化
def esqrt_format(x, N=10, xdec=False):
if N < 10:
NUM_FMT = '%%%dd' % (N - 1)
elif N <= 16:
NUM_FMT = '%%0%dx' % (N - 1)
else:
XDEC = 'dx'[int(xdec)]
NUM_FMT = ('%%0%d' + XDEC) % len(('%' + XDEC) % (N - 1))
def nfmt(n):
if N <= 16:
return ''.join('0123456789abcdef'[v] for v in n)
return '_'.join(NUM_FMT % v for v in n)
root, rrem, (x, step) = esqrt(x, N)
mrem = [s[3] for s in step]
msub = [s[4] for s in step]
sans = [s[0][-1] for s in step]
srem = [s[1] for s in step]
ssub = [[0]] + [s[2] for s in step]
sdiff = [s[3] for s in step]
nans = len(sans)
wn1 = len(nfmt([0]))
wn2 = len(nfmt([0, 0])) - wn1
def nwidth(column):
return wn1 + wn2 * (column - 1)
def del0(n):
while len(n) > 1 and n[0] == 0:
n = n[1:]
return n
def dfmt(n):
nn = len(n)
s = []
if nn & 1:
s = [(' ' * wn2) + nfmt(n[:1])]
n = n[1:]
nn -= 1
s += [nfmt(n[i:i+2]) for i in range(0, nn, 2)]
return ' '.join(s)
mr = ' '.join(' ' * wn2 + nfmt([n]) for n in step[-1][0])
mx = dfmt(del0(x))
ml = '-' * len(mx)
cmain = [mr, ml, mx]
for i in range(nans):
wc = i + nwidth(2) * (i + 1)
mfmt = '%%%ds' % wc
dr = dfmt(mrem[i])
ds = dfmt(msub[i])
dc = len(ds)
if i:
cmain.append(mfmt % dr)
dc = max(dc, len(dr))
cmain.append(mfmt % ds)
if (i + 1) != nans:
o = nwidth(2) + 1
dc += o
wc += o
mfmt = '%%%ds' % wc
cmain.append(mfmt % ('-' * dc))
cmain.append(mfmt % dfmt(rrem[1]))
mfmt = '%%-%ds' % len(cmain[0])
for n in range(nwidth(2)):
if sum(int(cmain[i][n] != ' ') for i in (0, 2, 3)):
break
cmain = [' ' + (mfmt % s)[n:] for s in cmain]
cmain[1] = '--' + cmain[1][2:]
cmain[2] = ')' + cmain[2][1:]
csub = []
wrs = nwidth(len(ssub[-1]) + 1)
for i in range(nans):
ws = nwidth(i + 2)
sfmt = ('%%%ds' % ws)
da = [sans[i]]
sa = sfmt % nfmt(da)
ss = sfmt % nfmt(del0(ssub[i] + da))
sl = '-' * nwidth(i + 2 + int((i + 1) < nans))
csub.append(ss + ' ' * (wrs - len(ss)))
csub.append(sa + ' ' * (wrs - len(sa)))
csub.append(sl + ' ' * (wrs - len(sl)))
csub.append(sfmt % nfmt(del0(ssub[-1])))
for n in range(nwidth(2)):
if sum(int(csub[i][n] != ' ') for i in (0, 1, 3)):
break
csub = [l[n:] for l in csub]
mspc = ' ' * len(cmain[0])
sspc = ' ' * len(csub[0])
csub = [sspc, sspc] + csub
for i in range(max(len(cmain), len(csub))):
m = mspc if i >= len(cmain) else cmain[i]
s = sspc if i >= len(csub) else csub[i]
print(m, ' ', s)
>>> esqrt_format(11223344) # デフォルトは10進数の開平法
3 3 5 0
-------------
) 11 22 33 44 3
9 3
----- --
2 22 63
1 89 3
-------- ---
33 33 665
33 25 5
-------- ----
8 44 6700
0 0
----- ----
8 44 6700
>>> esqrt_format(11223344, 2) # 2進数の開平法
1 1 0 1 0 0 0 1 0 1 1 0
-------------------------------------
) 10 10 10 11 01 00 00 01 00 11 00 00 1
1 1
----- ---
1 10 101
1 01 1
-------- ----
1 10 1100
0 0
-------- -----
1 10 11 11001
1 10 01 1
----------- ------
10 01 110100
0 0
-------- -------
10 01 00 1101000
0 0
----------- --------
10 01 00 00 11010000
0 0
-------------- ---------
10 01 00 00 01 110100001
1 10 10 00 01 1
----------------- ----------
10 10 00 00 00 1101000100
0 0
----------------- -----------
10 10 00 00 00 11 11010001001
1 10 10 00 10 01 1
-------------------- ------------
11 01 11 10 10 00 110100010101
11 01 00 01 01 01 1
-------------------- -------------
11 01 00 11 00 1101000101100
0 0
-------------- -------------
11 01 00 11 00 1101000101100
>>> esqrt_format(11223344, 16) # 16進数の開平法
d 1 6
----------
) ab 41 30 d
a9 d
----- ---
2 41 1a1
1 a1 1
-------- ----
a0 30 1a26
9c e4 6
----- ----
3 4c 1a2c
intsqrt(N) != bsqrt(N) のとき、esqrt_format(N,1<<30) で出てくる計算手順と違うところを探しました。
計算速度ならば gmpy2.isqrt を使った方が良いです。