viiccwen commented on code in PR #918:
URL: https://github.com/apache/mahout/pull/918#discussion_r2725082321
##########
qdp/qdp-kernels/src/amplitude.cu:
##########
@@ -391,6 +455,46 @@ __global__ void l2_norm_batch_kernel(
}
}
+/// Kernel: accumulate L2 norms for a batch (float32).
+/// Grid is organized as (blocks_per_sample * num_samples) blocks.
+__global__ void l2_norm_batch_kernel_f32(
+ const float* __restrict__ input_batch,
+ size_t num_samples,
+ size_t sample_len,
+ size_t blocks_per_sample,
+ float* __restrict__ out_norms
+) {
+ const size_t sample_idx = blockIdx.x / blocks_per_sample;
+ if (sample_idx >= num_samples) return;
+
+ const size_t block_in_sample = blockIdx.x % blocks_per_sample;
+ const size_t base = sample_idx * sample_len;
+
+ const size_t vec_idx = block_in_sample * blockDim.x + threadIdx.x;
+ const size_t stride = blockDim.x * blocks_per_sample;
+
+ float local_sum = 0.0f;
+
+ size_t vec_offset = vec_idx;
+ size_t offset = vec_offset * 2;
+ while (offset + 1 < sample_len) {
+ const float2 v = __ldg(reinterpret_cast<const float2*>(input_batch +
base) + vec_offset);
+ local_sum += v.x * v.x + v.y * v.y;
Review Comment:
Thx for reviewing!
Sure, it may misaligned for float2.
We could handle element 0 (of this sample) once, then shift base by +1 so
the remaining pointer is 8B-aligned.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]