refactor vnni

This commit is contained in:
MITSUNARI Shigeo 2020-10-19 15:45:26 +09:00
parent 276d09bae4
commit f85b1100b5
7 changed files with 74 additions and 81 deletions

View file

@ -1,5 +1,5 @@
TARGET=../xbyak/xbyak_mnemonic.h
BIN=sortline gen_code gen_avx512 gen_vnni
BIN=sortline gen_code gen_avx512
CFLAGS=-I../ -O2 -DXBYAK_NO_OP_NAMES -Wall -Wextra -Wno-missing-field-initializers
all: $(TARGET)
sortline: sortline.cpp
@ -8,8 +8,6 @@ gen_code: gen_code.cpp ../xbyak/xbyak.h avx_type.hpp
$(CXX) $(CFLAGS) $< -o $@
gen_avx512: gen_avx512.cpp ../xbyak/xbyak.h avx_type.hpp
$(CXX) $(CFLAGS) $< -o $@
gen_vnni: gen_vnni.cpp ../xbyak/xbyak.h avx_type.hpp
$(CXX) $(CFLAGS) $< -o $@
$(TARGET): $(BIN)
./gen_code | ./sortline > $@
@ -23,11 +21,6 @@ $(TARGET): $(BIN)
./gen_avx512 64 | ./sortline >> $@
echo "#endif" >> $@
echo "#endif" >> $@
echo "#ifdef XBYAK_DISABLE_AVX512" >> $@
./gen_vnni vexOnly | ./sortline >> $@
echo "#else" >> $@
./gen_vnni | ./sortline >> $@
echo "#endif" >> $@
clean:
$(RM) $(BIN) $(TARGET)

View file

@ -1729,6 +1729,24 @@ void put()
printf("void %s(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W%d, 0x%x, %d); }\n", p.name, p.w, p.code, p.mode);
}
}
// vnni
{
const struct Tbl {
uint8_t code;
const char *name;
int type;
} tbl[] = {
{ 0x50, "vpdpbusd", T_66 | T_0F38 | T_YMM | T_EW0 | T_SAE_Z | T_B32},
{ 0x51, "vpdpbusds", T_66 | T_0F38 | T_YMM | T_EW0 | T_SAE_Z | T_B32},
{ 0x52, "vpdpwssd", T_66 | T_0F38 | T_YMM | T_EW0 | T_SAE_Z | T_B32},
{ 0x53, "vpdpwssds", T_66 | T_0F38 | T_YMM | T_EW0 | T_SAE_Z | T_B32},
};
for (size_t i = 0; i < NUM_OF_ARRAY(tbl); i++) {
const Tbl *p = &tbl[i];
std::string type = type2String(p->type);
printf("void %s(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opVnni(x1, x2, op, %s, 0x%02X, encoding); }\n", p->name, type.c_str(), p->code);
}
}
}
void put32()

View file

@ -1,41 +0,0 @@
#define XBYAK_DONT_READ_LIST
#include <stdio.h>
#include <string.h>
#include "../xbyak/xbyak.h"
#define NUM_OF_ARRAY(x) (sizeof(x) / sizeof(x[0]))
using namespace Xbyak;
#ifdef _MSC_VER
#pragma warning(disable : 4996) // scanf
#define snprintf _snprintf_s
#endif
#include "avx_type.hpp"
void putVNNI(bool vexEncodingOnly)
{
const struct Tbl {
uint8_t code;
const char *name;
int type;
} tbl[] = {
{ 0x50, "vpdpbusd", T_66 | T_0F38 | T_YMM | T_EW0 | T_SAE_Z | T_B32},
{ 0x51, "vpdpbusds", T_66 | T_0F38 | T_YMM | T_EW0 | T_SAE_Z | T_B32},
{ 0x52, "vpdpwssd", T_66 | T_0F38 | T_YMM | T_EW0 | T_SAE_Z | T_B32},
{ 0x53, "vpdpwssds", T_66 | T_0F38 | T_YMM | T_EW0 | T_SAE_Z | T_B32},
};
for (size_t i = 0; i < NUM_OF_ARRAY(tbl); i++) {
const Tbl *p = &tbl[i];
std::string type = type2String(p->type);
printf("void %s(const Xmm& x1, const Xmm& x2, const Operand& op%s) { opAVX_X_X_XM(x1, x2, op, %s%s, 0x%02X, NONE%s); }\n"
, p->name, !vexEncodingOnly ? ", preferred_encoding_t encoding = DEFAULT" : "", type.c_str()
, !vexEncodingOnly ? " | T_PREF_EVEX" : "", p->code, !vexEncodingOnly ? ", encoding" : "");
}
}
int main(int argc, char *[])
{
bool vexEncodingOnly = argc == 2;
putVNNI(vexEncodingOnly);
}

