local fun_is_callable_module = "Module:fun/isCallable"
local math_module = "Module:math"
local scribunto_types_module = "Module:Scribunto/types"
local table_max_index_module = "Module:table/maxIndex"
local dump = mw.dumpObject
local error = error
local format = string.format
local ipairs = ipairs
local require = require
local select = select
local type = type
local function is_callable(...)
is_callable = require(fun_is_callable_module)
return is_callable(...)
end
local function is_positive_integer(...)
is_positive_integer = require(math_module).is_positive_integer
return is_positive_integer(...)
end
local function table_max_index(...)
table_max_index = require(table_max_index_module)
return table_max_index(...)
end
local scribunto_types
local function get_scribunto_types()
scribunto_types, get_scribunto_types = require(scribunto_types_module), nil
return scribunto_types
end
local function check_type(_type, param)
if (scribunto_types or get_scribunto_types()) then
return true
end
local type_tp = type(_type)
error(format(
"bad spec in 'fun/optionalParameters' in 'insert_if' for %sparameter #%d: name of Lua data type %s",
param,
type_tp == "string" and format("expected, got %s", dump(_type)) or
format("as a string expected, got %s object", type_tp)
))
end
local function get_first_and_last_parameter(func, insert_if, insert_if_type, buffer, default)
if buffer == 0 then
if default == nil then
if insert_if_type == nil then -- should be unreachable
-- buffer: no, default: no, test: none
return func
elseif insert_if_type == "string" then
-- buffer: no, default: no, test: string
return function(...)
if select("#", ...) == 0 or type((...)) == insert_if then
return func(...)
end
return func(nil, ...)
end
elseif insert_if_type == "callable" then
-- buffer: no, default: no, test: callable
return function(...)
if select("#", ...) == 0 or is_callable((...)) then
return func(...)
end
return func(nil, ...)
end
elseif insert_if_type == "set" then
-- buffer: no, default: no, test: set
return function(...)
if select("#", ...) == 0 or insert_if then
return func(...)
end
return func(nil, ...)
end
elseif insert_if_type == "call" then
-- buffer: no, default: no, test: call
return function(...)
local args_n = select("#", ...)
if args_n == 0 or insert_if(..., 1, args_n) then
return func(...)
end
return func(nil, ...)
end
end
elseif insert_if_type == nil then
-- buffer: no, default: yes, test: none
return function(_arg, ...)
if _arg == nil then
return func(default, ...)
end
return func(_arg, ...)
end
elseif insert_if_type == "string" then
-- buffer: no, default: yes, test: string
return function(...)
if type((...)) == insert_if then
return func(...)
end
return func(default, ...)
end
elseif insert_if_type == "callable" then
-- buffer: no, default: yes, test: callable
return function(...)
if is_callable((...)) then
return func(...)
end
return func(default, ...)
end
elseif insert_if_type == "set" then
-- buffer: no, default: yes, test: set
return function(...)
if insert_if then
return func(...)
end
return func(default, ...)
end
elseif insert_if_type == "call" then
-- buffer: no, default: yes, test: call
return function(...)
if insert_if(..., 1, select("#", ...)) then
return func(...)
end
return func(default, ...)
end
end
elseif default == nil then
if insert_if_type == nil then
-- buffer: yes, default: no, test: none
return function(...)
local args_n = select("#", ...)
if args_n == 0 or 1 <= args_n - buffer then
return func(...)
end
return func(nil, ...)
end
elseif insert_if_type == "string" then
-- buffer: yes, default: no, test: string
return function(...)
local args_n = select("#", ...)
if args_n == 0 or 1 <= args_n - buffer and type((...)) == insert_if then
return func(...)
end
return func(nil, ...)
end
elseif insert_if_type == "callable" then
-- buffer: yes, default: no, test: callable
return function(...)
local args_n = select("#", ...)
if args_n == 0 or 1 <= args_n - buffer and is_callable((...)) then
return func(...)
end
return func(nil, ...)
end
elseif insert_if_type == "set" then
-- buffer: yes, default: no, test: set
return function(...)
local args_n = select("#", ...)
if args_n == 0 or 1 <= args_n - buffer and insert_if then
return func(...)
end
return func(nil, ...)
end
elseif insert_if_type == "call" then
-- buffer: yes, default: no, test: call
return function(...)
local args_n = select("#", ...)
if args_n == 0 or 1 <= args_n - buffer and insert_if(..., 1, args_n) then
return func(...)
end
return func(nil, ...)
end
end
elseif insert_if_type == nil then
-- buffer: yes, default: yes, test: none
return function(...)
if 1 <= select("#", ...) - buffer then
return func(...)
end
return func(default, ...)
end
elseif insert_if_type == "string" then
-- buffer: yes, default: yes, test: string
return function(...)
if 1 <= select("#", ...) - buffer and type((...)) == insert_if then
return func(...)
end
return func(default, ...)
end
elseif insert_if_type == "callable" then
-- buffer: yes, default: yes, test: callable
return function(...)
if 1 <= select("#", ...) - buffer and is_callable((...)) then
return func(...)
end
return func(default, ...)
end
elseif insert_if_type == "set" then
-- buffer: yes, default: yes, test: set
return function(...)
if 1 <= select("#", ...) - buffer and insert_if then
return func(...)
end
return func(default, ...)
end
elseif insert_if_type == "call" then
-- buffer: yes, default: yes, test: call
return function(...)
local args_n = select("#", ...)
if 1 <= args_n - buffer and insert_if(..., 1, args_n) then
return func(...)
end
return func(default, ...)
end
end
error("Internal error: invalid value for `insert_if_type`")
end
local function get_first_parameter(func, chain, insert_if, insert_if_type, buffer, default)
if buffer == 0 then
if insert_if_type == nil then
if default == nil then
-- buffer: no, test: none, default: no
return function(_arg, ...)
return func(_arg, chain(select("#", ...) + 1, 2, ...))
end
end
-- buffer: no, test: none, default: yes
return function(_arg, ...)
if _arg == nil then
return func(default, chain(select("#", ...) + 1, 2, ...))
end
return func(_arg, chain(select("#", ...) + 1, 2, ...))
end
elseif insert_if_type == "string" then
-- buffer: no, test: string
return function(_arg, ...)
local args_n = select("#", ...) + 1
if type(_arg) == insert_if then
return func(_arg, chain(args_n, 2, ...))
end
return func(default, chain(args_n, 1, _arg, ...))
end
elseif insert_if_type == "callable" then
-- buffer: no, test: callable
return function(_arg, ...)
local args_n = select("#", ...) + 1
if is_callable(_arg) then
return func(_arg, chain(args_n, 2, ...))
end
return func(default, chain(args_n, 1, _arg, ...))
end
elseif insert_if_type == "set" then
-- buffer: no, test: set
return function(_arg, ...)
local args_n = select("#", ...) + 1
if insert_if then
return func(_arg, chain(args_n, 2, ...))
end
return func(default, chain(args_n, 1, _arg, ...))
end
elseif insert_if_type == "call" then
-- buffer: no, test: call
return function(_arg, ...)
local args_n = select("#", ...) + 1
if insert_if(_arg, 1, args_n) then
return func(_arg, chain(args_n, 2, ...))
end
return func(default, chain(args_n, 1, _arg, ...))
end
end
elseif insert_if_type == nil then
-- buffer: yes, test: none
return function(_arg, ...)
local args_n = select("#", ...) + 1
if 1 <= args_n - buffer then
return func(_arg, chain(args_n, 2, ...))
end
return func(default, chain(args_n, 1, _arg, ...))
end
elseif insert_if_type == "string" then
-- buffer: yes, test: string
return function(_arg, ...)
local args_n = select("#", ...) + 1
if 1 <= args_n - buffer and type(_arg) == insert_if then
return func(_arg, chain(args_n, 2, ...))
end
return func(default, chain(args_n, 1, _arg, ...))
end
elseif insert_if_type == "callable" then
-- buffer: yes, test: callable
return function(_arg, ...)
local args_n = select("#", ...) + 1
if 1 <= args_n - buffer and is_callable(_arg) then
return func(_arg, chain(args_n, 2, ...))
end
return func(default, chain(args_n, 1, _arg, ...))
end
elseif insert_if_type == "set" then
-- buffer: yes, test: set
return function(_arg, ...)
local args_n = select("#", ...) + 1
if 1 <= args_n - buffer and insert_if then
return func(_arg, chain(args_n, 2, ...))
end
return func(default, chain(args_n, 1, _arg, ...))
end
elseif insert_if_type == "call" then
-- buffer: yes, test: call
return function(_arg, ...)
local args_n = select("#", ...) + 1
if 1 <= args_n - buffer and insert_if(_arg, 1, args_n) then
return func(_arg, chain(args_n, 2, ...))
end
return func(default, chain(args_n, 1, _arg, ...))
end
end
error("Internal error: invalid value for `insert_if_type`")
end
local function get_first_and_possible_last_parameter(func, chain, insert_if, insert_if_type, buffer, default)
if buffer == 0 then
if insert_if_type == nil then
if default == nil then
-- buffer: no, test: none, default: no
return function(...)
local args_n = select("#", ...)
if args_n <= 1 then
return func(...)
end
return func(..., chain(args_n, 2, select(2, ...)))
end
end
-- buffer: no, test: none, default: yes
return function(...)
local args_n = select("#", ...)
if args_n == 0 then
return func()
elseif args_n == 1 then
if ... == nil then
return func(default)
end
return func(...)
elseif ... == nil then
return func(default, chain(args_n, 2, select(2, ...)))
end
return func(..., chain(args_n, 2, select(2, ...)))
end
elseif insert_if_type == "string" then
-- buffer: no, test: string
return function(...)
local args_n = select("#", ...)
if args_n == 0 then
return func()
elseif type((...)) ~= insert_if then
return func(default, chain(args_n, 1, ...))
elseif args_n == 1 then
return func(...)
end
return func(..., chain(args_n, 2, select(2, ...)))
end
elseif insert_if_type == "callable" then
-- buffer: no, test: callable
return function(...)
local args_n = select("#", ...)
if args_n == 0 then
return func()
elseif not is_callable((...)) then
return func(default, chain(args_n, 1, ...))
elseif args_n == 1 then
return func(...)
end
return func(..., chain(args_n, 2, select(2, ...)))
end
elseif insert_if_type == "set" then
-- buffer: no, test: set
return function(...)
local args_n = select("#", ...)
if args_n == 0 then
return func()
elseif not insert_if then
return func(default, chain(args_n, 1, ...))
elseif args_n == 1 then
return func(...)
end
return func(..., chain(args_n, 2, select(2, ...)))
end
elseif insert_if_type == "call" then
-- buffer: no, test: call
return function(...)
local args_n = select("#", ...)
if args_n == 0 then
return func()
elseif not insert_if(..., 1, args_n) then
return func(default, chain(args_n, 1, ...))
elseif args_n == 1 then
return func(...)
end
return func(..., chain(args_n, 2, select(2, ...)))
end
end
elseif insert_if_type == nil then
-- buffer: yes, test: none
return function(...)
local args_n = select("#", ...)
if args_n == 0 then
return func()
elseif 1 > args_n - buffer then
return func(default, chain(args_n, 1, ...))
elseif args_n == 1 then
return func(...)
end
return func(..., chain(args_n, 2, select(2, ...)))
end
elseif insert_if_type == "string" then
-- buffer: yes, test: string
return function(...)
local args_n = select("#", ...)
if args_n == 0 then
return func()
elseif not (1 <= args_n - buffer and type((...)) == insert_if) then
return func(default, chain(args_n, 1, ...))
elseif args_n == 1 then
return func(...)
end
return func(..., chain(args_n, 2, select(2, ...)))
end
elseif insert_if_type == "callable" then
-- buffer: yes, test: callable
return function(...)
local args_n = select("#", ...)
if args_n == 0 then
return func()
elseif not (1 <= args_n - buffer and is_callable((...))) then
return func(default, chain(args_n, 1, ...))
elseif args_n == 1 then
return func(...)
end
return func(..., chain(args_n, 2, select(2, ...)))
end
elseif insert_if_type == "set" then
-- buffer: yes, test: set
return function(...)
local args_n = select("#", ...)
if args_n == 0 then
return func()
elseif not (1 <= args_n - buffer and insert_if) then
return func(default, chain(args_n, 1, ...))
elseif args_n == 1 then
return func(...)
end
return func(..., chain(args_n, 2, select(2, ...)))
end
elseif insert_if_type == "call" then
-- buffer: yes, test: call
return function(...)
local args_n = select("#", ...)
if args_n == 0 then
return func()
elseif not (1 <= args_n - buffer and insert_if(..., 1, args_n)) then
return func(default, chain(args_n, 1, ...))
elseif args_n == 1 then
return func(...)
end
return func(..., chain(args_n, 2, select(2, ...)))
end
end
error("Internal error: invalid value for `insert_if_type`")
end
local function get_middle_parameter(chain, insert_if, insert_if_type, buffer, default)
if buffer == 0 then
if insert_if_type == nil then
if default == nil then
-- buffer: no, test: none, default: no
return function(args_n, i, _arg, ...)
return _arg, chain(args_n, i + 1, ...)
end
end
-- buffer: no, test: none, default: yes
return function(args_n, i, _arg, ...)
if _arg == nil then
return default, chain(args_n, i + 1, ...)
end
return _arg, chain(args_n, i + 1, ...)
end
elseif insert_if_type == "string" then
-- buffer: no, test: string
return function(args_n, i, _arg, ...)
if type(_arg) == insert_if then
return _arg, chain(args_n, i + 1, ...)
end
return default, chain(args_n, i, _arg, ...)
end
elseif insert_if_type == "callable" then
-- buffer: no, test: callable
return function(args_n, i, _arg, ...)
if is_callable(_arg) then
return _arg, chain(args_n, i + 1, ...)
end
return default, chain(args_n, i, _arg, ...)
end
elseif insert_if_type == "set" then
-- buffer: no, test: set
return function(args_n, i, _arg, ...)
if insert_if then
return _arg, chain(args_n, i + 1, ...)
end
return default, chain(args_n, i, _arg, ...)
end
elseif insert_if_type == "call" then
-- buffer: no, test: call
return function(args_n, i, _arg, ...)
if insert_if(_arg, i, args_n) then
return _arg, chain(args_n, i + 1, ...)
end
return default, chain(args_n, i, _arg, ...)
end
end
elseif insert_if_type == nil then
-- buffer: yes, test: none
return function(args_n, i, _arg, ...)
if i <= args_n - buffer then
return _arg, chain(args_n, i + 1, ...)
end
return default, chain(args_n, i, _arg, ...)
end
elseif insert_if_type == "string" then
-- buffer: yes, test: string
return function(args_n, i, _arg, ...)
if i <= args_n - buffer and type(_arg) == insert_if then
return _arg, chain(args_n, i + 1, ...)
end
return default, chain(args_n, i, _arg, ...)
end
elseif insert_if_type == "callable" then
-- buffer: yes, test: callable
return function(args_n, i, _arg, ...)
if i <= args_n - buffer and is_callable(_arg) then
return _arg, chain(args_n, i + 1, ...)
end
return default, chain(args_n, i, _arg, ...)
end
elseif insert_if_type == "set" then
-- buffer: yes, test: set
return function(args_n, i, _arg, ...)
if i <= args_n - buffer and insert_if then
return _arg, chain(args_n, i + 1, ...)
end
return default, chain(args_n, i, _arg, ...)
end
elseif insert_if_type == "call" then
-- buffer: yes, test: call
return function(args_n, i, _arg, ...)
if i <= args_n - buffer and insert_if(_arg, i, args_n) then
return _arg, chain(args_n, i + 1, ...)
end
return default, chain(args_n, i, _arg, ...)
end
end
error("Internal error: invalid value for `insert_if_type`")
end
local function get_middle_and_possible_last_parameter(chain, insert_if, insert_if_type, buffer, default)
if buffer == 0 then
if insert_if_type == nil then
if default == nil then
-- buffer: no, test: none, default: no
return function(args_n, i, _arg, ...)
if i >= args_n then
return _arg
end
return _arg, chain(args_n, i + 1, ...)
end
end
-- buffer: no, test: none, default: yes
return function(args_n, i, _arg, ...)
if i >= args_n then
if _arg == nil then
return default
end
return _arg
elseif _arg == nil then
return default, chain(args_n, i + 1, ...)
end
return _arg, chain(args_n, i + 1, ...)
end
elseif insert_if_type == "string" then
-- buffer: no, test: string
return function(args_n, i, _arg, ...)
if type(_arg) ~= insert_if then
return default, chain(args_n, i, _arg, ...)
elseif i >= args_n then
return _arg
end
return _arg, chain(args_n, i + 1, ...)
end
elseif insert_if_type == "callable" then
-- buffer: no, test: callable
return function(args_n, i, _arg, ...)
if not is_callable(_arg) then
return default, chain(args_n, i, _arg, ...)
elseif i >= args_n then
return _arg
end
return _arg, chain(args_n, i + 1, ...)
end
elseif insert_if_type == "set" then
-- buffer: no, test: set
return function(args_n, i, _arg, ...)
if not insert_if then
return default, chain(args_n, i, _arg, ...)
elseif i >= args_n then
return _arg
end
return _arg, chain(args_n, i + 1, ...)
end
elseif insert_if_type == "call" then
-- buffer: no, test: call
return function(args_n, i, _arg, ...)
if not insert_if(_arg, i, args_n) then
return default, chain(args_n, i, _arg, ...)
elseif i >= args_n then
return _arg
end
return _arg, chain(args_n, i + 1, ...)
end
end
elseif insert_if_type == nil then
-- buffer: yes, test: none
return function(args_n, i, _arg, ...)
if i > args_n - buffer then
return default, chain(args_n, i, _arg, ...)
elseif i >= args_n then
return _arg
end
return _arg, chain(args_n, i + 1, ...)
end
elseif insert_if_type == "string" then
-- buffer: yes, test: string
return function(args_n, i, _arg, ...)
if not (i <= args_n - buffer and type(_arg) == insert_if) then
return default, chain(args_n, i, _arg, ...)
elseif i >= args_n then
return _arg
end
return _arg, chain(args_n, i + 1, ...)
end
elseif insert_if_type == "callable" then
-- buffer: yes, test: callable
return function(args_n, i, _arg, ...)
if not (i <= args_n - buffer and is_callable(_arg)) then
return default, chain(args_n, i, _arg, ...)
elseif i >= args_n then
return _arg
end
return _arg, chain(args_n, i + 1, ...)
end
elseif insert_if_type == "set" then
-- buffer: yes, test: set
return function(args_n, i, _arg, ...)
if not (i <= args_n - buffer and insert_if) then
return default, chain(args_n, i, _arg, ...)
elseif i >= args_n then
return _arg
end
return _arg, chain(args_n, i + 1, ...)
end
elseif insert_if_type == "call" then
-- buffer: yes, test: call
return function(args_n, i, _arg, ...)
if not (i <= args_n - buffer and insert_if(_arg, i, args_n)) then
return default, chain(args_n, i, _arg, ...)
elseif i >= args_n then
return _arg
end
return _arg, chain(args_n, i + 1, ...)
end
end
error("Internal error: invalid value for `insert_if_type`")
end
local function get_last_parameter(insert_if, insert_if_type, buffer, default)
if buffer == 0 then
if insert_if_type == nil then
if default == nil then -- should be unreachable
-- buffer: no, test: none, default: no
return function(_, _, _arg, ...)
return _arg, ...
end
end
-- buffer: no, test: none, default: yes
return function(_, _, _arg, ...)
if _arg == nil then
return default, ...
end
return _arg, ...
end
elseif insert_if_type == "string" then
-- buffer: no, test: string
return function(_, _, _arg, ...)
if type(_arg) == insert_if then
return _arg, ...
end
return default, _arg, ...
end
elseif insert_if_type == "callable" then
-- buffer: no, test: callable
return function(_, _, _arg, ...)
if is_callable(_arg) then
return _arg, ...
end
return default, _arg, ...
end
elseif insert_if_type == "set" then
-- buffer: no, test: set
return function(_, _, _arg, ...)
if insert_if then
return _arg, ...
end
return default, _arg, ...
end
elseif insert_if_type == "call" then
-- buffer: no, test: call
return function(args_n, i, _arg, ...)
if insert_if(_arg, i, args_n) then
return _arg, ...
end
return default, _arg, ...
end
end
elseif insert_if_type == nil then
-- buffer: yes, test: none
return function(args_n, i, _arg, ...)
if i <= args_n - buffer then
return _arg, ...
end
return default, _arg, ...
end
elseif insert_if_type == "string" then
-- buffer: yes, test: string
return function(args_n, i, _arg, ...)
if i <= args_n - buffer and type(_arg) == insert_if then
return _arg, ...
end
return default, _arg, ...
end
elseif insert_if_type == "callable" then
-- buffer: yes, test: callable
return function(args_n, i, _arg, ...)
if i <= args_n - buffer and is_callable(_arg) then
return _arg, ...
end
return default, _arg, ...
end
elseif insert_if_type == "set" then
-- buffer: yes, test: set
return function(args_n, i, _arg, ...)
if i <= args_n - buffer and insert_if then
return _arg, ...
end
return default, _arg, ...
end
elseif insert_if_type == "call" then
-- buffer: yes, test: call
return function(args_n, i, _arg, ...)
if i <= args_n - buffer and insert_if(_arg, i, args_n) then
return _arg, ...
end
return default, _arg, ...
end
end
error("Internal error: invalid value for `insert_if_type`")
end
return function(func, specs)
local get_first = get_first_and_possible_last_parameter
local get_middle = get_middle_and_possible_last_parameter
local function get_spec(param)
local spec = specs
if spec == nil then
return nil, nil, 0, nil
end
-- `buffer` is subtracted from `args_n` to get the absolute max index.
local max_idx, buffer = spec.max
if max_idx == nil then
buffer = 0
else
local type_max = type(max_idx)
if not (type_max == "number" and is_positive_integer(-max_idx)) then
error(format(
"bad spec in 'fun/optionalParameters' in 'max' for parameter #%d: expected negative index, got %s",
param,
type_max == "number" and tostring(max_idx) or type_max
))
end
buffer = -(max_idx + 1)
end
-- Once a default has been specified, the chain must continue to that
-- point, so use the getters that return functions that don't allow
-- early return, which will be first used on the next param (since
-- the chain can still end early at this param).
local default = spec.default
if default ~= nil then
get_first, get_middle = get_first_parameter, get_middle_parameter
end
local insert_if = spec.insert_if
if insert_if == nil then
return nil, nil, buffer, default
elseif insert_if == "callable" then
return nil, "callable", buffer, default
end
local insert_if_type = type(insert_if)
if insert_if_type == "string" then
check_type(insert_if, param)
return insert_if, "string", buffer, default
elseif is_callable(insert_if) then
return insert_if, "call", buffer, default
elseif insert_if_type ~= "table" then
error(format(
"bad spec in 'fun/optionalParameters' in 'insert_if' for parameter #%d: string, function or table expected, got %s",
param, insert_if_type
))
end
-- If it's a table, unwind the first two iterations manually, and
-- confirm the first iteration returned a result.
local iter, state, i, type1, type2 = ipairs(insert_if)
i, type1 = iter(state, i)
if i == nil then
error(format("bad spec in 'fun/optionalParameters' for parameter #%d: list is empty", param))
end
check_type(type1, param)
-- If there's only one type specified, optimise by using the string
-- check with `type1`.
i, type2 = iter(state, i)
if i == nil then
return type1, "string", buffer, default
end
check_type(type2, param)
-- Otherwise, return a type map.
local types = {
= true,
= true
}
for _, _type in iter, state, i do
check_type(_type, param)
types = true
end
return types, "set", buffer, default
end
-- As an optimisation, every possible handler function has been specified
-- manually, as this removes as much overhead as possible from calls.
local param, chain = table_max_index(specs)
repeat
-- If no spec is found, return `func`.
if param == 0 then
return func
end
-- Ignore any final specs which are indistinguishable from no spec.
local insert_if, insert_if_type, buffer, default = get_spec(param)
if not (insert_if_type == nil and buffer == 0 and default == nil) then
-- If the only spec is parameter 1, the first and last functions
-- need to be combined.
if param == 1 then
return get_first_and_last_parameter(func, insert_if, insert_if_type, buffer, default)
end
chain = get_last_parameter(insert_if, insert_if_type, buffer, default)
end
param = param - 1
until chain
while param > 1 do
chain = get_middle(chain, get_spec(param))
param = param - 1
end
return get_first(func, chain, get_spec(param))
end