// Copyright (c) the JPEG XL Project Authors. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file.
// These templates are not found via ADL. using hwy::HWY_NAMESPACE::Eq; using hwy::HWY_NAMESPACE::IfThenElse; using hwy::HWY_NAMESPACE::Lt; using hwy::HWY_NAMESPACE::Max;
void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos,
size_t end, size_t prop) { auto cmp = [&](size_t a, size_t b) { returnstatic_cast<int32_t>(tree_samples.Property(prop, a)) - static_cast<int32_t>(tree_samples.Property(prop, b));
};
Rng rng(0); while (end > begin + 1) {
{
size_t pivot = rng.UniformU(begin, end);
tree_samples.Swap(begin, pivot);
}
size_t pivot_begin = begin;
size_t pivot_end = pivot_begin + 1; for (size_t i = begin + 1; i < end; i++) {
JXL_DASSERT(i >= pivot_end);
JXL_DASSERT(pivot_end > pivot_begin);
int32_t cmp_result = cmp(i, pivot_begin); if (cmp_result < 0) { // i < pivot, move pivot forward and put i before // the pivot.
tree_samples.ThreeShuffle(pivot_begin, pivot_end, i);
pivot_begin++;
pivot_end++;
} elseif (cmp_result == 0) {
tree_samples.Swap(pivot_end, i);
pivot_end++;
}
}
JXL_DASSERT(pivot_begin >= begin);
JXL_DASSERT(pivot_end > pivot_begin);
JXL_DASSERT(pivot_end <= end); for (size_t i = begin; i < pivot_begin; i++) {
JXL_DASSERT(cmp(i, pivot_begin) < 0);
} for (size_t i = pivot_end; i < end; i++) {
JXL_DASSERT(cmp(i, pivot_begin) > 0);
} for (size_t i = pivot_begin; i < pivot_end; i++) {
JXL_DASSERT(cmp(i, pivot_begin) == 0);
} // We now have that [begin, pivot_begin) is < pivot, [pivot_begin, // pivot_end) is = pivot, and [pivot_end, end) is > pivot. // If pos falls in the first or the last interval, we continue in that // interval; otherwise, we are done. if (pivot_begin > pos) {
end = pivot_begin;
} elseif (pivot_end < pos) {
begin = pivot_end;
} else { break;
}
}
}
// Compute the maximum token in the range.
size_t max_symbols = 0; for (size_t pred = 0; pred < num_predictors; pred++) { for (size_t i = begin; i < end; i++) {
uint32_t tok = tree_samples.Token(pred, i);
max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
}
}
max_symbols = Padded(max_symbols);
std::vector<int32_t> counts(max_symbols * num_predictors);
std::vector<uint32_t> tot_extra_bits(num_predictors); for (size_t pred = 0; pred < num_predictors; pred++) { for (size_t i = begin; i < end; i++) {
counts[pred * max_symbols + tree_samples.Token(pred, i)] +=
tree_samples.Count(i);
tot_extra_bits[pred] +=
tree_samples.NBits(pred, i) * tree_samples.Count(i);
}
}
float base_bits;
{
size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
base_bits =
EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
tot_extra_bits[pred];
}
SplitInfo *best = &best_split_nonstatic;
SplitInfo forced_split; // The multiplier ranges cut halfway through the current ranges of static // properties. We do this even if the current node is not a leaf, to // minimize the number of nodes in the resulting tree. for (constauto &mmi : mul_info) {
uint32_t axis;
uint32_t val;
IntersectionType t =
BoxIntersects(static_prop_range, mmi.range, axis, val); if (t == IntersectionType::kNone) continue; if (t == IntersectionType::kInside) {
(*tree)[pos].multiplier = mmi.multiplier; break;
} if (t == IntersectionType::kPartial) {
forced_split.val = tree_samples.QuantizeProperty(axis, val);
forced_split.prop = axis;
forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
best = &forced_split;
best->pos = begin;
JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); for (size_t x = begin; x < end; x++) { if (tree_samples.Property(best->prop, x) <= best->val) {
best->pos++;
}
} break;
}
}
if (best != &forced_split) {
std::vector<int> prop_value_used_count;
std::vector<int> count_increase;
std::vector<size_t> extra_bits_increase; // For each property, compute which of its values are used, and what // tokens correspond to those usages. Then, iterate through the values, // and compute the entropy of each side of the split (of the form `prop > // threshold`). Finally, find the split that minimizes the cost. struct CostInfo { float cost = std::numeric_limits<float>::max(); float extra_cost = 0; float Cost() const { return cost + extra_cost; }
Predictor pred; // will be uninitialized in some cases, but never used.
};
std::vector<CostInfo> costs_l;
std::vector<CostInfo> costs_r;
// TODO(veluca): consider finding multiple splits along a single // property at the same time, possibly with a bottom-up approach. for (size_t i = begin; i < end; i++) {
size_t p = tree_samples.Property(prop, i);
prop_value_used_count[p]++;
last_used = std::max(last_used, p);
first_used = std::min(first_used, p);
}
costs_l.resize(last_used - first_used);
costs_r.resize(last_used - first_used); // For all predictors, compute the right and left costs of each split. for (size_t pred = 0; pred < num_predictors; pred++) { // Compute cost and histogram increments for each property value. for (size_t i = begin; i < end; i++) {
size_t p = tree_samples.Property(prop, i);
size_t cnt = tree_samples.Count(i);
size_t sym = tree_samples.Token(pred, i);
count_increase[p * max_symbols + sym] += cnt;
extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt;
}
memcpy(counts_above.data(), counts.data() + pred * max_symbols,
max_symbols * sizeof counts_above[0]);
memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
size_t extra_bits_below = 0; // Exclude last used: this ensures neither counts_above nor // counts_below is empty. for (size_t i = first_used; i < last_used; i++) { if (!prop_value_used_count[i]) continue;
extra_bits_below += extra_bits_increase[i]; // The increase for this property value has been used, and will not // be used again: clear it. Also below.
extra_bits_increase[i] = 0; for (size_t sym = 0; sym < max_symbols; sym++) {
counts_above[sym] -= count_increase[i * max_symbols + sym];
counts_below[sym] += count_increase[i * max_symbols + sym];
count_increase[i * max_symbols + sym] = 0;
} float rcost = EstimateBits(counts_above.data(), max_symbols) +
tot_extra_bits[pred] - extra_bits_below; float lcost = EstimateBits(counts_below.data(), max_symbols) +
extra_bits_below;
JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); float penalty = 0; // Never discourage moving away from the Weighted predictor. if (tree_samples.PredictorFromIndex(pred) !=
(*tree)[pos].predictor &&
(*tree)[pos].predictor != Predictor::Weighted) {
penalty = change_pred_penalty;
} // If everything else is equal, disfavour Weighted (slower) and // favour Zero (faster if it's the only predictor used in a // group+channel combination) if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
penalty += 1e-8;
} if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
penalty -= 1e-8;
} if (rcost + penalty < costs_r[i - first_used].Cost()) {
costs_r[i - first_used].cost = rcost;
costs_r[i - first_used].extra_cost = penalty;
costs_r[i - first_used].pred =
tree_samples.PredictorFromIndex(pred);
} if (lcost + penalty < costs_l[i - first_used].Cost()) {
costs_l[i - first_used].cost = lcost;
costs_l[i - first_used].extra_cost = penalty;
costs_l[i - first_used].pred =
tree_samples.PredictorFromIndex(pred);
}
}
} // Iterate through the possible splits and find the one with minimum sum // of costs of the two sides.
size_t split = begin; for (size_t i = first_used; i < last_used; i++) { if (!prop_value_used_count[i]) continue;
split += prop_value_used_count[i]; float rcost = costs_r[i - first_used].cost; float lcost = costs_l[i - first_used].cost; // WP was not used + we would use the WP property or predictor bool adds_wp =
(tree_samples.PropertyFromIndex(prop) == kWPProp &&
(used_properties & (1LU << prop)) == 0) ||
((costs_l[i - first_used].pred == Predictor::Weighted ||
costs_r[i - first_used].pred == Predictor::Weighted) &&
(*tree)[pos].predictor != Predictor::Weighted); bool zero_entropy_side = rcost == 0 || lcost == 0;
// Try to avoid introducing WP. if (best_split_nowp.Cost() + threshold < base_bits &&
best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
best = &best_split_nowp;
} // Split along static props if possible and not significantly more // expensive. if (best_split_static.Cost() + threshold < base_bits &&
best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
best = &best_split_static;
} // Split along static props to create constant nodes if possible. if (best_split_static_constant.Cost() + threshold < base_bits) {
best = &best_split_static_constant;
}
}
size_t TreeSamples::Hash1(size_t a) const {
constexpr uint64_t constant = 0x1e35a7bd;
uint64_t h = constant; for (constauto &r : residuals) {
h = h * constant + r[a].tok;
h = h * constant + r[a].nbits;
} for (constauto &p : props) {
h = h * constant + p[a];
} return (h >> 16) & (dedup_table_.size() - 1);
}
size_t TreeSamples::Hash2(size_t a) const {
constexpr uint64_t constant = 0x1e35a7bd1e35a7bd;
uint64_t h = constant; for (constauto &p : props) {
h = h * constant ^ p[a];
} for (constauto &r : residuals) {
h = h * constant ^ r[a].tok;
h = h * constant ^ r[a].nbits;
} return (h >> 16) & (dedup_table_.size() - 1);
}
bool TreeSamples::IsSameSample(size_t a, size_t b) const { bool ret = true; for (constauto &r : residuals) { if (r[a].tok != r[b].tok) {
ret = false;
} if (r[a].nbits != r[b].nbits) {
ret = false;
}
} for (constauto &p : props) { if (p[a] != p[b]) {
ret = false;
}
} return ret;
}
void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, const pixel_type_w *predictions) { for (size_t i = 0; i < predictors.size(); i++) {
pixel_type v = pixel - predictions[static_cast<int>(predictors[i])];
uint32_t tok, nbits, bits;
HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits);
JXL_DASSERT(tok < 256);
JXL_DASSERT(nbits < 256);
residuals[i].emplace_back(
ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)});
} for (size_t i = 0; i < props_to_use.size(); i++) {
props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]]));
}
sample_counts.push_back(1);
num_samples++; if (AddToTableAndMerge(sample_counts.size() - 1)) { for (auto &r : residuals) r.pop_back(); for (auto &p : props) p.pop_back();
sample_counts.pop_back();
}
}
void TreeSamples::Swap(size_t a, size_t b) { if (a == b) return; for (auto &r : residuals) {
std::swap(r[a], r[b]);
} for (auto &p : props) {
std::swap(p[a], p[b]);
}
std::swap(sample_counts[a], sample_counts[b]);
}
void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) { if (b == c) {
Swap(a, b); return;
}
namespace {
std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram,
size_t num_chunks) { if (histogram.empty()) return {}; // TODO(veluca): selecting distinct quantiles is likely not the best // way to go about this.
std::vector<int32_t> thresholds;
uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU);
uint64_t cumsum = 0;
uint64_t threshold = 1; for (size_t i = 0; i + 1 < histogram.size(); i++) {
cumsum += histogram[i]; if (cumsum >= threshold * sum / num_chunks) {
thresholds.push_back(i); while (cumsum > threshold * sum / num_chunks) threshold++;
}
} return thresholds;
}
std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples,
size_t num_chunks) { if (samples.empty()) return {}; int min = *std::min_element(samples.begin(), samples.end());
constexpr int kRange = 512;
min = std::min(std::max(min, -kRange), kRange);
std::vector<uint32_t> counts(2 * kRange + 1); for (int s : samples) {
uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min;
counts[sample_offset]++;
}
std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks); for (auto &v : thresholds) v += min; return thresholds;
}
} // namespace
void TreeSamples::PreQuantizeProperties( const StaticPropRange &range, const std::vector<ModularMultiplierInfo> &multiplier_info, const std::vector<uint32_t> &group_pixel_count, const std::vector<uint32_t> &channel_pixel_count,
std::vector<pixel_type> &pixel_samples,
std::vector<pixel_type> &diff_samples, size_t max_property_values) { // If we have forced splits because of multipliers, choose channel and group // thresholds accordingly.
std::vector<int32_t> group_multiplier_thresholds;
std::vector<int32_t> channel_multiplier_thresholds; for (constauto &v : multiplier_info) { if (v.range[0][0] != range[0][0]) {
channel_multiplier_thresholds.push_back(v.range[0][0] - 1);
} if (v.range[0][1] != range[0][1]) {
channel_multiplier_thresholds.push_back(v.range[0][1] - 1);
} if (v.range[1][0] != range[1][0]) {
group_multiplier_thresholds.push_back(v.range[1][0] - 1);
} if (v.range[1][1] != range[1][1]) {
group_multiplier_thresholds.push_back(v.range[1][1] - 1);
}
}
std::sort(channel_multiplier_thresholds.begin(),
channel_multiplier_thresholds.end());
channel_multiplier_thresholds.resize(
std::unique(channel_multiplier_thresholds.begin(),
channel_multiplier_thresholds.end()) -
channel_multiplier_thresholds.begin());
std::sort(group_multiplier_thresholds.begin(),
group_multiplier_thresholds.end());
group_multiplier_thresholds.resize(
std::unique(group_multiplier_thresholds.begin(),
group_multiplier_thresholds.end()) -
group_multiplier_thresholds.begin());
Die Informationen auf dieser Webseite wurden
nach bestem Wissen sorgfältig zusammengestellt. Es wird jedoch weder Vollständigkeit, noch Richtigkeit,
noch Qualität der bereit gestellten Informationen zugesichert.
Bemerkung:
Die farbliche Syntaxdarstellung und die Messung sind noch experimentell.