Signed-off-by: Pierre Moreau <pierre.mor...@free.fr> --- src/gallium/auxiliary/Makefile.sources | 4 +- src/gallium/auxiliary/spirv/spirv_linker.c | 1324 ++++++++++++++++++++++++++++ src/gallium/auxiliary/spirv/spirv_linker.h | 67 ++ 3 files changed, 1394 insertions(+), 1 deletion(-) create mode 100644 src/gallium/auxiliary/spirv/spirv_linker.c create mode 100644 src/gallium/auxiliary/spirv/spirv_linker.h
diff --git a/src/gallium/auxiliary/Makefile.sources b/src/gallium/auxiliary/Makefile.sources index f4817742ff..91aac49dfb 100644 --- a/src/gallium/auxiliary/Makefile.sources +++ b/src/gallium/auxiliary/Makefile.sources @@ -314,7 +314,9 @@ NIR_SOURCES := \ SPIRV_SOURCES := \ spirv/spirv_utils.c \ - spirv/spirv_utils.h + spirv/spirv_utils.h \ + spirv/spirv_linker.c \ + spirv/spirv_linker.h VL_SOURCES := \ vl/vl_bicubic_filter.c \ diff --git a/src/gallium/auxiliary/spirv/spirv_linker.c b/src/gallium/auxiliary/spirv/spirv_linker.c new file mode 100644 index 0000000000..9d060be0cc --- /dev/null +++ b/src/gallium/auxiliary/spirv/spirv_linker.c @@ -0,0 +1,1324 @@ +/************************************************************************** + * + * Copyright 2017 Pierre Moreau + * All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sub license, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice (including the + * next paragraph) shall be included in all copies or substantial portions + * of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS AND/OR ITS SUPPLIERS BE LIABLE FOR + * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + **************************************************************************/ + +#include "spirv_linker.h" +#include "spirv_utils.h" + +#include "compiler/spirv/spirv.h" +#include "util/u_debug.h" +#include "util/u_hash_table.h" +#include "util/u_pointer.h" + +#include <stdlib.h> +#include <stdio.h> +#include <string.h> + +#define PTR_TO_UINT(x) ((unsigned)pointer_to_uintptr(x)) +#define UINT_TO_PTR(x) (uintptr_to_pointer((uintptr_t)(x))) + +/** + * Extracts the opcode and the number of words making up this instruction. + * + * @param binary binary to extract the information from + * @param word_id index of the word to extract + * @param word_count if not null, will be set to the number of words making up + * the instruction, otherwise will be left untouched + * @return the opcode + */ +static SpvOp +spirv_get_opcode(const char *binary, size_t word_offset, unsigned *word_count) +{ + const unsigned desc_word = spirv_get_word(binary, word_offset); + if (word_count) + *word_count = desc_word >> SpvWordCountShift; + return (SpvOp) (desc_word & SpvOpCodeMask); +} + +static unsigned +spirv_spvid_hash(void *id) +{ + return PTR_TO_UINT(id); +} + +static int +spirv_spvid_compare(void *id1, void *id2) +{ + return PTR_TO_UINT(id1) != PTR_TO_UINT(id2); +} + +/** + * Adds a specified base ID to the ID found at a specified position in the + * binary. + */ +static void +spirv_bump_id(char *binary, unsigned word_offset, void *base_id) +{ + SpvId old_id = spirv_get_word(binary, word_offset); + spirv_set_word(binary, word_offset, PTR_TO_UINT(base_id) + old_id); +} + +/** + * Replaces an ID with another one, if found in the link table. + */ +static void +spirv_link_ids(char *binary, unsigned word_offset, void *link_table) +{ + SpvId old_id = spirv_get_word(binary, word_offset); + void *new_id_ptr = util_hash_table_get((struct util_hash_table *) link_table, + UINT_TO_PTR(old_id)); + SpvId new_id = PTR_TO_UINT(new_id_ptr); + if (new_id_ptr != NULL) + spirv_set_word(binary, word_offset, new_id); +} + +/** + * Associates the given variable to its width, if found. + */ +static void +spirv_register_variable(char *binary, unsigned type_offset, + unsigned variable_offset, struct util_hash_table *types, + struct util_hash_table *variables) +{ + SpvId type_id = spirv_get_word(binary, type_offset); + SpvId var_id = spirv_get_word(binary, variable_offset); + void *width_ptr = util_hash_table_get(types, UINT_TO_PTR(type_id)); + if (width_ptr != NULL) + util_hash_table_set(variables, UINT_TO_PTR(var_id), width_ptr); +} + +/** + * Applies the given function onto the specified IDs. + */ +static void +spirv_transform_ids(void (*transform_id)(char *, unsigned, void *), + char *binary, unsigned offset, + void *data, int ids_count, ...) +{ + va_list ids; + va_start(ids, ids_count); + for (int i = 0; i < ids_count; ++i) + transform_id(binary, offset + va_arg(ids, SpvId), data); + va_end(ids); +} + +/** + * Applies the given function to all IDs found in the given binary. + */ +static int +spirv_transform_binary(char *binary, unsigned binary_word_count, + void (*transform_id)(char *, unsigned, void *), + void *data) +{ + unsigned i = 5u; // Skip header + unsigned j = 0u, k = 0u; + int ret = 0; + unsigned opcode = 0u, insn_word_count = 0u; + + struct util_hash_table *int_types = util_hash_table_create(&spirv_spvid_hash, + &spirv_spvid_compare); + struct util_hash_table *int_variables = util_hash_table_create(&spirv_spvid_hash, + &spirv_spvid_compare); + + while (i < binary_word_count) { + opcode = spirv_get_opcode(binary, i, &insn_word_count); + + switch (opcode) { + case SpvOpNop: + break; + case SpvOpUndef: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpSourceContinued: + break; + case SpvOpSource: + if (insn_word_count > 3u) + transform_id(binary, i + 3u, data); + break; + case SpvOpSourceExtension: + break; + case SpvOpName: + case SpvOpMemberName: + case SpvOpString: + case SpvOpLine: + transform_id(binary, i + 1u, data); + break; + case SpvOpNoLine: + break; + case SpvOpDecorate: + case SpvOpMemberDecorate: + case SpvOpDecorationGroup: + transform_id(binary, i + 1u, data); + break; + case SpvOpGroupDecorate: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpGroupMemberDecorate: + transform_id(binary, i + 1u, data); + for (j = 2u; j < insn_word_count; j += 2u) + transform_id(binary, i + j, data); + break; + case SpvOpExtension: + break; + case SpvOpExtInstImport: + transform_id(binary, i + 1u, data); + break; + case SpvOpExtInst: + spirv_transform_ids(transform_id, binary, i, data, 3, 1u, 2u, 3u); + for (j = 5u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpMemoryModel: + break; + case SpvOpEntryPoint: + transform_id(binary, i + 2u, data); + j = 3u; + k = 0u; + /* We have to compute first how long the string is */ + while (j < insn_word_count && + binary[(i + j) * sizeof(spirv_word) + k] != '\0') { + ++k; + if (k >= sizeof(spirv_word)) { + k = 0u; + ++j; + } + } + if (j < insn_word_count) + ++k; + if (k > 0u) { + k = 0u; + ++j; + } + for (; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpExecutionMode: + transform_id(binary, i + 1u, data); + break; + case SpvOpCapability: + break; + case SpvOpTypeVoid: + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + transform_id(binary, i + 1u, data); + if (opcode == SpvOpTypeInt) { + SpvId new_id = spirv_get_word(binary, i + 1u); + uint32_t width = spirv_get_word(binary, i + 2u); + width = width / 32u + (width % 32u != 0u); + util_hash_table_set(int_types, UINT_TO_PTR(new_id), + UINT_TO_PTR(width)); + } + break; + case SpvOpTypeVector: + case SpvOpTypeMatrix: + case SpvOpTypeImage: + for (j = 1u; j < 3u; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpTypeSampler: + transform_id(binary, i + 1u, data); + break; + case SpvOpTypeSampledImage: + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + case SpvOpTypeStruct: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpTypeOpaque: + transform_id(binary, i + 1u, data); + break; + case SpvOpTypePointer: + spirv_transform_ids(transform_id, binary, i, data, 2, 1u, 3u); + break; + case SpvOpTypeFunction: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpTypeEvent: + case SpvOpTypeDeviceEvent: + case SpvOpTypeReserveId: + case SpvOpTypeQueue: + case SpvOpTypePipe: + case SpvOpTypeForwardPointer: + transform_id(binary, i + 1u, data); + break; + case SpvOpConstantTrue: + case SpvOpConstantFalse: + case SpvOpConstant: + for (j = 1u; j < 3u; ++j) + transform_id(binary, i + j, data); + if (opcode == SpvOpConstant) + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpConstantComposite: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpConstantSampler: + case SpvOpConstantNull: + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantFalse: + case SpvOpSpecConstant: + for (j = 1u; j < 3u; ++j) + transform_id(binary, i + j, data); + if (opcode == SpvOpSpecConstant) + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpSpecConstantComposite: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpSpecConstantOp: + for (j = 1u; j < 3u; ++j) + transform_id(binary, i + j, data); + for (j = 4u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpVariable: + for (j = 1u; j < 3u; ++j) + transform_id(binary, i + j, data); + if (insn_word_count == 5u) + transform_id(binary, i + 4u, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpImageTexelPointer: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpLoad: + for (j = 1u; j < 4u; ++j) + transform_id(binary, i + j, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpStore: + case SpvOpCopyMemory: + for (j = 1u; j < 3u; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpCopyMemorySized: + for (j = 1u; j < 4u; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpArrayLength: + case SpvOpGenericPtrMemSemantics: + for (j = 1u; j < 4u; ++j) + transform_id(binary, i + j, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpInBoundsPtrAccessChain: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpFunction: + spirv_transform_ids(transform_id, binary, i, data, + 3, 1u, 2u, 4u); + break; + case SpvOpFunctionParameter: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpFunctionEnd: + break; + case SpvOpFunctionCall: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpSampledImage: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpImageSampleImplicitLod: + case SpvOpImageSampleExplicitLod: + case SpvOpImageSampleProjImplicitLod: + case SpvOpImageSampleProjExplicitLod: + case SpvOpImageFetch: + case SpvOpImageRead: + for (j = 1u; j < 5u; ++j) + transform_id(binary, i + j, data); + for (j = 6u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + if (opcode == SpvOpImageRead) + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpImageSampleDrefImplicitLod: + case SpvOpImageSampleDrefExplicitLod: + case SpvOpImageSampleProjDrefImplicitLod: + case SpvOpImageSampleProjDrefExplicitLod: + case SpvOpImageGather: + case SpvOpImageDrefGather: + for (j = 1u; j < 6u; ++j) + transform_id(binary, i + j, data); + for (j = 7u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + if (opcode != SpvOpImageGather) + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpImageWrite: + for (j = 1u; j < 4u; ++j) + transform_id(binary, i + j, data); + for (j = 5u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpImage: + case SpvOpImageQueryFormat: + case SpvOpImageQueryOrder: + case SpvOpImageQuerySizeLod: + case SpvOpImageQuerySize: + case SpvOpImageQueryLod: + case SpvOpImageQueryLevels: + case SpvOpImageQuerySamples: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + if (opcode != SpvOpImage && opcode != SpvOpImageQueryLod) + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpImageSparseSampleImplicitLod: + case SpvOpImageSparseSampleExplicitLod: + case SpvOpImageSparseFetch: + case SpvOpImageSparseRead: + for (j = 1u; j < 5u; ++j) + transform_id(binary, i + j, data); + for (j = 6u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpImageSparseSampleDrefImplicitLod: + case SpvOpImageSparseSampleDrefExplicitLod: + case SpvOpImageSparseGather: + case SpvOpImageSparseDrefGather: + for (j = 1u; j < 6u; ++j) + transform_id(binary, i + j, data); + for (j = 7u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: + assert(false); + return -1; + case SpvOpImageSparseTexelsResident: + for (j = 1u; j < 4u; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpConvertFToU: + case SpvOpConvertFToS: + case SpvOpUConvert: + case SpvOpSConvert: + case SpvOpConvertPtrToU: + case SpvOpSatConvertSToU: + case SpvOpSatConvertUToS: + case SpvOpBitcast: + for (j = 1u; j < 4u; ++j) + transform_id(binary, i + j, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpConvertSToF: + case SpvOpConvertUToF: + case SpvOpFConvert: + case SpvOpQuantizeToF16: + case SpvOpConvertUToPtr: + case SpvOpPtrCastToGeneric: + case SpvOpGenericCastToPtr: + case SpvOpGenericCastToPtrExplicit: + for (j = 1u; j < 4u; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpVectorExtractDynamic: + case SpvOpVectorInsertDynamic: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + if (opcode == SpvOpVectorExtractDynamic) + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpVectorShuffle: + case SpvOpCompositeInsert: + for (j = 1u; j < 5u; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpCompositeConstruct: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpCompositeExtract: + for (j = 1u; j < 4u; ++j) + transform_id(binary, i + j, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpCopyObject: + case SpvOpTranspose: + case SpvOpSNegate: + case SpvOpFNegate: + case SpvOpIAdd: + case SpvOpISub: + case SpvOpIMul: + case SpvOpUDiv: + case SpvOpSDiv: + case SpvOpUMod: + case SpvOpSRem: + case SpvOpSMod: + case SpvOpShiftRightLogical: + case SpvOpShiftRightArithmetic: + case SpvOpShiftLeftLogical: + case SpvOpBitwiseOr: + case SpvOpBitwiseXor: + case SpvOpBitwiseAnd: + case SpvOpNot: + case SpvOpBitFieldInsert: + case SpvOpBitFieldSExtract: + case SpvOpBitFieldUExtract: + case SpvOpSelect: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + if (opcode != SpvOpTranspose && opcode != SpvOpFNegate) + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpFAdd: + case SpvOpFSub: + case SpvOpFMul: + case SpvOpFDiv: + case SpvOpFRem: + case SpvOpFMod: + case SpvOpVectorTimesScalar: + case SpvOpMatrixTimesScalar: + case SpvOpVectorTimesMatrix: + case SpvOpMatrixTimesVector: + case SpvOpMatrixTimesMatrix: + case SpvOpOuterProduct: + case SpvOpDot: + case SpvOpIAddCarry: + case SpvOpISubBorrow: + case SpvOpUMulExtended: + case SpvOpSMulExtended: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpBitReverse: + case SpvOpBitCount: + case SpvOpAny: + case SpvOpAll: + case SpvOpIsNan: + case SpvOpIsInf: + case SpvOpIsFinite: + case SpvOpIsNormal: + case SpvOpSignBitSet: + case SpvOpPhi: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + if (opcode == SpvOpBitReverse || opcode == SpvOpBitCount || + opcode == SpvOpPhi) + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpLessOrGreater: + case SpvOpOrdered: + case SpvOpUnordered: + case SpvOpLogicalEqual: + case SpvOpLogicalNotEqual: + case SpvOpLogicalOr: + case SpvOpLogicalAnd: + case SpvOpLogicalNot: + case SpvOpIEqual: + case SpvOpINotEqual: + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: + case SpvOpULessThan: + case SpvOpSLessThan: + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: + case SpvOpFOrdEqual: + case SpvOpFUnordEqual: + case SpvOpFOrdNotEqual: + case SpvOpFUnordNotEqual: + case SpvOpFOrdLessThan: + case SpvOpFUnordLessThan: + case SpvOpFOrdGreaterThan: + case SpvOpFUnordGreaterThan: + case SpvOpFOrdLessThanEqual: + case SpvOpFUnordLessThanEqual: + case SpvOpFOrdGreaterThanEqual: + case SpvOpFUnordGreaterThanEqual: + case SpvOpDPdx: + case SpvOpDPdy: + case SpvOpFwidth: + case SpvOpDPdxFine: + case SpvOpDPdyFine: + case SpvOpFwidthFine: + case SpvOpDPdxCoarse: + case SpvOpDPdyCoarse: + case SpvOpFwidthCoarse: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpLoopMerge: + for (j = 1u; j < 3u; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpSelectionMerge: + case SpvOpLabel: + case SpvOpBranch: + transform_id(binary, i + 1u, data); + break; + case SpvOpBranchConditional: + spirv_transform_ids(transform_id, binary, i, data, 3, 1u, 2u, 3u); + break; + case SpvOpSwitch: + for (j = 1u; j < 4u; ++j) + transform_id(binary, i + j, data); + { + SpvId selector_id = spirv_get_word(binary, i + 1u); + void *width_ptr = util_hash_table_get(int_variables, + UINT_TO_PTR(selector_id)); + assert(width_ptr); + unsigned width = PTR_TO_UINT(width_ptr); + for (j = 3u + width; j < insn_word_count; j += 1u + width) + transform_id(binary, i + j, data); + } + break; + case SpvOpKill: + case SpvOpReturn: + case SpvOpUnreachable: + break; + case SpvOpReturnValue: + case SpvOpLifetimeStart: + case SpvOpLifetimeStop: + transform_id(binary, i + 1u, data); + break; + case SpvOpAtomicLoad: + case SpvOpAtomicStore: + case SpvOpAtomicExchange: + case SpvOpAtomicCompareExchange: + case SpvOpAtomicCompareExchangeWeak: + case SpvOpAtomicIIncrement: + case SpvOpAtomicIDecrement: + case SpvOpAtomicIAdd: + case SpvOpAtomicISub: + case SpvOpAtomicSMin: + case SpvOpAtomicUMin: + case SpvOpAtomicSMax: + case SpvOpAtomicUMax: + case SpvOpAtomicAnd: + case SpvOpAtomicOr: + case SpvOpAtomicXor: + case SpvOpAtomicFlagTestAndSet: + case SpvOpAtomicFlagClear: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + if (opcode != SpvOpAtomicStore && + opcode != SpvOpAtomicFlagTestAndSet && + opcode != SpvOpAtomicFlagClear) + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpEmitVertex: + case SpvOpEndPrimitive: + break; + case SpvOpEmitStreamVertex: + case SpvOpEndStreamPrimitive: + transform_id(binary, i + 1u, data); + break; + case SpvOpControlBarrier: + case SpvOpMemoryBarrier: + case SpvOpGroupAsyncCopy: + case SpvOpGroupWaitEvents: + case SpvOpGroupAll: + case SpvOpGroupAny: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpGroupBroadcast: + case SpvOpGroupIAdd: + case SpvOpGroupUMin: + case SpvOpGroupSMin: + case SpvOpGroupUMax: + case SpvOpGroupSMax: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpGroupFAdd: + case SpvOpGroupFMin: + case SpvOpGroupFMax: + case SpvOpSubgroupBallotKHR: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpSubgroupFirstInvocationKHR: + case SpvOpEnqueueMarker: + case SpvOpEnqueueKernel: + case SpvOpGetKernelNDrangeSubGroupCount: + case SpvOpGetKernelNDrangeMaxSubGroupSize: + case SpvOpGetKernelWorkGroupSize: + case SpvOpGetKernelPreferredWorkGroupSizeMultiple: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpRetainEvent: + case SpvOpReleaseEvent: + case SpvOpCreateUserEvent: + case SpvOpIsValidEvent: + case SpvOpSetUserEventStatus: + case SpvOpCaptureEventProfilingInfo: + case SpvOpGetDefaultQueue: + case SpvOpBuildNDRange: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpReadPipe: + case SpvOpWritePipe: + case SpvOpReservedReadPipe: + case SpvOpReservedWritePipe: + case SpvOpGetNumPipePackets: + case SpvOpGetMaxPipePackets: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + spirv_register_variable(binary, i + 1u, i + 2u, int_types, + int_variables); + break; + case SpvOpReserveReadPipePackets: + case SpvOpReserveWritePipePackets: + case SpvOpCommitReadPipe: + case SpvOpCommitWritePipe: + case SpvOpIsValidReserveId: + case SpvOpGroupReserveReadPipePackets: + case SpvOpGroupReserveWritePipePackets: + case SpvOpGroupCommitReadPipe: + case SpvOpGroupCommitWritePipe: + for (j = 1u; j < insn_word_count; ++j) + transform_id(binary, i + j, data); + break; + case SpvOpMax: //FALLTHROUGH + default: + assert(false); + return -1; + } + + i += insn_word_count; + } + + util_hash_table_destroy(int_types); + util_hash_table_destroy(int_variables); + + return ret; +} + +struct spirv_link_data { + struct util_hash_table *exports; + struct util_hash_table *link_table; + const char *linkage_uid; + SpvId export_id; + char **msg; + unsigned *msg_length; +}; + +static const char * +addressing_model_to_string(SpvAddressingModel model) +{ +#define CVT_ADDR_MODEL(v) case SpvAddressingModel##v: return #v + switch (model) { + CVT_ADDR_MODEL(Logical); + CVT_ADDR_MODEL(Physical32); + CVT_ADDR_MODEL(Physical64); + default: + return "Unsupported"; + } +#undef CVT_ADDR_MODEL +} + +static const char * +memory_model_to_string(SpvAddressingModel model) +{ +#define CVT_MEMORY_MODEL(v) case SpvMemoryModel##v: return #v + switch (model) { + CVT_MEMORY_MODEL(Simple); + CVT_MEMORY_MODEL(GLSL450); + CVT_MEMORY_MODEL(OpenCL); + default: + return "Unsupported"; + } +#undef CVT_MEMORY_MODEL +} + +/* Returns PIPE_ERROR when item found */ +static enum pipe_error +find_export(void *export_id, void *linkage_uid, void *data) +{ + struct spirv_link_data *link_data = (struct spirv_link_data *) data; + if (!strcmp(link_data->linkage_uid, (const char *) linkage_uid)) { + link_data->export_id = PTR_TO_UINT(export_id); + return PIPE_ERROR; + } + return PIPE_OK; +} + +static enum pipe_error +generate_link_table(void *import_id, void *linkage_uid, void *data) +{ + struct spirv_link_data *link_data = (struct spirv_link_data *) data; + link_data->export_id = 0u; + link_data->linkage_uid = (const char *) linkage_uid; + util_hash_table_foreach(link_data->exports, &find_export, data); + if (link_data->export_id == 0u) { + *link_data->msg_length = snprintf(NULL, 0, + "SPIR-V linker: Missing symbol \"%s\"\n", + link_data->linkage_uid) + 1; + *link_data->msg = (char *) malloc(*link_data->msg_length); + snprintf(*link_data->msg, *link_data->msg_length, + "SPIR-V linker: Missing symbol \"%s\"\n", + link_data->linkage_uid); + return PIPE_ERROR; + } + util_hash_table_set(link_data->link_table, import_id, + UINT_TO_PTR(link_data->export_id)); + return PIPE_OK; +} + +static bool +spirv_is_capability_op(SpvId opcode) +{ + return opcode == SpvOpCapability; +} + +static bool +spirv_is_extension_op(SpvId opcode) +{ + return opcode == SpvOpExtension; +} + +static bool +spirv_is_ext_inst_import_op(SpvId opcode) +{ + return opcode == SpvOpExtInstImport; +} + +static bool +spirv_is_entry_point_op(SpvId opcode) +{ + return opcode == SpvOpEntryPoint; +} + +static bool +spirv_is_execution_mode_op(SpvId opcode) +{ + return opcode == SpvOpExecutionMode; +} + +static bool +spirv_is_debug_a_op(SpvId opcode) +{ + switch (opcode) { + case SpvOpString: + case SpvOpSourceExtension: + case SpvOpSource: + case SpvOpSourceContinued: + return true; + default: + return false; + } +} + +static bool +spirv_is_debug_b_op(SpvId opcode) +{ + return opcode == SpvOpName || opcode == SpvOpMemberName; +} + +static bool +spirv_is_type_constant_op(SpvId opcode) +{ + switch (opcode) { + case SpvOpLine: + case SpvOpTypeVoid: + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypeVector: + case SpvOpTypeMatrix: + case SpvOpTypeImage: + case SpvOpTypeSampler: + case SpvOpTypeSampledImage: + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + case SpvOpTypeStruct: + case SpvOpTypeOpaque: + case SpvOpTypePointer: + case SpvOpTypeFunction: + case SpvOpTypeEvent: + case SpvOpTypeDeviceEvent: + case SpvOpTypeReserveId: + case SpvOpTypeQueue: + case SpvOpTypePipe: + case SpvOpTypeForwardPointer: + case SpvOpConstantTrue: + case SpvOpConstantFalse: + case SpvOpConstant: + case SpvOpConstantComposite: + case SpvOpConstantSampler: + case SpvOpConstantNull: + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantFalse: + case SpvOpSpecConstant: + case SpvOpSpecConstantComposite: + case SpvOpSpecConstantOp: + case SpvOpVariable: + case SpvOpUndef: + return true; + default: + return false; + } +} + +static void +spirv_copy_sections(char **binaries, const unsigned *lengths, + unsigned *indices, unsigned num_binaries, + char *merged_binary, unsigned *merged_index, + bool (*check_section)(SpvId)) +{ + unsigned i = 0u, j = 0u; + unsigned length = 0u, prev_j = 0u, index = 0u; + const char *binary = NULL; + unsigned opcode = 0u, word_count = 0u; + + for (i = 0u; i < num_binaries; ++i) { + binary = binaries[i]; + length = lengths[i]; + index = indices[i]; + prev_j = j = index; + while (j < length) { + opcode = spirv_get_opcode(binary, j, &word_count); + + prev_j = j; + + if (!check_section(opcode)) + break; + + j += word_count; + } + if (j == length) + prev_j = j; + + memcpy(&merged_binary[*merged_index * sizeof(spirv_word)], + &binary[index * sizeof(spirv_word)], + (prev_j - index) * sizeof(spirv_word)); + *merged_index += prev_j - index; + indices[i] = j; + } +} + +static int +spirv_merge_modules(char **binaries, const unsigned *lengths, + unsigned num_binaries, unsigned max_index, unsigned version, + char *merged_binary, bool create_library, + unsigned *final_length, char **msg, + unsigned *msg_length) +{ + unsigned *indices = (unsigned *) malloc(num_binaries * sizeof(unsigned)); + unsigned i = 0u, j = 0u, merged_index = 0u; + unsigned length = 0u, prev_j = 0u, index = 0u, function_begin = 0u; + int ret = 0; + const char *binary = NULL; + unsigned opcode = 0u, word_count = 0u; + SpvAddressingModel addressing_model, addressing_model_tmp; + const char *model_str = NULL, *model_tmp_str = NULL; + SpvMemoryModel memory_model, memory_model_tmp; + + struct util_hash_table *exports = util_hash_table_create(&spirv_spvid_hash, + &spirv_spvid_compare); + struct util_hash_table *imports = util_hash_table_create(&spirv_spvid_hash, + &spirv_spvid_compare); + struct util_hash_table *link_table = util_hash_table_create(&spirv_spvid_hash, + &spirv_spvid_compare); + + /* Generate new header */ + spirv_set_word(merged_binary, 0u, SpvMagicNumber); + spirv_set_word(merged_binary, 1u, version); + spirv_set_word(merged_binary, 2u, 0u); + spirv_set_word(merged_binary, 3u, max_index); + spirv_set_word(merged_binary, 4u, 0u); + merged_index += 5u; + + for (i = 0u; i < num_binaries; ++i) + indices[i] = 5u; + + /* Copy OpCapabilities */ + spirv_copy_sections(binaries, lengths, indices, num_binaries, merged_binary, + &merged_index, &spirv_is_capability_op); + + /* Copy OpExtensions */ + spirv_copy_sections(binaries, lengths, indices, num_binaries, merged_binary, + &merged_index, &spirv_is_extension_op); + + /* Copy OpExtInstImport */ + spirv_copy_sections(binaries, lengths, indices, num_binaries, merged_binary, + &merged_index, &spirv_is_ext_inst_import_op); + + /* Merge OpMemoryModel */ + for (i = 0u; i < num_binaries; ++i) { + binary = binaries[i]; + index = indices[i]; + opcode = spirv_get_opcode(binary, index, &word_count); + + assert(opcode == SpvOpMemoryModel); + if (i == 0u) { + addressing_model = (SpvAddressingModel) spirv_get_word(binary, + index + 1u); + memory_model = (SpvMemoryModel) spirv_get_word(binary, index + 2u); + } else { + addressing_model_tmp = (SpvAddressingModel) spirv_get_word(binary, + index + 1u); + memory_model_tmp = (SpvMemoryModel) spirv_get_word(binary, index + 2u); + if (addressing_model != addressing_model_tmp) { + model_str = addressing_model_to_string(addressing_model); + model_tmp_str = addressing_model_to_string(addressing_model_tmp); + *msg_length = snprintf(NULL, 0, + "SPIR-V linker: Inconsistent addressing models: '%s' for binary 0 and '%s' for binary %d\n", + model_str, model_tmp_str, i) + 1; + *msg = (char *) malloc(*msg_length); + snprintf(*msg, *msg_length, + "SPIR-V linker: Inconsistent addressing models: '%s' for binary 0 and '%s' for binary %d\n", + model_str, model_tmp_str, i); + ret = -1; + goto end; + } + if (memory_model != memory_model_tmp) { + model_str = memory_model_to_string(memory_model); + model_tmp_str = memory_model_to_string(memory_model_tmp); + *msg_length = snprintf(NULL, 0, + "SPIR-V linker: Inconsistent memory models: '%s' for binary 0 and '%s' for binary %d\n", + model_str, model_tmp_str, i) + 1; + *msg = (char *) malloc(*msg_length); + snprintf(*msg, *msg_length, + "SPIR-V linker: Inconsistent memory models: '%s' for binary 0 and '%s' for binary %d\n", + model_str, model_tmp_str, i); + ret = -1; + goto end; + } + } + + indices[i] = index + 3u; + } + spirv_set_word(merged_binary, merged_index, (3u << SpvWordCountShift) + 14); + spirv_set_word(merged_binary, merged_index + 1u, addressing_model); + spirv_set_word(merged_binary, merged_index + 2u, memory_model); + merged_index += 3u; + + /* Copy OpEntryPoint */ + spirv_copy_sections(binaries, lengths, indices, num_binaries, merged_binary, + &merged_index, &spirv_is_entry_point_op); + + /* Copy OpExecutionMode */ + spirv_copy_sections(binaries, lengths, indices, num_binaries, merged_binary, + &merged_index, &spirv_is_execution_mode_op); + + /* Copy debug a) */ + spirv_copy_sections(binaries, lengths, indices, num_binaries, merged_binary, + &merged_index, &spirv_is_debug_a_op); + /* Copy debug b) */ + spirv_copy_sections(binaries, lengths, indices, num_binaries, merged_binary, + &merged_index, &spirv_is_debug_b_op); + + /* Copy annotations */ + for (i = 0u; i < num_binaries; ++i) { + binary = binaries[i]; + length = lengths[i]; + index = indices[i]; + j = index; + prev_j = j; + while (j < length) { + opcode = spirv_get_opcode(binary, j, &word_count); + + prev_j = j; + + if (opcode != SpvOpDecorate && opcode != SpvOpMemberDecorate && + opcode != SpvOpGroupDecorate && opcode != SpvOpGroupMemberDecorate && + opcode != SpvOpDecorationGroup) + break; + + j += word_count; + + if (opcode != SpvOpDecorate) + continue; + const SpvDecoration decoration = spirv_get_word(binary, prev_j + 2u); + if (decoration != SpvDecorationLinkageAttributes) + continue; + + const SpvId id = spirv_get_word(binary, prev_j + 1u); + const SpvLinkageType linkage_type = spirv_get_word(binary, prev_j + (word_count - 1u)); + if (linkage_type == SpvLinkageTypeExport) + util_hash_table_set(exports, UINT_TO_PTR(id), + spirv_get_string(binary, prev_j + 3u)); + else + util_hash_table_set(imports, UINT_TO_PTR(id), + spirv_get_string(binary, prev_j + 3u)); + + /* Remove exports from SPIR-V if we are not building a library */ + if (create_library && linkage_type == SpvLinkageTypeExport) + continue; + + memcpy(&merged_binary[merged_index * sizeof(spirv_word)], + &binary[index * sizeof(spirv_word)], + (prev_j - index) * sizeof(spirv_word)); + merged_index += prev_j - index; + index = j; + } + if (j == length) + prev_j = j; + + memcpy(&merged_binary[merged_index * sizeof(spirv_word)], + &binary[index * sizeof(spirv_word)], + (prev_j - index) * sizeof(spirv_word)); + merged_index += prev_j - index; + indices[i] = j; + } + + struct spirv_link_data link_data = { + .exports = exports, + .link_table = link_table, + .linkage_uid = NULL, + .export_id = 0u, + .msg = msg, + .msg_length = msg_length + }; + ret = util_hash_table_foreach(imports, &generate_link_table, &link_data); + if (ret != PIPE_OK) + goto end; + + /* Copy types/constants/global variables/OpUndef */ + for (i = 0u; i < num_binaries; ++i) { + binary = binaries[i]; + length = lengths[i]; + index = indices[i]; + j = index; + prev_j = j; + while (j < length) { + opcode = spirv_get_opcode(binary, j, &word_count); + + prev_j = j; + + SpvStorageClass storage_type = spirv_get_word(binary, j + 3u); + if (!spirv_is_type_constant_op(opcode) || + (opcode == SpvOpVariable && + storage_type == SpvStorageClassFunction)) + break; + + j += word_count; + + /* Remove imported variables */ + void *tmp = UINT_TO_PTR(spirv_get_word(binary, prev_j + 2u)); + if (opcode == SpvOpVariable && + util_hash_table_get(link_table, tmp) != NULL) { + memcpy(&merged_binary[merged_index * sizeof(spirv_word)], + &binary[index * sizeof(spirv_word)], + (prev_j - index) * sizeof(spirv_word)); + merged_index += prev_j - index; + index = j; + } + } + if (j == length) + prev_j = j; + + memcpy(&merged_binary[merged_index * sizeof(spirv_word)], + &binary[index * sizeof(spirv_word)], + (prev_j - index) * sizeof(spirv_word)); + merged_index += prev_j - index; + indices[i] = j; + } + + /* Skip function declarations */ + for (i = 0u; i < num_binaries; ++i) { + binary = binaries[i]; + length = lengths[i]; + index = indices[i]; + j = index; + function_begin = j; + while (j < length) { + opcode = spirv_get_opcode(binary, j, &word_count); + + if (opcode == SpvOpFunction) + function_begin = j; + + /* This was not a prototype: revert to the beginning of the function */ + if (opcode != SpvOpFunction && opcode != SpvOpFunctionParameter && + opcode != SpvOpFunctionEnd) { + j = function_begin; + break; + } + + j += word_count; + } + + indices[i] = j; + } + + /* Copy function defintions */ + for (i = 0u; i < num_binaries; ++i) { + binary = binaries[i]; + length = lengths[i]; + index = indices[i]; + memcpy(&merged_binary[merged_index * sizeof(spirv_word)], + &binary[index * sizeof(spirv_word)], + (length - index) * sizeof(spirv_word)); + merged_index += length - index; + } + + /* Link everything together */ + ret = spirv_transform_binary(merged_binary, merged_index, &spirv_link_ids, + link_table); + if (ret) + goto end; + + if (final_length) + *final_length = merged_index; + +end: + free(indices); + util_hash_table_destroy(imports); + util_hash_table_destroy(exports); + util_hash_table_destroy(link_table); + return ret; +} + +const char * +spirv_link_binaries(const char **binaries, const unsigned *binaries_word_count, + unsigned num_binaries, bool create_library, + unsigned *linked_word_count, char **msg, + unsigned *length) +{ + char **tmp_binaries = NULL; + char *linked_binary = NULL; + const char *binary = NULL; + char *tmp_binary = NULL; + unsigned total_word_count = 0u, binary_max_id = 0u, id_upper_bound = 0u; + unsigned word_count = 0u, byte_count = 0u; + int i = 0, j = 0, ret = 0; + unsigned max_version = 0u, version = 0u; + + /* As we need to bump IDs in each module, so as they do not conflict as we + * merge them together, we need to make a copy which we can modify. + */ + tmp_binaries = (char **) malloc(num_binaries * sizeof(char *)); + + for (i = 0; i < (int) num_binaries; ++i) { + word_count = binaries_word_count[i]; + binary = binaries[i]; + byte_count = word_count * sizeof(spirv_word); + + version = spirv_get_word(binary, 1u); + if (version > SpvVersion) { + *length = snprintf(NULL, 0, + "SPIR-V linker: Binary %d uses an unsupported SPIR-V version: %u.%u\n", + i, ((version >> 16u) & 0xff), + ((version >> 8u) & 0xff)) + 1; + *msg = (char *) malloc(*length); + snprintf(*msg, *length, + "SPIR-V linker: Binary %d uses an unsupported SPIR-V version: %u.%u\n", + i, ((version >> 16u) & 0xff), ((version >> 8u) & 0xff)); + goto error; + } + if (version > max_version) + max_version = version; + + tmp_binary = (char *) malloc(byte_count); + if (!tmp_binary) + goto error; + memcpy(tmp_binary, binary, byte_count); + tmp_binaries[i] = tmp_binary; + + binary_max_id = spirv_get_word(binary, 3u) - 1u; + + ret = spirv_transform_binary(tmp_binary, word_count, &spirv_bump_id, + UINT_TO_PTR(id_upper_bound)); + if (ret) + goto error2; + + id_upper_bound += binary_max_id; + total_word_count += word_count; + } + + /* id_upper_bound is currently equal to the highest id being used, so add + * one to get the strict upper bound + */ + ++id_upper_bound; + + linked_binary = (char *) malloc(total_word_count * sizeof(spirv_word)); + ret = spirv_merge_modules(tmp_binaries, binaries_word_count, num_binaries, + id_upper_bound, max_version, linked_binary, + create_library, &total_word_count, msg, length); + if (ret) + goto error3; + + if (linked_word_count) + *linked_word_count = total_word_count; + + for (i = 0; i < (int) num_binaries; ++i) + free(tmp_binaries[i]); + free(tmp_binaries); + + return linked_binary; + +error3: + free(linked_binary); +error2: + if (i < (int) num_binaries) + free(tmp_binaries[i]); +error: + for (j = 0; j < i; ++j) + free(tmp_binaries[j]); + free(tmp_binaries); + return NULL; +} diff --git a/src/gallium/auxiliary/spirv/spirv_linker.h b/src/gallium/auxiliary/spirv/spirv_linker.h new file mode 100644 index 0000000000..0090c8071f --- /dev/null +++ b/src/gallium/auxiliary/spirv/spirv_linker.h @@ -0,0 +1,67 @@ +/************************************************************************** + * + * Copyright 2017 Pierre Moreau + * All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sub license, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice (including the + * next paragraph) shall be included in all copies or substantial portions + * of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS AND/OR ITS SUPPLIERS BE LIABLE FOR + * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + **************************************************************************/ + +#ifndef SPIRV_LINKER_H +#define SPIRV_LINKER_H + +#include <stdbool.h> + +#if defined __cplusplus +extern "C" { +#endif + +/** + * Link SPIR-V binaries into a SPIR-V library or a SPIR-V executable. + * + * The given SPIR-V modules are expected to + * - have been validated; + * - use the same endianness as the CPU. + * + * @param binaries binaries to be linked together + * @param binaries_word_count number of words making of each binary + * @param num_binaries number of binaries to be linked + * @param create_library whether to create a library (a SPIR-V module with no + * import linkage attributes, only export) or an + * executable (a SPIR-V module with no linkage attributes) + * @param linked_word_count if specified, it will contain the number of words + * making up the resulting linked SPIR-V binary + * @param msg the error message if something wrong happened during linking + * @param length the length of the error message msg + * @return if linking was successful, the linked SPIR-V binary, a nul pointer + * otherwise + */ +const char * +spirv_link_binaries(const char **binaries, const unsigned *binaries_word_count, + unsigned num_binaries, bool create_library, + unsigned *linked_word_count, char **msg, + unsigned *length); + +#if defined __cplusplus +} +#endif + +#endif /* SPIRV_LINKER_H */ -- 2.12.2 _______________________________________________ mesa-dev mailing list mesa-dev@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/mesa-dev