View file

@ -15,9 +15,3 @@ echo #ifdef XBYAK64>> %TARGET%
gen_avx512 64 | %SORT% >> %TARGET%
echo #endif>> %TARGET%
echo #endif>> %TARGET%
cl gen_vnni.cpp %OPT%
echo #ifdef XBYAK_DISABLE_AVX512>> %TARGET%
gen_vnni vexOnly | %SORT% >> %TARGET%
echo #else>> %TARGET%
gen_vnni | %SORT% >> %TARGET%
echo #endif>> %TARGET%

View file

@ -815,4 +815,29 @@ CYBOZU_TEST_AUTO(tileloadd)
CYBOZU_TEST_EXCEPTION(c.notSupported(), std::exception);
CYBOZU_TEST_EXCEPTION(c.notSupported2(), std::exception);
}
CYBOZU_TEST_AUTO(vnni)
{
struct Code : Xbyak::CodeGenerator {
Code()
{
vpdpbusd(xm0, xm1, xm2); // EVEX
vpdpbusd(xm0, xm1, xm2, VexEncoding); // VEX
}
void badVex()
{
vpdpbusd(xm0, xm1, xm31, VexEncoding);
}
} c;
const uint8_t tbl[] = {
0x62, 0xF2, 0x75, 0x08, 0x50, 0xC2,
0xC4, 0xE2, 0x71, 0x50, 0xC2,
};
const size_t n = sizeof(tbl) / sizeof(tbl[0]);
CYBOZU_TEST_EQUAL(c.getSize(), n);
CYBOZU_TEST_EQUAL_ARRAY(c.getCode(), tbl, n);
CYBOZU_TEST_EXCEPTION(c.badVex(), std::exception);
}
#endif

View file

