/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.common.output.model;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import lombok.Generated;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensor;

public class ModelTensors
implements Writeable,
ToXContentObject {
    public static final String OUTPUT_FIELD = "output";
    public static final String STATUS_CODE_FIELD = "status_code";
    private List<ModelTensor> mlModelTensors;
    private Integer statusCode;

    public ModelTensors(List<ModelTensor> mlModelTensors) {
        this.mlModelTensors = mlModelTensors;
    }

    public ModelTensors(Integer statusCode) {
        this.statusCode = statusCode;
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (this.mlModelTensors != null && this.mlModelTensors.size() > 0) {
            builder.startArray(OUTPUT_FIELD);
            for (ModelTensor output : this.mlModelTensors) {
                output.toXContent(builder, params);
            }
            builder.endArray();
        }
        if (this.statusCode != null) {
            builder.field(STATUS_CODE_FIELD, this.statusCode);
        }
        builder.endObject();
        return builder;
    }

    public ModelTensors(StreamInput in) throws IOException {
        if (in.readBoolean()) {
            this.mlModelTensors = new ArrayList<ModelTensor>();
            int size = in.readInt();
            for (int i = 0; i < size; ++i) {
                this.mlModelTensors.add(new ModelTensor(in));
            }
        }
        this.statusCode = in.readOptionalInt();
    }

    public void writeTo(StreamOutput out) throws IOException {
        if (this.mlModelTensors != null && this.mlModelTensors.size() > 0) {
            out.writeBoolean(true);
            out.writeInt(this.mlModelTensors.size());
            for (ModelTensor output : this.mlModelTensors) {
                output.writeTo(out);
            }
        } else {
            out.writeBoolean(false);
        }
        out.writeOptionalInt(this.statusCode);
    }

    public void filter(ModelResultFilter resultFilter) {
        boolean returnBytes = resultFilter.isReturnBytes();
        boolean returnNumber = resultFilter.isReturnNumber();
        List<String> targetResponse = resultFilter.getTargetResponse();
        List<Integer> targetResponsePositions = resultFilter.getTargetResponsePositions();
        if (!(targetResponse != null && targetResponse.size() != 0 || targetResponsePositions != null && targetResponsePositions.size() != 0)) {
            this.mlModelTensors.forEach(output -> this.filter((ModelTensor)output, returnBytes, returnNumber));
            return;
        }
        ArrayList<ModelTensor> targetOutput = new ArrayList<ModelTensor>();
        if (this.mlModelTensors != null) {
            for (int i = 0; i < this.mlModelTensors.size(); ++i) {
                ModelTensor output2 = this.mlModelTensors.get(i);
                if (targetResponse != null && targetResponse.contains(output2.getName())) {
                    this.filter(output2, returnBytes, returnNumber);
                    targetOutput.add(output2);
                    continue;
                }
                if (targetResponsePositions == null || !targetResponsePositions.contains(i)) continue;
                this.filter(output2, returnBytes, returnNumber);
                targetOutput.add(output2);
            }
        }
        this.mlModelTensors = targetOutput;
    }

    private void filter(ModelTensor output, boolean returnBytes, boolean returnNUmber) {
        if (!returnBytes) {
            output.setByteBuffer(null);
        }
        if (!returnNUmber) {
            output.setData(null);
        }
    }

    public byte[] toBytes() {
        byte[] byArray;
        BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
        try {
            byte[] bytes;
            this.writeTo((StreamOutput)bytesStreamOutput);
            bytesStreamOutput.flush();
            byArray = bytes = bytesStreamOutput.bytes().toBytesRef().bytes;
        }
        catch (Throwable throwable) {
            try {
                try {
                    bytesStreamOutput.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (Exception e) {
                throw new MLException("Failed to parse result", e);
            }
        }
        bytesStreamOutput.close();
        return byArray;
    }

    public static ModelTensors fromBytes(byte[] bytes) {
        ModelTensors modelTensors;
        block8: {
            ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
            StreamInput streamInput = BytesReference.fromByteBuffer((ByteBuffer)byteBuffer).streamInput();
            try {
                ModelTensors tensorOutput;
                modelTensors = tensorOutput = new ModelTensors(streamInput);
                if (streamInput == null) break block8;
            }
            catch (Throwable tensorOutput) {
                try {
                    if (streamInput != null) {
                        try {
                            streamInput.close();
                        }
                        catch (Throwable throwable) {
                            tensorOutput.addSuppressed(throwable);
                        }
                    }
                    throw tensorOutput;
                }
                catch (Exception e) {
                    String errorMsg = "Failed to parse output";
                    throw new MLException(errorMsg, e);
                }
            }
            streamInput.close();
        }
        return modelTensors;
    }

    public static ModelTensors parse(XContentParser parser) throws IOException {
        Integer statusCode = null;
        ArrayList<ModelTensor> mlModelTensors = new ArrayList<ModelTensor>();
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
        block8: while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();
            switch (fieldName) {
                case "status_code": {
                    statusCode = parser.intValue(false);
                    continue block8;
                }
                case "output": {
                    XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_ARRAY, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        mlModelTensors.add(ModelTensor.parser(parser));
                    }
                    continue block8;
                }
            }
            parser.skipChildren();
        }
        ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(mlModelTensors).build();
        modelTensors.setStatusCode(statusCode);
        return modelTensors;
    }

    @Generated
    public static ModelTensorsBuilder builder() {
        return new ModelTensorsBuilder();
    }

    @Generated
    public List<ModelTensor> getMlModelTensors() {
        return this.mlModelTensors;
    }

    @Generated
    public Integer getStatusCode() {
        return this.statusCode;
    }

    @Generated
    public void setStatusCode(Integer statusCode) {
        this.statusCode = statusCode;
    }

    @Generated
    public static class ModelTensorsBuilder {
        @Generated
        private List<ModelTensor> mlModelTensors;
        @Generated
        private Integer statusCode;

        @Generated
        ModelTensorsBuilder() {
        }

        @Generated
        public ModelTensorsBuilder mlModelTensors(List<ModelTensor> mlModelTensors) {
            this.mlModelTensors = mlModelTensors;
            return this;
        }

        @Generated
        public ModelTensors build() {
            return new ModelTensors(this.mlModelTensors);
        }

        @Generated
        public String toString() {
            return "ModelTensors.ModelTensorsBuilder(mlModelTensors=" + String.valueOf(this.mlModelTensors) + ")";
        }

        @Generated
        public ModelTensorsBuilder statusCode(Integer statusCode) {
            this.statusCode = statusCode;
            return this;
        }
    }
}

