local exports = exports or {}
local stroke_lua = stroke_lua or {}
stroke_lua.__index = stroke_lua

local AE_EFFECT_TAG = "AE_EFFECT_TAG Cutout-Auto"
local INPUT1 = 1
local GRAPH_NAME = "effectsdk_defaut_graph"
local MATTING_NAME = "interactive_matting_0"
local MATTING_REINIT = "interactive_matting_reinit"
local MATTING_STEP_COUNT = "interactive_matting_step_count"

function stroke_lua.new(construct, ...)
    local self = setmetatable({}, stroke_lua)
    if construct and stroke_lua.constructor then
        stroke_lua.constructor(self, ...)
    end
    return self
end

function stroke_lua:constructor()
end

function stroke_lua:onStrokeStart(comp, stroke, canvas)
    Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeStart()')
    self.canvas = canvas
    self.config = canvas.brushConfig
    self.blitMaterial = canvas.blitMaterial
    self.stroke = stroke
    self.resolution = stroke.resolution
    self.halfResolution = self.resolution * 0.5
    self.undoTexCache = stroke.undoTexCache
    self.redoTexCache = stroke.redoTexCache
    self.entity = stroke.entity
    self.processor = self.entity:getComponent("Brush2DInputProcessor")
    self.processor:setResolution(self.halfResolution)
    self.processor:setConfig(self.config)
    self.generator = self.entity:getComponent("Brush2DMeshGenerator")
    self.generator:setResolution(self.halfResolution)
    self.generator:setConfig(self.config)
    self.renderMaterial = self.config.brushMaterial
    local baseResolution = self.config.baseResolution
    local strokeSize = self.config.strokeSize
    self.renderMaterial:setFloat("strokeSize", baseResolution * strokeSize.value)

    self.subMeshIndex = -1
    self.cacheFlag = false
    self.targetCache = Amaz.RenderTexture()
    self.targetCache.width = self.resolution.x
    self.targetCache.height = self.resolution.y
    self.strokeTex = Amaz.RenderTexture()
    self.strokeTex.width = self.resolution.x
    self.strokeTex.height = self.resolution.y
    self.renderCmd = Amaz.CommandBuffer()
    self.cacheCmd = Amaz.CommandBuffer()
    self.stModel = Amaz.Matrix4x4f():SetIdentity()
    self.rectMesh = Amaz.Mesh()
    self.cropMesh = Amaz.Mesh()
    self.block1 = Amaz.MaterialPropertyBlock()
    self.block2 = Amaz.MaterialPropertyBlock() 
    self.mattingTex = nil
    self.maskTex = nil
    self.strokeFinished = false
    self.activeCount = 0
    self.properties = canvas.entity:getComponent("TableComponent")
    if self.properties == nil then
        self.properties = canvas.entity:addComponent("TableComponent")
    end
    self.properties.table:set(MATTING_STEP_COUNT, -1)
    Amaz.Algorithm.setAlgorithmEnable(GRAPH_NAME, MATTING_NAME, true)
    Amaz.Algorithm.setAlgorithmParamInt(GRAPH_NAME, MATTING_NAME, MATTING_REINIT, 0)
end

function stroke_lua:onStrokeUpdate(comp, deltaTime)
    local dirty = self.stroke.dirtyFlag
    local finished = self.stroke.finished
    Amaz.LOGD(AE_EFFECT_TAG, 'onStrokeUpdate() dirty: ' .. tostring(dirty) .. ' finished: ' .. tostring(finished))
    if dirty == Amaz.Brush2DDirtyFlag.Increase and not self.strokeFinished then
        Amaz.Algorithm.setAlgorithmParamInt(GRAPH_NAME, MATTING_NAME, MATTING_REINIT, 0)
        self.mattingTex = nil
        self.strokeMesh = self.generator:getMesh()
        local newIndex = self.strokeMesh.submeshes:size() - 1
        if self.subMeshIndex < newIndex then
            local subMesh = self.strokeMesh:getSubMesh(newIndex)
            -- subMesh:releaseIBO()
            self.subMeshIndex = newIndex
            self.dirtyRect = subMesh.boundingBox
            self.stroke.dirtyRect = self.dirtyRect
            self.strokeBBox = self.strokeMesh.boundingBox
            self.stroke.boundingBox = self.strokeBBox
            -- Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeUpdate() mesh increasing')
        elseif finished and newIndex >= 0 then
            -- dirty = Amaz.Brush2DDirtyFlag.Cache
        else
            Amaz.LOGW(AE_EFFECT_TAG, "onStrokeUpdate() submesh is not increasing. self.subMeshIndex = " .. self.subMeshIndex)
            dirty = Amaz.Brush2DDirtyFlag.Unchanged
        end
    elseif dirty == Amaz.Brush2DDirtyFlag.Undo or dirty == Amaz.Brush2DDirtyFlag.Redo then
        if finished then
            self.stroke.dirtyRect = self.strokeBBox
            if dirty == Amaz.Brush2DDirtyFlag.Undo then
                self.properties.table:set(MATTING_STEP_COUNT, self.properties.table:get(MATTING_STEP_COUNT) - 1)
            else
                self.properties.table:set(MATTING_STEP_COUNT, self.properties.table:get(MATTING_STEP_COUNT) + 1)
            end
        else
            Amaz.LOGE(AE_EFFECT_TAG, "onStrokeUpdate() stroke is unfinished.")
            dirty = Amaz.Brush2DDirtyFlag.Unchanged
        end
    end

    self.stroke.dirtyFlag = dirty
    self.renderCmd:clearAll()
    self.cacheCmd:clearAll()
    --self.blitMaterial:setTex("_MainTex", nil)
