/*******************************************************************************
* Copyright 2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "ocl/gemm_x8s8s32x_inner_product.hpp"
#include "ocl/ocl_utils.hpp"

namespace dnnl {
namespace impl {
namespace ocl {

status_t gemm_x8s8s32x_inner_product_fwd_t::execute_forward(
        const exec_ctx_t &ctx) const {

    exec_args_t gemm_args;
    gemm_args[DNNL_ARG_SRC_0] = ctx.args().at(DNNL_ARG_WEIGHTS);
    gemm_args[DNNL_ARG_SRC_1] = ctx.args().at(DNNL_ARG_SRC);

    std::shared_ptr<memory_t> scratchpad_mem;
    if (pd()->use_scratchpad()) {
        scratchpad_mem.reset(new memory_t(engine(), pd()->ip_scratchpad_md(),
                memory_flags_t::use_runtime_ptr, nullptr));
        void *mem_storage_scratchpad = nullptr;
        scratchpad_->get_data_handle(&mem_storage_scratchpad);
        scratchpad_mem->set_data_handle(mem_storage_scratchpad);
    }

    if (pd()->use_temp_dst()) {
        gemm_args[DNNL_ARG_DST] = {scratchpad_mem.get(), false};
    } else {
        gemm_args[DNNL_ARG_DST] = ctx.args().at(DNNL_ARG_DST);
    }

    exec_ctx_t gemm_ctx(ctx.stream(), std::move(gemm_args));
    status_t gemm_exec_status = gemm_->execute(gemm_ctx);
    if (gemm_exec_status != status::success) return gemm_exec_status;

    if (pd()->with_post_process()) {
        compute::kernel_arg_list_t arg_list;
        arg_list.set(0, CTX_OUT_STORAGE(DNNL_ARG_DST));
        arg_list.set(1, CTX_IN_STORAGE(DNNL_ARG_BIAS));
        arg_list.set(2, CTX_OUT_STORAGE(DNNL_ARG_DST));
        arg_list.set(3, pd()->eltwise_alpha());
        arg_list.set(4, pd()->eltwise_beta());
        arg_list.set(5, pd()->sum_scale());
        arg_list.set(6,
                pd()->use_scratchpad() ? *scratchpad_
                                       : memory_storage_t::empty_storage());
        arg_list.set(7,
                pd()->with_scales() ? *scales_mem_->memory_storage()
                                    : memory_storage_t::empty_storage());

        size_t mb = pd()->MB();
        size_t oc = pd()->OC();

        compute::compute_stream_t *compute_stream
                = utils::downcast<compute::compute_stream_t *>(ctx.stream());

        const size_t gws[] = {1, mb, oc};
        const size_t lws[] = {1, 1, 1};
        auto nd_range = compute::nd_range_t(gws, lws);
        status_t status = compute_stream->parallel_for(
                nd_range, post_process_kernel_, arg_list);
        if (status != status::success) return status;
    }

    return status::success;
}

} // namespace ocl
} // namespace impl
} // namespace dnnl
