1
0

More than 1 year has passed since last update.

x86: SSE4.2 文字列検索用 PCMPESTRI / PCMPESTRM / PCMPISTRI / PCMPISTRM 命令

Last updated at Posted at 2021-09-15

SSE4.2 にある、文字(列)検索に適した 4 つの命令

  PCMPESTRI , PCMPESTRM , PCMPISTRI , PCMPISTRM

の動作をマニュアルを見つつ調べてみます。

4つの PCMPxSTRy 命令

まとめて PCMP<x>STR<y> 命令とすると、アセンブリ言語では

  PCMP<x>STR<y>    xmm1, xmm2, imm8
  PCMP<x>STR<y>    xmm1, m128, imm8

の記述になっています。 xmm1 が検索内容で、xmm2/m128 が検索対象です。

<x> と <y> は入力形式(x:E または I)と出力形式(y:I または M)の指定です。

入力形式<x> と出力形式<y>

指定したオペランドとフラグ レジスタのほかに、入力形式 <x> と出力形式 <y> により決まったレジスタが使用されます。

- y = I : Index y = M : Mask 入力用レジスタ
x = E : Explicit PCMPESTRI PCMPESTRM EAX(RAX) と EDX(RDX)
x = I : Implicit PCMPISTRI PCMPISTRM なし
出力用レジスタ ECX(RCX) XMM0 -

文字の種類と数

オペランドに指定する xmm1 および xmm2/m128 のデータ形式は、第3オペランド imm8 の下位 2 ビット

  imm8[0] = 文字のビット幅
  imm8[1] = 符号の有無

で指定します。

imm8 +0(8 ビット) +1(16 ビット)
+0(符号なし) 符号なし 8 ビット 16 文字
uint8_t [16]
符号なし 16 ビット 8 文字
uint16_t [8]
+2(符号あり) 符号あり 8 ビット 16 文字
int8_t [16]
符号あり 16 ビット 8 文字
int16_t [8]

以下、文字数(8 ビットなら 16 文字、16 ビットなら 8 文字)を PLEN とします。

有効な文字数

各オペランドと有効な文字数は以下のとおりです。

命令 \ オペランド xmm1 xmm2/m128 内容
x=E 形式
PCMPESTR<y>
EAX(RAX) EDX(RDX) レジスタによる指定
x=I 形式
PCMPISTR<y>
strlen(xmm1) strlen(xmm2/m128) NULL(を含まない)文字まで

レジスタによる指定では、負の値も絶対値として正の値に変換されるので有効です。実際に有効な文字数の上限は imm8 に指定した文字数 PLEN になります。

以下、有効な文字数を

  XLEN1 を第 1 オペランド xmm1 に対するもの
  XLEN2 を第 2 オペランド xmm2/m128 に対するもの

とします。

動作モード(Mode)

検索の目的を imm8 の第 2 および第 3 ビットで指定します。

Mode
imm8[3:2]
比較方法 内容
0 一致 xmm1 を文字セットとします
1 範囲 xmm1 を文字の範囲情報とします
2 一致 xmm1 と xmm2/m128 の配列一致
3 一致 xmm1 を部分文字列として検索します

比較方法は、"一致" と "範囲" の 2 種類になります。

比較判定表(BoolRes)

xmm2/m128 の 1 文字に対して xmm1 の PLEN 文字と比較し、比較判定表(PLEN×PLEN)が作られます。

以下の比較判定表では、横軸を xmm1、縦軸を xmm2/m128 とします。

比較方法:一致(eq)

BoolRes 1 文字目 PLEN 文字目
1 文字目 eq eq
PLEN 文字目 eq eq

比較方法:範囲(ge,le)

xmm1 は 2 文字で 1 つの範囲を構成します。よって、最大で (PLEN÷2) 個の範囲を指定できます。

BoolRes 1 文字目
範囲開始
2 文字目
範囲終了
PLEN-1 文字目
範囲開始
PLEN 文字目
範囲終了
1 文字目 ge le ge le
PLEN 文字目 ge le ge le

有効文字数による修正

xmm1 では XLEN1 文字より後を、xmm2/m128 では XLEN2 文字より後を固定の真偽値として比較判定表(BoolRes)が修正されます。(以下の表では「文字目」を省きます)

Mode 0

BoolRes 1 XLEN1 XLEN1+1 PLEN
1 eq eq false false
XLEN2 eq eq false false
XLEN2+1 false false false false
PLEN false false false false

Mode 1

BoolRes 1 XLEN1 XLEN1+1 PLEN
1 ge ge/le false false
XLEN2 ge ge/le false false
XLEN2+1 false false false false
PLEN false false false false

Mode 0 と 1 の修正は同じです。

Mode 2

BoolRes 1 XLEN1 XLEN1+1 PLEN
1 eq eq false false
XLEN2 eq eq false false
XLEN2+1 false false true true
PLEN false false true true

Mode 3

BoolRes 1 XLEN1 XLEN1+1 PLEN
1 eq eq true true
XLEN2 eq eq true true
XLEN2+1 false false true true
PLEN false false true true

結果生成(その1:IntRes1)

修正された比較判定表(BoolRes)から xmm2/m128 の各文字に対する真偽値表(IntRes1)が生成されます。値は PLEN ビット列で表現されます。IntRes1 があるということは、加工された IntRes2 も生成されます。

以下の表では XLEN2 文字目より後の結果が固定になっていますが、これは XLEN1 や XLEN2 が PLEN より小さい場合で、比較判定表が修正された場合に限ります。

Mode 0:文字セット

xmm2/m128 の文字が xmm1 に含まれていたら true になります。

BoolRes 1 XLEN1 XLEN1+1 PLEN 結果(IntRes1)
1 eq eq false false この行の論理和
XLEN2 eq eq false false この行の論理和
XLEN2+1 false false false false false
PLEN false false false false false
C
IntRes1 = 0;
for (x2 = 0; x2 < PLEN; x2++)
  for (x1 = 0; x1 < PLEN; x1++)
    IntRes1 |= BoolRes[x2][x1] << x2;

上記でのインデックスは 0 からです。

Mode 1:範囲

xmm2/m128 の文字は、xmm1 にある最大(PLEN÷2)個の範囲指定のいずれかならば true になります。

BoolRes 1 と 2 3 と 4 XLEN1 XLEN1+1 PLEN 結果(IntRes1)
1 ge and le ge and le ge/le false false この行の論理和
XLEN2 ge and le ge and le ge/le false false この行の論理和
XLEN2+1 false false false false false false
PLEN false false false false false false

各行は「開始と終了の対を論理積」とした後の(PLEN÷2)個の論理和になります。

C
IntRes1 = 0;
for (x2 = 0; x2 < PLEN; x2++)
  for (x1 = 0; x1 < PLEN; x1 += 2)
    IntRes1 |= (BoolRes[x2][x1] & BoolRes[x2][x1+1]) << x2;

上記でのインデックスは 0 からです。

Mode 2:配列一致

文字列を配列で比較し、一致する文字を選択します。よって、表の左上から右下への対角が結果になります。

BoolRes 1 XLEN1 XLEN1+1 PLEN 結果(IntRes1)
1 eq eq false false 1 文字目
XLEN2 eq eq false false XLEN2 文字目
XLEN2+1 false false true true XLEN1 と XLEN2 が同じ場合は true
PLEN false false true true true
C
IntRes1 = 0;
for (x2 = 0; x2 < PLEN; x2++)
  IntRes1 |= BoolRes[x2][x2] << x2;

上記でのインデックスは 0 からです。

Mode 3:部分一致

xmm1 の文字列が xmm2/m128 に含まれている場合の開始位置が true になります。

以下の比較判定表

BoolRes 1 XLEN1 XLEN1+1 PLEN
1 eq eq true true
XLEN2 eq eq true true
XLEN2+1 false false true true
PLEN false false true true

に対して

BoolRes 1 PLEN
1
N true
true
PLEN true

となっているならば N 文字目は true になります。

xmm1="ab", xmm2/m128="abcab" だと、2箇所の 'a' の位置が true です。

C
IntRes1 = ((1 << PLEN) - 1);
for (x2 = 0; x2 < PLEN; x2++)
  for (x1 = 0; x1 < (PLEN - x2); x1++)
    IntRes1 &= (BoolRes[x2+x1][x1] << x2) | ~(1 << x2);

上記でのインデックスは 0 からです。

結果生成(その2:IntRes2)

第 3 オペランド imm8 の第 4 および第 5 ビットの指定により IntRes1 を加工して IntRes2 が生成されます。

Polarity
imm8[5:4]
操作 内容
0 Positive Polarity IntRes2 = IntRes1
1 Negative Polarity IntRes2 = ~IntRes1
2 Masked (+) IntRes2 = IntRes1
3 Masked (-) IntRes2 = ~IntRes1 ^ ((1 << PLEN) - (1 << XLEN2))

結果出力

命令によって出力先と内容が異なります。

形式 レジスタ 内容 命令
y=I 形式 ECX(RCX) 文字位置 PCMPESTRI , PCMPISTRI
y=M 形式 XMM0 ビット列 PCMPESTRM , PCMPISTRM

y=I 形式 PCMP<x>STRI 命令

ECX(RCX) レジスタの内容は、第 3 オペランド imm8 第6 ビットの指定により変化します。

imm8[6] ECX(RCX) レジスタ
0 IntRes2 で "1" となっている最小ビット位置: BSF(IntRes2) 相当
1 IntRes2 で "1" となっている最大ビット位置: BSR(IntRes2) 相当

y=M 形式 PCMP<x>STRM 命令

XMM0 レジスタの内容は、第 3 オペランド imm8 第6 ビットの指定により変化します。

imm8[6] XMM0 レジスタ
0 IntRes2 を 128 ビットにゼロ拡張
1 IntRes2 の各ビットを文字単位に拡張する
(PCMPEQB や PCMPEQW と同じ形式)

フラグ・レジスタ

内容
CF IntRes2 != 0
ZF XLEN2 < PLEN
SF XLEN1 < PLEN
OF IntRes2 & 1
AF 0
PF 0

まとめ

PCMP<x>STR<y> xmm1, xmm2, imm8
PCMP<x>STR<y> xmm1, m128, imm8

命令 op1 op2 op3
PCMPESTRI xmm1 xmm2
m128
imm8
PCMPESTRM xmm1 xmm2
m128
imm8
PCMPISTRI xmm1 xmm2
m128
imm8
PCMPISTRM xmm1 xmm2
m128
imm8