end

function stroke_lua:onStrokeRender(comp, target)
    local dirty = self.stroke.dirtyFlag
    local finished = self.stroke.finished
    Amaz.LOGD(AE_EFFECT_TAG, 'onStrokeRender() dirty: ' .. tostring(dirty) .. ' finished: ' .. tostring(finished))
    if dirty == Amaz.Brush2DDirtyFlag.Unchanged and self.cacheFlag then
        return
    end

    if dirty == Amaz.Brush2DDirtyFlag.Increase then
        if self.subMeshIndex == 0 then
            self.renderCmd:setRenderTexture(self.strokeTex)
            self.renderCmd:clearRenderTexture(true, false, Amaz.Color(0, 0, 0, 0))
            Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeRender() clear strokeTex')
            if self.activeCount < 1 then
                Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeRender() Cache target')
                self.renderCmd:blit(target, self.targetCache)
            end
        end

        if self.strokeFinished then
            Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeRender() getInteractiveMattingInfo')
            local result = Amaz.Algorithm.getAEAlgorithmResult()
            local mattingResult = result:getInteractiveMattingInfo()
            if mattingResult then
                if self.mattingTex == nil
                or self.mattingTex.width ~= mattingResult.width
                or self.mattingTex.height ~= mattingResult.height
                then
                    self.mattingTex = Amaz.Texture2D()
                end
                if self.mattingTex then
                    self.mattingTex:storage(mattingResult.alphaMask)
                    self:expandBBox(self.strokeBBox, mattingResult.left, mattingResult.bottom, mattingResult.right, mattingResult.top)
                    self:expandBBox(self.dirtyRect, mattingResult.left, mattingResult.bottom, mattingResult.right, mattingResult.top)
                end
            else
                Amaz.LOGE(AE_EFFECT_TAG, "onStrokeRender mattingResult is nil.")
                self.mattingTex = nil
            end
        end

        if self.mattingTex and self.activeCount >= 1 then
            Amaz.LOGD(AE_EFFECT_TAG, 'onStrokeRender() maskMode 1')
            local longSide = math.max(self.mattingTex.width, self.mattingTex.height)
            local t = (longSide - 800.0) / (6400.0 - 800.0)
            t = math.max(0.0, math.min(1.0, t))
            t = (3.0 / 8.0) * (1.0 - t) + (31.0 / 64.0) * t

            self.renderMaterial:setInt("maskMode", 1)
            self.renderMaterial:setFloat("threshold", t)
            self.renderCmd:setGlobalTexture("strokeTexture", self.mattingTex)

            dirty = Amaz.Brush2DDirtyFlag.Cache
            self.stroke.finished = true
            Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeRender() Set real finish')
        else
            Amaz.LOGD(AE_EFFECT_TAG, 'onStrokeRender() maskMode 0')
            self.renderMaterial:setInt("maskMode", 0)
            self.renderCmd:setRenderTexture(self.strokeTex)
            self.renderCmd:drawMesh(self.strokeMesh, self.stModel, self.renderMaterial, self.subMeshIndex, 0, nil)
            self.renderCmd:setGlobalTexture("strokeTexture", self.strokeTex)

            local validArea = self.canvas:validArea()
            local isIntersect = self:intersects(self.stroke.boundingBox, validArea)
            if finished and isIntersect then
                dirty = Amaz.Brush2DDirtyFlag.Increase
                finished = false
                Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeRender() First finish')

                local w, h = self:getAlgorithmProperSize(self.strokeTex.width, self.strokeTex.height)
                if w > 0 and h > 0
                and (self.maskTex == nil
                or self.maskTex.width ~= w
                or self.maskTex.height ~= h)
                then
                    self.maskTex = Amaz.RenderTexture()
                    self.maskTex.width = w
                    self.maskTex.height = h
                    Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeRender() Create maskTex w:', w, 'h:', h)
                end
                if self.maskTex then
                    Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeRender() Blit strokeTex to maskTex', self.maskTex.width, self.maskTex.height)
                    local block = Amaz.MaterialPropertyBlock()
                    block:setFloat("flip", 1)
                    block:setTexture("_MainTex", self.strokeTex)
                    self.renderCmd:blitWithMaterialAndProperties(self.strokeTex, self.maskTex, self.blitMaterial, 0, block)

                    -- set mask texture to interactive matting algorith
                    local top_index = self.properties.table:get(MATTING_STEP_COUNT)
                    if top_index <= -1 or top_index >= 7 then
                        Amaz.Algorithm.setAlgorithmParamInt(GRAPH_NAME, MATTING_NAME, MATTING_REINIT, 1)
                        top_index = -1
                    end
                    Amaz.Algorithm.setAlgorithmParamInt(GRAPH_NAME, MATTING_NAME, MATTING_STEP_COUNT, top_index)
                    Amaz.Algorithm.setInputTexture(INPUT1, self.maskTex, Amaz.Map())
                    top_index = top_index + 1
                    self.properties.table:set(MATTING_STEP_COUNT, top_index)
                else
                    Amaz.LOGE(AE_EFFECT_TAG, 'onStrokeRender() maskTex is nil')
                end
            else
                self.maskTex = nil
                Amaz.Algorithm.setInputTexture(INPUT1, nil, Amaz.Map())
                if not isIntersect then
                    Amaz.LOGW(AE_EFFECT_TAG, 'onStrokeRender() Not intersect, strokeBBox: ' .. tostring(self.stroke.boundingBox) .. ' canvasValidArea: ' .. tostring(validArea))
                end
            end
        end

        local alignedBBox = (finished and {self.strokeBBox} or {self.dirtyRect})[1]
        Amaz.Brush2DUtils.alignBBoxToResolution(self.halfResolution, alignedBBox)
        Amaz.Brush2DUtils.generateBBoxMesh(alignedBBox, self.rectMesh, true, true)
        self.renderCmd:setGlobalTexture("layerTexture", self.targetCache)
        self.renderCmd:setRenderTexture(target)
        self.renderCmd:drawMesh(self.rectMesh, self.stModel, self.renderMaterial, 0, 1, nil)
        comp.entity.scene:commitCommandBuffer(self.renderCmd)
        self.activeCount = self.activeCount + 1
    elseif dirty == Amaz.Brush2DDirtyFlag.Undo then
        if self.undoTexCache then
            self.cacheTex = self.undoTexCache:getTexture()
        end
        dirty = Amaz.Brush2DDirtyFlag.Unchanged
    elseif dirty == Amaz.Brush2DDirtyFlag.Redo then
        if self.redoTexCache then
            self.cacheTex = self.redoTexCache:getTexture()
        end
        dirty = Amaz.Brush2DDirtyFlag.Unchanged
    elseif dirty == Amaz.Brush2DDirtyFlag.Cache then
        self.cacheTex = nil
        collectgarbage("collect")
        dirty = Amaz.Brush2DDirtyFlag.Unchanged
    end

    if finished then
        if not self.cacheFlag and self.mattingTex then
            Amaz.Brush2DUtils.alignBBoxToResolution(self.halfResolution, self.strokeBBox)
            Amaz.Brush2DUtils.generateBBoxMesh(self.strokeBBox, self.cropMesh, false, true)
            local cacheWidth = (self.strokeBBox.max_x - self.strokeBBox.min_x) * self.halfResolution.x
            local cacheHeight = (self.strokeBBox.max_y - self.strokeBBox.min_y) * self.halfResolution.y
            if self.undoTexCache then
                self.undoTex = Amaz.RenderTexture()
                self.undoTex.width = cacheWidth
                self.undoTex.height = cacheHeight
                self.undoTexCache:setTexture(self.undoTex)
                self.cacheCmd:setRenderTexture(self.undoTex)
                self.block1:setTexture("_MainTex",self.targetCache)
                self.cacheCmd:setGlobalTexture("_MainTex", self.targetCache)
                self.cacheCmd:drawMesh(self.cropMesh, self.stModel, self.blitMaterial, 0, 0, self.block1)
            end
            if self.redoTexCache then
                self.redoTex = Amaz.RenderTexture()
                self.redoTex.width = cacheWidth
                self.redoTex.height = cacheHeight
                self.redoTexCache:setTexture(self.redoTex)
                self.cacheCmd:setRenderTexture(self.redoTex)
                self.block2:setTexture("_MainTex",target)
                self.cacheCmd:setGlobalTexture("_MainTex", target)
                self.cacheCmd:drawMesh(self.cropMesh, self.stModel, self.blitMaterial, 0, 0, self.block2)
            end
            comp.entity.scene:commitCommandBuffer(self.cacheCmd)
            dirty = Amaz.Brush2DDirtyFlag.Cache
            self.cacheFlag = true
            self:onStrokeRealFinish()
        elseif self.cacheTex then
            local block = Amaz.MaterialPropertyBlock()
            block:setTexture("_MainTex", self.cacheTex)
            self.renderCmd:setRenderTexture(target)
            self.renderCmd:drawMesh(self.cacheMesh, self.stModel, self.blitMaterial, 0, 0, block)
            comp.entity.scene:commitCommandBuffer(self.renderCmd)
            dirty = Amaz.Brush2DDirtyFlag.Cache
        end
    end

    self.stroke.dirtyFlag = dirty