@ -1543,7 +1543,11 @@ inline const uint8_t* Label::getAddress() const
return mgr->getCode() + offset;
}
typedef enum preferred_encoding_t_ { VEX, DEFAULT } preferred_encoding_t;
typedef enum {
DefaultEncoding,
VexEncoding,
EvexEncoding
} PreferredEncoding;
class CodeGenerator : public CodeArray {
public:
@ -1654,7 +1658,6 @@ private:
T_M_K = 1 << 28, // mem{k}
T_VSIB = 1 << 29,
T_MEM_EVEX = 1 << 30, // use evex if mem
T_PREF_EVEX = 1 << 31, // generate EVEX if preferred_encoding = DEFAULT for AVX512
T_XXX
};
void vex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false)
@ -1694,7 +1697,7 @@ private:
}
int evex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false, bool b = false, int aaa = 0, uint32_t VL = 0, bool Hi16Vidx = false)
{
if (!(type & (T_EVEX | T_MUST_EVEX | T_PREF_EVEX))) XBYAK_THROW_RET(ERR_EVEX_IS_INVALID, 0)
if (!(type & (T_EVEX | T_MUST_EVEX))) XBYAK_THROW_RET(ERR_EVEX_IS_INVALID, 0)
int w = (type & T_EW1) ? 1 : 0;
uint32_t mm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0;
uint32_t pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0;
@ -2131,15 +2134,8 @@ private:
{
db(code1); db(code2 | reg.getIdx());
}
void opVex(const Reg& r, const Operand *p1, const Operand& op2, int type, int code, int imm8 = NONE, preferred_encoding_t encoding_ = DEFAULT)
void opVex(const Reg& r, const Operand *p1, const Operand& op2, int type, int code, int imm8 = NONE)
{
#ifdef XBYAK_DISABLE_AVX512
preferred_encoding_t encoding = VEX;
#else
preferred_encoding_t encoding = encoding_;
#endif
if ((encoding == VEX) && ((type & T_MUST_EVEX) || (r.hasEvex() || p1->hasEvex() || op2.hasEvex()))) XBYAK_THROW(ERR_BAD_COMBINATION);
if (op2.isMEM()) {
const Address& addr = op2.getAddress();
const RegExp& regExp = addr.getRegExp();
@ -2148,7 +2144,7 @@ private:
if (BIT == 64 && addr.is32bit()) db(0x67);
int disp8N = 0;
bool x = index.isExtIdx();
if ((encoding == DEFAULT) && ((type & (T_MUST_EVEX | T_PREF_EVEX | T_MEM_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || addr.isBroadcast() || addr.getOpmaskIdx())) {
if ((type & (T_MUST_EVEX|T_MEM_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || addr.isBroadcast() || addr.getOpmaskIdx()) {
int aaa = addr.getOpmaskIdx();
if (aaa && !(type & T_M_K)) XBYAK_THROW(ERR_INVALID_OPMASK_WITH_MEMORY)
bool b = false;
@ -2164,7 +2160,7 @@ private:
opAddr(addr, r.getIdx(), (imm8 != NONE) ? 1 : 0, disp8N, (type & T_VSIB) != 0);
} else {
const Reg& base = op2.getReg();
if ((encoding == DEFAULT) && ((type & (T_MUST_EVEX | T_PREF_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || base.hasEvex())) {
if ((type & T_MUST_EVEX) || r.hasEvex() || (p1 && p1->hasEvex()) || base.hasEvex()) {
evex(r, base, p1, type, code);
} else {
vex(r, base, p1, type, code);
@ -2185,7 +2181,7 @@ private:
type |= (bit == 64) ? T_W1 : T_W0;
opVex(r, p1, *p2, type, code, imm8);
}
void opAVX_X_X_XM(const Xmm& x1, const Operand& op1, const Operand& op2, int type, int code0, int imm8 = NONE, preferred_encoding_t encoding = DEFAULT)
void opAVX_X_X_XM(const Xmm& x1, const Operand& op1, const Operand& op2, int type, int code0, int imm8 = NONE)
{
const Xmm *x2 = static_cast<const Xmm*>(&op1);
const Operand *op = &op2;
@ -2195,7 +2191,7 @@ private:
}
// (x1, x2, op)
if (!((x1.isXMM() && x2->isXMM()) || ((type & T_YMM) && ((x1.isYMM() && x2->isYMM()) || (x1.isZMM() && x2->isZMM()))))) XBYAK_THROW(ERR_BAD_COMBINATION)
opVex(x1, x2, *op, type, code0, imm8, encoding);
opVex(x1, x2, *op, type, code0, imm8);
}
void opAVX_K_X_XM(const Opmask& k, const Xmm& x2, const Operand& op3, int type, int code0, int imm8 = NONE)
{
@ -2307,6 +2303,21 @@ private:
if (addr.getRegExp().getIndex().getKind() != kind) XBYAK_THROW(ERR_BAD_VSIB_ADDRESSING)
opVex(x, 0, addr, type, code);
}
void opVnni(const Xmm& x1, const Xmm& x2, const Operand& op, int type, int code0, PreferredEncoding encoding)
{
if (encoding == DefaultEncoding) {
#ifdef XBYAK_DISABLE_AVX512
encoding = VexEncoding;
#else
encoding = EvexEncoding;
#endif
}
if (encoding == EvexEncoding) {
type |= T_MUST_EVEX;
}
opAVX_X_X_XM(x1, x2, op, type, code0);
}
void opInOut(const Reg& a, const Reg& d, uint8_t code)
{
if (a.getIdx() == Operand::AL && d.getIdx() == Operand::DX && d.getBit() == 16) {

View file

@ -1180,6 +1180,10 @@ void vpcmpgtq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1
void vpcmpgtw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x65); }
void vpcmpistri(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x63, imm); }
void vpcmpistrm(const Xmm& xm, const Operand& op, uint8_t imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x62, imm); }
void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opVnni(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32, 0x50, encoding); }
void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opVnni(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32, 0x51, encoding); }
void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opVnni(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32, 0x52, encoding); }
void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op, PreferredEncoding encoding = DefaultEncoding) { opVnni(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32, 0x53, encoding); }
void vperm2f128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8_t imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x06, imm); }
void vperm2i128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8_t imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) XBYAK_THROW(ERR_BAD_COMBINATION) opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x46, imm); }
void vpermd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x36); }
@ -2044,14 +2048,3 @@ void kmovq(const Reg64& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2
void vpbroadcastq(const Xmm& x, const Reg64& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7C); }
#endif
#endif
#ifdef XBYAK_DISABLE_AVX512
void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32, 0x50, NONE); }
void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32, 0x51, NONE); }
void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32, 0x52, NONE); }
void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32, 0x53, NONE); }
#else
void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op, preferred_encoding_t encoding = DEFAULT) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32 | T_PREF_EVEX, 0x50, NONE, encoding); }
void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op, preferred_encoding_t encoding = DEFAULT) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32 | T_PREF_EVEX, 0x51, NONE, encoding); }
void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op, preferred_encoding_t encoding = DEFAULT) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32 | T_PREF_EVEX, 0x52, NONE, encoding); }
void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op, preferred_encoding_t encoding = DEFAULT) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_B32 | T_PREF_EVEX, 0x53, NONE, encoding); }
#endif