m128 のアドレスは 16 バイト境界でなくてもよいハズ。

命令形式と固定レジスタ

y = I : Index y = M : Mask 入力用レジスタ
x = E : Explicit PCMPESTRI PCMPESTRM EAX(RAX) と EDX(RDX)
x = I : Implicit PCMPISTRI PCMPISTRM なし
出力用レジスタ ECX(RCX) XMM0 -

x=E 形式の有効文字数

レジスタ 対象
EAX(RAX) xmm1
EDX(RDX) xmm2/m128

負の場合は、絶対値として正の値に変換されます。

第 3 オペランド imm8

ビット 7 6 5:4 3:2 1:0
内容 - 出力形式 Polarity Mode データ形式
データ形式
imm8[1:0]
+0(8 ビット) +1(16 ビット)
+0(符号なし) 符号なし 8 ビット 16 文字
uint8_t [16]
符号なし 16 ビット 8 文字
uint16_t [8]
+2(符号あり) 符号あり 8 ビット 16 文字
int8_t [16]
符号あり 16 ビット 8 文字
int16_t [8]
Mode
imm8[3:2]
比較方法 内容
0 一致 xmm1 を文字セットとします
1 範囲 xmm1 を文字の範囲情報とします
範囲は最小と最大の 2 文字で一組とする
2 一致 xmm1 と xmm2/m128 の配列一致
3 一致 xmm1 を部分文字列として検索します
Polarity
imm8[5:4]
操作 内容
0 Positive Polarity IntRes2 = IntRes1
1 Negative Polarity IntRes2 = ~IntRes1
2 Masked (+) IntRes2 = IntRes1
3 Masked (-) IntRes2 = ~IntRes1 ^ ((1 << PLEN) - (1 << XLEN2))

出力形式: y=I 形式 PCMP<x>STRI 命令

imm8[6] ECX(RCX) レジスタ
0 IntRes2 で "1" となっている最小ビット位置: BSF(IntRes2) 相当
1 IntRes2 で "1" となっている最大ビット位置: BSR(IntRes2) 相当

出力形式: y=M 形式 PCMP<x>STRM 命令

imm8[6] XMM0 レジスタ
0 IntRes2 を 128 ビットにゼロ拡張
1 IntRes2 の各ビットを文字単位に拡張する
(PCMPEQB や PCMPEQW と同じ形式)

フラグ・レジスタ

内容
CF IntRes2 != 0
ZF XLEN2 < PLEN
SF XLEN1 < PLEN
OF IntRes2 & 1
AF 0
PF 0

テスト・プログラム

動作環境
  macOS BigSur 11.6
コンパイラ
  Apple clang version 12.0.5 (clang-1205.0.22.11)

C++ ソース・コード
test_pcmpxstry.cpp
#include <cctype>
#include <cinttypes>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <nmmintrin.h>

/*
 *
 */

static const char *program = NULL;
static bool flag_verbose = false;

/*
 *
 */

union __x128
{
    __m128 v;
    __m128i vi;
    __m128d vd;

    int8_t i8[16];
    int16_t i16[8];
    int32_t i32[4];
    int64_t i64[2];

    uint8_t u8[16];
    uint16_t u16[8];
    uint32_t u32[4];
    uint64_t u64[2];

    void clear()
    {
        i64[0] = 0;
        i64[1] = 0;
    }
};

/*
 *
 */

struct res_pcmpxstrx
{
    __m128 m;
    uint64_t i;
    uint64_t f;

    void clear()
    {
        m = _mm_setzero_si128();
        i = 0;
        f = 0;
    }
};

#define CASE_IMM8_1(name, imm8)  case imm8: return name<imm8>(xmm1, xmm2, regA, regD)
#define CASE_IMM8_2(name, imm8)  CASE_IMM8_1(name, imm8 + 0); CASE_IMM8_1(name, imm8 + 1)
#define CASE_IMM8_4(name, imm8)  CASE_IMM8_2(name, imm8 + 0); CASE_IMM8_2(name, imm8 + 2)
#define CASE_IMM8_8(name, imm8)  CASE_IMM8_4(name, imm8 + 0); CASE_IMM8_4(name, imm8 + 4)
#define CASE_IMM8_16(name, imm8)  CASE_IMM8_8(name, imm8 + 0); CASE_IMM8_8(name, imm8 + 8)
#define CASE_IMM8_32(name, imm8)  CASE_IMM8_16(name, imm8 + 0); CASE_IMM8_16(name, imm8 + 16)
#define CASE_IMM8_64(name, imm8)  CASE_IMM8_32(name, imm8 + 0); CASE_IMM8_32(name, imm8 + 32)
#define CASE_IMM8_128(name, imm8)  CASE_IMM8_64(name, imm8 + 0); CASE_IMM8_64(name, imm8 + 64)
#define CASE_IMM8_256(name)        CASE_IMM8_128(name, 0); CASE_IMM8_128(name, 128)
#define SWITCH_CASE_256(name, imm8)  switch (imm8) { CASE_IMM8_256(name); default: break; }