end

function stroke_lua:onStrokeFinish()
    Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeFinish()')
    self.strokeFinished = true
end

function stroke_lua:onStrokeRealFinish()
    Amaz.LOGI(AE_EFFECT_TAG, 'onStrokeRealFinish()')
    self.entity:removeComponent("Brush2DInputProcessor")
    self.entity:removeComponent("Brush2DMeshGenerator")
    self.cacheMesh = Amaz.Mesh()
    Amaz.Brush2DUtils.generateBBoxMesh(self.strokeBBox, self.cacheMesh, true, false)

    self.resolution = nil
    self.config = nil
    self.entity = nil
    self.processor = nil
    self.generator = nil
    self.targetCache = nil
    self.strokeTex = nil
    self.undoTex = nil
    self.redoTex = nil
    self.mattingTex = nil
    self.maskTex = nil
    self.strokeMesh = nil
    self.rectMesh = nil
    self.cropMesh = nil
    self.canvas = nil
    self.block1=nil
    self.block2=nil
    self.strokeFinished = false
    self.activeCount = 0

    Amaz.Algorithm.setInputTexture(INPUT1, nil, Amaz.Map())

    collectgarbage("collect")
end

function stroke_lua:getAlgorithmProperSize(w, h)
    -- Amaz.LOGI(AE_EFFECT_TAG, 'begin getAlgorithmProperSize() w: ' .. w .. ' h:' .. h)
    if w == 0 or h == 0 then return w, h end
    if w > h then
        w = w * 256 / h
        h = 256
    else
        h = h * 256 / w
        w = 256
    end
    -- Amaz.LOGI(AE_EFFECT_TAG, 'end getAlgorithmProperSize() w: ' .. w .. ' h:' .. h)
    return w, h
end

function stroke_lua:expandBBox(bbox, min_x, min_y, max_x, max_y)
    if min_x < bbox.min_x then bbox.min_x = min_x end
    if min_y < bbox.min_y then bbox.min_y = min_y end
    if max_x > bbox.max_x then bbox.max_x = max_x end
    if max_y > bbox.max_y then bbox.max_y = max_y end
end

function stroke_lua:intersects(bboxA, bboxB)
    if not bboxA or not bboxB then return false end

    if bboxA.max_x < bboxB.min_x then return false end
    if bboxA.max_y < bboxB.min_y then return false end
    if bboxA.max_z < bboxB.min_z then return false end

    if bboxA.min_x > bboxB.max_x then return false end
    if bboxA.min_y > bboxB.max_y then return false end
    if bboxA.min_z > bboxB.max_z then return false end

    return true
end

exports.stroke_lua = stroke_lua
return exports
