From: Junyan He <junyan...@linux.intel.com> Signed-off-by: Junyan He <junyan...@linux.intel.com> Reviewed-by: Yang Rong <rong.r.y...@intel.com> --- backend/src/backend/gen_context.cpp | 277 ++++++++++++++++++++++++++++++++++++ 1 file changed, 277 insertions(+)
diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp index ed6c9f0..fd5503c 100644 --- a/backend/src/backend/gen_context.cpp +++ b/backend/src/backend/gen_context.cpp @@ -2345,7 +2345,284 @@ namespace gbe p->TYPED_WRITE(header, true, bti); } + static void workgroupOpBetweenThread(GenRegister msgData, GenRegister theVal, GenRegister threadData, + uint32_t simd, uint32_t wg_op, GenEncoder *p) { + p->push(); + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = 1; + + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) { + uint32_t cond; + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) + cond = GEN_CONDITIONAL_LE; + else + cond = GEN_CONDITIONAL_GE; + + p->SEL_CMP(cond, msgData, threadData, msgData); + } + p->pop(); + } + + static void initValue(GenEncoder *p, GenRegister dataReg, uint32_t wg_op) { + if (dataReg.type == GEN_TYPE_UD) { + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN + || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN) { + p->MOV(dataReg, GenRegister::immud(0xFFFFFFFF)); + } else { + GBE_ASSERT(wg_op == ir::WORKGROUP_OP_REDUCE_MAX || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX + || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX); + p->MOV(dataReg, GenRegister::immud(0)); + } + } else if (dataReg.type == GEN_TYPE_F) { + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN + || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN) { + p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0x7F800000)); // inf + } else if (wg_op == ir::WORKGROUP_OP_REDUCE_MAX || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX + || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX) { + p->MOV(GenRegister::retype(dataReg, GEN_TYPE_UD), GenRegister::immud(0xFF800000)); // -inf + } + } else { + GBE_ASSERT(0); + } + } + + static void workgroupOpInThread(GenRegister msgData, GenRegister theVal, GenRegister threadData, + GenRegister tmp, uint32_t simd, uint32_t wg_op, GenEncoder *p) { + p->push(); + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = 1; + + /* Setting the init value here. */ + threadData = GenRegister::retype(threadData, theVal.type); + initValue(p, threadData, wg_op); + + if (theVal.hstride != GEN_HORIZONTAL_STRIDE_0) { + /* We need to set the value out of dispatch mask to MAX. */ + tmp = GenRegister::retype(tmp, theVal.type); + p->push(); + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = simd; + initValue(p, tmp, wg_op); + p->curr.noMask = 0; + p->MOV(tmp, theVal); + p->pop(); + } + + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) { + uint32_t cond; + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) + cond = GEN_CONDITIONAL_LE; + else + cond = GEN_CONDITIONAL_GE; + + if (theVal.hstride == GEN_HORIZONTAL_STRIDE_0) { // an uniform value. + p->SEL_CMP(cond, threadData, threadData, theVal); + } else { + GBE_ASSERT(tmp.type == theVal.type); + GenRegister v = GenRegister::toUniform(tmp, theVal.type); + for (uint32_t i = 0; i < simd; i++) { + p->SEL_CMP(cond, threadData, threadData, v); + v.subnr += typeSize(theVal.type); + if (v.subnr == 32) { + v.subnr = 0; + v.nr++; + } + } + } + } + + p->pop(); + } + +#define SEND_RESULT_MSG() \ +do { \ + p->push(); { /* then send msg. */ \ + p->curr.noMask = 1; \ + p->curr.predicate = GEN_PREDICATE_NONE; \ + p->curr.execWidth = 1; \ + GenRegister offLen = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 20), GEN_TYPE_UD); \ + offLen.vstride = GEN_VERTICAL_STRIDE_0; \ + offLen.width = GEN_WIDTH_1; \ + offLen.hstride = GEN_HORIZONTAL_STRIDE_0; \ + uint32_t szEnc = typeSize(theVal.type) >> 1; \ + if (szEnc == 4) { \ + szEnc = 3; \ + } \ + p->MOV(offLen, GenRegister::immud((szEnc << 8) | (nextThreadID.nr << 21))); \ + \ + GenRegister tidEuid = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 16), GEN_TYPE_UD); \ + tidEuid.vstride = GEN_VERTICAL_STRIDE_0; \ + tidEuid.width = GEN_WIDTH_1; \ + tidEuid.hstride = GEN_HORIZONTAL_STRIDE_0; \ + p->SHL(tidEuid, tidEuid, GenRegister::immud(16)); \ + \ + p->curr.execWidth = 8; \ + p->FWD_GATEWAY_MSG(nextThreadID, 2); \ + } p->pop(); \ +} while(0) + + + /* The basic idea is like this: + 1. All the threads firstly calculate the max/min/add value within their own thread, that is finding + the max/min/add value within their 16 work items when SIMD == 16. + 2. The logical thread ID 0 begins to send the MSG to thread 1, and that message contains the calculated + result of the first step. Except the thread 0, all other threads wait on the n0.2 for message forwarding. + 3. Each thread is waken up because of getting the forwarding message from the thread_id - 1. Then it + compares the result in the message and the result within its thread, then forward the correct result to + the next thread by sending a message again. If it is the last thread, send it to thread 0. + 4. Thread 0 finally get the message from the last one and broadcast the final result. */ void GenContext::emitWorkGroupOpInstruction(const SelectionInstruction &insn) { + const GenRegister dst = ra->genReg(insn.dst(0)); + const GenRegister tmp = ra->genReg(insn.dst(2)); + GenRegister flagReg = GenRegister::flag(insn.state.flag, insn.state.subFlag); + GenRegister nextThreadID = ra->genReg(insn.src(1)); + const GenRegister theVal = ra->genReg(insn.src(0)); + GenRegister threadid = ra->genReg(GenRegister::ud1grf(ir::ocl::threadid)); + GenRegister threadnum = ra->genReg(GenRegister::ud1grf(ir::ocl::threadn)); + GenRegister msgData = GenRegister::retype(nextThreadID, dst.type); // The data forward. + msgData.vstride = GEN_VERTICAL_STRIDE_0; + msgData.width = GEN_WIDTH_1; + msgData.hstride = GEN_HORIZONTAL_STRIDE_0; + GenRegister threadData = + GenRegister::retype(GenRegister::offset(nextThreadID, 0, 24), dst.type); // Res within thread. + threadData.vstride = GEN_VERTICAL_STRIDE_0; + threadData.width = GEN_WIDTH_1; + threadData.hstride = GEN_HORIZONTAL_STRIDE_0; + uint32_t wg_op = insn.extra.workgroupOp; + uint32_t simd = p->curr.execWidth; + GenRegister flag_save = GenRegister::retype(GenRegister::offset(nextThreadID, 0, 8), GEN_TYPE_UW); + flag_save.vstride = GEN_VERTICAL_STRIDE_0; + flag_save.width = GEN_WIDTH_1; + flag_save.hstride = GEN_HORIZONTAL_STRIDE_0; + int32_t jip; + int32_t oneThreadJip = -1; + + p->push(); { /* First, so something within thread. */ + p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr()); + /* Do some calculation within each thread. */ + workgroupOpInThread(msgData, theVal, threadData, tmp, simd, wg_op, p); + } p->pop(); + + /* If we are the only one thread, no need to send msg, just broadcast the result.*/ + p->push(); { + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = 1; + p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr()); + p->CMP(GEN_CONDITIONAL_EQ, threadnum, GenRegister::immud(0x1)); + + /* Broadcast result. */ + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN) { + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.inversePredicate = 1; + p->MOV(flag_save, GenRegister::immuw(0x0)); + p->curr.inversePredicate = 0; + p->MOV(flag_save, GenRegister::immuw(0xffff)); + p->curr.predicate = GEN_PREDICATE_NONE; + p->MOV(flagReg, flag_save); + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.execWidth = simd; + p->MOV(dst, threadData); + } + + /* Bail out. */ + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.inversePredicate = 0; + p->curr.execWidth = 1; + oneThreadJip = p->n_instruction(); + p->JMPI(GenRegister::immud(0)); + } p->pop(); + + p->push(); { + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = 1; + p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr()); + p->CMP(GEN_CONDITIONAL_EQ, threadid, GenRegister::immud(0x0)); + + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.inversePredicate = 1; + p->MOV(flag_save, GenRegister::immuw(0x0)); + p->curr.inversePredicate = 0; + p->MOV(flag_save, GenRegister::immuw(0xffff)); + + p->curr.predicate = GEN_PREDICATE_NONE; + p->MOV(flagReg, flag_save); + } p->pop(); + + p->push(); { + p->curr.noMask = 1; + p->curr.execWidth = 1; + + /* threadid 0, send the msg and wait */ + p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr()); + p->curr.inversePredicate = 1; + p->curr.predicate = GEN_PREDICATE_NORMAL; + jip = p->n_instruction(); + p->JMPI(GenRegister::immud(0)); + p->curr.predicate = GEN_PREDICATE_NONE; + p->MOV(msgData, threadData); + SEND_RESULT_MSG(); + p->WAIT(2); + p->patchJMPI(jip, (p->n_instruction() - jip), 0); + + /* Others wait and send msg, and do something when we get the msg. */ + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.inversePredicate = 0; + jip = p->n_instruction(); + p->JMPI(GenRegister::immud(0)); + p->curr.predicate = GEN_PREDICATE_NONE; + p->WAIT(2); + workgroupOpBetweenThread(msgData, theVal, threadData, simd, wg_op, p); + SEND_RESULT_MSG(); + p->patchJMPI(jip, (p->n_instruction() - jip), 0); + + /* Restore the flag. */ + p->curr.predicate = GEN_PREDICATE_NONE; + p->MOV(flagReg, flag_save); + } p->pop(); + + /* Broadcast the result. */ + if (wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) { + p->push(); { + p->curr.predicate = GEN_PREDICATE_NORMAL; + p->curr.noMask = 1; + p->curr.execWidth = 1; + p->curr.useFlag(flagReg.flag_nr(), flagReg.flag_subnr()); + p->curr.inversePredicate = 0; + + /* Not the first thread, wait for msg first. */ + jip = p->n_instruction(); + p->JMPI(GenRegister::immud(0)); + p->curr.predicate = GEN_PREDICATE_NONE; + p->WAIT(2); + p->patchJMPI(jip, (p->n_instruction() - jip), 0); + + /* Do something when get the msg. */ + p->curr.execWidth = simd; + p->MOV(dst, msgData); + + p->curr.execWidth = 8; + p->FWD_GATEWAY_MSG(nextThreadID, 2); + + p->curr.execWidth = 1; + p->curr.inversePredicate = 1; + p->curr.predicate = GEN_PREDICATE_NORMAL; + + /* The first thread, the last one will notify us. */ + jip = p->n_instruction(); + p->JMPI(GenRegister::immud(0)); + p->curr.predicate = GEN_PREDICATE_NONE; + p->WAIT(2); + p->patchJMPI(jip, (p->n_instruction() - jip), 0); + } p->pop(); + } + + if (oneThreadJip >=0) + p->patchJMPI(oneThreadJip, (p->n_instruction() - oneThreadJip), 0); } void GenContext::setA0Content(uint16_t new_a0[16], uint16_t max_offset, int sz) { -- 2.5.0 _______________________________________________ Beignet mailing list Beignet@lists.freedesktop.org http://lists.freedesktop.org/mailman/listinfo/beignet