#define IMPL_PCMPxSTRx(name)                                            \
template <int imm8>                                                     \
inline static res_pcmpxstrx name(__m128i xmm1, __m128i xmm2, int regA, int regD) \
{                                                                       \
    res_pcmpxstrx r;                                                    \
                                                                        \
    __m128i mask;                                                       \
    uint64_t index;                                                     \
    uint64_t flags;                                                     \
                                                                        \
    __asm__ volatile ("xorq %%rcx, %%rcx": "=c" (index));               \
    __asm__ volatile ("pxor %%xmm0, %%xmm0": "=Yz" (mask));             \
                                                                        \
    __asm__ volatile (#name "\t" "%[imm8],%[xmm2],%[xmm1]\n"            \
                      "pushfq\n" "popq" "\t" "%[freg]"                  \
                      : [freg] "=r" (flags), "=c" (index), "=Yz" (mask) \
                      : [xmm1] "x" (xmm1), [xmm2] "x" (xmm2), [imm8] "N" (imm8), \
                        "a" (regA), "d" (regD));                        \
                                                                        \
    r.m = mask;                                                         \
    r.i = index;                                                        \
    r.f = flags & 0x08d5; /* 0000_1000_1101_0101 */                     \
    return r;                                                           \
}                                                                       \
static res_pcmpxstrx name(__m128i xmm1, int regA, __m128i xmm2, int regD, int imm8) \
{                                                                       \
    SWITCH_CASE_256(name, imm8);                                        \
    res_pcmpxstrx r;                                                    \
    r.clear();                                                          \
    return r;                                                           \
}

IMPL_PCMPxSTRx(pcmpestri)
IMPL_PCMPxSTRx(pcmpestrm)
IMPL_PCMPxSTRx(pcmpistri)
IMPL_PCMPxSTRx(pcmpistrm)

typedef res_pcmpxstrx (*pcmpxstrx)(__m128i, int, __m128i, int, int);

static __m128i setbxmm(const char *s)
{
    __x128 u;

    int l = (s == NULL) ?  0 : strlen(s);
    for (int i = 0; i < 16; i++)
        u.i8[i] = (i < l) ? s[i] : 0;
    return u.vi;
}

static __m128i setwxmm(const char *s)
{
    __x128 u;

    int l = (s == NULL) ?  0 : strlen(s);
    for (int i = 0; i < 8; i++)
        u.i16[i] = (i < l) ? s[i] : 0;
    return u.vi;
}

static int strlen_x128b(const __x128 &s)
{
    for (int i = 0; i < 16; i++)
        if (s.i8[i] == 0)
            return i;
    return 16;
}

static int strlen_x128w(const __x128 &s)
{
    for (int i = 0; i < 8; i++)
        if (s.i16[i] == 0)
            return i;
    return 8;
}

typedef bool BoolRes[16][16];

static res_pcmpxstrx PCMPxSTRy(bool implicit /* <x> */,
                               bool mask /* <y> */,
                               __m128i xmm1, __m128i xmm2, uint8_t imm8,
                               int64_t rax, int64_t rdx,
                               BoolRes b_res,
                               uint16_t &i_res1,
                               uint16_t &i_res2,
                               int64_t &rcx)
{
    res_pcmpxstrx res;

    __x128 s1;
    __x128 s2;

    bool f_word = !!(imm8 & 0x01);
    bool f_signed = !!(imm8 & 0x02);
    int mode = (imm8 >> 2) & 3;
    bool f_negative = !!(imm8 & 0x10);
    bool f_masked = !!(imm8 & 0x20);
    bool f_outsel = !!(imm8 & 0x40);
    int ecnt = !f_word ? 16 : 8;
    // bool b_res[16][16];

    res.clear();
    s1.v = xmm1;
    s2.v = xmm2;

    for (int i2 = 0; i2 < 16; i2++)
        for (int i1 = 0; i1 < 16; i1++)
            b_res[i2][i1] = false;

    int32_t c1[16];
    int32_t c2[16];

    for (int i = 0; i < 16; i++)
        c1[i] = c2[i] = 0;

    if (!f_signed)
    {
        if (!f_word)
        {
            for (int i = 0; i < 16; i++)
            {
                c1[i] = s1.u8[i];
                c2[i] = s2.u8[i];
            }
        }
        else
        {
            for (int i = 0; i < 8; i++)
            {
                c1[i] = s1.u16[i];
                c2[i] = s2.u16[i];
            }
        }
    }
    else
    {
        if (!f_word)
        {
            for (int i = 0; i < 16; i++)
            {
                c1[i] = s1.i8[i];
                c2[i] = s2.i8[i];
            }
        }
        else
        {
            for (int i = 0; i < 8; i++)
            {
                c1[i] = s1.i16[i];
                c2[i] = s2.i16[i];
            }
        }
    }

    if (!implicit)
    {
        if (rax < 0) rax = -rax;
        if (rdx < 0) rdx = -rdx;
    }
    else
    {
        rax = !f_word ? strlen_x128b(s1) : strlen_x128w(s1);
        rdx = !f_word ? strlen_x128b(s2) : strlen_x128w(s2);
    }
    if (rax > ecnt) rax = ecnt;
    if (rdx > ecnt) rdx = ecnt;

    int l1 = int(rax);
    int l2 = int(rdx);

    if (mode != 1)
    {
        for (int i2 = 0; i2 < ecnt; i2++)
            for (int i1 = 0; i1 < ecnt; i1++)
                b_res[i2][i1] = (c1[i1] == c2[i2]);
    }
    else
    {
        for (int i2 = 0; i2 < ecnt; i2++)
        {
            for (int i1 = 0; i1 < ecnt; i1 += 2)
            {
                b_res[i2][i1 + 0] = (c1[i1 + 0] <= c2[i2]);
                b_res[i2][i1 + 1] = (c1[i1 + 1] >= c2[i2]);
            }
        }
    }

    bool o1 = false;
    bool o2 = false;
    bool o3 = false;

    switch (mode)
    {
    case 3: o1 = true;
    case 2: o3 = true;
    default: break;
    }
    for (int i1 = l1; i1 < ecnt; i1++)
        for (int i2 = 0; i2 < l2; i2++)
            b_res[i2][i1] = o1;
    for (int i2 = l2; i2 < ecnt; i2++)
    {
        for (int i1 = 0; i1 < l1; i1++)
            b_res[i2][i1] = o2;
        for (int i1 = l1; i1 < ecnt; i1++)
            b_res[i2][i1] = o3;
    }

    bool i_res[16];

    for (int i = 0; i < 16; i++)
        i_res[i] = false;

    switch (mode)
    {
    case 0:
        for (int i2 = 0; i2 < ecnt; i2++)
            for (int i1 = 0; i1 < ecnt; i1++)
                i_res[i2] |= b_res[i2][i1];
        break;

    case 1:
        for (int i2 = 0; i2 < ecnt; i2++)
            for (int i1 = 0; i1 < ecnt; i1 += 2)
                i_res[i2] |= (b_res[i2][i1 + 0] & b_res[i2][i1 + 1]);
        break;

    case 2:
        for (int i = 0; i < ecnt; i++)
            i_res[i] = b_res[i][i];
        break;

    case 3:
        for (int i2 = 0; i2 < ecnt; i2++)
        {
            i_res[i2] = true;
            for (int i1 = 0; i1 < (ecnt - i2); i1++)
                i_res[i2] &= b_res[i2 + i1][i1];
        }
        break;

    default:
        break;
    }

    // uint16_t i_res1;

    i_res1 = 0;
    for (int i = 0; i < ecnt; i++)
        i_res1 |= (i_res[i] << i);

    if (f_negative)
    {
        int ncnt = !f_masked ? ecnt : l2;

        for (int i = 0; i < ncnt; i++)
            i_res[i] = !i_res[i];
    }

    // uint16_t i_res2;

    i_res2 = 0;
    for (int i = 0; i < ecnt; i++)
        i_res2 |= (i_res[i] << i);

    // int64_t rcx;

    rcx = ecnt;
    if (i_res2 != 0)
    {
        int i;

        if (!f_outsel)
        {
            for (i = 0; i < ecnt; i++)
                if (i_res[i])
                    break;
        }
        else
        {
            for (i = ecnt - 1; i >= 0; i--)
                if (i_res[i])
                    break;
        }
        rcx = i;
    }

    __x128 xmm0;

    xmm0.clear();
    if (!f_outsel)
        xmm0.u16[0] = i_res2;
    else if (!f_word)
        for (int i = 0; i < 16; i++)
            xmm0.i8[i] = (i_res[i] ? -1 : 0);
    else
        for (int i = 0; i < 8; i++)
            xmm0.i16[i] = (i_res[i] ? -1 : 0);

    if (!mask)
        res.i = rcx;
    else
        res.m = xmm0.v;

    if (i_res2 != 0)
        res.f |= (1 << 0);
    if (l2 < ecnt)
        res.f |= (1 << 6);
    if (l1 < ecnt)
        res.f |= (1 << 7);
    if ((i_res2 & 1))
        res.f |= (1 << 11);

    return res;
}

static const char *format_flag(uint64_t flags)
{
    static char msg[16];
    char *d = msg;

    *d++ = (flags & (1 <<  0)) ? 'C' : '-';
    *d++ = (flags & (1 <<  6)) ? 'Z' : '-';
    *d++ = (flags & (1 <<  7)) ? 'S' : '-';
    *d++ = (flags & (1 << 11)) ? 'O' : '-';
    *d++ = (flags & (1 <<  4)) ? 'A' : '-';
    *d++ = (flags & (1 <<  2)) ? 'P' : '-';
    *d++ = 0;
    return msg;
}

static void print_xmm_char(__m128i xmm)
{
    __x128 u;

    u.vi = xmm;
    for (int i = 0; i < 16; i++)
        printf("%c", isprint(u.u8[i]) ? u.u8[i] : '.');
}

static void print_xmm_hex(__m128i xmm)
{
    __x128 u;

    u.vi = xmm;

    const char *sep = "";
    for (int i = 0; i < 4; i++)
    {
        printf("%s%08x", sep, u.u32[3 - i]);
        sep = "_";
    }
}

