tra updated this revision to Diff 358041. tra marked an inline comment as done. tra edited the summary of this revision. tra added a comment.
Addressed review comments. Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D105384/new/ https://reviews.llvm.org/D105384 Files: clang/include/clang/Basic/BuiltinsNVPTX.def clang/lib/CodeGen/CGBuiltin.cpp clang/test/CodeGen/builtins-nvptx-mma.cu clang/test/CodeGen/builtins-nvptx-mma.py llvm/include/llvm/IR/IntrinsicsNVVM.td llvm/lib/Target/NVPTX/NVPTXInstrInfo.td llvm/lib/Target/NVPTX/NVPTXIntrinsics.td llvm/test/CodeGen/NVPTX/wmma.py
Index: llvm/test/CodeGen/NVPTX/wmma.py =================================================================== --- llvm/test/CodeGen/NVPTX/wmma.py +++ llvm/test/CodeGen/NVPTX/wmma.py @@ -55,14 +55,14 @@ # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \ # RUN: | FileCheck %t-ptx65-sm_75.ll -# Check all variants of instructions supported by PTX70 on SM80+ -# RUN: %python %s --ptx=70 --gpu-arch=80 > %t-ptx70-sm_80.ll -# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \ -# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX70MMA -# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \ +# Check all variants of instructions supported by PTX71 on SM80+ +# RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll +# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \ +# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX71MMA +# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \ # RUN: --check-prefixes=INTRINSICS -# RUN: llc < %t-ptx70-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 \ -# RUN: | FileCheck %t-ptx70-sm_80.ll +# RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \ +# RUN: | FileCheck %t-ptx71-sm_80.ll from __future__ import print_function @@ -649,9 +649,16 @@ print(Template(mma_template).substitute(test_params)) return (test_params["intrinsic"], test_params["instruction"]) +def get_b1_ops(ptx_type): + if ptx_type != "b1": + return [""] + if ptx_version >= 71: + return [".xor.popc", ".and.popc"] + return [".xor.popc"] + def gen_wmma_mma_tests(): - wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}" - wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}" + wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}" + wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}" generated_items=[] @@ -665,29 +672,30 @@ if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf): continue - params = { - "aligned" : ".aligned" if ptx_version >= 63 else "", - "alayout" : alayout, - "blayout" : blayout, - "intrinsic_signature" : wmma_signature(op), - "ptx_signature" : wmma_ptx_signature(op), - "satf" : satf, - "rnd" : rnd, - "geom" : op.a.geom, - "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "", - } + for b1op in get_b1_ops(op.a.mma_type.ptx_type): + params = { + "aligned" : ".aligned" if ptx_version >= 63 else "", + "alayout" : alayout, + "blayout" : blayout, + "intrinsic_signature" : wmma_signature(op), + "ptx_signature" : wmma_ptx_signature(op), + "satf" : satf, + "rnd" : rnd, + "geom" : op.a.geom, + "b1op" : b1op + } - intrinsic_template = wmma_intrinsic_template - instruction_template = wmma_instruction_template + intrinsic_template = wmma_intrinsic_template + instruction_template = wmma_instruction_template - generated_items.append(common_mma_test_gen(params, op, - intrinsic_template, instruction_template)) + generated_items.append(common_mma_test_gen(params, op, + intrinsic_template, instruction_template)) return generated_items def gen_mma_tests(): - mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}" - mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${mma_variant}" + mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}" + mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}" generated_items=[] @@ -700,22 +708,23 @@ if not is_mma_variant_supported(op, alayout, blayout, satf): continue - params = { - "aligned" : ".aligned" if ptx_version >= 63 else "", - "alayout" : alayout, - "blayout" : blayout, - "intrinsic_signature" : mma_signature(op), - "ptx_signature" : mma_ptx_signature(op), - "satf" : satf, - "geom" : op.a.geom, - "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "", - } + for b1op in get_b1_ops(op.a.mma_type.ptx_type): + params = { + "aligned" : ".aligned" if ptx_version >= 63 else "", + "alayout" : alayout, + "blayout" : blayout, + "intrinsic_signature" : mma_signature(op), + "ptx_signature" : mma_ptx_signature(op), + "satf" : satf, + "geom" : op.a.geom, + "b1op" : b1op + } - intrinsic_template = mma_intrinsic_template - instruction_template = mma_instruction_template + intrinsic_template = mma_intrinsic_template + instruction_template = mma_instruction_template - generated_items.append(common_mma_test_gen(params, op, - intrinsic_template, instruction_template)) + generated_items.append(common_mma_test_gen(params, op, + intrinsic_template, instruction_template)) return generated_items @@ -810,32 +819,35 @@ ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4 -; PTX70MMA-DAG: mma.m8n8k4.row.col.f64 -; PTX70MMA-DAG: mma.m16n8k4.row.col.tf32 -; PTX70MMA-DAG: mma.m16n8k8.row.col.tf32 -; PTX70MMA-DAG: mma.m16n8k16.row.col.bf16 -; PTX70MMA-DAG: mma.m16n8k8.row.col.bf16 -; PTX70MMA-DAG: mma.m16n8k16.row.col.f16.f16 -; PTX70MMA-DAG: mma.m16n8k16.row.col.f32.f32 -; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8 -; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8 -; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8 -; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4 -; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4 -; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4 -; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4 -; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4 -; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4 -; PTX70MMA-DAG: mma.m8n8k128.row.col.b1 -; PTX70MMA-DAG: mma.m16n8k128.row.col.b1 -; PTX70MMA-DAG: mma.m16n8k256.row.col.b1 +; PTX71MMA-DAG: mma.m8n8k4.row.col.f64 +; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32 +; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32 +; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16 +; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16 +; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16 +; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32 +; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8 +; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8 +; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8 +; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4 +; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4 +; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4 +; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4 +; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4 +; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4 +; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1 +; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1 +; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1 +; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1 +; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1 +; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1 ; """) Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7796,15 +7796,24 @@ } // layout } // defset +// B1 instruction variants need extra constraints. +class MMA_OP_PREDICATES<WMMA_REGINFO FragA, string b1op> { + string Op = b1op; + WMMA_REGINFO Frag = FragA; + list<Predicate> ret = !listconcat( + FragA.Predicates, + !if(!eq(b1op, ".and.popc"), [hasSM80,hasPTX71],[]) + ); +} // WMMA.MMA class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, - string ALayout, string BLayout, int Satfinite, string rnd> - : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, FragA, FragB, FragC, FragD>.record, + string ALayout, string BLayout, int Satfinite, string rnd, string b1op> + : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record, [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. - Requires<FragA.Predicates> { + Requires<MMA_OP_PREDICATES<FragA, b1op>.ret> { let OutOperandList = FragD.Outs; let InOperandList = !con(Args, (ins MmaCode:$ptx)); string TypeList = !cond( @@ -7816,7 +7825,7 @@ # "." # FragC.ptx_elt_type, ); let AsmString = "wmma.mma" - # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") + # b1op # ".sync" # "${ptx:aligned}" # "." # ALayout @@ -7837,13 +7846,15 @@ foreach satf = [0, 1] in { foreach rnd = ["", "rn", "rz", "rm", "rp"] in { foreach op = NVVM_MMA_OPS.all_wmma_ops in { - if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then { - def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">, - WMMA_REGINFO<op[1], "wmma.mma">, - WMMA_REGINFO<op[2], "wmma.mma">, - WMMA_REGINFO<op[3], "wmma.mma">, - layout_a, layout_b, satf, rnd>; - } + foreach b1op = NVVM_MMA_B1OPS<op>.ret in { + if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then { + def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">, + WMMA_REGINFO<op[1], "wmma.mma">, + WMMA_REGINFO<op[2], "wmma.mma">, + WMMA_REGINFO<op[3], "wmma.mma">, + layout_a, layout_b, satf, rnd, b1op>; + } + } // b1op } // op } // rnd } // satf @@ -7854,12 +7865,12 @@ // MMA class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, - string ALayout, string BLayout, int Satfinite> - : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record, + string ALayout, string BLayout, int Satfinite, string b1op> + : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record, [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. - Requires<FragA.Predicates> { + Requires<MMA_OP_PREDICATES<FragA, b1op>.ret> { let OutOperandList = FragD.Outs; let InOperandList = !con(Args, (ins MmaCode:$ptx)); string TypeList = "." # FragD.ptx_elt_type @@ -7872,7 +7883,7 @@ # "." # BLayout # !if(Satfinite, ".satfinite", "") # TypeList - # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") # "\n\t\t" + # b1op # "\n\t\t" # FragD.regstring # ",\n\t\t" # FragA.regstring # ",\n\t\t" # FragB.regstring # ",\n\t\t" @@ -7884,13 +7895,15 @@ foreach layout_b = ["row", "col"] in { foreach satf = [0, 1] in { foreach op = NVVM_MMA_OPS.all_mma_ops in { - if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then { - def : MMA<WMMA_REGINFO<op[0], "mma">, - WMMA_REGINFO<op[1], "mma">, - WMMA_REGINFO<op[2], "mma">, - WMMA_REGINFO<op[3], "mma">, - layout_a, layout_b, satf>; - } + foreach b1op = NVVM_MMA_B1OPS<op>.ret in { + if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then { + def : MMA<WMMA_REGINFO<op[0], "mma">, + WMMA_REGINFO<op[1], "mma">, + WMMA_REGINFO<op[2], "mma">, + WMMA_REGINFO<op[3], "mma">, + layout_a, layout_b, satf, b1op>; + } + } // b1op } // op } // satf } // layout_b Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -146,6 +146,7 @@ def hasPTX64 : Predicate<"Subtarget->getPTXVersion() >= 64">; def hasPTX65 : Predicate<"Subtarget->getPTXVersion() >= 65">; def hasPTX70 : Predicate<"Subtarget->getPTXVersion() >= 70">; +def hasPTX71 : Predicate<"Subtarget->getPTXVersion() >= 71">; def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">; def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">; Index: llvm/include/llvm/IR/IntrinsicsNVVM.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsNVVM.td +++ llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -225,12 +225,13 @@ string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type)); } -class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, +class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, string b1op, WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> { string signature = MMA_SIGNATURE<A, B, C, D>.ret; string llvm = "llvm.nvvm.wmma." # A.geom # ".mma" + # b1op # "." # ALayout # "." # BLayout # !if(!ne(Rnd, ""), !strconcat(".", Rnd), "") @@ -241,11 +242,12 @@ !subst("llvm.", "int_", llvm)); } -class MMA_NAME<string ALayout, string BLayout, int Satfinite, +class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> { string signature = MMA_SIGNATURE<A, B, C, D>.ret; - string llvm = "llvm.nvvm.mma." - # A.geom + string llvm = "llvm.nvvm.mma" + # b1op + # "." # A.geom # "." # ALayout # "." # BLayout # !if(Satfinite, ".satfinite", "") @@ -430,6 +432,13 @@ ); } +class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> { + list<string> ret = !cond( + !eq(frags[0].ptx_elt_type, "b1") : [".xor.popc", ".and.popc"], + true: [""] + ); +} + // Returns true if this combination of layout/satf for MMA ops is supported; // false otherwise. // E.g. @@ -4460,25 +4469,27 @@ } // WMMA.MMA -class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite, string rnd, +class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite, string rnd, string b1op, WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> : Intrinsic<D.regs, !listconcat(A.regs, B.regs, C.regs), [IntrNoMem], - WMMA_NAME<ALayout, BLayout, Satfinite, rnd, A, B, C, D>.llvm>; + WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, A, B, C, D>.llvm>; foreach layout_a = ["row", "col"] in { foreach layout_b = ["row", "col"] in { foreach satf = [0, 1] in { foreach rnd = ["", "rn", "rz", "rm", "rp"] in { foreach op = NVVM_MMA_OPS.all_wmma_ops in { - if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then { - def WMMA_NAME<layout_a, layout_b, satf, rnd, - op[0], op[1], op[2], op[3]>.record - : NVVM_WMMA_MMA<layout_a, layout_b, satf, rnd, - op[0], op[1], op[2], op[3]>; - } + foreach b1op = NVVM_MMA_B1OPS<op>.ret in { + if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then { + def WMMA_NAME<layout_a, layout_b, satf, rnd, b1op, + op[0], op[1], op[2], op[3]>.record + : NVVM_WMMA_MMA<layout_a, layout_b, satf, rnd, b1op, + op[0], op[1], op[2], op[3]>; + } + } // b1op } // op } // rnd } // satf @@ -4486,21 +4497,23 @@ } // layout_a // MMA -class NVVM_MMA<string ALayout, string BLayout, int Satfinite, +class NVVM_MMA<string ALayout, string BLayout, int Satfinite, string b1op, WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> : Intrinsic<D.regs, !listconcat(A.regs, B.regs, C.regs), [IntrNoMem], - MMA_NAME<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>; + MMA_NAME<ALayout, BLayout, Satfinite, b1op, A, B, C, D>.llvm>; foreach layout_a = ["row", "col"] in { foreach layout_b = ["row", "col"] in { foreach satf = [0, 1] in { foreach op = NVVM_MMA_OPS.all_mma_ops in { - if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then { - def MMA_NAME<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>.record - : NVVM_MMA<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>; - } + foreach b1op = NVVM_MMA_B1OPS<op>.ret in { + if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then { + def MMA_NAME<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>.record + : NVVM_MMA<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>; + } + } // b1op } // op } // satf } // layout_b Index: clang/test/CodeGen/builtins-nvptx-mma.py =================================================================== --- clang/test/CodeGen/builtins-nvptx-mma.py +++ clang/test/CodeGen/builtins-nvptx-mma.py @@ -22,24 +22,29 @@ return "%s:%s:%s" % (self.geom, self.frag, self.ptx_type) class MMAOp: - def __init__(self, a, b, c, d): + def __init__(self, a, b, c, d, b1op=""): self.a = a self.b = b self.c = c self.d = d + self.b1op = b1op def __repr__(self): return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d )) -def make_mma_ops(geoms, types_a, types_b, types_c, types_d): +def make_mma_ops(geoms, types_a, types_b, types_c, types_d, b1ops=None): ops = [] + if b1ops is None: + b1ops = [""] for geom, type_a, type_c in product( geoms, types_a, types_c): for type_b, type_d in product(types_b if types_b else [type_a], types_d if types_d else [type_c]): - ops.append(MMAOp(MMAFrag(geom, "a", type_a), - MMAFrag(geom, "b", type_b), - MMAFrag(geom, "c", type_c), - MMAFrag(geom, "d", type_d))) + ops += [ + MMAOp(MMAFrag(geom, "a", type_a), + MMAFrag(geom, "b", type_b), + MMAFrag(geom, "c", type_c), + MMAFrag(geom, "d", type_d), b1op) + for b1op in b1ops] return ops def make_ldst_ops(geoms, frags, types): @@ -60,9 +65,12 @@ make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], []) + make_mma_ops(["m8n8k128"], - ["b1"], [], ["s32"], [])) + ["b1"], [], ["s32"], [], + [".xor.popc", ".and.popc"])) def get_ldst_ops(): + # NOTE: fragemts are from the point of view of PTX. + # fragment `d` is only for store ops, others for both loads and stores. return (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["a", "b"], ["f16", "u8", "s8", "bf16"]) + make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], @@ -71,8 +79,9 @@ make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) + - make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) + - make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])) + # TF32 m16n16k8 is odd. + # It uses __mma_tf32_m16n16k8_ld_c but __mma_m16n16k8_st_c_f32. + make_ldst_ops(["m16n16k8"], ["a", "b", "c", "d"], ["tf32", "f32"])) def is_geom_supported(geom): # geometries for FP and ints. @@ -117,6 +126,12 @@ if not (is_type_supported(frag.ptx_type) and is_geom_supported(frag.geom)): return False + if frag.geom == "m16n16k8": + if frag.frag in ["a", "b"]: + return frag.ptx_type == "tf32" + if frag.frag == "c": + return frag.ptx_type == "tf32"; + return frag.ptx_type == "f32"; if frag.ptx_type in ["s4", "u4", "b1"]: # sub-integer types require sm_75 and ptx63, row/col layout for a/b. return ((frag.frag == "a" and layout == "row") @@ -180,15 +195,19 @@ else: suffix = op.a.ptx_type - name = "%s_%s_mma%s_%s" % (prefix, op.a.geom, - "_xor_popc" if op.a.ptx_type == "b1" else "", - suffix) + name = "{prefix}_{geom}_mma{b1op}_{suffix}".format( + prefix = prefix, + geom = op.a.geom, + b1op = op.b1op.replace(".","_"), + suffix = suffix) return name -def get_required_sm(frag): +def get_required_sm(frag, b1op=""): if frag.ptx_type in ["f64", "bf16", "tf32"]: return 80 if frag.ptx_type in ["u4", "s4", "b1"]: + if b1op == "_and_popc": + return 80 return 75 if frag.ptx_type in ["s8", "u8"]: return 72 @@ -204,7 +223,9 @@ return 70 assert(False) -def get_required_ptx(frag): +def get_required_ptx(frag, b1op=""): + if frag.ptx_type == "b1" and b1op == ".and.popc": + return 71 if frag.ptx_type in ["f64", "bf16", "tf32"]: return 70 if frag.ptx_type in ["f16", "f32"]: @@ -215,11 +236,13 @@ return 61 return 63 -def get_src_dst_prefix(ptx_type): - if ptx_type == "f32": +def get_src_dst_prefix(frag): + if frag.ptx_type == "f32": return "f" - if ptx_type == "f64": + if frag.ptx_type == "f64": return "d" + if frag.ptx_type == "tf32" and frag.frag in ["c", "d"]: + return "f" return "" def gen_wmma_ldst_tests(results): @@ -235,9 +258,17 @@ if not is_ldst_variant_supported(frag, layout): continue - src_dst_prefix = get_src_dst_prefix(frag.ptx_type) + src_dst_prefix = get_src_dst_prefix(frag) + min_sm = get_required_sm(frag) min_ptx = get_required_ptx(frag) + # TF32 uses f32 for accumulator loads. + if frag.geom == "m16n16k8" and frag.frag =="c": + assert frag.ptx_type == "tf32" + itype = "f32" + else: + itype = frag.ptx_type + params = { "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm), "builtin" : get_ldst_builtin_name(frag), @@ -250,7 +281,7 @@ "frag" : frag.frag, "geom" : frag.geom, "ilayout" : layout, - "itype" : frag.ptx_type, + "itype" : itype, "op" : "store" if frag.frag == "d" else "load", }) } @@ -283,7 +314,7 @@ // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}} ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf}); """.rstrip() - intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}" + intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}.${intrinsic_signature}${satf}" for op, alayout, blayout, satf in sorted(product( get_mma_ops(), ["row","col"], @@ -294,15 +325,15 @@ if not is_mma_variant_supported(op, alayout, blayout, satf): continue - asrc_prefix = get_src_dst_prefix(op.a.ptx_type) - csrc_prefix = get_src_dst_prefix(op.c.ptx_type) - ddst_prefix = get_src_dst_prefix(op.d.ptx_type) - min_sm = get_required_sm(op.a) - min_ptx = get_required_ptx(op.a) + asrc_prefix = get_src_dst_prefix(op.a) + csrc_prefix = get_src_dst_prefix(op.c) + ddst_prefix = get_src_dst_prefix(op.d) if op.a.ptx_type == "b1": # .b1 MMA has no satf argument. isatf_arg = "" else: isatf_arg = ", 1" if satf else ", 0" + min_sm = get_required_sm(op.a, op.b1op) + min_ptx = get_required_ptx(op.a, op.b1op) params = { "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm), "builtin" : get_mma_builtin_name(op), @@ -319,6 +350,7 @@ "blayout" : blayout, "intrinsic_signature" : mma_signature(op), "satf" : satf, + "b1op" : op.b1op }) } results[(min_ptx, min_sm)] += Template(mma_template).substitute(params) Index: clang/test/CodeGen/builtins-nvptx-mma.cu =================================================================== --- clang/test/CodeGen/builtins-nvptx-mma.cu +++ clang/test/CodeGen/builtins-nvptx-mma.cu @@ -3,20 +3,21 @@ // *** DO NOT EDIT *** // // This test has been automatically generated by -// builtins-nvtx-mma.py --ptx=70 --gpu-arch=80 +// builtins-nvtx-mma.py --ptx=71 --gpu-arch=80 // -// Make sure we can handle all builtins available on sm_80 with PTX70 +// Make sure we can handle all builtins available on sm_80 with PTX71 // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_80 \ -// RUN: -fcuda-is-device -target-feature +ptx70 \ -// RUN: -DPTX=70 -DSM=80 \ +// RUN: -fcuda-is-device -target-feature +ptx71 \ +// RUN: -DPTX=71 -DSM=80 \ // RUN: -S -emit-llvm -o - -x cuda %s \ -// RUN: | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75 %s +// RUN: | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75,CHECK_PTX71_SM75 %s // Verify that all builtins have correct constraints. // RUN: %clang_cc1 -triple nvptx-unknown-unknown \ // RUN: -target-cpu sm_60 -target-feature +ptx42 \ -// RUN: -DPTX=70 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \ +// RUN: -DPTX=71 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \ // RUN: -verify %s + #if !defined(CUDA_VERSION) #define __device__ __attribute__((device)) #define __global__ __attribute__((global)) @@ -31,6 +32,7 @@ float *fsrc, float *fdst, double *dsrc, double *ddst, int ldm) { + #if (PTX >= 60) && (SM >= 70) // CHECK_PTX60_SM70: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16 @@ -735,7 +737,7 @@ // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.store.d.row.stride.s32 // expected-error-re@+1 {{'__imma_m8n8k32_st_c_i32' needs target feature (sm_75{{.*}},(ptx63{{.*}}}} __imma_m8n8k32_st_c_i32(dst, src, ldm, 0); - // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.row.col.b1 + // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.xor.popc.row.col.b1 // expected-error-re@+1 {{'__bmma_m8n8k128_mma_xor_popc_b1' needs target feature (sm_75{{.*}},(ptx63{{.*}}}} __bmma_m8n8k128_mma_xor_popc_b1(dst, src, src, src, 1); // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.s4 @@ -750,7 +752,7 @@ // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.u4.satfinite // expected-error-re@+1 {{'__imma_m8n8k32_mma_u4' needs target feature (sm_75{{.*}},(ptx63{{.*}}}} __imma_m8n8k32_mma_u4(dst, src, src, src, 1, 1); -#endif // (PTX >= 63) && (SM >= 75) +#endif // (PTX >= 63) && (SM >= 75) #if (PTX >= 70) && (SM >= 80) @@ -898,5 +900,12 @@ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64 // expected-error-re@+1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}} __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 0, 0); -#endif // (PTX >= 70) && (SM >= 80) +#endif // (PTX >= 70) && (SM >= 80) + +#if (PTX >= 71) && (SM >= 75) + + // CHECK_PTX71_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.and.popc.row.col.b1 + // expected-error-re@+1 {{'__bmma_m8n8k128_mma_and_popc_b1' needs target feature (sm_75{{.*}},(ptx71{{.*}}}} + __bmma_m8n8k128_mma_and_popc_b1(dst, src, src, src, 1); +#endif // (PTX >= 71) && (SM >= 75) } Index: clang/lib/CodeGen/CGBuiltin.cpp =================================================================== --- clang/lib/CodeGen/CGBuiltin.cpp +++ clang/lib/CodeGen/CGBuiltin.cpp @@ -16556,9 +16556,18 @@ 0, \ 0 // b1 MMA does not support .satfinite. -#define MMA_VARIANTS_B1(geom, type) \ +#define MMA_VARIANTS_B1_XOR(geom, type) \ 0, \ - Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_xor_popc_row_col_##type, \ + 0, \ + 0, \ + 0, \ + 0, \ + 0, \ + 0 +#define MMA_VARIANTS_B1_AND(geom, type) \ + 0, \ + Intrinsic::nvvm_wmma_##geom##_mma_and_popc_row_col_##type, \ 0, \ 0, \ 0, \ @@ -16615,7 +16624,9 @@ case NVPTX::BI__imma_m8n8k32_mma_u4: return {1, 1, 2, 2, {{MMA_VARIANTS_I4(m8n8k32, u4)}}}; case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: - return {1, 1, 2, 2, {{MMA_VARIANTS_B1(m8n8k128, b1)}}}; + return {1, 1, 2, 2, {{MMA_VARIANTS_B1_XOR(m8n8k128, b1)}}}; + case NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1: + return {1, 1, 2, 2, {{MMA_VARIANTS_B1_AND(m8n8k128, b1)}}}; // Double MMA case NVPTX::BI__dmma_m8n8k4_mma_f64: @@ -16636,7 +16647,8 @@ #undef MMA_VARIANTS #undef MMA_SATF_VARIANTS #undef MMA_VARIANTS_I4 -#undef MMA_VARIANTS_B1 +#undef MMA_VARIANTS_B1_AND +#undef MMA_VARIANTS_B1_XOR } } // namespace @@ -17045,6 +17057,7 @@ case NVPTX::BI__imma_m8n8k32_mma_s4: case NVPTX::BI__imma_m8n8k32_mma_u4: case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: + case NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1: case NVPTX::BI__dmma_m8n8k4_mma_f64: case NVPTX::BI__mma_bf16_m16n16k16_mma_f32: case NVPTX::BI__mma_bf16_m8n32k16_mma_f32: @@ -17062,7 +17075,8 @@ if (Layout < 0 || Layout > 3) return nullptr; llvm::APSInt SatfArg; - if (BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1) + if (BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1 || + BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1) SatfArg = 0; // .b1 does not have satf argument. else if (Optional<llvm::APSInt> OptSatfArg = E->getArg(5)->getIntegerConstantExpr(getContext())) Index: clang/include/clang/Basic/BuiltinsNVPTX.def =================================================================== --- clang/include/clang/Basic/BuiltinsNVPTX.def +++ clang/include/clang/Basic/BuiltinsNVPTX.def @@ -724,6 +724,7 @@ TARGET_BUILTIN(__bmma_m8n8k128_ld_a_b1, "vi*iC*UiIi", "", AND(SM_75,PTX63)) TARGET_BUILTIN(__bmma_m8n8k128_ld_b_b1, "vi*iC*UiIi", "", AND(SM_75,PTX63)) TARGET_BUILTIN(__bmma_m8n8k128_ld_c, "vi*iC*UiIi", "", AND(SM_75,PTX63)) +TARGET_BUILTIN(__bmma_m8n8k128_mma_and_popc_b1, "vi*iC*iC*iC*Ii", "", AND(SM_75,PTX71)) TARGET_BUILTIN(__bmma_m8n8k128_mma_xor_popc_b1, "vi*iC*iC*iC*Ii", "", AND(SM_75,PTX63)) TARGET_BUILTIN(__bmma_m8n8k128_st_c_i32, "vi*iC*UiIi", "", AND(SM_75,PTX63)) TARGET_BUILTIN(__imma_m16n16k16_ld_a_s8, "vi*iC*UiIi", "", AND(SM_72,PTX63))
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits