diff --git a/.vscode/launch.json b/.vscode/launch.json index 455046b4e4..77827f6218 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -285,6 +285,17 @@ "cwd": "${workspaceRoot}/../fable-test", "stopAtEntry": false, "console": "internalConsole" + }, + { + "name": "BuildLuaTest", + "type": "coreclr", + "request": "launch", + "preLaunchTask": "build", + "program": "${workspaceFolder}/src/Fable.Cli/bin/Debug/net6.0/fable.dll", + "args": ["watch", "--cwd", "tests/Lua", "--exclude", "Fable.Core", "--outDir", "build/tests/Lua", "--lang", "Lua", "--fableLib", "build/tests/Lua/fable-lib"], + "cwd": "${workspaceFolder}", + "stopAtEntry": true, + "console": "internalConsole" } ] } diff --git a/build_old.fsx b/build_old.fsx index 1179805e2b..4d56babdee 100644 --- a/build_old.fsx +++ b/build_old.fsx @@ -1,5 +1,5 @@ #load "src/Fable.PublishUtils/PublishUtils.fs" - +//This is old, please use build project via build.cmd or build.sh open System open System.Text.RegularExpressions open PublishUtils @@ -282,6 +282,33 @@ let buildLibraryPy () = removeDirRecursive (buildDirPy "fable_library/fable-library") +let buildLibraryLua () = + let libraryDir = "src/fable-library-lua" + let projectDir = libraryDir + "/fable" + let buildDirLua = "build/fable-library-lua" + + cleanDirs [ buildDirLua ] + + runFableWithArgs + projectDir + [ + "--outDir " + buildDirLua "fable" + "--fableLib " + buildDirLua "fable" + "--lang Lua" + "--exclude Fable.Core" + "--define FABLE_LIBRARY" + ] + // Copy *.lua from projectDir to buildDir + copyDirRecursive libraryDir buildDirLua + + runInDir buildDirLua ("lua -v") +//runInDir buildDirLua ("lua ./setup.lua develop") +let buildLuaLibraryIfNotExists () = + let baseDir = __SOURCE_DIRECTORY__ + + if not (pathExists (baseDir "build/fable-library-lua")) then + buildLibraryLua () + let buildLibraryPyIfNotExists () = let baseDir = __SOURCE_DIRECTORY__ @@ -658,6 +685,29 @@ let testPython () = // Testing in Windows // runInDir buildDir "python -m pytest -x" +let testLua () = + buildLuaLibraryIfNotExists () // NOTE: fable-library-py needs to be built separately. + + let projectDir = "tests/Lua" + let buildDir = "build/tests/Lua" + + cleanDirs [ buildDir ] + copyDirRecursive ("build" "fable-library-lua" "fable") (buildDir "fable-lib") + runInDir projectDir "dotnet test" + + runFableWithArgs + projectDir + [ + "--outDir " + buildDir + "--exclude Fable.Core" + "--lang Lua" + "--fableLib " + buildDir "fable-lib" //("fable-library-lua" "fable") //cannot use relative paths in lua. Copy to subfolder? + ] + + copyFile (projectDir "luaunit.lua") (buildDir "luaunit.lua") + copyFile (projectDir "runtests.lua") (buildDir "runtests.lua") + runInDir buildDir "lua runtests.lua" + type RustTestMode = | NoStd | Default @@ -970,6 +1020,8 @@ match BUILD_ARGS_LOWER with | "test-rust-threaded" :: _ -> testRust Threaded | "test-dart" :: _ -> testDart (false) | "watch-test-dart" :: _ -> testDart (true) +| "test-lua" :: _ -> testLua () + | "quicktest" :: _ -> buildLibraryTsIfNotExists () @@ -1028,6 +1080,7 @@ match BUILD_ARGS_LOWER with | ("fable-library-dart" | "library-dart") :: _ -> let clean = hasFlag "--no-clean" |> not buildLibraryDart (clean) +| ("fable-library-lua" | "library-lua") :: _ -> buildLibraryLua () | ("fable-compiler-js" | "compiler-js") :: _ -> let minify = hasFlag "--no-minify" |> not diff --git a/share/lua/5.1/luaunit.lua b/share/lua/5.1/luaunit.lua new file mode 100644 index 0000000000..4741478bc7 --- /dev/null +++ b/share/lua/5.1/luaunit.lua @@ -0,0 +1,3453 @@ +--[[ + luaunit.lua + +Description: A unit testing framework +Homepage: https://github.com/bluebird75/luaunit +Development by Philippe Fremy +Based on initial work of Ryu, Gwang (http://www.gpgstudy.com/gpgiki/LuaUnit) +License: BSD License, see LICENSE.txt +]]-- + +require("math") +local M={} + +-- private exported functions (for testing) +M.private = {} + +M.VERSION='3.4' +M._VERSION=M.VERSION -- For LuaUnit v2 compatibility + +-- a version which distinguish between regular Lua and LuaJit +M._LUAVERSION = (jit and jit.version) or _VERSION + +--[[ Some people like assertEquals( actual, expected ) and some people prefer +assertEquals( expected, actual ). +]]-- +M.ORDER_ACTUAL_EXPECTED = true +M.PRINT_TABLE_REF_IN_ERROR_MSG = false +M.LINE_LENGTH = 80 +M.TABLE_DIFF_ANALYSIS_THRESHOLD = 10 -- display deep analysis for more than 10 items +M.LIST_DIFF_ANALYSIS_THRESHOLD = 10 -- display deep analysis for more than 10 items + +-- this setting allow to remove entries from the stack-trace, for +-- example to hide a call to a framework which would be calling luaunit +M.STRIP_EXTRA_ENTRIES_IN_STACK_TRACE = 0 + +--[[ EPS is meant to help with Lua's floating point math in simple corner +cases like almostEquals(1.1-0.1, 1), which may not work as-is (e.g. on numbers +with rational binary representation) if the user doesn't provide some explicit +error margin. + +The default margin used by almostEquals() in such cases is EPS; and since +Lua may be compiled with different numeric precisions (single vs. double), we +try to select a useful default for it dynamically. Note: If the initial value +is not acceptable, it can be changed by the user to better suit specific needs. + +See also: https://en.wikipedia.org/wiki/Machine_epsilon +]] +M.EPS = 2^-52 -- = machine epsilon for "double", ~2.22E-16 +if math.abs(1.1 - 1 - 0.1) > M.EPS then + -- rounding error is above EPS, assume single precision + M.EPS = 2^-23 -- = machine epsilon for "float", ~1.19E-07 +end + +-- set this to false to debug luaunit +local STRIP_LUAUNIT_FROM_STACKTRACE = true + +M.VERBOSITY_DEFAULT = 10 +M.VERBOSITY_LOW = 1 +M.VERBOSITY_QUIET = 0 +M.VERBOSITY_VERBOSE = 20 +M.DEFAULT_DEEP_ANALYSIS = nil +M.FORCE_DEEP_ANALYSIS = true +M.DISABLE_DEEP_ANALYSIS = false + +-- set EXPORT_ASSERT_TO_GLOBALS to have all asserts visible as global values +-- EXPORT_ASSERT_TO_GLOBALS = true + +-- we need to keep a copy of the script args before it is overriden +local cmdline_argv = rawget(_G, "arg") + +M.FAILURE_PREFIX = 'LuaUnit test FAILURE: ' -- prefix string for failed tests +M.SUCCESS_PREFIX = 'LuaUnit test SUCCESS: ' -- prefix string for successful tests finished early +M.SKIP_PREFIX = 'LuaUnit test SKIP: ' -- prefix string for skipped tests + + + +M.USAGE=[[Usage: lua [options] [testname1 [testname2] ... ] +Options: + -h, --help: Print this help + --version: Print version information + -v, --verbose: Increase verbosity + -q, --quiet: Set verbosity to minimum + -e, --error: Stop on first error + -f, --failure: Stop on first failure or error + -s, --shuffle: Shuffle tests before running them + -o, --output OUTPUT: Set output type to OUTPUT + Possible values: text, tap, junit, nil + -n, --name NAME: For junit only, mandatory name of xml file + -r, --repeat NUM: Execute all tests NUM times, e.g. to trig the JIT + -p, --pattern PATTERN: Execute all test names matching the Lua PATTERN + May be repeated to include several patterns + Make sure you escape magic chars like +? with % + -x, --exclude PATTERN: Exclude all test names matching the Lua PATTERN + May be repeated to exclude several patterns + Make sure you escape magic chars like +? with % + testname1, testname2, ... : tests to run in the form of testFunction, + TestClass or TestClass.testMethod + +You may also control LuaUnit options with the following environment variables: +* LUAUNIT_OUTPUT: same as --output +* LUAUNIT_JUNIT_FNAME: same as --name ]] + +---------------------------------------------------------------- +-- +-- general utility functions +-- +---------------------------------------------------------------- + +--[[ Note on catching exit + +I have seen the case where running a big suite of test cases and one of them would +perform a os.exit(0), making the outside world think that the full test suite was executed +successfully. + +This is an attempt to mitigate this problem: we override os.exit() to now let a test +exit the framework while we are running. When we are not running, it behaves normally. +]] + +M.oldOsExit = os.exit +os.exit = function(...) + if M.LuaUnit and #M.LuaUnit.instances ~= 0 then + local msg = [[You are trying to exit but there is still a running instance of LuaUnit. +LuaUnit expects to run until the end before exiting with a complete status of successful/failed tests. + +To force exit LuaUnit while running, please call before os.exit (assuming lu is the luaunit module loaded): + + lu.unregisterCurrentSuite() + +]] + M.private.error_fmt(2, msg) + end + M.oldOsExit(...) +end + +local function pcall_or_abort(func, ...) + -- unpack is a global function for Lua 5.1, otherwise use table.unpack + local unpack = rawget(_G, "unpack") or table.unpack + local result = {pcall(func, ...)} + if not result[1] then + -- an error occurred + print(result[2]) -- error message + print() + print(M.USAGE) + os.exit(-1) + end + return unpack(result, 2) +end + +local crossTypeOrdering = { + number = 1, boolean = 2, string = 3, table = 4, other = 5 +} +local crossTypeComparison = { + number = function(a, b) return a < b end, + string = function(a, b) return a < b end, + other = function(a, b) return tostring(a) < tostring(b) end, +} + +local function crossTypeSort(a, b) + local type_a, type_b = type(a), type(b) + if type_a == type_b then + local func = crossTypeComparison[type_a] or crossTypeComparison.other + return func(a, b) + end + type_a = crossTypeOrdering[type_a] or crossTypeOrdering.other + type_b = crossTypeOrdering[type_b] or crossTypeOrdering.other + return type_a < type_b +end + +local function __genSortedIndex( t ) + -- Returns a sequence consisting of t's keys, sorted. + local sortedIndex = {} + + for key,_ in pairs(t) do + table.insert(sortedIndex, key) + end + + table.sort(sortedIndex, crossTypeSort) + return sortedIndex +end +M.private.__genSortedIndex = __genSortedIndex + +local function sortedNext(state, control) + -- Equivalent of the next() function of table iteration, but returns the + -- keys in sorted order (see __genSortedIndex and crossTypeSort). + -- The state is a temporary variable during iteration and contains the + -- sorted key table (state.sortedIdx). It also stores the last index (into + -- the keys) used by the iteration, to find the next one quickly. + local key + + --print("sortedNext: control = "..tostring(control) ) + if control == nil then + -- start of iteration + state.count = #state.sortedIdx + state.lastIdx = 1 + key = state.sortedIdx[1] + return key, state.t[key] + end + + -- normally, we expect the control variable to match the last key used + if control ~= state.sortedIdx[state.lastIdx] then + -- strange, we have to find the next value by ourselves + -- the key table is sorted in crossTypeSort() order! -> use bisection + local lower, upper = 1, state.count + repeat + state.lastIdx = math.modf((lower + upper) / 2) + key = state.sortedIdx[state.lastIdx] + if key == control then + break -- key found (and thus prev index) + end + if crossTypeSort(key, control) then + -- key < control, continue search "right" (towards upper bound) + lower = state.lastIdx + 1 + else + -- key > control, continue search "left" (towards lower bound) + upper = state.lastIdx - 1 + end + until lower > upper + if lower > upper then -- only true if the key wasn't found, ... + state.lastIdx = state.count -- ... so ensure no match in code below + end + end + + -- proceed by retrieving the next value (or nil) from the sorted keys + state.lastIdx = state.lastIdx + 1 + key = state.sortedIdx[state.lastIdx] + if key then + return key, state.t[key] + end + + -- getting here means returning `nil`, which will end the iteration +end + +local function sortedPairs(tbl) + -- Equivalent of the pairs() function on tables. Allows to iterate in + -- sorted order. As required by "generic for" loops, this will return the + -- iterator (function), an "invariant state", and the initial control value. + -- (see http://www.lua.org/pil/7.2.html) + return sortedNext, {t = tbl, sortedIdx = __genSortedIndex(tbl)}, nil +end +M.private.sortedPairs = sortedPairs + +-- seed the random with a strongly varying seed +math.randomseed(math.floor(os.clock()*1E11)) + +local function randomizeTable( t ) + -- randomize the item orders of the table t + for i = #t, 2, -1 do + local j = math.random(i) + if i ~= j then + t[i], t[j] = t[j], t[i] + end + end +end +M.private.randomizeTable = randomizeTable + +local function strsplit(delimiter, text) +-- Split text into a list consisting of the strings in text, separated +-- by strings matching delimiter (which may _NOT_ be a pattern). +-- Example: strsplit(", ", "Anna, Bob, Charlie, Dolores") + if delimiter == "" or delimiter == nil then -- this would result in endless loops + error("delimiter is nil or empty string!") + end + if text == nil then + return nil + end + + local list, pos, first, last = {}, 1 + while true do + first, last = text:find(delimiter, pos, true) + if first then -- found? + table.insert(list, text:sub(pos, first - 1)) + pos = last + 1 + else + table.insert(list, text:sub(pos)) + break + end + end + return list +end +M.private.strsplit = strsplit + +local function hasNewLine( s ) + -- return true if s has a newline + return (string.find(s, '\n', 1, true) ~= nil) +end +M.private.hasNewLine = hasNewLine + +local function prefixString( prefix, s ) + -- Prefix all the lines of s with prefix + return prefix .. string.gsub(s, '\n', '\n' .. prefix) +end +M.private.prefixString = prefixString + +local function strMatch(s, pattern, start, final ) + -- return true if s matches completely the pattern from index start to index end + -- return false in every other cases + -- if start is nil, matches from the beginning of the string + -- if final is nil, matches to the end of the string + start = start or 1 + final = final or string.len(s) + + local foundStart, foundEnd = string.find(s, pattern, start, false) + return foundStart == start and foundEnd == final +end +M.private.strMatch = strMatch + +local function patternFilter(patterns, expr) + -- Run `expr` through the inclusion and exclusion rules defined in patterns + -- and return true if expr shall be included, false for excluded. + -- Inclusion pattern are defined as normal patterns, exclusions + -- patterns start with `!` and are followed by a normal pattern + + -- result: nil = UNKNOWN (not matched yet), true = ACCEPT, false = REJECT + -- default: true if no explicit "include" is found, set to false otherwise + local default, result = true, nil + + if patterns ~= nil then + for _, pattern in ipairs(patterns) do + local exclude = pattern:sub(1,1) == '!' + if exclude then + pattern = pattern:sub(2) + else + -- at least one include pattern specified, a match is required + default = false + end + -- print('pattern: ',pattern) + -- print('exclude: ',exclude) + -- print('default: ',default) + + if string.find(expr, pattern) then + -- set result to false when excluding, true otherwise + result = not exclude + end + end + end + + if result ~= nil then + return result + end + return default +end +M.private.patternFilter = patternFilter + +local function xmlEscape( s ) + -- Return s escaped for XML attributes + -- escapes table: + -- " " + -- ' ' + -- < < + -- > > + -- & & + + return string.gsub( s, '.', { + ['&'] = "&", + ['"'] = """, + ["'"] = "'", + ['<'] = "<", + ['>'] = ">", + } ) +end +M.private.xmlEscape = xmlEscape + +local function xmlCDataEscape( s ) + -- Return s escaped for CData section, escapes: "]]>" + return string.gsub( s, ']]>', ']]>' ) +end +M.private.xmlCDataEscape = xmlCDataEscape + + +local function lstrip( s ) + --[[Return s with all leading white spaces and tabs removed]] + local idx = 0 + while idx < s:len() do + idx = idx + 1 + local c = s:sub(idx,idx) + if c ~= ' ' and c ~= '\t' then + break + end + end + return s:sub(idx) +end +M.private.lstrip = lstrip + +local function extractFileLineInfo( s ) + --[[ From a string in the form "(leading spaces) dir1/dir2\dir3\file.lua:linenb: msg" + + Return the "file.lua:linenb" information + ]] + local s2 = lstrip(s) + local firstColon = s2:find(':', 1, true) + if firstColon == nil then + -- string is not in the format file:line: + return s + end + local secondColon = s2:find(':', firstColon+1, true) + if secondColon == nil then + -- string is not in the format file:line: + return s + end + + return s2:sub(1, secondColon-1) +end +M.private.extractFileLineInfo = extractFileLineInfo + + +local function stripLuaunitTrace2( stackTrace, errMsg ) + --[[ + -- Example of a traceback: + < + [C]: in function 'xpcall' + ./luaunit.lua:1449: in function 'protectedCall' + ./luaunit.lua:1508: in function 'execOneFunction' + ./luaunit.lua:1596: in function 'runSuiteByInstances' + ./luaunit.lua:1660: in function 'runSuiteByNames' + ./luaunit.lua:1736: in function 'runSuite' + example_with_luaunit.lua:140: in main chunk + [C]: in ?>> + error message: <> + + Other example: + < + [C]: in function 'xpcall' + ./luaunit.lua:1517: in function 'protectedCall' + ./luaunit.lua:1578: in function 'execOneFunction' + ./luaunit.lua:1677: in function 'runSuiteByInstances' + ./luaunit.lua:1730: in function 'runSuiteByNames' + ./luaunit.lua:1806: in function 'runSuite' + example_with_luaunit.lua:140: in main chunk + [C]: in ?>> + error message: <> + + < + [C]: in function 'xpcall' + luaunit2/luaunit.lua:1532: in function 'protectedCall' + luaunit2/luaunit.lua:1591: in function 'execOneFunction' + luaunit2/luaunit.lua:1679: in function 'runSuiteByInstances' + luaunit2/luaunit.lua:1743: in function 'runSuiteByNames' + luaunit2/luaunit.lua:1819: in function 'runSuite' + luaunit2/example_with_luaunit.lua:140: in main chunk + [C]: in ?>> + error message: <> + + + -- first line is "stack traceback": KEEP + -- next line may be luaunit line: REMOVE + -- next lines are call in the program under testOk: REMOVE + -- next lines are calls from luaunit to call the program under test: KEEP + + -- Strategy: + -- keep first line + -- remove lines that are part of luaunit + -- kepp lines until we hit a luaunit line + + The strategy for stripping is: + * keep first line "stack traceback:" + * part1: + * analyse all lines of the stack from bottom to top of the stack (first line to last line) + * extract the "file:line:" part of the line + * compare it with the "file:line" part of the error message + * if it does not match strip the line + * if it matches, keep the line and move to part 2 + * part2: + * anything NOT starting with luaunit.lua is the interesting part of the stack trace + * anything starting again with luaunit.lua is part of the test launcher and should be stripped out + ]] + + local function isLuaunitInternalLine( s ) + -- return true if line of stack trace comes from inside luaunit + return s:find('[/\\]luaunit%.lua:%d+: ') ~= nil + end + + -- print( '<<'..stackTrace..'>>' ) + + local t = strsplit( '\n', stackTrace ) + -- print( prettystr(t) ) + + local idx = 2 + + local errMsgFileLine = extractFileLineInfo(errMsg) + -- print('emfi="'..errMsgFileLine..'"') + + -- remove lines that are still part of luaunit + while t[idx] and extractFileLineInfo(t[idx]) ~= errMsgFileLine do + -- print('Removing : '..t[idx] ) + table.remove(t, idx) + end + + -- keep lines until we hit luaunit again + while t[idx] and (not isLuaunitInternalLine(t[idx])) do + -- print('Keeping : '..t[idx] ) + idx = idx + 1 + end + + -- remove remaining luaunit lines + while t[idx] do + -- print('Removing2 : '..t[idx] ) + table.remove(t, idx) + end + + -- print( prettystr(t) ) + return table.concat( t, '\n') + +end +M.private.stripLuaunitTrace2 = stripLuaunitTrace2 + + +local function prettystr_sub(v, indentLevel, printTableRefs, cycleDetectTable ) + local type_v = type(v) + if "string" == type_v then + -- use clever delimiters according to content: + -- enclose with single quotes if string contains ", but no ' + if v:find('"', 1, true) and not v:find("'", 1, true) then + return "'" .. v .. "'" + end + -- use double quotes otherwise, escape embedded " + return '"' .. v:gsub('"', '\\"') .. '"' + + elseif "table" == type_v then + --if v.__class__ then + -- return string.gsub( tostring(v), 'table', v.__class__ ) + --end + return M.private._table_tostring(v, indentLevel, printTableRefs, cycleDetectTable) + + elseif "number" == type_v then + -- eliminate differences in formatting between various Lua versions + if v ~= v then + return "#NaN" -- "not a number" + end + if v == math.huge then + return "#Inf" -- "infinite" + end + if v == -math.huge then + return "-#Inf" + end + if _VERSION == "Lua 5.3" then + local i = math.tointeger(v) + if i then + return tostring(i) + end + end + end + + return tostring(v) +end + +local function prettystr( v ) + --[[ Pretty string conversion, to display the full content of a variable of any type. + + * string are enclosed with " by default, or with ' if string contains a " + * tables are expanded to show their full content, with indentation in case of nested tables + ]]-- + local cycleDetectTable = {} + local s = prettystr_sub(v, 1, M.PRINT_TABLE_REF_IN_ERROR_MSG, cycleDetectTable) + if cycleDetectTable.detected and not M.PRINT_TABLE_REF_IN_ERROR_MSG then + -- some table contain recursive references, + -- so we must recompute the value by including all table references + -- else the result looks like crap + cycleDetectTable = {} + s = prettystr_sub(v, 1, true, cycleDetectTable) + end + return s +end +M.prettystr = prettystr + +function M.adjust_err_msg_with_iter( err_msg, iter_msg ) + --[[ Adjust the error message err_msg: trim the FAILURE_PREFIX or SUCCESS_PREFIX information if needed, + add the iteration message if any and return the result. + + err_msg: string, error message captured with pcall + iter_msg: a string describing the current iteration ("iteration N") or nil + if there is no iteration in this test. + + Returns: (new_err_msg, test_status) + new_err_msg: string, adjusted error message, or nil in case of success + test_status: M.NodeStatus.FAIL, SUCCESS or ERROR according to the information + contained in the error message. + ]] + if iter_msg then + iter_msg = iter_msg..', ' + else + iter_msg = '' + end + + local RE_FILE_LINE = '.*:%d+: ' + + -- error message is not necessarily a string, + -- so convert the value to string with prettystr() + if type( err_msg ) ~= 'string' then + err_msg = prettystr( err_msg ) + end + + if (err_msg:find( M.SUCCESS_PREFIX ) == 1) or err_msg:match( '('..RE_FILE_LINE..')' .. M.SUCCESS_PREFIX .. ".*" ) then + -- test finished early with success() + return nil, M.NodeStatus.SUCCESS + end + + if (err_msg:find( M.SKIP_PREFIX ) == 1) or (err_msg:match( '('..RE_FILE_LINE..')' .. M.SKIP_PREFIX .. ".*" ) ~= nil) then + -- substitute prefix by iteration message + err_msg = err_msg:gsub('.*'..M.SKIP_PREFIX, iter_msg, 1) + -- print("failure detected") + return err_msg, M.NodeStatus.SKIP + end + + if (err_msg:find( M.FAILURE_PREFIX ) == 1) or (err_msg:match( '('..RE_FILE_LINE..')' .. M.FAILURE_PREFIX .. ".*" ) ~= nil) then + -- substitute prefix by iteration message + err_msg = err_msg:gsub(M.FAILURE_PREFIX, iter_msg, 1) + -- print("failure detected") + return err_msg, M.NodeStatus.FAIL + end + + + + -- print("error detected") + -- regular error, not a failure + if iter_msg then + local match + -- "./test\\test_luaunit.lua:2241: some error msg + match = err_msg:match( '(.*:%d+: ).*' ) + if match then + err_msg = err_msg:gsub( match, match .. iter_msg ) + else + -- no file:line: infromation, just add the iteration info at the beginning of the line + err_msg = iter_msg .. err_msg + end + end + return err_msg, M.NodeStatus.ERROR +end + +local function tryMismatchFormatting( table_a, table_b, doDeepAnalysis, margin ) + --[[ + Prepares a nice error message when comparing tables, performing a deeper + analysis. + + Arguments: + * table_a, table_b: tables to be compared + * doDeepAnalysis: + M.DEFAULT_DEEP_ANALYSIS: (the default if not specified) perform deep analysis only for big lists and big dictionnaries + M.FORCE_DEEP_ANALYSIS : always perform deep analysis + M.DISABLE_DEEP_ANALYSIS: never perform deep analysis + * margin: supplied only for almost equality + + Returns: {success, result} + * success: false if deep analysis could not be performed + in this case, just use standard assertion message + * result: if success is true, a multi-line string with deep analysis of the two lists + ]] + + -- check if table_a & table_b are suitable for deep analysis + if type(table_a) ~= 'table' or type(table_b) ~= 'table' then + return false + end + + if doDeepAnalysis == M.DISABLE_DEEP_ANALYSIS then + return false + end + + local len_a, len_b, isPureList = #table_a, #table_b, true + + for k1, v1 in pairs(table_a) do + if type(k1) ~= 'number' or k1 > len_a then + -- this table a mapping + isPureList = false + break + end + end + + if isPureList then + for k2, v2 in pairs(table_b) do + if type(k2) ~= 'number' or k2 > len_b then + -- this table a mapping + isPureList = false + break + end + end + end + + if isPureList and math.min(len_a, len_b) < M.LIST_DIFF_ANALYSIS_THRESHOLD then + if not (doDeepAnalysis == M.FORCE_DEEP_ANALYSIS) then + return false + end + end + + if isPureList then + return M.private.mismatchFormattingPureList( table_a, table_b, margin ) + else + -- only work on mapping for the moment + -- return M.private.mismatchFormattingMapping( table_a, table_b, doDeepAnalysis ) + return false + end +end +M.private.tryMismatchFormatting = tryMismatchFormatting + +local function getTaTbDescr() + if not M.ORDER_ACTUAL_EXPECTED then + return 'expected', 'actual' + end + return 'actual', 'expected' +end + +local function extendWithStrFmt( res, ... ) + table.insert( res, string.format( ... ) ) +end + +local function mismatchFormattingMapping( table_a, table_b, doDeepAnalysis ) + --[[ + Prepares a nice error message when comparing tables which are not pure lists, performing a deeper + analysis. + + Returns: {success, result} + * success: false if deep analysis could not be performed + in this case, just use standard assertion message + * result: if success is true, a multi-line string with deep analysis of the two lists + ]] + + -- disable for the moment + --[[ + local result = {} + local descrTa, descrTb = getTaTbDescr() + + local keysCommon = {} + local keysOnlyTa = {} + local keysOnlyTb = {} + local keysDiffTaTb = {} + + local k, v + + for k,v in pairs( table_a ) do + if is_equal( v, table_b[k] ) then + table.insert( keysCommon, k ) + else + if table_b[k] == nil then + table.insert( keysOnlyTa, k ) + else + table.insert( keysDiffTaTb, k ) + end + end + end + + for k,v in pairs( table_b ) do + if not is_equal( v, table_a[k] ) and table_a[k] == nil then + table.insert( keysOnlyTb, k ) + end + end + + local len_a = #keysCommon + #keysDiffTaTb + #keysOnlyTa + local len_b = #keysCommon + #keysDiffTaTb + #keysOnlyTb + local limited_display = (len_a < 5 or len_b < 5) + + if math.min(len_a, len_b) < M.TABLE_DIFF_ANALYSIS_THRESHOLD then + return false + end + + if not limited_display then + if len_a == len_b then + extendWithStrFmt( result, 'Table A (%s) and B (%s) both have %d items', descrTa, descrTb, len_a ) + else + extendWithStrFmt( result, 'Table A (%s) has %d items and table B (%s) has %d items', descrTa, len_a, descrTb, len_b ) + end + + if #keysCommon == 0 and #keysDiffTaTb == 0 then + table.insert( result, 'Table A and B have no keys in common, they are totally different') + else + local s_other = 'other ' + if #keysCommon then + extendWithStrFmt( result, 'Table A and B have %d identical items', #keysCommon ) + else + table.insert( result, 'Table A and B have no identical items' ) + s_other = '' + end + + if #keysDiffTaTb ~= 0 then + result[#result] = string.format( '%s and %d items differing present in both tables', result[#result], #keysDiffTaTb) + else + result[#result] = string.format( '%s and no %sitems differing present in both tables', result[#result], s_other, #keysDiffTaTb) + end + end + + extendWithStrFmt( result, 'Table A has %d keys not present in table B and table B has %d keys not present in table A', #keysOnlyTa, #keysOnlyTb ) + end + + local function keytostring(k) + if "string" == type(k) and k:match("^[_%a][_%w]*$") then + return k + end + return prettystr(k) + end + + if #keysDiffTaTb ~= 0 then + table.insert( result, 'Items differing in A and B:') + for k,v in sortedPairs( keysDiffTaTb ) do + extendWithStrFmt( result, ' - A[%s]: %s', keytostring(v), prettystr(table_a[v]) ) + extendWithStrFmt( result, ' + B[%s]: %s', keytostring(v), prettystr(table_b[v]) ) + end + end + + if #keysOnlyTa ~= 0 then + table.insert( result, 'Items only in table A:' ) + for k,v in sortedPairs( keysOnlyTa ) do + extendWithStrFmt( result, ' - A[%s]: %s', keytostring(v), prettystr(table_a[v]) ) + end + end + + if #keysOnlyTb ~= 0 then + table.insert( result, 'Items only in table B:' ) + for k,v in sortedPairs( keysOnlyTb ) do + extendWithStrFmt( result, ' + B[%s]: %s', keytostring(v), prettystr(table_b[v]) ) + end + end + + if #keysCommon ~= 0 then + table.insert( result, 'Items common to A and B:') + for k,v in sortedPairs( keysCommon ) do + extendWithStrFmt( result, ' = A and B [%s]: %s', keytostring(v), prettystr(table_a[v]) ) + end + end + + return true, table.concat( result, '\n') + ]] +end +M.private.mismatchFormattingMapping = mismatchFormattingMapping + +local function mismatchFormattingPureList( table_a, table_b, margin ) + --[[ + Prepares a nice error message when comparing tables which are lists, performing a deeper + analysis. + + margin is supplied only for almost equality + + Returns: {success, result} + * success: false if deep analysis could not be performed + in this case, just use standard assertion message + * result: if success is true, a multi-line string with deep analysis of the two lists + ]] + local result, descrTa, descrTb = {}, getTaTbDescr() + + local len_a, len_b, refa, refb = #table_a, #table_b, '', '' + if M.PRINT_TABLE_REF_IN_ERROR_MSG then + refa, refb = string.format( '<%s> ', M.private.table_ref(table_a)), string.format('<%s> ', M.private.table_ref(table_b) ) + end + local longest, shortest = math.max(len_a, len_b), math.min(len_a, len_b) + local deltalv = longest - shortest + + local commonUntil = shortest + for i = 1, shortest do + if not M.private.is_table_equals(table_a[i], table_b[i], margin) then + commonUntil = i - 1 + break + end + end + + local commonBackTo = shortest - 1 + for i = 0, shortest - 1 do + if not M.private.is_table_equals(table_a[len_a-i], table_b[len_b-i], margin) then + commonBackTo = i - 1 + break + end + end + + + table.insert( result, 'List difference analysis:' ) + if len_a == len_b then + -- TODO: handle expected/actual naming + extendWithStrFmt( result, '* lists %sA (%s) and %sB (%s) have the same size', refa, descrTa, refb, descrTb ) + else + extendWithStrFmt( result, '* list sizes differ: list %sA (%s) has %d items, list %sB (%s) has %d items', refa, descrTa, len_a, refb, descrTb, len_b ) + end + + extendWithStrFmt( result, '* lists A and B start differing at index %d', commonUntil+1 ) + if commonBackTo >= 0 then + if deltalv > 0 then + extendWithStrFmt( result, '* lists A and B are equal again from index %d for A, %d for B', len_a-commonBackTo, len_b-commonBackTo ) + else + extendWithStrFmt( result, '* lists A and B are equal again from index %d', len_a-commonBackTo ) + end + end + + local function insertABValue(ai, bi) + bi = bi or ai + if M.private.is_table_equals( table_a[ai], table_b[bi], margin) then + return extendWithStrFmt( result, ' = A[%d], B[%d]: %s', ai, bi, prettystr(table_a[ai]) ) + else + extendWithStrFmt( result, ' - A[%d]: %s', ai, prettystr(table_a[ai])) + extendWithStrFmt( result, ' + B[%d]: %s', bi, prettystr(table_b[bi])) + end + end + + -- common parts to list A & B, at the beginning + if commonUntil > 0 then + table.insert( result, '* Common parts:' ) + for i = 1, commonUntil do + insertABValue( i ) + end + end + + -- diffing parts to list A & B + if commonUntil < shortest - commonBackTo - 1 then + table.insert( result, '* Differing parts:' ) + for i = commonUntil + 1, shortest - commonBackTo - 1 do + insertABValue( i ) + end + end + + -- display indexes of one list, with no match on other list + if shortest - commonBackTo <= longest - commonBackTo - 1 then + table.insert( result, '* Present only in one list:' ) + for i = shortest - commonBackTo, longest - commonBackTo - 1 do + if len_a > len_b then + extendWithStrFmt( result, ' - A[%d]: %s', i, prettystr(table_a[i]) ) + -- table.insert( result, '+ (no matching B index)') + else + -- table.insert( result, '- no matching A index') + extendWithStrFmt( result, ' + B[%d]: %s', i, prettystr(table_b[i]) ) + end + end + end + + -- common parts to list A & B, at the end + if commonBackTo >= 0 then + table.insert( result, '* Common parts at the end of the lists' ) + for i = longest - commonBackTo, longest do + if len_a > len_b then + insertABValue( i, i-deltalv ) + else + insertABValue( i-deltalv, i ) + end + end + end + + return true, table.concat( result, '\n') +end +M.private.mismatchFormattingPureList = mismatchFormattingPureList + +local function prettystrPairs(value1, value2, suffix_a, suffix_b) + --[[ + This function helps with the recurring task of constructing the "expected + vs. actual" error messages. It takes two arbitrary values and formats + corresponding strings with prettystr(). + + To keep the (possibly complex) output more readable in case the resulting + strings contain line breaks, they get automatically prefixed with additional + newlines. Both suffixes are optional (default to empty strings), and get + appended to the "value1" string. "suffix_a" is used if line breaks were + encountered, "suffix_b" otherwise. + + Returns the two formatted strings (including padding/newlines). + ]] + local str1, str2 = prettystr(value1), prettystr(value2) + if hasNewLine(str1) or hasNewLine(str2) then + -- line break(s) detected, add padding + return "\n" .. str1 .. (suffix_a or ""), "\n" .. str2 + end + return str1 .. (suffix_b or ""), str2 +end +M.private.prettystrPairs = prettystrPairs + +local UNKNOWN_REF = 'table 00-unknown ref' +local ref_generator = { value=1, [UNKNOWN_REF]=0 } + +local function table_ref( t ) + -- return the default tostring() for tables, with the table ID, even if the table has a metatable + -- with the __tostring converter + local ref = '' + local mt = getmetatable( t ) + if mt == nil then + ref = tostring(t) + else + local success, result + success, result = pcall(setmetatable, t, nil) + if not success then + -- protected table, if __tostring is defined, we can + -- not get the reference. And we can not know in advance. + ref = tostring(t) + if not ref:match( 'table: 0?x?[%x]+' ) then + return UNKNOWN_REF + end + else + ref = tostring(t) + setmetatable( t, mt ) + end + end + -- strip the "table: " part + ref = ref:sub(8) + if ref ~= UNKNOWN_REF and ref_generator[ref] == nil then + -- Create a new reference number + ref_generator[ref] = ref_generator.value + ref_generator.value = ref_generator.value+1 + end + if M.PRINT_TABLE_REF_IN_ERROR_MSG then + return string.format('table %02d-%s', ref_generator[ref], ref) + else + return string.format('table %02d', ref_generator[ref]) + end +end +M.private.table_ref = table_ref + +local TABLE_TOSTRING_SEP = ", " +local TABLE_TOSTRING_SEP_LEN = string.len(TABLE_TOSTRING_SEP) + +local function _table_tostring( tbl, indentLevel, printTableRefs, cycleDetectTable ) + printTableRefs = printTableRefs or M.PRINT_TABLE_REF_IN_ERROR_MSG + cycleDetectTable = cycleDetectTable or {} + cycleDetectTable[tbl] = true + + local result, dispOnMultLines = {}, false + + -- like prettystr but do not enclose with "" if the string is just alphanumerical + -- this is better for displaying table keys who are often simple strings + local function keytostring(k) + if "string" == type(k) and k:match("^[_%a][_%w]*$") then + return k + end + return prettystr_sub(k, indentLevel+1, printTableRefs, cycleDetectTable) + end + + local mt = getmetatable( tbl ) + + if mt and mt.__tostring then + -- if table has a __tostring() function in its metatable, use it to display the table + -- else, compute a regular table + result = tostring(tbl) + if type(result) ~= 'string' then + return string.format( '', prettystr(result) ) + end + result = strsplit( '\n', result ) + return M.private._table_tostring_format_multiline_string( result, indentLevel ) + + else + -- no metatable, compute the table representation + + local entry, count, seq_index = nil, 0, 1 + for k, v in sortedPairs( tbl ) do + + -- key part + if k == seq_index then + -- for the sequential part of tables, we'll skip the "=" output + entry = '' + seq_index = seq_index + 1 + elseif cycleDetectTable[k] then + -- recursion in the key detected + cycleDetectTable.detected = true + entry = "<"..table_ref(k)..">=" + else + entry = keytostring(k) .. "=" + end + + -- value part + if cycleDetectTable[v] then + -- recursion in the value detected! + cycleDetectTable.detected = true + entry = entry .. "<"..table_ref(v)..">" + else + entry = entry .. + prettystr_sub( v, indentLevel+1, printTableRefs, cycleDetectTable ) + end + count = count + 1 + result[count] = entry + end + return M.private._table_tostring_format_result( tbl, result, indentLevel, printTableRefs ) + end + +end +M.private._table_tostring = _table_tostring -- prettystr_sub() needs it + +local function _table_tostring_format_multiline_string( tbl_str, indentLevel ) + local indentString = '\n'..string.rep(" ", indentLevel - 1) + return table.concat( tbl_str, indentString ) + +end +M.private._table_tostring_format_multiline_string = _table_tostring_format_multiline_string + + +local function _table_tostring_format_result( tbl, result, indentLevel, printTableRefs ) + -- final function called in _table_to_string() to format the resulting list of + -- string describing the table. + + local dispOnMultLines = false + + -- set dispOnMultLines to true if the maximum LINE_LENGTH would be exceeded with the values + local totalLength = 0 + for k, v in ipairs( result ) do + totalLength = totalLength + string.len( v ) + if totalLength >= M.LINE_LENGTH then + dispOnMultLines = true + break + end + end + + -- set dispOnMultLines to true if the max LINE_LENGTH would be exceeded + -- with the values and the separators. + if not dispOnMultLines then + -- adjust with length of separator(s): + -- two items need 1 sep, three items two seps, ... plus len of '{}' + if #result > 0 then + totalLength = totalLength + TABLE_TOSTRING_SEP_LEN * (#result - 1) + end + dispOnMultLines = (totalLength + 2 >= M.LINE_LENGTH) + end + + -- now reformat the result table (currently holding element strings) + if dispOnMultLines then + local indentString = string.rep(" ", indentLevel - 1) + result = { + "{\n ", + indentString, + table.concat(result, ",\n " .. indentString), + "\n", + indentString, + "}" + } + else + result = {"{", table.concat(result, TABLE_TOSTRING_SEP), "}"} + end + if printTableRefs then + table.insert(result, 1, "<"..table_ref(tbl).."> ") -- prepend table ref + end + return table.concat(result) +end +M.private._table_tostring_format_result = _table_tostring_format_result -- prettystr_sub() needs it + +local function table_findkeyof(t, element) + -- Return the key k of the given element in table t, so that t[k] == element + -- (or `nil` if element is not present within t). Note that we use our + -- 'general' is_equal comparison for matching, so this function should + -- handle table-type elements gracefully and consistently. + if type(t) == "table" then + for k, v in pairs(t) do + if M.private.is_table_equals(v, element) then + return k + end + end + end + return nil +end + +local function _is_table_items_equals(actual, expected ) + local type_a, type_e = type(actual), type(expected) + + if type_a ~= type_e then + return false + + elseif (type_a == 'table') --[[and (type_e == 'table')]] then + for k, v in pairs(actual) do + if table_findkeyof(expected, v) == nil then + return false -- v not contained in expected + end + end + for k, v in pairs(expected) do + if table_findkeyof(actual, v) == nil then + return false -- v not contained in actual + end + end + return true + + elseif actual ~= expected then + return false + end + + return true +end + +--[[ +This is a specialized metatable to help with the bookkeeping of recursions +in _is_table_equals(). It provides an __index table that implements utility +functions for easier management of the table. The "cached" method queries +the state of a specific (actual,expected) pair; and the "store" method sets +this state to the given value. The state of pairs not "seen" / visited is +assumed to be `nil`. +]] +local _recursion_cache_MT = { + __index = { + -- Return the cached value for an (actual,expected) pair (or `nil`) + cached = function(t, actual, expected) + local subtable = t[actual] or {} + return subtable[expected] + end, + + -- Store cached value for a specific (actual,expected) pair. + -- Returns the value, so it's easy to use for a "tailcall" (return ...). + store = function(t, actual, expected, value, asymmetric) + local subtable = t[actual] + if not subtable then + subtable = {} + t[actual] = subtable + end + subtable[expected] = value + + -- Unless explicitly marked "asymmetric": Consider the recursion + -- on (expected,actual) to be equivalent to (actual,expected) by + -- default, and thus cache the value for both. + if not asymmetric then + t:store(expected, actual, value, true) + end + + return value + end + } +} + +local function _is_table_equals(actual, expected, cycleDetectTable, marginForAlmostEqual) + --[[Returns true if both table are equal. + + If argument marginForAlmostEqual is suppied, number comparison is done using alomstEqual instead + of strict equality. + + cycleDetectTable is an internal argument used during recursion on tables. + ]] + --print('_is_table_equals( \n '..prettystr(actual)..'\n , '..prettystr(expected).. + -- '\n , '..prettystr(cycleDetectTable)..'\n , '..prettystr(marginForAlmostEqual)..' )') + + local type_a, type_e = type(actual), type(expected) + + if type_a ~= type_e then + return false -- different types won't match + end + + if type_a == 'number' then + if marginForAlmostEqual ~= nil then + return M.almostEquals(actual, expected, marginForAlmostEqual) + else + return actual == expected + end + elseif type_a ~= 'table' then + -- other types compare directly + return actual == expected + end + + cycleDetectTable = cycleDetectTable or { actual={}, expected={} } + if cycleDetectTable.actual[ actual ] then + -- oh, we hit a cycle in actual + if cycleDetectTable.expected[ expected ] then + -- uh, we hit a cycle at the same time in expected + -- so the two tables have similar structure + return true + end + + -- cycle was hit only in actual, the structure differs from expected + return false + end + + if cycleDetectTable.expected[ expected ] then + -- no cycle in actual, but cycle in expected + -- the structure differ + return false + end + + -- at this point, no table cycle detected, we are + -- seeing this table for the first time + + -- mark the cycle detection + cycleDetectTable.actual[ actual ] = true + cycleDetectTable.expected[ expected ] = true + + + local actualKeysMatched = {} + for k, v in pairs(actual) do + actualKeysMatched[k] = true -- Keep track of matched keys + if not _is_table_equals(v, expected[k], cycleDetectTable, marginForAlmostEqual) then + -- table differs on this key + -- clear the cycle detection before returning + cycleDetectTable.actual[ actual ] = nil + cycleDetectTable.expected[ expected ] = nil + return false + end + end + + for k, v in pairs(expected) do + if not actualKeysMatched[k] then + -- Found a key that we did not see in "actual" -> mismatch + -- clear the cycle detection before returning + cycleDetectTable.actual[ actual ] = nil + cycleDetectTable.expected[ expected ] = nil + return false + end + -- Otherwise actual[k] was already matched against v = expected[k]. + end + + -- all key match, we have a match ! + cycleDetectTable.actual[ actual ] = nil + cycleDetectTable.expected[ expected ] = nil + return true +end +M.private._is_table_equals = _is_table_equals + +local function failure(main_msg, extra_msg_or_nil, level) + -- raise an error indicating a test failure + -- for error() compatibility we adjust "level" here (by +1), to report the + -- calling context + local msg + if type(extra_msg_or_nil) == 'string' and extra_msg_or_nil:len() > 0 then + msg = extra_msg_or_nil .. '\n' .. main_msg + else + msg = main_msg + end + error(M.FAILURE_PREFIX .. msg, (level or 1) + 1 + M.STRIP_EXTRA_ENTRIES_IN_STACK_TRACE) +end + +local function is_table_equals(actual, expected, marginForAlmostEqual) + return _is_table_equals(actual, expected, nil, marginForAlmostEqual) +end +M.private.is_table_equals = is_table_equals + +local function fail_fmt(level, extra_msg_or_nil, ...) + -- failure with printf-style formatted message and given error level + failure(string.format(...), extra_msg_or_nil, (level or 1) + 1) +end +M.private.fail_fmt = fail_fmt + +local function error_fmt(level, ...) + -- printf-style error() + error(string.format(...), (level or 1) + 1 + M.STRIP_EXTRA_ENTRIES_IN_STACK_TRACE) +end +M.private.error_fmt = error_fmt + +---------------------------------------------------------------- +-- +-- assertions +-- +---------------------------------------------------------------- + +local function errorMsgEquality(actual, expected, doDeepAnalysis, margin) + -- margin is supplied only for almost equal verification + + if not M.ORDER_ACTUAL_EXPECTED then + expected, actual = actual, expected + end + if type(expected) == 'string' or type(expected) == 'table' then + local strExpected, strActual = prettystrPairs(expected, actual) + local result = string.format("expected: %s\nactual: %s", strExpected, strActual) + if margin then + result = result .. '\nwere not equal by the margin of: '..prettystr(margin) + end + + -- extend with mismatch analysis if possible: + local success, mismatchResult + success, mismatchResult = tryMismatchFormatting( actual, expected, doDeepAnalysis, margin ) + if success then + result = table.concat( { result, mismatchResult }, '\n' ) + end + return result + end + return string.format("expected: %s, actual: %s", + prettystr(expected), prettystr(actual)) +end + +function M.assertError(f, ...) + -- assert that calling f with the arguments will raise an error + -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error + if pcall( f, ... ) then + failure( "Expected an error when calling function but no error generated", nil, 2 ) + end +end + +function M.fail( msg ) + -- stops a test due to a failure + failure( msg, nil, 2 ) +end + +function M.failIf( cond, msg ) + -- Fails a test with "msg" if condition is true + if cond then + failure( msg, nil, 2 ) + end +end + +function M.skip(msg) + -- skip a running test + error_fmt(2, M.SKIP_PREFIX .. msg) +end + +function M.skipIf( cond, msg ) + -- skip a running test if condition is met + if cond then + error_fmt(2, M.SKIP_PREFIX .. msg) + end +end + +function M.runOnlyIf( cond, msg ) + -- continue a running test if condition is met, else skip it + if not cond then + error_fmt(2, M.SKIP_PREFIX .. prettystr(msg)) + end +end + +function M.success() + -- stops a test with a success + error_fmt(2, M.SUCCESS_PREFIX) +end + +function M.successIf( cond ) + -- stops a test with a success if condition is met + if cond then + error_fmt(2, M.SUCCESS_PREFIX) + end +end + + +------------------------------------------------------------------ +-- Equality assertions +------------------------------------------------------------------ + +function M.assertEquals(actual, expected, extra_msg_or_nil, doDeepAnalysis) + if type(actual) == 'table' and type(expected) == 'table' then + if not is_table_equals(actual, expected) then + failure( errorMsgEquality(actual, expected, doDeepAnalysis), extra_msg_or_nil, 2 ) + end + elseif type(actual) ~= type(expected) then + failure( errorMsgEquality(actual, expected), extra_msg_or_nil, 2 ) + elseif actual ~= expected then + failure( errorMsgEquality(actual, expected), extra_msg_or_nil, 2 ) + end +end + +function M.almostEquals( actual, expected, margin ) + if type(actual) ~= 'number' or type(expected) ~= 'number' or type(margin) ~= 'number' then + error_fmt(3, 'almostEquals: must supply only number arguments.\nArguments supplied: %s, %s, %s', + prettystr(actual), prettystr(expected), prettystr(margin)) + end + if margin < 0 then + error_fmt(3, 'almostEquals: margin must not be negative, current value is ' .. margin) + end + return math.abs(expected - actual) <= margin +end + +function M.assertAlmostEquals( actual, expected, margin, extra_msg_or_nil ) + -- check that two floats are close by margin + margin = margin or M.EPS + if type(margin) ~= 'number' then + error_fmt(2, 'almostEquals: margin must be a number, not %s', prettystr(margin)) + end + + if type(actual) == 'table' and type(expected) == 'table' then + -- handle almost equals for table + if not is_table_equals(actual, expected, margin) then + failure( errorMsgEquality(actual, expected, nil, margin), extra_msg_or_nil, 2 ) + end + elseif type(actual) == 'number' and type(expected) == 'number' and type(margin) == 'number' then + if not M.almostEquals(actual, expected, margin) then + if not M.ORDER_ACTUAL_EXPECTED then + expected, actual = actual, expected + end + local delta = math.abs(actual - expected) + fail_fmt(2, extra_msg_or_nil, 'Values are not almost equal\n' .. + 'Actual: %s, expected: %s, delta %s above margin of %s', + actual, expected, delta, margin) + end + else + error_fmt(3, 'almostEquals: must supply only number or table arguments.\nArguments supplied: %s, %s, %s', + prettystr(actual), prettystr(expected), prettystr(margin)) + end +end + +function M.assertNotEquals(actual, expected, extra_msg_or_nil) + if type(actual) ~= type(expected) then + return + end + + if type(actual) == 'table' and type(expected) == 'table' then + if not is_table_equals(actual, expected) then + return + end + elseif actual ~= expected then + return + end + fail_fmt(2, extra_msg_or_nil, 'Received the not expected value: %s', prettystr(actual)) +end + +function M.assertNotAlmostEquals( actual, expected, margin, extra_msg_or_nil ) + -- check that two floats are not close by margin + margin = margin or M.EPS + if M.almostEquals(actual, expected, margin) then + if not M.ORDER_ACTUAL_EXPECTED then + expected, actual = actual, expected + end + local delta = math.abs(actual - expected) + fail_fmt(2, extra_msg_or_nil, 'Values are almost equal\nActual: %s, expected: %s' .. + ', delta %s below margin of %s', + actual, expected, delta, margin) + end +end + +function M.assertItemsEquals(actual, expected, extra_msg_or_nil) + -- checks that the items of table expected + -- are contained in table actual. Warning, this function + -- is at least O(n^2) + if not _is_table_items_equals(actual, expected ) then + expected, actual = prettystrPairs(expected, actual) + fail_fmt(2, extra_msg_or_nil, 'Content of the tables are not identical:\nExpected: %s\nActual: %s', + expected, actual) + end +end + +------------------------------------------------------------------ +-- String assertion +------------------------------------------------------------------ + +function M.assertStrContains( str, sub, isPattern, extra_msg_or_nil ) + -- this relies on lua string.find function + -- a string always contains the empty string + -- assert( type(str) == 'string', 'Argument 1 of assertStrContains() should be a string.' ) ) + -- assert( type(sub) == 'string', 'Argument 2 of assertStrContains() should be a string.' ) ) + if not string.find(str, sub, 1, not isPattern) then + sub, str = prettystrPairs(sub, str, '\n') + fail_fmt(2, extra_msg_or_nil, 'Could not find %s %s in string %s', + isPattern and 'pattern' or 'substring', sub, str) + end +end + +function M.assertStrIContains( str, sub, extra_msg_or_nil ) + -- this relies on lua string.find function + -- a string always contains the empty string + if not string.find(str:lower(), sub:lower(), 1, true) then + sub, str = prettystrPairs(sub, str, '\n') + fail_fmt(2, extra_msg_or_nil, 'Could not find (case insensitively) substring %s in string %s', + sub, str) + end +end + +function M.assertNotStrContains( str, sub, isPattern, extra_msg_or_nil ) + -- this relies on lua string.find function + -- a string always contains the empty string + if string.find(str, sub, 1, not isPattern) then + sub, str = prettystrPairs(sub, str, '\n') + fail_fmt(2, extra_msg_or_nil, 'Found the not expected %s %s in string %s', + isPattern and 'pattern' or 'substring', sub, str) + end +end + +function M.assertNotStrIContains( str, sub, extra_msg_or_nil ) + -- this relies on lua string.find function + -- a string always contains the empty string + if string.find(str:lower(), sub:lower(), 1, true) then + sub, str = prettystrPairs(sub, str, '\n') + fail_fmt(2, extra_msg_or_nil, 'Found (case insensitively) the not expected substring %s in string %s', + sub, str) + end +end + +function M.assertStrMatches( str, pattern, start, final, extra_msg_or_nil ) + -- Verify a full match for the string + if not strMatch( str, pattern, start, final ) then + pattern, str = prettystrPairs(pattern, str, '\n') + fail_fmt(2, extra_msg_or_nil, 'Could not match pattern %s with string %s', + pattern, str) + end +end + +local function _assertErrorMsgEquals( stripFileAndLine, expectedMsg, func, ... ) + local no_error, error_msg = pcall( func, ... ) + if no_error then + failure( 'No error generated when calling function but expected error: '..M.prettystr(expectedMsg), nil, 3 ) + end + if type(expectedMsg) == "string" and type(error_msg) ~= "string" then + -- table are converted to string automatically + error_msg = tostring(error_msg) + end + local differ = false + if stripFileAndLine then + if error_msg:gsub("^.+:%d+: ", "") ~= expectedMsg then + differ = true + end + else + if error_msg ~= expectedMsg then + local tr = type(error_msg) + local te = type(expectedMsg) + if te == 'table' then + if tr ~= 'table' then + differ = true + else + local ok = pcall(M.assertItemsEquals, error_msg, expectedMsg) + if not ok then + differ = true + end + end + else + differ = true + end + end + end + + if differ then + error_msg, expectedMsg = prettystrPairs(error_msg, expectedMsg) + fail_fmt(3, nil, 'Error message expected: %s\nError message received: %s\n', + expectedMsg, error_msg) + end +end + +function M.assertErrorMsgEquals( expectedMsg, func, ... ) + -- assert that calling f with the arguments will raise an error + -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error + _assertErrorMsgEquals(false, expectedMsg, func, ...) +end + +function M.assertErrorMsgContentEquals(expectedMsg, func, ...) + _assertErrorMsgEquals(true, expectedMsg, func, ...) +end + +function M.assertErrorMsgContains( partialMsg, func, ... ) + -- assert that calling f with the arguments will raise an error + -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error + local no_error, error_msg = pcall( func, ... ) + if no_error then + failure( 'No error generated when calling function but expected error containing: '..prettystr(partialMsg), nil, 2 ) + end + if type(error_msg) ~= "string" then + error_msg = tostring(error_msg) + end + if not string.find( error_msg, partialMsg, nil, true ) then + error_msg, partialMsg = prettystrPairs(error_msg, partialMsg) + fail_fmt(2, nil, 'Error message does not contain: %s\nError message received: %s\n', + partialMsg, error_msg) + end +end + +function M.assertErrorMsgMatches( expectedMsg, func, ... ) + -- assert that calling f with the arguments will raise an error + -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error + local no_error, error_msg = pcall( func, ... ) + if no_error then + failure( 'No error generated when calling function but expected error matching: "'..expectedMsg..'"', nil, 2 ) + end + if type(error_msg) ~= "string" then + error_msg = tostring(error_msg) + end + if not strMatch( error_msg, expectedMsg ) then + expectedMsg, error_msg = prettystrPairs(expectedMsg, error_msg) + fail_fmt(2, nil, 'Error message does not match pattern: %s\nError message received: %s\n', + expectedMsg, error_msg) + end +end + +------------------------------------------------------------------ +-- Type assertions +------------------------------------------------------------------ + +function M.assertEvalToTrue(value, extra_msg_or_nil) + if not value then + failure("expected: a value evaluating to true, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertEvalToFalse(value, extra_msg_or_nil) + if value then + failure("expected: false or nil, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsTrue(value, extra_msg_or_nil) + if value ~= true then + failure("expected: true, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertNotIsTrue(value, extra_msg_or_nil) + if value == true then + failure("expected: not true, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsFalse(value, extra_msg_or_nil) + if value ~= false then + failure("expected: false, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertNotIsFalse(value, extra_msg_or_nil) + if value == false then + failure("expected: not false, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsNil(value, extra_msg_or_nil) + if value ~= nil then + failure("expected: nil, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertNotIsNil(value, extra_msg_or_nil) + if value == nil then + failure("expected: not nil, actual: nil", extra_msg_or_nil, 2) + end +end + +--[[ +Add type assertion functions to the module table M. Each of these functions +takes a single parameter "value", and checks that its Lua type matches the +expected string (derived from the function name): + +M.assertIsXxx(value) -> ensure that type(value) conforms to "xxx" +]] +for _, funcName in ipairs( + {'assertIsNumber', 'assertIsString', 'assertIsTable', 'assertIsBoolean', + 'assertIsFunction', 'assertIsUserdata', 'assertIsThread'} +) do + local typeExpected = funcName:match("^assertIs([A-Z]%a*)$") + -- Lua type() always returns lowercase, also make sure the match() succeeded + typeExpected = typeExpected and typeExpected:lower() + or error("bad function name '"..funcName.."' for type assertion") + + M[funcName] = function(value, extra_msg_or_nil) + if type(value) ~= typeExpected then + if type(value) == 'nil' then + fail_fmt(2, extra_msg_or_nil, 'expected: a %s value, actual: nil', + typeExpected, type(value), prettystrPairs(value)) + else + fail_fmt(2, extra_msg_or_nil, 'expected: a %s value, actual: type %s, value %s', + typeExpected, type(value), prettystrPairs(value)) + end + end + end +end + +--[[ +Add shortcuts for verifying type of a variable, without failure (luaunit v2 compatibility) +M.isXxx(value) -> returns true if type(value) conforms to "xxx" +]] +for _, typeExpected in ipairs( + {'Number', 'String', 'Table', 'Boolean', + 'Function', 'Userdata', 'Thread', 'Nil' } +) do + local typeExpectedLower = typeExpected:lower() + local isType = function(value) + return (type(value) == typeExpectedLower) + end + M['is'..typeExpected] = isType + M['is_'..typeExpectedLower] = isType +end + +--[[ +Add non-type assertion functions to the module table M. Each of these functions +takes a single parameter "value", and checks that its Lua type differs from the +expected string (derived from the function name): + +M.assertNotIsXxx(value) -> ensure that type(value) is not "xxx" +]] +for _, funcName in ipairs( + {'assertNotIsNumber', 'assertNotIsString', 'assertNotIsTable', 'assertNotIsBoolean', + 'assertNotIsFunction', 'assertNotIsUserdata', 'assertNotIsThread'} +) do + local typeUnexpected = funcName:match("^assertNotIs([A-Z]%a*)$") + -- Lua type() always returns lowercase, also make sure the match() succeeded + typeUnexpected = typeUnexpected and typeUnexpected:lower() + or error("bad function name '"..funcName.."' for type assertion") + + M[funcName] = function(value, extra_msg_or_nil) + if type(value) == typeUnexpected then + fail_fmt(2, extra_msg_or_nil, 'expected: not a %s type, actual: value %s', + typeUnexpected, prettystrPairs(value)) + end + end +end + +function M.assertIs(actual, expected, extra_msg_or_nil) + if actual ~= expected then + if not M.ORDER_ACTUAL_EXPECTED then + actual, expected = expected, actual + end + local old_print_table_ref_in_error_msg = M.PRINT_TABLE_REF_IN_ERROR_MSG + M.PRINT_TABLE_REF_IN_ERROR_MSG = true + expected, actual = prettystrPairs(expected, actual, '\n', '') + M.PRINT_TABLE_REF_IN_ERROR_MSG = old_print_table_ref_in_error_msg + fail_fmt(2, extra_msg_or_nil, 'expected and actual object should not be different\nExpected: %s\nReceived: %s', + expected, actual) + end +end + +function M.assertNotIs(actual, expected, extra_msg_or_nil) + if actual == expected then + local old_print_table_ref_in_error_msg = M.PRINT_TABLE_REF_IN_ERROR_MSG + M.PRINT_TABLE_REF_IN_ERROR_MSG = true + local s_expected + if not M.ORDER_ACTUAL_EXPECTED then + s_expected = prettystrPairs(actual) + else + s_expected = prettystrPairs(expected) + end + M.PRINT_TABLE_REF_IN_ERROR_MSG = old_print_table_ref_in_error_msg + fail_fmt(2, extra_msg_or_nil, 'expected and actual object should be different: %s', s_expected ) + end +end + + +------------------------------------------------------------------ +-- Scientific assertions +------------------------------------------------------------------ + + +function M.assertIsNaN(value, extra_msg_or_nil) + if type(value) ~= "number" or value == value then + failure("expected: NaN, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertNotIsNaN(value, extra_msg_or_nil) + if type(value) == "number" and value ~= value then + failure("expected: not NaN, actual: NaN", extra_msg_or_nil, 2) + end +end + +function M.assertIsInf(value, extra_msg_or_nil) + if type(value) ~= "number" or math.abs(value) ~= math.huge then + failure("expected: #Inf, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsPlusInf(value, extra_msg_or_nil) + if type(value) ~= "number" or value ~= math.huge then + failure("expected: #Inf, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsMinusInf(value, extra_msg_or_nil) + if type(value) ~= "number" or value ~= -math.huge then + failure("expected: -#Inf, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertNotIsPlusInf(value, extra_msg_or_nil) + if type(value) == "number" and value == math.huge then + failure("expected: not #Inf, actual: #Inf", extra_msg_or_nil, 2) + end +end + +function M.assertNotIsMinusInf(value, extra_msg_or_nil) + if type(value) == "number" and value == -math.huge then + failure("expected: not -#Inf, actual: -#Inf", extra_msg_or_nil, 2) + end +end + +function M.assertNotIsInf(value, extra_msg_or_nil) + if type(value) == "number" and math.abs(value) == math.huge then + failure("expected: not infinity, actual: " .. prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsPlusZero(value, extra_msg_or_nil) + if type(value) ~= 'number' or value ~= 0 then + failure("expected: +0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2) + else if (1/value == -math.huge) then + -- more precise error diagnosis + failure("expected: +0.0, actual: -0.0", extra_msg_or_nil, 2) + else if (1/value ~= math.huge) then + -- strange, case should have already been covered + failure("expected: +0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end + end + end +end + +function M.assertIsMinusZero(value, extra_msg_or_nil) + if type(value) ~= 'number' or value ~= 0 then + failure("expected: -0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2) + else if (1/value == math.huge) then + -- more precise error diagnosis + failure("expected: -0.0, actual: +0.0", extra_msg_or_nil, 2) + else if (1/value ~= -math.huge) then + -- strange, case should have already been covered + failure("expected: -0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end + end + end +end + +function M.assertNotIsPlusZero(value, extra_msg_or_nil) + if type(value) == 'number' and (1/value == math.huge) then + failure("expected: not +0.0, actual: +0.0", extra_msg_or_nil, 2) + end +end + +function M.assertNotIsMinusZero(value, extra_msg_or_nil) + if type(value) == 'number' and (1/value == -math.huge) then + failure("expected: not -0.0, actual: -0.0", extra_msg_or_nil, 2) + end +end + +function M.assertTableContains(t, expected, extra_msg_or_nil) + -- checks that table t contains the expected element + if table_findkeyof(t, expected) == nil then + t, expected = prettystrPairs(t, expected) + fail_fmt(2, extra_msg_or_nil, 'Table %s does NOT contain the expected element %s', + t, expected) + end +end + +function M.assertNotTableContains(t, expected, extra_msg_or_nil) + -- checks that table t doesn't contain the expected element + local k = table_findkeyof(t, expected) + if k ~= nil then + t, expected = prettystrPairs(t, expected) + fail_fmt(2, extra_msg_or_nil, 'Table %s DOES contain the unwanted element %s (at key %s)', + t, expected, prettystr(k)) + end +end + +---------------------------------------------------------------- +-- Compatibility layer +---------------------------------------------------------------- + +-- for compatibility with LuaUnit v2.x +function M.wrapFunctions() + -- In LuaUnit version <= 2.1 , this function was necessary to include + -- a test function inside the global test suite. Nowadays, the functions + -- are simply run directly as part of the test discovery process. + -- so just do nothing ! + io.stderr:write[[Use of WrapFunctions() is no longer needed. +Just prefix your test function names with "test" or "Test" and they +will be picked up and run by LuaUnit. +]] +end + +local list_of_funcs = { + -- { official function name , alias } + + -- general assertions + { 'assertEquals' , 'assert_equals' }, + { 'assertItemsEquals' , 'assert_items_equals' }, + { 'assertNotEquals' , 'assert_not_equals' }, + { 'assertAlmostEquals' , 'assert_almost_equals' }, + { 'assertNotAlmostEquals' , 'assert_not_almost_equals' }, + { 'assertEvalToTrue' , 'assert_eval_to_true' }, + { 'assertEvalToFalse' , 'assert_eval_to_false' }, + { 'assertStrContains' , 'assert_str_contains' }, + { 'assertStrIContains' , 'assert_str_icontains' }, + { 'assertNotStrContains' , 'assert_not_str_contains' }, + { 'assertNotStrIContains' , 'assert_not_str_icontains' }, + { 'assertStrMatches' , 'assert_str_matches' }, + { 'assertError' , 'assert_error' }, + { 'assertErrorMsgEquals' , 'assert_error_msg_equals' }, + { 'assertErrorMsgContains' , 'assert_error_msg_contains' }, + { 'assertErrorMsgMatches' , 'assert_error_msg_matches' }, + { 'assertErrorMsgContentEquals', 'assert_error_msg_content_equals' }, + { 'assertIs' , 'assert_is' }, + { 'assertNotIs' , 'assert_not_is' }, + { 'assertTableContains' , 'assert_table_contains' }, + { 'assertNotTableContains' , 'assert_not_table_contains' }, + { 'wrapFunctions' , 'WrapFunctions' }, + { 'wrapFunctions' , 'wrap_functions' }, + + -- type assertions: assertIsXXX -> assert_is_xxx + { 'assertIsNumber' , 'assert_is_number' }, + { 'assertIsString' , 'assert_is_string' }, + { 'assertIsTable' , 'assert_is_table' }, + { 'assertIsBoolean' , 'assert_is_boolean' }, + { 'assertIsNil' , 'assert_is_nil' }, + { 'assertIsTrue' , 'assert_is_true' }, + { 'assertIsFalse' , 'assert_is_false' }, + { 'assertIsNaN' , 'assert_is_nan' }, + { 'assertIsInf' , 'assert_is_inf' }, + { 'assertIsPlusInf' , 'assert_is_plus_inf' }, + { 'assertIsMinusInf' , 'assert_is_minus_inf' }, + { 'assertIsPlusZero' , 'assert_is_plus_zero' }, + { 'assertIsMinusZero' , 'assert_is_minus_zero' }, + { 'assertIsFunction' , 'assert_is_function' }, + { 'assertIsThread' , 'assert_is_thread' }, + { 'assertIsUserdata' , 'assert_is_userdata' }, + + -- type assertions: assertIsXXX -> assertXxx + { 'assertIsNumber' , 'assertNumber' }, + { 'assertIsString' , 'assertString' }, + { 'assertIsTable' , 'assertTable' }, + { 'assertIsBoolean' , 'assertBoolean' }, + { 'assertIsNil' , 'assertNil' }, + { 'assertIsTrue' , 'assertTrue' }, + { 'assertIsFalse' , 'assertFalse' }, + { 'assertIsNaN' , 'assertNaN' }, + { 'assertIsInf' , 'assertInf' }, + { 'assertIsPlusInf' , 'assertPlusInf' }, + { 'assertIsMinusInf' , 'assertMinusInf' }, + { 'assertIsPlusZero' , 'assertPlusZero' }, + { 'assertIsMinusZero' , 'assertMinusZero'}, + { 'assertIsFunction' , 'assertFunction' }, + { 'assertIsThread' , 'assertThread' }, + { 'assertIsUserdata' , 'assertUserdata' }, + + -- type assertions: assertIsXXX -> assert_xxx (luaunit v2 compat) + { 'assertIsNumber' , 'assert_number' }, + { 'assertIsString' , 'assert_string' }, + { 'assertIsTable' , 'assert_table' }, + { 'assertIsBoolean' , 'assert_boolean' }, + { 'assertIsNil' , 'assert_nil' }, + { 'assertIsTrue' , 'assert_true' }, + { 'assertIsFalse' , 'assert_false' }, + { 'assertIsNaN' , 'assert_nan' }, + { 'assertIsInf' , 'assert_inf' }, + { 'assertIsPlusInf' , 'assert_plus_inf' }, + { 'assertIsMinusInf' , 'assert_minus_inf' }, + { 'assertIsPlusZero' , 'assert_plus_zero' }, + { 'assertIsMinusZero' , 'assert_minus_zero' }, + { 'assertIsFunction' , 'assert_function' }, + { 'assertIsThread' , 'assert_thread' }, + { 'assertIsUserdata' , 'assert_userdata' }, + + -- type assertions: assertNotIsXXX -> assert_not_is_xxx + { 'assertNotIsNumber' , 'assert_not_is_number' }, + { 'assertNotIsString' , 'assert_not_is_string' }, + { 'assertNotIsTable' , 'assert_not_is_table' }, + { 'assertNotIsBoolean' , 'assert_not_is_boolean' }, + { 'assertNotIsNil' , 'assert_not_is_nil' }, + { 'assertNotIsTrue' , 'assert_not_is_true' }, + { 'assertNotIsFalse' , 'assert_not_is_false' }, + { 'assertNotIsNaN' , 'assert_not_is_nan' }, + { 'assertNotIsInf' , 'assert_not_is_inf' }, + { 'assertNotIsPlusInf' , 'assert_not_plus_inf' }, + { 'assertNotIsMinusInf' , 'assert_not_minus_inf' }, + { 'assertNotIsPlusZero' , 'assert_not_plus_zero' }, + { 'assertNotIsMinusZero' , 'assert_not_minus_zero' }, + { 'assertNotIsFunction' , 'assert_not_is_function' }, + { 'assertNotIsThread' , 'assert_not_is_thread' }, + { 'assertNotIsUserdata' , 'assert_not_is_userdata' }, + + -- type assertions: assertNotIsXXX -> assertNotXxx (luaunit v2 compat) + { 'assertNotIsNumber' , 'assertNotNumber' }, + { 'assertNotIsString' , 'assertNotString' }, + { 'assertNotIsTable' , 'assertNotTable' }, + { 'assertNotIsBoolean' , 'assertNotBoolean' }, + { 'assertNotIsNil' , 'assertNotNil' }, + { 'assertNotIsTrue' , 'assertNotTrue' }, + { 'assertNotIsFalse' , 'assertNotFalse' }, + { 'assertNotIsNaN' , 'assertNotNaN' }, + { 'assertNotIsInf' , 'assertNotInf' }, + { 'assertNotIsPlusInf' , 'assertNotPlusInf' }, + { 'assertNotIsMinusInf' , 'assertNotMinusInf' }, + { 'assertNotIsPlusZero' , 'assertNotPlusZero' }, + { 'assertNotIsMinusZero' , 'assertNotMinusZero' }, + { 'assertNotIsFunction' , 'assertNotFunction' }, + { 'assertNotIsThread' , 'assertNotThread' }, + { 'assertNotIsUserdata' , 'assertNotUserdata' }, + + -- type assertions: assertNotIsXXX -> assert_not_xxx + { 'assertNotIsNumber' , 'assert_not_number' }, + { 'assertNotIsString' , 'assert_not_string' }, + { 'assertNotIsTable' , 'assert_not_table' }, + { 'assertNotIsBoolean' , 'assert_not_boolean' }, + { 'assertNotIsNil' , 'assert_not_nil' }, + { 'assertNotIsTrue' , 'assert_not_true' }, + { 'assertNotIsFalse' , 'assert_not_false' }, + { 'assertNotIsNaN' , 'assert_not_nan' }, + { 'assertNotIsInf' , 'assert_not_inf' }, + { 'assertNotIsPlusInf' , 'assert_not_plus_inf' }, + { 'assertNotIsMinusInf' , 'assert_not_minus_inf' }, + { 'assertNotIsPlusZero' , 'assert_not_plus_zero' }, + { 'assertNotIsMinusZero' , 'assert_not_minus_zero' }, + { 'assertNotIsFunction' , 'assert_not_function' }, + { 'assertNotIsThread' , 'assert_not_thread' }, + { 'assertNotIsUserdata' , 'assert_not_userdata' }, + + -- all assertions with Coroutine duplicate Thread assertions + { 'assertIsThread' , 'assertIsCoroutine' }, + { 'assertIsThread' , 'assertCoroutine' }, + { 'assertIsThread' , 'assert_is_coroutine' }, + { 'assertIsThread' , 'assert_coroutine' }, + { 'assertNotIsThread' , 'assertNotIsCoroutine' }, + { 'assertNotIsThread' , 'assertNotCoroutine' }, + { 'assertNotIsThread' , 'assert_not_is_coroutine' }, + { 'assertNotIsThread' , 'assert_not_coroutine' }, +} + +-- Create all aliases in M +for _,v in ipairs( list_of_funcs ) do + local funcname, alias = v[1], v[2] + M[alias] = M[funcname] + + if EXPORT_ASSERT_TO_GLOBALS then + _G[funcname] = M[funcname] + _G[alias] = M[funcname] + end +end + +---------------------------------------------------------------- +-- +-- Outputters +-- +---------------------------------------------------------------- + +-- A common "base" class for outputters +-- For concepts involved (class inheritance) see http://www.lua.org/pil/16.2.html + +local genericOutput = { __class__ = 'genericOutput' } -- class +local genericOutput_MT = { __index = genericOutput } -- metatable +M.genericOutput = genericOutput -- publish, so that custom classes may derive from it + +function genericOutput.new(runner, default_verbosity) + -- runner is the "parent" object controlling the output, usually a LuaUnit instance + local t = { runner = runner } + if runner then + t.result = runner.result + t.verbosity = runner.verbosity or default_verbosity + t.fname = runner.fname + else + t.verbosity = default_verbosity + end + return setmetatable( t, genericOutput_MT) +end + +-- abstract ("empty") methods +function genericOutput:startSuite() + -- Called once, when the suite is started +end + +function genericOutput:startClass(className) + -- Called each time a new test class is started +end + +function genericOutput:startTest(testName) + -- called each time a new test is started, right before the setUp() + -- the current test status node is already created and available in: self.result.currentNode +end + +function genericOutput:updateStatus(node) + -- called with status failed or error as soon as the error/failure is encountered + -- this method is NOT called for a successful test because a test is marked as successful by default + -- and does not need to be updated +end + +function genericOutput:endTest(node) + -- called when the test is finished, after the tearDown() method +end + +function genericOutput:endClass() + -- called when executing the class is finished, before moving on to the next class of at the end of the test execution +end + +function genericOutput:endSuite() + -- called at the end of the test suite execution +end + + +---------------------------------------------------------------- +-- class TapOutput +---------------------------------------------------------------- + +local TapOutput = genericOutput.new() -- derived class +local TapOutput_MT = { __index = TapOutput } -- metatable +TapOutput.__class__ = 'TapOutput' + + -- For a good reference for TAP format, check: http://testanything.org/tap-specification.html + + function TapOutput.new(runner) + local t = genericOutput.new(runner, M.VERBOSITY_LOW) + return setmetatable( t, TapOutput_MT) + end + function TapOutput:startSuite() + print("1.."..self.result.selectedCount) + print('# Started on '..self.result.startDate) + end + function TapOutput:startClass(className) + if className ~= '[TestFunctions]' then + print('# Starting class: '..className) + end + end + + function TapOutput:updateStatus( node ) + if node:isSkipped() then + io.stdout:write("ok ", self.result.currentTestNumber, "\t# SKIP ", node.msg, "\n" ) + return + end + + io.stdout:write("not ok ", self.result.currentTestNumber, "\t", node.testName, "\n") + if self.verbosity > M.VERBOSITY_LOW then + print( prefixString( '# ', node.msg ) ) + end + if (node:isFailure() or node:isError()) and self.verbosity > M.VERBOSITY_DEFAULT then + print( prefixString( '# ', node.stackTrace ) ) + end + end + + function TapOutput:endTest( node ) + if node:isSuccess() then + io.stdout:write("ok ", self.result.currentTestNumber, "\t", node.testName, "\n") + end + end + + function TapOutput:endSuite() + print( '# '..M.LuaUnit.statusLine( self.result ) ) + return self.result.notSuccessCount + end + + +-- class TapOutput end + +---------------------------------------------------------------- +-- class JUnitOutput +---------------------------------------------------------------- + +-- See directory junitxml for more information about the junit format +local JUnitOutput = genericOutput.new() -- derived class +local JUnitOutput_MT = { __index = JUnitOutput } -- metatable +JUnitOutput.__class__ = 'JUnitOutput' + + function JUnitOutput.new(runner) + local t = genericOutput.new(runner, M.VERBOSITY_LOW) + t.testList = {} + return setmetatable( t, JUnitOutput_MT ) + end + + function JUnitOutput:startSuite() + -- open xml file early to deal with errors + if self.fname == nil then + error('With Junit, an output filename must be supplied with --name!') + end + if string.sub(self.fname,-4) ~= '.xml' then + self.fname = self.fname..'.xml' + end + self.fd = io.open(self.fname, "w") + if self.fd == nil then + error("Could not open file for writing: "..self.fname) + end + + print('# XML output to '..self.fname) + print('# Started on '..self.result.startDate) + end + function JUnitOutput:startClass(className) + if className ~= '[TestFunctions]' then + print('# Starting class: '..className) + end + end + function JUnitOutput:startTest(testName) + print('# Starting test: '..testName) + end + + function JUnitOutput:updateStatus( node ) + if node:isFailure() then + print( '# Failure: ' .. prefixString( '# ', node.msg ):sub(4, nil) ) + -- print('# ' .. node.stackTrace) + elseif node:isError() then + print( '# Error: ' .. prefixString( '# ' , node.msg ):sub(4, nil) ) + -- print('# ' .. node.stackTrace) + end + end + + function JUnitOutput:endSuite() + print( '# '..M.LuaUnit.statusLine(self.result)) + + -- XML file writing + self.fd:write('\n') + self.fd:write('\n') + self.fd:write(string.format( + ' \n', + self.result.runCount, self.result.startIsodate, self.result.duration, self.result.errorCount, self.result.failureCount, self.result.skippedCount )) + self.fd:write(" \n") + self.fd:write(string.format(' \n', _VERSION ) ) + self.fd:write(string.format(' \n', M.VERSION) ) + -- XXX please include system name and version if possible + self.fd:write(" \n") + + for i,node in ipairs(self.result.allTests) do + self.fd:write(string.format(' \n', + node.className, node.testName, node.duration ) ) + if node:isNotSuccess() then + self.fd:write(node:statusXML()) + end + self.fd:write(' \n') + end + + -- Next two lines are needed to validate junit ANT xsd, but really not useful in general: + self.fd:write(' \n') + self.fd:write(' \n') + + self.fd:write(' \n') + self.fd:write('\n') + self.fd:close() + return self.result.notSuccessCount + end + + +-- class TapOutput end + +---------------------------------------------------------------- +-- class TextOutput +---------------------------------------------------------------- + +--[[ Example of other unit-tests suite text output + +-- Python Non verbose: + +For each test: . or F or E + +If some failed tests: + ============== + ERROR / FAILURE: TestName (testfile.testclass) + --------- + Stack trace + + +then -------------- +then "Ran x tests in 0.000s" +then OK or FAILED (failures=1, error=1) + +-- Python Verbose: +testname (filename.classname) ... ok +testname (filename.classname) ... FAIL +testname (filename.classname) ... ERROR + +then -------------- +then "Ran x tests in 0.000s" +then OK or FAILED (failures=1, error=1) + +-- Ruby: +Started + . + Finished in 0.002695 seconds. + + 1 tests, 2 assertions, 0 failures, 0 errors + +-- Ruby: +>> ruby tc_simple_number2.rb +Loaded suite tc_simple_number2 +Started +F.. +Finished in 0.038617 seconds. + + 1) Failure: +test_failure(TestSimpleNumber) [tc_simple_number2.rb:16]: +Adding doesn't work. +<3> expected but was +<4>. + +3 tests, 4 assertions, 1 failures, 0 errors + +-- Java Junit +.......F. +Time: 0,003 +There was 1 failure: +1) testCapacity(junit.samples.VectorTest)junit.framework.AssertionFailedError + at junit.samples.VectorTest.testCapacity(VectorTest.java:87) + at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) + at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) + at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) + +FAILURES!!! +Tests run: 8, Failures: 1, Errors: 0 + + +-- Maven + +# mvn test +------------------------------------------------------- + T E S T S +------------------------------------------------------- +Running math.AdditionTest +Tests run: 2, Failures: 1, Errors: 0, Skipped: 0, Time elapsed: +0.03 sec <<< FAILURE! + +Results : + +Failed tests: + testLireSymbole(math.AdditionTest) + +Tests run: 2, Failures: 1, Errors: 0, Skipped: 0 + + +-- LuaUnit +---- non verbose +* display . or F or E when running tests +---- verbose +* display test name + ok/fail +---- +* blank line +* number) ERROR or FAILURE: TestName + Stack trace +* blank line +* number) ERROR or FAILURE: TestName + Stack trace + +then -------------- +then "Ran x tests in 0.000s (%d not selected, %d skipped)" +then OK or FAILED (failures=1, error=1) + + +]] + +local TextOutput = genericOutput.new() -- derived class +local TextOutput_MT = { __index = TextOutput } -- metatable +TextOutput.__class__ = 'TextOutput' + + function TextOutput.new(runner) + local t = genericOutput.new(runner, M.VERBOSITY_DEFAULT) + t.errorList = {} + return setmetatable( t, TextOutput_MT ) + end + + function TextOutput:startSuite() + if self.verbosity > M.VERBOSITY_DEFAULT then + print( 'Started on '.. self.result.startDate ) + end + end + + function TextOutput:startTest(testName) + if self.verbosity > M.VERBOSITY_DEFAULT then + io.stdout:write( " ", self.result.currentNode.testName, " ... " ) + end + end + + function TextOutput:endTest( node ) + if node:isSuccess() then + if self.verbosity > M.VERBOSITY_DEFAULT then + io.stdout:write("Ok\n") + else + io.stdout:write(".") + io.stdout:flush() + end + else + if self.verbosity > M.VERBOSITY_DEFAULT then + print( node.status ) + print( node.msg ) + --[[ + -- find out when to do this: + if self.verbosity > M.VERBOSITY_DEFAULT then + print( node.stackTrace ) + end + ]] + else + -- write only the first character of status E, F or S + io.stdout:write(string.sub(node.status, 1, 1)) + io.stdout:flush() + end + end + end + + function TextOutput:displayOneFailedTest( index, fail ) + print(index..") "..fail.testName ) + print( fail.msg ) + print( fail.stackTrace ) + print() + end + + function TextOutput:displayErroredTests() + if #self.result.errorTests ~= 0 then + print("Tests with errors:") + print("------------------") + for i, v in ipairs(self.result.errorTests) do + self:displayOneFailedTest(i, v) + end + end + end + + function TextOutput:displayFailedTests() + if #self.result.failedTests ~= 0 then + print("Failed tests:") + print("-------------") + for i, v in ipairs(self.result.failedTests) do + self:displayOneFailedTest(i, v) + end + end + end + + function TextOutput:endSuite() + if self.verbosity > M.VERBOSITY_DEFAULT then + print("=========================================================") + else + print() + end + self:displayErroredTests() + self:displayFailedTests() + print( M.LuaUnit.statusLine( self.result ) ) + if self.result.notSuccessCount == 0 then + print('OK') + end + end + +-- class TextOutput end + + +---------------------------------------------------------------- +-- class NilOutput +---------------------------------------------------------------- + +local function nopCallable() + --print(42) + return nopCallable +end + +local NilOutput = { __class__ = 'NilOuptut' } -- class +local NilOutput_MT = { __index = nopCallable } -- metatable + +function NilOutput.new(runner) + return setmetatable( { __class__ = 'NilOutput' }, NilOutput_MT ) +end + +---------------------------------------------------------------- +-- +-- class LuaUnit +-- +---------------------------------------------------------------- + +M.LuaUnit = { + outputType = TextOutput, + verbosity = M.VERBOSITY_DEFAULT, + __class__ = 'LuaUnit', + instances = {} +} +local LuaUnit_MT = { __index = M.LuaUnit } + +if EXPORT_ASSERT_TO_GLOBALS then + LuaUnit = M.LuaUnit +end + + function M.LuaUnit.new() + local newInstance = setmetatable( {}, LuaUnit_MT ) + return newInstance + end + + -----------------[[ Utility methods ]]--------------------- + + function M.LuaUnit.asFunction(aObject) + -- return "aObject" if it is a function, and nil otherwise + if 'function' == type(aObject) then + return aObject + end + end + + function M.LuaUnit.splitClassMethod(someName) + --[[ + Return a pair of className, methodName strings for a name in the form + "class.method". If no class part (or separator) is found, will return + nil, someName instead (the latter being unchanged). + + This convention thus also replaces the older isClassMethod() test: + You just have to check for a non-nil className (return) value. + ]] + local separator = string.find(someName, '.', 1, true) + if separator then + return someName:sub(1, separator - 1), someName:sub(separator + 1) + end + return nil, someName + end + + function M.LuaUnit.isMethodTestName( s ) + -- return true is the name matches the name of a test method + -- default rule is that is starts with 'Test' or with 'test' + return string.sub(s, 1, 4):lower() == 'test' + end + + function M.LuaUnit.isTestName( s ) + -- return true is the name matches the name of a test + -- default rule is that is starts with 'Test' or with 'test' + return string.sub(s, 1, 4):lower() == 'test' + end + + function M.LuaUnit.collectTests() + -- return a list of all test names in the global namespace + -- that match LuaUnit.isTestName + + local testNames = {} + for k, _ in pairs(_G) do + if type(k) == "string" and M.LuaUnit.isTestName( k ) then + table.insert( testNames , k ) + end + end + table.sort( testNames ) + return testNames + end + + function M.LuaUnit.parseCmdLine( cmdLine ) + -- parse the command line + -- Supported command line parameters: + -- --verbose, -v: increase verbosity + -- --quiet, -q: silence output + -- --error, -e: treat errors as fatal (quit program) + -- --output, -o, + name: select output type + -- --pattern, -p, + pattern: run test matching pattern, may be repeated + -- --exclude, -x, + pattern: run test not matching pattern, may be repeated + -- --shuffle, -s, : shuffle tests before reunning them + -- --name, -n, + fname: name of output file for junit, default to stdout + -- --repeat, -r, + num: number of times to execute each test + -- [testnames, ...]: run selected test names + -- + -- Returns a table with the following fields: + -- verbosity: nil, M.VERBOSITY_DEFAULT, M.VERBOSITY_QUIET, M.VERBOSITY_VERBOSE + -- output: nil, 'tap', 'junit', 'text', 'nil' + -- testNames: nil or a list of test names to run + -- exeRepeat: num or 1 + -- pattern: nil or a list of patterns + -- exclude: nil or a list of patterns + + local result, state = {}, nil + local SET_OUTPUT = 1 + local SET_PATTERN = 2 + local SET_EXCLUDE = 3 + local SET_FNAME = 4 + local SET_REPEAT = 5 + + if cmdLine == nil then + return result + end + + local function parseOption( option ) + if option == '--help' or option == '-h' then + result['help'] = true + return + elseif option == '--version' then + result['version'] = true + return + elseif option == '--verbose' or option == '-v' then + result['verbosity'] = M.VERBOSITY_VERBOSE + return + elseif option == '--quiet' or option == '-q' then + result['verbosity'] = M.VERBOSITY_QUIET + return + elseif option == '--error' or option == '-e' then + result['quitOnError'] = true + return + elseif option == '--failure' or option == '-f' then + result['quitOnFailure'] = true + return + elseif option == '--shuffle' or option == '-s' then + result['shuffle'] = true + return + elseif option == '--output' or option == '-o' then + state = SET_OUTPUT + return state + elseif option == '--name' or option == '-n' then + state = SET_FNAME + return state + elseif option == '--repeat' or option == '-r' then + state = SET_REPEAT + return state + elseif option == '--pattern' or option == '-p' then + state = SET_PATTERN + return state + elseif option == '--exclude' or option == '-x' then + state = SET_EXCLUDE + return state + end + error('Unknown option: '..option,3) + end + + local function setArg( cmdArg, state ) + if state == SET_OUTPUT then + result['output'] = cmdArg + return + elseif state == SET_FNAME then + result['fname'] = cmdArg + return + elseif state == SET_REPEAT then + result['exeRepeat'] = tonumber(cmdArg) + or error('Malformed -r argument: '..cmdArg) + return + elseif state == SET_PATTERN then + if result['pattern'] then + table.insert( result['pattern'], cmdArg ) + else + result['pattern'] = { cmdArg } + end + return + elseif state == SET_EXCLUDE then + local notArg = '!'..cmdArg + if result['pattern'] then + table.insert( result['pattern'], notArg ) + else + result['pattern'] = { notArg } + end + return + end + error('Unknown parse state: '.. state) + end + + + for i, cmdArg in ipairs(cmdLine) do + if state ~= nil then + setArg( cmdArg, state, result ) + state = nil + else + if cmdArg:sub(1,1) == '-' then + state = parseOption( cmdArg ) + else + if result['testNames'] then + table.insert( result['testNames'], cmdArg ) + else + result['testNames'] = { cmdArg } + end + end + end + end + + if result['help'] then + M.LuaUnit.help() + end + + if result['version'] then + M.LuaUnit.version() + end + + if state ~= nil then + error('Missing argument after '..cmdLine[ #cmdLine ],2 ) + end + + return result + end + + function M.LuaUnit.help() + print(M.USAGE) + os.exit(0) + end + + function M.LuaUnit.version() + print('LuaUnit v'..M.VERSION..' by Philippe Fremy ') + os.exit(0) + end + +---------------------------------------------------------------- +-- class NodeStatus +---------------------------------------------------------------- + + local NodeStatus = { __class__ = 'NodeStatus' } -- class + local NodeStatus_MT = { __index = NodeStatus } -- metatable + M.NodeStatus = NodeStatus + + -- values of status + NodeStatus.SUCCESS = 'SUCCESS' + NodeStatus.SKIP = 'SKIP' + NodeStatus.FAIL = 'FAIL' + NodeStatus.ERROR = 'ERROR' + + function NodeStatus.new( number, testName, className ) + -- default constructor, test are PASS by default + local t = { number = number, testName = testName, className = className } + setmetatable( t, NodeStatus_MT ) + t:success() + return t + end + + function NodeStatus:success() + self.status = self.SUCCESS + -- useless because lua does this for us, but it helps me remembering the relevant field names + self.msg = nil + self.stackTrace = nil + end + + function NodeStatus:skip(msg) + self.status = self.SKIP + self.msg = msg + self.stackTrace = nil + end + + function NodeStatus:fail(msg, stackTrace) + self.status = self.FAIL + self.msg = msg + self.stackTrace = stackTrace + end + + function NodeStatus:error(msg, stackTrace) + self.status = self.ERROR + self.msg = msg + self.stackTrace = stackTrace + end + + function NodeStatus:isSuccess() + return self.status == NodeStatus.SUCCESS + end + + function NodeStatus:isNotSuccess() + -- Return true if node is either failure or error or skip + return (self.status == NodeStatus.FAIL or self.status == NodeStatus.ERROR or self.status == NodeStatus.SKIP) + end + + function NodeStatus:isSkipped() + return self.status == NodeStatus.SKIP + end + + function NodeStatus:isFailure() + return self.status == NodeStatus.FAIL + end + + function NodeStatus:isError() + return self.status == NodeStatus.ERROR + end + + function NodeStatus:statusXML() + if self:isError() then + return table.concat( + {' \n', + ' \n'}) + elseif self:isFailure() then + return table.concat( + {' \n', + ' \n'}) + elseif self:isSkipped() then + return table.concat({' ', xmlEscape(self.msg),'\n' } ) + end + return ' \n' -- (not XSD-compliant! normally shouldn't get here) + end + + --------------[[ Output methods ]]------------------------- + + local function conditional_plural(number, singular) + -- returns a grammatically well-formed string "%d " + local suffix = '' + if number ~= 1 then -- use plural + suffix = (singular:sub(-2) == 'ss') and 'es' or 's' + end + return string.format('%d %s%s', number, singular, suffix) + end + + function M.LuaUnit.statusLine(result) + -- return status line string according to results + local s = { + string.format('Ran %d tests in %0.3f seconds', + result.runCount, result.duration), + conditional_plural(result.successCount, 'success'), + } + if result.notSuccessCount > 0 then + if result.failureCount > 0 then + table.insert(s, conditional_plural(result.failureCount, 'failure')) + end + if result.errorCount > 0 then + table.insert(s, conditional_plural(result.errorCount, 'error')) + end + else + table.insert(s, '0 failures') + end + if result.skippedCount > 0 then + table.insert(s, string.format("%d skipped", result.skippedCount)) + end + if result.nonSelectedCount > 0 then + table.insert(s, string.format("%d non-selected", result.nonSelectedCount)) + end + return table.concat(s, ', ') + end + + function M.LuaUnit:startSuite(selectedCount, nonSelectedCount) + self.result = { + selectedCount = selectedCount, + nonSelectedCount = nonSelectedCount, + successCount = 0, + runCount = 0, + currentTestNumber = 0, + currentClassName = "", + currentNode = nil, + suiteStarted = true, + startTime = os.clock(), + startDate = os.date(os.getenv('LUAUNIT_DATEFMT')), + startIsodate = os.date('%Y-%m-%dT%H:%M:%S'), + patternIncludeFilter = self.patternIncludeFilter, + + -- list of test node status + allTests = {}, + failedTests = {}, + errorTests = {}, + skippedTests = {}, + + failureCount = 0, + errorCount = 0, + notSuccessCount = 0, + skippedCount = 0, + } + + self.outputType = self.outputType or TextOutput + self.output = self.outputType.new(self) + self.output:startSuite() + end + + function M.LuaUnit:startClass( className, classInstance ) + self.result.currentClassName = className + self.output:startClass( className ) + self:setupClass( className, classInstance ) + end + + function M.LuaUnit:startTest( testName ) + self.result.currentTestNumber = self.result.currentTestNumber + 1 + self.result.runCount = self.result.runCount + 1 + self.result.currentNode = NodeStatus.new( + self.result.currentTestNumber, + testName, + self.result.currentClassName + ) + self.result.currentNode.startTime = os.clock() + table.insert( self.result.allTests, self.result.currentNode ) + self.output:startTest( testName ) + end + + function M.LuaUnit:updateStatus( err ) + -- "err" is expected to be a table / result from protectedCall() + if err.status == NodeStatus.SUCCESS then + return + end + + local node = self.result.currentNode + + --[[ As a first approach, we will report only one error or one failure for one test. + + However, we can have the case where the test is in failure, and the teardown is in error. + In such case, it's a good idea to report both a failure and an error in the test suite. This is + what Python unittest does for example. However, it mixes up counts so need to be handled carefully: for + example, there could be more (failures + errors) count that tests. What happens to the current node ? + + We will do this more intelligent version later. + ]] + + -- if the node is already in failure/error, just don't report the new error (see above) + if node.status ~= NodeStatus.SUCCESS then + return + end + + if err.status == NodeStatus.FAIL then + node:fail( err.msg, err.trace ) + table.insert( self.result.failedTests, node ) + elseif err.status == NodeStatus.ERROR then + node:error( err.msg, err.trace ) + table.insert( self.result.errorTests, node ) + elseif err.status == NodeStatus.SKIP then + node:skip( err.msg ) + table.insert( self.result.skippedTests, node ) + else + error('No such status: ' .. prettystr(err.status)) + end + + self.output:updateStatus( node ) + end + + function M.LuaUnit:endTest() + local node = self.result.currentNode + -- print( 'endTest() '..prettystr(node)) + -- print( 'endTest() '..prettystr(node:isNotSuccess())) + node.duration = os.clock() - node.startTime + node.startTime = nil + self.output:endTest( node ) + + if node:isSuccess() then + self.result.successCount = self.result.successCount + 1 + elseif node:isError() then + if self.quitOnError or self.quitOnFailure then + -- Runtime error - abort test execution as requested by + -- "--error" option. This is done by setting a special + -- flag that gets handled in internalRunSuiteByInstances(). + print("\nERROR during LuaUnit test execution:\n" .. node.msg) + self.result.aborted = true + end + elseif node:isFailure() then + if self.quitOnFailure then + -- Failure - abort test execution as requested by + -- "--failure" option. This is done by setting a special + -- flag that gets handled in internalRunSuiteByInstances(). + print("\nFailure during LuaUnit test execution:\n" .. node.msg) + self.result.aborted = true + end + elseif node:isSkipped() then + self.result.runCount = self.result.runCount - 1 + else + error('No such node status: ' .. prettystr(node.status)) + end + self.result.currentNode = nil + end + + function M.LuaUnit:endClass() + self:teardownClass( self.lastClassName, self.lastClassInstance ) + self.output:endClass() + end + + function M.LuaUnit:endSuite() + if self.result.suiteStarted == false then + error('LuaUnit:endSuite() -- suite was already ended' ) + end + self.result.duration = os.clock()-self.result.startTime + self.result.suiteStarted = false + + -- Expose test counts for outputter's endSuite(). This could be managed + -- internally instead by using the length of the lists of failed tests + -- but unit tests rely on these fields being present. + self.result.failureCount = #self.result.failedTests + self.result.errorCount = #self.result.errorTests + self.result.notSuccessCount = self.result.failureCount + self.result.errorCount + self.result.skippedCount = #self.result.skippedTests + + self.output:endSuite() + end + + function M.LuaUnit:setOutputType(outputType, fname) + -- Configures LuaUnit runner output + -- outputType is one of: NIL, TAP, JUNIT, TEXT + -- when outputType is junit, the additional argument fname is used to set the name of junit output file + -- for other formats, fname is ignored + if outputType:upper() == "NIL" then + self.outputType = NilOutput + return + end + if outputType:upper() == "TAP" then + self.outputType = TapOutput + return + end + if outputType:upper() == "JUNIT" then + self.outputType = JUnitOutput + if fname then + self.fname = fname + end + return + end + if outputType:upper() == "TEXT" then + self.outputType = TextOutput + return + end + error( 'No such format: '..outputType,2) + end + + --------------[[ Runner ]]----------------- + + function M.LuaUnit:protectedCall(classInstance, methodInstance, prettyFuncName) + -- if classInstance is nil, this is just a function call + -- else, it's method of a class being called. + + local function err_handler(e) + -- transform error into a table, adding the traceback information + return { + status = NodeStatus.ERROR, + msg = e, + trace = string.sub(debug.traceback("", 1), 2) + } + end + + local ok, err + if classInstance then + -- stupid Lua < 5.2 does not allow xpcall with arguments so let's use a workaround + ok, err = xpcall( function () methodInstance(classInstance) end, err_handler ) + else + ok, err = xpcall( function () methodInstance() end, err_handler ) + end + if ok then + return {status = NodeStatus.SUCCESS} + end + -- print('ok="'..prettystr(ok)..'" err="'..prettystr(err)..'"') + + local iter_msg + iter_msg = self.exeRepeat and 'iteration '..self.currentCount + + err.msg, err.status = M.adjust_err_msg_with_iter( err.msg, iter_msg ) + + if err.status == NodeStatus.SUCCESS or err.status == NodeStatus.SKIP then + err.trace = nil + return err + end + + -- reformat / improve the stack trace + if prettyFuncName then -- we do have the real method name + err.trace = err.trace:gsub("in (%a+) 'methodInstance'", "in %1 '"..prettyFuncName.."'") + end + if STRIP_LUAUNIT_FROM_STACKTRACE then + err.trace = stripLuaunitTrace2(err.trace, err.msg) + end + + return err -- return the error "object" (table) + end + + + function M.LuaUnit:execOneFunction(className, methodName, classInstance, methodInstance) + -- When executing a test function, className and classInstance must be nil + -- When executing a class method, all parameters must be set + + if type(methodInstance) ~= 'function' then + self:unregisterSuite() + error( tostring(methodName)..' must be a function, not '..type(methodInstance)) + end + + local prettyFuncName + if className == nil then + className = '[TestFunctions]' + prettyFuncName = methodName + else + prettyFuncName = className..'.'..methodName + end + + if self.lastClassName ~= className then + if self.lastClassName ~= nil then + self:endClass() + end + self:startClass( className, classInstance ) + self.lastClassName = className + self.lastClassInstance = classInstance + end + + self:startTest(prettyFuncName) + + local node = self.result.currentNode + for iter_n = 1, self.exeRepeat or 1 do + if node:isNotSuccess() then + break + end + self.currentCount = iter_n + + -- run setUp first (if any) + if classInstance then + local func = self.asFunction( classInstance.setUp ) or + self.asFunction( classInstance.Setup ) or + self.asFunction( classInstance.setup ) or + self.asFunction( classInstance.SetUp ) + if func then + self:updateStatus(self:protectedCall(classInstance, func, className..'.setUp')) + end + end + + -- run testMethod() + if node:isSuccess() then + self:updateStatus(self:protectedCall(classInstance, methodInstance, prettyFuncName)) + end + + -- lastly, run tearDown (if any) + if classInstance then + local func = self.asFunction( classInstance.tearDown ) or + self.asFunction( classInstance.TearDown ) or + self.asFunction( classInstance.teardown ) or + self.asFunction( classInstance.Teardown ) + if func then + self:updateStatus(self:protectedCall(classInstance, func, className..'.tearDown')) + end + end + end + + self:endTest() + end + + function M.LuaUnit.expandOneClass( result, className, classInstance ) + --[[ + Input: a list of { name, instance }, a class name, a class instance + Ouptut: modify result to add all test method instance in the form: + { className.methodName, classInstance } + ]] + for methodName, methodInstance in sortedPairs(classInstance) do + if M.LuaUnit.asFunction(methodInstance) and M.LuaUnit.isMethodTestName( methodName ) then + table.insert( result, { className..'.'..methodName, classInstance } ) + end + end + end + + function M.LuaUnit.expandClasses( listOfNameAndInst ) + --[[ + -- expand all classes (provided as {className, classInstance}) to a list of {className.methodName, classInstance} + -- functions and methods remain untouched + + Input: a list of { name, instance } + + Output: + * { function name, function instance } : do nothing + * { class.method name, class instance }: do nothing + * { class name, class instance } : add all method names in the form of (className.methodName, classInstance) + ]] + local result = {} + + for i,v in ipairs( listOfNameAndInst ) do + local name, instance = v[1], v[2] + if M.LuaUnit.asFunction(instance) then + table.insert( result, { name, instance } ) + else + if type(instance) ~= 'table' then + error( 'Instance must be a table or a function, not a '..type(instance)..' with value '..prettystr(instance)) + end + local className, methodName = M.LuaUnit.splitClassMethod( name ) + if className then + local methodInstance = instance[methodName] + if methodInstance == nil then + error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) ) + end + table.insert( result, { name, instance } ) + else + M.LuaUnit.expandOneClass( result, name, instance ) + end + end + end + + return result + end + + function M.LuaUnit.applyPatternFilter( patternIncFilter, listOfNameAndInst ) + local included, excluded = {}, {} + for i, v in ipairs( listOfNameAndInst ) do + -- local name, instance = v[1], v[2] + if patternFilter( patternIncFilter, v[1] ) then + table.insert( included, v ) + else + table.insert( excluded, v ) + end + end + return included, excluded + end + + local function getKeyInListWithGlobalFallback( key, listOfNameAndInst ) + local result = nil + for i,v in ipairs( listOfNameAndInst ) do + if(listOfNameAndInst[i][1] == key) then + result = listOfNameAndInst[i][2] + break + end + end + if(not M.LuaUnit.asFunction( result ) ) then + result = _G[key] + end + return result + end + + function M.LuaUnit:setupSuite( listOfNameAndInst ) + local setupSuite = getKeyInListWithGlobalFallback("setupSuite", listOfNameAndInst) + if self.asFunction( setupSuite ) then + self:updateStatus( self:protectedCall( nil, setupSuite, 'setupSuite' ) ) + end + end + + function M.LuaUnit:teardownSuite(listOfNameAndInst) + local teardownSuite = getKeyInListWithGlobalFallback("teardownSuite", listOfNameAndInst) + if self.asFunction( teardownSuite ) then + self:updateStatus( self:protectedCall( nil, teardownSuite, 'teardownSuite') ) + end + end + + function M.LuaUnit:setupClass( className, instance ) + if type( instance ) == 'table' and self.asFunction( instance.setupClass ) then + self:updateStatus( self:protectedCall( instance, instance.setupClass, className..'.setupClass' ) ) + end + end + + function M.LuaUnit:teardownClass( className, instance ) + if type( instance ) == 'table' and self.asFunction( instance.teardownClass ) then + self:updateStatus( self:protectedCall( instance, instance.teardownClass, className..'.teardownClass' ) ) + end + end + + function M.LuaUnit:internalRunSuiteByInstances( listOfNameAndInst ) + --[[ Run an explicit list of tests. Each item of the list must be one of: + * { function name, function instance } + * { class name, class instance } + * { class.method name, class instance } + + This function is internal to LuaUnit. The official API to perform this action is runSuiteByInstances() + ]] + + local expandedList = self.expandClasses( listOfNameAndInst ) + if self.shuffle then + randomizeTable( expandedList ) + end + local filteredList, filteredOutList = self.applyPatternFilter( + self.patternIncludeFilter, expandedList ) + + self:startSuite( #filteredList, #filteredOutList ) + self:setupSuite( listOfNameAndInst ) + + for i,v in ipairs( filteredList ) do + local name, instance = v[1], v[2] + if M.LuaUnit.asFunction(instance) then + self:execOneFunction( nil, name, nil, instance ) + else + -- expandClasses() should have already taken care of sanitizing the input + assert( type(instance) == 'table' ) + local className, methodName = M.LuaUnit.splitClassMethod( name ) + assert( className ~= nil ) + local methodInstance = instance[methodName] + assert(methodInstance ~= nil) + self:execOneFunction( className, methodName, instance, methodInstance ) + end + if self.result.aborted then + break -- "--error" or "--failure" option triggered + end + end + + if self.lastClassName ~= nil then + self:endClass() + end + + self:teardownSuite( listOfNameAndInst ) + self:endSuite() + + if self.result.aborted then + print("LuaUnit ABORTED (as requested by --error or --failure option)") + self:unregisterSuite() + os.exit(-2) + end + end + + function M.LuaUnit:internalRunSuiteByNames( listOfName ) + --[[ Run LuaUnit with a list of generic names, coming either from command-line or from global + namespace analysis. Convert the list into a list of (name, valid instances (table or function)) + and calls internalRunSuiteByInstances. + ]] + + local instanceName, instance + local listOfNameAndInst = {} + + for i,name in ipairs( listOfName ) do + local className, methodName = M.LuaUnit.splitClassMethod( name ) + if className then + instanceName = className + instance = _G[instanceName] + + if instance == nil then + self:unregisterSuite() + error( "No such name in global space: "..instanceName ) + end + + if type(instance) ~= 'table' then + self:unregisterSuite() + error( 'Instance of '..instanceName..' must be a table, not '..type(instance)) + end + + local methodInstance = instance[methodName] + if methodInstance == nil then + self:unregisterSuite() + error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) ) + end + + else + -- for functions and classes + instanceName = name + instance = _G[instanceName] + end + + if instance == nil then + self:unregisterSuite() + error( "No such name in global space: "..instanceName ) + end + + if (type(instance) ~= 'table' and type(instance) ~= 'function') then + self:unregisterSuite() + error( 'Name must match a function or a table: '..instanceName ) + end + + table.insert( listOfNameAndInst, { name, instance } ) + end + + self:internalRunSuiteByInstances( listOfNameAndInst ) + end + + function M.LuaUnit.run(...) + -- Run some specific test classes. + -- If no arguments are passed, run the class names specified on the + -- command line. If no class name is specified on the command line + -- run all classes whose name starts with 'Test' + -- + -- If arguments are passed, they must be strings of the class names + -- that you want to run or generic command line arguments (-o, -p, -v, ...) + local runner = M.LuaUnit.new() + return runner:runSuite(...) + end + + function M.LuaUnit:registerSuite() + -- register the current instance into our global array of instances + -- print('-> Register suite') + M.LuaUnit.instances[ #M.LuaUnit.instances+1 ] = self + end + + function M.unregisterCurrentSuite() + -- force unregister the last registered suite + table.remove(M.LuaUnit.instances, #M.LuaUnit.instances) + end + + function M.LuaUnit:unregisterSuite() + -- print('<- Unregister suite') + -- remove our current instqances from the global array of instances + local instanceIdx = nil + for i, instance in ipairs(M.LuaUnit.instances) do + if instance == self then + instanceIdx = i + break + end + end + + if instanceIdx ~= nil then + table.remove(M.LuaUnit.instances, instanceIdx) + -- print('Unregister done') + end + + end + + function M.LuaUnit:initFromArguments( ... ) + --[[Parses all arguments from either command-line or direct call and set internal + flags of LuaUnit runner according to it. + + Return the list of names which were possibly passed on the command-line or as arguments + ]] + local args = {...} + if type(args[1]) == 'table' and args[1].__class__ == 'LuaUnit' then + -- run was called with the syntax M.LuaUnit:runSuite() + -- we support both M.LuaUnit.run() and M.LuaUnit:run() + -- strip out the first argument self to make it a command-line argument list + table.remove(args,1) + end + + if #args == 0 then + args = cmdline_argv + end + + local options = pcall_or_abort( M.LuaUnit.parseCmdLine, args ) + + -- We expect these option fields to be either `nil` or contain + -- valid values, so it's safe to always copy them directly. + self.verbosity = options.verbosity + self.quitOnError = options.quitOnError + self.quitOnFailure = options.quitOnFailure + + self.exeRepeat = options.exeRepeat + self.patternIncludeFilter = options.pattern + self.shuffle = options.shuffle + + options.output = options.output or os.getenv('LUAUNIT_OUTPUT') + options.fname = options.fname or os.getenv('LUAUNIT_JUNIT_FNAME') + + if options.output then + if options.output:lower() == 'junit' and options.fname == nil then + print('With junit output, a filename must be supplied with -n or --name') + os.exit(-1) + end + pcall_or_abort(self.setOutputType, self, options.output, options.fname) + end + + return options.testNames + end + + function M.LuaUnit:runSuite( ... ) + testNames = self:initFromArguments(...) + self:registerSuite() + self:internalRunSuiteByNames( testNames or M.LuaUnit.collectTests() ) + self:unregisterSuite() + return self.result.notSuccessCount + end + + function M.LuaUnit:runSuiteByInstances( listOfNameAndInst, commandLineArguments ) + --[[ + Run all test functions or tables provided as input. + + Input: a list of { name, instance } + instance can either be a function or a table containing test functions starting with the prefix "test" + + return the number of failures and errors, 0 meaning success + ]] + -- parse the command-line arguments + testNames = self:initFromArguments( commandLineArguments ) + self:registerSuite() + self:internalRunSuiteByInstances( listOfNameAndInst ) + self:unregisterSuite() + return self.result.notSuccessCount + end + + + +-- class LuaUnit + +-- For compatbility with LuaUnit v2 +M.run = M.LuaUnit.run +M.Run = M.LuaUnit.run + +function M:setVerbosity( verbosity ) + -- set the verbosity value (as integer) + M.LuaUnit.verbosity = verbosity +end +M.set_verbosity = M.setVerbosity +M.SetVerbosity = M.setVerbosity + + +return M + diff --git a/src/Fable.AST/Plugins.fs b/src/Fable.AST/Plugins.fs index b6f5ce0a04..da749efeb5 100644 --- a/src/Fable.AST/Plugins.fs +++ b/src/Fable.AST/Plugins.fs @@ -16,6 +16,7 @@ type Language = | Php | Dart | Rust + | Lua override this.ToString() = match this with @@ -25,6 +26,7 @@ type Language = | Php -> "PHP" | Dart -> "Dart" | Rust -> "Rust" + | Lua -> "Lua" type CompilerOptions = { diff --git a/src/Fable.Build/Fable.Build.fsproj b/src/Fable.Build/Fable.Build.fsproj index 21aedbefd0..52ce78143a 100644 --- a/src/Fable.Build/Fable.Build.fsproj +++ b/src/Fable.Build/Fable.Build.fsproj @@ -16,6 +16,7 @@ + @@ -27,6 +28,7 @@ + @@ -36,6 +38,7 @@ + diff --git a/src/Fable.Build/FableLibrary/Lua.fs b/src/Fable.Build/FableLibrary/Lua.fs new file mode 100644 index 0000000000..4873c426a3 --- /dev/null +++ b/src/Fable.Build/FableLibrary/Lua.fs @@ -0,0 +1,18 @@ +namespace Build.FableLibrary + +open System.IO +open Fake.IO + +type BuildFableLibraryLua() = + inherit + BuildFableLibrary( + "lua", + Path.Combine("src", "fable-library-lua"), + Path.Combine("src", "fable-library-lua"), + Path.Combine("temp", "fable-library-lua"), + Path.Combine("temp", "fable-library-lua"), + Path.Combine(".", "temp", "fable-library-lua") + ) + + override this.CopyStage() = + Directory.GetFiles(this.LibraryDir, "*") |> Shell.copyFiles this.BuildDir diff --git a/src/Fable.Build/Main.fs b/src/Fable.Build/Main.fs index db0ae5a746..19874145ee 100644 --- a/src/Fable.Build/Main.fs +++ b/src/Fable.Build/Main.fs @@ -19,6 +19,7 @@ Available commands: --python Build fable-library for Python --dart Build fable-library for Dart --rust Build fable-library for Rust + --lua Build fable-library for Lua quicktest Watch for changes and re-run the quicktest This is useful to work on a feature in an isolated @@ -30,6 +31,7 @@ Available commands: python Run for Python dart Run for Dart rust Run for Rust + lua Run for Lua Options: --skip-fable-library Skip building fable-library if folder already exists @@ -41,6 +43,7 @@ Available commands: python Run the tests for Python dart Run the tests for Dart rust Run the tests for Rust + lua Run the tests for Lua integration Run the integration test suite standalone Tests the standalone version of Fable (Fable running on top of Node.js) @@ -119,6 +122,7 @@ let main argv = | "--python" :: _ -> BuildFableLibraryPython().Run() | "--dart" :: _ -> BuildFableLibraryDart().Run() | "--rust" :: _ -> BuildFableLibraryRust().Run() + | "--lua" :: _ -> BuildFableLibraryLua().Run() | _ -> printHelp () | "test" :: args -> match args with @@ -132,6 +136,7 @@ let main argv = // This test is using quicktest project for now, // because it can't compile (yet?) the Main JavaScript tests | "compiler-js" :: _ -> Test.CompilerJs.handle args + | "lua" :: args -> Test.Lua.handle args | _ -> printHelp () | "quicktest" :: args -> match args with @@ -140,6 +145,7 @@ let main argv = | "python" :: _ -> Quicktest.Python.handle args | "dart" :: _ -> Quicktest.Dart.handle args | "rust" :: _ -> Quicktest.Rust.handle args + | "lua" :: _ -> Quicktest.Lua.handle args | _ -> printHelp () | "standalone" :: args -> Standalone.handle args | "compiler-js" :: args -> CompilerJs.handle args diff --git a/src/Fable.Build/Quicktest/Lua.fs b/src/Fable.Build/Quicktest/Lua.fs new file mode 100644 index 0000000000..171abe509d --- /dev/null +++ b/src/Fable.Build/Quicktest/Lua.fs @@ -0,0 +1,15 @@ +module Build.Quicktest.Lua + +open Build.FableLibrary +open Build.Quicktest.Core +//TODO: Quicktest still needs to be actually written. +let handle (args: string list) = + genericQuicktest + { + Language = "lua" + FableLibBuilder = BuildFableLibraryLua() + ProjectDir = "src/quicktest-lua" + Extension = ".lua" + RunMode = RunScript + } + args diff --git a/src/Fable.Build/Test/Lua.fs b/src/Fable.Build/Test/Lua.fs new file mode 100644 index 0000000000..d0d2a5821c --- /dev/null +++ b/src/Fable.Build/Test/Lua.fs @@ -0,0 +1,63 @@ +module Build.Test.Lua + +open Build.FableLibrary +open System.IO +open Build.Utils +open BlackFox.CommandLine +open SimpleExec +open Fake.IO + +let private buildDir = Path.Resolve("temp", "tests", "Lua") +let private testsFolder = Path.Resolve("tests", "Lua") +let private testsFsprojFolder = Path.Resolve("tests", "Lua") + +let handle (args: string list) = + let skipFableLibrary = args |> List.contains "--skip-fable-library" + let isWatch = args |> List.contains "--watch" + let noDotnet = args |> List.contains "--no-dotnet" + + BuildFableLibraryLua().Run(skipFableLibrary) + + Directory.clean buildDir + + Directory.GetFiles(testsFolder, "*") |> Seq.iter (Shell.copyFile buildDir) + + Directory.GetFiles(testsFolder, "*.lua") |> Seq.iter (Shell.copyFile buildDir) + + let testCmd = $"lua test {buildDir}/main.lua" + + let fableArgs = + CmdLine.concat + [ + CmdLine.empty + |> CmdLine.appendRaw testsFsprojFolder + |> CmdLine.appendPrefix "--outDir" (buildDir "src") + |> CmdLine.appendPrefix "--lang" "lua" + |> CmdLine.appendPrefix "--exclude" "Fable.Core" + |> CmdLine.appendRaw "--noCache" + + if isWatch then + CmdLine.empty + |> CmdLine.appendRaw "--watch" + |> CmdLine.appendRaw "--runWatch" + |> CmdLine.appendRaw testCmd + else + CmdLine.empty |> CmdLine.appendRaw "--run" |> CmdLine.appendRaw testCmd + ] + + if isWatch then + Async.Parallel + [ + if not noDotnet then + Command.RunAsync("dotnet", "watch test -c Release", workingDirectory = testsFsprojFolder) + |> Async.AwaitTask + + Command.WatchFableAsync(fableArgs, workingDirectory = buildDir) + |> Async.AwaitTask + ] + |> Async.RunSynchronously + |> ignore + else + Command.Run("dotnet", "test -c Release", workingDirectory = testsFsprojFolder) + + Command.Fable(fableArgs, workingDirectory = buildDir) diff --git a/src/Fable.Cli/Entry.fs b/src/Fable.Cli/Entry.fs index d889896734..0a211bb2c1 100644 --- a/src/Fable.Cli/Entry.fs +++ b/src/Fable.Cli/Entry.fs @@ -185,6 +185,8 @@ let argLanguage (args: CliArgs) = | "javascript" -> Ok JavaScript | "ts" | "typescript" -> Ok TypeScript + | "lua" + | "Lua" -> Ok Lua | "py" | "python" -> Ok Python | "php" -> Ok Php @@ -199,6 +201,7 @@ let argLanguage (args: CliArgs) = "Available options:" " - javascript (alias js)" " - typescript (alias ts)" + " - lua" " - python (alias py)" " - rust (alias rs)" " - php" @@ -319,6 +322,7 @@ type Runner = | Python -> "FABLE_COMPILER_PYTHON" | TypeScript -> "FABLE_COMPILER_TYPESCRIPT" | JavaScript -> "FABLE_COMPILER_JAVASCRIPT" + | Lua -> "FABLE_COMPILER_LUA" ] |> List.distinct @@ -464,6 +468,7 @@ let getStatus = | Rust -> "alpha" | Dart -> "beta" | Php -> "experimental" + | Lua -> "experimental" let getLibPkgVersion = function @@ -472,7 +477,8 @@ let getLibPkgVersion = | Python | Rust | Dart - | Php -> None + | Php + | Lua -> None let private logPrelude commands language = match commands with diff --git a/src/Fable.Cli/Pipeline.fs b/src/Fable.Cli/Pipeline.fs index af45bde624..144d63c1ba 100644 --- a/src/Fable.Cli/Pipeline.fs +++ b/src/Fable.Cli/Pipeline.fs @@ -475,6 +475,51 @@ module Rust = do! RustPrinter.run writer crate } +module Lua = + open Fable.Transforms + + type LuaWriter(com: Compiler, cliArgs: CliArgs, pathResolver, targetPath: string) = + let sourcePath = com.CurrentFile + let fileExt = cliArgs.CompilerOptions.FileExtension + let stream = new IO.StreamWriter(targetPath) + + interface Printer.Writer with + member _.Write(str) = + stream.WriteAsync(str) |> Async.AwaitTask + + member _.MakeImportPath(path) = + let projDir = IO.Path.GetDirectoryName(cliArgs.ProjectFile) + + let path = + Imports.getImportPath pathResolver sourcePath targetPath projDir cliArgs.OutDir path + + if path.EndsWith(".fs") then + Path.ChangeExtension(path, fileExt) + else + path + + member _.AddSourceMapping(_, _, _, _, _, _) = () + + member _.AddLog(msg, severity, ?range) = + com.AddLog(msg, severity, ?range = range, fileName = com.CurrentFile) + + member _.Dispose() = stream.Dispose() + + let compileFile (com: Compiler) (cliArgs: CliArgs) pathResolver isSilent (outPath: string) = + async { + let lua = + FSharp2Fable.Compiler.transformFile com + |> FableTransforms.transformFile com + |> Fable2Lua.transformFile com + + // + if not (isSilent || LuaPrinter.isEmpty lua) then + use writer = new LuaWriter(com, cliArgs, pathResolver, outPath) + do! LuaPrinter.run writer lua + + } + + let compileFile (com: Compiler) (cliArgs: CliArgs) pathResolver isSilent (outPath: string) = match com.Options.Language with | JavaScript @@ -483,3 +528,4 @@ let compileFile (com: Compiler) (cliArgs: CliArgs) pathResolver isSilent (outPat | Php -> Php.compileFile com cliArgs pathResolver isSilent outPath | Dart -> Dart.compileFile com cliArgs pathResolver isSilent outPath | Rust -> Rust.compileFile com cliArgs pathResolver isSilent outPath + | Lua -> Lua.compileFile com cliArgs pathResolver isSilent outPath diff --git a/src/Fable.Compiler/ProjectCracker.fs b/src/Fable.Compiler/ProjectCracker.fs index 95fe64c2a8..ea9541e0a6 100644 --- a/src/Fable.Compiler/ProjectCracker.fs +++ b/src/Fable.Compiler/ProjectCracker.fs @@ -662,6 +662,7 @@ let getFableLibraryPath (opts: CrackerOptions) = | JavaScript, None -> "fable-library-js", $"fable-library-js.%s{Literals.VERSION}" | Python, None -> "fable-library-py/fable_library", "fable_library" | Python, Some Py.Naming.sitePackages -> "fable-library-py", "fable-library" + | Lua, None -> "fable-library-lua", "fable-library-lua" | _, Some path -> if path.StartsWith("./", StringComparison.Ordinal) then "", Path.normalizeFullPath path diff --git a/src/Fable.Compiler/Util.fs b/src/Fable.Compiler/Util.fs index 3c393d6bc9..ee9ac56e45 100644 --- a/src/Fable.Compiler/Util.fs +++ b/src/Fable.Compiler/Util.fs @@ -134,6 +134,7 @@ module File = | Fable.Dart -> ".dart" | Fable.Rust -> ".rs" | Fable.JavaScript -> ".js" + | Fable.Lua -> ".lua" match language, usesOutDir with | Fable.Python, _ -> fileExt // Extension will always be .py for Python diff --git a/src/Fable.Core/Fable.Core.LuaInterop.fs b/src/Fable.Core/Fable.Core.LuaInterop.fs new file mode 100644 index 0000000000..27c542a49b --- /dev/null +++ b/src/Fable.Core/Fable.Core.LuaInterop.fs @@ -0,0 +1,90 @@ +module Fable.Core.LuaInterop + +open System +open Fable.Core + +// /// Has same effect as `unbox` (dynamic casting erased in compiled Lua code). +// /// The casted type can be defined on the call site: `!!myObj?bar(5): float` +// let (!!) x: 'T = nativeOnly + +// /// Implicit cast for erased unions (U2, U3...) +// let inline (!^) (x:^t1) : ^t2 = ((^t1 or ^t2) : (static member op_ErasedCast : ^t1 -> ^t2) x) + +// /// Dynamically access a property of an arbitrary object. +// /// `myObj?propA` in Lua becomes `myObj.propA` +// /// `myObj?(propA)` in Lua becomes `myObj[propA]` +// let (?) (o: obj) (prop: obj): 'a = nativeOnly + +// /// Dynamically assign a value to a property of an arbitrary object. +// /// `myObj?propA <- 5` in Lua becomes `myObj.propA = 5` +// /// `myObj?(propA) <- 5` in Lua becomes `myObj[propA] = 5` +// let (?<-) (o: obj) (prop: obj) (v: obj): unit = nativeOnly + +// /// Works like `ImportAttribute` (same semantics as ES6 imports). +// /// You can use "*" or "default" selectors. +// let import<'T> (selector: string) (path: string):'T = nativeOnly + +// /// F#: let myMember = importMember "myModule" +// /// Py: from my_module import my_member +// /// Note the import must be immediately assigned to a value in a let binding +// let importMember<'T> (path: string):'T = nativeOnly + +// /// F#: let myLib = importAll "myLib" +// /// Py: from my_lib import * +// let importAll<'T> (path: string):'T = nativeOnly + +[] +module Lua = + + [] + type ArrayConstructor = + [] + abstract Create: size: int -> 'T[] + + [] + abstract isArray: arg: obj -> bool + + abstract from: arg: obj -> 'T[] + + [] + let Array: ArrayConstructor = nativeOnly +(* +/// Destructure and apply a tuple to an arbitrary value. +/// E.g. `myFn $ (arg1, arg2)` in Python becomes `myFn(arg1, arg2)` +let ($) (callee: obj) (args: obj): 'a = nativeOnly + +/// Upcast the right operand to obj (and uncurry it if it's a function) and create a key-value tuple. +/// Mostly convenient when used with `createObj`. +/// E.g. `createObj [ "a" ==> 5 ]` in Python becomes `{ a: 5 }` +let (==>) (key: string) (v: obj): string*obj = nativeOnly + +/// Destructure a tuple of arguments and applies to literal Python code as with EmitAttribute. +/// E.g. `emitExpr (arg1, arg2) "$0 + $1"` in Python becomes `arg1 + arg2` +let emitExpr<'T> (args: obj) (pyCode: string): 'T = nativeOnly + +/// Same as emitExpr but intended for Python code that must appear in a statement position +/// E.g. `emitStatement aValue "while($0 < 5) doSomething()"` +let emitStatement<'T> (args: obj) (pyCode: string): 'T = nativeOnly + +/// Create a literal Python object from a collection of key-value tuples. +/// E.g. `createObj [ "a" ==> 5 ]` in Python becomes `{ a: 5 }` +let createObj (fields: #seq): obj = nativeOnly + +/// Create a literal Python object from a collection of union constructors. +/// E.g. `keyValueList CaseRules.LowerFirst [ MyUnion 4 ]` in Python becomes `{ myUnion: 4 }` +let keyValueList (caseRule: CaseRules) (li: 'T seq): obj = nativeOnly + +/// Create an empty Python object: {} +let createEmpty<'T> : 'T = nativeOnly + +[] +let pyTypeof (x: obj): string = nativeOnly + +[] +let pyInstanceof (x: obj) (cons: obj): bool = nativeOnly + + + +/// Imports a file only for its side effects +let importSideEffects (path: string): unit = nativeOnly +*) diff --git a/src/Fable.Core/Fable.Core.fsproj b/src/Fable.Core/Fable.Core.fsproj index 6c68c417f4..d81449790a 100644 --- a/src/Fable.Core/Fable.Core.fsproj +++ b/src/Fable.Core/Fable.Core.fsproj @@ -1,4 +1,4 @@ - + true @@ -15,6 +15,7 @@ + @@ -26,4 +27,4 @@ - + \ No newline at end of file diff --git a/src/Fable.Transforms/FSharp2Fable.Util.fs b/src/Fable.Transforms/FSharp2Fable.Util.fs index e4cb8e414e..66e262cb67 100644 --- a/src/Fable.Transforms/FSharp2Fable.Util.fs +++ b/src/Fable.Transforms/FSharp2Fable.Util.fs @@ -381,7 +381,6 @@ type FsEnt(maybeAbbrevEnt: FSharpEntity) = isInstance: bool, ?argTypes: Fable.Type[], ?genArgs, - // ?searchHierarchy: bool, ?requireDispatchSlot: bool ) = @@ -910,6 +909,7 @@ module Helpers = match com.Options.Language with | Python | JavaScript + | Lua | TypeScript -> memb.IsMutable && isNotPrivate memb | Rust -> true // always | Php diff --git a/src/Fable.Transforms/Fable.Transforms.fsproj b/src/Fable.Transforms/Fable.Transforms.fsproj index 44a8779e54..6b29adf68a 100644 --- a/src/Fable.Transforms/Fable.Transforms.fsproj +++ b/src/Fable.Transforms/Fable.Transforms.fsproj @@ -38,6 +38,10 @@ + + + + diff --git a/src/Fable.Transforms/Lua/Compiler.fs b/src/Fable.Transforms/Lua/Compiler.fs new file mode 100644 index 0000000000..4bdb5fb7eb --- /dev/null +++ b/src/Fable.Transforms/Lua/Compiler.fs @@ -0,0 +1,17 @@ +module rec Fable.Compilers.Lua + +open Fable +open Fable.AST +open Fable.AST.Fable + +type LuaCompiler(com: Fable.Compiler) = + let mutable types = Map.empty + let mutable decisionTreeTargets = [] + member this.Com = com + member this.AddClassDecl(c: ClassDecl) = types <- types |> Map.add c.Entity c + member this.GetByRef(e: EntityRef) = types |> Map.tryFind e + member this.DecisionTreeTargets(exprs: (Fable.Ident list * Expr) list) = decisionTreeTargets <- exprs + member this.GetDecisionTreeTargets(idx: int) = decisionTreeTargets.[idx] + + member this.GetMember(memberRef: Fable.MemberRef) : Fable.MemberFunctionOrValue = + com.GetMember(memberRef: Fable.MemberRef) diff --git a/src/Fable.Transforms/Lua/Fable2Lua.fs b/src/Fable.Transforms/Lua/Fable2Lua.fs new file mode 100644 index 0000000000..8fe92266cd --- /dev/null +++ b/src/Fable.Transforms/Lua/Fable2Lua.fs @@ -0,0 +1,547 @@ +module rec Fable.Transforms.Fable2Lua + +//cloned from FableToBabel + +open System +open System.Collections.Generic +open System.Text.RegularExpressions + +open Fable +open Fable.AST +open Fable.AST.Lua +open Fable.Compilers.Lua +open Fable.Naming +open Fable.Core + + +type UsedNames = + { + RootScope: HashSet + DeclarationScopes: HashSet + CurrentDeclarationScope: HashSet + } + +type BoundVars = + { + EnclosingScope: HashSet + LocalScope: HashSet + Inceptions: int + } + +type ITailCallOpportunity = + abstract Label: string + abstract Args: Fable.Ident list + abstract IsRecursiveRef: Fable.Expr -> bool + +type Context = + { + File: Fable.File + UsedNames: UsedNames + BoundVars: BoundVars + DecisionTargets: (Fable.Ident list * Fable.Expr) list + HoistVars: Fable.Ident list -> bool + TailCallOpportunity: ITailCallOpportunity option + OptimizeTailCall: unit -> unit + ScopedTypeParams: Set + TypeParamsScope: int + } + +module Transforms = + module Helpers = + let transformStatements transformStatements transformReturn exprs = + [ + match exprs |> List.rev with + | h :: t -> + for x in t |> List.rev do + yield transformStatements x + + yield transformReturn h + | [] -> () + ] + + let ident name = + Ident + { + Name = name + Namespace = None + } + + let fcall args expr = FunctionCall(expr, args) + + ///lua's immediately invoked function expressions + let iife statements = + FunctionCall(AnonymousFunc([], statements), []) + + let debugLog expr = + FunctionCall(Helpers.ident "print", [ expr ]) |> Do + + let libEquality a b = + FunctionCall( + GetObjMethod( + FunctionCall(Helpers.ident "require", [ ConstString "./fable-lib/Util" |> Const ]), + "equals" + ), + [ a; b ] + ) + + let maybeIife = + function + | [] -> NoOp + | [ Return expr ] -> expr + | statements -> iife statements + + let tryNewObj (names: string list) (values: Expr list) = + if names.Length = values.Length then + let pairs = List.zip names values + + let compareExprs = + names + |> List.map (fun name -> + libEquality (GetField(Helpers.ident "self", name)) (GetField(Helpers.ident "toCompare", name)) + ) + + let compareExprAcc = + compareExprs |> List.reduce (fun acc item -> Binary(And, acc, item)) + + let equality = + "Equals", + Function( + [ "self"; "toCompare" ], + [ + //yield debugLog (ConstString "Calling equality" |> Const) + // debugLog (Helpers.ident "self") + // debugLog (Helpers.ident "toCompare") + //yield! compareExprs |> List.map debugLog + Return compareExprAcc + ] + ) + + NewObj(equality :: pairs) + else + sprintf "Names and values do not match %A %A" names values |> Unknown + + let transformValueKind (com: LuaCompiler) = + function + | Fable.NumberConstant(Fable.AST.Fable.NumberValue.Float64 v, _kind) -> Const(ConstNumber v) + | Fable.NumberConstant(Fable.AST.Fable.NumberValue.Int32 v, _kind) -> Const(ConstInteger v) + | Fable.StringConstant(s) -> Const(ConstString s) + | Fable.BoolConstant(b) -> Const(ConstBool b) + | Fable.UnitConstant -> Const(ConstNull) + | Fable.CharConstant(c: char) -> Const(ConstString(string c)) + // | Fable.EnumConstant(e,ref) -> + // convertExpr com e + | Fable.NewRecord(values, ref, args) -> + let entity = com.Com.GetEntity(ref) + + if entity.IsFSharpRecord then + let names = entity.FSharpFields |> List.map (fun f -> f.Name) + let values = values |> List.map (transformExpr com) + Helpers.tryNewObj names values + else + sprintf "unknown ety %A %A %A %A" values ref args entity |> Unknown + | Fable.NewAnonymousRecord(values, names, _, _) -> + let transformedValues = values |> List.map (transformExpr com) + Helpers.tryNewObj (Array.toList names) transformedValues + | Fable.NewUnion(values, tag, _, _) -> + let values = + values + |> List.map (transformExpr com) + |> List.mapi (fun i x -> sprintf "p_%i" i, x) + + NewObj(("tag", tag |> float |> ConstNumber |> Const) :: values) + | Fable.NewOption(value, t, _) -> + value |> Option.map (transformExpr com) |> Option.defaultValue (Const ConstNull) + | Fable.NewTuple(values, isStruct) -> + // let fields = values |> List.mapi(fun i x -> sprintf "p_%i" i, transformExpr com x) + // NewObj(fields) + NewArr(values |> List.map (transformExpr com)) + | Fable.NewArray(kind, t, _) -> + match kind with + | Fable.ArrayValues values -> NewArr(values |> List.map (transformExpr com)) + | _ -> NewArr([]) + | Fable.Null _ -> Const(ConstNull) + | x -> sprintf "unknown %A" x |> ConstString |> Const + + let transformOp com = + let transformExpr = transformExpr com + + function + | Fable.OperationKind.Binary(BinaryModulus, left, right) -> + GetField(Helpers.ident "math", "fmod") + |> Helpers.fcall [ transformExpr left; transformExpr right ] + | Fable.OperationKind.Binary(op, left, right) -> + let op = + match op with + | BinaryOperator.BinaryMultiply -> Multiply + | BinaryOperator.BinaryDivide -> Divide + | BinaryOperator.BinaryEqual -> Equals + | BinaryOperator.BinaryPlus -> Plus + | BinaryOperator.BinaryMinus -> Minus + | BinaryOperator.BinaryUnequal -> Unequal + | BinaryOperator.BinaryLess -> Less + | BinaryOperator.BinaryGreater -> Greater + | BinaryOperator.BinaryLessOrEqual -> LessOrEqual + | BinaryOperator.BinaryGreaterOrEqual -> GreaterOrEqual + | x -> sprintf "%A" x |> BinaryTodo + + Binary(op, transformExpr left, transformExpr right) + | Fable.OperationKind.Unary(op, expr) -> + match op with + | UnaryOperator.UnaryNotBitwise -> transformExpr expr //not sure why this is being added + | UnaryOperator.UnaryNot -> Unary(Not, transformExpr expr) + | _ -> sprintf "%A %A" op expr |> Unknown + | x -> Unknown(sprintf "%A" x) + + ///lua's immediately invoked function expressions + let asSingleExprIife (exprs: Expr list) : Expr = //function + match exprs with + | [] -> NoOp + | [ h ] -> h + | exprs -> + let statements = Helpers.transformStatements (Do) (Return) exprs + statements |> Helpers.maybeIife + + let flattenReturnIifes e = + let rec collectStatementsRec = + function + | Return(FunctionCall(AnonymousFunc([], [ Return s ]), [])) -> [ Return s ] + | Return(FunctionCall(AnonymousFunc([], statements), [])) -> //self executing functions only + statements |> List.collect collectStatementsRec + | x -> [ x ] + + let statements = collectStatementsRec e + + match statements with + | [ Return s ] -> Return s + | [] -> NoOp |> Do + | _ -> FunctionCall(AnonymousFunc([], statements), []) |> Return + + ///lua's immediately invoked function expressions + let asSingleExprIifeTr com : Fable.Expr list -> Expr = + List.map (transformExpr com) >> asSingleExprIife + + let (|Regex|_|) pattern input = + let m = Regex.Match(input, pattern) + + if m.Success then + Some(List.tail [ for g in m.Groups -> g.Value ]) + else + None + + let transformExpr (com: LuaCompiler) expr = + let transformExpr = transformExpr com + let transformOp = transformOp com + + match expr with + | Fable.IdentExpr i when i.Name = "" -> Unknown "ident" + | Fable.IdentExpr i -> + Ident + { + Namespace = None + Name = i.Name + } + | Fable.Value(value, _) -> transformValueKind com value + | Fable.Lambda(arg, body, name) -> Function([ arg.Name ], [ transformExpr body |> Return ]) + | Fable.Delegate(idents, body, _, _) -> + Function(idents |> List.map (fun i -> i.Name), [ transformExpr body |> Return |> flattenReturnIifes ]) //can be flattened + | Fable.ObjectExpr(_members, typ, _baseCall) -> Unknown $"Obj %A{typ}" + | Fable.TypeCast(expr, t) -> transformExpr expr //typecasts are meaningless + | Fable.Test(expr, kind, b) -> + match kind with + | Fable.UnionCaseTest i -> Binary(Equals, GetField(transformExpr expr, "tag"), Const(ConstNumber(float i))) + | Fable.OptionTest isSome -> + if isSome then + Binary(Unequal, Const ConstNull, transformExpr expr) + else + Binary(Equals, Const ConstNull, transformExpr expr) + | Fable.TestKind.TypeTest t -> + // match t with + // | Fable.DeclaredType (ent, genArgs) -> + // match ent.FullName with + // | Fable.Transforms.Types.ienumerable -> //isArrayLike + // | Fable.Transforms.Types.array + // | _ -> + // | _ -> () + Binary(Equals, GetField(transformExpr expr, "type"), Const(t.ToString() |> ConstString)) + | _ -> Unknown(sprintf "test %A %A" expr kind) + | Fable.Call(expr, callInfo, t, r) -> + let lhs = + match expr with + | Fable.Get(expr, Fable.GetKind.FieldGet info, t, _) -> + match t with + | Fable.DeclaredType(_, _) + | Fable.AnonymousRecordType(_, _, _) -> GetObjMethod(transformExpr expr, info.Name) + | _ -> transformExpr expr + | Fable.Delegate _ -> transformExpr expr |> Parentheses + | _ -> transformExpr expr + + FunctionCall(lhs, List.map transformExpr callInfo.Args) + | Fable.CurriedApply(applied, args, _, _) -> FunctionCall(transformExpr applied, args |> List.map transformExpr) + | Fable.Operation(kind, _, _, _) -> transformOp kind + | Fable.Import(info, t, r) -> + let path = + match info.Kind, info.Path with + | libImport, Regex "fable-lib\/(\w+).(?:fs|js)" [ name ] -> "fable-lib/" + name + | _, Regex "fable-library-lua\/fable\/fable-library\/(\w+).(?:fs|js)" [ name ] -> + "fable-lib/fable-library" + name + | _, Regex "fable-library-lua\/fable\/(\w+).(?:fs|js)" [ name ] -> "fable-lib/" + name + | _ -> info.Path.Replace(".fs", "").Replace(".js", "") //todo - make less brittle + + let rcall = + FunctionCall( + Ident + { + Namespace = None + Name = "require" + }, + [ Const(ConstString path) ] + ) + + if String.IsNullOrEmpty info.Selector then + rcall + else + GetObjMethod(rcall, info.Selector) + + | Fable.Emit(m, _, _) -> + // let argsExprs = m.CallInfo.Args |> List.map transformExpr + // let macroExpr = Macro(m.Macro, argsExprs) + // let exprs = + // argsExprs + // @ [macroExpr] + // asSingleExprIife exprs + Macro(m.Macro, m.CallInfo.Args |> List.map transformExpr) + | Fable.DecisionTree(expr, lst) -> + com.DecisionTreeTargets(lst) + transformExpr expr + | Fable.DecisionTreeSuccess(i, boundValues, _) -> + let idents, target = com.GetDecisionTreeTargets(i) + + if idents.Length = boundValues.Length then + let statements = + [ + for (ident, value) in List.zip idents boundValues do + yield Assignment([ ident.Name ], transformExpr value, false) + yield transformExpr target |> Return + ] + + statements |> Helpers.maybeIife + else + sprintf "not equal lengths %A %A" idents boundValues |> Unknown + | Fable.Let(ident, value, body) -> + let statements = + [ + Assignment([ ident.Name ], transformExpr value, true) + transformExpr body |> Return + ] + + Helpers.maybeIife statements + | Fable.LetRec(ls, m) -> + match ls with + | [] -> Unknown "let rec" + | [ (i, e) ] -> Unknown $"let rec %A{i.Name}" + | (i, e) :: ls -> Unknown $"let rec %A{i.Name}" + | Fable.Get(expr, kind, t, _) -> + match kind with + | Fable.GetKind.FieldGet info -> GetField(transformExpr expr, info.Name) + | Fable.GetKind.UnionField info -> GetField(transformExpr expr, sprintf "p_%i" info.FieldIndex) + | Fable.GetKind.ExprGet e -> GetAtIndex(transformExpr expr, transformExpr e) + | Fable.GetKind.TupleIndex i -> GetAtIndex(transformExpr expr, Const(ConstNumber(float i))) + | Fable.GetKind.OptionValue -> transformExpr expr //todo null check, throw if null? + | Fable.ListHead -> Unknown "list Head" + | Fable.ListTail -> Unknown "list Tail" + | Fable.UnionTag -> Unknown "Union Tag" + | Fable.Set(expr, kind, t, value, _) -> + match kind with + | Fable.SetKind.ValueSet -> SetValue(transformExpr expr, transformExpr value) + | Fable.SetKind.ExprSet e -> SetExpr(transformExpr expr, transformExpr e, transformExpr value) + | Fable.SetKind.FieldSet name -> Unknown $"FieldSet %s{name} of type %A{t}" + | Fable.Sequential exprs -> asSingleExprIifeTr com exprs + | Fable.WhileLoop(guard, body, _label) -> + Helpers.maybeIife [ WhileLoop(transformExpr guard, [ transformExpr body |> Do ]) ] + | Fable.ForLoop(ident, start, limit, body, _isUp, _) -> + Helpers.maybeIife + [ + ForLoop(ident.Name, transformExpr start, transformExpr limit, [ transformExpr body |> Do ]) + ] + | Fable.TryCatch(body, catch, finalizer, _) -> + Helpers.maybeIife + [ + Assignment( + [ "status"; "resOrErr" ], + FunctionCall(Helpers.ident "pcall", [ Function([], [ transformExpr body |> Return ]) ]), + true + ) + let finalizer = finalizer |> Option.map transformExpr + + let catch = + catch |> Option.map (fun (ident, expr) -> ident.Name, transformExpr expr) + + IfThenElse( + Helpers.ident "status", + [ + match finalizer with + | Some finalizer -> yield Do finalizer + | None -> () + yield Helpers.ident "resOrErr" |> Return + ], + [ + match catch with + | Some(ident, expr) -> yield expr |> Return + | _ -> () + ] + ) + ] + | Fable.IfThenElse(guardExpr, thenExpr, elseExpr, _) -> + Ternary(transformExpr guardExpr, transformExpr thenExpr, transformExpr elseExpr) + | Fable.Unresolved(kind, _typ, _range) -> + match kind with + | Fable.UnresolvedExpr.UnresolvedTraitCall _ -> Unknown "Unresolved Trait" + | Fable.UnresolvedExpr.UnresolvedReplaceCall _ -> Unknown "Unresolved Replace" + | Fable.UnresolvedExpr.UnresolvedInlineCall _ -> Unknown "Unresolved Inline" + | Fable.Extended(kind, t) -> + match kind with + | Fable.ExtendedSet.Throw(expr, typ) -> + let errorExpr = Const(ConstString "There was an error, todo") + FunctionCall(Helpers.ident "error", [ errorExpr ]) + | Fable.ExtendedSet.Debugger -> Unknown "Debugger" + | Fable.ExtendedSet.Curry(e, a) -> + //transformExpr expr |> sprintf "(Fable2Lua:~266) todo curry %A" |> Unknown + Unknown $"Curry (arity: %i{a})" //in rare cases currying may need to happen at runtime + + + let transformDeclarations (com: LuaCompiler) ctx decl = + let withCurrentScope (ctx: Context) (usedNames: Set) f = + let ctx = + { ctx with UsedNames = { ctx.UsedNames with CurrentDeclarationScope = HashSet usedNames } } + + let result = f ctx + ctx.UsedNames.DeclarationScopes.UnionWith(ctx.UsedNames.CurrentDeclarationScope) + result + + let transformAttachedProperty + (com: LuaCompiler) + ctx + (info: Fable.MemberFunctionOrValue) + (memb: Fable.MemberDecl) + = + //TODO For some reason this never gets hit + let isStatic = not info.IsInstance + let isGetter = info.IsGetter + + let decorators = + [ + if isStatic then + Helpers.ident "staticmethod" + elif isGetter then + Helpers.ident "property" + else + Helpers.ident $"%s{memb.Name}.setter" + ] + + let args, body, returnType = [ "" ], [ Do(Unknown "") ], Unknown "" + //getMemberArgsAndBody com ctx (Attached isStatic) false memb.Args memb.Body + + let key = "key" + // memberFromName com ctx memb.Name + // |> nameFromKey com ctx + + // let arguments = + // if isStatic then + // // { args with Args = [""] } + // else + // { args with Args = args}//args.Args } + + Function(args, body) + //(key, arguments, body = body, decoratorList = decorators, returns = returnType) + |> List.singleton + + let transformAttachedMethod + (com: LuaCompiler) + ctx + (info: Fable.MemberFunctionOrValue) + (memb: Fable.MemberDecl) + = + [ Helpers.ident info.FullName ] + //TODO For some reason this never gets hit + + + match decl with + | Fable.ModuleDeclaration _m -> Assignment([ "moduleDecTest" ], Expr.Const(ConstString "moduledectest"), false) + | Fable.MemberDeclaration m -> + if m.Args.Length = 0 then + Assignment([ m.Name ], transformExpr com m.Body, true) + else + let info = com.Com.GetMember(m.MemberRef) + + let unwrapSelfExStatements = + match transformExpr com m.Body |> Return |> flattenReturnIifes with + | Return(FunctionCall(AnonymousFunc([], statements), [])) -> statements + | s -> [ s ] + + FunctionDeclaration(m.Name, m.Args |> List.map (fun a -> a.Name), unwrapSelfExStatements, info.IsPublic) + | Fable.ClassDeclaration(d) -> + com.AddClassDecl d + let _ent = d.Entity + + let transformAttached (memb: Fable.MemberDecl) ctx = + let info = + memb.ImplementedSignatureRef + |> Option.map com.GetMember + |> Option.defaultWith (fun () -> com.GetMember(memb.MemberRef)) + + if not memb.IsMangled && (info.IsGetter || info.IsSetter) then + transformAttachedProperty com ctx info memb + else + transformAttachedMethod com ctx info memb + + let classMembers = + d.AttachedMembers + |> List.collect (fun memb -> withCurrentScope ctx memb.UsedNames <| transformAttached memb) + + match d.Constructor with + | Some cons -> + withCurrentScope ctx cons.UsedNames + <| fun _ctx -> Assignment([ d.Name ], NewArr(classMembers), true) //transformClassWithPrimaryConstructor com ctx ent decl classMembers cons + | None -> Assignment([ d.Name ], NewArr(classMembers), true) + + | x -> sprintf "%A" x |> Unknown |> Do + +let transformFile com (file: Fable.File) : File = + let declScopes = + let hs = HashSet() + + for decl in file.Declarations do + hs.UnionWith(decl.UsedNames) + + hs + + let ctx: Context = + { + File = file + UsedNames = + { + RootScope = HashSet file.UsedNamesInRootScope + DeclarationScopes = declScopes + CurrentDeclarationScope = Unchecked.defaultof<_> + } + BoundVars = + { + EnclosingScope = HashSet() + LocalScope = HashSet() + Inceptions = 0 + } + DecisionTargets = [] + HoistVars = fun _ -> false + TailCallOpportunity = None + OptimizeTailCall = fun () -> () + ScopedTypeParams = Set.empty + TypeParamsScope = 0 + } + + let comp = LuaCompiler(com) + + { + Filename = "abc" + Statements = file.Declarations |> List.map (Transforms.transformDeclarations comp ctx) + ASTDebug = sprintf "%A" file.Declarations + } diff --git a/src/Fable.Transforms/Lua/Lua.fs b/src/Fable.Transforms/Lua/Lua.fs new file mode 100644 index 0000000000..24242d3193 --- /dev/null +++ b/src/Fable.Transforms/Lua/Lua.fs @@ -0,0 +1,73 @@ +// fsharplint:disable MemberNames InterfaceNames +namespace rec Fable.AST.Lua + + +type Const = + | ConstNumber of float + | ConstInteger of int + | ConstString of string + | ConstBool of bool + | ConstNull + +type LuaIdentity = + { + Namespace: string option + Name: string + } + +type UnaryOp = + | Not + | NotBitwise + +type BinaryOp = + | Equals + | Unequal + | Less + | LessOrEqual + | Greater + | GreaterOrEqual + | Multiply + | Divide + | Plus + | Minus + | BinaryTodo of string + | And + | Or + +type Expr = + | Ident of LuaIdentity + | Const of Const + | Unary of UnaryOp * Expr + | Binary of BinaryOp * Expr * Expr + | GetField of Expr * name: string + | GetObjMethod of Expr * name: string + | GetAtIndex of Expr * idx: Expr + | SetValue of Expr * value: Expr + | SetExpr of Expr * Expr * value: Expr + | FunctionCall of f: Expr * args: Expr list + | Parentheses of Expr + | AnonymousFunc of args: string list * body: Statement list + | Unknown of string + | Macro of string * args: Expr list + | Ternary of guardExpr: Expr * thenExpr: Expr * elseExpr: Expr + | NoOp + | Function of args: string list * body: Statement list + | NewObj of values: (string * Expr) list + | NewArr of values: Expr list + +type Statement = + | Assignment of names: string list * Expr * isLocal: bool + | FunctionDeclaration of name: string * args: string list * body: Statement list * exportToMod: bool + | Return of Expr + | Do of Expr + | SNoOp + | ForLoop of string * start: Expr * limit: Expr * body: Statement list + | WhileLoop of guard: Expr * body: Statement list + | IfThenElse of guard: Expr * thenSt: Statement list * elseSt: Statement list + +type File = + { + Filename: string + Statements: (Statement) list + ASTDebug: string + } diff --git a/src/Fable.Transforms/Lua/LuaPrinter.fs b/src/Fable.Transforms/Lua/LuaPrinter.fs new file mode 100644 index 0000000000..be0ab3da25 --- /dev/null +++ b/src/Fable.Transforms/Lua/LuaPrinter.fs @@ -0,0 +1,369 @@ +// fsharplint:disable InterfaceNames +module Fable.Transforms.LuaPrinter + +open System +open System.IO +open Fable +open Fable.AST +open Fable.AST.Lua + +type System.Text.StringBuilder with + + member sb.Write(txt: string) = sb.Append(txt) |> ignore + + member sb.WriteLine(txt: string) = + sb.Append(txt) |> ignore + sb.AppendLine() |> ignore + +module Output = + + type Writer = + { + Writer: System.Text.StringBuilder + Indent: int + Precedence: int + CurrentNamespace: string option + } + + module Helper = + let separateWithCommas = + function + | [] -> "" + | [ x ] -> x + | lst -> lst |> List.reduce (fun acc item -> acc + " ," + item) + + let indent ctx = { ctx with Indent = ctx.Indent + 1 } + + module Writer = + let create w = + { + Writer = w + Indent = 0 + Precedence = Int32.MaxValue + CurrentNamespace = None + } + + let writeIndent ctx = + for _ in 1 .. ctx.Indent do + ctx.Writer.Write(" ") + + let write ctx txt = ctx.Writer.Write(txt: string) + + let writei ctx txt = + writeIndent ctx + write ctx txt + + let writeln ctx txt = ctx.Writer.WriteLine(txt: string) + + let writeCommented ctx help txt = + writeln ctx "--[[" + write ctx help + writeln ctx txt + writeln ctx " --]]" + + let writeOp ctx = + function + | Multiply -> write ctx "*" + | Equals -> write ctx "==" + | Unequal -> write ctx "~=" + | Less -> write ctx "<" + | LessOrEqual -> write ctx "<=" + | Greater -> write ctx ">" + | GreaterOrEqual -> write ctx ">=" + | Divide -> write ctx """/""" + | Plus -> write ctx "+" + | Minus -> write ctx "-" + | And -> write ctx "and" + | Or -> write ctx "or" + | BinaryTodo x -> writeCommented ctx "binary todo" x + + let sprintExprSimple = + function + | Ident i -> i.Name + | _ -> "" + + let rec writeExpr ctx = + function + | Ident i -> write ctx i.Name + | Const c -> + match c with + | ConstString s -> s |> sprintf "'%s'" |> write ctx + | ConstNumber n -> n |> sprintf "%f" |> write ctx + | ConstInteger n -> n |> sprintf "%i" |> write ctx + | ConstBool b -> b |> sprintf "%b" |> write ctx + | ConstNull -> write ctx "nil" + | FunctionCall(e, args) -> + writeExpr ctx e + write ctx "(" + args |> writeExprs ctx + write ctx ")" + | AnonymousFunc(args, body) -> + write ctx "(function " + write ctx "(" + args |> Helper.separateWithCommas |> write ctx + write ctx ")" + writeln ctx "" + let ctxI = indent ctx + + for b in body do + writeStatement ctxI b + + writei ctx "end)" + | Unary(Not, expr) -> + write ctx "not " + writeExpr ctx expr + | Unary(NotBitwise, expr) -> + write ctx "~" + writeExpr ctx expr + | Binary(op, left, right) -> + writeExpr ctx left + write ctx " " + writeOp ctx op + write ctx " " + writeExpr ctx right + | GetField(expr, fieldName) -> + writeExpr ctx expr + write ctx "." + write ctx fieldName + | GetObjMethod(expr, fieldName) -> + writeExpr ctx expr + write ctx ":" + write ctx fieldName + | GetAtIndex(expr, idx) -> + writeExpr ctx expr + write ctx "[" + //hack alert - lua indexers are 1-based and not 0-based, so we need to "add1". Probably correct soln here is to simplify ast after +1 if possible + let add1 = Binary(BinaryOp.Plus, Const(ConstNumber 1.0), idx) + writeExpr ctx add1 + write ctx "]" + | SetValue(expr, value) -> + writeExpr ctx expr + write ctx " = " + writeExpr ctx value + | SetExpr(expr, a, value) -> + writeExpr ctx expr + write ctx " = " + // writeExpr ctx a + // write ctx " " + writeExpr ctx value + | Ternary(guardExpr, thenExpr, elseExpr) -> + //let ctxA = indent ctx + write ctx "(" + writeExpr ctx guardExpr + //writeln ctx "" + let ctxI = indent ctx + write ctx " and " + //writei ctx "and " + writeExpr ctxI thenExpr + //writeln ctx "" + write ctx " or " + //writei ctx "or " + writeExpr ctxI elseExpr + write ctx ")" + | Macro(macro, args) -> + + // let subbedMacro = + // (s, args |> List.mapi(fun i x -> i.ToString(), sprintExprSimple x)) + // ||> List.fold (fun acc (i, arg) -> acc.Replace("$"+i, arg) ) + // writei ctx subbedMacro + let regex = System.Text.RegularExpressions.Regex("\$(?\d)(?\.\.\.)?") + let matches = regex.Matches(macro) + let mutable pos = 0 + + for m in matches do + let n = int m.Groups.["n"].Value + write ctx (macro.Substring(pos, m.Index - pos)) + + if m.Groups.["s"].Success then + if n < args.Length then + match args.[n] with + | NewArr items -> + let mutable first = true + + for value in items do + if first then + first <- false + else + write ctx ", " + + writeExpr ctx value + | _ -> writeExpr ctx args.[n] + + elif n < args.Length then + writeExpr ctx args.[n] + + pos <- m.Index + m.Length + + write ctx (macro.Substring(pos)) + | Function(args, body) -> + write ctx "function " + write ctx "(" + args |> Helper.separateWithCommas |> write ctx + write ctx ")" + let ctxI = indent ctx + writeln ctxI "" + body |> List.iter (writeStatement ctxI) + writei ctx "end" + | NewObj(args) -> + write ctx "{" + let ctxI = indent ctx + writeln ctxI "" + + for idx, (name, expr) in args |> List.mapi (fun i x -> i, x) do + writei ctxI name + write ctxI " = " + writeExpr ctxI expr + + if idx < args.Length - 1 then + writeln ctxI "," + //writeExprs ctxI args + writeln ctx "" + writei ctx "}" + | NewArr(args) -> + write ctx "{" + let ctxI = indent ctx + writeln ctxI "" + + for idx, expr in args |> List.mapi (fun i x -> i, x) do + writei ctxI "" + writeExpr ctxI expr + + if idx < args.Length - 1 then + writeln ctxI "," + //writeExprs ctxI args + writeln ctx "" + writei ctx "}" + | NoOp -> () + | Parentheses expr -> + write ctx "(" + writeExpr ctx expr + write ctx ")" + + | Unknown x -> writeCommented ctx "todo: unknown" x + + and writeExprs ctx = + function + | [] -> () + | h :: t -> + writeExpr ctx h + + for item in t do + write ctx ", " + writeExpr ctx item + + and writeStatement ctx = + function + | Assignment(names, expr, isLocal) -> + let names = names |> Helper.separateWithCommas + writei ctx "" + + if isLocal then + write ctx "local " + + write ctx names + write ctx " = " + writeExpr ctx expr + writeln ctx "" + | FunctionDeclaration(name, args, body, exportToMod) -> + writei ctx "function " + write ctx name + write ctx "(" + // let args = if exportToMod then "self"::args else args + args |> Helper.separateWithCommas |> write ctx + write ctx ")" + let ctxI = indent ctx + writeln ctxI "" + body |> List.iter (writeStatement ctxI) + writeln ctx "end" + + if exportToMod then + writei ctx "mod." + write ctx name + write ctx " = function(self, ...) " + write ctx name + write ctx "(...)" + write ctx " end" + writeln ctxI "" + | Return expr -> + writei ctx "return " + writeExpr ctx expr + writeln ctx "" + | Do expr -> + writei ctx "" + writeExpr ctx expr + writeln ctx "" + | ForLoop(name, start, limit, body) -> + writei ctx "for " + write ctx name + write ctx "=" + writeExpr ctx start + write ctx ", " + writeExpr ctx limit + write ctx " do" + let ctxI = indent ctx + + for statement in body do + writeln ctxI "" + writeStatement ctxI statement + + writeln ctx "" + writei ctx "end" + writeln ctx "" + | WhileLoop(guard, body) -> + writei ctx "while " + writeExpr ctx guard + write ctx " do" + let ctxI = indent ctx + + for statement in body do + writeln ctxI "" + writeStatement ctxI statement + + writeln ctx "" + writei ctx "end" + writeln ctx "" + | IfThenElse(guard, thenSt, elseSt) -> + writei ctx "if " + writeExpr ctx guard + write ctx " then" + let ctxI = indent ctx + + for statement in thenSt do + writeln ctxI "" + writeStatement ctxI statement + + writeln ctx "" + writei ctx "else" + + for statement in elseSt do + writeln ctxI "" + writeStatement ctxI statement + + writeln ctx "" + writei ctx "end" + writeln ctx "" + | SNoOp -> () + + let writeFile ctx (file: File) = + writeln ctx "mod = {}" + + for s in file.Statements do + writeStatement ctx s + + write ctx "return mod" + //debugging + writeln ctx "" +// writeln ctx "--[[" +// sprintf "%s" file.ASTDebug |> write ctx +//sprintf "%A" file.Statements |> write ctx +//writeln ctx " --]]" + +let isEmpty (file: File) : bool = false //TODO: determine if printer will not print anything + +let run (writer: Printer.Writer) (lib: File) : Async = + async { + let sb = System.Text.StringBuilder() + let ctx = Output.Writer.create sb + Output.writeFile ctx lib + do! writer.Write(sb.ToString()) + } diff --git a/src/Fable.Transforms/Replacements.Api.fs b/src/Fable.Transforms/Replacements.Api.fs index 4cfe41e217..9ce608c3e7 100644 --- a/src/Fable.Transforms/Replacements.Api.fs +++ b/src/Fable.Transforms/Replacements.Api.fs @@ -81,7 +81,8 @@ let createMutablePublicValue (com: ICompiler) value = | TypeScript -> JS.Replacements.createAtom com value | Rust | Php - | Dart -> value + | Dart + | Lua -> value let getRefCell (com: ICompiler) r typ (expr: Expr) = match com.Options.Language with diff --git a/src/fable-library-lua/Array.fs b/src/fable-library-lua/Array.fs new file mode 100644 index 0000000000..d01cfafbce --- /dev/null +++ b/src/fable-library-lua/Array.fs @@ -0,0 +1,1226 @@ +module ArrayModule + +// Disables warn:1204 raised by use of LanguagePrimitives.ErrorStrings.* +#nowarn "1204" + +open System.Collections.Generic +open Fable.Core +open Fable.Core.JsInterop + +open Native +open Native.Helpers + +let private indexNotFound () = + failwith "An index satisfying the predicate was not found in the collection." + +let private differentLengths () = failwith "Arrays had different lengths" + +// Pay attention when benchmarking to append and filter functions below +// if implementing via native JS array .concat() and .filter() do not fall behind due to js-native transitions. + +// Don't use native JS Array.prototype.concat as it doesn't work with typed arrays +let append (array1: 'T[]) (array2: 'T[]) ([] cons: Cons<'T>) : 'T[] = + let len1 = array1.Length + let len2 = array2.Length + let newArray = allocateArrayFromCons cons (len1 + len2) + + for i = 0 to len1 - 1 do + newArray.[i] <- array1.[i] + + for i = 0 to len2 - 1 do + newArray.[i + len1] <- array2.[i] + + newArray + +let filter (predicate: 'T -> bool) (array: 'T[]) = filterImpl predicate array + +// intentionally returns target instead of unit +let fill (target: 'T[]) (targetIndex: int) (count: int) (value: 'T) : 'T[] = fillImpl target value targetIndex count + +let getSubArray (array: 'T[]) (start: int) (count: int) : 'T[] = subArrayImpl array start count + +let last (array: 'T[]) = + if array.Length = 0 then + invalidArg "array" LanguagePrimitives.ErrorStrings.InputArrayEmptyString + + array.[array.Length - 1] + +let tryLast (array: 'T[]) = + if array.Length = 0 then + None + else + Some array.[array.Length - 1] + +let mapIndexed (f: int -> 'T -> 'U) (source: 'T[]) ([] cons: Cons<'U>) : 'U[] = + let len = source.Length + let target = allocateArrayFromCons cons len + + for i = 0 to (len - 1) do + target.[i] <- f i source.[i] + + target + +let map (f: 'T -> 'U) (source: 'T[]) ([] cons: Cons<'U>) : 'U[] = + let len = source.Length + let target = allocateArrayFromCons cons len + + for i = 0 to (len - 1) do + target.[i] <- f source.[i] + + target + +let mapIndexed2 + (f: int -> 'T1 -> 'T2 -> 'U) + (source1: 'T1[]) + (source2: 'T2[]) + ([] cons: Cons<'U>) + : 'U[] + = + if source1.Length <> source2.Length then + failwith "Arrays had different lengths" + + let result = allocateArrayFromCons cons source1.Length + + for i = 0 to source1.Length - 1 do + result.[i] <- f i source1.[i] source2.[i] + + result + +let map2 (f: 'T1 -> 'T2 -> 'U) (source1: 'T1[]) (source2: 'T2[]) ([] cons: Cons<'U>) : 'U[] = + if source1.Length <> source2.Length then + failwith "Arrays had different lengths" + + let result = allocateArrayFromCons cons source1.Length + + for i = 0 to source1.Length - 1 do + result.[i] <- f source1.[i] source2.[i] + + result + +let mapIndexed3 + (f: int -> 'T1 -> 'T2 -> 'T3 -> 'U) + (source1: 'T1[]) + (source2: 'T2[]) + (source3: 'T3[]) + ([] cons: Cons<'U>) + : 'U[] + = + if source1.Length <> source2.Length || source2.Length <> source3.Length then + failwith "Arrays had different lengths" + + let result = allocateArrayFromCons cons source1.Length + + for i = 0 to source1.Length - 1 do + result.[i] <- f i source1.[i] source2.[i] source3.[i] + + result + +let map3 + (f: 'T1 -> 'T2 -> 'T3 -> 'U) + (source1: 'T1[]) + (source2: 'T2[]) + (source3: 'T3[]) + ([] cons: Cons<'U>) + : 'U[] + = + if source1.Length <> source2.Length || source2.Length <> source3.Length then + failwith "Arrays had different lengths" + + let result = allocateArrayFromCons cons source1.Length + + for i = 0 to source1.Length - 1 do + result.[i] <- f source1.[i] source2.[i] source3.[i] + + result + +let mapFold<'T, 'State, 'Result> + (mapping: 'State -> 'T -> 'Result * 'State) + state + (array: 'T[]) + ([] cons: Cons<'Result>) + = + match array.Length with + | 0 -> [||], state + | len -> + let mutable acc = state + let res = allocateArrayFromCons cons len + + for i = 0 to array.Length - 1 do + let h, s = mapping acc array.[i] + res.[i] <- h + acc <- s + + res, acc + +let mapFoldBack<'T, 'State, 'Result> + (mapping: 'T -> 'State -> 'Result * 'State) + (array: 'T[]) + state + ([] cons: Cons<'Result>) + = + match array.Length with + | 0 -> [||], state + | len -> + let mutable acc = state + let res = allocateArrayFromCons cons len + + for i = array.Length - 1 downto 0 do + let h, s = mapping array.[i] acc + res.[i] <- h + acc <- s + + res, acc + +let indexed (source: 'T[]) = + let len = source.Length + let target = allocateArray len + + for i = 0 to (len - 1) do + target.[i] <- i, source.[i] + + target + +let truncate (count: int) (array: 'T[]) : 'T[] = + let count = max 0 count + subArrayImpl array 0 count + +let concat (arrays: 'T[] seq) ([] cons: Cons<'T>) : 'T[] = + let arrays = + if isDynamicArrayImpl arrays then + arrays :?> 'T[][] // avoid extra copy + else + arrayFrom arrays + + match arrays.Length with + | 0 -> allocateArrayFromCons cons 0 + | 1 -> arrays.[0] + | _ -> + let mutable totalIdx = 0 + let mutable totalLength = 0 + + for arr in arrays do + totalLength <- totalLength + arr.Length + + let result = allocateArrayFromCons cons totalLength + + for arr in arrays do + for j = 0 to (arr.Length - 1) do + result.[totalIdx] <- arr.[j] + totalIdx <- totalIdx + 1 + + result + +let collect (mapping: 'T -> 'U[]) (array: 'T[]) ([] cons: Cons<'U>) : 'U[] = + let mapped = map mapping array Unchecked.defaultof<_> + concat mapped cons +// collectImpl mapping array // flatMap not widely available yet + +let where predicate (array: _[]) = filterImpl predicate array + +let indexOf<'T> + (array: 'T[]) + (item: 'T) + (start: int option) + (count: int option) + ([] eq: IEqualityComparer<'T>) + = + let start = defaultArg start 0 + + let end' = + count |> Option.map (fun c -> start + c) |> Option.defaultValue array.Length + + let rec loop i = + if i >= end' then + -1 + else if eq.Equals(item, array.[i]) then + i + else + loop (i + 1) + + loop start + +let contains<'T> (value: 'T) (array: 'T[]) ([] eq: IEqualityComparer<'T>) = + indexOf array value None None eq >= 0 + +let empty cons = allocateArrayFromCons cons 0 + +let singleton value ([] cons: Cons<'T>) = + let ar = allocateArrayFromCons cons 1 + ar.[0] <- value + ar + +let initialize count initializer ([] cons: Cons<'T>) = + if count < 0 then + invalidArg "count" LanguagePrimitives.ErrorStrings.InputMustBeNonNegativeString + + let result = allocateArrayFromCons cons count + + for i = 0 to count - 1 do + result.[i] <- initializer i + + result + +let pairwise (array: 'T[]) = + if array.Length < 2 then + [||] + else + let count = array.Length - 1 + let result = allocateArray count + + for i = 0 to count - 1 do + result.[i] <- array.[i], array.[i + 1] + + result + +let replicate count initial ([] cons: Cons<'T>) = + // Shorthand version: = initialize count (fun _ -> initial) + if count < 0 then + invalidArg "count" LanguagePrimitives.ErrorStrings.InputMustBeNonNegativeString + + let result: 'T array = allocateArrayFromCons cons count + + for i = 0 to result.Length - 1 do + result.[i] <- initial + + result + +let copy (array: 'T[]) = + // if isTypedArrayImpl array then + // let res = allocateArrayFrom array array.Length + // for i = 0 to array.Length-1 do + // res.[i] <- array.[i] + // res + // else + copyImpl array + +let copyTo (source: 'T[]) (sourceIndex: int) (target: 'T[]) (targetIndex: int) (count: int) = + // TODO: Check array lengths + System.Array.Copy(source, sourceIndex, target, targetIndex, count) + +let reverse (array: 'T[]) = + // if isTypedArrayImpl array then + // let res = allocateArrayFrom array array.Length + // let mutable j = array.Length-1 + // for i = 0 to array.Length-1 do + // res.[j] <- array.[i] + // j <- j - 1 + // res + // else + copyImpl array |> reverseImpl + +let scan<'T, 'State> folder (state: 'State) (array: 'T[]) ([] cons: Cons<'State>) = + let res = allocateArrayFromCons cons (array.Length + 1) + res.[0] <- state + + for i = 0 to array.Length - 1 do + res.[i + 1] <- folder res.[i] array.[i] + + res + +let scanBack<'T, 'State> folder (array: 'T[]) (state: 'State) ([] cons: Cons<'State>) = + let res = allocateArrayFromCons cons (array.Length + 1) + res.[array.Length] <- state + + for i = array.Length - 1 downto 0 do + res.[i] <- folder array.[i] res.[i + 1] + + res + +let skip count (array: 'T[]) ([] cons: Cons<'T>) = + if count > array.Length then + invalidArg "count" "count is greater than array length" + + if count = array.Length then + allocateArrayFromCons cons 0 + else + let count = + if count < 0 then + 0 + else + count + + skipImpl array count + +let skipWhile predicate (array: 'T[]) ([] cons: Cons<'T>) = + let mutable count = 0 + + while count < array.Length && predicate array.[count] do + count <- count + 1 + + if count = array.Length then + allocateArrayFromCons cons 0 + else + skipImpl array count + +let take count (array: 'T[]) ([] cons: Cons<'T>) = + if count < 0 then + invalidArg "count" LanguagePrimitives.ErrorStrings.InputMustBeNonNegativeString + + if count > array.Length then + invalidArg "count" "count is greater than array length" + + if count = 0 then + allocateArrayFromCons cons 0 + else + subArrayImpl array 0 count + +let takeWhile predicate (array: 'T[]) ([] cons: Cons<'T>) = + let mutable count = 0 + + while count < array.Length && predicate array.[count] do + count <- count + 1 + + if count = 0 then + allocateArrayFromCons cons 0 + else + subArrayImpl array 0 count + +let addInPlace (x: 'T) (array: 'T[]) = + // if isTypedArrayImpl array then invalidArg "array" "Typed arrays not supported" + pushImpl array x |> ignore + +let addRangeInPlace (range: seq<'T>) (array: 'T[]) = + // if isTypedArrayImpl array then invalidArg "array" "Typed arrays not supported" + for x in range do + addInPlace x array + +let insertRangeInPlace index (range: seq<'T>) (array: 'T[]) = + // if isTypedArrayImpl array then invalidArg "array" "Typed arrays not supported" + let mutable i = index + + for x in range do + insertImpl array i x |> ignore + i <- i + 1 + +let removeInPlace (item: 'T) (array: 'T[]) ([] eq: IEqualityComparer<'T>) = + let i = indexOf array item None None eq + + if i > -1 then + spliceImpl array i 1 |> ignore + true + else + false + +let removeAllInPlace predicate (array: 'T[]) = + let rec countRemoveAll count = + let i = findIndexImpl predicate array + + if i > -1 then + spliceImpl array i 1 |> ignore + countRemoveAll count + 1 + else + count + + countRemoveAll 0 + +let partition (f: 'T -> bool) (source: 'T[]) ([] cons: Cons<'T>) = + let len = source.Length + let res1 = allocateArrayFromCons cons len + let res2 = allocateArrayFromCons cons len + let mutable iTrue = 0 + let mutable iFalse = 0 + + for i = 0 to len - 1 do + if f source.[i] then + res1.[iTrue] <- source.[i] + iTrue <- iTrue + 1 + else + res2.[iFalse] <- source.[i] + iFalse <- iFalse + 1 + + res1 |> truncate iTrue, res2 |> truncate iFalse + +let find (predicate: 'T -> bool) (array: 'T[]) : 'T = + match findImpl predicate array with + | Some res -> res + | None -> indexNotFound () + +let tryFind (predicate: 'T -> bool) (array: 'T[]) : 'T option = findImpl predicate array + +let findIndex (predicate: 'T -> bool) (array: 'T[]) : int = + match findIndexImpl predicate array with + | index when index > -1 -> index + | _ -> + indexNotFound () + -1 + +let tryFindIndex (predicate: 'T -> bool) (array: 'T[]) : int option = + match findIndexImpl predicate array with + | index when index > -1 -> Some index + | _ -> None + +let pick chooser (array: _[]) = + let rec loop i = + if i >= array.Length then + indexNotFound () + else + match chooser array.[i] with + | None -> loop (i + 1) + | Some res -> res + + loop 0 + +let tryPick chooser (array: _[]) = + let rec loop i = + if i >= array.Length then + None + else + match chooser array.[i] with + | None -> loop (i + 1) + | res -> res + + loop 0 + +let findBack predicate (array: _[]) = + let rec loop i = + if i < 0 then + indexNotFound () + elif predicate array.[i] then + array.[i] + else + loop (i - 1) + + loop (array.Length - 1) + +let tryFindBack predicate (array: _[]) = + let rec loop i = + if i < 0 then + None + elif predicate array.[i] then + Some array.[i] + else + loop (i - 1) + + loop (array.Length - 1) + +let findLastIndex predicate (array: _[]) = + let rec loop i = + if i < 0 then + -1 + elif predicate array.[i] then + i + else + loop (i - 1) + + loop (array.Length - 1) + +let findIndexBack predicate (array: _[]) = + let rec loop i = + if i < 0 then + indexNotFound () + -1 + elif predicate array.[i] then + i + else + loop (i - 1) + + loop (array.Length - 1) + +let tryFindIndexBack predicate (array: _[]) = + let rec loop i = + if i < 0 then + None + elif predicate array.[i] then + Some i + else + loop (i - 1) + + loop (array.Length - 1) + +let choose (chooser: 'T -> 'U option) (array: 'T[]) ([] cons: Cons<'U>) = + let res: 'U[] = [||] + + for i = 0 to array.Length - 1 do + match chooser array.[i] with + | None -> () + | Some y -> pushImpl res y |> ignore + + match box cons with + | null -> res // avoid extra copy + | _ -> map id res cons + +let foldIndexed<'T, 'State> folder (state: 'State) (array: 'T[]) = + // if isTypedArrayImpl array then + // let mutable acc = state + // for i = 0 to array.Length - 1 do + // acc <- folder i acc array.[i] + // acc + // else + foldIndexedImpl (fun acc x i -> folder i acc x) state array + +let fold<'T, 'State> folder (state: 'State) (array: 'T[]) = + // if isTypedArrayImpl array then + // let mutable acc = state + // for i = 0 to array.Length - 1 do + // acc <- folder acc array.[i] + // acc + // else + foldImpl (fun acc x -> folder acc x) state array + +let iterate action (array: 'T[]) = + for i = 0 to array.Length - 1 do + action array.[i] + +let iterateIndexed action (array: 'T[]) = + for i = 0 to array.Length - 1 do + action i array.[i] + +let iterate2 action (array1: 'T1[]) (array2: 'T2[]) = + if array1.Length <> array2.Length then + differentLengths () + + for i = 0 to array1.Length - 1 do + action array1.[i] array2.[i] + +let iterateIndexed2 action (array1: 'T1[]) (array2: 'T2[]) = + if array1.Length <> array2.Length then + differentLengths () + + for i = 0 to array1.Length - 1 do + action i array1.[i] array2.[i] + +let isEmpty (array: 'T[]) = array.Length = 0 + +let forAll predicate (array: 'T[]) = + // if isTypedArrayImpl array then + // let mutable i = 0 + // let mutable result = true + // while i < array.Length && result do + // result <- predicate array.[i] + // i <- i + 1 + // result + // else + forAllImpl predicate array + +let permute f (array: 'T[]) = + let size = array.Length + let res = copyImpl array + let checkFlags = allocateArray size + + iterateIndexed + (fun i x -> + let j = f i + + if j < 0 || j >= size then + invalidOp "Not a valid permutation" + + res.[j] <- x + checkFlags.[j] <- 1 + ) + array + + let isValid = checkFlags |> forAllImpl ((=) 1) + + if not isValid then + invalidOp "Not a valid permutation" + + res + +let setSlice (target: 'T[]) (lower: int option) (upper: int option) (source: 'T[]) = + let lower = defaultArg lower 0 + let upper = defaultArg upper -1 + + let length = + (if upper >= 0 then + upper + else + target.Length - 1) + - lower + // can't cast to TypedArray, so can't use TypedArray-specific methods + // if isTypedArrayImpl target && source.Length <= length then + // typedArraySetImpl target source lower + // else + for i = 0 to length do + target.[i + lower] <- source.[i] + +let sortInPlaceBy (projection: 'a -> 'b) (xs: 'a[]) ([] comparer: IComparer<'b>) : unit = + sortInPlaceWithImpl (fun x y -> comparer.Compare(projection x, projection y)) xs + +let sortInPlace (xs: 'T[]) ([] comparer: IComparer<'T>) = + sortInPlaceWithImpl (fun x y -> comparer.Compare(x, y)) xs + +let inline internal sortInPlaceWith (comparer: 'T -> 'T -> int) (xs: 'T[]) = + sortInPlaceWithImpl comparer xs + xs + +let sort (xs: 'T[]) ([] comparer: IComparer<'T>) : 'T[] = + sortInPlaceWith (fun x y -> comparer.Compare(x, y)) (copyImpl xs) + +let sortBy (projection: 'a -> 'b) (xs: 'a[]) ([] comparer: IComparer<'b>) : 'a[] = + sortInPlaceWith (fun x y -> comparer.Compare(projection x, projection y)) (copyImpl xs) + +let sortDescending (xs: 'T[]) ([] comparer: IComparer<'T>) : 'T[] = + sortInPlaceWith (fun x y -> comparer.Compare(x, y) * -1) (copyImpl xs) + +let sortByDescending (projection: 'a -> 'b) (xs: 'a[]) ([] comparer: IComparer<'b>) : 'a[] = + sortInPlaceWith (fun x y -> comparer.Compare(projection x, projection y) * -1) (copyImpl xs) + +let sortWith (comparer: 'T -> 'T -> int) (xs: 'T[]) : 'T[] = sortInPlaceWith comparer (copyImpl xs) + +let allPairs (xs: 'T1[]) (ys: 'T2[]) : ('T1 * 'T2)[] = + let len1 = xs.Length + let len2 = ys.Length + let res = allocateArray (len1 * len2) + + for i = 0 to xs.Length - 1 do + for j = 0 to ys.Length - 1 do + res.[i * len2 + j] <- (xs.[i], ys.[j]) + + res + +let unfold<'T, 'State> (generator: 'State -> ('T * 'State) option) (state: 'State) : 'T[] = + let res: 'T[] = [||] + + let rec loop state = + match generator state with + | None -> () + | Some(x, s) -> + pushImpl res x |> ignore + loop s + + loop state + res + +// TODO: We should pass Cons<'T> here (and unzip3) but 'a and 'b may differ +let unzip (array: _[]) = + let len = array.Length + let res1 = allocateArray len + let res2 = allocateArray len + + iterateIndexed + (fun i (item1, item2) -> + res1.[i] <- item1 + res2.[i] <- item2 + ) + array + + res1, res2 + +let unzip3 (array: _[]) = + let len = array.Length + let res1 = allocateArray len + let res2 = allocateArray len + let res3 = allocateArray len + + iterateIndexed + (fun i (item1, item2, item3) -> + res1.[i] <- item1 + res2.[i] <- item2 + res3.[i] <- item3 + ) + array + + res1, res2, res3 + +let zip (array1: 'T[]) (array2: 'U[]) = + // Shorthand version: map2 (fun x y -> x, y) array1 array2 + if array1.Length <> array2.Length then + differentLengths () + + let result = allocateArray array1.Length + + for i = 0 to array1.Length - 1 do + result.[i] <- array1.[i], array2.[i] + + result + +let zip3 (array1: 'T[]) (array2: 'U[]) (array3: 'V[]) = + // Shorthand version: map3 (fun x y z -> x, y, z) array1 array2 array3 + if array1.Length <> array2.Length || array2.Length <> array3.Length then + differentLengths () + + let result = allocateArray array1.Length + + for i = 0 to array1.Length - 1 do + result.[i] <- array1.[i], array2.[i], array3.[i] + + result + +let chunkBySize (chunkSize: int) (array: 'T[]) : 'T[][] = + if chunkSize < 1 then + invalidArg "size" "The input must be positive." + + if array.Length = 0 then + [| [||] |] + else + let result: 'T[][] = [||] + // add each chunk to the result + for x = 0 to int (System.Math.Ceiling(float (array.Length) / float (chunkSize))) - 1 do + let start = x * chunkSize + let slice = subArrayImpl array start chunkSize + pushImpl result slice |> ignore + + result + +let splitAt (index: int) (array: 'T[]) : 'T[] * 'T[] = + if index < 0 || index > array.Length then + invalidArg "index" SR.indexOutOfBounds + + subArrayImpl array 0 index, skipImpl array index + +// Note that, though it's not consistent with `compare` operator, +// Array.compareWith doesn't compare first the length, see #2961 +let compareWith (comparer: 'T -> 'T -> int) (source1: 'T[]) (source2: 'T[]) = + if isNull source1 then + if isNull source2 then + 0 + else + -1 + elif isNull source2 then + 1 + else + let len1 = source1.Length + let len2 = source2.Length + + let len = + if len1 < len2 then + len1 + else + len2 + + let mutable i = 0 + let mutable res = 0 + + while res = 0 && i < len do + res <- comparer source1.[i] source2.[i] + i <- i + 1 + + if res <> 0 then + res + elif len1 > len2 then + 1 + elif len1 < len2 then + -1 + else + 0 + +let compareTo (comparer: 'T -> 'T -> int) (source1: 'T[]) (source2: 'T[]) = + if isNull source1 then + if isNull source2 then + 0 + else + -1 + elif isNull source2 then + 1 + else + let len1 = source1.Length + let len2 = source2.Length + + if len1 > len2 then + 1 + elif len1 < len2 then + -1 + else + let mutable i = 0 + let mutable res = 0 + + while res = 0 && i < len1 do + res <- comparer source1.[i] source2.[i] + i <- i + 1 + + res + +let equalsWith (equals: 'T -> 'T -> bool) (array1: 'T[]) (array2: 'T[]) = + if isNull array1 then + if isNull array2 then + true + else + false + elif isNull array2 then + false + else + let mutable i = 0 + let mutable result = true + let length1 = array1.Length + let length2 = array2.Length + + if length1 > length2 then + false + elif length1 < length2 then + false + else + while i < length1 && result do + result <- equals array1.[i] array2.[i] + i <- i + 1 + + result + +let exactlyOne (array: 'T[]) = + if array.Length = 1 then + array.[0] + elif array.Length = 0 then + invalidArg "array" LanguagePrimitives.ErrorStrings.InputSequenceEmptyString + else + invalidArg "array" "Input array too long" + +let tryExactlyOne (array: 'T[]) = + if array.Length = 1 then + Some(array.[0]) + else + None + +let head (array: 'T[]) = + if array.Length = 0 then + invalidArg "array" LanguagePrimitives.ErrorStrings.InputArrayEmptyString + else + array.[0] + +let tryHead (array: 'T[]) = + if array.Length = 0 then + None + else + Some array.[0] + +let tail (array: 'T[]) = + if array.Length = 0 then + invalidArg "array" "Not enough elements" + + skipImpl array 1 + +let item index (array: _[]) = array.[index] + +let tryItem index (array: 'T[]) = + if index < 0 || index >= array.Length then + None + else + Some array.[index] + +let foldBackIndexed<'T, 'State> folder (array: 'T[]) (state: 'State) = + // if isTypedArrayImpl array then + // let mutable acc = state + // let size = array.Length + // for i = 1 to size do + // acc <- folder (i-1) array.[size - i] acc + // acc + // else + foldBackIndexedImpl (fun acc x i -> folder i x acc) state array + +let foldBack<'T, 'State> folder (array: 'T[]) (state: 'State) = + // if isTypedArrayImpl array then + // foldBackIndexed (fun _ x acc -> folder x acc) array state + // else + foldBackImpl (fun acc x -> folder x acc) state array + +let foldIndexed2 folder state (array1: _[]) (array2: _[]) = + let mutable acc = state + + if array1.Length <> array2.Length then + failwith "Arrays have different lengths" + + for i = 0 to array1.Length - 1 do + acc <- folder i acc array1.[i] array2.[i] + + acc + +let fold2<'T1, 'T2, 'State> folder (state: 'State) (array1: 'T1[]) (array2: 'T2[]) = + foldIndexed2 (fun _ acc x y -> folder acc x y) state array1 array2 + +let foldBackIndexed2<'T1, 'T2, 'State> folder (array1: 'T1[]) (array2: 'T2[]) (state: 'State) = + let mutable acc = state + + if array1.Length <> array2.Length then + differentLengths () + + let size = array1.Length + + for i = 1 to size do + acc <- folder (i - 1) array1.[size - i] array2.[size - i] acc + + acc + +let foldBack2<'T1, 'T2, 'State> f (array1: 'T1[]) (array2: 'T2[]) (state: 'State) = + foldBackIndexed2 (fun _ x y acc -> f x y acc) array1 array2 state + +let reduce reduction (array: 'T[]) = + if array.Length = 0 then + invalidOp LanguagePrimitives.ErrorStrings.InputArrayEmptyString + // if isTypedArrayImpl array then + // foldIndexed (fun i acc x -> if i = 0 then x else reduction acc x) Unchecked.defaultof<_> array + // else + reduceImpl reduction array + +let reduceBack reduction (array: 'T[]) = + if array.Length = 0 then + invalidOp LanguagePrimitives.ErrorStrings.InputArrayEmptyString + // if isTypedArrayImpl array then + // foldBackIndexed (fun i x acc -> if i = 0 then x else reduction acc x) array Unchecked.defaultof<_> + // else + reduceBackImpl reduction array + +let forAll2 predicate array1 array2 = + fold2 (fun acc x y -> acc && predicate x y) true array1 array2 + +let rec existsOffset predicate (array: 'T[]) index = + if index = array.Length then + false + else + predicate array.[index] || existsOffset predicate array (index + 1) + +let exists predicate array = existsOffset predicate array 0 + +let rec existsOffset2 predicate (array1: _[]) (array2: _[]) index = + if index = array1.Length then + false + else + predicate array1.[index] array2.[index] + || existsOffset2 predicate array1 array2 (index + 1) + +let rec exists2 predicate (array1: _[]) (array2: _[]) = + if array1.Length <> array2.Length then + differentLengths () + + existsOffset2 predicate array1 array2 0 + +let sum (array: 'T[]) ([] adder: IGenericAdder<'T>) : 'T = + let mutable acc = adder.GetZero() + + for i = 0 to array.Length - 1 do + acc <- adder.Add(acc, array.[i]) + + acc + +let sumBy (projection: 'T -> 'T2) (array: 'T[]) ([] adder: IGenericAdder<'T2>) : 'T2 = + let mutable acc = adder.GetZero() + + for i = 0 to array.Length - 1 do + acc <- adder.Add(acc, projection array.[i]) + + acc + +let maxBy (projection: 'a -> 'b) (xs: 'a[]) ([] comparer: IComparer<'b>) : 'a = + reduce + (fun x y -> + if comparer.Compare(projection y, projection x) > 0 then + y + else + x + ) + xs + +let max (xs: 'a[]) ([] comparer: IComparer<'a>) : 'a = + reduce + (fun x y -> + if comparer.Compare(y, x) > 0 then + y + else + x + ) + xs + +let minBy (projection: 'a -> 'b) (xs: 'a[]) ([] comparer: IComparer<'b>) : 'a = + reduce + (fun x y -> + if comparer.Compare(projection y, projection x) > 0 then + x + else + y + ) + xs + +let min (xs: 'a[]) ([] comparer: IComparer<'a>) : 'a = + reduce + (fun x y -> + if comparer.Compare(y, x) > 0 then + x + else + y + ) + xs + +let average (array: 'T[]) ([] averager: IGenericAverager<'T>) : 'T = + if array.Length = 0 then + invalidArg "array" LanguagePrimitives.ErrorStrings.InputArrayEmptyString + + let mutable total = averager.GetZero() + + for i = 0 to array.Length - 1 do + total <- averager.Add(total, array.[i]) + + averager.DivideByInt(total, array.Length) + +let averageBy (projection: 'T -> 'T2) (array: 'T[]) ([] averager: IGenericAverager<'T2>) : 'T2 = + if array.Length = 0 then + invalidArg "array" LanguagePrimitives.ErrorStrings.InputArrayEmptyString + + let mutable total = averager.GetZero() + + for i = 0 to array.Length - 1 do + total <- averager.Add(total, projection array.[i]) + + averager.DivideByInt(total, array.Length) + +// let toList (source: 'T[]) = List.ofArray (see Replacements) + +let windowed (windowSize: int) (source: 'T[]) : 'T[][] = + if windowSize <= 0 then + failwith "windowSize must be positive" + + let res = + FSharp.Core.Operators.max 0 (source.Length - windowSize + 1) |> allocateArray + + for i = windowSize to source.Length do + res.[i - windowSize] <- source.[i - windowSize .. i - 1] + + res + +let splitInto (chunks: int) (array: 'T[]) : 'T[][] = + if chunks < 1 then + invalidArg "chunks" "The input must be positive." + + if array.Length = 0 then + [| [||] |] + else + let result: 'T[][] = [||] + let chunks = FSharp.Core.Operators.min chunks array.Length + let minChunkSize = array.Length / chunks + let chunksWithExtraItem = array.Length % chunks + + for i = 0 to chunks - 1 do + let chunkSize = + if i < chunksWithExtraItem then + minChunkSize + 1 + else + minChunkSize + + let start = i * minChunkSize + (FSharp.Core.Operators.min chunksWithExtraItem i) + let slice = subArrayImpl array start chunkSize + pushImpl result slice |> ignore + + result + +let transpose (arrays: 'T[] seq) ([] cons: Cons<'T>) : 'T[][] = + let arrays = + if isDynamicArrayImpl arrays then + arrays :?> 'T[][] // avoid extra copy + else + arrayFrom arrays + + let len = arrays.Length + + match len with + | 0 -> allocateArray 0 + | _ -> + let firstArray = arrays.[0] + let lenInner = firstArray.Length + + if arrays |> forAll (fun a -> a.Length = lenInner) |> not then + differentLengths () + + let result: 'T[][] = allocateArray lenInner + + for i in 0 .. lenInner - 1 do + result.[i] <- allocateArrayFromCons cons len + + for j in 0 .. len - 1 do + result.[i].[j] <- arrays.[j].[i] + + result + +let insertAt (index: int) (y: 'T) (xs: 'T[]) ([] cons: Cons<'T>) : 'T[] = + let len = xs.Length + + if index < 0 || index > len then + invalidArg "index" SR.indexOutOfBounds + + let target = allocateArrayFromCons cons (len + 1) + + for i = 0 to (index - 1) do + target.[i] <- xs.[i] + + target.[index] <- y + + for i = index to (len - 1) do + target.[i + 1] <- xs.[i] + + target + +let insertManyAt (index: int) (ys: seq<'T>) (xs: 'T[]) ([] cons: Cons<'T>) : 'T[] = + let len = xs.Length + + if index < 0 || index > len then + invalidArg "index" SR.indexOutOfBounds + + let ys = arrayFrom ys + let len2 = ys.Length + let target = allocateArrayFromCons cons (len + len2) + + for i = 0 to (index - 1) do + target.[i] <- xs.[i] + + for i = 0 to (len2 - 1) do + target.[index + i] <- ys.[i] + + for i = index to (len - 1) do + target.[i + len2] <- xs.[i] + + target + +let removeAt (index: int) (xs: 'T[]) : 'T[] = + if index < 0 || index >= xs.Length then + invalidArg "index" SR.indexOutOfBounds + + let mutable i = -1 + + xs + |> filter (fun _ -> + i <- i + 1 + i <> index + ) + +let removeManyAt (index: int) (count: int) (xs: 'T[]) : 'T[] = + let mutable i = -1 + // incomplete -1, in-progress 0, complete 1 + let mutable status = -1 + + let ys = + xs + |> filter (fun _ -> + i <- i + 1 + + if i = index then + status <- 0 + false + elif i > index then + if i < index + count then + false + else + status <- 1 + true + else + true + ) + + let status = + if status = 0 && i + 1 = index + count then + 1 + else + status + + if status < 1 then + // F# always says the wrong parameter is index but the problem may be count + let arg = + if status < 0 then + "index" + else + "count" + + invalidArg arg SR.indexOutOfBounds + + ys + +let updateAt (index: int) (y: 'T) (xs: 'T[]) ([] cons: Cons<'T>) : 'T[] = + let len = xs.Length + + if index < 0 || index >= len then + invalidArg "index" SR.indexOutOfBounds + + let target = allocateArrayFromCons cons len + + for i = 0 to (len - 1) do + target.[i] <- + if i = index then + y + else + xs.[i] + + target diff --git a/src/fable-library-lua/Choice.fs b/src/fable-library-lua/Choice.fs new file mode 100644 index 0000000000..4d6ae677c0 --- /dev/null +++ b/src/fable-library-lua/Choice.fs @@ -0,0 +1,85 @@ +namespace FSharp.Core + +[] +type Result<'T, 'TError> = + | Ok of ResultValue: 'T + | Error of ErrorValue: 'TError + +module Result = + [] + let map (mapping: 'a -> 'b) (result: Result<'a, 'c>) : Result<'b, 'c> = + match result with + | Error e -> Error e + | Ok x -> Ok(mapping x) + + [] + let mapError (mapping: 'a -> 'b) (result: Result<'c, 'a>) : Result<'c, 'b> = + match result with + | Error e -> Error(mapping e) + | Ok x -> Ok x + + [] + let bind (binder: 'a -> Result<'b, 'c>) (result: Result<'a, 'c>) : Result<'b, 'c> = + match result with + | Error e -> Error e + | Ok x -> binder x + +[] +type Choice<'T1, 'T2> = + | Choice1Of2 of 'T1 + | Choice2Of2 of 'T2 + +[] +type Choice<'T1, 'T2, 'T3> = + | Choice1Of3 of 'T1 + | Choice2Of3 of 'T2 + | Choice3Of3 of 'T3 + +[] +type Choice<'T1, 'T2, 'T3, 'T4> = + | Choice1Of4 of 'T1 + | Choice2Of4 of 'T2 + | Choice3Of4 of 'T3 + | Choice4Of4 of 'T4 + +[] +type Choice<'T1, 'T2, 'T3, 'T4, 'T5> = + | Choice1Of5 of 'T1 + | Choice2Of5 of 'T2 + | Choice3Of5 of 'T3 + | Choice4Of5 of 'T4 + | Choice5Of5 of 'T5 + +[] +type Choice<'T1, 'T2, 'T3, 'T4, 'T5, 'T6> = + | Choice1Of6 of 'T1 + | Choice2Of6 of 'T2 + | Choice3Of6 of 'T3 + | Choice4Of6 of 'T4 + | Choice5Of6 of 'T5 + | Choice6Of6 of 'T6 + +[] +type Choice<'T1, 'T2, 'T3, 'T4, 'T5, 'T6, 'T7> = + | Choice1Of7 of 'T1 + | Choice2Of7 of 'T2 + | Choice3Of7 of 'T3 + | Choice4Of7 of 'T4 + | Choice5Of7 of 'T5 + | Choice6Of7 of 'T6 + | Choice7Of7 of 'T7 + +module Choice = + let makeChoice1Of2 (x: 'T1) : Choice<'T1, 'a> = Choice1Of2 x + + let makeChoice2Of2 (x: 'T2) : Choice<'a, 'T2> = Choice2Of2 x + + let tryValueIfChoice1Of2 (x: Choice<'T1, 'T2>) : Option<'T1> = + match x with + | Choice1Of2 x -> Some x + | _ -> None + + let tryValueIfChoice2Of2 (x: Choice<'T1, 'T2>) : Option<'T2> = + match x with + | Choice2Of2 x -> Some x + | _ -> None diff --git a/src/fable-library-lua/Fable.Library.fsproj b/src/fable-library-lua/Fable.Library.fsproj new file mode 100644 index 0000000000..0858de5d60 --- /dev/null +++ b/src/fable-library-lua/Fable.Library.fsproj @@ -0,0 +1,42 @@ + + + + net8.0 + $(DefineConstants);FABLE_COMPILER + $(DefineConstants);FX_NO_BIGINT + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/fable-library-lua/Global.fs b/src/fable-library-lua/Global.fs new file mode 100644 index 0000000000..acffeabc1c --- /dev/null +++ b/src/fable-library-lua/Global.fs @@ -0,0 +1,34 @@ +namespace Fable.Core + +type IGenericAdder<'T> = + abstract GetZero: unit -> 'T + abstract Add: 'T * 'T -> 'T + +type IGenericAverager<'T> = + abstract GetZero: unit -> 'T + abstract Add: 'T * 'T -> 'T + abstract DivideByInt: 'T * int -> 'T + +type Symbol_wellknown = + abstract ``Symbol.toStringTag``: string + +type IJsonSerializable = + abstract toJSON: unit -> obj + +namespace global + +[] +module SR = + let indexOutOfBounds = + "The index was outside the range of elements in the collection." + + let inputWasEmpty = "Collection was empty." + let inputMustBeNonNegative = "The input must be non-negative." + let inputSequenceEmpty = "The input sequence was empty." + let inputSequenceTooLong = "The input sequence contains more than one element." + + let keyNotFoundAlt = + "An index satisfying the predicate was not found in the collection." + + let differentLengths = "The collections had different lengths." + let notEnoughElements = "The input sequence has an insufficient number of elements." diff --git a/src/fable-library-lua/Native.fs b/src/fable-library-lua/Native.fs new file mode 100644 index 0000000000..dcfcb141a8 --- /dev/null +++ b/src/fable-library-lua/Native.fs @@ -0,0 +1,112 @@ +module Native + +// Disables warn:1204 raised by use of LanguagePrimitives.ErrorStrings.* +#nowarn "1204" + +open System.Collections.Generic +open Fable.Core +open Fable.Core.JsInterop +open Fable.Core.LuaInterop + +[] +type Cons<'T> = + [] + abstract Allocate: len: int -> 'T[] + +module Helpers = + [] + let arrayFrom (xs: 'T seq) : 'T[] = nativeOnly + + [] + let allocateArray (len: int) : 'T[] = nativeOnly + + [] + let allocateArrayFrom (xs: 'T[]) (len: int) : 'T[] = nativeOnly + + let allocateArrayFromCons (cons: Cons<'T>) (len: int) : 'T[] = + if isNull cons then + Lua.Array.Create(len) + else + cons.Allocate(len) + + let inline isDynamicArrayImpl arr = Lua.Array.isArray arr + + // let inline typedArraySetImpl (target: obj) (source: obj) (offset: int): unit = + // !!target?set(source, offset) + + [] + let concatImpl (array1: 'T[]) (arrays: 'T[] seq) : 'T[] = nativeOnly + + let fillImpl (array: 'T[]) (value: 'T) (start: int) (count: int) : 'T[] = + for i = 0 to count - 1 do + array.[i + start] <- value + + array + + [] + let foldImpl (folder: 'State -> 'T -> 'State) (state: 'State) (array: 'T[]) : 'State = nativeOnly + + let inline foldIndexedImpl (folder: 'State -> 'T -> int -> 'State) (state: 'State) (array: 'T[]) : 'State = + !! array?reduce (System.Func<'State, 'T, int, 'State>(folder), state) + + let inline foldBackImpl (folder: 'State -> 'T -> 'State) (state: 'State) (array: 'T[]) : 'State = + !! array?reduceRight (System.Func<'State, 'T, 'State>(folder), state) + + let inline foldBackIndexedImpl (folder: 'State -> 'T -> int -> 'State) (state: 'State) (array: 'T[]) : 'State = + !! array?reduceRight (System.Func<'State, 'T, int, 'State>(folder), state) + + // Typed arrays not supported, only dynamic ones do + let inline pushImpl (array: 'T[]) (item: 'T) : int = !! array?append (item) + + // Typed arrays not supported, only dynamic ones do + let inline insertImpl (array: 'T[]) (index: int) (item: 'T) : 'T[] = !! array?splice (index, 0, item) + + // Typed arrays not supported, only dynamic ones do + let inline spliceImpl (array: 'T[]) (start: int) (deleteCount: int) : 'T[] = !! array?splice (start, deleteCount) + + [] + let reverseImpl (array: 'T[]) : 'T[] = nativeOnly + + [] + let copyImpl (array: 'T[]) : 'T[] = nativeOnly + + [] + let skipImpl (array: 'T[]) (count: int) : 'T[] = nativeOnly + //__TS__ArraySplice + [] + let subArrayImpl (array: 'T[]) (start: int) (count: int) : 'T[] = nativeOnly + + let inline indexOfImpl (array: 'T[]) (item: 'T) (start: int) : int = !! array?indexOf (item, start) + + let inline findImpl (predicate: 'T -> bool) (array: 'T[]) : 'T option = !! array?find (predicate) + + let inline findIndexImpl (predicate: 'T -> bool) (array: 'T[]) : int = !! array?findIndex (predicate) + + let inline collectImpl (mapping: 'T -> 'U[]) (array: 'T[]) : 'U[] = !! array?flatMap (mapping) + + let inline containsImpl (predicate: 'T -> bool) (array: 'T[]) : bool = !! array?filter (predicate) + + let inline existsImpl (predicate: 'T -> bool) (array: 'T[]) : bool = !! array?some (predicate) + + let inline forAllImpl (predicate: 'T -> bool) (array: 'T[]) : bool = !! array?every (predicate) + + let inline filterImpl (predicate: 'T -> bool) (array: 'T[]) : 'T[] = !! array?filter (predicate) + + [] + let reduceImpl (reduction: 'T -> 'T -> 'T) (array: 'T[]) : 'T = nativeOnly + + let inline reduceBackImpl (reduction: 'T -> 'T -> 'T) (array: 'T[]) : 'T = !! array?reduceRight (reduction) + + // Inlining in combination with dynamic application may cause problems with uncurrying + // Using Emit keeps the argument signature. Note: Python cannot take an argument here. + [] + let sortInPlaceWithImpl (comparer: 'T -> 'T -> int) (array: 'T[]) : unit = nativeOnly //!!array?sort(comparer) + + [] + let copyToTypedArray (src: 'T[]) (srci: int) (trg: 'T[]) (trgi: int) (cnt: int) : unit = nativeOnly diff --git a/src/fable-library-lua/README.md b/src/fable-library-lua/README.md new file mode 100644 index 0000000000..b472dd39ea --- /dev/null +++ b/src/fable-library-lua/README.md @@ -0,0 +1,13 @@ +# Fable Library for Lua + +This module is used as the [Fable](https://fable.io/) library for +Lua. + +On windows, testing was done against lua 5.2.4, which you can get through chocolatey: + +https://community.chocolatey.org/packages/lua52 +or direct, although it needs to be in the path: +http://luabinaries.sourceforge.net/download.html + + +choco install lua52 \ No newline at end of file diff --git a/src/fable-library-lua/Timer.fs b/src/fable-library-lua/Timer.fs new file mode 100644 index 0000000000..87f27724d2 --- /dev/null +++ b/src/fable-library-lua/Timer.fs @@ -0,0 +1,26 @@ +module Timer + +open Fable.Core + +/// This class represents an action that should be run only after a +/// certain amount of time has passed — a timer. Timer is a subclass of +/// Thread and as such also functions as an example of creating custom +/// threads. +type ITimer = + abstract daemon: bool with get, set + + /// Start the thread’s activity. + abstract start: unit -> unit + /// Stop the timer, and cancel the execution of the timer’s action. + /// This will only work if the timer is still in its waiting stage. + abstract cancel: unit -> unit + + /// Create a timer that will run function with arguments args and + /// keyword arguments kwargs, after interval seconds have passed. If + /// args is None (the default) then an empty list will be used. If + /// kwargs is None (the default) then an empty dict will be used. + [] + abstract Create: float * (unit -> unit) -> ITimer + +[] +let Timer: ITimer = nativeOnly diff --git a/src/fable-library-lua/Util.lua b/src/fable-library-lua/Util.lua new file mode 100644 index 0000000000..6f238be622 --- /dev/null +++ b/src/fable-library-lua/Util.lua @@ -0,0 +1,1630 @@ +-- mod = {} + +-- -- https://web.archive.org/web/20131225070434/http://snippets.luacode.org/snippets/Deep_Comparison_of_Two_Values_3 +-- function deepcompare(t1,t2,ignore_mt) +-- local ty1 = type(t1) +-- local ty2 = type(t2) +-- if ty1 ~= ty2 then return false end +-- -- non-table types can be directly compared +-- if ty1 ~= 'table' and ty2 ~= 'table' then return t1 == t2 end +-- -- as well as tables which have the metamethod __eq +-- local mt = getmetatable(t1) +-- if not ignore_mt and mt and mt.__eq then return t1 == t2 end +-- for k1,v1 in pairs(t1) do +-- local v2 = t2[k1] +-- if v2 == nil or not deepcompare(v1,v2) then return false end +-- end +-- for k2,v2 in pairs(t2) do +-- local v1 = t1[k2] +-- if v1 == nil or not deepcompare(v1,v2) then return false end +-- end +-- return true +-- end + +-- function mod.equals(a, b) +-- return deepcompare(a, b, true) +-- end +-- return mod + +function TableConcat(t1,t2) + for i=1,#t2 do + t1[#t1+1] = t2[i] + end + return t1 +end + +function table.create(len) + local a = {} + for i = 1, len do + a[i] = 0 + end + return a +end + +function table.slice(tbl, first, last) + local sliced = {} + + for i = first or 1, last or #tbl do + sliced[#sliced+1] = tbl[i] + end + + return sliced +end + +function table.reverse(tbl) + local reved = {} + for i = #tbl, 1, -1 do + reved[#reved+1] = tbl[i] + end + return reved +end + +function table.shallow_copy(t) + local t2 = {} + for k,v in pairs(t) do + t2[k] = v + end + return t2 +end +-- https://stackoverflow.com/questions/5977654/how-do-i-use-the-bitwise-operator-xor-in-lua +local function BitXOR(a,b)--Bitwise xor + local p,c=1,0 + while a>0 and b>0 do + local ra,rb=a%2,b%2 + if ra~=rb then c=c+p end + a,b,p=(a-ra)/2,(b-rb)/2,p*2 + end + if a0 do + local ra=a%2 + if ra>0 then c=c+p end + a,p=(a-ra)/2,p*2 + end + return c +end + +local function BitOR(a,b)--Bitwise or + local p,c=1,0 + while a+b>0 do + local ra,rb=a%2,b%2 + if ra+rb>0 then c=c+p end + a,b,p=(a-ra)/2,(b-rb)/2,p*2 + end + return c +end + +local function BitAND(a,b)--Bitwise and + local p,c=1,0 + while a>0 and b>0 do + local ra,rb=a%2,b%2 + if ra+rb>1 then c=c+p end + a,b,p=(a-ra)/2,(b-rb)/2,p*2 + end + return c +end + +function lshift(x, by) + return x * 2 ^ by + end + + function rshift(x, by) + return math.floor(x / 2 ^ by) + end + + +--[[ Generated with https://github.com/TypeScriptToLua/TypeScriptToLua ]] +-- Lua Library inline imports +____symbolMetatable = { + __tostring = function(self) + return ("Symbol(" .. (self.description or "")) .. ")" + end +} +function __TS__Symbol(description) + return setmetatable({description = description}, ____symbolMetatable) +end +Symbol = { + iterator = __TS__Symbol("Symbol.iterator"), + hasInstance = __TS__Symbol("Symbol.hasInstance"), + species = __TS__Symbol("Symbol.species"), + toStringTag = __TS__Symbol("Symbol.toStringTag") +} + +function __TS__ArrayIsArray(value) + return (type(value) == "table") and ((value[1] ~= nil) or (next(value, nil) == nil)) +end + +function __TS__Class(self) + local c = {prototype = {}} + c.prototype.__index = c.prototype + c.prototype.constructor = c + return c +end + +function __TS__ClassExtends(target, base) + target.____super = base + local staticMetatable = setmetatable({__index = base}, base) + setmetatable(target, staticMetatable) + local baseMetatable = getmetatable(base) + if baseMetatable then + if type(baseMetatable.__index) == "function" then + staticMetatable.__index = baseMetatable.__index + end + if type(baseMetatable.__newindex) == "function" then + staticMetatable.__newindex = baseMetatable.__newindex + end + end + setmetatable(target.prototype, base.prototype) + if type(base.prototype.__index) == "function" then + target.prototype.__index = base.prototype.__index + end + if type(base.prototype.__newindex) == "function" then + target.prototype.__newindex = base.prototype.__newindex + end + if type(base.prototype.__tostring) == "function" then + target.prototype.__tostring = base.prototype.__tostring + end +end + +function __TS__New(target, ...) + local instance = setmetatable({}, target.prototype) + instance:____constructor(...) + return instance +end + +function __TS__GetErrorStack(self, constructor) + local level = 1 + while true do + local info = debug.getinfo(level, "f") + level = level + 1 + if not info then + level = 1 + break + elseif info.func == constructor then + break + end + end + return debug.traceback(nil, level) +end +function __TS__WrapErrorToString(self, getDescription) + return function(self) + local description = getDescription(self) + local caller = debug.getinfo(3, "f") + if (_VERSION == "Lua 5.1") or (caller and (caller.func ~= error)) then + return description + else + return (tostring(description) .. "\n") .. self.stack + end + end +end +function __TS__InitErrorClass(self, Type, name) + Type.name = name + return setmetatable( + Type, + { + __call = function(____, _self, message) return __TS__New(Type, message) end + } + ) +end +Error = __TS__InitErrorClass( + _G, + (function() + local ____ = __TS__Class() + ____.name = "" + function ____.prototype.____constructor(self, message) + if message == nil then + message = "" + end + self.message = message + self.name = "Error" + self.stack = __TS__GetErrorStack(_G, self.constructor.new) + local metatable = getmetatable(self) + if not metatable.__errorToStringPatched then + metatable.__errorToStringPatched = true + metatable.__tostring = __TS__WrapErrorToString(_G, metatable.__tostring) + end + end + function ____.prototype.__tostring(self) + return (((self.message ~= "") and (function() return (self.name .. ": ") .. self.message end)) or (function() return self.name end))() + end + return ____ + end)(), + "Error" +) +for ____, errorName in ipairs({"RangeError", "ReferenceError", "SyntaxError", "TypeError", "URIError"}) do + _G[errorName] = __TS__InitErrorClass( + _G, + (function() + local ____ = __TS__Class() + ____.name = ____.name + __TS__ClassExtends(____, Error) + function ____.prototype.____constructor(self, ...) + Error.prototype.____constructor(self, ...) + self.name = errorName + end + return ____ + end)(), + errorName + ) +end + +function __TS__ObjectAssign(to, ...) + local sources = {...} + if to == nil then + return to + end + for ____, source in ipairs(sources) do + for key in pairs(source) do + to[key] = source[key] + end + end + return to +end + +function __TS__CloneDescriptor(____bindingPattern0) + local enumerable + enumerable = ____bindingPattern0.enumerable + local configurable + configurable = ____bindingPattern0.configurable + local get + get = ____bindingPattern0.get + local set + set = ____bindingPattern0.set + local writable + writable = ____bindingPattern0.writable + local value + value = ____bindingPattern0.value + local descriptor = {enumerable = enumerable == true, configurable = configurable == true} + local hasGetterOrSetter = (get ~= nil) or (set ~= nil) + local hasValueOrWritableAttribute = (writable ~= nil) or (value ~= nil) + if hasGetterOrSetter and hasValueOrWritableAttribute then + error("Invalid property descriptor. Cannot both specify accessors and a value or writable attribute.", 0) + end + if get or set then + descriptor.get = get + descriptor.set = set + else + descriptor.value = value + descriptor.writable = writable == true + end + return descriptor +end + +function ____descriptorIndex(self, key) + local value = rawget(self, key) + if value ~= nil then + return value + end + local metatable = getmetatable(self) + while metatable do + local rawResult = rawget(metatable, key) + if rawResult ~= nil then + return rawResult + end + local descriptors = rawget(metatable, "_descriptors") + if descriptors then + local descriptor = descriptors[key] + if descriptor then + if descriptor.get then + return descriptor.get(self) + end + return descriptor.value + end + end + metatable = getmetatable(metatable) + end +end +function ____descriptorNewindex(self, key, value) + local metatable = getmetatable(self) + while metatable do + local descriptors = rawget(metatable, "_descriptors") + if descriptors then + local descriptor = descriptors[key] + if descriptor then + if descriptor.set then + descriptor.set(self, value) + else + if descriptor.writable == false then + error( + ((("Cannot assign to read only property '" .. key) .. "' of object '") .. tostring(self)) .. "'", + 0 + ) + end + descriptor.value = value + end + return + end + end + metatable = getmetatable(metatable) + end + rawset(self, key, value) +end +function __TS__SetDescriptor(target, key, desc, isPrototype) + if isPrototype == nil then + isPrototype = false + end + local metatable = ((isPrototype and (function() return target end)) or (function() return getmetatable(target) end))() + if not metatable then + metatable = {} + setmetatable(target, metatable) + end + local value = rawget(target, key) + if value ~= nil then + rawset(target, key, nil) + end + if not rawget(metatable, "_descriptors") then + metatable._descriptors = {} + end + local descriptor = __TS__CloneDescriptor(desc) + metatable._descriptors[key] = descriptor + metatable.__index = ____descriptorIndex + metatable.__newindex = ____descriptorNewindex +end + +function __TS__StringAccess(self, index) + if (index >= 0) and (index < #self) then + return string.sub(self, index + 1, index + 1) + end +end + +____radixChars = "0123456789abcdefghijklmnopqrstuvwxyz" +function __TS__NumberToString(self, radix) + if ((((radix == nil) or (radix == 10)) or (self == math.huge)) or (self == -math.huge)) or (self ~= self) then + return tostring(self) + end + radix = math.floor(radix) + if (radix < 2) or (radix > 36) then + error("toString() radix argument must be between 2 and 36", 0) + end + local integer, fraction = math.modf( + math.abs(self) + ) + local result = "" + if radix == 8 then + result = string.format("%o", integer) + elseif radix == 16 then + result = string.format("%x", integer) + else + repeat + do + result = __TS__StringAccess(____radixChars, integer % radix) .. result + integer = math.floor(integer / radix) + end + until not (integer ~= 0) + end + if fraction ~= 0 then + result = result .. "." + local delta = 1e-16 + repeat + do + fraction = fraction * radix + delta = delta * radix + local digit = math.floor(fraction) + result = result .. __TS__StringAccess(____radixChars, digit) + fraction = fraction - digit + end + until not (fraction >= delta) + end + if self < 0 then + result = "-" .. result + end + return result +end + +function __TS__InstanceOf(obj, classTbl) + if type(classTbl) ~= "table" then + error("Right-hand side of 'instanceof' is not an object", 0) + end + if classTbl[Symbol.hasInstance] ~= nil then + return not (not classTbl[Symbol.hasInstance](classTbl, obj)) + end + if type(obj) == "table" then + local luaClass = obj.constructor + while luaClass ~= nil do + if luaClass == classTbl then + return true + end + luaClass = luaClass.____super + end + end + return false +end + +function __TS__IteratorGeneratorStep(self) + local co = self.____coroutine + local status, value = coroutine.resume(co) + if not status then + error(value, 0) + end + if coroutine.status(co) == "dead" then + return + end + return true, value +end +function __TS__IteratorIteratorStep(self) + local result = self:next() + if result.done then + return + end + return true, result.value +end +function __TS__IteratorStringStep(self, index) + index = index + 1 + if index > #self then + return + end + return index, string.sub(self, index, index) +end +function __TS__Iterator(iterable) + if type(iterable) == "string" then + return __TS__IteratorStringStep, iterable, 0 + elseif iterable.____coroutine ~= nil then + return __TS__IteratorGeneratorStep, iterable + elseif iterable[Symbol.iterator] then + local iterator = iterable[Symbol.iterator](iterable) + return __TS__IteratorIteratorStep, iterator + else + return ipairs(iterable) + end +end + +WeakMap = (function() + local WeakMap = __TS__Class() + WeakMap.name = "WeakMap" + function WeakMap.prototype.____constructor(self, entries) + self[Symbol.toStringTag] = "WeakMap" + self.items = {} + setmetatable(self.items, {__mode = "k"}) + if entries == nil then + return + end + local iterable = entries + if iterable[Symbol.iterator] then + local iterator = iterable[Symbol.iterator](iterable) + while true do + local result = iterator:next() + if result.done then + break + end + local value = result.value + self.items[value[1]] = value[2] + end + else + for ____, kvp in ipairs(entries) do + self.items[kvp[1]] = kvp[2] + end + end + end + function WeakMap.prototype.delete(self, key) + local contains = self:has(key) + self.items[key] = nil + return contains + end + function WeakMap.prototype.get(self, key) + return self.items[key] + end + function WeakMap.prototype.has(self, key) + return self.items[key] ~= nil + end + function WeakMap.prototype.set(self, key, value) + self.items[key] = value + return self + end + WeakMap[Symbol.species] = WeakMap + return WeakMap +end)() + +function __TS__StringCharCodeAt(self, index) + if index ~= index then + index = 0 + end + if index < 0 then + return 0 / 0 + end + return string.byte(self, index + 1) or (0 / 0) +end + +function __TS__ArrayReduce(arr, callbackFn, ...) + local len = #arr + local k = 0 + local accumulator = nil + if select("#", ...) ~= 0 then + accumulator = select(1, ...) + elseif len > 0 then + accumulator = arr[1] + k = 1 + else + error("Reduce of empty array with no initial value", 0) + end + for i = k, len - 1 do + accumulator = callbackFn(_G, accumulator, arr[i + 1], i, arr) + end + return accumulator +end + +function __TS__TypeOf(value) + local luaType = type(value) + if luaType == "table" then + return "object" + elseif luaType == "nil" then + return "undefined" + else + return luaType + end +end + +function __TS__ObjectValues(obj) + local result = {} + for key in pairs(obj) do + result[#result + 1] = obj[key] + end + return result +end + +function __TS__ArrayMap(arr, callbackfn) + local newArray = {} + do + local i = 0 + while i < #arr do + newArray[i + 1] = callbackfn(_G, arr[i + 1], i, arr) + i = i + 1 + end + end + return newArray +end + +function __TS__ObjectKeys(obj) + local result = {} + for key in pairs(obj) do + result[#result + 1] = key + end + return result +end + +function __TS__ArraySort(arr, compareFn) + if compareFn ~= nil then + table.sort( + arr, + function(a, b) return compareFn(_G, a, b) < 0 end + ) + else + table.sort(arr) + end + return arr +end + +function __TS__StringReplace(source, searchValue, replaceValue) + searchValue = string.gsub(searchValue, "[%%%(%)%.%+%-%*%?%[%^%$]", "%%%1") + if type(replaceValue) == "string" then + replaceValue = string.gsub(replaceValue, "%%", "%%%%") + local result = string.gsub(source, searchValue, replaceValue, 1) + return result + else + local result = string.gsub( + source, + searchValue, + function(match) return replaceValue(_G, match) end, + 1 + ) + return result + end +end + +function __TS__ArraySplice(list, ...) + local len = #list + local actualArgumentCount = select("#", ...) + local start = select(1, ...) + local deleteCount = select(2, ...) + local actualStart + if start < 0 then + actualStart = math.max(len + start, 0) + else + actualStart = math.min(start, len) + end + local itemCount = math.max(actualArgumentCount - 2, 0) + local actualDeleteCount + if actualArgumentCount == 0 then + actualDeleteCount = 0 + elseif actualArgumentCount == 1 then + actualDeleteCount = len - actualStart + else + actualDeleteCount = math.min( + math.max(deleteCount or 0, 0), + len - actualStart + ) + end + local out = {} + do + local k = 0 + while k < actualDeleteCount do + local from = actualStart + k + if list[from + 1] then + out[k + 1] = list[from + 1] + end + k = k + 1 + end + end + if itemCount < actualDeleteCount then + do + local k = actualStart + while k < (len - actualDeleteCount) do + local from = k + actualDeleteCount + local to = k + itemCount + if list[from + 1] then + list[to + 1] = list[from + 1] + else + list[to + 1] = nil + end + k = k + 1 + end + end + do + local k = len + while k > ((len - actualDeleteCount) + itemCount) do + list[k] = nil + k = k - 1 + end + end + elseif itemCount > actualDeleteCount then + do + local k = len - actualDeleteCount + while k > actualStart do + local from = (k + actualDeleteCount) - 1 + local to = (k + itemCount) - 1 + if list[from + 1] then + list[to + 1] = list[from + 1] + else + list[to + 1] = nil + end + k = k - 1 + end + end + end + local j = actualStart + for i = 3, actualArgumentCount do + list[j + 1] = select(i, ...) + j = j + 1 + end + do + local k = #list - 1 + while k >= ((len - actualDeleteCount) + itemCount) do + list[k + 1] = nil + k = k - 1 + end + end + return out +end + +function __TS__ArrayConcat(arr1, ...) + local args = {...} + local out = {} + for ____, val in ipairs(arr1) do + out[#out + 1] = val + end + for ____, arg in ipairs(args) do + if __TS__ArrayIsArray(arg) then + local argAsArray = arg + for ____, val in ipairs(argAsArray) do + out[#out + 1] = val + end + else + out[#out + 1] = arg + end + end + return out +end + +local ____exports = {} +local isComparable, isEquatable, isHashable, equalObjects, compareObjects +function ____exports.isArrayLike(self, x) + return __TS__ArrayIsArray(x) or ArrayBuffer:isView(x) +end +function isComparable(self, x) + return type(x.CompareTo) == "function" +end +function isEquatable(self, x) + return type(x.Equals) == "function" +end +function isHashable(self, x) + return type(x.GetHashCode) == "function" +end +function ____exports.dateOffset(self, date) + local date1 = date + return ((type(date1.offset) == "number") and date1.offset) or (((date.kind == 1) and 0) or (date:getTimezoneOffset() * -60000)) +end +function ____exports.stringHash(self, s) + local i = 0 + local h = 5381 + local len = #s + while i < len do + h = BitXOR((h * 33), __TS__StringCharCodeAt( + s, + (function() + local ____tmp = i + i = ____tmp + 1 + return ____tmp + end)() + )) + end + return h +end +function ____exports.numberHash(self, x) + return BitOR((x * 2654435761), 0) +end +function ____exports.combineHashCodes(self, hashes) + if #hashes == 0 then + return 0 + end + return __TS__ArrayReduce( + hashes, + function(____, h1, h2) + return BitXOR((lshift(h1, 5) + h1), h2) + end + ) +end +function ____exports.dateHash(self, x) + return x:getTime() +end +function ____exports.arrayHash(self, x) + local len = x.length + local hashes = __TS__New(Array, len) + do + local i = 0 + while i < len do + hashes[i + 1] = ____exports.structuralHash(nil, x[i]) + i = i + 1 + end + end + return ____exports.combineHashCodes(nil, hashes) +end +function ____exports.structuralHash(self, x) + if x == nil then + return 0 + end + local ____switch68 = __TS__TypeOf(x) + if ____switch68 == "boolean" then + goto ____switch68_case_0 + elseif ____switch68 == "number" then + goto ____switch68_case_1 + elseif ____switch68 == "string" then + goto ____switch68_case_2 + end + goto ____switch68_case_default + ::____switch68_case_0:: + do + return (x and 1) or 0 + end + ::____switch68_case_1:: + do + return ____exports.numberHash(nil, x) + end + ::____switch68_case_2:: + do + return ____exports.stringHash(nil, x) + end + ::____switch68_case_default:: + do + do + if isHashable(nil, x) then + return x:GetHashCode() + elseif ____exports.isArrayLike(nil, x) then + return ____exports.arrayHash(nil, x) + elseif __TS__InstanceOf(x, Date) then + return ____exports.dateHash(nil, x) + elseif Object:getPrototypeOf(x).constructor == Object then + local hashes = __TS__ArrayMap( + __TS__ObjectValues(x), + function(____, v) return ____exports.structuralHash(nil, v) end + ) + return ____exports.combineHashCodes(nil, hashes) + else + return ____exports.numberHash( + nil, + ____exports.ObjectRef:id(x) + ) + end + end + end + ::____switch68_end:: +end +function ____exports.equalArraysWith(self, x, y, eq) + if x == nil then + return y == nil + end + if y == nil then + return false + end + if x.length ~= y.length then + return false + end + do + local i = 0 + while i < x.length do + if not eq(nil, x[i], y[i]) then + return false + end + i = i + 1 + end + end + return true +end +function ____exports.equalArrays(self, x, y) + return ____exports.equalArraysWith(nil, x, y, ____exports.equals) +end +function equalObjects(self, x, y) + local xKeys = __TS__ObjectKeys(x) + local yKeys = __TS__ObjectKeys(y) + if #xKeys ~= #yKeys then + return false + end + __TS__ArraySort(xKeys) + __TS__ArraySort(yKeys) + do + local i = 0 + while i < #xKeys do + if (xKeys[i + 1] ~= yKeys[i + 1]) or (not ____exports.equals(nil, x[xKeys[i + 1]], y[yKeys[i + 1]])) then + return false + end + i = i + 1 + end + end + return true +end +function ____exports.equals(self, x, y) + if x == y then + return true + elseif x == nil then + return y == nil + elseif y == nil then + return false + elseif type(x) ~= "table" then + return false + elseif isEquatable(nil, x) then + return x:Equals(y) + elseif ____exports.isArrayLike(nil, x) then + return ____exports.isArrayLike(nil, y) and ____exports.equalArrays(nil, x, y) + elseif __TS__InstanceOf(x, Date) then + return __TS__InstanceOf(y, Date) and (____exports.compareDates(nil, x, y) == 0) + else + return (Object:getPrototypeOf(x).constructor == Object) and equalObjects(nil, x, y) + end +end +function ____exports.compareDates(self, x, y) + local xtime + local ytime + if (x.offset ~= nil) and (y.offset ~= nil) then + xtime = x:getTime() + ytime = y:getTime() + else + xtime = x:getTime() + ____exports.dateOffset(nil, x) + ytime = y:getTime() + ____exports.dateOffset(nil, y) + end + return ((xtime == ytime) and 0) or (((xtime < ytime) and -1) or 1) +end +function ____exports.compareArraysWith(self, x, y, comp) + if x == nil then + return ((y == nil) and 0) or 1 + end + if y == nil then + return -1 + end + if x.length ~= y.length then + return ((x.length < y.length) and -1) or 1 + end + do + local i = 0 + local j = 0 + while i < x.length do + j = comp(nil, x[i], y[i]) + if j ~= 0 then + return j + end + i = i + 1 + end + end + return 0 +end +function ____exports.compareArrays(self, x, y) + return ____exports.compareArraysWith(nil, x, y, ____exports.compare) +end +function compareObjects(self, x, y) + local xKeys = __TS__ObjectKeys(x) + local yKeys = __TS__ObjectKeys(y) + if #xKeys ~= #yKeys then + return ((#xKeys < #yKeys) and -1) or 1 + end + __TS__ArraySort(xKeys) + __TS__ArraySort(yKeys) + do + local i = 0 + local j = 0 + while i < #xKeys do + local key = xKeys[i + 1] + if key ~= yKeys[i + 1] then + return ((key < yKeys[i + 1]) and -1) or 1 + else + j = ____exports.compare(nil, x[key], y[key]) + if j ~= 0 then + return j + end + end + i = i + 1 + end + end + return 0 +end +function ____exports.compare(self, x, y) + if x == y then + return 0 + elseif x == nil then + return ((y == nil) and 0) or -1 + elseif y == nil then + return 1 + elseif type(x) ~= "table" then + return ((x < y) and -1) or 1 + elseif isComparable(nil, x) then + return x:CompareTo(y) + elseif ____exports.isArrayLike(nil, x) then + return (____exports.isArrayLike(nil, y) and ____exports.compareArrays(nil, x, y)) or -1 + elseif __TS__InstanceOf(x, Date) then + return (__TS__InstanceOf(y, Date) and ____exports.compareDates(nil, x, y)) or -1 + else + return ((Object:getPrototypeOf(x).constructor == Object) and compareObjects(nil, x, y)) or -1 + end +end +function ____exports.isIterable(self, x) + return ((x ~= nil) and (type(x) == "table")) and (x[Symbol.iterator] ~= nil) +end +local function isComparer(self, x) + return type(x.Compare) == "function" +end +function ____exports.isDisposable(self, x) + return (x ~= nil) and (type(x.Dispose) == "function") +end +function ____exports.sameConstructor(self, x, y) + return Object:getPrototypeOf(x).constructor == Object:getPrototypeOf(y).constructor +end +____exports.Enumerator = __TS__Class() +local Enumerator = ____exports.Enumerator +Enumerator.name = "Enumerator" +function Enumerator.prototype.____constructor(self, iter) + self.iter = iter +end +Enumerator.prototype["System.Collections.Generic.IEnumerator`1.get_Current"] = function(self) + return self.current +end +Enumerator.prototype["System.Collections.IEnumerator.get_Current"] = function(self) + return self.current +end +Enumerator.prototype["System.Collections.IEnumerator.MoveNext"] = function(self) + local cur = self.iter:next() + self.current = cur.value + return not cur.done +end +Enumerator.prototype["System.Collections.IEnumerator.Reset"] = function(self) + error( + __TS__New(Error, "JS iterators cannot be reset"), + 0 + ) +end +function Enumerator.prototype.Dispose(self) + return +end +function ____exports.getEnumerator(self, o) + return ((type(o.GetEnumerator) == "function") and o:GetEnumerator()) or __TS__New( + ____exports.Enumerator, + o[Symbol.iterator](o) + ) +end +function ____exports.toIterator(self, en) + return { + [Symbol.iterator] = function(self) + return self + end, + next = function(self) + local hasNext = en["System.Collections.IEnumerator.MoveNext"](en) + local current = ((hasNext and (function() return en["System.Collections.IEnumerator.get_Current"](en) end)) or (function() return nil end))() + return {done = not hasNext, value = current} + end + } +end +____exports.Comparer = __TS__Class() +local Comparer = ____exports.Comparer +Comparer.name = "Comparer" +function Comparer.prototype.____constructor(self, f) + self.Compare = f or ____exports.compare +end +function ____exports.comparerFromEqualityComparer(self, comparer) + if isComparer(nil, comparer) then + return __TS__New(____exports.Comparer, comparer.Compare) + else + return __TS__New( + ____exports.Comparer, + function(____, x, y) + local xhash = comparer:GetHashCode(x) + local yhash = comparer:GetHashCode(y) + if xhash == yhash then + return (comparer:Equals(x, y) and 0) or -1 + else + return ((xhash < yhash) and -1) or 1 + end + end + ) + end +end +function ____exports.assertEqual(self, actual, expected, msg) + if not ____exports.equals(nil, actual, expected) then + error( + __TS__ObjectAssign( + __TS__New( + Error, + msg or ((("Expected: " .. tostring(expected)) .. " - Actual: ") .. tostring(actual)) + ), + {actual = actual, expected = expected} + ), + 0 + ) + end +end +function ____exports.assertNotEqual(self, actual, expected, msg) + if ____exports.equals(nil, actual, expected) then + error( + __TS__ObjectAssign( + __TS__New( + Error, + msg or ((("Expected: " .. tostring(expected)) .. " - Actual: ") .. tostring(actual)) + ), + {actual = actual, expected = expected} + ), + 0 + ) + end +end +____exports.Lazy = __TS__Class() +local Lazy = ____exports.Lazy +Lazy.name = "Lazy" +function Lazy.prototype.____constructor(self, factory) + self.factory = factory + self.isValueCreated = false +end +__TS__SetDescriptor( + Lazy.prototype, + "Value", + { + get = function(self) + if not self.isValueCreated then + self.createdValue = self:factory() + self.isValueCreated = true + end + return self.createdValue + end + }, + true +) +__TS__SetDescriptor( + Lazy.prototype, + "IsValueCreated", + { + get = function(self) + return self.isValueCreated + end + }, + true +) +function ____exports.lazyFromValue(self, v) + return __TS__New( + ____exports.Lazy, + function() return v end + ) +end +function ____exports.padWithZeros(self, i, length) + local str = __TS__NumberToString(i, 10) + while #str < length do + str = "0" .. str + end + return str +end +function ____exports.padLeftAndRightWithZeros(self, i, lengthLeft, lengthRight) + local str = __TS__NumberToString(i, 10) + while #str < lengthLeft do + str = "0" .. str + end + while #str < lengthRight do + str = str .. "0" + end + return str +end +function ____exports.int16ToString(self, i, radix) + i = ((((i < 0) and (radix ~= nil)) and (radix ~= 10)) and ((65535 + i) + 1)) or i + return __TS__NumberToString(i, radix) +end +function ____exports.int32ToString(self, i, radix) + i = ((((i < 0) and (radix ~= nil)) and (radix ~= 10)) and ((4294967295 + i) + 1)) or i + return __TS__NumberToString(i, radix) +end +____exports.ObjectRef = __TS__Class() +local ObjectRef = ____exports.ObjectRef +ObjectRef.name = "ObjectRef" +function ObjectRef.prototype.____constructor(self) +end +function ObjectRef.id(self, o) + if not ____exports.ObjectRef.idMap:has(o) then + ____exports.ObjectRef.idMap:set( + o, + (function() + local ____tmp = ____exports.ObjectRef.count + 1 + ____exports.ObjectRef.count = ____tmp + return ____tmp + end)() + ) + end + return ____exports.ObjectRef.idMap:get(o) +end +ObjectRef.idMap = __TS__New(WeakMap) +ObjectRef.count = 0 +function ____exports.physicalHash(self, x) + if x == nil then + return 0 + end + local ____switch58 = __TS__TypeOf(x) + if ____switch58 == "boolean" then + goto ____switch58_case_0 + elseif ____switch58 == "number" then + goto ____switch58_case_1 + elseif ____switch58 == "string" then + goto ____switch58_case_2 + end + goto ____switch58_case_default + ::____switch58_case_0:: + do + return (x and 1) or 0 + end + ::____switch58_case_1:: + do + return ____exports.numberHash(nil, x) + end + ::____switch58_case_2:: + do + return ____exports.stringHash(nil, x) + end + ::____switch58_case_default:: + do + return ____exports.numberHash( + nil, + ____exports.ObjectRef:id(x) + ) + end + ::____switch58_end:: +end +function ____exports.identityHash(self, x) + if x == nil then + return 0 + elseif isHashable(nil, x) then + return x:GetHashCode() + else + return ____exports.physicalHash(nil, x) + end +end +function ____exports.fastStructuralHash(self, x) + return ____exports.stringHash( + nil, + String(nil, x) + ) +end +function ____exports.safeHash(self, x) + return ((x == nil) and 0) or ((isHashable(nil, x) and x:GetHashCode()) or ____exports.numberHash( + nil, + ____exports.ObjectRef:id(x) + )) +end +function ____exports.comparePrimitives(self, x, y) + return ((x == y) and 0) or (((x < y) and -1) or 1) +end +function ____exports.min(self, comparer, x, y) + return ((comparer(nil, x, y) < 0) and x) or y +end +function ____exports.max(self, comparer, x, y) + return ((comparer(nil, x, y) > 0) and x) or y +end +function ____exports.clamp(self, comparer, value, min, max) + return ((comparer(nil, value, min) < 0) and min) or (((comparer(nil, value, max) > 0) and max) or value) +end +function ____exports.createAtom(self, value) + local atom = value + return function(____, value, isSetter) + if not isSetter then + return atom + else + atom = value + return nil + end + end +end +function ____exports.createObj(self, fields) + local obj = {} + for ____, kv in __TS__Iterator(fields) do + obj[kv[1]] = kv[2] + end + return obj +end +function ____exports.jsOptions(self, mutator) + local opts = {} + mutator(nil, opts) + return opts +end +function ____exports.round(self, value, digits) + if digits == nil then + digits = 0 + end + local m = math.pow(10, digits) + local n = ((digits and (value * m)) or value):toFixed(8) + local i = math.floor(n) + local f = n - i + local e = 1e-8 + local r = (((f > (0.5 - e)) and (f < (0.5 + e))) and ((((i % 2) == 0) and i) or (i + 1))) or math.floor(n + 0.5) + return (digits and (r / m)) or r +end +function ____exports.sign(self, x) + return ((x > 0) and 1) or (((x < 0) and -1) or 0) +end +function ____exports.randomNext(self, min, max) + return math.floor( + math.random() * (max - min) + ) + min +end +function ____exports.randomBytes(self, buffer) + if buffer == nil then + error( + __TS__New(Error, "Buffer cannot be null"), + 0 + ) + end + do + local i = 0 + while i < buffer.length do + local r = math.floor( + math.random() * 281474976710656 + ) + local rhi = math.floor(r / 16777216) + do + local j = 0 + while (j < 6) and ((i + j) < buffer.length) do + if j == 3 then + r = rhi + end + buffer[i + j] = BitAND(r, 255) + r = rshift(r, 8) + j = j + 1 + end + end + i = i + 6 + end + end +end +function ____exports.unescapeDataString(self, s) + return decodeURIComponent( + nil, + __TS__StringReplace(s, nil, "%20") + ) +end +function ____exports.escapeDataString(self, s) + return __TS__StringReplace( + __TS__StringReplace( + __TS__StringReplace( + __TS__StringReplace( + __TS__StringReplace( + encodeURIComponent(nil, s), + nil, + "%21" + ), + nil, + "%27" + ), + nil, + "%28" + ), + nil, + "%29" + ), + nil, + "%2A" + ) +end +function ____exports.escapeUriString(self, s) + return encodeURI(nil, s) +end +function ____exports.count(self, col) + if ____exports.isArrayLike(nil, col) then + return col.length + else + local count = 0 + for ____, _ in __TS__Iterator(col) do + count = count + 1 + end + return count + end +end +function ____exports.clear(self, col) + if ____exports.isArrayLike(nil, col) then + __TS__ArraySplice(col, 0) + else + col:clear() + end +end +local CURRIED_KEY = "__CURRIED__" +function ____exports.uncurry(self, arity, f) + if (f == nil) or (f.length > 1) then + return f + end + local uncurriedFn + local ____switch154 = arity + if ____switch154 == 2 then + goto ____switch154_case_0 + elseif ____switch154 == 3 then + goto ____switch154_case_1 + elseif ____switch154 == 4 then + goto ____switch154_case_2 + elseif ____switch154 == 5 then + goto ____switch154_case_3 + elseif ____switch154 == 6 then + goto ____switch154_case_4 + elseif ____switch154 == 7 then + goto ____switch154_case_5 + elseif ____switch154 == 8 then + goto ____switch154_case_6 + end + goto ____switch154_case_default + ::____switch154_case_0:: + do + uncurriedFn = function(____, a1, a2) return f(nil, a1)(nil, a2) end + goto ____switch154_end + end + ::____switch154_case_1:: + do + uncurriedFn = function(____, a1, a2, a3) return f(nil, a1)(nil, a2)(nil, a3) end + goto ____switch154_end + end + ::____switch154_case_2:: + do + uncurriedFn = function(____, a1, a2, a3, a4) return f(nil, a1)(nil, a2)(nil, a3)(nil, a4) end + goto ____switch154_end + end + ::____switch154_case_3:: + do + uncurriedFn = function(____, a1, a2, a3, a4, a5) return f(nil, a1)(nil, a2)(nil, a3)(nil, a4)(nil, a5) end + goto ____switch154_end + end + ::____switch154_case_4:: + do + uncurriedFn = function(____, a1, a2, a3, a4, a5, a6) return f(nil, a1)(nil, a2)(nil, a3)(nil, a4)(nil, a5)(nil, a6) end + goto ____switch154_end + end + ::____switch154_case_5:: + do + uncurriedFn = function(____, a1, a2, a3, a4, a5, a6, a7) return f(nil, a1)(nil, a2)(nil, a3)(nil, a4)(nil, a5)(nil, a6)(nil, a7) end + goto ____switch154_end + end + ::____switch154_case_6:: + do + uncurriedFn = function(____, a1, a2, a3, a4, a5, a6, a7, a8) return f(nil, a1)(nil, a2)(nil, a3)(nil, a4)(nil, a5)(nil, a6)(nil, a7)(nil, a8) end + goto ____switch154_end + end + ::____switch154_case_default:: + do + error( + __TS__New( + Error, + "Uncurrying to more than 8-arity is not supported: " .. tostring(arity) + ), + 0 + ) + end + ::____switch154_end:: + uncurriedFn[CURRIED_KEY] = f + return uncurriedFn +end +function ____exports.curry(self, arity, f) + if (f == nil) or (f.length == 1) then + return f + end + if f[CURRIED_KEY] ~= nil then + return f[CURRIED_KEY] + end + local ____switch165 = arity + if ____switch165 == 2 then + goto ____switch165_case_0 + elseif ____switch165 == 3 then + goto ____switch165_case_1 + elseif ____switch165 == 4 then + goto ____switch165_case_2 + elseif ____switch165 == 5 then + goto ____switch165_case_3 + elseif ____switch165 == 6 then + goto ____switch165_case_4 + elseif ____switch165 == 7 then + goto ____switch165_case_5 + elseif ____switch165 == 8 then + goto ____switch165_case_6 + end + goto ____switch165_case_default + ::____switch165_case_0:: + do + return function(____, a1) return function(____, a2) return f(nil, a1, a2) end end + end + ::____switch165_case_1:: + do + return function(____, a1) return function(____, a2) return function(____, a3) return f(nil, a1, a2, a3) end end end + end + ::____switch165_case_2:: + do + return function(____, a1) return function(____, a2) return function(____, a3) return function(____, a4) return f(nil, a1, a2, a3, a4) end end end end + end + ::____switch165_case_3:: + do + return function(____, a1) return function(____, a2) return function(____, a3) return function(____, a4) return function(____, a5) return f(nil, a1, a2, a3, a4, a5) end end end end end + end + ::____switch165_case_4:: + do + return function(____, a1) return function(____, a2) return function(____, a3) return function(____, a4) return function(____, a5) return function(____, a6) return f(nil, a1, a2, a3, a4, a5, a6) end end end end end end + end + ::____switch165_case_5:: + do + return function(____, a1) return function(____, a2) return function(____, a3) return function(____, a4) return function(____, a5) return function(____, a6) return function(____, a7) return f(nil, a1, a2, a3, a4, a5, a6, a7) end end end end end end end + end + ::____switch165_case_6:: + do + return function(____, a1) return function(____, a2) return function(____, a3) return function(____, a4) return function(____, a5) return function(____, a6) return function(____, a7) return function(____, a8) return f(nil, a1, a2, a3, a4, a5, a6, a7, a8) end end end end end end end end + end + ::____switch165_case_default:: + do + error( + __TS__New( + Error, + "Currying to more than 8-arity is not supported: " .. tostring(arity) + ), + 0 + ) + end + ::____switch165_end:: +end +function ____exports.checkArity(self, arity, f) + return ((f.length > arity) and (function(____, ...) + local args1 = {...} + return function(____, ...) + local args2 = {...} + return f:apply( + nil, + __TS__ArrayConcat(args1, args2) + ) + end + end)) or f +end +function ____exports.partialApply(self, arity, f, args) + if arity == 1 then + return function (a) return f(table.unpack(__TS__ArrayConcat({ table.unpack(args) }, a))) end + elseif arity == 2 then + return function (a, b) return f(table.unpack(__TS__ArrayConcat({ table.unpack(args) }, a, b))) end + end + -- if f == nil then + -- return nil + -- elseif type(f) == 'table' and f[CURRIED_KEY] ~= nil then + -- f = f[CURRIED_KEY] + -- do + -- local i = 0 + -- while i < #args do + -- f = f(nil, args[i + 1]) + -- i = i + 1 + -- end + -- end + -- return f + -- else + -- local ____switch209 = arity + -- if ____switch209 == 1 then + -- goto ____switch209_case_0 + -- elseif ____switch209 == 2 then + -- goto ____switch209_case_1 + -- elseif ____switch209 == 3 then + -- goto ____switch209_case_2 + -- elseif ____switch209 == 4 then + -- goto ____switch209_case_3 + -- elseif ____switch209 == 5 then + -- goto ____switch209_case_4 + -- elseif ____switch209 == 6 then + -- goto ____switch209_case_5 + -- elseif ____switch209 == 7 then + -- goto ____switch209_case_6 + -- elseif ____switch209 == 8 then + -- goto ____switch209_case_7 + -- end + -- goto ____switch209_case_default + -- ::____switch209_case_0:: + -- do + -- return function(____, a1) return f:apply( + -- nil, + -- __TS__ArrayConcat(args, {a1}) + -- ) end + -- end + -- ::____switch209_case_1:: + -- do + -- return function(____, a1) return function(____, a2) return f:apply( + -- nil, + -- __TS__ArrayConcat(args, {a1, a2}) + -- ) end end + -- end + -- ::____switch209_case_2:: + -- do + -- return function(____, a1) return function(____, a2) return function(____, a3) return f:apply( + -- nil, + -- __TS__ArrayConcat(args, {a1, a2, a3}) + -- ) end end end + -- end + -- ::____switch209_case_3:: + -- do + -- return function(____, a1) return function(____, a2) return function(____, a3) return function(____, a4) return f:apply( + -- nil, + -- __TS__ArrayConcat(args, {a1, a2, a3, a4}) + -- ) end end end end + -- end + -- ::____switch209_case_4:: + -- do + -- return function(____, a1) return function(____, a2) return function(____, a3) return function(____, a4) return function(____, a5) return f:apply( + -- nil, + -- __TS__ArrayConcat(args, {a1, a2, a3, a4, a5}) + -- ) end end end end end + -- end + -- ::____switch209_case_5:: + -- do + -- return function(____, a1) return function(____, a2) return function(____, a3) return function(____, a4) return function(____, a5) return function(____, a6) return f:apply( + -- nil, + -- __TS__ArrayConcat(args, {a1, a2, a3, a4, a5, a6}) + -- ) end end end end end end + -- end + -- ::____switch209_case_6:: + -- do + -- return function(____, a1) return function(____, a2) return function(____, a3) return function(____, a4) return function(____, a5) return function(____, a6) return function(____, a7) return f:apply( + -- nil, + -- __TS__ArrayConcat(args, {a1, a2, a3, a4, a5, a6, a7}) + -- ) end end end end end end end + -- end + -- ::____switch209_case_7:: + -- do + -- return function(____, a1) return function(____, a2) return function(____, a3) return function(____, a4) return function(____, a5) return function(____, a6) return function(____, a7) return function(____, a8) return f:apply( + -- nil, + -- __TS__ArrayConcat(args, {a1, a2, a3, a4, a5, a6, a7, a8}) + -- ) end end end end end end end end + -- end + -- ::____switch209_case_default:: + -- do + -- error( + -- __TS__New( + -- Error, + -- "Partially applying to more than 8-arity is not supported: " .. tostring(arity) + -- ), + -- 0 + -- ) + -- end + -- ::____switch209_end:: + -- end +end +function ____exports.mapCurriedArgs(self, fn, mappings) + local function mapArg(self, fn, arg, mappings, idx) + local mapping = mappings[idx + 1] + if mapping ~= 0 then + local expectedArity = mapping[1] + local actualArity = mapping[2] + if expectedArity > 1 then + arg = ____exports.curry(nil, expectedArity, arg) + end + if actualArity > 1 then + arg = ____exports.uncurry(nil, actualArity, arg) + end + end + local res = fn(nil, arg) + if (idx + 1) == #mappings then + return res + else + return function(____, arg) return mapArg(nil, res, arg, mappings, idx + 1) end + end + end + return function(____, arg) return mapArg(nil, fn, arg, mappings, 0) end +end +return ____exports \ No newline at end of file diff --git a/src/fable-library-py/fable_library/async_builder.py b/src/fable-library-py/fable_library/async_builder.py index 04acd03473..c385d4d490 100644 --- a/src/fable-library-py/fable_library/async_builder.py +++ b/src/fable-library-py/fable_library/async_builder.py @@ -35,8 +35,7 @@ def __init__(self, msg: str | None = None) -> None: class _Listener(Protocol): - def __call__(self, __state: Any | None = None) -> None: - ... + def __call__(self, __state: Any | None = None) -> None: ... class CancellationToken: @@ -93,36 +92,29 @@ class IAsyncContext(Generic[_T]): __slots__ = () @abstractmethod - def on_success(self, value: _T) -> None: - ... + def on_success(self, value: _T) -> None: ... @abstractmethod - def on_error(self, error: Exception) -> None: - ... + def on_error(self, error: Exception) -> None: ... @abstractmethod - def on_cancel(self, error: OperationCanceledError) -> None: - ... + def on_cancel(self, error: OperationCanceledError) -> None: ... @property @abstractmethod - def trampoline(self) -> Trampoline: - ... + def trampoline(self) -> Trampoline: ... @trampoline.setter @abstractmethod - def trampoline(self, val: Trampoline): - ... + def trampoline(self, val: Trampoline): ... @property @abstractmethod - def cancel_token(self) -> CancellationToken: - ... + def cancel_token(self) -> CancellationToken: ... @cancel_token.setter @abstractmethod - def cancel_token(self, val: CancellationToken): - ... + def cancel_token(self, val: CancellationToken): ... @staticmethod def create( @@ -308,12 +300,10 @@ def delay() -> Async[_U]: return self.While(lambda: not done, self.Delay(delay)) @overload - def Return(self) -> Async[None]: - ... + def Return(self) -> Async[None]: ... @overload - def Return(self, value: _T) -> Async[_T]: - ... + def Return(self, value: _T) -> Async[_T]: ... def Return(self, value: Any = None) -> Async[Any]: return protected_return(value) @@ -367,12 +357,10 @@ def compensation() -> None: return self.TryFinally(binder(resource), compensation) @overload - def While(self, guard: Callable[[], bool], computation: Async[Literal[None]]) -> Async[None]: - ... + def While(self, guard: Callable[[], bool], computation: Async[Literal[None]]) -> Async[None]: ... @overload - def While(self, guard: Callable[[], bool], computation: Async[_T]) -> Async[_T]: - ... + def While(self, guard: Callable[[], bool], computation: Async[_T]) -> Async[_T]: ... def While(self, guard: Callable[[], bool], computation: Async[Any]) -> Async[Any]: if guard(): diff --git a/src/fable-library-py/fable_library/event.py b/src/fable-library-py/fable_library/event.py index 13c8e280a3..947ceb84e8 100644 --- a/src/fable-library-py/fable_library/event.py +++ b/src/fable-library-py/fable_library/event.py @@ -33,16 +33,13 @@ class IDelegateEvent(Generic[_T_co], Protocol): @abstractmethod - def AddHandler(self, d: DotNetDelegate[_T]) -> None: - ... + def AddHandler(self, d: DotNetDelegate[_T]) -> None: ... @abstractmethod - def RemoveHandler(self, d: DotNetDelegate[_T]) -> None: - ... + def RemoveHandler(self, d: DotNetDelegate[_T]) -> None: ... -class IEvent_2(IObservable[_Args], IDelegateEvent[_Delegate], Protocol): - ... +class IEvent_2(IObservable[_Args], IDelegateEvent[_Delegate], Protocol): ... IEvent = IEvent_2[_T, _T] @@ -60,12 +57,10 @@ def Publish(self) -> IEvent[_T]: return self @overload - def Trigger(self, value: _T) -> None: - ... + def Trigger(self, value: _T) -> None: ... @overload - def Trigger(self, sender: Any, value: _T) -> None: - ... + def Trigger(self, sender: Any, value: _T) -> None: ... def Trigger(self, sender_or_value: Any, value_or_undefined: _T | None = None) -> None: if value_or_undefined is None: diff --git a/src/fable-library-py/fable_library/mailbox_processor.py b/src/fable-library-py/fable_library/mailbox_processor.py index d4ed25a78a..ae7bca26e7 100644 --- a/src/fable-library-py/fable_library/mailbox_processor.py +++ b/src/fable-library-py/fable_library/mailbox_processor.py @@ -71,9 +71,9 @@ def post_and_async_reply(self, build_message: Callable[[AsyncReplyChannel[_Reply """ result: _Reply | None = None - continuation: Continuations[ - Any - ] | None = None # This is the continuation for the `done` callback of the awaiting poster. + continuation: Continuations[Any] | None = ( + None # This is the continuation for the `done` callback of the awaiting poster. + ) def check_completion() -> None: if result is not None and continuation is not None: diff --git a/src/fable-library-py/fable_library/map_util.py b/src/fable-library-py/fable_library/map_util.py index 275a97d1d7..86d0e42e4c 100644 --- a/src/fable-library-py/fable_library/map_util.py +++ b/src/fable-library-py/fable_library/map_util.py @@ -54,8 +54,7 @@ def change_case(string: str, case_rule: CaseRules) -> str: if TYPE_CHECKING: - class FSharpMap(dict[_K, _V]): - ... + class FSharpMap(dict[_K, _V]): ... else: from .map import FSharpMap diff --git a/src/fable-library-py/fable_library/observable.py b/src/fable-library-py/fable_library/observable.py index d702cca320..94202b432d 100644 --- a/src/fable-library-py/fable_library/observable.py +++ b/src/fable-library-py/fable_library/observable.py @@ -23,16 +23,13 @@ class IObserver(Protocol, Generic[_T_contra]): __slots__ = () @abstractmethod - def OnNext(self, __value: _T_contra) -> None: - ... + def OnNext(self, __value: _T_contra) -> None: ... @abstractmethod - def OnError(self, __error: Exception) -> None: - ... + def OnError(self, __error: Exception) -> None: ... @abstractmethod - def OnCompleted(self) -> None: - ... + def OnCompleted(self) -> None: ... def _noop(__arg: Any = None) -> None: @@ -66,8 +63,7 @@ class IObservable(Protocol, Generic[_T_co]): __slots__ = () @abstractmethod - def Subscribe(self, __obs: IObserver[_T_co]) -> IDisposable: - ... + def Subscribe(self, __obs: IObserver[_T_co]) -> IDisposable: ... class Observable(IObservable[_T]): diff --git a/src/fable-library-py/fable_library/string_.py b/src/fable-library-py/fable_library/string_.py index 4b47291565..e14f6a5415 100644 --- a/src/fable-library-py/fable_library/string_.py +++ b/src/fable-library-py/fable_library/string_.py @@ -478,8 +478,7 @@ def compare(string1: str, string2: str, /) -> int: @overload -def compare(string1: str, string2: str, ignore_case: bool, culture: StringComparison, /) -> int: - ... +def compare(string1: str, string2: str, ignore_case: bool, culture: StringComparison, /) -> int: ... def compare(*args: Any) -> int: diff --git a/src/fable-library-py/fable_library/task.py b/src/fable-library-py/fable_library/task.py index 089847e1fb..2da909dec8 100644 --- a/src/fable-library-py/fable_library/task.py +++ b/src/fable-library-py/fable_library/task.py @@ -4,6 +4,7 @@ [tasks](https://docs.microsoft.com/en-us/dotnet/standard/async-in-depth) using Python async / await. """ + from __future__ import annotations import asyncio diff --git a/src/fable-library-py/fable_library/task_builder.py b/src/fable-library-py/fable_library/task_builder.py index c0ba1fdd64..e86b6ac58a 100644 --- a/src/fable-library-py/fable_library/task_builder.py +++ b/src/fable-library-py/fable_library/task_builder.py @@ -19,8 +19,7 @@ class Delayed(Protocol[_T_co]): - def __call__(self, __unit: None | None = None) -> Awaitable[_T_co]: - ... + def __call__(self, __unit: None | None = None) -> Awaitable[_T_co]: ... class TaskBuilder: @@ -61,12 +60,10 @@ def delay(): return self.While(lambda: not done, self.Delay(delay)) @overload - def Return(self) -> Awaitable[None]: - ... + def Return(self) -> Awaitable[None]: ... @overload - def Return(self, value: _T) -> Awaitable[_T]: - ... + def Return(self, value: _T) -> Awaitable[_T]: ... def Return(self, value: Any = None) -> Awaitable[Any]: return from_result(value) @@ -98,12 +95,10 @@ def Using(self, resource: _TD, binder: Callable[[_TD], Awaitable[_U]]) -> Awaita return self.TryFinally(self.Delay(lambda: binder(resource)), lambda: resource.Dispose()) @overload - def While(self, guard: Callable[[], bool], computation: Delayed[None]) -> Awaitable[None]: - ... + def While(self, guard: Callable[[], bool], computation: Delayed[None]) -> Awaitable[None]: ... @overload - def While(self, guard: Callable[[], bool], computation: Delayed[_T]) -> Awaitable[_T]: - ... + def While(self, guard: Callable[[], bool], computation: Delayed[_T]) -> Awaitable[_T]: ... def While(self, guard: Callable[[], bool], computation: Delayed[Any]) -> Awaitable[Any]: if guard(): diff --git a/src/fcs-fable/src/Compiler/Driver/parallel-optimization.md b/src/fcs-fable/src/Compiler/Driver/parallel-optimization.md index bf7d79104e..3045b3821a 100644 --- a/src/fcs-fable/src/Compiler/Driver/parallel-optimization.md +++ b/src/fcs-fable/src/Compiler/Driver/parallel-optimization.md @@ -12,4 +12,4 @@ This allows us to parallelize the whole process as shown in the diagram below: ![Optimisation chart](parallel-optimization.drawio.svg) This parallelization is implemented in `OptimizeInputs.fs`. -It can enabled with an experimental flag `--test:ParallelOptimization`. \ No newline at end of file +It can enabled with an experimental flag `--test:ParallelOptimization`. diff --git a/tests/Lua/Fable.Tests.Lua.fsproj b/tests/Lua/Fable.Tests.Lua.fsproj new file mode 100644 index 0000000000..8997a7ed14 --- /dev/null +++ b/tests/Lua/Fable.Tests.Lua.fsproj @@ -0,0 +1,30 @@ + + + net8.0 + false + false + true + preview + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + + + diff --git a/tests/Lua/Main.fs b/tests/Lua/Main.fs new file mode 100644 index 0000000000..ae5f6f6959 --- /dev/null +++ b/tests/Lua/Main.fs @@ -0,0 +1,6 @@ +#if FABLE_COMPILER +module Program +() +#else +module Program = let [] main _ = 0 +#endif \ No newline at end of file diff --git a/tests/Lua/TestArithmetic.fs b/tests/Lua/TestArithmetic.fs new file mode 100644 index 0000000000..2189e3e0f8 --- /dev/null +++ b/tests/Lua/TestArithmetic.fs @@ -0,0 +1,151 @@ +module Fable.Tests.Arithmetic + +open System +open Util.Testing + +[] +let testTwoPlusTwo () = + 2 + 2 |> equal 4 + +[] +let testMinus () = + 7 - 2 |> equal 5 + +[] +let testMultiply () = + 3 * 2 |> equal 6 + +[] +let testDivide () = + 10 / 2 |> equal 5 + +[] +let testFloatAdd () = + 3.141 + 2.85 |> equal 5.991 + +let private addFn a b = a + b + +[] +let testAddThroughTrivialFn () = + addFn 2 2 |> equal 4 + +[] +let testLocalsWithFcalls () = + let a = addFn 1 0 + let b = 2 + let c = addFn 3 0 + a + b + c |> equal 6 + +[] +let testAddStrings () = + let a () = "hello" + let b () = "world" + a() + " " + b() |> equal "hello world" + +[] +let testLocalFunction () = + let locAdd1 a = + addFn 1 a + locAdd1 2 |> equal 3 + +[] +let testInlineLambda () = + 1 |> fun x -> x + 1 |> fun x -> x - 3 |> equal (-1) + +let add42 = addFn 42 +[] +let testPartialApply () = + add42 3 |> equal 45 +// let [] aLiteral = 5 +// let notALiteral = 5 +// let [] literalNegativeValue = -345 + +// let checkTo3dp (expected: float) actual = +// floor (actual * 1000.) |> equal expected + +// // let positiveInfinity = System.Double.PositiveInfinity +// // let negativeInfinity = System.Double.NegativeInfinity +// //let isNaN = fun x -> System.Double.IsNaN(x) + +// let equals (x:'a) (y:'a) = x = y +// let compareTo (x:'a) (y:'a) = compare x y + +// let decimalOne = 1M +// let decimalTwo = 2M + +// [] +// let ``test Infix add can be generated`` () = +// 4 + 2 |> equal 6 + +// [] +// let ``test Int32 literal addition is optimized`` () = +// aLiteral + 7 |> equal 12 +// notALiteral + 7 |> equal 12 + +// // FIXME +// // [] +// // let ``test Unary negation with negative literal values works`` () = +// // -literalNegativeValue |> equal 345 + +// [] +// let ``test Unary negation with integer MinValue works`` () = +// -(-128y) |> equal System.SByte.MinValue +// -(-32768s) |> equal System.Int16.MinValue +// -(-2147483648) |> equal System.Int32.MinValue +// // FIXME -(-9223372036854775808L) |> equal System.Int64.MinValue + +// [] +// let ``test Infix subtract can be generated`` () = +// 4 - 2 |> equal 2 + +// [] +// let ``test Infix multiply can be generated`` () = +// 4 * 2 |> equal 8 + +// [] +// let ``test Infix divide can be generated`` () = +// 4 / 2 |> equal 2 + +// [] +// let ``test Integer division doesn't produce floats`` () = +// 5. / 2. |> equal 2.5 +// 5 / 2 |> equal 2 +// 5 / 3 |> equal 1 + +// [] +// let ``test Infix modulo can be generated`` () = +// 4 % 3 |> equal 1 + +// [] +// let ``test Evaluation order is preserved by generated code`` () = +// (4 - 2) * 2 + 1 |> equal 5 + +// [] +// let ``test Decimal.ToString works`` () = +// string 001.23456M |> equal "1.23456" +// string 1.23456M |> equal "1.23456" +// string 0.12345M |> equal "0.12345" +// string 0.01234M |> equal "0.01234" +// string 0.00123M |> equal "0.00123" +// string 0.00012M |> equal "0.00012" +// string 0.00001M |> equal "0.00001" +// // FIXME: +// // string 0.00000M |> equal "0.00000" +// // string 0.12300M |> equal "0.12300" +// // string 0.0M |> equal "0.0" +// string 0M |> equal "0" +// string 1M |> equal "1" +// string -1M |> equal "-1" +// string 00000000000000000000000000000.M |> equal "0" +// // string 0.0000000000000000000000000000M |> equal "0.0000000000000000000000000000" +// string 79228162514264337593543950335M |> equal "79228162514264337593543950335" +// string -79228162514264337593543950335M |> equal "-79228162514264337593543950335" + +// [] +// let ``test Decimal precision is kept`` () = +// let items = [ 290.8M +// 290.8M +// 337.12M +// 6.08M +// -924.8M ] +// List.sum items |> equal 0M diff --git a/tests/Lua/TestArray.fs b/tests/Lua/TestArray.fs new file mode 100644 index 0000000000..7f9ead6555 --- /dev/null +++ b/tests/Lua/TestArray.fs @@ -0,0 +1,31 @@ +module Fable.Tests.Array + +open System +open Util.Testing + +[] +let testCreateArray () = + let arr = [||] + arr |> equal [||] + +[] +let testCreateArray2 () = + let arr = [|1;2;3|] + arr |> equal [|1;2;3|] + +[] +let testDeref1 () = + let arr = [|1;2;3|] + arr.[0] |> equal 1 + arr.[1] |> equal 2 + arr.[2] |> equal 3 + +// [] +// let testDestructure () = +// let x = [|1;2|] +// let res = +// match x with +// | [|a; b|] -> +// Some (a, b) +// | _ -> None +// res |> equal (Some(1, 2)) \ No newline at end of file diff --git a/tests/Lua/TestControlFlow.fs b/tests/Lua/TestControlFlow.fs new file mode 100644 index 0000000000..c3154f0ef4 --- /dev/null +++ b/tests/Lua/TestControlFlow.fs @@ -0,0 +1,86 @@ +module Fable.Tests.TestControlFlow + +open System +open Util.Testing + +[] +let testIfElse () = + let r = + if true then 4 else 6 + r |> equal 4 +[] +let testIfElse2 () = + let r = + if false then 4 else 6 + r |> equal 6 + +let bfn x a b= + if x then a else b + +[] +let testIfElseFn1 () = + bfn true 1 2 |> equal 1 +[] +let testIfElseFn2 () = + bfn false 3 4 |> equal 4 + +[] +let testIfElseIf () = + let a x = + if x = 1 then + 1 + else if x = 2 then + 2 + else 3 + a 1 |> equal 1 + a 2 |> equal 2 + a 3 |> equal 3 + +[] +let testForEach1 () = + let mutable a = 42 + for i in 0..5 do + a <- i + a + a |> equal 57 + +[] +let testWhile1 () = + let mutable a = 1 + while a < 3 do + a <- a + 1 + a |> equal 3 + +[] +let testExHappy() = + let a = + try + 3 + with ex -> + 4 + a |> equal 3 + +[] +let testExThrow() = + let a = + try + failwith "boom" + 3 + with ex -> + 4 + a |> equal 4 + + +[] +let testSimpleFnParam() = + let add a b = a + b + let fn addFn a b = addFn a b + fn add 3 2 |> equal 5 + +[] +let testPartialApply() = + let add a b = a + b + let fn addFn a b = addFn a b + fn add 3 2 |> equal 5 + let fnAdd4 = fn add 4 + fnAdd4 5 |> equal 9 + fnAdd4 2 |> equal 6 \ No newline at end of file diff --git a/tests/Lua/TestRecords.fs b/tests/Lua/TestRecords.fs new file mode 100644 index 0000000000..71138fefb8 --- /dev/null +++ b/tests/Lua/TestRecords.fs @@ -0,0 +1,49 @@ +module Fable.Tests.Record + +open Util.Testing + +type Simple = { + one: string + two: int +} + +type Parent = { + a: Simple + b: Simple +} + +// type Recursive = { +// x: Recursive +// } + +[] +let testMakeRecord () = + let r = { one="string_one"; two=2} + r.one |> equal "string_one" + r.two |> equal 2 + +[] +let testMakeNestedRecord () = + let r = { + a = { one = "a"; two = 2} + b = { one = "b"; two = 4} + } + r.a.one |> equal "a" + r.b.one |> equal "b" + r.a.two |> equal 2 + r.b.two |> equal 4 + +[] +let testStructuralCompareRecords () = + let a = { one="string_one"; two=2} + let b = { one="string_one"; two=2} + let c = { one="string_two"; two=4} + a = a |> equal true + a = b |> equal true + a = c |> equal false + +[] +let testMakeAnonRecord () = + let r = {| x = 3.142; y = true |} + r.x |> equal 3.142 + r.y |> equal true diff --git a/tests/Lua/TestUnionType.fs b/tests/Lua/TestUnionType.fs new file mode 100644 index 0000000000..ddcd1406fb --- /dev/null +++ b/tests/Lua/TestUnionType.fs @@ -0,0 +1,45 @@ +module Fable.Tests.UnionTypes + +open Util.Testing + +type Shape = Square | Circle + +[] +let testMakeUnion () = + let r = Square + r |> equal Square + +[] +let testMakeUnion2 () = + let r = Circle + r |> equal Circle + +type Stuff = + | A of string * int + | B + | C of bool + +[] +let testMakeUnionContent () = + let r = A ("abc", 42) + r |> equal (A ("abc", 42)) + +[] +let testMakeUnionContent2 () = + let r = C true + r |> equal (C true) + +[] +let testMakeUnionContent3 () = + let r = B + r |> equal B + +[] +let testMatch1 () = + let thing = A("abc", 123) + let res = + match thing with + | A(s, i) -> Some(s, i) + | B -> None + | C(_) -> None + res |> equal (Some("abc", 123)) \ No newline at end of file diff --git a/tests/Lua/Util.fs b/tests/Lua/Util.fs new file mode 100644 index 0000000000..3e73f61d40 --- /dev/null +++ b/tests/Lua/Util.fs @@ -0,0 +1,32 @@ +module Fable.Tests.Util + +open System + +module Testing = +#if FABLE_COMPILER + open Fable.Core + open Fable.Core.PyInterop + + type Assert = + [] + static member AreEqual(actual: 'T, expected: 'T, ?msg: string): unit = nativeOnly + [] + static member NotEqual(actual: 'T, expected: 'T, ?msg: string): unit = nativeOnly + + let equal expected actual: unit = Assert.AreEqual(actual, expected) + let notEqual expected actual: unit = Assert.NotEqual(actual, expected) + + type Fact() = inherit System.Attribute() +#else + open Xunit + type FactAttribute = Xunit.FactAttribute + + let equal<'T> (expected: 'T) (actual: 'T): unit = Assert.Equal(expected, actual) + let notEqual<'T> (expected: 'T) (actual: 'T) : unit = Assert.NotEqual(expected, actual) +#endif + + // let rec sumFirstSeq (zs: seq) (n: int): float = + // match n with + // | 0 -> 0. + // | 1 -> Seq.head zs + // | _ -> (Seq.head zs) + sumFirstSeq (Seq.skip 1 zs) (n-1) diff --git a/tests/Lua/luaunit.lua b/tests/Lua/luaunit.lua new file mode 100644 index 0000000000..937c0f90a6 --- /dev/null +++ b/tests/Lua/luaunit.lua @@ -0,0 +1,3372 @@ +--[[ + luaunit.lua +Description: A unit testing framework +Homepage: https://github.com/bluebird75/luaunit +Development by Philippe Fremy +Based on initial work of Ryu, Gwang (http://www.gpgstudy.com/gpgiki/LuaUnit) +License: BSD License, see LICENSE.txt +]]-- + +require("math") +local M={} + +-- private exported functions (for testing) +M.private = {} + +M.VERSION='3.4' +M._VERSION=M.VERSION -- For LuaUnit v2 compatibility + +-- a version which distinguish between regular Lua and LuaJit +M._LUAVERSION = (jit and jit.version) or _VERSION + +--[[ Some people like assertEquals( actual, expected ) and some people prefer +assertEquals( expected, actual ). +]]-- +M.ORDER_ACTUAL_EXPECTED = true +M.PRINT_TABLE_REF_IN_ERROR_MSG = false +M.LINE_LENGTH = 80 +M.TABLE_DIFF_ANALYSIS_THRESHOLD = 10 -- display deep analysis for more than 10 items +M.LIST_DIFF_ANALYSIS_THRESHOLD = 10 -- display deep analysis for more than 10 items + +-- this setting allow to remove entries from the stack-trace, for +-- example to hide a call to a framework which would be calling luaunit +M.STRIP_EXTRA_ENTRIES_IN_STACK_TRACE = 0 + +--[[ EPS is meant to help with Lua's floating point math in simple corner +cases like almostEquals(1.1-0.1, 1), which may not work as-is (e.g. on numbers +with rational binary representation) if the user doesn't provide some explicit +error margin. +The default margin used by almostEquals() in such cases is EPS; and since +Lua may be compiled with different numeric precisions (single vs. double), we +try to select a useful default for it dynamically. Note: If the initial value +is not acceptable, it can be changed by the user to better suit specific needs. +See also: https://en.wikipedia.org/wiki/Machine_epsilon +]] +M.EPS = 2^-52 -- = machine epsilon for "double", ~2.22E-16 +if math.abs(1.1 - 1 - 0.1) > M.EPS then + -- rounding error is above EPS, assume single precision + M.EPS = 2^-23 -- = machine epsilon for "float", ~1.19E-07 +end + +-- set this to false to debug luaunit +local STRIP_LUAUNIT_FROM_STACKTRACE = true + +M.VERBOSITY_DEFAULT = 10 +M.VERBOSITY_LOW = 1 +M.VERBOSITY_QUIET = 0 +M.VERBOSITY_VERBOSE = 20 +M.DEFAULT_DEEP_ANALYSIS = nil +M.FORCE_DEEP_ANALYSIS = true +M.DISABLE_DEEP_ANALYSIS = false + +-- set EXPORT_ASSERT_TO_GLOBALS to have all asserts visible as global values +-- EXPORT_ASSERT_TO_GLOBALS = true + +-- we need to keep a copy of the script args before it is overriden +local cmdline_argv = rawget(_G, "arg") + +M.FAILURE_PREFIX = 'LuaUnit test FAILURE: ' -- prefix string for failed tests +M.SUCCESS_PREFIX = 'LuaUnit test SUCCESS: ' -- prefix string for successful tests finished early +M.SKIP_PREFIX = 'LuaUnit test SKIP: ' -- prefix string for skipped tests + + + +M.USAGE=[[Usage: lua [options] [testname1 [testname2] ... ] +Options: + -h, --help: Print this help + --version: Print version information + -v, --verbose: Increase verbosity + -q, --quiet: Set verbosity to minimum + -e, --error: Stop on first error + -f, --failure: Stop on first failure or error + -s, --shuffle: Shuffle tests before running them + -o, --output OUTPUT: Set output type to OUTPUT + Possible values: text, tap, junit, nil + -n, --name NAME: For junit only, mandatory name of xml file + -r, --repeat NUM: Execute all tests NUM times, e.g. to trig the JIT + -p, --pattern PATTERN: Execute all test names matching the Lua PATTERN + May be repeated to include several patterns + Make sure you escape magic chars like +? with % + -x, --exclude PATTERN: Exclude all test names matching the Lua PATTERN + May be repeated to exclude several patterns + Make sure you escape magic chars like +? with % + testname1, testname2, ... : tests to run in the form of testFunction, + TestClass or TestClass.testMethod +You may also control LuaUnit options with the following environment variables: +* LUAUNIT_OUTPUT: same as --output +* LUAUNIT_JUNIT_FNAME: same as --name ]] + +---------------------------------------------------------------- +-- +-- general utility functions +-- +---------------------------------------------------------------- + +--[[ Note on catching exit +I have seen the case where running a big suite of test cases and one of them would +perform a os.exit(0), making the outside world think that the full test suite was executed +successfully. +This is an attempt to mitigate this problem: we override os.exit() to now let a test +exit the framework while we are running. When we are not running, it behaves normally. +]] + +M.oldOsExit = os.exit +os.exit = function(...) + if M.LuaUnit and #M.LuaUnit.instances ~= 0 then + local msg = [[You are trying to exit but there is still a running instance of LuaUnit. +LuaUnit expects to run until the end before exiting with a complete status of successful/failed tests. +To force exit LuaUnit while running, please call before os.exit (assuming lu is the luaunit module loaded): + lu.unregisterCurrentSuite() +]] + M.private.error_fmt(2, msg) + end + M.oldOsExit(...) +end + +local function pcall_or_abort(func, ...) + -- unpack is a global function for Lua 5.1, otherwise use table.unpack + local unpack = rawget(_G, "unpack") or table.unpack + local result = {pcall(func, ...)} + if not result[1] then + -- an error occurred + print(result[2]) -- error message + print() + print(M.USAGE) + os.exit(-1) + end + return unpack(result, 2) +end + +local crossTypeOrdering = { + number = 1, boolean = 2, string = 3, table = 4, other = 5 +} +local crossTypeComparison = { + number = function(a, b) return a < b end, + string = function(a, b) return a < b end, + other = function(a, b) return tostring(a) < tostring(b) end, +} + +local function crossTypeSort(a, b) + local type_a, type_b = type(a), type(b) + if type_a == type_b then + local func = crossTypeComparison[type_a] or crossTypeComparison.other + return func(a, b) + end + type_a = crossTypeOrdering[type_a] or crossTypeOrdering.other + type_b = crossTypeOrdering[type_b] or crossTypeOrdering.other + return type_a < type_b +end + +local function __genSortedIndex( t ) + -- Returns a sequence consisting of t's keys, sorted. + local sortedIndex = {} + + for key,_ in pairs(t) do + table.insert(sortedIndex, key) + end + + table.sort(sortedIndex, crossTypeSort) + return sortedIndex +end +M.private.__genSortedIndex = __genSortedIndex + +local function sortedNext(state, control) + -- Equivalent of the next() function of table iteration, but returns the + -- keys in sorted order (see __genSortedIndex and crossTypeSort). + -- The state is a temporary variable during iteration and contains the + -- sorted key table (state.sortedIdx). It also stores the last index (into + -- the keys) used by the iteration, to find the next one quickly. + local key + + --print("sortedNext: control = "..tostring(control) ) + if control == nil then + -- start of iteration + state.count = #state.sortedIdx + state.lastIdx = 1 + key = state.sortedIdx[1] + return key, state.t[key] + end + + -- normally, we expect the control variable to match the last key used + if control ~= state.sortedIdx[state.lastIdx] then + -- strange, we have to find the next value by ourselves + -- the key table is sorted in crossTypeSort() order! -> use bisection + local lower, upper = 1, state.count + repeat + state.lastIdx = math.modf((lower + upper) / 2) + key = state.sortedIdx[state.lastIdx] + if key == control then + break -- key found (and thus prev index) + end + if crossTypeSort(key, control) then + -- key < control, continue search "right" (towards upper bound) + lower = state.lastIdx + 1 + else + -- key > control, continue search "left" (towards lower bound) + upper = state.lastIdx - 1 + end + until lower > upper + if lower > upper then -- only true if the key wasn't found, ... + state.lastIdx = state.count -- ... so ensure no match in code below + end + end + + -- proceed by retrieving the next value (or nil) from the sorted keys + state.lastIdx = state.lastIdx + 1 + key = state.sortedIdx[state.lastIdx] + if key then + return key, state.t[key] + end + + -- getting here means returning `nil`, which will end the iteration +end + +local function sortedPairs(tbl) + -- Equivalent of the pairs() function on tables. Allows to iterate in + -- sorted order. As required by "generic for" loops, this will return the + -- iterator (function), an "invariant state", and the initial control value. + -- (see http://www.lua.org/pil/7.2.html) + return sortedNext, {t = tbl, sortedIdx = __genSortedIndex(tbl)}, nil +end +M.private.sortedPairs = sortedPairs + +-- seed the random with a strongly varying seed +math.randomseed(math.floor(os.clock()*1E11)) + +local function randomizeTable( t ) + -- randomize the item orders of the table t + for i = #t, 2, -1 do + local j = math.random(i) + if i ~= j then + t[i], t[j] = t[j], t[i] + end + end +end +M.private.randomizeTable = randomizeTable + +local function strsplit(delimiter, text) +-- Split text into a list consisting of the strings in text, separated +-- by strings matching delimiter (which may _NOT_ be a pattern). +-- Example: strsplit(", ", "Anna, Bob, Charlie, Dolores") + if delimiter == "" or delimiter == nil then -- this would result in endless loops + error("delimiter is nil or empty string!") + end + if text == nil then + return nil + end + + local list, pos, first, last = {}, 1 + while true do + first, last = text:find(delimiter, pos, true) + if first then -- found? + table.insert(list, text:sub(pos, first - 1)) + pos = last + 1 + else + table.insert(list, text:sub(pos)) + break + end + end + return list +end +M.private.strsplit = strsplit + +local function hasNewLine( s ) + -- return true if s has a newline + return (string.find(s, '\n', 1, true) ~= nil) +end +M.private.hasNewLine = hasNewLine + +local function prefixString( prefix, s ) + -- Prefix all the lines of s with prefix + return prefix .. string.gsub(s, '\n', '\n' .. prefix) +end +M.private.prefixString = prefixString + +local function strMatch(s, pattern, start, final ) + -- return true if s matches completely the pattern from index start to index end + -- return false in every other cases + -- if start is nil, matches from the beginning of the string + -- if final is nil, matches to the end of the string + start = start or 1 + final = final or string.len(s) + + local foundStart, foundEnd = string.find(s, pattern, start, false) + return foundStart == start and foundEnd == final +end +M.private.strMatch = strMatch + +local function patternFilter(patterns, expr) + -- Run `expr` through the inclusion and exclusion rules defined in patterns + -- and return true if expr shall be included, false for excluded. + -- Inclusion pattern are defined as normal patterns, exclusions + -- patterns start with `!` and are followed by a normal pattern + + -- result: nil = UNKNOWN (not matched yet), true = ACCEPT, false = REJECT + -- default: true if no explicit "include" is found, set to false otherwise + local default, result = true, nil + + if patterns ~= nil then + for _, pattern in ipairs(patterns) do + local exclude = pattern:sub(1,1) == '!' + if exclude then + pattern = pattern:sub(2) + else + -- at least one include pattern specified, a match is required + default = false + end + -- print('pattern: ',pattern) + -- print('exclude: ',exclude) + -- print('default: ',default) + + if string.find(expr, pattern) then + -- set result to false when excluding, true otherwise + result = not exclude + end + end + end + + if result ~= nil then + return result + end + return default +end +M.private.patternFilter = patternFilter + +local function xmlEscape( s ) + -- Return s escaped for XML attributes + -- escapes table: + -- " " + -- ' ' + -- < < + -- > > + -- & & + + return string.gsub( s, '.', { + ['&'] = "&", + ['"'] = """, + ["'"] = "'", + ['<'] = "<", + ['>'] = ">", + } ) +end +M.private.xmlEscape = xmlEscape + +local function xmlCDataEscape( s ) + -- Return s escaped for CData section, escapes: "]]>" + return string.gsub( s, ']]>', ']]>' ) +end +M.private.xmlCDataEscape = xmlCDataEscape + + +local function lstrip( s ) + --[[Return s with all leading white spaces and tabs removed]] + local idx = 0 + while idx < s:len() do + idx = idx + 1 + local c = s:sub(idx,idx) + if c ~= ' ' and c ~= '\t' then + break + end + end + return s:sub(idx) +end +M.private.lstrip = lstrip + +local function extractFileLineInfo( s ) + --[[ From a string in the form "(leading spaces) dir1/dir2\dir3\file.lua:linenb: msg" + Return the "file.lua:linenb" information + ]] + local s2 = lstrip(s) + local firstColon = s2:find(':', 1, true) + if firstColon == nil then + -- string is not in the format file:line: + return s + end + local secondColon = s2:find(':', firstColon+1, true) + if secondColon == nil then + -- string is not in the format file:line: + return s + end + + return s2:sub(1, secondColon-1) +end +M.private.extractFileLineInfo = extractFileLineInfo + + +local function stripLuaunitTrace2( stackTrace, errMsg ) + --[[ + -- Example of a traceback: + < + [C]: in function 'xpcall' + ./luaunit.lua:1449: in function 'protectedCall' + ./luaunit.lua:1508: in function 'execOneFunction' + ./luaunit.lua:1596: in function 'runSuiteByInstances' + ./luaunit.lua:1660: in function 'runSuiteByNames' + ./luaunit.lua:1736: in function 'runSuite' + example_with_luaunit.lua:140: in main chunk + [C]: in ?>> + error message: <> + Other example: + < + [C]: in function 'xpcall' + ./luaunit.lua:1517: in function 'protectedCall' + ./luaunit.lua:1578: in function 'execOneFunction' + ./luaunit.lua:1677: in function 'runSuiteByInstances' + ./luaunit.lua:1730: in function 'runSuiteByNames' + ./luaunit.lua:1806: in function 'runSuite' + example_with_luaunit.lua:140: in main chunk + [C]: in ?>> + error message: <> + < + [C]: in function 'xpcall' + luaunit2/luaunit.lua:1532: in function 'protectedCall' + luaunit2/luaunit.lua:1591: in function 'execOneFunction' + luaunit2/luaunit.lua:1679: in function 'runSuiteByInstances' + luaunit2/luaunit.lua:1743: in function 'runSuiteByNames' + luaunit2/luaunit.lua:1819: in function 'runSuite' + luaunit2/example_with_luaunit.lua:140: in main chunk + [C]: in ?>> + error message: <> + -- first line is "stack traceback": KEEP + -- next line may be luaunit line: REMOVE + -- next lines are call in the program under testOk: REMOVE + -- next lines are calls from luaunit to call the program under test: KEEP + -- Strategy: + -- keep first line + -- remove lines that are part of luaunit + -- kepp lines until we hit a luaunit line + The strategy for stripping is: + * keep first line "stack traceback:" + * part1: + * analyse all lines of the stack from bottom to top of the stack (first line to last line) + * extract the "file:line:" part of the line + * compare it with the "file:line" part of the error message + * if it does not match strip the line + * if it matches, keep the line and move to part 2 + * part2: + * anything NOT starting with luaunit.lua is the interesting part of the stack trace + * anything starting again with luaunit.lua is part of the test launcher and should be stripped out + ]] + + local function isLuaunitInternalLine( s ) + -- return true if line of stack trace comes from inside luaunit + return s:find('[/\\]luaunit%.lua:%d+: ') ~= nil + end + + -- print( '<<'..stackTrace..'>>' ) + + local t = strsplit( '\n', stackTrace ) + -- print( prettystr(t) ) + + local idx = 2 + + local errMsgFileLine = extractFileLineInfo(errMsg) + -- print('emfi="'..errMsgFileLine..'"') + + -- remove lines that are still part of luaunit + while t[idx] and extractFileLineInfo(t[idx]) ~= errMsgFileLine do + -- print('Removing : '..t[idx] ) + table.remove(t, idx) + end + + -- keep lines until we hit luaunit again + while t[idx] and (not isLuaunitInternalLine(t[idx])) do + -- print('Keeping : '..t[idx] ) + idx = idx + 1 + end + + -- remove remaining luaunit lines + while t[idx] do + -- print('Removing2 : '..t[idx] ) + table.remove(t, idx) + end + + -- print( prettystr(t) ) + return table.concat( t, '\n') + +end +M.private.stripLuaunitTrace2 = stripLuaunitTrace2 + + +local function prettystr_sub(v, indentLevel, printTableRefs, cycleDetectTable ) + local type_v = type(v) + if "string" == type_v then + -- use clever delimiters according to content: + -- enclose with single quotes if string contains ", but no ' + if v:find('"', 1, true) and not v:find("'", 1, true) then + return "'" .. v .. "'" + end + -- use double quotes otherwise, escape embedded " + return '"' .. v:gsub('"', '\\"') .. '"' + + elseif "table" == type_v then + --if v.__class__ then + -- return string.gsub( tostring(v), 'table', v.__class__ ) + --end + return M.private._table_tostring(v, indentLevel, printTableRefs, cycleDetectTable) + + elseif "number" == type_v then + -- eliminate differences in formatting between various Lua versions + if v ~= v then + return "#NaN" -- "not a number" + end + if v == math.huge then + return "#Inf" -- "infinite" + end + if v == -math.huge then + return "-#Inf" + end + if _VERSION == "Lua 5.3" then + local i = math.tointeger(v) + if i then + return tostring(i) + end + end + end + + return tostring(v) +end + +local function prettystr( v ) + --[[ Pretty string conversion, to display the full content of a variable of any type. + * string are enclosed with " by default, or with ' if string contains a " + * tables are expanded to show their full content, with indentation in case of nested tables + ]]-- + local cycleDetectTable = {} + local s = prettystr_sub(v, 1, M.PRINT_TABLE_REF_IN_ERROR_MSG, cycleDetectTable) + if cycleDetectTable.detected and not M.PRINT_TABLE_REF_IN_ERROR_MSG then + -- some table contain recursive references, + -- so we must recompute the value by including all table references + -- else the result looks like crap + cycleDetectTable = {} + s = prettystr_sub(v, 1, true, cycleDetectTable) + end + return s +end +M.prettystr = prettystr + +function M.adjust_err_msg_with_iter( err_msg, iter_msg ) + --[[ Adjust the error message err_msg: trim the FAILURE_PREFIX or SUCCESS_PREFIX information if needed, + add the iteration message if any and return the result. + err_msg: string, error message captured with pcall + iter_msg: a string describing the current iteration ("iteration N") or nil + if there is no iteration in this test. + Returns: (new_err_msg, test_status) + new_err_msg: string, adjusted error message, or nil in case of success + test_status: M.NodeStatus.FAIL, SUCCESS or ERROR according to the information + contained in the error message. + ]] + if iter_msg then + iter_msg = iter_msg..', ' + else + iter_msg = '' + end + + local RE_FILE_LINE = '.*:%d+: ' + + -- error message is not necessarily a string, + -- so convert the value to string with prettystr() + if type( err_msg ) ~= 'string' then + err_msg = prettystr( err_msg ) + end + + if (err_msg:find( M.SUCCESS_PREFIX ) == 1) or err_msg:match( '('..RE_FILE_LINE..')' .. M.SUCCESS_PREFIX .. ".*" ) then + -- test finished early with success() + return nil, M.NodeStatus.SUCCESS + end + + if (err_msg:find( M.SKIP_PREFIX ) == 1) or (err_msg:match( '('..RE_FILE_LINE..')' .. M.SKIP_PREFIX .. ".*" ) ~= nil) then + -- substitute prefix by iteration message + err_msg = err_msg:gsub('.*'..M.SKIP_PREFIX, iter_msg, 1) + -- print("failure detected") + return err_msg, M.NodeStatus.SKIP + end + + if (err_msg:find( M.FAILURE_PREFIX ) == 1) or (err_msg:match( '('..RE_FILE_LINE..')' .. M.FAILURE_PREFIX .. ".*" ) ~= nil) then + -- substitute prefix by iteration message + err_msg = err_msg:gsub(M.FAILURE_PREFIX, iter_msg, 1) + -- print("failure detected") + return err_msg, M.NodeStatus.FAIL + end + + + + -- print("error detected") + -- regular error, not a failure + if iter_msg then + local match + -- "./test\\test_luaunit.lua:2241: some error msg + match = err_msg:match( '(.*:%d+: ).*' ) + if match then + err_msg = err_msg:gsub( match, match .. iter_msg ) + else + -- no file:line: infromation, just add the iteration info at the beginning of the line + err_msg = iter_msg .. err_msg + end + end + return err_msg, M.NodeStatus.ERROR +end + +local function tryMismatchFormatting( table_a, table_b, doDeepAnalysis, margin ) + --[[ + Prepares a nice error message when comparing tables, performing a deeper + analysis. + Arguments: + * table_a, table_b: tables to be compared + * doDeepAnalysis: + M.DEFAULT_DEEP_ANALYSIS: (the default if not specified) perform deep analysis only for big lists and big dictionnaries + M.FORCE_DEEP_ANALYSIS : always perform deep analysis + M.DISABLE_DEEP_ANALYSIS: never perform deep analysis + * margin: supplied only for almost equality + Returns: {success, result} + * success: false if deep analysis could not be performed + in this case, just use standard assertion message + * result: if success is true, a multi-line string with deep analysis of the two lists + ]] + + -- check if table_a & table_b are suitable for deep analysis + if type(table_a) ~= 'table' or type(table_b) ~= 'table' then + return false + end + + if doDeepAnalysis == M.DISABLE_DEEP_ANALYSIS then + return false + end + + local len_a, len_b, isPureList = #table_a, #table_b, true + + for k1, v1 in pairs(table_a) do + if type(k1) ~= 'number' or k1 > len_a then + -- this table a mapping + isPureList = false + break + end + end + + if isPureList then + for k2, v2 in pairs(table_b) do + if type(k2) ~= 'number' or k2 > len_b then + -- this table a mapping + isPureList = false + break + end + end + end + + if isPureList and math.min(len_a, len_b) < M.LIST_DIFF_ANALYSIS_THRESHOLD then + if not (doDeepAnalysis == M.FORCE_DEEP_ANALYSIS) then + return false + end + end + + if isPureList then + return M.private.mismatchFormattingPureList( table_a, table_b, margin ) + else + -- only work on mapping for the moment + -- return M.private.mismatchFormattingMapping( table_a, table_b, doDeepAnalysis ) + return false + end +end +M.private.tryMismatchFormatting = tryMismatchFormatting + +local function getTaTbDescr() + if not M.ORDER_ACTUAL_EXPECTED then + return 'expected', 'actual' + end + return 'actual', 'expected' +end + +local function extendWithStrFmt( res, ... ) + table.insert( res, string.format( ... ) ) +end + +local function mismatchFormattingMapping( table_a, table_b, doDeepAnalysis ) + --[[ + Prepares a nice error message when comparing tables which are not pure lists, performing a deeper + analysis. + Returns: {success, result} + * success: false if deep analysis could not be performed + in this case, just use standard assertion message + * result: if success is true, a multi-line string with deep analysis of the two lists + ]] + + -- disable for the moment + --[[ + local result = {} + local descrTa, descrTb = getTaTbDescr() + local keysCommon = {} + local keysOnlyTa = {} + local keysOnlyTb = {} + local keysDiffTaTb = {} + local k, v + for k,v in pairs( table_a ) do + if is_equal( v, table_b[k] ) then + table.insert( keysCommon, k ) + else + if table_b[k] == nil then + table.insert( keysOnlyTa, k ) + else + table.insert( keysDiffTaTb, k ) + end + end + end + for k,v in pairs( table_b ) do + if not is_equal( v, table_a[k] ) and table_a[k] == nil then + table.insert( keysOnlyTb, k ) + end + end + local len_a = #keysCommon + #keysDiffTaTb + #keysOnlyTa + local len_b = #keysCommon + #keysDiffTaTb + #keysOnlyTb + local limited_display = (len_a < 5 or len_b < 5) + if math.min(len_a, len_b) < M.TABLE_DIFF_ANALYSIS_THRESHOLD then + return false + end + if not limited_display then + if len_a == len_b then + extendWithStrFmt( result, 'Table A (%s) and B (%s) both have %d items', descrTa, descrTb, len_a ) + else + extendWithStrFmt( result, 'Table A (%s) has %d items and table B (%s) has %d items', descrTa, len_a, descrTb, len_b ) + end + if #keysCommon == 0 and #keysDiffTaTb == 0 then + table.insert( result, 'Table A and B have no keys in common, they are totally different') + else + local s_other = 'other ' + if #keysCommon then + extendWithStrFmt( result, 'Table A and B have %d identical items', #keysCommon ) + else + table.insert( result, 'Table A and B have no identical items' ) + s_other = '' + end + if #keysDiffTaTb ~= 0 then + result[#result] = string.format( '%s and %d items differing present in both tables', result[#result], #keysDiffTaTb) + else + result[#result] = string.format( '%s and no %sitems differing present in both tables', result[#result], s_other, #keysDiffTaTb) + end + end + extendWithStrFmt( result, 'Table A has %d keys not present in table B and table B has %d keys not present in table A', #keysOnlyTa, #keysOnlyTb ) + end + local function keytostring(k) + if "string" == type(k) and k:match("^[_%a][_%w]*$") then + return k + end + return prettystr(k) + end + if #keysDiffTaTb ~= 0 then + table.insert( result, 'Items differing in A and B:') + for k,v in sortedPairs( keysDiffTaTb ) do + extendWithStrFmt( result, ' - A[%s]: %s', keytostring(v), prettystr(table_a[v]) ) + extendWithStrFmt( result, ' + B[%s]: %s', keytostring(v), prettystr(table_b[v]) ) + end + end + if #keysOnlyTa ~= 0 then + table.insert( result, 'Items only in table A:' ) + for k,v in sortedPairs( keysOnlyTa ) do + extendWithStrFmt( result, ' - A[%s]: %s', keytostring(v), prettystr(table_a[v]) ) + end + end + if #keysOnlyTb ~= 0 then + table.insert( result, 'Items only in table B:' ) + for k,v in sortedPairs( keysOnlyTb ) do + extendWithStrFmt( result, ' + B[%s]: %s', keytostring(v), prettystr(table_b[v]) ) + end + end + if #keysCommon ~= 0 then + table.insert( result, 'Items common to A and B:') + for k,v in sortedPairs( keysCommon ) do + extendWithStrFmt( result, ' = A and B [%s]: %s', keytostring(v), prettystr(table_a[v]) ) + end + end + return true, table.concat( result, '\n') + ]] +end +M.private.mismatchFormattingMapping = mismatchFormattingMapping + +local function mismatchFormattingPureList( table_a, table_b, margin ) + --[[ + Prepares a nice error message when comparing tables which are lists, performing a deeper + analysis. + margin is supplied only for almost equality + Returns: {success, result} + * success: false if deep analysis could not be performed + in this case, just use standard assertion message + * result: if success is true, a multi-line string with deep analysis of the two lists + ]] + local result, descrTa, descrTb = {}, getTaTbDescr() + + local len_a, len_b, refa, refb = #table_a, #table_b, '', '' + if M.PRINT_TABLE_REF_IN_ERROR_MSG then + refa, refb = string.format( '<%s> ', M.private.table_ref(table_a)), string.format('<%s> ', M.private.table_ref(table_b) ) + end + local longest, shortest = math.max(len_a, len_b), math.min(len_a, len_b) + local deltalv = longest - shortest + + local commonUntil = shortest + for i = 1, shortest do + if not M.private.is_table_equals(table_a[i], table_b[i], margin) then + commonUntil = i - 1 + break + end + end + + local commonBackTo = shortest - 1 + for i = 0, shortest - 1 do + if not M.private.is_table_equals(table_a[len_a-i], table_b[len_b-i], margin) then + commonBackTo = i - 1 + break + end + end + + + table.insert( result, 'List difference analysis:' ) + if len_a == len_b then + -- TODO: handle expected/actual naming + extendWithStrFmt( result, '* lists %sA (%s) and %sB (%s) have the same size', refa, descrTa, refb, descrTb ) + else + extendWithStrFmt( result, '* list sizes differ: list %sA (%s) has %d items, list %sB (%s) has %d items', refa, descrTa, len_a, refb, descrTb, len_b ) + end + + extendWithStrFmt( result, '* lists A and B start differing at index %d', commonUntil+1 ) + if commonBackTo >= 0 then + if deltalv > 0 then + extendWithStrFmt( result, '* lists A and B are equal again from index %d for A, %d for B', len_a-commonBackTo, len_b-commonBackTo ) + else + extendWithStrFmt( result, '* lists A and B are equal again from index %d', len_a-commonBackTo ) + end + end + + local function insertABValue(ai, bi) + bi = bi or ai + if M.private.is_table_equals( table_a[ai], table_b[bi], margin) then + return extendWithStrFmt( result, ' = A[%d], B[%d]: %s', ai, bi, prettystr(table_a[ai]) ) + else + extendWithStrFmt( result, ' - A[%d]: %s', ai, prettystr(table_a[ai])) + extendWithStrFmt( result, ' + B[%d]: %s', bi, prettystr(table_b[bi])) + end + end + + -- common parts to list A & B, at the beginning + if commonUntil > 0 then + table.insert( result, '* Common parts:' ) + for i = 1, commonUntil do + insertABValue( i ) + end + end + + -- diffing parts to list A & B + if commonUntil < shortest - commonBackTo - 1 then + table.insert( result, '* Differing parts:' ) + for i = commonUntil + 1, shortest - commonBackTo - 1 do + insertABValue( i ) + end + end + + -- display indexes of one list, with no match on other list + if shortest - commonBackTo <= longest - commonBackTo - 1 then + table.insert( result, '* Present only in one list:' ) + for i = shortest - commonBackTo, longest - commonBackTo - 1 do + if len_a > len_b then + extendWithStrFmt( result, ' - A[%d]: %s', i, prettystr(table_a[i]) ) + -- table.insert( result, '+ (no matching B index)') + else + -- table.insert( result, '- no matching A index') + extendWithStrFmt( result, ' + B[%d]: %s', i, prettystr(table_b[i]) ) + end + end + end + + -- common parts to list A & B, at the end + if commonBackTo >= 0 then + table.insert( result, '* Common parts at the end of the lists' ) + for i = longest - commonBackTo, longest do + if len_a > len_b then + insertABValue( i, i-deltalv ) + else + insertABValue( i-deltalv, i ) + end + end + end + + return true, table.concat( result, '\n') +end +M.private.mismatchFormattingPureList = mismatchFormattingPureList + +local function prettystrPairs(value1, value2, suffix_a, suffix_b) + --[[ + This function helps with the recurring task of constructing the "expected + vs. actual" error messages. It takes two arbitrary values and formats + corresponding strings with prettystr(). + To keep the (possibly complex) output more readable in case the resulting + strings contain line breaks, they get automatically prefixed with additional + newlines. Both suffixes are optional (default to empty strings), and get + appended to the "value1" string. "suffix_a" is used if line breaks were + encountered, "suffix_b" otherwise. + Returns the two formatted strings (including padding/newlines). + ]] + local str1, str2 = prettystr(value1), prettystr(value2) + if hasNewLine(str1) or hasNewLine(str2) then + -- line break(s) detected, add padding + return "\n" .. str1 .. (suffix_a or ""), "\n" .. str2 + end + return str1 .. (suffix_b or ""), str2 +end +M.private.prettystrPairs = prettystrPairs + +local UNKNOWN_REF = 'table 00-unknown ref' +local ref_generator = { value=1, [UNKNOWN_REF]=0 } + +local function table_ref( t ) + -- return the default tostring() for tables, with the table ID, even if the table has a metatable + -- with the __tostring converter + local ref = '' + local mt = getmetatable( t ) + if mt == nil then + ref = tostring(t) + else + local success, result + success, result = pcall(setmetatable, t, nil) + if not success then + -- protected table, if __tostring is defined, we can + -- not get the reference. And we can not know in advance. + ref = tostring(t) + if not ref:match( 'table: 0?x?[%x]+' ) then + return UNKNOWN_REF + end + else + ref = tostring(t) + setmetatable( t, mt ) + end + end + -- strip the "table: " part + ref = ref:sub(8) + if ref ~= UNKNOWN_REF and ref_generator[ref] == nil then + -- Create a new reference number + ref_generator[ref] = ref_generator.value + ref_generator.value = ref_generator.value+1 + end + if M.PRINT_TABLE_REF_IN_ERROR_MSG then + return string.format('table %02d-%s', ref_generator[ref], ref) + else + return string.format('table %02d', ref_generator[ref]) + end +end +M.private.table_ref = table_ref + +local TABLE_TOSTRING_SEP = ", " +local TABLE_TOSTRING_SEP_LEN = string.len(TABLE_TOSTRING_SEP) + +local function _table_tostring( tbl, indentLevel, printTableRefs, cycleDetectTable ) + printTableRefs = printTableRefs or M.PRINT_TABLE_REF_IN_ERROR_MSG + cycleDetectTable = cycleDetectTable or {} + cycleDetectTable[tbl] = true + + local result, dispOnMultLines = {}, false + + -- like prettystr but do not enclose with "" if the string is just alphanumerical + -- this is better for displaying table keys who are often simple strings + local function keytostring(k) + if "string" == type(k) and k:match("^[_%a][_%w]*$") then + return k + end + return prettystr_sub(k, indentLevel+1, printTableRefs, cycleDetectTable) + end + + local mt = getmetatable( tbl ) + + if mt and mt.__tostring then + -- if table has a __tostring() function in its metatable, use it to display the table + -- else, compute a regular table + result = tostring(tbl) + if type(result) ~= 'string' then + return string.format( '', prettystr(result) ) + end + result = strsplit( '\n', result ) + return M.private._table_tostring_format_multiline_string( result, indentLevel ) + + else + -- no metatable, compute the table representation + + local entry, count, seq_index = nil, 0, 1 + for k, v in sortedPairs( tbl ) do + + -- key part + if k == seq_index then + -- for the sequential part of tables, we'll skip the "=" output + entry = '' + seq_index = seq_index + 1 + elseif cycleDetectTable[k] then + -- recursion in the key detected + cycleDetectTable.detected = true + entry = "<"..table_ref(k)..">=" + else + entry = keytostring(k) .. "=" + end + + -- value part + if cycleDetectTable[v] then + -- recursion in the value detected! + cycleDetectTable.detected = true + entry = entry .. "<"..table_ref(v)..">" + else + entry = entry .. + prettystr_sub( v, indentLevel+1, printTableRefs, cycleDetectTable ) + end + count = count + 1 + result[count] = entry + end + return M.private._table_tostring_format_result( tbl, result, indentLevel, printTableRefs ) + end + +end +M.private._table_tostring = _table_tostring -- prettystr_sub() needs it + +local function _table_tostring_format_multiline_string( tbl_str, indentLevel ) + local indentString = '\n'..string.rep(" ", indentLevel - 1) + return table.concat( tbl_str, indentString ) + +end +M.private._table_tostring_format_multiline_string = _table_tostring_format_multiline_string + + +local function _table_tostring_format_result( tbl, result, indentLevel, printTableRefs ) + -- final function called in _table_to_string() to format the resulting list of + -- string describing the table. + + local dispOnMultLines = false + + -- set dispOnMultLines to true if the maximum LINE_LENGTH would be exceeded with the values + local totalLength = 0 + for k, v in ipairs( result ) do + totalLength = totalLength + string.len( v ) + if totalLength >= M.LINE_LENGTH then + dispOnMultLines = true + break + end + end + + -- set dispOnMultLines to true if the max LINE_LENGTH would be exceeded + -- with the values and the separators. + if not dispOnMultLines then + -- adjust with length of separator(s): + -- two items need 1 sep, three items two seps, ... plus len of '{}' + if #result > 0 then + totalLength = totalLength + TABLE_TOSTRING_SEP_LEN * (#result - 1) + end + dispOnMultLines = (totalLength + 2 >= M.LINE_LENGTH) + end + + -- now reformat the result table (currently holding element strings) + if dispOnMultLines then + local indentString = string.rep(" ", indentLevel - 1) + result = { + "{\n ", + indentString, + table.concat(result, ",\n " .. indentString), + "\n", + indentString, + "}" + } + else + result = {"{", table.concat(result, TABLE_TOSTRING_SEP), "}"} + end + if printTableRefs then + table.insert(result, 1, "<"..table_ref(tbl).."> ") -- prepend table ref + end + return table.concat(result) +end +M.private._table_tostring_format_result = _table_tostring_format_result -- prettystr_sub() needs it + +local function table_findkeyof(t, element) + -- Return the key k of the given element in table t, so that t[k] == element + -- (or `nil` if element is not present within t). Note that we use our + -- 'general' is_equal comparison for matching, so this function should + -- handle table-type elements gracefully and consistently. + if type(t) == "table" then + for k, v in pairs(t) do + if M.private.is_table_equals(v, element) then + return k + end + end + end + return nil +end + +local function _is_table_items_equals(actual, expected ) + local type_a, type_e = type(actual), type(expected) + + if type_a ~= type_e then + return false + + elseif (type_a == 'table') --[[and (type_e == 'table')]] then + for k, v in pairs(actual) do + if table_findkeyof(expected, v) == nil then + return false -- v not contained in expected + end + end + for k, v in pairs(expected) do + if table_findkeyof(actual, v) == nil then + return false -- v not contained in actual + end + end + return true + + elseif actual ~= expected then + return false + end + + return true +end + +--[[ +This is a specialized metatable to help with the bookkeeping of recursions +in _is_table_equals(). It provides an __index table that implements utility +functions for easier management of the table. The "cached" method queries +the state of a specific (actual,expected) pair; and the "store" method sets +this state to the given value. The state of pairs not "seen" / visited is +assumed to be `nil`. +]] +local _recursion_cache_MT = { + __index = { + -- Return the cached value for an (actual,expected) pair (or `nil`) + cached = function(t, actual, expected) + local subtable = t[actual] or {} + return subtable[expected] + end, + + -- Store cached value for a specific (actual,expected) pair. + -- Returns the value, so it's easy to use for a "tailcall" (return ...). + store = function(t, actual, expected, value, asymmetric) + local subtable = t[actual] + if not subtable then + subtable = {} + t[actual] = subtable + end + subtable[expected] = value + + -- Unless explicitly marked "asymmetric": Consider the recursion + -- on (expected,actual) to be equivalent to (actual,expected) by + -- default, and thus cache the value for both. + if not asymmetric then + t:store(expected, actual, value, true) + end + + return value + end + } +} + +local function _is_table_equals(actual, expected, cycleDetectTable, marginForAlmostEqual) + --[[Returns true if both table are equal. + If argument marginForAlmostEqual is suppied, number comparison is done using alomstEqual instead + of strict equality. + cycleDetectTable is an internal argument used during recursion on tables. + ]] + --print('_is_table_equals( \n '..prettystr(actual)..'\n , '..prettystr(expected).. + -- '\n , '..prettystr(cycleDetectTable)..'\n , '..prettystr(marginForAlmostEqual)..' )') + + local type_a, type_e = type(actual), type(expected) + + if type_a ~= type_e then + return false -- different types won't match + end + + if type_a == 'number' then + if marginForAlmostEqual ~= nil then + return M.almostEquals(actual, expected, marginForAlmostEqual) + else + return actual == expected + end + elseif type_a ~= 'table' then + -- other types compare directly + return actual == expected + end + + cycleDetectTable = cycleDetectTable or { actual={}, expected={} } + if cycleDetectTable.actual[ actual ] then + -- oh, we hit a cycle in actual + if cycleDetectTable.expected[ expected ] then + -- uh, we hit a cycle at the same time in expected + -- so the two tables have similar structure + return true + end + + -- cycle was hit only in actual, the structure differs from expected + return false + end + + if cycleDetectTable.expected[ expected ] then + -- no cycle in actual, but cycle in expected + -- the structure differ + return false + end + + -- at this point, no table cycle detected, we are + -- seeing this table for the first time + + -- mark the cycle detection + cycleDetectTable.actual[ actual ] = true + cycleDetectTable.expected[ expected ] = true + + + local actualKeysMatched = {} + for k, v in pairs(actual) do + actualKeysMatched[k] = true -- Keep track of matched keys + if not _is_table_equals(v, expected[k], cycleDetectTable, marginForAlmostEqual) then + -- table differs on this key + -- clear the cycle detection before returning + cycleDetectTable.actual[ actual ] = nil + cycleDetectTable.expected[ expected ] = nil + return false + end + end + + for k, v in pairs(expected) do + if not actualKeysMatched[k] then + -- Found a key that we did not see in "actual" -> mismatch + -- clear the cycle detection before returning + cycleDetectTable.actual[ actual ] = nil + cycleDetectTable.expected[ expected ] = nil + return false + end + -- Otherwise actual[k] was already matched against v = expected[k]. + end + + -- all key match, we have a match ! + cycleDetectTable.actual[ actual ] = nil + cycleDetectTable.expected[ expected ] = nil + return true +end +M.private._is_table_equals = _is_table_equals + +local function failure(main_msg, extra_msg_or_nil, level) + -- raise an error indicating a test failure + -- for error() compatibility we adjust "level" here (by +1), to report the + -- calling context + local msg + if type(extra_msg_or_nil) == 'string' and extra_msg_or_nil:len() > 0 then + msg = extra_msg_or_nil .. '\n' .. main_msg + else + msg = main_msg + end + error(M.FAILURE_PREFIX .. msg, (level or 1) + 1 + M.STRIP_EXTRA_ENTRIES_IN_STACK_TRACE) +end + +local function is_table_equals(actual, expected, marginForAlmostEqual) + return _is_table_equals(actual, expected, nil, marginForAlmostEqual) +end +M.private.is_table_equals = is_table_equals + +local function fail_fmt(level, extra_msg_or_nil, ...) + -- failure with printf-style formatted message and given error level + failure(string.format(...), extra_msg_or_nil, (level or 1) + 1) +end +M.private.fail_fmt = fail_fmt + +local function error_fmt(level, ...) + -- printf-style error() + error(string.format(...), (level or 1) + 1 + M.STRIP_EXTRA_ENTRIES_IN_STACK_TRACE) +end +M.private.error_fmt = error_fmt + +---------------------------------------------------------------- +-- +-- assertions +-- +---------------------------------------------------------------- + +local function errorMsgEquality(actual, expected, doDeepAnalysis, margin) + -- margin is supplied only for almost equal verification + + if not M.ORDER_ACTUAL_EXPECTED then + expected, actual = actual, expected + end + if type(expected) == 'string' or type(expected) == 'table' then + local strExpected, strActual = prettystrPairs(expected, actual) + local result = string.format("expected: %s\nactual: %s", strExpected, strActual) + if margin then + result = result .. '\nwere not equal by the margin of: '..prettystr(margin) + end + + -- extend with mismatch analysis if possible: + local success, mismatchResult + success, mismatchResult = tryMismatchFormatting( actual, expected, doDeepAnalysis, margin ) + if success then + result = table.concat( { result, mismatchResult }, '\n' ) + end + return result + end + return string.format("expected: %s, actual: %s", + prettystr(expected), prettystr(actual)) +end + +function M.assertError(f, ...) + -- assert that calling f with the arguments will raise an error + -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error + if pcall( f, ... ) then + failure( "Expected an error when calling function but no error generated", nil, 2 ) + end +end + +function M.fail( msg ) + -- stops a test due to a failure + failure( msg, nil, 2 ) +end + +function M.failIf( cond, msg ) + -- Fails a test with "msg" if condition is true + if cond then + failure( msg, nil, 2 ) + end +end + +function M.skip(msg) + -- skip a running test + error_fmt(2, M.SKIP_PREFIX .. msg) +end + +function M.skipIf( cond, msg ) + -- skip a running test if condition is met + if cond then + error_fmt(2, M.SKIP_PREFIX .. msg) + end +end + +function M.runOnlyIf( cond, msg ) + -- continue a running test if condition is met, else skip it + if not cond then + error_fmt(2, M.SKIP_PREFIX .. prettystr(msg)) + end +end + +function M.success() + -- stops a test with a success + error_fmt(2, M.SUCCESS_PREFIX) +end + +function M.successIf( cond ) + -- stops a test with a success if condition is met + if cond then + error_fmt(2, M.SUCCESS_PREFIX) + end +end + + +------------------------------------------------------------------ +-- Equality assertions +------------------------------------------------------------------ + +function M.assertEquals(actual, expected, extra_msg_or_nil, doDeepAnalysis) + if type(actual) == 'table' and type(expected) == 'table' then + if not is_table_equals(actual, expected) then + failure( errorMsgEquality(actual, expected, doDeepAnalysis), extra_msg_or_nil, 2 ) + end + elseif type(actual) ~= type(expected) then + failure( errorMsgEquality(actual, expected), extra_msg_or_nil, 2 ) + elseif actual ~= expected then + failure( errorMsgEquality(actual, expected), extra_msg_or_nil, 2 ) + end +end + +function M.almostEquals( actual, expected, margin ) + if type(actual) ~= 'number' or type(expected) ~= 'number' or type(margin) ~= 'number' then + error_fmt(3, 'almostEquals: must supply only number arguments.\nArguments supplied: %s, %s, %s', + prettystr(actual), prettystr(expected), prettystr(margin)) + end + if margin < 0 then + error_fmt(3, 'almostEquals: margin must not be negative, current value is ' .. margin) + end + return math.abs(expected - actual) <= margin +end + +function M.assertAlmostEquals( actual, expected, margin, extra_msg_or_nil ) + -- check that two floats are close by margin + margin = margin or M.EPS + if type(margin) ~= 'number' then + error_fmt(2, 'almostEquals: margin must be a number, not %s', prettystr(margin)) + end + + if type(actual) == 'table' and type(expected) == 'table' then + -- handle almost equals for table + if not is_table_equals(actual, expected, margin) then + failure( errorMsgEquality(actual, expected, nil, margin), extra_msg_or_nil, 2 ) + end + elseif type(actual) == 'number' and type(expected) == 'number' and type(margin) == 'number' then + if not M.almostEquals(actual, expected, margin) then + if not M.ORDER_ACTUAL_EXPECTED then + expected, actual = actual, expected + end + local delta = math.abs(actual - expected) + fail_fmt(2, extra_msg_or_nil, 'Values are not almost equal\n' .. + 'Actual: %s, expected: %s, delta %s above margin of %s', + actual, expected, delta, margin) + end + else + error_fmt(3, 'almostEquals: must supply only number or table arguments.\nArguments supplied: %s, %s, %s', + prettystr(actual), prettystr(expected), prettystr(margin)) + end +end + +function M.assertNotEquals(actual, expected, extra_msg_or_nil) + if type(actual) ~= type(expected) then + return + end + + if type(actual) == 'table' and type(expected) == 'table' then + if not is_table_equals(actual, expected) then + return + end + elseif actual ~= expected then + return + end + fail_fmt(2, extra_msg_or_nil, 'Received the not expected value: %s', prettystr(actual)) +end + +function M.assertNotAlmostEquals( actual, expected, margin, extra_msg_or_nil ) + -- check that two floats are not close by margin + margin = margin or M.EPS + if M.almostEquals(actual, expected, margin) then + if not M.ORDER_ACTUAL_EXPECTED then + expected, actual = actual, expected + end + local delta = math.abs(actual - expected) + fail_fmt(2, extra_msg_or_nil, 'Values are almost equal\nActual: %s, expected: %s' .. + ', delta %s below margin of %s', + actual, expected, delta, margin) + end +end + +function M.assertItemsEquals(actual, expected, extra_msg_or_nil) + -- checks that the items of table expected + -- are contained in table actual. Warning, this function + -- is at least O(n^2) + if not _is_table_items_equals(actual, expected ) then + expected, actual = prettystrPairs(expected, actual) + fail_fmt(2, extra_msg_or_nil, 'Content of the tables are not identical:\nExpected: %s\nActual: %s', + expected, actual) + end +end + +------------------------------------------------------------------ +-- String assertion +------------------------------------------------------------------ + +function M.assertStrContains( str, sub, isPattern, extra_msg_or_nil ) + -- this relies on lua string.find function + -- a string always contains the empty string + -- assert( type(str) == 'string', 'Argument 1 of assertStrContains() should be a string.' ) ) + -- assert( type(sub) == 'string', 'Argument 2 of assertStrContains() should be a string.' ) ) + if not string.find(str, sub, 1, not isPattern) then + sub, str = prettystrPairs(sub, str, '\n') + fail_fmt(2, extra_msg_or_nil, 'Could not find %s %s in string %s', + isPattern and 'pattern' or 'substring', sub, str) + end +end + +function M.assertStrIContains( str, sub, extra_msg_or_nil ) + -- this relies on lua string.find function + -- a string always contains the empty string + if not string.find(str:lower(), sub:lower(), 1, true) then + sub, str = prettystrPairs(sub, str, '\n') + fail_fmt(2, extra_msg_or_nil, 'Could not find (case insensitively) substring %s in string %s', + sub, str) + end +end + +function M.assertNotStrContains( str, sub, isPattern, extra_msg_or_nil ) + -- this relies on lua string.find function + -- a string always contains the empty string + if string.find(str, sub, 1, not isPattern) then + sub, str = prettystrPairs(sub, str, '\n') + fail_fmt(2, extra_msg_or_nil, 'Found the not expected %s %s in string %s', + isPattern and 'pattern' or 'substring', sub, str) + end +end + +function M.assertNotStrIContains( str, sub, extra_msg_or_nil ) + -- this relies on lua string.find function + -- a string always contains the empty string + if string.find(str:lower(), sub:lower(), 1, true) then + sub, str = prettystrPairs(sub, str, '\n') + fail_fmt(2, extra_msg_or_nil, 'Found (case insensitively) the not expected substring %s in string %s', + sub, str) + end +end + +function M.assertStrMatches( str, pattern, start, final, extra_msg_or_nil ) + -- Verify a full match for the string + if not strMatch( str, pattern, start, final ) then + pattern, str = prettystrPairs(pattern, str, '\n') + fail_fmt(2, extra_msg_or_nil, 'Could not match pattern %s with string %s', + pattern, str) + end +end + +local function _assertErrorMsgEquals( stripFileAndLine, expectedMsg, func, ... ) + local no_error, error_msg = pcall( func, ... ) + if no_error then + failure( 'No error generated when calling function but expected error: '..M.prettystr(expectedMsg), nil, 3 ) + end + if type(expectedMsg) == "string" and type(error_msg) ~= "string" then + -- table are converted to string automatically + error_msg = tostring(error_msg) + end + local differ = false + if stripFileAndLine then + if error_msg:gsub("^.+:%d+: ", "") ~= expectedMsg then + differ = true + end + else + if error_msg ~= expectedMsg then + local tr = type(error_msg) + local te = type(expectedMsg) + if te == 'table' then + if tr ~= 'table' then + differ = true + else + local ok = pcall(M.assertItemsEquals, error_msg, expectedMsg) + if not ok then + differ = true + end + end + else + differ = true + end + end + end + + if differ then + error_msg, expectedMsg = prettystrPairs(error_msg, expectedMsg) + fail_fmt(3, nil, 'Error message expected: %s\nError message received: %s\n', + expectedMsg, error_msg) + end +end + +function M.assertErrorMsgEquals( expectedMsg, func, ... ) + -- assert that calling f with the arguments will raise an error + -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error + _assertErrorMsgEquals(false, expectedMsg, func, ...) +end + +function M.assertErrorMsgContentEquals(expectedMsg, func, ...) + _assertErrorMsgEquals(true, expectedMsg, func, ...) +end + +function M.assertErrorMsgContains( partialMsg, func, ... ) + -- assert that calling f with the arguments will raise an error + -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error + local no_error, error_msg = pcall( func, ... ) + if no_error then + failure( 'No error generated when calling function but expected error containing: '..prettystr(partialMsg), nil, 2 ) + end + if type(error_msg) ~= "string" then + error_msg = tostring(error_msg) + end + if not string.find( error_msg, partialMsg, nil, true ) then + error_msg, partialMsg = prettystrPairs(error_msg, partialMsg) + fail_fmt(2, nil, 'Error message does not contain: %s\nError message received: %s\n', + partialMsg, error_msg) + end +end + +function M.assertErrorMsgMatches( expectedMsg, func, ... ) + -- assert that calling f with the arguments will raise an error + -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error + local no_error, error_msg = pcall( func, ... ) + if no_error then + failure( 'No error generated when calling function but expected error matching: "'..expectedMsg..'"', nil, 2 ) + end + if type(error_msg) ~= "string" then + error_msg = tostring(error_msg) + end + if not strMatch( error_msg, expectedMsg ) then + expectedMsg, error_msg = prettystrPairs(expectedMsg, error_msg) + fail_fmt(2, nil, 'Error message does not match pattern: %s\nError message received: %s\n', + expectedMsg, error_msg) + end +end + +------------------------------------------------------------------ +-- Type assertions +------------------------------------------------------------------ + +function M.assertEvalToTrue(value, extra_msg_or_nil) + if not value then + failure("expected: a value evaluating to true, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertEvalToFalse(value, extra_msg_or_nil) + if value then + failure("expected: false or nil, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsTrue(value, extra_msg_or_nil) + if value ~= true then + failure("expected: true, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertNotIsTrue(value, extra_msg_or_nil) + if value == true then + failure("expected: not true, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsFalse(value, extra_msg_or_nil) + if value ~= false then + failure("expected: false, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertNotIsFalse(value, extra_msg_or_nil) + if value == false then + failure("expected: not false, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsNil(value, extra_msg_or_nil) + if value ~= nil then + failure("expected: nil, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertNotIsNil(value, extra_msg_or_nil) + if value == nil then + failure("expected: not nil, actual: nil", extra_msg_or_nil, 2) + end +end + +--[[ +Add type assertion functions to the module table M. Each of these functions +takes a single parameter "value", and checks that its Lua type matches the +expected string (derived from the function name): +M.assertIsXxx(value) -> ensure that type(value) conforms to "xxx" +]] +for _, funcName in ipairs( + {'assertIsNumber', 'assertIsString', 'assertIsTable', 'assertIsBoolean', + 'assertIsFunction', 'assertIsUserdata', 'assertIsThread'} +) do + local typeExpected = funcName:match("^assertIs([A-Z]%a*)$") + -- Lua type() always returns lowercase, also make sure the match() succeeded + typeExpected = typeExpected and typeExpected:lower() + or error("bad function name '"..funcName.."' for type assertion") + + M[funcName] = function(value, extra_msg_or_nil) + if type(value) ~= typeExpected then + if type(value) == 'nil' then + fail_fmt(2, extra_msg_or_nil, 'expected: a %s value, actual: nil', + typeExpected, type(value), prettystrPairs(value)) + else + fail_fmt(2, extra_msg_or_nil, 'expected: a %s value, actual: type %s, value %s', + typeExpected, type(value), prettystrPairs(value)) + end + end + end +end + +--[[ +Add shortcuts for verifying type of a variable, without failure (luaunit v2 compatibility) +M.isXxx(value) -> returns true if type(value) conforms to "xxx" +]] +for _, typeExpected in ipairs( + {'Number', 'String', 'Table', 'Boolean', + 'Function', 'Userdata', 'Thread', 'Nil' } +) do + local typeExpectedLower = typeExpected:lower() + local isType = function(value) + return (type(value) == typeExpectedLower) + end + M['is'..typeExpected] = isType + M['is_'..typeExpectedLower] = isType +end + +--[[ +Add non-type assertion functions to the module table M. Each of these functions +takes a single parameter "value", and checks that its Lua type differs from the +expected string (derived from the function name): +M.assertNotIsXxx(value) -> ensure that type(value) is not "xxx" +]] +for _, funcName in ipairs( + {'assertNotIsNumber', 'assertNotIsString', 'assertNotIsTable', 'assertNotIsBoolean', + 'assertNotIsFunction', 'assertNotIsUserdata', 'assertNotIsThread'} +) do + local typeUnexpected = funcName:match("^assertNotIs([A-Z]%a*)$") + -- Lua type() always returns lowercase, also make sure the match() succeeded + typeUnexpected = typeUnexpected and typeUnexpected:lower() + or error("bad function name '"..funcName.."' for type assertion") + + M[funcName] = function(value, extra_msg_or_nil) + if type(value) == typeUnexpected then + fail_fmt(2, extra_msg_or_nil, 'expected: not a %s type, actual: value %s', + typeUnexpected, prettystrPairs(value)) + end + end +end + +function M.assertIs(actual, expected, extra_msg_or_nil) + if actual ~= expected then + if not M.ORDER_ACTUAL_EXPECTED then + actual, expected = expected, actual + end + local old_print_table_ref_in_error_msg = M.PRINT_TABLE_REF_IN_ERROR_MSG + M.PRINT_TABLE_REF_IN_ERROR_MSG = true + expected, actual = prettystrPairs(expected, actual, '\n', '') + M.PRINT_TABLE_REF_IN_ERROR_MSG = old_print_table_ref_in_error_msg + fail_fmt(2, extra_msg_or_nil, 'expected and actual object should not be different\nExpected: %s\nReceived: %s', + expected, actual) + end +end + +function M.assertNotIs(actual, expected, extra_msg_or_nil) + if actual == expected then + local old_print_table_ref_in_error_msg = M.PRINT_TABLE_REF_IN_ERROR_MSG + M.PRINT_TABLE_REF_IN_ERROR_MSG = true + local s_expected + if not M.ORDER_ACTUAL_EXPECTED then + s_expected = prettystrPairs(actual) + else + s_expected = prettystrPairs(expected) + end + M.PRINT_TABLE_REF_IN_ERROR_MSG = old_print_table_ref_in_error_msg + fail_fmt(2, extra_msg_or_nil, 'expected and actual object should be different: %s', s_expected ) + end +end + + +------------------------------------------------------------------ +-- Scientific assertions +------------------------------------------------------------------ + + +function M.assertIsNaN(value, extra_msg_or_nil) + if type(value) ~= "number" or value == value then + failure("expected: NaN, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertNotIsNaN(value, extra_msg_or_nil) + if type(value) == "number" and value ~= value then + failure("expected: not NaN, actual: NaN", extra_msg_or_nil, 2) + end +end + +function M.assertIsInf(value, extra_msg_or_nil) + if type(value) ~= "number" or math.abs(value) ~= math.huge then + failure("expected: #Inf, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsPlusInf(value, extra_msg_or_nil) + if type(value) ~= "number" or value ~= math.huge then + failure("expected: #Inf, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsMinusInf(value, extra_msg_or_nil) + if type(value) ~= "number" or value ~= -math.huge then + failure("expected: -#Inf, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertNotIsPlusInf(value, extra_msg_or_nil) + if type(value) == "number" and value == math.huge then + failure("expected: not #Inf, actual: #Inf", extra_msg_or_nil, 2) + end +end + +function M.assertNotIsMinusInf(value, extra_msg_or_nil) + if type(value) == "number" and value == -math.huge then + failure("expected: not -#Inf, actual: -#Inf", extra_msg_or_nil, 2) + end +end + +function M.assertNotIsInf(value, extra_msg_or_nil) + if type(value) == "number" and math.abs(value) == math.huge then + failure("expected: not infinity, actual: " .. prettystr(value), extra_msg_or_nil, 2) + end +end + +function M.assertIsPlusZero(value, extra_msg_or_nil) + if type(value) ~= 'number' or value ~= 0 then + failure("expected: +0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2) + else if (1/value == -math.huge) then + -- more precise error diagnosis + failure("expected: +0.0, actual: -0.0", extra_msg_or_nil, 2) + else if (1/value ~= math.huge) then + -- strange, case should have already been covered + failure("expected: +0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end + end + end +end + +function M.assertIsMinusZero(value, extra_msg_or_nil) + if type(value) ~= 'number' or value ~= 0 then + failure("expected: -0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2) + else if (1/value == math.huge) then + -- more precise error diagnosis + failure("expected: -0.0, actual: +0.0", extra_msg_or_nil, 2) + else if (1/value ~= -math.huge) then + -- strange, case should have already been covered + failure("expected: -0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2) + end + end + end +end + +function M.assertNotIsPlusZero(value, extra_msg_or_nil) + if type(value) == 'number' and (1/value == math.huge) then + failure("expected: not +0.0, actual: +0.0", extra_msg_or_nil, 2) + end +end + +function M.assertNotIsMinusZero(value, extra_msg_or_nil) + if type(value) == 'number' and (1/value == -math.huge) then + failure("expected: not -0.0, actual: -0.0", extra_msg_or_nil, 2) + end +end + +function M.assertTableContains(t, expected, extra_msg_or_nil) + -- checks that table t contains the expected element + if table_findkeyof(t, expected) == nil then + t, expected = prettystrPairs(t, expected) + fail_fmt(2, extra_msg_or_nil, 'Table %s does NOT contain the expected element %s', + t, expected) + end +end + +function M.assertNotTableContains(t, expected, extra_msg_or_nil) + -- checks that table t doesn't contain the expected element + local k = table_findkeyof(t, expected) + if k ~= nil then + t, expected = prettystrPairs(t, expected) + fail_fmt(2, extra_msg_or_nil, 'Table %s DOES contain the unwanted element %s (at key %s)', + t, expected, prettystr(k)) + end +end + +---------------------------------------------------------------- +-- Compatibility layer +---------------------------------------------------------------- + +-- for compatibility with LuaUnit v2.x +function M.wrapFunctions() + -- In LuaUnit version <= 2.1 , this function was necessary to include + -- a test function inside the global test suite. Nowadays, the functions + -- are simply run directly as part of the test discovery process. + -- so just do nothing ! + io.stderr:write[[Use of WrapFunctions() is no longer needed. +Just prefix your test function names with "test" or "Test" and they +will be picked up and run by LuaUnit. +]] +end + +local list_of_funcs = { + -- { official function name , alias } + + -- general assertions + { 'assertEquals' , 'assert_equals' }, + { 'assertItemsEquals' , 'assert_items_equals' }, + { 'assertNotEquals' , 'assert_not_equals' }, + { 'assertAlmostEquals' , 'assert_almost_equals' }, + { 'assertNotAlmostEquals' , 'assert_not_almost_equals' }, + { 'assertEvalToTrue' , 'assert_eval_to_true' }, + { 'assertEvalToFalse' , 'assert_eval_to_false' }, + { 'assertStrContains' , 'assert_str_contains' }, + { 'assertStrIContains' , 'assert_str_icontains' }, + { 'assertNotStrContains' , 'assert_not_str_contains' }, + { 'assertNotStrIContains' , 'assert_not_str_icontains' }, + { 'assertStrMatches' , 'assert_str_matches' }, + { 'assertError' , 'assert_error' }, + { 'assertErrorMsgEquals' , 'assert_error_msg_equals' }, + { 'assertErrorMsgContains' , 'assert_error_msg_contains' }, + { 'assertErrorMsgMatches' , 'assert_error_msg_matches' }, + { 'assertErrorMsgContentEquals', 'assert_error_msg_content_equals' }, + { 'assertIs' , 'assert_is' }, + { 'assertNotIs' , 'assert_not_is' }, + { 'assertTableContains' , 'assert_table_contains' }, + { 'assertNotTableContains' , 'assert_not_table_contains' }, + { 'wrapFunctions' , 'WrapFunctions' }, + { 'wrapFunctions' , 'wrap_functions' }, + + -- type assertions: assertIsXXX -> assert_is_xxx + { 'assertIsNumber' , 'assert_is_number' }, + { 'assertIsString' , 'assert_is_string' }, + { 'assertIsTable' , 'assert_is_table' }, + { 'assertIsBoolean' , 'assert_is_boolean' }, + { 'assertIsNil' , 'assert_is_nil' }, + { 'assertIsTrue' , 'assert_is_true' }, + { 'assertIsFalse' , 'assert_is_false' }, + { 'assertIsNaN' , 'assert_is_nan' }, + { 'assertIsInf' , 'assert_is_inf' }, + { 'assertIsPlusInf' , 'assert_is_plus_inf' }, + { 'assertIsMinusInf' , 'assert_is_minus_inf' }, + { 'assertIsPlusZero' , 'assert_is_plus_zero' }, + { 'assertIsMinusZero' , 'assert_is_minus_zero' }, + { 'assertIsFunction' , 'assert_is_function' }, + { 'assertIsThread' , 'assert_is_thread' }, + { 'assertIsUserdata' , 'assert_is_userdata' }, + + -- type assertions: assertIsXXX -> assertXxx + { 'assertIsNumber' , 'assertNumber' }, + { 'assertIsString' , 'assertString' }, + { 'assertIsTable' , 'assertTable' }, + { 'assertIsBoolean' , 'assertBoolean' }, + { 'assertIsNil' , 'assertNil' }, + { 'assertIsTrue' , 'assertTrue' }, + { 'assertIsFalse' , 'assertFalse' }, + { 'assertIsNaN' , 'assertNaN' }, + { 'assertIsInf' , 'assertInf' }, + { 'assertIsPlusInf' , 'assertPlusInf' }, + { 'assertIsMinusInf' , 'assertMinusInf' }, + { 'assertIsPlusZero' , 'assertPlusZero' }, + { 'assertIsMinusZero' , 'assertMinusZero'}, + { 'assertIsFunction' , 'assertFunction' }, + { 'assertIsThread' , 'assertThread' }, + { 'assertIsUserdata' , 'assertUserdata' }, + + -- type assertions: assertIsXXX -> assert_xxx (luaunit v2 compat) + { 'assertIsNumber' , 'assert_number' }, + { 'assertIsString' , 'assert_string' }, + { 'assertIsTable' , 'assert_table' }, + { 'assertIsBoolean' , 'assert_boolean' }, + { 'assertIsNil' , 'assert_nil' }, + { 'assertIsTrue' , 'assert_true' }, + { 'assertIsFalse' , 'assert_false' }, + { 'assertIsNaN' , 'assert_nan' }, + { 'assertIsInf' , 'assert_inf' }, + { 'assertIsPlusInf' , 'assert_plus_inf' }, + { 'assertIsMinusInf' , 'assert_minus_inf' }, + { 'assertIsPlusZero' , 'assert_plus_zero' }, + { 'assertIsMinusZero' , 'assert_minus_zero' }, + { 'assertIsFunction' , 'assert_function' }, + { 'assertIsThread' , 'assert_thread' }, + { 'assertIsUserdata' , 'assert_userdata' }, + + -- type assertions: assertNotIsXXX -> assert_not_is_xxx + { 'assertNotIsNumber' , 'assert_not_is_number' }, + { 'assertNotIsString' , 'assert_not_is_string' }, + { 'assertNotIsTable' , 'assert_not_is_table' }, + { 'assertNotIsBoolean' , 'assert_not_is_boolean' }, + { 'assertNotIsNil' , 'assert_not_is_nil' }, + { 'assertNotIsTrue' , 'assert_not_is_true' }, + { 'assertNotIsFalse' , 'assert_not_is_false' }, + { 'assertNotIsNaN' , 'assert_not_is_nan' }, + { 'assertNotIsInf' , 'assert_not_is_inf' }, + { 'assertNotIsPlusInf' , 'assert_not_plus_inf' }, + { 'assertNotIsMinusInf' , 'assert_not_minus_inf' }, + { 'assertNotIsPlusZero' , 'assert_not_plus_zero' }, + { 'assertNotIsMinusZero' , 'assert_not_minus_zero' }, + { 'assertNotIsFunction' , 'assert_not_is_function' }, + { 'assertNotIsThread' , 'assert_not_is_thread' }, + { 'assertNotIsUserdata' , 'assert_not_is_userdata' }, + + -- type assertions: assertNotIsXXX -> assertNotXxx (luaunit v2 compat) + { 'assertNotIsNumber' , 'assertNotNumber' }, + { 'assertNotIsString' , 'assertNotString' }, + { 'assertNotIsTable' , 'assertNotTable' }, + { 'assertNotIsBoolean' , 'assertNotBoolean' }, + { 'assertNotIsNil' , 'assertNotNil' }, + { 'assertNotIsTrue' , 'assertNotTrue' }, + { 'assertNotIsFalse' , 'assertNotFalse' }, + { 'assertNotIsNaN' , 'assertNotNaN' }, + { 'assertNotIsInf' , 'assertNotInf' }, + { 'assertNotIsPlusInf' , 'assertNotPlusInf' }, + { 'assertNotIsMinusInf' , 'assertNotMinusInf' }, + { 'assertNotIsPlusZero' , 'assertNotPlusZero' }, + { 'assertNotIsMinusZero' , 'assertNotMinusZero' }, + { 'assertNotIsFunction' , 'assertNotFunction' }, + { 'assertNotIsThread' , 'assertNotThread' }, + { 'assertNotIsUserdata' , 'assertNotUserdata' }, + + -- type assertions: assertNotIsXXX -> assert_not_xxx + { 'assertNotIsNumber' , 'assert_not_number' }, + { 'assertNotIsString' , 'assert_not_string' }, + { 'assertNotIsTable' , 'assert_not_table' }, + { 'assertNotIsBoolean' , 'assert_not_boolean' }, + { 'assertNotIsNil' , 'assert_not_nil' }, + { 'assertNotIsTrue' , 'assert_not_true' }, + { 'assertNotIsFalse' , 'assert_not_false' }, + { 'assertNotIsNaN' , 'assert_not_nan' }, + { 'assertNotIsInf' , 'assert_not_inf' }, + { 'assertNotIsPlusInf' , 'assert_not_plus_inf' }, + { 'assertNotIsMinusInf' , 'assert_not_minus_inf' }, + { 'assertNotIsPlusZero' , 'assert_not_plus_zero' }, + { 'assertNotIsMinusZero' , 'assert_not_minus_zero' }, + { 'assertNotIsFunction' , 'assert_not_function' }, + { 'assertNotIsThread' , 'assert_not_thread' }, + { 'assertNotIsUserdata' , 'assert_not_userdata' }, + + -- all assertions with Coroutine duplicate Thread assertions + { 'assertIsThread' , 'assertIsCoroutine' }, + { 'assertIsThread' , 'assertCoroutine' }, + { 'assertIsThread' , 'assert_is_coroutine' }, + { 'assertIsThread' , 'assert_coroutine' }, + { 'assertNotIsThread' , 'assertNotIsCoroutine' }, + { 'assertNotIsThread' , 'assertNotCoroutine' }, + { 'assertNotIsThread' , 'assert_not_is_coroutine' }, + { 'assertNotIsThread' , 'assert_not_coroutine' }, +} + +-- Create all aliases in M +for _,v in ipairs( list_of_funcs ) do + local funcname, alias = v[1], v[2] + M[alias] = M[funcname] + + if EXPORT_ASSERT_TO_GLOBALS then + _G[funcname] = M[funcname] + _G[alias] = M[funcname] + end +end + +---------------------------------------------------------------- +-- +-- Outputters +-- +---------------------------------------------------------------- + +-- A common "base" class for outputters +-- For concepts involved (class inheritance) see http://www.lua.org/pil/16.2.html + +local genericOutput = { __class__ = 'genericOutput' } -- class +local genericOutput_MT = { __index = genericOutput } -- metatable +M.genericOutput = genericOutput -- publish, so that custom classes may derive from it + +function genericOutput.new(runner, default_verbosity) + -- runner is the "parent" object controlling the output, usually a LuaUnit instance + local t = { runner = runner } + if runner then + t.result = runner.result + t.verbosity = runner.verbosity or default_verbosity + t.fname = runner.fname + else + t.verbosity = default_verbosity + end + return setmetatable( t, genericOutput_MT) +end + +-- abstract ("empty") methods +function genericOutput:startSuite() + -- Called once, when the suite is started +end + +function genericOutput:startClass(className) + -- Called each time a new test class is started +end + +function genericOutput:startTest(testName) + -- called each time a new test is started, right before the setUp() + -- the current test status node is already created and available in: self.result.currentNode +end + +function genericOutput:updateStatus(node) + -- called with status failed or error as soon as the error/failure is encountered + -- this method is NOT called for a successful test because a test is marked as successful by default + -- and does not need to be updated +end + +function genericOutput:endTest(node) + -- called when the test is finished, after the tearDown() method +end + +function genericOutput:endClass() + -- called when executing the class is finished, before moving on to the next class of at the end of the test execution +end + +function genericOutput:endSuite() + -- called at the end of the test suite execution +end + + +---------------------------------------------------------------- +-- class TapOutput +---------------------------------------------------------------- + +local TapOutput = genericOutput.new() -- derived class +local TapOutput_MT = { __index = TapOutput } -- metatable +TapOutput.__class__ = 'TapOutput' + + -- For a good reference for TAP format, check: http://testanything.org/tap-specification.html + + function TapOutput.new(runner) + local t = genericOutput.new(runner, M.VERBOSITY_LOW) + return setmetatable( t, TapOutput_MT) + end + function TapOutput:startSuite() + print("1.."..self.result.selectedCount) + print('# Started on '..self.result.startDate) + end + function TapOutput:startClass(className) + if className ~= '[TestFunctions]' then + print('# Starting class: '..className) + end + end + + function TapOutput:updateStatus( node ) + if node:isSkipped() then + io.stdout:write("ok ", self.result.currentTestNumber, "\t# SKIP ", node.msg, "\n" ) + return + end + + io.stdout:write("not ok ", self.result.currentTestNumber, "\t", node.testName, "\n") + if self.verbosity > M.VERBOSITY_LOW then + print( prefixString( '# ', node.msg ) ) + end + if (node:isFailure() or node:isError()) and self.verbosity > M.VERBOSITY_DEFAULT then + print( prefixString( '# ', node.stackTrace ) ) + end + end + + function TapOutput:endTest( node ) + if node:isSuccess() then + io.stdout:write("ok ", self.result.currentTestNumber, "\t", node.testName, "\n") + end + end + + function TapOutput:endSuite() + print( '# '..M.LuaUnit.statusLine( self.result ) ) + return self.result.notSuccessCount + end + + +-- class TapOutput end + +---------------------------------------------------------------- +-- class JUnitOutput +---------------------------------------------------------------- + +-- See directory junitxml for more information about the junit format +local JUnitOutput = genericOutput.new() -- derived class +local JUnitOutput_MT = { __index = JUnitOutput } -- metatable +JUnitOutput.__class__ = 'JUnitOutput' + + function JUnitOutput.new(runner) + local t = genericOutput.new(runner, M.VERBOSITY_LOW) + t.testList = {} + return setmetatable( t, JUnitOutput_MT ) + end + + function JUnitOutput:startSuite() + -- open xml file early to deal with errors + if self.fname == nil then + error('With Junit, an output filename must be supplied with --name!') + end + if string.sub(self.fname,-4) ~= '.xml' then + self.fname = self.fname..'.xml' + end + self.fd = io.open(self.fname, "w") + if self.fd == nil then + error("Could not open file for writing: "..self.fname) + end + + print('# XML output to '..self.fname) + print('# Started on '..self.result.startDate) + end + function JUnitOutput:startClass(className) + if className ~= '[TestFunctions]' then + print('# Starting class: '..className) + end + end + function JUnitOutput:startTest(testName) + print('# Starting test: '..testName) + end + + function JUnitOutput:updateStatus( node ) + if node:isFailure() then + print( '# Failure: ' .. prefixString( '# ', node.msg ):sub(4, nil) ) + -- print('# ' .. node.stackTrace) + elseif node:isError() then + print( '# Error: ' .. prefixString( '# ' , node.msg ):sub(4, nil) ) + -- print('# ' .. node.stackTrace) + end + end + + function JUnitOutput:endSuite() + print( '# '..M.LuaUnit.statusLine(self.result)) + + -- XML file writing + self.fd:write('\n') + self.fd:write('\n') + self.fd:write(string.format( + ' \n', + self.result.runCount, self.result.startIsodate, self.result.duration, self.result.errorCount, self.result.failureCount, self.result.skippedCount )) + self.fd:write(" \n") + self.fd:write(string.format(' \n', _VERSION ) ) + self.fd:write(string.format(' \n', M.VERSION) ) + -- XXX please include system name and version if possible + self.fd:write(" \n") + + for i,node in ipairs(self.result.allTests) do + self.fd:write(string.format(' \n', + node.className, node.testName, node.duration ) ) + if node:isNotSuccess() then + self.fd:write(node:statusXML()) + end + self.fd:write(' \n') + end + + -- Next two lines are needed to validate junit ANT xsd, but really not useful in general: + self.fd:write(' \n') + self.fd:write(' \n') + + self.fd:write(' \n') + self.fd:write('\n') + self.fd:close() + return self.result.notSuccessCount + end + + +-- class TapOutput end + +---------------------------------------------------------------- +-- class TextOutput +---------------------------------------------------------------- + +--[[ Example of other unit-tests suite text output +-- Python Non verbose: +For each test: . or F or E +If some failed tests: + ============== + ERROR / FAILURE: TestName (testfile.testclass) + --------- + Stack trace +then -------------- +then "Ran x tests in 0.000s" +then OK or FAILED (failures=1, error=1) +-- Python Verbose: +testname (filename.classname) ... ok +testname (filename.classname) ... FAIL +testname (filename.classname) ... ERROR +then -------------- +then "Ran x tests in 0.000s" +then OK or FAILED (failures=1, error=1) +-- Ruby: +Started + . + Finished in 0.002695 seconds. + 1 tests, 2 assertions, 0 failures, 0 errors +-- Ruby: +>> ruby tc_simple_number2.rb +Loaded suite tc_simple_number2 +Started +F.. +Finished in 0.038617 seconds. + 1) Failure: +test_failure(TestSimpleNumber) [tc_simple_number2.rb:16]: +Adding doesn't work. +<3> expected but was +<4>. +3 tests, 4 assertions, 1 failures, 0 errors +-- Java Junit +.......F. +Time: 0,003 +There was 1 failure: +1) testCapacity(junit.samples.VectorTest)junit.framework.AssertionFailedError + at junit.samples.VectorTest.testCapacity(VectorTest.java:87) + at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) + at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) + at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) +FAILURES!!! +Tests run: 8, Failures: 1, Errors: 0 +-- Maven +# mvn test +------------------------------------------------------- + T E S T S +------------------------------------------------------- +Running math.AdditionTest +Tests run: 2, Failures: 1, Errors: 0, Skipped: 0, Time elapsed: +0.03 sec <<< FAILURE! +Results : +Failed tests: + testLireSymbole(math.AdditionTest) +Tests run: 2, Failures: 1, Errors: 0, Skipped: 0 +-- LuaUnit +---- non verbose +* display . or F or E when running tests +---- verbose +* display test name + ok/fail +---- +* blank line +* number) ERROR or FAILURE: TestName + Stack trace +* blank line +* number) ERROR or FAILURE: TestName + Stack trace +then -------------- +then "Ran x tests in 0.000s (%d not selected, %d skipped)" +then OK or FAILED (failures=1, error=1) +]] + +local TextOutput = genericOutput.new() -- derived class +local TextOutput_MT = { __index = TextOutput } -- metatable +TextOutput.__class__ = 'TextOutput' + + function TextOutput.new(runner) + local t = genericOutput.new(runner, M.VERBOSITY_DEFAULT) + t.errorList = {} + return setmetatable( t, TextOutput_MT ) + end + + function TextOutput:startSuite() + if self.verbosity > M.VERBOSITY_DEFAULT then + print( 'Started on '.. self.result.startDate ) + end + end + + function TextOutput:startTest(testName) + if self.verbosity > M.VERBOSITY_DEFAULT then + io.stdout:write( " ", self.result.currentNode.testName, " ... " ) + end + end + + function TextOutput:endTest( node ) + if node:isSuccess() then + if self.verbosity > M.VERBOSITY_DEFAULT then + io.stdout:write("Ok\n") + else + io.stdout:write(".") + io.stdout:flush() + end + else + if self.verbosity > M.VERBOSITY_DEFAULT then + print( node.status ) + print( node.msg ) + --[[ + -- find out when to do this: + if self.verbosity > M.VERBOSITY_DEFAULT then + print( node.stackTrace ) + end + ]] + else + -- write only the first character of status E, F or S + io.stdout:write(string.sub(node.status, 1, 1)) + io.stdout:flush() + end + end + end + + function TextOutput:displayOneFailedTest( index, fail ) + print(index..") "..fail.testName ) + print( fail.msg ) + print( fail.stackTrace ) + print() + end + + function TextOutput:displayErroredTests() + if #self.result.errorTests ~= 0 then + print("Tests with errors:") + print("------------------") + for i, v in ipairs(self.result.errorTests) do + self:displayOneFailedTest(i, v) + end + end + end + + function TextOutput:displayFailedTests() + if #self.result.failedTests ~= 0 then + print("Failed tests:") + print("-------------") + for i, v in ipairs(self.result.failedTests) do + self:displayOneFailedTest(i, v) + end + end + end + + function TextOutput:endSuite() + if self.verbosity > M.VERBOSITY_DEFAULT then + print("=========================================================") + else + print() + end + self:displayErroredTests() + self:displayFailedTests() + print( M.LuaUnit.statusLine( self.result ) ) + if self.result.notSuccessCount == 0 then + print('OK') + end + end + +-- class TextOutput end + + +---------------------------------------------------------------- +-- class NilOutput +---------------------------------------------------------------- + +local function nopCallable() + --print(42) + return nopCallable +end + +local NilOutput = { __class__ = 'NilOuptut' } -- class +local NilOutput_MT = { __index = nopCallable } -- metatable + +function NilOutput.new(runner) + return setmetatable( { __class__ = 'NilOutput' }, NilOutput_MT ) +end + +---------------------------------------------------------------- +-- +-- class LuaUnit +-- +---------------------------------------------------------------- + +M.LuaUnit = { + outputType = TextOutput, + verbosity = M.VERBOSITY_DEFAULT, + __class__ = 'LuaUnit', + instances = {} +} +local LuaUnit_MT = { __index = M.LuaUnit } + +if EXPORT_ASSERT_TO_GLOBALS then + LuaUnit = M.LuaUnit +end + + function M.LuaUnit.new() + local newInstance = setmetatable( {}, LuaUnit_MT ) + return newInstance + end + + -----------------[[ Utility methods ]]--------------------- + + function M.LuaUnit.asFunction(aObject) + -- return "aObject" if it is a function, and nil otherwise + if 'function' == type(aObject) then + return aObject + end + end + + function M.LuaUnit.splitClassMethod(someName) + --[[ + Return a pair of className, methodName strings for a name in the form + "class.method". If no class part (or separator) is found, will return + nil, someName instead (the latter being unchanged). + This convention thus also replaces the older isClassMethod() test: + You just have to check for a non-nil className (return) value. + ]] + local separator = string.find(someName, '.', 1, true) + if separator then + return someName:sub(1, separator - 1), someName:sub(separator + 1) + end + return nil, someName + end + + function M.LuaUnit.isMethodTestName( s ) + -- return true is the name matches the name of a test method + -- default rule is that is starts with 'Test' or with 'test' + return string.sub(s, 1, 4):lower() == 'test' + end + + function M.LuaUnit.isTestName( s ) + -- return true is the name matches the name of a test + -- default rule is that is starts with 'Test' or with 'test' + return string.sub(s, 1, 4):lower() == 'test' + end + + function M.LuaUnit.collectTests() + -- return a list of all test names in the global namespace + -- that match LuaUnit.isTestName + + local testNames = {} + for k, _ in pairs(_G) do + if type(k) == "string" and M.LuaUnit.isTestName( k ) then + table.insert( testNames , k ) + end + end + table.sort( testNames ) + return testNames + end + + function M.LuaUnit.parseCmdLine( cmdLine ) + -- parse the command line + -- Supported command line parameters: + -- --verbose, -v: increase verbosity + -- --quiet, -q: silence output + -- --error, -e: treat errors as fatal (quit program) + -- --output, -o, + name: select output type + -- --pattern, -p, + pattern: run test matching pattern, may be repeated + -- --exclude, -x, + pattern: run test not matching pattern, may be repeated + -- --shuffle, -s, : shuffle tests before reunning them + -- --name, -n, + fname: name of output file for junit, default to stdout + -- --repeat, -r, + num: number of times to execute each test + -- [testnames, ...]: run selected test names + -- + -- Returns a table with the following fields: + -- verbosity: nil, M.VERBOSITY_DEFAULT, M.VERBOSITY_QUIET, M.VERBOSITY_VERBOSE + -- output: nil, 'tap', 'junit', 'text', 'nil' + -- testNames: nil or a list of test names to run + -- exeRepeat: num or 1 + -- pattern: nil or a list of patterns + -- exclude: nil or a list of patterns + + local result, state = {}, nil + local SET_OUTPUT = 1 + local SET_PATTERN = 2 + local SET_EXCLUDE = 3 + local SET_FNAME = 4 + local SET_REPEAT = 5 + + if cmdLine == nil then + return result + end + + local function parseOption( option ) + if option == '--help' or option == '-h' then + result['help'] = true + return + elseif option == '--version' then + result['version'] = true + return + elseif option == '--verbose' or option == '-v' then + result['verbosity'] = M.VERBOSITY_VERBOSE + return + elseif option == '--quiet' or option == '-q' then + result['verbosity'] = M.VERBOSITY_QUIET + return + elseif option == '--error' or option == '-e' then + result['quitOnError'] = true + return + elseif option == '--failure' or option == '-f' then + result['quitOnFailure'] = true + return + elseif option == '--shuffle' or option == '-s' then + result['shuffle'] = true + return + elseif option == '--output' or option == '-o' then + state = SET_OUTPUT + return state + elseif option == '--name' or option == '-n' then + state = SET_FNAME + return state + elseif option == '--repeat' or option == '-r' then + state = SET_REPEAT + return state + elseif option == '--pattern' or option == '-p' then + state = SET_PATTERN + return state + elseif option == '--exclude' or option == '-x' then + state = SET_EXCLUDE + return state + end + error('Unknown option: '..option,3) + end + + local function setArg( cmdArg, state ) + if state == SET_OUTPUT then + result['output'] = cmdArg + return + elseif state == SET_FNAME then + result['fname'] = cmdArg + return + elseif state == SET_REPEAT then + result['exeRepeat'] = tonumber(cmdArg) + or error('Malformed -r argument: '..cmdArg) + return + elseif state == SET_PATTERN then + if result['pattern'] then + table.insert( result['pattern'], cmdArg ) + else + result['pattern'] = { cmdArg } + end + return + elseif state == SET_EXCLUDE then + local notArg = '!'..cmdArg + if result['pattern'] then + table.insert( result['pattern'], notArg ) + else + result['pattern'] = { notArg } + end + return + end + error('Unknown parse state: '.. state) + end + + + for i, cmdArg in ipairs(cmdLine) do + if state ~= nil then + setArg( cmdArg, state, result ) + state = nil + else + if cmdArg:sub(1,1) == '-' then + state = parseOption( cmdArg ) + else + if result['testNames'] then + table.insert( result['testNames'], cmdArg ) + else + result['testNames'] = { cmdArg } + end + end + end + end + + if result['help'] then + M.LuaUnit.help() + end + + if result['version'] then + M.LuaUnit.version() + end + + if state ~= nil then + error('Missing argument after '..cmdLine[ #cmdLine ],2 ) + end + + return result + end + + function M.LuaUnit.help() + print(M.USAGE) + os.exit(0) + end + + function M.LuaUnit.version() + print('LuaUnit v'..M.VERSION..' by Philippe Fremy ') + os.exit(0) + end + +---------------------------------------------------------------- +-- class NodeStatus +---------------------------------------------------------------- + + local NodeStatus = { __class__ = 'NodeStatus' } -- class + local NodeStatus_MT = { __index = NodeStatus } -- metatable + M.NodeStatus = NodeStatus + + -- values of status + NodeStatus.SUCCESS = 'SUCCESS' + NodeStatus.SKIP = 'SKIP' + NodeStatus.FAIL = 'FAIL' + NodeStatus.ERROR = 'ERROR' + + function NodeStatus.new( number, testName, className ) + -- default constructor, test are PASS by default + local t = { number = number, testName = testName, className = className } + setmetatable( t, NodeStatus_MT ) + t:success() + return t + end + + function NodeStatus:success() + self.status = self.SUCCESS + -- useless because lua does this for us, but it helps me remembering the relevant field names + self.msg = nil + self.stackTrace = nil + end + + function NodeStatus:skip(msg) + self.status = self.SKIP + self.msg = msg + self.stackTrace = nil + end + + function NodeStatus:fail(msg, stackTrace) + self.status = self.FAIL + self.msg = msg + self.stackTrace = stackTrace + end + + function NodeStatus:error(msg, stackTrace) + self.status = self.ERROR + self.msg = msg + self.stackTrace = stackTrace + end + + function NodeStatus:isSuccess() + return self.status == NodeStatus.SUCCESS + end + + function NodeStatus:isNotSuccess() + -- Return true if node is either failure or error or skip + return (self.status == NodeStatus.FAIL or self.status == NodeStatus.ERROR or self.status == NodeStatus.SKIP) + end + + function NodeStatus:isSkipped() + return self.status == NodeStatus.SKIP + end + + function NodeStatus:isFailure() + return self.status == NodeStatus.FAIL + end + + function NodeStatus:isError() + return self.status == NodeStatus.ERROR + end + + function NodeStatus:statusXML() + if self:isError() then + return table.concat( + {' \n', + ' \n'}) + elseif self:isFailure() then + return table.concat( + {' \n', + ' \n'}) + elseif self:isSkipped() then + return table.concat({' ', xmlEscape(self.msg),'\n' } ) + end + return ' \n' -- (not XSD-compliant! normally shouldn't get here) + end + + --------------[[ Output methods ]]------------------------- + + local function conditional_plural(number, singular) + -- returns a grammatically well-formed string "%d " + local suffix = '' + if number ~= 1 then -- use plural + suffix = (singular:sub(-2) == 'ss') and 'es' or 's' + end + return string.format('%d %s%s', number, singular, suffix) + end + + function M.LuaUnit.statusLine(result) + -- return status line string according to results + local s = { + string.format('Ran %d tests in %0.3f seconds', + result.runCount, result.duration), + conditional_plural(result.successCount, 'success'), + } + if result.notSuccessCount > 0 then + if result.failureCount > 0 then + table.insert(s, conditional_plural(result.failureCount, 'failure')) + end + if result.errorCount > 0 then + table.insert(s, conditional_plural(result.errorCount, 'error')) + end + else + table.insert(s, '0 failures') + end + if result.skippedCount > 0 then + table.insert(s, string.format("%d skipped", result.skippedCount)) + end + if result.nonSelectedCount > 0 then + table.insert(s, string.format("%d non-selected", result.nonSelectedCount)) + end + return table.concat(s, ', ') + end + + function M.LuaUnit:startSuite(selectedCount, nonSelectedCount) + self.result = { + selectedCount = selectedCount, + nonSelectedCount = nonSelectedCount, + successCount = 0, + runCount = 0, + currentTestNumber = 0, + currentClassName = "", + currentNode = nil, + suiteStarted = true, + startTime = os.clock(), + startDate = os.date(os.getenv('LUAUNIT_DATEFMT')), + startIsodate = os.date('%Y-%m-%dT%H:%M:%S'), + patternIncludeFilter = self.patternIncludeFilter, + + -- list of test node status + allTests = {}, + failedTests = {}, + errorTests = {}, + skippedTests = {}, + + failureCount = 0, + errorCount = 0, + notSuccessCount = 0, + skippedCount = 0, + } + + self.outputType = self.outputType or TextOutput + self.output = self.outputType.new(self) + self.output:startSuite() + end + + function M.LuaUnit:startClass( className, classInstance ) + self.result.currentClassName = className + self.output:startClass( className ) + self:setupClass( className, classInstance ) + end + + function M.LuaUnit:startTest( testName ) + self.result.currentTestNumber = self.result.currentTestNumber + 1 + self.result.runCount = self.result.runCount + 1 + self.result.currentNode = NodeStatus.new( + self.result.currentTestNumber, + testName, + self.result.currentClassName + ) + self.result.currentNode.startTime = os.clock() + table.insert( self.result.allTests, self.result.currentNode ) + self.output:startTest( testName ) + end + + function M.LuaUnit:updateStatus( err ) + -- "err" is expected to be a table / result from protectedCall() + if err.status == NodeStatus.SUCCESS then + return + end + + local node = self.result.currentNode + + --[[ As a first approach, we will report only one error or one failure for one test. + However, we can have the case where the test is in failure, and the teardown is in error. + In such case, it's a good idea to report both a failure and an error in the test suite. This is + what Python unittest does for example. However, it mixes up counts so need to be handled carefully: for + example, there could be more (failures + errors) count that tests. What happens to the current node ? + We will do this more intelligent version later. + ]] + + -- if the node is already in failure/error, just don't report the new error (see above) + if node.status ~= NodeStatus.SUCCESS then + return + end + + if err.status == NodeStatus.FAIL then + node:fail( err.msg, err.trace ) + table.insert( self.result.failedTests, node ) + elseif err.status == NodeStatus.ERROR then + node:error( err.msg, err.trace ) + table.insert( self.result.errorTests, node ) + elseif err.status == NodeStatus.SKIP then + node:skip( err.msg ) + table.insert( self.result.skippedTests, node ) + else + error('No such status: ' .. prettystr(err.status)) + end + + self.output:updateStatus( node ) + end + + function M.LuaUnit:endTest() + local node = self.result.currentNode + -- print( 'endTest() '..prettystr(node)) + -- print( 'endTest() '..prettystr(node:isNotSuccess())) + node.duration = os.clock() - node.startTime + node.startTime = nil + self.output:endTest( node ) + + if node:isSuccess() then + self.result.successCount = self.result.successCount + 1 + elseif node:isError() then + if self.quitOnError or self.quitOnFailure then + -- Runtime error - abort test execution as requested by + -- "--error" option. This is done by setting a special + -- flag that gets handled in internalRunSuiteByInstances(). + print("\nERROR during LuaUnit test execution:\n" .. node.msg) + self.result.aborted = true + end + elseif node:isFailure() then + if self.quitOnFailure then + -- Failure - abort test execution as requested by + -- "--failure" option. This is done by setting a special + -- flag that gets handled in internalRunSuiteByInstances(). + print("\nFailure during LuaUnit test execution:\n" .. node.msg) + self.result.aborted = true + end + elseif node:isSkipped() then + self.result.runCount = self.result.runCount - 1 + else + error('No such node status: ' .. prettystr(node.status)) + end + self.result.currentNode = nil + end + + function M.LuaUnit:endClass() + self:teardownClass( self.lastClassName, self.lastClassInstance ) + self.output:endClass() + end + + function M.LuaUnit:endSuite() + if self.result.suiteStarted == false then + error('LuaUnit:endSuite() -- suite was already ended' ) + end + self.result.duration = os.clock()-self.result.startTime + self.result.suiteStarted = false + + -- Expose test counts for outputter's endSuite(). This could be managed + -- internally instead by using the length of the lists of failed tests + -- but unit tests rely on these fields being present. + self.result.failureCount = #self.result.failedTests + self.result.errorCount = #self.result.errorTests + self.result.notSuccessCount = self.result.failureCount + self.result.errorCount + self.result.skippedCount = #self.result.skippedTests + + self.output:endSuite() + end + + function M.LuaUnit:setOutputType(outputType, fname) + -- Configures LuaUnit runner output + -- outputType is one of: NIL, TAP, JUNIT, TEXT + -- when outputType is junit, the additional argument fname is used to set the name of junit output file + -- for other formats, fname is ignored + if outputType:upper() == "NIL" then + self.outputType = NilOutput + return + end + if outputType:upper() == "TAP" then + self.outputType = TapOutput + return + end + if outputType:upper() == "JUNIT" then + self.outputType = JUnitOutput + if fname then + self.fname = fname + end + return + end + if outputType:upper() == "TEXT" then + self.outputType = TextOutput + return + end + error( 'No such format: '..outputType,2) + end + + --------------[[ Runner ]]----------------- + + function M.LuaUnit:protectedCall(classInstance, methodInstance, prettyFuncName) + -- if classInstance is nil, this is just a function call + -- else, it's method of a class being called. + + local function err_handler(e) + -- transform error into a table, adding the traceback information + return { + status = NodeStatus.ERROR, + msg = e, + trace = string.sub(debug.traceback("", 1), 2) + } + end + + local ok, err + if classInstance then + -- stupid Lua < 5.2 does not allow xpcall with arguments so let's use a workaround + ok, err = xpcall( function () methodInstance(classInstance) end, err_handler ) + else + ok, err = xpcall( function () methodInstance() end, err_handler ) + end + if ok then + return {status = NodeStatus.SUCCESS} + end + -- print('ok="'..prettystr(ok)..'" err="'..prettystr(err)..'"') + + local iter_msg + iter_msg = self.exeRepeat and 'iteration '..self.currentCount + + err.msg, err.status = M.adjust_err_msg_with_iter( err.msg, iter_msg ) + + if err.status == NodeStatus.SUCCESS or err.status == NodeStatus.SKIP then + err.trace = nil + return err + end + + -- reformat / improve the stack trace + if prettyFuncName then -- we do have the real method name + err.trace = err.trace:gsub("in (%a+) 'methodInstance'", "in %1 '"..prettyFuncName.."'") + end + if STRIP_LUAUNIT_FROM_STACKTRACE then + err.trace = stripLuaunitTrace2(err.trace, err.msg) + end + + return err -- return the error "object" (table) + end + + + function M.LuaUnit:execOneFunction(className, methodName, classInstance, methodInstance) + -- When executing a test function, className and classInstance must be nil + -- When executing a class method, all parameters must be set + + if type(methodInstance) ~= 'function' then + self:unregisterSuite() + error( tostring(methodName)..' must be a function, not '..type(methodInstance)) + end + + local prettyFuncName + if className == nil then + className = '[TestFunctions]' + prettyFuncName = methodName + else + prettyFuncName = className..'.'..methodName + end + + if self.lastClassName ~= className then + if self.lastClassName ~= nil then + self:endClass() + end + self:startClass( className, classInstance ) + self.lastClassName = className + self.lastClassInstance = classInstance + end + + self:startTest(prettyFuncName) + + local node = self.result.currentNode + for iter_n = 1, self.exeRepeat or 1 do + if node:isNotSuccess() then + break + end + self.currentCount = iter_n + + -- run setUp first (if any) + if classInstance then + local func = self.asFunction( classInstance.setUp ) or + self.asFunction( classInstance.Setup ) or + self.asFunction( classInstance.setup ) or + self.asFunction( classInstance.SetUp ) + if func then + self:updateStatus(self:protectedCall(classInstance, func, className..'.setUp')) + end + end + + -- run testMethod() + if node:isSuccess() then + self:updateStatus(self:protectedCall(classInstance, methodInstance, prettyFuncName)) + end + + -- lastly, run tearDown (if any) + if classInstance then + local func = self.asFunction( classInstance.tearDown ) or + self.asFunction( classInstance.TearDown ) or + self.asFunction( classInstance.teardown ) or + self.asFunction( classInstance.Teardown ) + if func then + self:updateStatus(self:protectedCall(classInstance, func, className..'.tearDown')) + end + end + end + + self:endTest() + end + + function M.LuaUnit.expandOneClass( result, className, classInstance ) + --[[ + Input: a list of { name, instance }, a class name, a class instance + Ouptut: modify result to add all test method instance in the form: + { className.methodName, classInstance } + ]] + for methodName, methodInstance in sortedPairs(classInstance) do + if M.LuaUnit.asFunction(methodInstance) and M.LuaUnit.isMethodTestName( methodName ) then + table.insert( result, { className..'.'..methodName, classInstance } ) + end + end + end + + function M.LuaUnit.expandClasses( listOfNameAndInst ) + --[[ + -- expand all classes (provided as {className, classInstance}) to a list of {className.methodName, classInstance} + -- functions and methods remain untouched + Input: a list of { name, instance } + Output: + * { function name, function instance } : do nothing + * { class.method name, class instance }: do nothing + * { class name, class instance } : add all method names in the form of (className.methodName, classInstance) + ]] + local result = {} + + for i,v in ipairs( listOfNameAndInst ) do + local name, instance = v[1], v[2] + if M.LuaUnit.asFunction(instance) then + table.insert( result, { name, instance } ) + else + if type(instance) ~= 'table' then + error( 'Instance must be a table or a function, not a '..type(instance)..' with value '..prettystr(instance)) + end + local className, methodName = M.LuaUnit.splitClassMethod( name ) + if className then + local methodInstance = instance[methodName] + if methodInstance == nil then + error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) ) + end + table.insert( result, { name, instance } ) + else + M.LuaUnit.expandOneClass( result, name, instance ) + end + end + end + + return result + end + + function M.LuaUnit.applyPatternFilter( patternIncFilter, listOfNameAndInst ) + local included, excluded = {}, {} + for i, v in ipairs( listOfNameAndInst ) do + -- local name, instance = v[1], v[2] + if patternFilter( patternIncFilter, v[1] ) then + table.insert( included, v ) + else + table.insert( excluded, v ) + end + end + return included, excluded + end + + local function getKeyInListWithGlobalFallback( key, listOfNameAndInst ) + local result = nil + for i,v in ipairs( listOfNameAndInst ) do + if(listOfNameAndInst[i][1] == key) then + result = listOfNameAndInst[i][2] + break + end + end + if(not M.LuaUnit.asFunction( result ) ) then + result = _G[key] + end + return result + end + + function M.LuaUnit:setupSuite( listOfNameAndInst ) + local setupSuite = getKeyInListWithGlobalFallback("setupSuite", listOfNameAndInst) + if self.asFunction( setupSuite ) then + self:updateStatus( self:protectedCall( nil, setupSuite, 'setupSuite' ) ) + end + end + + function M.LuaUnit:teardownSuite(listOfNameAndInst) + local teardownSuite = getKeyInListWithGlobalFallback("teardownSuite", listOfNameAndInst) + if self.asFunction( teardownSuite ) then + self:updateStatus( self:protectedCall( nil, teardownSuite, 'teardownSuite') ) + end + end + + function M.LuaUnit:setupClass( className, instance ) + if type( instance ) == 'table' and self.asFunction( instance.setupClass ) then + self:updateStatus( self:protectedCall( instance, instance.setupClass, className..'.setupClass' ) ) + end + end + + function M.LuaUnit:teardownClass( className, instance ) + if type( instance ) == 'table' and self.asFunction( instance.teardownClass ) then + self:updateStatus( self:protectedCall( instance, instance.teardownClass, className..'.teardownClass' ) ) + end + end + + function M.LuaUnit:internalRunSuiteByInstances( listOfNameAndInst ) + --[[ Run an explicit list of tests. Each item of the list must be one of: + * { function name, function instance } + * { class name, class instance } + * { class.method name, class instance } + This function is internal to LuaUnit. The official API to perform this action is runSuiteByInstances() + ]] + + local expandedList = self.expandClasses( listOfNameAndInst ) + if self.shuffle then + randomizeTable( expandedList ) + end + local filteredList, filteredOutList = self.applyPatternFilter( + self.patternIncludeFilter, expandedList ) + + self:startSuite( #filteredList, #filteredOutList ) + self:setupSuite( listOfNameAndInst ) + + for i,v in ipairs( filteredList ) do + local name, instance = v[1], v[2] + if M.LuaUnit.asFunction(instance) then + self:execOneFunction( nil, name, nil, instance ) + else + -- expandClasses() should have already taken care of sanitizing the input + assert( type(instance) == 'table' ) + local className, methodName = M.LuaUnit.splitClassMethod( name ) + assert( className ~= nil ) + local methodInstance = instance[methodName] + assert(methodInstance ~= nil) + self:execOneFunction( className, methodName, instance, methodInstance ) + end + if self.result.aborted then + break -- "--error" or "--failure" option triggered + end + end + + if self.lastClassName ~= nil then + self:endClass() + end + + self:teardownSuite( listOfNameAndInst ) + self:endSuite() + + if self.result.aborted then + print("LuaUnit ABORTED (as requested by --error or --failure option)") + self:unregisterSuite() + os.exit(-2) + end + end + + function M.LuaUnit:internalRunSuiteByNames( listOfName ) + --[[ Run LuaUnit with a list of generic names, coming either from command-line or from global + namespace analysis. Convert the list into a list of (name, valid instances (table or function)) + and calls internalRunSuiteByInstances. + ]] + + local instanceName, instance + local listOfNameAndInst = {} + + for i,name in ipairs( listOfName ) do + local className, methodName = M.LuaUnit.splitClassMethod( name ) + if className then + instanceName = className + instance = _G[instanceName] + + if instance == nil then + self:unregisterSuite() + error( "No such name in global space: "..instanceName ) + end + + if type(instance) ~= 'table' then + self:unregisterSuite() + error( 'Instance of '..instanceName..' must be a table, not '..type(instance)) + end + + local methodInstance = instance[methodName] + if methodInstance == nil then + self:unregisterSuite() + error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) ) + end + + else + -- for functions and classes + instanceName = name + instance = _G[instanceName] + end + + if instance == nil then + self:unregisterSuite() + error( "No such name in global space: "..instanceName ) + end + + if (type(instance) ~= 'table' and type(instance) ~= 'function') then + self:unregisterSuite() + error( 'Name must match a function or a table: '..instanceName ) + end + + table.insert( listOfNameAndInst, { name, instance } ) + end + + self:internalRunSuiteByInstances( listOfNameAndInst ) + end + + function M.LuaUnit.run(...) + -- Run some specific test classes. + -- If no arguments are passed, run the class names specified on the + -- command line. If no class name is specified on the command line + -- run all classes whose name starts with 'Test' + -- + -- If arguments are passed, they must be strings of the class names + -- that you want to run or generic command line arguments (-o, -p, -v, ...) + local runner = M.LuaUnit.new() + return runner:runSuite(...) + end + + function M.LuaUnit:registerSuite() + -- register the current instance into our global array of instances + -- print('-> Register suite') + M.LuaUnit.instances[ #M.LuaUnit.instances+1 ] = self + end + + function M.unregisterCurrentSuite() + -- force unregister the last registered suite + table.remove(M.LuaUnit.instances, #M.LuaUnit.instances) + end + + function M.LuaUnit:unregisterSuite() + -- print('<- Unregister suite') + -- remove our current instqances from the global array of instances + local instanceIdx = nil + for i, instance in ipairs(M.LuaUnit.instances) do + if instance == self then + instanceIdx = i + break + end + end + + if instanceIdx ~= nil then + table.remove(M.LuaUnit.instances, instanceIdx) + -- print('Unregister done') + end + + end + + function M.LuaUnit:initFromArguments( ... ) + --[[Parses all arguments from either command-line or direct call and set internal + flags of LuaUnit runner according to it. + Return the list of names which were possibly passed on the command-line or as arguments + ]] + local args = {...} + if type(args[1]) == 'table' and args[1].__class__ == 'LuaUnit' then + -- run was called with the syntax M.LuaUnit:runSuite() + -- we support both M.LuaUnit.run() and M.LuaUnit:run() + -- strip out the first argument self to make it a command-line argument list + table.remove(args,1) + end + + if #args == 0 then + args = cmdline_argv + end + + local options = pcall_or_abort( M.LuaUnit.parseCmdLine, args ) + + -- We expect these option fields to be either `nil` or contain + -- valid values, so it's safe to always copy them directly. + self.verbosity = options.verbosity + self.quitOnError = options.quitOnError + self.quitOnFailure = options.quitOnFailure + + self.exeRepeat = options.exeRepeat + self.patternIncludeFilter = options.pattern + self.shuffle = options.shuffle + + options.output = options.output or os.getenv('LUAUNIT_OUTPUT') + options.fname = options.fname or os.getenv('LUAUNIT_JUNIT_FNAME') + + if options.output then + if options.output:lower() == 'junit' and options.fname == nil then + print('With junit output, a filename must be supplied with -n or --name') + os.exit(-1) + end + pcall_or_abort(self.setOutputType, self, options.output, options.fname) + end + + return options.testNames + end + + function M.LuaUnit:runSuite( ... ) + testNames = self:initFromArguments(...) + self:registerSuite() + self:internalRunSuiteByNames( testNames or M.LuaUnit.collectTests() ) + self:unregisterSuite() + return self.result.notSuccessCount + end + + function M.LuaUnit:runSuiteByInstances( listOfNameAndInst, commandLineArguments ) + --[[ + Run all test functions or tables provided as input. + Input: a list of { name, instance } + instance can either be a function or a table containing test functions starting with the prefix "test" + return the number of failures and errors, 0 meaning success + ]] + -- parse the command-line arguments + testNames = self:initFromArguments( commandLineArguments ) + self:registerSuite() + self:internalRunSuiteByInstances( listOfNameAndInst ) + self:unregisterSuite() + return self.result.notSuccessCount + end + + + +-- class LuaUnit + +-- For compatbility with LuaUnit v2 +M.run = M.LuaUnit.run +M.Run = M.LuaUnit.run + +function M:setVerbosity( verbosity ) + -- set the verbosity value (as integer) + M.LuaUnit.verbosity = verbosity +end +M.set_verbosity = M.setVerbosity +M.SetVerbosity = M.setVerbosity + + +return M \ No newline at end of file diff --git a/tests/Lua/runtests.lua b/tests/Lua/runtests.lua new file mode 100644 index 0000000000..d055ab15fc --- /dev/null +++ b/tests/Lua/runtests.lua @@ -0,0 +1,14 @@ +luaunit = require('luaunit') + +-- TestMod = {} +-- function TestMod.testHello() +-- assertEquals(1, 1) +-- end + +TestArithmetic = require('TestArithmetic') +TestArray = require('TestArray') +TestRecords = require('TestRecords') +TestControlFlow = require('TestControlFlow') +TestUnionType = require('TestUnionType') + +luaunit.run()