static int test(const char *insn, pcmpxstrx func, const char *s1, const char *s2, int imm8, int len1, int len2)
{
    bool flag_wchar = (imm8 & 1);

    __m128i xmm1 = (flag_wchar ? setwxmm(s1) : setbxmm(s1));
    __m128i xmm2 = (flag_wchar ? setwxmm(s2) : setbxmm(s2));

    bool implicit = (insn[4] == 'i');
    bool mask = (insn[8] == 'm');

    printf("%s:\n", insn);

    printf("  <-- xmm1 = ");
    print_xmm_hex(xmm1);
    if (!implicit)
        printf(", EAX = %d", len1);
    printf("\n");
    printf("  <-- xmm2 = ");
    print_xmm_hex(xmm2);
    if (!implicit)
        printf(", EDX = %d", len2);
    printf("\n");
    printf("  <-- imm8 = 0x%02x\n", imm8);

    res_pcmpxstrx res = func(xmm1, len1, xmm2, len2, imm8);

    printf("  --> FLAG = %s\n", format_flag(res.f));
    if (!mask)
        printf("  --> ECX(RCX) = %2"PRIu64"\n", res.i);
    else
        printf("  --> XMM0 = "), print_xmm_hex(res.m), printf("\n");

    BoolRes b_res;
    uint16_t i_res1, i_res2;
    int64_t rcx;

    res_pcmpxstrx emu = PCMPxSTRy(implicit, mask, xmm1, xmm2, imm8, len1, len2, b_res, i_res1, i_res2, rcx);

    printf("PCMPxSTRy:\n");
    printf("  --> FLAG = %s\n", format_flag(emu.f));
    if (!mask)
        printf("  --> ECX(RCX) = %2"PRIu64"\n", emu.i);
    else
        printf("  --> XMM0 = "), print_xmm_hex(emu.m), printf("\n");

    int ecnt = (imm8 & 1) ? 8 : 16;

    __x128 x1;
    __x128 x2;

    x1.v = xmm1;
    x2.v = xmm2;

    printf("  BoolRes: ");
    for (int i = 0; i < ecnt; i++)
    {
        int c = flag_wchar ? x1.i16[i] : x1.i8[i];

        printf("%c", isprint(c) ? c : ' ');
    }
    printf("\n");
    for (int i2 = 0; i2 < ecnt; i2++)
    {
        char b[32];
        for (int i1 = 0; i1 < ecnt; i1++)
            b[i1] = (b_res[i2][i1] ? '1' : '.');
        b[ecnt] = 0;

        int c = flag_wchar ? x2.i16[i2] : x2.i8[i2];

        if (!isprint(c))
            c = ' ';

        printf("   '%c'[%d]: %s\n", c, bool(i_res1 & (1 << i2)), b);
    }
    printf("  IntRes1: 0x%08x\n", i_res1);
    printf("  IntRes2: 0x%08x\n", i_res2);
    printf("  Index  : %"PRId64"\n", rcx);

    return 0;
}

