This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 40fab65be3 [Unity] Update LM Sample builtins (#14793)
40fab65be3 is described below
commit 40fab65be3ec1ce126e94a1c3f8be778a5764feb
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat May 6 19:47:26 2023 -0400
[Unity] Update LM Sample builtins (#14793)
This PR updates the LM sample routine builtins.
Make the sample from logits aware of temperature.
Add sample from prob that samples from probablity
distribution after softmax scaling.
---
src/runtime/relax_vm/lm_support.cc | 92 +++++++++++++++++++++++++++++++-------
1 file changed, 77 insertions(+), 15 deletions(-)
diff --git a/src/runtime/relax_vm/lm_support.cc
b/src/runtime/relax_vm/lm_support.cc
index 8b867cc602..8f7e8ebdf9 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -214,34 +214,41 @@ int SampleTopPFromLogits(NDArray logits, double
temperature, double top_p, doubl
for (size_t i = 0; i < data.size(); ++i) {
data[i] = std::make_pair(plogits[i], static_cast<int>(i));
}
- // sort by logits from smallest to largest
- std::sort(data.begin(), data.end());
- float max_value = data.back().first;
+
+ auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float,
int>& rhs) {
+ return lhs.first > rhs.first;
+ };
+ // sort by logits from largest to smallest
+ std::sort(data.begin(), data.end(), fcmp);
+
// argmax
if (temperature < 1e-6f) {
- return data.back().second;
+ return data[0].second;
}
- // compute expf
- float sum = 0.0f;
- for (size_t i = 0; i < data.size(); ++i) {
- data[i].first = expf(data[i].first - max_value);
- sum += data[i].first;
+
+ // compute expf scaled by temp
+ float sum = 0.0f, logit_scale = 1.0f / temperature;
+ float max_value = data[0].first;
+ for (auto it = data.begin(); it != data.end(); ++it) {
+ it->first = expf((it->first - max_value) * logit_scale);
+ sum += it->first;
}
+
// do a cumsum in order of data
float cum_sum_prob = 0.0f;
float top_p_sum = 0.0f;
- for (auto rit = data.rbegin(); rit != data.rend(); ++rit) {
- float prob = rit->first / sum;
+ for (auto it = data.begin(); it != data.end(); ++it) {
+ float prob = it->first / sum;
if (cum_sum_prob < top_p) {
top_p_sum += prob;
}
cum_sum_prob += prob;
- rit->first = cum_sum_prob;
+ it->first = cum_sum_prob;
}
// pick a number based on random in (0, 1)
- for (auto rit = data.rbegin(); rit != data.rend(); ++rit) {
- if (uniform_sample < rit->first / top_p_sum) {
- return rit->second;
+ for (auto it = data.begin(); it != data.end(); ++it) {
+ if (uniform_sample < it->first / top_p_sum) {
+ return it->second;
}
}
ICHECK_LE(uniform_sample, data[0].first);
@@ -250,6 +257,61 @@ int SampleTopPFromLogits(NDArray logits, double
temperature, double top_p, doubl
TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_logits").set_body_typed(SampleTopPFromLogits);
+int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) {
+ ICHECK(prob.IsContiguous());
+ ICHECK(prob.DataType() == DataType::Float(32));
+
+ if (prob->device.device_type != kDLCPU) {
+ prob = prob.CopyTo(DLDevice{kDLCPU, 0});
+ }
+
+ ICHECK(prob->device.device_type == kDLCPU);
+
+ for (int i = 0; i < prob->ndim - 1; ++i) {
+ ICHECK_EQ(prob->shape[i], 1) << "The leading dimensions of logits must be
1";
+ }
+
+ std::vector<std::pair<float, int>> data;
+ data.resize(prob->shape[prob->ndim - 1]);
+ const float* p_prob = static_cast<float*>(prob->data);
+ for (size_t i = 0; i < data.size(); ++i) {
+ data[i] = std::make_pair(p_prob[i], static_cast<int>(i));
+ }
+
+ auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float,
int>& rhs) {
+ return lhs.first > rhs.first;
+ };
+
+ // sort by logits from largest to smallest
+ std::sort(data.begin(), data.end(), fcmp);
+
+ if (top_p < 1e-6f) {
+ return data.begin()->second;
+ }
+
+ // do a cumsum in order of data
+ float cum_sum_prob = 0.0f;
+ float top_p_sum = 0.0f;
+ for (auto it = data.begin(); it != data.end(); ++it) {
+ float prob = it->first;
+ if (cum_sum_prob < top_p) {
+ top_p_sum += prob;
+ }
+ cum_sum_prob += prob;
+ it->first = cum_sum_prob;
+ }
+ // pick a number based on random in (0, 1)
+ for (auto it = data.begin(); it != data.end(); ++it) {
+ if (uniform_sample < it->first / top_p_sum) {
+ return it->second;
+ }
+ }
+ ICHECK_LE(uniform_sample, data[0].first);
+ return data[0].second;
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb);
+
} // namespace relax_vm
} // namespace runtime
} // namespace tvm