/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.search.backpressure.trackers;

import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.DoubleSupplier;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.util.MovingAverage;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.monitor.jvm.JvmStats;
import org.opensearch.search.backpressure.trackers.TaskResourceUsageTracker;
import org.opensearch.search.backpressure.trackers.TaskResourceUsageTrackerType;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskCancellation;

public class HeapUsageTracker
extends TaskResourceUsageTracker {
    private static final Logger logger = LogManager.getLogger(HeapUsageTracker.class);
    private static final long HEAP_SIZE_BYTES = JvmStats.jvmStats().getMem().getHeapMax().getBytes();
    private final DoubleSupplier heapVarianceSupplier;
    private final DoubleSupplier heapPercentThresholdSupplier;
    private final AtomicReference<MovingAverage> movingAverageReference;

    public HeapUsageTracker(DoubleSupplier heapVarianceSupplier, DoubleSupplier heapPercentThresholdSupplier, int heapMovingAverageWindowSize, ClusterSettings clusterSettings, Setting<Integer> windowSizeSetting) {
        this.heapVarianceSupplier = heapVarianceSupplier;
        this.heapPercentThresholdSupplier = heapPercentThresholdSupplier;
        this.movingAverageReference = new AtomicReference<MovingAverage>(new MovingAverage(heapMovingAverageWindowSize));
        clusterSettings.addSettingsUpdateConsumer(windowSizeSetting, this::updateWindowSize);
    }

    @Override
    public String name() {
        return TaskResourceUsageTrackerType.HEAP_USAGE_TRACKER.getName();
    }

    @Override
    public void update(Task task) {
        this.movingAverageReference.get().record(task.getTotalResourceStats().getMemoryInBytes());
    }

    @Override
    public Optional<TaskCancellation.Reason> checkAndMaybeGetCancellationReason(Task task) {
        MovingAverage movingAverage = this.movingAverageReference.get();
        if (!movingAverage.isReady()) {
            return Optional.empty();
        }
        double currentUsage = task.getTotalResourceStats().getMemoryInBytes();
        double averageUsage = movingAverage.getAverage();
        double variance = this.heapVarianceSupplier.getAsDouble();
        double allowedUsage = averageUsage * variance;
        double threshold = this.heapPercentThresholdSupplier.getAsDouble() * (double)HEAP_SIZE_BYTES;
        if (!HeapUsageTracker.isHeapTrackingSupported() || currentUsage < threshold || currentUsage < allowedUsage) {
            return Optional.empty();
        }
        return Optional.of(new TaskCancellation.Reason("heap usage exceeded [" + new ByteSizeValue((long)currentUsage) + " >= " + new ByteSizeValue((long)allowedUsage) + "]", (int)(currentUsage / averageUsage)));
    }

    private void updateWindowSize(int heapMovingAverageWindowSize) {
        this.movingAverageReference.set(new MovingAverage(heapMovingAverageWindowSize));
    }

    public static boolean isHeapTrackingSupported() {
        return HEAP_SIZE_BYTES > 0L;
    }

    public static boolean isHeapUsageDominatedBySearch(List<CancellableTask> cancellableTasks, double heapPercentThreshold) {
        long usage = cancellableTasks.stream().mapToLong(task -> task.getTotalResourceStats().getMemoryInBytes()).sum();
        long threshold = (long)(heapPercentThreshold * (double)HEAP_SIZE_BYTES);
        if (HeapUsageTracker.isHeapTrackingSupported() && usage < threshold) {
            logger.debug("heap usage not dominated by search requests [{}/{}]", (Object)usage, (Object)threshold);
            return false;
        }
        return true;
    }

    @Override
    public TaskResourceUsageTracker.Stats stats(List<? extends Task> activeTasks) {
        long currentMax = activeTasks.stream().mapToLong(t -> t.getTotalResourceStats().getMemoryInBytes()).max().orElse(0L);
        long currentAvg = (long)activeTasks.stream().mapToLong(t -> t.getTotalResourceStats().getMemoryInBytes()).average().orElse(0.0);
        return new Stats(this.getCancellations(), currentMax, currentAvg, (long)this.movingAverageReference.get().getAverage());
    }

    public static class Stats
    implements TaskResourceUsageTracker.Stats {
        private final long cancellationCount;
        private final long currentMax;
        private final long currentAvg;
        private final long rollingAvg;

        public Stats(long cancellationCount, long currentMax, long currentAvg, long rollingAvg) {
            this.cancellationCount = cancellationCount;
            this.currentMax = currentMax;
            this.currentAvg = currentAvg;
            this.rollingAvg = rollingAvg;
        }

        public Stats(StreamInput in) throws IOException {
            this(in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong());
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            return builder.startObject().field("cancellation_count", this.cancellationCount).humanReadableField("current_max_bytes", "current_max", (Object)new ByteSizeValue(this.currentMax)).humanReadableField("current_avg_bytes", "current_avg", (Object)new ByteSizeValue(this.currentAvg)).humanReadableField("rolling_avg_bytes", "rolling_avg", (Object)new ByteSizeValue(this.rollingAvg)).endObject();
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeVLong(this.cancellationCount);
            out.writeVLong(this.currentMax);
            out.writeVLong(this.currentAvg);
            out.writeVLong(this.rollingAvg);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Stats stats = (Stats)o;
            return this.cancellationCount == stats.cancellationCount && this.currentMax == stats.currentMax && this.currentAvg == stats.currentAvg && this.rollingAvg == stats.rollingAvg;
        }

        public int hashCode() {
            return Objects.hash(this.cancellationCount, this.currentMax, this.currentAvg, this.rollingAvg);
        }
    }
}

