-- (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

-- Reference design: https://www.internalfb.com/intern/px/p/3Wx3X

MB_START = [[
  {
    "end_velocity": 0.4,
    "start_angle": 90,
    "end_angle": 270
  }
]]

MB_END = [[
  {
    "start_velocity": 0.4,
    "start_angle": -90,
    "end_angle": 90
  }
]]

T_PARAMS = [[
  {
    "start_zoom": 1.0,
    "mid_zoom": 1.1,
    "end_zoom": 1.2,
    "global": false,
    "start_curve": "LINEAR",
    "end_curve": "LINEAR"
  }
]];

T_PARAMS_2 = [[
  {
    "start_zoom": 1.2,
    "mid_zoom": 1.1,
    "end_zoom": 1.0,
    "global": false,
    "start_curve": "LINEAR",
    "end_curve": "LINEAR"
  }
]];

MIRROR_PARAMS = [[
  {
    "distance_start": 0.0,
    "distance_end": -0.3,
    "direction": "vertical"
  }
]];

MIRROR_PARAMS_2 = [[
  {
    "distance_start": -0.5,
    "distance_end": -2.0,
    "direction": "vertical"
  }
]];

MIRROR_PARAMS_3 = [[
  {
    "distance_start": 0.0,
    "distance_end": 0.4,
    "direction": "horizontal"
  }
]];

MIRROR_PARAMS_4 = [[
  {
    "distance_start": -1.8,
    "distance_end": 0.0,
    "direction": "horizontal"
  }
]];

BRIGHTEN_START = [[
  {
    "end_amount": 0.6,
    "curve": "LINEAR"
  }
]]

BRIGHTEN_END = [[
  {
    "start_amount": 0.6,
    "end_amount": 0.0,
    "curve": "LINEAR"
  }
]]

ZOOM_IN = [[
    {
      "start_amount": 0.0,
      "end_amount": 0.09,
      "curve": "LINEAR"
    }
]];

ZOOM_OUT = [[
    {
    "start_amount": 0.1,
    "end_amount": 0.0,
    "curve": "LINEAR"
    }
]];

-- Shaders --
VERTEX_SHADER = [[
  precision highp float; // fill 1st line to help auto-indent
  layout(location = 0) in vec4 inPosition;
  layout(location = 1) in vec2 inTexCoord0;
  out vec2 texCoord0;

  void main(void) {
    gl_Position = inPosition;
    texCoord0 = inTexCoord0;
  }
]]

FRAGMENT_SHADER = [[
  precision highp float; // fill 1st line to help auto-indent
  in vec2 texCoord0;
  uniform sampler2D texture0;
  uniform float float1;
  uniform float float2;
  uniform float float3;

  layout(location = 0) out vec4 outColor;

  #define steps 64
  #define stepsDivisor 0.0015
  #define brightness float1
  #define velocity float2
  #define direction float3

  vec4 applyMotionBlur(float direction, float velocity) {
    vec4 color = vec4(0, 0, 0, 0);
    vec2 offset = vec2(0, 0);

    float velocityX = (cos(direction) * velocity) * stepsDivisor;
    float velocityY = (sin(direction) * velocity) * stepsDivisor;
    vec2 velocityVector = vec2(velocityX, velocityY);

    for (int i = 0; i < steps; ++i) {
      color += texture(texture0, texCoord0 + offset);
      offset += velocityVector;
    }
    color /= float(steps);

    return color;
  }

  vec3 changeBrightness(in vec3 color, float brightness) {
    return color * (1.0 + brightness);
  }

  void main(void) {
    vec4 color = applyMotionBlur(direction, velocity);

    outColor = vec4(changeBrightness(color.rgb, brightness), 1.0);
  }
]]

---- Lua code ----

-- Helper functions to construct animation curves
function appendKeyFrame(curve, newKeyFrame)
  if (curve == "") then
      return curve .. newKeyFrame
  else
      return curve .. "," .. newKeyFrame
  end
end

function makeKeyframe(timestamp, value, curve)
  return [[{"timestamp": ]] .. tostring(timestamp) .. [[, "value": ]] .. tostring(value) .. [[, "curve": "]] .. curve .. [["}]]
end

function makeJSONAnimationCurve(uniformName, animationCurve)
  return [["]] .. uniformName .. [[": []] .. animationCurve .. "]"
end

function makeJSONCurves(curve1, curve2, curve3)
  return [[{]] .. curve1 .. "," .. curve2 .. "," .. curve3 .. [[}]]
end

-- Get the median of a table.
function table.median( t )
  local temp={}

  -- deep copy table so that when we sort it, the original is unchanged
  -- also weed out any non numbers
  for k,v in pairs(t) do
    if type(v) == 'number' then
      table.insert( temp, v )
    end
  end

  table.sort( temp )

  -- If we have an even number of table elements or odd.
  if math.fmod(#temp,2) == 0 then
    -- return mean value of middle two elements
    return ( temp[#temp/2] + temp[(#temp/2)+1] ) / 2
  else
    -- return middle element
    return temp[math.ceil(#temp/2)]
  end
end

function table.contains(table, element)
  for _, value in pairs(table) do
      if value == element then
          return true
      end
  end
  return false
end

function getConfig()
  -- default config
  local config = {
      maxItemCount = 15,
      isVideoSegmentationEnabled = false,
      minVideoSegmentDuration = 30,
      seedValue = 0,
      shouldAllowDuplicatedMedia = true,
      shouldEnableSingleMediaLogic = false,
      shouldEnableTemplatesOptimization = false
  }

  for _, mediaEvent in ipairs(MediaEvents) do
      if mediaEvent:getSourceMedia() == "config" then
      for i = 0, mediaEvent:size() - 1, 1 do
          local startTime, weight, duration, tags = mediaEvent:getEvent(i)
          local param = tags[1]
          if param == "maxItemCount" then
            config.maxItemCount = math.floor(weight)
          elseif param == "isVideoSegmentationEnabled" then
            config.isVideoSegmentationEnabled = weight == 1.0
          elseif param == "minVideoSegmentDuration" then
            config.minVideoSegmentDuration = math.floor(weight)
          elseif param == "seed" then
            config.seedValue = math.floor(weight)
          elseif param == "shouldAllowDuplicatedMedia" then
            config.shouldAllowDuplicatedMedia = param
          elseif param == "shouldEnableSingleMediaLogic" then
            config.shouldEnableSingleMediaLogic = weight == 1.0
          elseif param == "shouldEnableTemplatesOptimization" then
            config.shouldEnableTemplatesOptimization = weight == 1.0
          end
      end
      end
  end

  return config
end

function getBeatInfo()
  local beatsTable = {}
  local prevStart = 0
  local beatDurations = {}
  local medianDuration = -1
  for _, mediaEvent in ipairs(MediaEvents) do
      if mediaEvent:getSourceMedia() == "config" then goto CONTINUE end
      local numEvents = mediaEvent:size()
      local onsetsAfter = 0
      for i = 0, numEvents - 1, 1 do
      -- check 'beat' tag in an event
      local start, weights, _, tags = mediaEvent:getEvent(i)
      if (table.contains(tags, "beat") and table.contains(tags, "onset") ) then
          -- save previous beat's onsets after the beat
          if #beatsTable ~= 0 then
              beatsTable[#beatsTable].onsetsAfter = onsetsAfter
          end
          -- save the new beat
          local beat = {
              startTime = start,
              weight = weights,
              onset = table.contains(tags, "onset"),
              onsetsBefore = onsetsAfter, -- onsets before beat are previous beat's onsets after the beat
              onsetsAfter = 0, -- This will be updated by the next beat.
          }
          table.insert(beatsTable, beat)
          onsetsAfter = 0
          -- save the beat duration
          local currDuration = start - prevStart
          table.insert(beatDurations, currDuration)
          prevStart = start
      else
          if #beatsTable ~= 0 then
              onsetsAfter = onsetsAfter + 1
          end
      end
      end
      ::CONTINUE::
  end
  if #beatDurations > 1 then
      medianDuration = table.median(beatDurations)
  end
  return beatsTable, medianDuration
end

local function flashEffect(startTime, endTime, params)
  Treatments:transformTreatment(startTime, endTime, params[1]);
  Treatments:brightnessTreatment(startTime, endTime, params[2]);
  Treatments:motionBlurTreatment(startTime, endTime, params[3]);
end

local function slideEffect(startTime, endTime, params)
  Treatments:mirroredTileTreatment(startTime, endTime, params[1]);
end

local function brightnessEffect(startTime, endTime, params)
  Treatments:brightnessTreatment(startTime, endTime, params[1]);
end

function addFlashEffect(startEffect, endEffect, effectDuration, clipDuration)
  local keyFrames = {
    brightness = "",
    velocity = "",
    direction = ""
  }
  local curveType = "LINEAR"

  if (startEffect.name == "Flash") then
    keyFrames.brightness = appendKeyFrame(keyFrames.brightness, makeKeyframe(0.0, 0.6, curveType))
    keyFrames.brightness = appendKeyFrame(keyFrames.brightness, makeKeyframe(effectDuration, 0.0, curveType))

    keyFrames.velocity = appendKeyFrame(keyFrames.velocity, makeKeyframe(0.0, 0.4, curveType))
    keyFrames.velocity = appendKeyFrame(keyFrames.velocity, makeKeyframe(effectDuration, 0.0, curveType))

    keyFrames.direction = appendKeyFrame(keyFrames.direction, makeKeyframe(0.0, 90 * math.pi / 180.0, curveType))
    keyFrames.direction = appendKeyFrame(keyFrames.direction, makeKeyframe(effectDuration, 270 * math.pi / 180.0, curveType))
    keyFrames.direction = appendKeyFrame(keyFrames.direction, makeKeyframe(effectDuration + 0.001, 0.0, curveType))
  end

  if (endEffect.name == "Flash") then
    keyFrames.brightness = appendKeyFrame(keyFrames.brightness, makeKeyframe(clipDuration - effectDuration, 0.0, curveType))
    keyFrames.brightness = appendKeyFrame(keyFrames.brightness, makeKeyframe(clipDuration, 0.6, curveType))

    keyFrames.velocity = appendKeyFrame(keyFrames.velocity, makeKeyframe(clipDuration - effectDuration, 0.0, curveType))
    keyFrames.velocity = appendKeyFrame(keyFrames.velocity, makeKeyframe(clipDuration, 0.4, curveType))

    keyFrames.direction = appendKeyFrame(keyFrames.direction, makeKeyframe(clipDuration - effectDuration, -90 * math.pi / 180.0, curveType))
    keyFrames.direction = appendKeyFrame(keyFrames.direction, makeKeyframe(clipDuration, 90 * math.pi / 180.0, curveType))
  end

  local curves = {
    brightness = makeJSONAnimationCurve("float1", keyFrames.brightness),
    velocity = makeJSONAnimationCurve("float2", keyFrames.velocity),
    direction = makeJSONAnimationCurve("float3", keyFrames.direction)
  }
  ShaderEffects:addShader(VERTEX_SHADER, FRAGMENT_SHADER, 0.0, clipDuration, makeJSONCurves(curves.brightness, curves.velocity, curves.direction))
end


local function createEffect(name, callback, params)
  return { name = name, callback = callback, params = params }
end

local function populateEndEffects(numEffects, repeatEffectCount, isSingleMediaLogic)
  local isFlashEffect = false
  local isSlideUpEffect = false

  -- Populate end clip effects
  local endEffects = {}
  for i = 1, numEffects do
      if (isSingleMediaLogic) then
        isFlashEffect = i % repeatEffectCount == 1
        isSlideUpEffect = i % repeatEffectCount == 2
      else
        isFlashEffect = math.floor((i - 1) / repeatEffectCount) % 2 == 0
        isSlideUpEffect = i % repeatEffectCount == 1
      end
      local isLastClip = (i == numEffects)

      if isLastClip then
          endEffects[i] = createEffect("NoEffect", nil, nil)
      elseif isFlashEffect then
          endEffects[i] = createEffect("Flash", flashEffect, {T_PARAMS, BRIGHTEN_START, MB_START})
      elseif isSlideUpEffect then
          endEffects[i] = createEffect("SlideUp", slideEffect, {MIRROR_PARAMS})
      else
          endEffects[i] = createEffect("SlideLeft", slideEffect, {MIRROR_PARAMS_3})
      end
  end

  return endEffects
end

local function populateStartEffects(numEffects, endEffects, isSingleMediaLogic)
  -- Populate start clip effects based on endEffects
  local startEffects = {createEffect("NoEffect", nil, nil)}

  for i = 2, numEffects do
      local prevEndEffect = endEffects[i - 1].name
      local currEndEffect = endEffects[i].name

      if (not isSingleMediaLogic) then
        if (prevEndEffect == "Flash" and currEndEffect == "SlideUp") then
            startEffects[i] = createEffect("NoEffect")
        end
      end
      if (prevEndEffect == "SlideUp") then
          startEffects[i] = createEffect("SlideUp", slideEffect, {MIRROR_PARAMS_2})
      elseif (prevEndEffect == "SlideLeft") then
          startEffects[i] = createEffect("SlideLeft", slideEffect, {MIRROR_PARAMS_4})
      else
          startEffects[i] = createEffect("Flash", flashEffect, {T_PARAMS_2, BRIGHTEN_END, MB_END})
      end
  end

  return startEffects
end

function generateEffects(numEffects, repeatEffectCount)
  local config = getConfig()
  local isSingleMediaLogic = (#MediaItems == 1 and config.shouldEnableSingleMediaLogic)

  -- In order to create a clip transition
  -- we are adding the same treatment at the end of the current
  -- and at the start of the next clip
  local endEffects = populateEndEffects(numEffects, repeatEffectCount, isSingleMediaLogic)
  local startEffects = populateStartEffects(numEffects, endEffects, isSingleMediaLogic)

  return startEffects, endEffects
end

function addClip(clipDuration, mediaIndex)
  local mediaItem = MediaItems[mediaIndex]
  local startTime = mediaItem:getStartSec()
  local speed = 1.0
  Project:addClip(startTime, clipDuration, speed, mediaItem)
end

function addEffect(effect, startTime, endTime)
  local effectFunction = effect.callback
  if (effectFunction ~= nil) then
      effectFunction(startTime, endTime, effect.params)
  end
end

function main()
  -- Keep only beats that fit in the suggested timeline
  -- https://www.wolframalpha.com/input?i=0.02+*+%28x+%5E+2%29+-+0.6+*+x+%2B+6%2C+x+from+1.0+to+15&fbclid=IwAR28NqOXDY9gC2qRCoc5yoEzfyEdBjtZzgz4-XIP2Gw7rqxDu3r7oBg5etM
  local numMediaItems = #MediaItems
  local suggestedClipDuration = 0.02 * (numMediaItems ^ 2) - 0.6 * numMediaItems + 6
  local maxTimelineDuration = suggestedClipDuration * numMediaItems
  local config = getConfig()

  -- Find beatDuration
  local _, beatDuration = getBeatInfo()

  local startClipEffects, endClipEffects = generateEffects(numMediaItems, 3)

  for i = 1, numMediaItems do
      -- Sync clip to the beat
      local clipDuration = suggestedClipDuration
      if (beatDuration > 0.0 and suggestedClipDuration > beatDuration) then
          local beatsCount = math.floor(suggestedClipDuration/beatDuration)
          clipDuration = beatsCount * beatDuration
      end

      local mediaItem = MediaItems[i]
      if (mediaItem:isVideo()) then
          local mediaDuration = mediaItem:getDurationSec()
          if (mediaDuration < clipDuration) then
              -- In this case ignore the beat
              clipDuration = mediaDuration
          end
      end

      -- Add clips in the timeline
      addClip(clipDuration, i)

      -- Add effects per clip
      if (mediaItem:isImage()) then
          if (i % 2 == 0) then
              Treatments:zoomTreatment2(0.0, clipDuration, ZOOM_OUT)
          else
              Treatments:zoomTreatment2(0.0, clipDuration, ZOOM_IN)
          end
      end

      local effectDuration = 0.3
      if (config.shouldEnableTemplatesOptimization) then
        addFlashEffect(startClipEffects[i], endClipEffects[i], effectDuration, clipDuration)

        if (startClipEffects[i].name ~= "Flash") then
            addEffect(startClipEffects[i], 0.0, effectDuration)
        end

        if (endClipEffects[i].name ~= "Flash") then
            addEffect(endClipEffects[i], clipDuration - effectDuration, clipDuration)
        end
    else
        addEffect(startClipEffects[i], 0.0, effectDuration)
        addEffect(endClipEffects[i], clipDuration - effectDuration, clipDuration)
    end
  end
end

function flashAndSlideSingleMedia()
  -- Keep only beats that fit in the suggested timeline
  -- https://www.wolframalpha.com/input?i=0.02+*+%28x+%5E+2%29+-+0.6+*+x+%2B+6%2C+x+from+1.0+to+15&fbclid=IwAR28NqOXDY9gC2qRCoc5yoEzfyEdBjtZzgz4-XIP2Gw7rqxDu3r7oBg5etM
  local numMediaItems = 1
  local suggestedClipDuration = 8.0
  local effectDuration = 0.6
  local suggestedChangeInSec = 2.0
  local changeInSec = suggestedChangeInSec

  -- Find beatDuration
  local _, beatDuration = getBeatInfo()

  -- Sync effect changes to the beat
  if (beatDuration > 0.0 and suggestedChangeInSec > beatDuration) then
    local beatsCount = math.floor(suggestedChangeInSec/beatDuration)
    changeInSec = beatsCount * beatDuration
  end

  local clipDuration = suggestedClipDuration
  local mediaItem = MediaItems[1]
  if (mediaItem:isVideo()) then
    local mediaDuration = mediaItem:getDurationSec()
    if (mediaDuration < clipDuration) then
      -- In this case ignore the beat
      clipDuration = mediaDuration
    end
  end

  -- Add clip in the timeline
  addClip(clipDuration, 1)

  -- Add zoom effect to the clip
  if (mediaItem:isImage()) then
    Treatments:zoomTreatment2(0.0, clipDuration, ZOOM_IN)
  end

  local flashKeyFrames = {
    brightness = "",
    velocity = "",
    direction = ""
  }

  local totalEffectsApplied = math.floor(suggestedClipDuration/changeInSec)
  local effectsOrder = {"Flash", "SlideUp", "SlideLeft"}
  for i = 0, totalEffectsApplied - 1 do
    local effect = effectsOrder[i % #effectsOrder + 1]
    local startTime = (i+1)*changeInSec

    if (effect == "Flash") then
      local curveType = "LINEAR"
      flashKeyFrames.brightness = appendKeyFrame(flashKeyFrames.brightness, makeKeyframe(startTime, 0.0, curveType))
      flashKeyFrames.brightness = appendKeyFrame(flashKeyFrames.brightness, makeKeyframe(startTime + 0.5*effectDuration, 0.6, curveType))
      flashKeyFrames.brightness = appendKeyFrame(flashKeyFrames.brightness, makeKeyframe(startTime + effectDuration, 0.0, curveType))

      flashKeyFrames.velocity = appendKeyFrame(flashKeyFrames.velocity, makeKeyframe(startTime, 0.0, curveType))
      flashKeyFrames.velocity = appendKeyFrame(flashKeyFrames.velocity, makeKeyframe(startTime + 0.5*effectDuration, 0.8, curveType))
      flashKeyFrames.velocity = appendKeyFrame(flashKeyFrames.velocity, makeKeyframe(startTime + effectDuration, 0.0, curveType))

      flashKeyFrames.direction = appendKeyFrame(flashKeyFrames.direction, makeKeyframe(startTime, 0.0, curveType))
      flashKeyFrames.direction = appendKeyFrame(flashKeyFrames.direction, makeKeyframe(startTime + effectDuration, 2.0 * math.pi, curveType))
    elseif (effect == "SlideUp") then
        local f = createEffect(fx, slideEffect, {[[{ "distance_start": 0.0, "distance_end": -2.0, "direction": "vertical" }]]})
        addEffect(f, startTime, startTime + effectDuration)
    elseif (effect == "SlideLeft") then
        local f = createEffect(fx, slideEffect, {[[{ "distance_start": 0.0, "distance_end": 2.0, "direction": "horizontal" } ]]})
        addEffect(f, startTime, startTime + effectDuration)
    end
  end

  local flashCurves = {
    brightness = makeJSONAnimationCurve("float1", flashKeyFrames.brightness),
    velocity = makeJSONAnimationCurve("float2", flashKeyFrames.velocity),
    direction = makeJSONAnimationCurve("float3", flashKeyFrames.direction)
  }

  ShaderEffects:addShader(VERTEX_SHADER, FRAGMENT_SHADER, 0.0, clipDuration, makeJSONCurves(flashCurves.brightness, flashCurves.velocity, flashCurves.direction))
end

local config = getConfig()
if (#MediaItems == 1 and config.shouldEnableSingleMediaLogic) then
  flashAndSlideSingleMedia()
else
  main()
end