#define IMPL_TEST(name)                                                 \
    static int test_##name(const char *s1, const char *s2, int imm8, int len1, int len2) \
    {                                                                   \
        return test(#name, name, s1, s2, imm8, len1, len2);             \
    }

IMPL_TEST(pcmpestri);
IMPL_TEST(pcmpestrm);
IMPL_TEST(pcmpistri);
IMPL_TEST(pcmpistrm);

/*
 * main / usage
 */

static int usage()
{
    printf("Usage: %s [options] instruction str1 str2 imm8 [len1 len2]\n"
           "\n"
           "imm8:\n"
           "   binary - 0bBBBBBBBBBB, xBBBBBBBB\n"
           "   hex    - 0xHH, xHH\n"
           "   %%[format]\n"
           "       b - byte\n"
           "       w - word\n"
           "       u - unsigned\n"
           "       s - signed\n"
           "       0 - Mode 0\n"
           "       1 - Mode 1\n"
           "       2 - Mode 2\n"
           "       3 - Mode 3\n"
           "       P - Positive Polarity\n"
           "       p - Negative Polarity\n"
           "       M - Masked (+)\n"
           "       m - Masked (-)\n"
           "       O - least significant index / bit mask\n"
           "       o - most significant index / byte/word mask\n"
           "\n"
           , program);
    return 1;
}

static char *shift_arg(int &argc, char **&argv)
{
    char *arg = NULL;

    if (argc > 0)
    {
        arg = *argv;
        --argc;
        ++argv;
    }
    return arg;
}

#define shift()  (shift_arg(argc, argv))

int main(int argc, char **argv)
{
    int new_argc = 0;
    char **new_argv = argv;

    program = shift();
    while (argc > 0)
    {
        char *arg = shift();

        if ((arg[0] != '-') || (arg[1] == 0))
        {
            new_argv[new_argc++] = arg;
            continue;
        }

        int s_opt;

        while ((s_opt = *(++arg)) != 0)
        {
            switch (s_opt)
            {
            case 'v':
                flag_verbose = true;
                continue;
            default:
                return usage();
            }
        }
    }

    argc = new_argc;
    argv = new_argv;

    if (argc < 4)
        return usage();

    const char *instruction = argv[0];
    const char *str1 = argv[1];
    const char *str2 = argv[2];
    const char *imm8s = argv[3];

    int len1 = strlen(str1);
    int len2 = strlen(str2);

    if (argc > 4)
    {
        if (argc != 6)
            return usage();
        len1 = strtol(argv[4], NULL, 10);
        len2 = strtol(argv[5], NULL, 10);
    }

    int imm8 = 0;
    {
        if (imm8s[0] == '0' &&
            (imm8s[1] == 'b' || imm8s[1] == 'x'))
            ++imm8s;

        switch (imm8s[0])
        {
        case '%':
            while ((++imm8s)[0])
            {
                switch (imm8s[0])
                {
                    // Source Data Format
                case 'b': // byte
                    imm8 &= ~1;
                    continue;
                case 'w': // word
                    imm8 |= 1;
                    continue;
                case 'u': // unsigned
                    imm8 &= ~2;
                    continue;
                case 's': // signed
                    imm8 |= 2;
                    continue;

                    // Mode
                case '0': // Mode 0
                    imm8 &= ~0x0c;
                    continue;
                case '1': // Mode 1
                    imm8 &= ~0x0c;
                    imm8 |= 0x04;
                    continue;
                case '2': // Mode 2
                    imm8 &= ~0x0c;
                    imm8 |= 0x08;
                    continue;
                case '3': // Mode 3
                    imm8 |= 0x0C;
                    continue;

                    // Polarity
                case 'P': // Positive Polarity
                    imm8 &= ~0x30;
                    continue;
                case 'p': // Negative Polarity
                    imm8 &= ~0x30;
                    imm8 |= 0x10;
                    continue;
                case 'M': // Masked (+)
                    imm8 &= ~0x30;
                    imm8 |= 0x20;
                    continue;
                case 'm': // Masked (-)
                    imm8 |= 0x30;
                    continue;

                    // Output Selection
                case 'O':
                    imm8 &= ~0x40;
                    continue;
                case 'o':
                    imm8 |= 0x40;
                    continue;

                default:
                    return usage();
                }
            }
            break;
        case 'b':
            imm8 = strtol(imm8s + 1, NULL, 2);
            break;
        case 'x':
            imm8 = strtol(imm8s + 1, NULL, 16);
            break;
        default:
            imm8 = strtol(imm8s, NULL, 10);
            break;
        }
    }

    if (!strcasecmp(instruction, "PCMPESTRI")) return test_pcmpestri(str1, str2, imm8, len1, len2);
    if (!strcasecmp(instruction, "PCMPESTRM")) return test_pcmpestrm(str1, str2, imm8, len1, len2);
    if (!strcasecmp(instruction, "PCMPISTRI")) return test_pcmpistri(str1, str2, imm8, len1, len2);
    if (!strcasecmp(instruction, "PCMPISTRM")) return test_pcmpistrm(str1, str2, imm8, len1, len2);
    printf("unknown instruction: %s\n", instruction);
    return 2;
}

実験結果

モード 0:文字セット

文字列に含まれている文字 "aeiou" を判定する。

$ ./test_pcmpxstry pcmpistrm "aeiou" "honjitsuhaseiten" %0
pcmpistrm:
  <-- xmm1 = 00000000_00000000_00000075_6f696561
  <-- xmm2 = 6e657469_65736168_75737469_6a6e6f68
  <-- imm8 = 0x00
  --> FLAG = C-S---
  --> XMM0 = 00000000_00000000_00000000_00005a92
PCMPxSTRy:
  --> FLAG = C-S---
  --> XMM0 = 00000000_00000000_00000000_00005a92
  BoolRes: aeiou           
   'h'[0]: ................
   'o'[1]: ...1............
   'n'[0]: ................
   'j'[0]: ................
   'i'[1]: ..1.............
   't'[0]: ................
   's'[0]: ................
   'u'[1]: ....1...........
   'h'[0]: ................
   'a'[1]: 1...............
   's'[0]: ................
   'e'[1]: .1..............
   'i'[1]: ..1.............
   't'[0]: ................
   'e'[1]: .1..............
   'n'[0]: ................
  IntRes1: 0x00005a92
  IntRes2: 0x00005a92
  Index  : 1

モード 1:文字範囲

文字列に含まれている文字を regex の "[0-9A-Za-z]" 相当で判定する。'_' は文字数を明示して意図的に外した。

$ ./test_pcmpxstry pcmpestrm "09AZaz__" "int sample_1234;" %1 6 16
pcmpestrm:
  <-- xmm1 = 00000000_00000000_5f5f7a61_5a413930, EAX = 6
  <-- xmm2 = 3b343332_315f656c_706d6173_20746e69, EDX = 16
  <-- imm8 = 0x04
  --> FLAG = C-SO--
  --> XMM0 = 00000000_00000000_00000000_00007bf7
PCMPxSTRy:
  --> FLAG = C-SO--
  --> XMM0 = 00000000_00000000_00000000_00007bf7
  BoolRes: 09AZaz__        
   'i'[1]: 1.1.11..........
   'n'[1]: 1.1.11..........
   't'[1]: 1.1.11..........
   ' '[0]: .1.1.1..........
   's'[1]: 1.1.11..........
   'a'[1]: 1.1.11..........
   'm'[1]: 1.1.11..........
   'p'[1]: 1.1.11..........
   'l'[1]: 1.1.11..........
   'e'[1]: 1.1.11..........
   '_'[0]: 1.1..1..........
   '1'[1]: 11.1.1..........
   '2'[1]: 11.1.1..........
   '3'[1]: 11.1.1..........
   '4'[1]: 11.1.1..........
   ';'[0]: 1..1.1..........
  IntRes1: 0x00007bf7
  IntRes2: 0x00007bf7
  Index  : 0

モード 2:各文字の比較

文字列の各文字を比較する。モードに Negative Polarity をセットすると CF=0 で文字列の一致を判定できる。

$ ./test_pcmpxstry pcmpistri "instruction" "instruction" %2p
pcmpistri:
  <-- xmm1 = 00000000_006e6f69_74637572_74736e69
  <-- xmm2 = 00000000_006e6f69_74637572_74736e69
  <-- imm8 = 0x18
  --> FLAG = -ZS---
  --> ECX(RCX) = 16
PCMPxSTRy:
  --> FLAG = -ZS---
  --> ECX(RCX) = 16
  BoolRes: instruction     
   'i'[1]: 1.......1.......
   'n'[1]: .1........1.....
   's'[1]: ..1.............
   't'[1]: ...1...1........
   'r'[1]: ....1...........
   'u'[1]: .....1..........
   'c'[1]: ......1.........
   't'[1]: ...1...1........
   'i'[1]: 1.......1.......
   'o'[1]: .........1......
   'n'[1]: .1........1.....
   ' '[1]: ...........11111
   ' '[1]: ...........11111
   ' '[1]: ...........11111
   ' '[1]: ...........11111
   ' '[1]: ...........11111
  IntRes1: 0x0000ffff
  IntRes2: 0x00000000
  Index  : 16

モード 3:文字列の部分一致

文字列から部分文字列の先頭を見つける。以下の例では、"abcdef" が2箇所ある。

$ ./test_pcmpxstry pcmpistrm "abcdef" "01abcdefabcdefgh" %3
pcmpistrm:
  <-- xmm1 = 00000000_00000000_00006665_64636261
  <-- xmm2 = 68676665_64636261_66656463_62613130
  <-- imm8 = 0x0c
  --> FLAG = C-S---
  --> XMM0 = 00000000_00000000_00000000_00000104
PCMPxSTRy:
  --> FLAG = C-S---
  --> XMM0 = 00000000_00000000_00000000_00000104
  BoolRes: abcdef          
   '0'[0]: ......1111111111
   '1'[0]: ......1111111111
   'a'[1]: 1.....1111111111
   'b'[0]: .1....1111111111
   'c'[0]: ..1...1111111111
   'd'[0]: ...1..1111111111
   'e'[0]: ....1.1111111111
   'f'[0]: .....11111111111
   'a'[1]: 1.....1111111111
   'b'[0]: .1....1111111111
   'c'[0]: ..1...1111111111
   'd'[0]: ...1..1111111111
   'e'[0]: ....1.1111111111
   'f'[0]: .....11111111111
   'g'[0]: ......1111111111
   'h'[0]: ......1111111111
  IntRes1: 0x00000104
  IntRes2: 0x00000104
  Index  : 2

PCMPxSTRy を使ったサンプル・プログラム

コンパイラ: Apple clang version 13.0.0 (clang-1300.0.29.3)
実行環境: macOS Big Sur (11.6)

メモリ・アクセスを 16 バイト境界の 16 バイト単位行う strlen 関数(x86:64ビット)

最初の 16 バイト内の調査では PCMPESTRM を、後続は PCMPESTRI を使用する。

strlen_sse4_2.c
#include <inttypes.h>
#include <x86intrin.h>

size_t strlen_sse4_2(const char *s)
{
    size_t len;

    intptr_t is = (intptr_t)s;
    int os = (int)(is & 15);

    const __m128i *p;
    __m128i x0, xm;
    int im;
    int n;

    /* とりあえず prefetch を発行してみる */
    _mm_prefetch(s, _MM_HINT_T0);

    /* 16 バイト境界で処理しないと SEGV の恐れあり */
    p = (const __m128i *)(is - os);

    /* 先頭 16 バイトの調査 */
    x0 = _mm_setzero_si128();
    xm = _mm_cmpestrm(x0, 16, _mm_load_si128(p++), 16, _SIDD_CMP_EQUAL_EACH);
    im = _mm_cvtsi128_si32(xm) & (-1 << os);
    if (im != 0)
        return (_bit_scan_forward(im) - os);

    /* 後続の調査 */
    len = 16 - os;
    do
    {
        /* 16 文字の中で 0 の部分を検索: 16 で次へ */
        n = _mm_cmpestri(x0, 16, _mm_load_si128(p++), 16, _SIDD_CMP_EQUAL_EACH);
        len += n;
    }
    while (n >= 16);
    return len;
}

アセンブラ版

strlen_sse4_2.S
        .section        __TEXT,__text,regular,pure_instructions
        .build_version  macos, 11, 0    sdk_version 11, 3
        .globl          _strlen_sse4_2
_strlen_sse4_2:
        .cfi_startproc
        prefetcht0      (%rdi)
        movq            %rdi, %rcx
        movq            %rdi, %rsi
        andq            $15, %rcx
        pxor            %xmm2, %xmm2
        subq            %rcx, %rdi
        movq            $16, %rax
        movq            $16, %rdx
        pcmpestrm       $8, (%rdi), %xmm2
        movq            $-1, %r8
        shlq            %cl, %r8
        movq            %xmm0, %rcx
        andq            %r8, %rcx
        bsfq            %rcx, %rcx
        jnz             2f
1:      pcmpestri       $8, 16(%rdi), %xmm2
        lea             16(%rdi), %rdi
        jnc             1b
2:      lea             (%rcx, %rdi), %rax
        subq            %rsi, %rax
        retq
        .cfi_endproc
.subsections_via_symbols
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0