From 492c91a204b63403d38ce58364be04e962180dac Mon Sep 17 00:00:00 2001 From: xeals Date: Tue, 15 Jan 2019 11:20:10 +1100 Subject: [PATCH] chore: migrate to unique git repo --- autostart | 1 + cfg.fnl | 42 + conf/apps.fnl | 9 + conf/client.fnl | 1 + conf/client/buttons.fnl | 17 + conf/client/keys.fnl | 45 + conf/client/rules.fnl | 51 + conf/keys.fnl | 119 +++ conf/rules.fnl | 10 + conf/tags.fnl | 21 + conf/theme.fnl | 35 + lib/README.md | 8 + lib/fennel.lua | 1964 +++++++++++++++++++++++++++++++++++++++ lib/fun.lua | 1056 +++++++++++++++++++++ lib/keys.fnl | 35 + lib/std.fnl | 17 + module/decorate.fnl | 74 ++ module/sidebar.fnl | 74 ++ rc.lua | 4 + 19 files changed, 3583 insertions(+) create mode 120000 autostart create mode 100644 cfg.fnl create mode 100644 conf/apps.fnl create mode 100644 conf/client.fnl create mode 100644 conf/client/buttons.fnl create mode 100644 conf/client/keys.fnl create mode 100644 conf/client/rules.fnl create mode 100644 conf/keys.fnl create mode 100644 conf/rules.fnl create mode 100644 conf/tags.fnl create mode 100644 conf/theme.fnl create mode 100644 lib/README.md create mode 100644 lib/fennel.lua create mode 100644 lib/fun.lua create mode 100644 lib/keys.fnl create mode 100644 lib/std.fnl create mode 100644 module/decorate.fnl create mode 100644 module/sidebar.fnl create mode 100644 rc.lua diff --git a/autostart b/autostart new file mode 120000 index 0000000..8f20b3b --- /dev/null +++ b/autostart @@ -0,0 +1 @@ +../autostart.sh \ No newline at end of file diff --git a/cfg.fnl b/cfg.fnl new file mode 100644 index 0000000..8e92d3d --- /dev/null +++ b/cfg.fnl @@ -0,0 +1,42 @@ +(local fun (require :lib.fun)) + +(local awful (require :awful)) +(require :awful.autofocus) +(local naughty (require :naughty)) +(local beautiful (require :beautiful)) + +(beautiful.init (require :conf.theme)) + +(require :module.decorate) +(require :module.sidebar) + +(require :conf.client) +(require :conf.tags) +(_G.root.keys (require :conf.keys)) + +;; TEMP Next release should support this through `beautiful'. +(set naughty.config.defaults.border_width beautiful.notification_border_width) +(set naughty.config.defaults.icon_size beautiful.notification_icon_size) + +;;; +;; Post-config -- wiring it all up + +(->> {:manage (lambda [c] nil + (when (and awesome.startup + (not c.size_hints.user_position) + (not c.size_hints.program_position)) + ;; Prevent clients from being unreachable after screen count changes. + (awful.placement.no_offscreen c))) + :mouse::enter (lambda [c] + (when (and (~= (awful.layout.get c.screen) + awful.layout.suit.magnifier) + (awful.client.focus.filter c)) + (set client.focus c))) + :focus (lambda [c] (set c.border_color beautiful.border_focus)) + :unfocus (lambda [c] (set c.border_color beautiful.border_normal))} + (fun.each (fn [event callback] (client.connect_signal event callback)))) + +(awful.spawn.with_shell "~/.config/awesome/autostart") + +;; Empty return +{} diff --git a/conf/apps.fnl b/conf/apps.fnl new file mode 100644 index 0000000..d48d170 --- /dev/null +++ b/conf/apps.fnl @@ -0,0 +1,9 @@ +;;; conf/apps.fnl --- Common program parameters + +;;; Code: + +{:terminal "xfce4-terminal" + :dropdown "xfce4-terminal --drop-down" + :editor "emacs" + :rofi "rofi -show drun"} +;;; conf/apps.fnl ends here diff --git a/conf/client.fnl b/conf/client.fnl new file mode 100644 index 0000000..63251fe --- /dev/null +++ b/conf/client.fnl @@ -0,0 +1 @@ +(require :conf.client.rules) diff --git a/conf/client/buttons.fnl b/conf/client/buttons.fnl new file mode 100644 index 0000000..3cd9746 --- /dev/null +++ b/conf/client/buttons.fnl @@ -0,0 +1,17 @@ +;;; buttons.fnl --- Client buttons + +;;; Code: + +(var buttons {}) + +(local awful (require :awful)) +(local gears (require :gears)) + +(local button (. (require :lib.keys) :button)) + +(gears.table.join + (button [] 1 (lambda [c] (set client.focus c) (: c :raise))) + (button [:mod] 1 awful.mouse.client.move) + (button [:mod] 3 awful.mouse.client.resize)) + +;;; buttons.fnl ends here diff --git a/conf/client/keys.fnl b/conf/client/keys.fnl new file mode 100644 index 0000000..acc4e56 --- /dev/null +++ b/conf/client/keys.fnl @@ -0,0 +1,45 @@ +;;; conf/client/keys.fnl --- Client keys + +;;; Code: + +(var keys {}) + +(local awful (require :awful)) +(local beautiful (require :beautiful)) +(local gears (require :gears)) + +(local key (. (require :lib.keys) :key)) + +;;; +;; Functions + +(fn register [field ...] + (tset keys field (gears.table.join ...))) + +(fn cinv [k] + (fn [c] + (let [v (. c k)] + (tset c k (not v))))) + +;;; +;; Configuration + +(register :client + (key [:mod] :q (lambda [c] (: c :kill))) + (key [:mod] :w (lambda [c] (set c.minimized true))) + + (key [:mod] :t (lambda [c] (awful.titlebar.toggle c beautiful.titlebar_position))) + (key [:mod :shift] :t (lambda [c] (awful.titlebar.toggle c :left))) + (key [:mod :ctrl] :t (lambda [c] (awful.titlebar.toggle c))) + + (key [:mod] :a (cinv :floating)) + (key [:mod] :f (cinv :fullscreen)) + (key [:mod] :m (cinv :maximized)) ;; FIXME behaves weird + (key [:mod] :s (cinv :sticky))) + +(register :floating + (. keys :client) + (key [:alt :shift] :h (lambda [c] (: c :relative_move {:x -100})))) + +keys +;;; conf/client/keys.fnl ends here diff --git a/conf/client/rules.fnl b/conf/client/rules.fnl new file mode 100644 index 0000000..60e86e7 --- /dev/null +++ b/conf/client/rules.fnl @@ -0,0 +1,51 @@ +;;; conf/client/rules.fnl --- Client placement rules + +;;; Code: + +(local awful (require :awful)) +(local beautiful (require :beautiful)) + +(local keys (require :conf.client.keys)) +(local buttons (require :conf.client.buttons)) + +;;; +;; Functions + +(local + rules + [{:rule {} + :properties {:border_width beautiful.border_width + :border_color beautiful.border_normal + :focus true + :keys (. keys :client) + :buttons buttons + :placement (+ awful.placement.no_overlap + awful.placement.no_offscreen)}} + + {:rule {:floating true} + :properties {:keys (. keys :floating)}} + + ;; Floating clients. + {:rule_any {:class ["Gpick"] + :name ["Event Tester"] ;; xev + :role ["pop-up" + "xfce4-terminal-dropdown"]} + :properties {:floating true}} + + ;; Add titlebars to normal clients and dialogs. + {:rule_any {:type ["normal" "dialog"]} + :properties {:titlebars_enabled true}} + + ;; Awesome reserves space for tint2, but doesn't actually fucking place it. + {:rule {:class "Tint2"} + :properties {:x 0 :y 0}} + + ;; Set Firefox to always map on the tag named "2" on screen 1. + ;; {:rule {:class "Firefox"} + ;; :properties {:screen 1 :tag "2"}} + ]) + +(tset awful.rules :rules rules) + +{} +;;; conf/client/rules.fnl ends here diff --git a/conf/keys.fnl b/conf/keys.fnl new file mode 100644 index 0000000..e9ba16f --- /dev/null +++ b/conf/keys.fnl @@ -0,0 +1,119 @@ +;;; conf/keys.fnl --- Global keys + +;;; Code: + +(var keys {}) + +(local gears (require :gears)) +(local awful (require :awful)) + +(local apps (require :conf.apps)) +(local key (. (require :lib.keys) :key)) + +;;; +;; Functions + +(fn register [...] + (set keys (gears.table.join keys ...))) + +;; Spawn lamda +(fn run [prog] + (fn [] (awful.spawn prog))) + +;; Run silent +(fn run! [prog] + (fn [] (awful.spawn prog false))) + +(fn playerctl [a] + (run! (.. "playerctl " a))) + +(fn mpc [a] + (run! (.. "mpc " a))) + +(fn amixer [args] + (run! (.. "amixer sset Master " args))) + +(fn light [a v] + (let [flag (. {:up "-A" :down "-U"} a)] + (run! (.. "light " flag " " (tostring v))))) + +(fn with-tag [i fun] + (lambda [] (when client.focus + (let [tag (. client.focus.screen.tags i)] + (when tag + (fun tag)))))) + +(fn apply-focus [i attr] + (with-tag i (lambda [t] (: client.focus attr t)))) + +(fn current-tag [] + (. (awful.screen.focused) :selected_tag)) + +(fn focused-tag-by-index [i] + (. (awful.screen.focused) :tags i)) + +(fn with-focused-tag [i fun] + (lambda [] (let [tag (focused-tag-by-index i)] + (when tag + (fun tag))))) + +(fn focus [c] + (when c + (set client.focus c) + (: c :raise))) + +;;; +;; Configuration + +(register + (key [:mod :ctrl :shift] :q awesome.quit) + (key [:mod :ctrl] :r awesome.restart) + (key [:mod :ctrl] :q (run "oblogout")) + + (key [:mod] :equal (lambda [] + (doto (current-tag) + (tset :master_width_factor 0.5) + (tset :master_count 1) + (tset :column_count 1)))) + (key [:mod :ctrl] :l (lambda [] (awful.tag.incmwfact 0.05))) + (key [:mod :ctrl] :h (lambda [] (awful.tag.incmwfact -0.05))) + (key [:mod] :comma (lambda [] (awful.tag.incnmaster 1 nil true))) + (key [:mod] :period (lambda [] (awful.tag.incnmaster -1 nil true))) + (key [:mod :ctrl] :period (lambda [] (awful.tag.incncol 1 nil true))) + (key [:mod :ctrl] :comma (lambda [] (awful.tag.incncol -1 nil true))) + + (key [:mod] :Return (run apps.terminal)) + (key [:mod] :d (run apps.dropdown)) + + (key [:mod] :space (lambda [] (awful.layout.inc 1))) + (key [:mod] :r (run apps.rofi)) + (key [:mod] :p (run "xfce4-display-settings")) + + (key [:mod] :e (run apps.editor)) + + (key [:mod] :BackSpace (lambda [] (focus (awful.client.restore)))) + + (key [] :XF86AudioPlay (playerctl :play-pause)) + (key [] :XF86AudioNext (playerctl :next)) + (key [] :XF86AudioPrev (playerctl :previous)) + (key [] :XF86AudioRaiseVolume (amixer "5%+")) + (key [] :XF86AudioLowerVolume (amixer "5%-")) + (key [] :XF86AudioMute (amixer "toggle")) + (key [] :XF86MonBrightnessUp (light :up 5)) + (key [] :XF86MonBrightnessDown (light :down 5))) + +(each [_key dir (pairs {:h :left :j :down :k :up :l :right})] + (register + (key [:mod] _key (lambda [] (awful.client.focus.bydirection dir))) + (key [:mod :shift] _key (lambda [] (awful.client.swap.bydirection dir))))) + +(for [i 1 9] + (let [ksym (.. "#" (+ i 9))] + (register + (key [:mod] ksym (with-focused-tag i (lambda [t] (: t :view_only)))) + (key [:mod :ctrl] ksym (with-focused-tag i (lambda [t] (awful.tag.viewtoggle t)))) + (key [:mod :shift] ksym (apply-focus i :move_to_tag)) + (key [:mod :ctrl :shift] ksym (apply-focus i :toggle_tag))))) + +keys +;;; conf/keys.fnl ends here diff --git a/conf/rules.fnl b/conf/rules.fnl new file mode 100644 index 0000000..662ea4e --- /dev/null +++ b/conf/rules.fnl @@ -0,0 +1,10 @@ +;;; rules.fnl --- Rule definitions + +;;; Code: + +(var rules {}) + + + +rules +;;; rules.fnl ends here diff --git a/conf/tags.fnl b/conf/tags.fnl new file mode 100644 index 0000000..80513af --- /dev/null +++ b/conf/tags.fnl @@ -0,0 +1,21 @@ +;;; conf/tags.fnl --- Tag configuration + +;;; Code: + +(local awful (require :awful)) + +;;; +;; Configuration + +(local layouts [awful.layout.suit.tile + awful.layout.suit.floating]) + +(set awful.layout.layouts layouts) + +(awful.screen.connect_for_each_screen + (lambda [s] + (awful.tag ["1", "2", "3", "4", "5", "6", "7", "8", "9"] + s (. layouts 1)))) + +{} +;;; conf/tags.fnl ends here diff --git a/conf/theme.fnl b/conf/theme.fnl new file mode 100644 index 0000000..75f73b4 --- /dev/null +++ b/conf/theme.fnl @@ -0,0 +1,35 @@ +;; Have to use underscores for `beautiful' compat. +;; Non-beautiful variables are commented with an asterisk. +{ + :font "Sarasa Mono J" + :variable_font "Sarasa UI J" + + :sidebar_position :left ;; * + :sidebar_width 36 ;; * + :sidebar_bg "#1a1e24" ;; * + :sidebar_subbox "#252b33" ;; * + + :bg_focus "#252b33" + :bg_normal "#1a1e24" + :fg_focus "#cfcfcf" + :fg_normal "#cfcfcf" + + :useless_gap 0 + :border_width 0 + :border_focus "#5bb3b4" + :border_normal "#1a1e24" + :use_titlebars_for_borders nil ;; * + :titlebar_border_width 1 ;; * + + :titlebar_position :left ;; * + :titlebar_size 15 ;; * + :titlebar_bg_focus "#926b3e" + :titlebar_bg_normal "#252b33" + :titlebar_fg_focus "#1a1e24" + :titlebar_fg_normal "#cfcfcf" + + :icon_theme "Vertex Maia" + :notification_icon_size 48 + :notification_border_width 0 + :notification_bg "#252b33" + } diff --git a/lib/README.md b/lib/README.md new file mode 100644 index 0000000..9789473 --- /dev/null +++ b/lib/README.md @@ -0,0 +1,8 @@ +## lib + +These would normally be installed by `luarocks`; however, given how central they are to the config, they are vendored in instead. + +### Versions + +- `fennel`: 0.1.1-2 +- `fun`: 0.1.3-1 diff --git a/lib/fennel.lua b/lib/fennel.lua new file mode 100644 index 0000000..c7d3031 --- /dev/null +++ b/lib/fennel.lua @@ -0,0 +1,1964 @@ +--[[ +Copyright (c) 2016-2018 Calvin Rose and contributors +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +]] + +-- Make global variables local. +local setmetatable = setmetatable +local getmetatable = getmetatable +local type = type +local assert = assert +local pairs = pairs +local ipairs = ipairs +local tostring = tostring +local unpack = unpack or table.unpack + +-- +-- Main Types and support functions +-- + +local function deref(self) return self[1] end + +local SYMBOL_MT = { 'SYMBOL', __tostring = deref } +local EXPR_MT = { 'EXPR', __tostring = deref } +local VARARG = setmetatable({ '...' }, { 'VARARG', __tostring = deref }) +local LIST_MT = { 'LIST', + __tostring = function (self) + local strs = {} + for _, s in ipairs(self) do + table.insert(strs, tostring(s)) + end + return '(' .. table.concat(strs, ', ', 1, #self) .. ')' + end +} + +-- Load code with an environment in all recent Lua versions +local function loadCode(code, environment, filename) + environment = environment or _ENV or _G + if setfenv and loadstring then + local f = assert(loadstring(code, filename)) + setfenv(f, environment) + return f + else + return assert(load(code, filename, "t", environment)) + end +end + +-- Create a new list +local function list(...) + return setmetatable({...}, LIST_MT) +end + +-- Create a new symbol +local function sym(str, scope, meta) + local s = {str, scope = scope} + if meta then + for k, v in pairs(meta) do + if type(k) == 'string' then s[k] = v end + end + end + return setmetatable(s, SYMBOL_MT) +end + +-- Create a new expr +-- etype should be one of +-- "literal", -- literals like numbers, strings, nil, true, false +-- "expression", -- Complex strings of Lua code, may have side effects, etc, but is an expression +-- "statement", -- Same as expression, but is also a valid statement (function calls). +-- "vargs", -- varargs symbol +-- "sym", -- symbol reference +local function expr(strcode, etype) + return setmetatable({ strcode, type = etype }, EXPR_MT) +end + +local function varg() + return VARARG +end + +local function isVarg(x) + return x == VARARG and x +end + +-- Checks if an object is a List. Returns the object if is a List. +local function isList(x) + return type(x) == 'table' and getmetatable(x) == LIST_MT and x +end + +-- Checks if an object is a symbol. Returns the object if it is a symbol. +local function isSym(x) + return type(x) == 'table' and getmetatable(x) == SYMBOL_MT and x +end + +-- Checks if an object any kind of table, EXCEPT list or symbol +local function isTable(x) + return type(x) == 'table' and + x ~= VARARG and + getmetatable(x) ~= LIST_MT and getmetatable(x) ~= SYMBOL_MT and x +end + +-- +-- Parser +-- + +-- Convert a stream of chunks to a stream of bytes. +-- Also returns a second function to clear the buffer in the byte stream +local function granulate(getchunk) + local c = '' + local index = 1 + local done = false + return function () + if done then return nil end + if index <= #c then + local b = c:byte(index) + index = index + 1 + return b + else + c = getchunk() + if not c or c == '' then + done = true + return nil + end + index = 2 + return c:byte(1) + end + end, function () + c = '' + end +end + +-- Convert a string into a stream of bytes +local function stringStream(str) + local index = 1 + return function() + local r = str:byte(index) + index = index + 1 + return r + end +end + +-- Table of delimiter bytes - (, ), [, ], {, } +-- Opener keys have closer as the value, and closers keys +-- have true as their value. +local delims = { + [40] = 41, -- ( + [41] = true, -- ) + [91] = 93, -- [ + [93] = true, -- ] + [123] = 125, -- { + [125] = true -- } +} + +local function iswhitespace(b) + return b == 32 or (b >= 9 and b <= 13) or b == 44 +end + +local function issymbolchar(b) + return b > 32 and + not delims[b] and + b ~= 127 and + b ~= 34 and + b ~= 39 and + b ~= 59 and + b ~= 44 +end + +-- Parse one value given a function that +-- returns sequential bytes. Will throw an error as soon +-- as possible without getting more bytes on bad input. Returns +-- if a value was read, and then the value read. Will return nil +-- when input stream is finished. +local function parser(getbyte, filename) + + -- Stack of unfinished values + local stack = {} + + -- Provide one character buffer and keep + -- track of current line and byte index + local line = 1 + local byteindex = 0 + local lastb + local function ungetb(ub) + if ub == 10 then line = line - 1 end + byteindex = byteindex - 1 + lastb = ub + end + local function getb() + local r + if lastb then + r, lastb = lastb, nil + else + r = getbyte() + end + byteindex = byteindex + 1 + if r == 10 then line = line + 1 end + return r + end + local function parseError(msg) + return error(msg .. ' in ' .. (filename or 'unknown') .. ':' .. line, 0) + end + + -- Parse stream + return function () + + -- Dispatch when we complete a value + local done, retval + local function dispatch(v) + if #stack == 0 then + retval = v + done = true + else + table.insert(stack[#stack], v) + end + end + + -- The main parse loop + repeat + local b + + -- Skip whitespace + repeat + b = getb() + until not b or not iswhitespace(b) + if not b then + if #stack > 0 then parseError 'unexpected end of source' end + return nil + end + + if b == 59 then -- ; Comment + repeat + b = getb() + until not b or b == 10 -- newline + elseif type(delims[b]) == 'number' then -- Opening delimiter + table.insert(stack, setmetatable({ + closer = delims[b], + line = line, + filename = filename, + bytestart = byteindex + }, LIST_MT)) + elseif delims[b] then -- Closing delimiter + if #stack == 0 then parseError 'unexpected closing delimiter' end + local last = stack[#stack] + local val + if last.closer ~= b then + parseError('unexpected delimiter ' .. string.char(b) .. + ', expected ' .. string.char(last.closer)) + end + last.byteend = byteindex -- Set closing byte index + if b == 41 then -- ) + val = last + elseif b == 93 then -- ] + val = {} + for i = 1, #last do + val[i] = last[i] + end + else -- } + if #last % 2 ~= 0 then + parseError('expected even number of values in table literal') + end + val = {} + for i = 1, #last, 2 do + val[last[i]] = last[i + 1] + end + end + stack[#stack] = nil + dispatch(val) + elseif b == 34 or b == 39 then -- Quoted string + local start = b + local state = "base" + local chars = {start} + repeat + b = getb() + chars[#chars + 1] = b + if state == "base" then + if b == 92 then + state = "backslash" + elseif b == start then + state = "done" + end + else + -- state == "backslash" + state = "base" + end + until not b or (state == "done") + if not b then parseError('unexpected end of source') end + local raw = string.char(unpack(chars)) + local formatted = raw:gsub("[\1-\31]", function (c) return '\\' .. c:byte() end) + local loadFn = loadCode(('return %s'):format(formatted), nil, filename) + dispatch(loadFn()) + else -- Try symbol + local chars = {} + local bytestart = byteindex + repeat + chars[#chars + 1] = b + b = getb() + until not b or not issymbolchar(b) + if b then ungetb(b) end + local rawstr = string.char(unpack(chars)) + if rawstr == 'true' then dispatch(true) + elseif rawstr == 'false' then dispatch(false) + elseif rawstr == '...' then dispatch(VARARG) + elseif rawstr:match('^:.+$') then -- keyword style strings + dispatch(rawstr:sub(2)) + else + local forceNumber = rawstr:match('^%d') + local numberWithStrippedUnderscores = rawstr:gsub("_", "") + local x + if forceNumber then + x = tonumber(numberWithStrippedUnderscores) or + parseError('could not read token "' .. rawstr .. '"') + else + x = tonumber(numberWithStrippedUnderscores) or + sym(rawstr, nil, { line = line, + filename = filename, + bytestart = bytestart, + byteend = byteindex, }) + end + dispatch(x) + end + end + until done + return true, retval + end +end + +-- +-- Compilation +-- + +-- Create a new Scope, optionally under a parent scope. Scopes are compile time constructs +-- that are responsible for keeping track of local variables, name mangling, and macros. +-- They are accessible to user code via the '*compiler' special form (may change). They +-- use metatables to implement nesting via inheritance. +local function makeScope(parent) + return { + unmanglings = setmetatable({}, { + __index = parent and parent.unmanglings + }), + manglings = setmetatable({}, { + __index = parent and parent.manglings + }), + specials = setmetatable({}, { + __index = parent and parent.specials + }), + symmeta = setmetatable({}, { + __index = parent and parent.symmeta + }), + parent = parent, + vararg = parent and parent.vararg, + depth = parent and ((parent.depth or 0) + 1) or 0 + } +end + +-- Assert a condition and raise a compile error with line numbers. The ast arg +-- should be unmodified so that its first element is the form being called. +local function assertCompile(condition, msg, ast) + -- if we use regular `assert' we can't provide the `level' argument of zero + if not condition then + error(string.format("Compile error in '%s' %s:%s: %s", + isSym(ast[1]) and ast[1][1] or ast[1] or '()', + ast.filename or "unknown", ast.line or '?', msg), 0) + end + return condition +end + +local GLOBAL_SCOPE = makeScope() +GLOBAL_SCOPE.vararg = true +local SPECIALS = GLOBAL_SCOPE.specials +local COMPILER_SCOPE = makeScope(GLOBAL_SCOPE) + +local luaKeywords = { + 'and', 'break', 'do', 'else', 'elseif', 'end', 'false', 'for', 'function', + 'if', 'in', 'local', 'nil', 'not', 'or', 'repeat', 'return', 'then', 'true', + 'until', 'while' +} +for i, v in ipairs(luaKeywords) do + luaKeywords[v] = i +end + +local function isValidLuaIdentifier(str) + return (str:match('^[%a_][%w_]*$') and not luaKeywords[str]) +end + +-- Allow printing a string to Lua, also keep as 1 line. +local serializeSubst = { + ['\a'] = '\\a', + ['\b'] = '\\b', + ['\f'] = '\\f', + ['\n'] = 'n', + ['\t'] = '\\t', + ['\v'] = '\\v' +} +local function serializeString(str) + local s = ("%q"):format(str) + s = s:gsub('.', serializeSubst):gsub("[\128-\255]", function(c) + return "\\" .. c:byte() + end) + return s +end + +-- A multi symbol is a symbol that is actually composed of +-- two or more symbols using the dot syntax. The main differences +-- from normal symbols is that they cannot be declared local, and +-- they may have side effects on invocation (metatables) +local function isMultiSym(str) + if type(str) ~= 'string' then return end + local parts = {} + for part in str:gmatch('[^%.]+') do + parts[#parts + 1] = part + end + return #parts > 0 and + str:match('%.') and + (not str:match('%.%.')) and + str:byte() ~= string.byte '.' and + str:byte(-1) ~= string.byte '.' and + parts +end + +-- Mangler for global symbols. Does not protect against collisions, +-- but makes them unlikely. This is the mangling that is exposed to +-- to the world. +local function globalMangling(str) + if isValidLuaIdentifier(str) then + return str + end + -- Use underscore as escape character + return '__fnl_global__' .. str:gsub('[^%w]', function (c) + return ('_%02x'):format(c:byte()) + end) +end + +-- Reverse a global mangling. Takes a Lua identifier and +-- returns the fennel symbol string that created it. +local function globalUnmangling(identifier) + local rest = identifier:match('^__fnl_global__(.*)$') + if rest then + return rest:gsub('_[%da-f][%da-f]', function (code) + return string.char(tonumber(code:sub(2), 16)) + end) + else + return identifier + end +end + +-- Creates a symbol from a string by mangling it. +-- ensures that the generated symbol is unique +-- if the input string is unique in the scope. +local function localMangling(str, scope, ast) + if scope.manglings[str] then + return scope.manglings[str] + end + local append = 0 + local mangling = str + assertCompile(not isMultiSym(str), 'did not expect multi symbol ' .. str, ast) + + -- Mapping mangling to a valid Lua identifier + if luaKeywords[mangling] or mangling:match('^%d') then + mangling = '_' .. mangling + end + mangling = mangling:gsub('-', '_') + mangling = mangling:gsub('[^%w_]', function (c) + return ('_%02x'):format(c:byte()) + end) + + local raw = mangling + while scope.unmanglings[mangling] do + mangling = raw .. append + append = append + 1 + end + scope.unmanglings[mangling] = str + scope.manglings[str] = mangling + return mangling +end + +-- Combine parts of a symbol +local function combineParts(parts, scope) + local ret = scope.manglings[parts[1]] or globalMangling(parts[1]) + for i = 2, #parts do + if isValidLuaIdentifier(parts[i]) then + ret = ret .. '.' .. parts[i] + else + ret = ret .. '[' .. serializeString(parts[i]) .. ']' + end + end + return ret +end + +-- Generates a unique symbol in the scope. +local function gensym(scope) + local mangling + local append = 0 + repeat + mangling = '_' .. append .. '_' + append = append + 1 + until not scope.unmanglings[mangling] + scope.unmanglings[mangling] = true + return mangling +end + +-- Declare a local symbol +local function declareLocal(symbol, meta, scope, ast) + local name = symbol[1] + assertCompile(not isMultiSym(name), "did not expect mutltisym", ast) + local mangling = localMangling(name, scope, ast) + scope.symmeta[name] = meta + return mangling +end + +-- If there's a provided list of allowed globals, don't let references +-- thru that aren't on the list. This list is set at the compiler +-- entry points of compile and compileStream. +local allowedGlobals + +local function globalAllowed(name) + if not allowedGlobals then return true end + for _, g in ipairs(allowedGlobals) do + if g == name then return true end + end +end + +-- Convert symbol to Lua code. Will only work for local symbols +-- if they have already been declared via declareLocal +local function symbolToExpression(symbol, scope, isReference) + local name = symbol[1] + local parts = isMultiSym(name) or {name} + local etype = (#parts > 1) and "expression" or "sym" + local isLocal = scope.manglings[parts[1]] + -- if it's a reference and not a symbol which introduces a new binding + -- then we need to check for allowed globals + assertCompile(not isReference or isLocal or globalAllowed(parts[1]), + 'unknown global in strict mode: ' .. parts[1], symbol) + return expr(combineParts(parts, scope), etype) +end + + +-- Emit Lua code +local function emit(chunk, out, ast) + if type(out) == 'table' then + table.insert(chunk, out) + else + table.insert(chunk, {leaf = out, ast = ast}) + end +end + +-- Do some peephole optimization. +local function peephole(chunk) + if chunk.leaf then return chunk end + -- Optimize do ... end in some cases. + if #chunk == 3 and + chunk[1].leaf == 'do' and + not chunk[2].leaf and + chunk[3].leaf == 'end' then + return peephole(chunk[2]) + end + -- Recurse + for i, v in ipairs(chunk) do + chunk[i] = peephole(v) + end + return chunk +end + +-- correlate line numbers in input with line numbers in output +local function flattenChunkCorrelated(mainChunk) + local function flatten(chunk, out, lastLine, file) + if chunk.leaf then + out[lastLine] = (out[lastLine] or "") .. " " .. chunk.leaf + else + for _, subchunk in ipairs(chunk) do + -- Ignore empty chunks + if subchunk.leaf or #subchunk > 0 then + -- don't increase line unless it's from the same file + if subchunk.ast and file == subchunk.ast.file then + lastLine = math.max(lastLine, subchunk.ast.line or 0) + end + lastLine = flatten(subchunk, out, lastLine, file) + end + end + end + return lastLine + end + local out = {} + local last = flatten(mainChunk, out, 1, mainChunk.file) + for i = 1, last do + if out[i] == nil then out[i] = "" end + end + return table.concat(out, "\n") +end + +-- Flatten a tree of indented Lua source code lines. +-- Tab is what is used to indent a block. +local function flattenChunk(sm, chunk, tab, depth) + if type(tab) == 'boolean' then tab = tab and ' ' or '' end + if chunk.leaf then + local code = chunk.leaf + local info = chunk.ast + -- Just do line info for now to save memory + if sm then sm[#sm + 1] = info and info.line or -1 end + return code + else + local parts = {} + for i = 1, #chunk do + -- Ignore empty chunks + if chunk[i].leaf or #(chunk[i]) > 0 then + local sub = flattenChunk(sm, chunk[i], tab, depth + 1) + if depth > 0 then sub = tab .. sub:gsub('\n', '\n' .. tab) end + table.insert(parts, sub) + end + end + return table.concat(parts, '\n') + end +end + +-- Some global state for all fennel sourcemaps. For the time being, +-- this seems the easiest way to store the source maps. +-- Sourcemaps are stored with source being mapped as the key, prepended +-- with '@' if it is a filename (like debug.getinfo returns for source). +-- The value is an array of mappings for each line. +local fennelSourcemap = {} +-- TODO: loading, unloading, and saving sourcemaps? + +local function makeShortSrc(source) + source = source:gsub('\n', ' ') + if #source <= 49 then + return '[fennel "' .. source .. '"]' + else + return '[fennel "' .. source:sub(1, 46) .. '..."]' + end +end + +-- Return Lua source and source map table +local function flatten(chunk, options) + local sm = options.sourcemap and {} + chunk = peephole(chunk) + if(options.correlate) then + return flattenChunkCorrelated(chunk), {} + else + local ret = flattenChunk(sm, chunk, options.indent, 0) + if sm then + local key, short_src + if options.filename then + short_src = options.filename + key = '@' .. short_src + else + key = ret + short_src = makeShortSrc(options.source or ret) + end + sm.short_src = short_src + sm.key = key + fennelSourcemap[key] = sm + end + return ret, sm + end +end + +-- Convert expressions to Lua string +local function exprs1(exprs) + local t = {} + for _, e in ipairs(exprs) do + t[#t + 1] = e[1] + end + return table.concat(t, ', ') +end + +-- Compile side effects for a chunk +local function keepSideEffects(exprs, chunk, start, ast) + start = start or 1 + for j = start, #exprs do + local se = exprs[j] + -- Avoid the rogue 'nil' expression (nil is usually a literal, + -- but becomes an expression if a special form + -- returns 'nil'.) + if se.type == 'expression' and se[1] ~= 'nil' then + emit(chunk, ('do local _ = %s end'):format(tostring(se)), ast) + elseif se.type == 'statement' then + emit(chunk, tostring(se), ast) + end + end +end + +-- Does some common handling of returns and register +-- targets for special forms. Also ensures a list expression +-- has an acceptable number of expressions if opts contains the +-- "nval" option. +local function handleCompileOpts(exprs, parent, opts, ast) + if opts.nval then + local n = opts.nval + if n ~= #exprs then + local len = #exprs + if len > n then + -- Drop extra + keepSideEffects(exprs, parent, n + 1, ast) + for i = n, len do + exprs[i] = nil + end + else + -- Pad with nils + for i = #exprs + 1, n do + exprs[i] = expr('nil', 'literal') + end + end + end + end + if opts.tail then + emit(parent, ('return %s'):format(exprs1(exprs)), ast) + end + if opts.target then + emit(parent, ('%s = %s'):format(opts.target, exprs1(exprs)), ast) + end + if opts.tail or opts.target then + -- Prevent statements and expression from being used twice if they + -- have side-effects. Since if the target or tail options are set, + -- the expressions are already emitted, we should not return them. This + -- is fine, as when these options are set, the caller doesn't need the result + -- anyways. + exprs = {} + end + return exprs +end + +-- Compile an AST expression in the scope into parent, a tree +-- of lines that is eventually compiled into Lua code. Also +-- returns some information about the evaluation of the compiled expression, +-- which can be used by the calling function. Macros +-- are resolved here, as well as special forms in that order. +-- the 'ast' param is the root AST to compile +-- the 'scope' param is the scope in which we are compiling +-- the 'parent' param is the table of lines that we are compiling into. +-- add lines to parent by appending strings. Add indented blocks by appending +-- tables of more lines. +-- the 'opts' param contains info about where the form is being compiled. +-- Options include: +-- 'target' - mangled name of symbol(s) being compiled to. +-- Could be one variable, 'a', or a list, like 'a, b, _0_'. +-- 'tail' - boolean indicating tail position if set. If set, form will generate a return +-- instruction. +local function compile1(ast, scope, parent, opts) + opts = opts or {} + local exprs = {} + + -- Compile the form + if isList(ast) then + -- Function call or special form + local len = #ast + assertCompile(len > 0, "expected a function to call", ast) + -- Test for special form + local first = ast[1] + if isSym(first) then -- Resolve symbol + first = first[1] + end + local special = scope.specials[first] + if special and isSym(ast[1]) then + -- Special form + exprs = special(ast, scope, parent, opts) or expr('nil', 'literal') + -- Be very accepting of strings or expression + -- as well as lists or expressions + if type(exprs) == 'string' then exprs = expr(exprs, 'expression') end + if getmetatable(exprs) == EXPR_MT then exprs = {exprs} end + -- Unless the special form explicitly handles the target, tail, and nval properties, + -- (indicated via the 'returned' flag, handle these options. + if not exprs.returned then + exprs = handleCompileOpts(exprs, parent, opts, ast) + elseif opts.tail or opts.target then + exprs = {} + end + exprs.returned = true + return exprs + else + -- Function call + local fargs = {} + local fcallee = compile1(ast[1], scope, parent, { + nval = 1 + })[1] + assertCompile(fcallee.type ~= 'literal', + 'cannot call literal value', ast) + fcallee = tostring(fcallee) + for i = 2, len do + local subexprs = compile1(ast[i], scope, parent, { + nval = i ~= len and 1 or nil + }) + fargs[#fargs + 1] = subexprs[1] or expr('nil', 'literal') + if i == len then + -- Add sub expressions to function args + for j = 2, #subexprs do + fargs[#fargs + 1] = subexprs[j] + end + else + -- Emit sub expression only for side effects + keepSideEffects(subexprs, parent, 2, ast[i]) + end + end + local call = ('%s(%s)'):format(tostring(fcallee), exprs1(fargs)) + exprs = handleCompileOpts({expr(call, 'statement')}, parent, opts, ast) + end + elseif isVarg(ast) then + assertCompile(scope.vararg, "unexpected vararg", ast) + exprs = handleCompileOpts({expr('...', 'varg')}, parent, opts, ast) + elseif isSym(ast) then + local e + -- Handle nil as special symbol - it resolves to the nil literal rather than + -- being unmangled. Alternatively, we could remove it from the lua keywords table. + if ast[1] == 'nil' then + e = expr('nil', 'literal') + else + e = symbolToExpression(ast, scope, true) + end + exprs = handleCompileOpts({e}, parent, opts, ast) + elseif type(ast) == 'nil' or type(ast) == 'boolean' then + exprs = handleCompileOpts({expr(tostring(ast), 'literal')}, parent, opts) + elseif type(ast) == 'number' then + local n = ('%.17g'):format(ast) + exprs = handleCompileOpts({expr(n, 'literal')}, parent, opts) + elseif type(ast) == 'string' then + local s = serializeString(ast) + exprs = handleCompileOpts({expr(s, 'literal')}, parent, opts) + elseif type(ast) == 'table' then + local buffer = {} + for i = 1, #ast do -- Write numeric keyed values. + local nval = i ~= #ast and 1 + buffer[#buffer + 1] = exprs1(compile1(ast[i], scope, parent, {nval = nval})) + end + local keys = {} + for k, _ in pairs(ast) do -- Write other keys. + if type(k) ~= 'number' or math.floor(k) ~= k or k < 1 or k > #ast then + local kstr + if type(k) == 'string' and isValidLuaIdentifier(k) then + kstr = k + else + kstr = '[' .. tostring(compile1(k, scope, parent, {nval = 1})[1]) .. ']' + end + table.insert(keys, { kstr, k }) + end + end + table.sort(keys, function (a, b) return a[1] < b[1] end) + for _, k in ipairs(keys) do + local v = ast[k[2]] + buffer[#buffer + 1] = ('%s = %s'):format( + k[1], tostring(compile1(v, scope, parent, {nval = 1})[1])) + end + local tbl = '{' .. table.concat(buffer, ', ') ..'}' + exprs = handleCompileOpts({expr(tbl, 'expression')}, parent, opts, ast) + else + assertCompile(false, 'could not compile value of type ' .. type(ast), ast) + end + exprs.returned = true + return exprs +end + +-- SPECIALS -- + +-- For statements and expressions, put the value in a local to avoid +-- double-evaluating it. +local function once(val, ast, scope, parent) + if val.type == 'statement' or val.type == 'expression' then + local s = gensym(scope) + emit(parent, ('local %s = %s'):format(s, tostring(val)), ast) + return expr(s, 'sym') + else + return val + end +end + +-- Implements destructuring for forms like let, bindings, etc. +-- Takes a number of options to control behavior. +-- var: Whether or not to mark symbols as mutable +-- declaration: begin each assignment with 'local' in output +-- nomulti: disallow multisyms in the destructuring. Used for (local) and (global). +-- noundef: Don't set undefined bindings. (set) +-- forceglobal: Don't allow local bindings +local function destructure(to, from, ast, scope, parent, opts) + opts = opts or {} + local isvar = opts.isvar + local declaration = opts.declaration + local nomulti = opts.nomulti + local noundef = opts.noundef + local forceglobal = opts.forceglobal + local forceset = opts.forceset + local setter = declaration and "local %s = %s" or "%s = %s" + + -- Get Lua source for symbol, and check for errors + local function getname(symbol, up1) + local raw = symbol[1] + assertCompile(not (nomulti and isMultiSym(raw)), + 'did not expect multisym', up1) + if declaration then + return declareLocal(symbol, {var = isvar}, scope, symbol) + else + local parts = isMultiSym(raw) or {raw} + local meta = scope.symmeta[parts[1]] + if #parts == 1 and not forceset then + assertCompile(not(forceglobal and meta), + 'expected global, found var', up1) + assertCompile(meta or not noundef, + 'expected local var ' .. parts[1], up1) + assertCompile(not (meta and not meta.var), + 'expected local var', up1) + end + return symbolToExpression(symbol, scope)[1] + end + end + + -- Recursive auxiliary function + local function destructure1(left, rightexprs, up1) + if isSym(left) and left[1] ~= "nil" then + emit(parent, setter:format(getname(left, up1), exprs1(rightexprs)), left) + elseif isTable(left) then -- table destructuring + local s = gensym(scope) + emit(parent, ("local %s = %s"):format(s, exprs1(rightexprs)), left) + for i, v in ipairs(left) do + if isSym(left[i]) and left[i][1] == "&" then + assertCompile(not left[i+2], + "expected rest argument in final position", left) + local subexpr = expr(('{(table.unpack or unpack)(%s, %s)}'):format(s, i), + 'expression') + destructure1(left[i+1], {subexpr}, left) + return + else + local subexpr = expr(('%s[%d]'):format(s, i), 'expression') + destructure1(v, {subexpr}, left) + end + end + elseif isList(left) then -- values destructuring + local leftNames, tables = {}, {} + for i, name in ipairs(left) do + local symname + if isSym(name) then -- binding directly to a name + symname = getname(name, up1) + else -- further destructuring of tables inside values + symname = gensym(scope) + tables[i] = {name, expr(symname, 'sym')} + end + table.insert(leftNames, symname) + end + emit(parent, setter: + format(table.concat(leftNames, ", "), exprs1(rightexprs)), left) + for _, pair in pairs(tables) do -- recurse if left-side tables found + destructure1(pair[1], {pair[2]}, left) + end + else + assertCompile(false, 'unable to destructure ' .. tostring(left), up1) + end + end + + local rexps = compile1(from, scope, parent) + local ret = destructure1(to, rexps, ast) + return ret +end + +-- Unlike most expressions and specials, 'values' resolves with multiple +-- values, one for each argument, allowing multiple return values. The last +-- expression, can return multiple arguments as well, allowing for more than the number +-- of expected arguments. +local function values(ast, scope, parent) + local len = #ast + local exprs = {} + for i = 2, len do + local subexprs = compile1(ast[i], scope, parent, {}) + exprs[#exprs + 1] = subexprs[1] or expr('nil', 'literal') + if i == len then + for j = 2, #subexprs do + exprs[#exprs + 1] = subexprs[j] + end + else + -- Emit sub expression only for side effects + keepSideEffects(subexprs, parent, 2, ast) + end + end + return exprs +end + +-- Compile a list of forms for side effects +local function compileDo(ast, scope, parent, start) + start = start or 2 + local len = #ast + local subScope = makeScope(scope) + for i = start, len do + compile1(ast[i], subScope, parent, { + nval = 0 + }) + end +end + +-- Implements a do statement, starting at the 'start' element. By default, start is 2. +local function doImpl(ast, scope, parent, opts, start, chunk, subScope) + start = start or 2 + subScope = subScope or makeScope(scope) + chunk = chunk or {} + local len = #ast + local outerTarget = opts.target + local outerTail = opts.tail + local retexprs = {returned = true} + + -- See if we need special handling to get the return values + -- of the do block + if not outerTarget and opts.nval ~= 0 and not outerTail then + if opts.nval then + -- Generate a local target + local syms = {} + for i = 1, opts.nval do + local s = gensym(scope) + syms[i] = s + retexprs[i] = expr(s, 'sym') + end + outerTarget = table.concat(syms, ', ') + emit(parent, ('local %s'):format(outerTarget), ast) + emit(parent, 'do', ast) + else + -- We will use an IIFE for the do + local fname = gensym(scope) + local fargs = scope.vararg and '...' or '' + emit(parent, ('local function %s(%s)'):format(fname, fargs), ast) + retexprs = expr(fname .. '(' .. fargs .. ')', 'statement') + outerTail = true + outerTarget = nil + end + else + emit(parent, 'do', ast) + end + -- Compile the body + if start > len then + -- In the unlikely case we do a do with no arguments. + compile1(nil, subScope, chunk, { + tail = outerTail, + target = outerTarget + }) + -- There will be no side effects + else + for i = start, len do + local subopts = { + nval = i ~= len and 0 or opts.nval, + tail = i == len and outerTail or nil, + target = i == len and outerTarget or nil + } + local subexprs = compile1(ast[i], subScope, chunk, subopts) + if i ~= len then + keepSideEffects(subexprs, parent, nil, ast[i]) + end + end + end + emit(parent, chunk, ast) + emit(parent, 'end', ast) + return retexprs +end + +SPECIALS['do'] = doImpl +SPECIALS['values'] = values + +-- The fn special declares a function. Syntax is similar to other lisps; +-- (fn optional-name [arg ...] (body)) +-- Further decoration such as docstrings, meta info, and multibody functions a possibility. +SPECIALS['fn'] = function(ast, scope, parent) + local fScope = makeScope(scope) + local fChunk = {} + local index = 2 + local fnName = isSym(ast[index]) + local isLocalFn + fScope.vararg = false + if fnName and fnName[1] ~= 'nil' then + isLocalFn = not isMultiSym(fnName[1]) + if isLocalFn then + fnName = declareLocal(fnName, {}, scope, ast) + else + fnName = symbolToExpression(fnName, scope)[1] + end + index = index + 1 + else + isLocalFn = true + fnName = gensym(scope) + end + local argList = assertCompile(isTable(ast[index]), + 'expected vector arg list [a b ...]', ast) + local argNameList = {} + for i = 1, #argList do + if isVarg(argList[i]) then + assertCompile(i == #argList, "expected vararg in last parameter position", ast) + argNameList[i] = '...' + fScope.vararg = true + elseif(isSym(argList[i]) and argList[i][1] ~= "nil" + and not isMultiSym(argList[i][1])) then + argNameList[i] = declareLocal(argList[i], {}, fScope, ast) + elseif isTable(argList[i]) then + local raw = sym(gensym(scope)) + argNameList[i] = declareLocal(raw, {}, fScope, ast) + destructure(argList[i], raw, ast, fScope, fChunk, + { declaration = true, nomulti = true }) + else + assertCompile(false, 'expected symbol for function parameter', ast) + end + end + for i = index + 1, #ast do + compile1(ast[i], fScope, fChunk, { + tail = i == #ast, + nval = i ~= #ast and 0 or nil + }) + end + if isLocalFn then + emit(parent, ('local function %s(%s)') + :format(fnName, table.concat(argNameList, ', ')), ast) + else + emit(parent, ('%s = function(%s)') + :format(fnName, table.concat(argNameList, ', ')), ast) + end + emit(parent, fChunk, ast) + emit(parent, 'end', ast) + return expr(fnName, 'sym') +end + +SPECIALS['luaexpr'] = function(ast) + return tostring(ast[2]) +end + +SPECIALS['luastatement'] = function(ast) + return expr(tostring(ast[2]), 'statement') +end + +-- Wrapper for table access +SPECIALS['.'] = function(ast, scope, parent) + local len = #ast + assertCompile(len > 1, "expected table argument", ast) + local lhs = compile1(ast[2], scope, parent, {nval = 1}) + if len == 2 then + return tostring(lhs[1]) + else + local indices = {} + for i = 3, len do + local index = ast[i] + if type(index) == 'string' and isValidLuaIdentifier(index) then + table.insert(indices, '.' .. index) + else + index = compile1(index, scope, parent, {nval = 1})[1] + table.insert(indices, '[' .. tostring(index) .. ']') + end + end + -- extra parens are needed for table literals + if isTable(ast[2]) then + return '(' .. tostring(lhs[1]) .. ')' .. table.concat(indices) + else + return tostring(lhs[1]) .. table.concat(indices) + end + end +end + +SPECIALS['global'] = function(ast, scope, parent) + assertCompile(#ast == 3, "expected name and value", ast) + if allowedGlobals then table.insert(allowedGlobals, ast[2][1]) end + destructure(ast[2], ast[3], ast, scope, parent, { + nomulti = true, + forceglobal = true + }) +end + +SPECIALS['set'] = function(ast, scope, parent) + assertCompile(#ast == 3, "expected name and value", ast) + destructure(ast[2], ast[3], ast, scope, parent, { + noundef = true + }) +end + +SPECIALS['set-forcibly!'] = function(ast, scope, parent) + assertCompile(#ast == 3, "expected name and value", ast) + destructure(ast[2], ast[3], ast, scope, parent, { + forceset = true + }) +end + +SPECIALS['local'] = function(ast, scope, parent) + assertCompile(#ast == 3, "expected name and value", ast) + destructure(ast[2], ast[3], ast, scope, parent, { + declaration = true, + nomulti = true + }) +end + +SPECIALS['var'] = function(ast, scope, parent) + assertCompile(#ast == 3, "expected name and value", ast) + destructure(ast[2], ast[3], ast, scope, parent, { + declaration = true, + nomulti = true, + isvar = true + }) +end + +SPECIALS['let'] = function(ast, scope, parent, opts) + local bindings = ast[2] + assertCompile(isList(bindings) or isTable(bindings), + 'expected table for destructuring', ast) + assertCompile(#bindings % 2 == 0, + 'expected even number of name/value bindings', ast) + assertCompile(#ast >= 3, 'missing body expression', ast) + local subScope = makeScope(scope) + local subChunk = {} + for i = 1, #bindings, 2 do + destructure(bindings[i], bindings[i + 1], ast, subScope, subChunk, { + declaration = true, + nomulti = true + }) + end + return doImpl(ast, scope, parent, opts, 3, subChunk, subScope) +end + +-- For setting items in a table +SPECIALS['tset'] = function(ast, scope, parent) + local root = compile1(ast[2], scope, parent, {nval = 1})[1] + local keys = {} + for i = 3, #ast - 1 do + local key = compile1(ast[i], scope, parent, {nval = 1})[1] + keys[#keys + 1] = tostring(key) + end + local value = compile1(ast[#ast], scope, parent, {nval = 1})[1] + emit(parent, ('%s[%s] = %s'):format(tostring(root), + table.concat(keys, ']['), + tostring(value)), ast) +end + +-- The if special form behaves like the cond form in +-- many languages +SPECIALS['if'] = function(ast, scope, parent, opts) + local doScope = makeScope(scope) + local branches = {} + local elseBranch = nil + + -- Calculate some external stuff. Optimizes for tail calls and what not + local outerTail = true + local outerTarget = nil + local wrapper = 'iife' + if opts.tail then + wrapper = 'none' + end + + -- Compile bodies and conditions + local bodyOpts = { + tail = outerTail, + target = outerTarget + } + local function compileBody(i) + local chunk = {} + local cscope = makeScope(doScope) + compile1(ast[i], cscope, chunk, bodyOpts) + return { + chunk = chunk, + scope = cscope + } + end + for i = 2, #ast - 1, 2 do + local condchunk = {} + local cond = compile1(ast[i], doScope, condchunk, {nval = 1}) + local branch = compileBody(i + 1) + branch.cond = cond + branch.condchunk = condchunk + branch.nested = i ~= 2 and next(condchunk, nil) == nil + table.insert(branches, branch) + end + local hasElse = #ast > 3 and #ast % 2 == 0 + if hasElse then elseBranch = compileBody(#ast) end + + -- Emit code + local s = gensym(scope) + local buffer = {} + local lastBuffer = buffer + for i = 1, #branches do + local branch = branches[i] + local fstr = not branch.nested and 'if %s then' or 'elseif %s then' + local condLine = fstr:format(tostring(branch.cond[1])) + if branch.nested then + emit(lastBuffer, branch.condchunk, ast) + else + for _, v in ipairs(branch.condchunk) do emit(lastBuffer, v, ast) end + end + emit(lastBuffer, condLine, ast) + emit(lastBuffer, branch.chunk, ast) + if i == #branches then + if hasElse then + emit(lastBuffer, 'else', ast) + emit(lastBuffer, elseBranch.chunk, ast) + end + emit(lastBuffer, 'end', ast) + elseif not branches[i + 1].nested then + emit(lastBuffer, 'else', ast) + local nextBuffer = {} + emit(lastBuffer, nextBuffer, ast) + emit(lastBuffer, 'end', ast) + lastBuffer = nextBuffer + end + end + + if wrapper == 'iife' then + local iifeargs = scope.vararg and '...' or '' + emit(parent, ('local function %s(%s)'):format(tostring(s), iifeargs), ast) + emit(parent, buffer, ast) + emit(parent, 'end', ast) + return expr(('%s(%s)'):format(tostring(s), iifeargs), 'statement') + elseif wrapper == 'none' then + -- Splice result right into code + for i = 1, #buffer do + emit(parent, buffer[i], ast) + end + return {returned = true} + end +end + +-- (each [k v (pairs t)] body...) => [] +SPECIALS['each'] = function(ast, scope, parent) + local binding = assertCompile(isTable(ast[2]), 'expected binding table', ast) + local iter = table.remove(binding, #binding) -- last item is iterator call + local bindVars = {} + local destructures = {} + for _, v in ipairs(binding) do + assertCompile(isSym(v) or isTable(v), + 'expected iterator symbol or table', ast) + if(isSym(v)) then + table.insert(bindVars, declareLocal(v, {}, scope, ast)) + else + local raw = sym(gensym(scope)) + destructures[raw] = v + table.insert(bindVars, declareLocal(raw, {}, scope, ast)) + end + end + emit(parent, ('for %s in %s do'):format( + table.concat(bindVars, ', '), + tostring(compile1(iter, scope, parent, {nval = 1})[1])), ast) + local chunk = {} + for raw, args in pairs(destructures) do + destructure(args, raw, ast, scope, chunk, + { declaration = true, nomulti = true }) + end + compileDo(ast, scope, chunk, 3) + emit(parent, chunk, ast) + emit(parent, 'end', ast) +end + +-- (while condition body...) => [] +SPECIALS['while'] = function(ast, scope, parent) + local len1 = #parent + local condition = compile1(ast[2], scope, parent, {nval = 1})[1] + local len2 = #parent + local subChunk = {} + if len1 ~= len2 then + -- Compound condition + emit(parent, 'while true do', ast) + -- Move new compilation to subchunk + for i = len1 + 1, len2 do + subChunk[#subChunk + 1] = parent[i] + parent[i] = nil + end + emit(parent, ('if %s then break end'):format(condition[1]), ast) + else + -- Simple condition + emit(parent, 'while ' .. tostring(condition) .. ' do', ast) + end + compileDo(ast, makeScope(scope), subChunk, 3) + emit(parent, subChunk, ast) + emit(parent, 'end', ast) +end + +SPECIALS['for'] = function(ast, scope, parent) + local ranges = assertCompile(isTable(ast[2]), 'expected binding table', ast) + local bindingSym = assertCompile(isSym(table.remove(ast[2], 1)), + 'expected iterator symbol', ast) + local rangeArgs = {} + for i = 1, math.min(#ranges, 3) do + rangeArgs[i] = tostring(compile1(ranges[i], scope, parent, {nval = 1})[1]) + end + emit(parent, ('for %s = %s do'):format( + declareLocal(bindingSym, {}, scope, ast), + table.concat(rangeArgs, ', ')), ast) + local chunk = {} + compileDo(ast, scope, chunk, 3) + emit(parent, chunk, ast) + emit(parent, 'end', ast) +end + +SPECIALS[':'] = function(ast, scope, parent) + assertCompile(#ast >= 3, 'expected at least 3 arguments', ast) + -- Compile object + local objectexpr = compile1(ast[2], scope, parent, {nval = 1})[1] + -- Compile method selector + local methodstring + local methodident = false + if type(ast[3]) == 'string' and isValidLuaIdentifier(ast[3]) then + methodident = true + methodstring = ast[3] + else + methodstring = tostring(compile1(ast[3], scope, parent, {nval = 1})[1]) + objectexpr = once(objectexpr, ast[2], scope, parent) + end + -- Compile arguments + local args = {} + for i = 4, #ast do + local subexprs = compile1(ast[i], scope, parent, { + nval = i ~= #ast and 1 or nil + }) + for j = 1, #subexprs do + args[#args + 1] = tostring(subexprs[j]) + end + end + local fstring + if methodident then + fstring = objectexpr.type == 'literal' + and '(%s):%s(%s)' + or '%s:%s(%s)' + else + -- Make object first argument + table.insert(args, 1, tostring(objectexpr)) + fstring = objectexpr.type == 'sym' + and '%s[%s](%s)' + or '(%s)[%s](%s)' + end + return expr(fstring:format( + tostring(objectexpr), + methodstring, + table.concat(args, ', ')), 'statement') +end + +local function defineArithmeticSpecial(name, zeroArity, unaryPrefix) + local paddedOp = ' ' .. name .. ' ' + SPECIALS[name] = function(ast, scope, parent) + local len = #ast + if len == 1 then + assertCompile(zeroArity ~= nil, 'Expected more than 0 arguments', ast) + return expr(zeroArity, 'literal') + else + local operands = {} + for i = 2, len do + local subexprs = compile1(ast[i], scope, parent, { + nval = (i == 1 and 1 or nil) + }) + for j = 1, #subexprs do + operands[#operands + 1] = tostring(subexprs[j]) + end + end + if #operands == 1 then + if unaryPrefix then + return '(' .. unaryPrefix .. paddedOp .. operands[1] .. ')' + else + return operands[1] + end + else + return '(' .. table.concat(operands, paddedOp) .. ')' + end + end + end +end + +defineArithmeticSpecial('+', '0') +defineArithmeticSpecial('..', "''") +defineArithmeticSpecial('^') +defineArithmeticSpecial('-', nil, '') +defineArithmeticSpecial('*', '1') +defineArithmeticSpecial('%') +defineArithmeticSpecial('/', nil, '1') +defineArithmeticSpecial('//', nil, '1') +defineArithmeticSpecial('or', 'false') +defineArithmeticSpecial('and', 'true') + +local function defineComparatorSpecial(name, realop) + local op = realop or name + SPECIALS[name] = function(ast, scope, parent) + local len = #ast + assertCompile(len > 2, 'expected at least two arguments', ast) + local lhs = compile1(ast[2], scope, parent, {nval = 1})[1] + local lastval = compile1(ast[3], scope, parent, {nval = 1})[1] + -- avoid double-eval by introducing locals for possible side-effects + if len > 3 then lastval = once(lastval, ast[3], scope, parent) end + local out = ('(%s %s %s)'): + format(tostring(lhs), op, tostring(lastval)) + if len > 3 then + for i = 4, len do -- variadic comparison + local nextval = once(compile1(ast[i], scope, parent, {nval = 1})[1], + ast[i], scope, parent) + out = (out .. " and (%s %s %s)"): + format(tostring(lastval), op, tostring(nextval)) + lastval = nextval + end + out = '(' .. out .. ')' + end + return out + end +end + +defineComparatorSpecial('>') +defineComparatorSpecial('<') +defineComparatorSpecial('>=') +defineComparatorSpecial('<=') +defineComparatorSpecial('=', '==') +defineComparatorSpecial('~=') + +local function defineUnarySpecial(op, realop) + SPECIALS[op] = function(ast, scope, parent) + assertCompile(#ast == 2, 'expected one argument', ast) + local tail = compile1(ast[2], scope, parent, {nval = 1}) + return (realop or op) .. tostring(tail[1]) + end +end + +defineUnarySpecial('not', 'not ') +defineUnarySpecial('#') + +-- Covert a macro function to a special form +local function macroToSpecial(mac) + return function(ast, scope, parent, opts) + local ok, transformed = pcall(mac, unpack(ast, 2)) + assertCompile(ok, transformed, ast) + return compile1(transformed, scope, parent, opts) + end +end + +local function compile(ast, options) + options = options or {} + local oldGlobals = allowedGlobals + allowedGlobals = options.allowedGlobals + if options.indent == nil then options.indent = ' ' end + local chunk = {} + local scope = options.scope or makeScope(GLOBAL_SCOPE) + local exprs = compile1(ast, scope, chunk, {tail = true}) + keepSideEffects(exprs, chunk, nil, ast) + allowedGlobals = oldGlobals + return flatten(chunk, options) +end + +local function compileStream(strm, options) + options = options or {} + local oldGlobals = allowedGlobals + allowedGlobals = options.allowedGlobals + if options.indent == nil then options.indent = ' ' end + local scope = options.scope or makeScope(GLOBAL_SCOPE) + local vals = {} + for ok, val in parser(strm, options.filename) do + if not ok then break end + vals[#vals + 1] = val + end + local chunk = {} + for i = 1, #vals do + local exprs = compile1(vals[i], scope, chunk, { + tail = i == #vals + }) + keepSideEffects(exprs, chunk, nil, vals[i]) + end + allowedGlobals = oldGlobals + return flatten(chunk, options) +end + +local function compileString(str, options) + local strm = stringStream(str) + return compileStream(strm, options) +end + +--- +--- Evaluation +--- + +-- A custom traceback function for Fennel that looks similar to +-- the Lua's debug.traceback. +-- Use with xpcall to produce fennel specific stacktraces. +local function traceback(msg, start) + local level = start or 2 -- Can be used to skip some frames + local lines = {} + if msg then + table.insert(lines, msg) + end + table.insert(lines, 'stack traceback:') + while true do + local info = debug.getinfo(level, "Sln") + if not info then break end + local line + if info.what == "C" then + if info.name then + line = (' [C]: in function \'%s\''):format(info.name) + else + line = ' [C]: in ?' + end + else + local remap = fennelSourcemap[info.source] + if remap and remap[info.currentline] then + -- And some global info + info.short_src = remap.short_src + local mapping = remap[info.currentline] + -- Overwrite info with values from the mapping (mapping is now just integer, + -- but may eventually be a table + info.currentline = mapping + end + if info.what == 'Lua' then + local n = info.name and ("'" .. info.name .. "'") or '?' + line = (' %s:%d: in function %s'):format(info.short_src, info.currentline, n) + elseif info.short_src == '(tail call)' then + line = ' (tail call)' + else + line = (' %s:%d: in main chunk'):format(info.short_src, info.currentline) + end + end + table.insert(lines, line) + level = level + 1 + end + return table.concat(lines, '\n') +end + +local function currentGlobalNames(env) + local names = {} + for k in pairs(env or _G) do table.insert(names, k) end + return names +end + +local function eval(str, options, ...) + options = options or {} + -- eval and dofile are considered "live" entry points, so we can assume + -- that the globals available at compile time are a reasonable allowed list + -- UNLESS there's a metatable on env, in which case we can't assume that + -- pairs will return all the effective globals; for instance openresty + -- sets up _G in such a way that all the globals are available thru + -- the __index meta method, but as far as pairs is concerned it's empty. + if options.allowedGlobals == nil and not getmetatable(options.env) then + options.allowedGlobals = currentGlobalNames(options.env) + end + local luaSource = compileString(str, options) + local loader = loadCode(luaSource, options.env, + options.filename and ('@' .. options.filename) or str) + return loader(...) +end + +local function dofile_fennel(filename, options, ...) + options = options or {sourcemap = true} + if options.allowedGlobals == nil then + options.allowedGlobals = currentGlobalNames(options.env) + end + local f = assert(io.open(filename, "rb")) + local source = f:read("*all"):gsub("^#![^\n]*\n", "") + f:close() + options.filename = options.filename or filename + return eval(source, options, ...) +end + +-- Implements a configurable repl +local function repl(options) + + local opts = options or {} + -- This would get set for us when calling eval, but we want to seed it + -- with a value that is persistent so it doesn't get reset on each eval. + if opts.allowedGlobals == nil then + options.allowedGlobals = currentGlobalNames(opts.env) + end + + local env = opts.env or setmetatable({}, { + __index = _ENV or _G + }) + + local function defaultReadChunk() + io.write('>> ') + io.flush() + local input = io.read() + return input and input .. '\n' + end + + local function defaultOnValues(xs) + io.write(table.concat(xs, '\t')) + io.write('\n') + end + + local function defaultOnError(errtype, err, luaSource) + if (errtype == 'Lua Compile') then + io.write('Bad code generated - likely a bug with the compiler:\n') + io.write('--- Generated Lua Start ---\n') + io.write(luaSource .. '\n') + io.write('--- Generated Lua End ---\n') + end + if (errtype == 'Runtime') then + io.write(traceback(err, 4)) + io.write('\n') + else + io.write(('%s error: %s\n'):format(errtype, tostring(err))) + end + end + + -- Read options + local readChunk = opts.readChunk or defaultReadChunk + local onValues = opts.onValues or defaultOnValues + local onError = opts.onError or defaultOnError + local pp = opts.pp or tostring + + -- Make parser + local bytestream, clearstream = granulate(readChunk) + local chars = {} + local read = parser(function() + local c = bytestream() + chars[#chars + 1] = c + return c + end) + + local envdbg = (opts.env or _G)["debug"] + -- if the environment doesn't support debug.getlocal you can't save locals + local saveLocals = opts.saveLocals ~= false and envdbg and envdbg.getlocal + local saveSource = table. + concat({"local ___i___ = 1", + "while true do", + " local name, value = debug.getlocal(1, ___i___)", + " if(name and name ~= \"___i___\") then", + " ___replLocals___[name] = value", + " ___i___ = ___i___ + 1", + " else break end end"}, "\n") + + local spliceSaveLocals = function(luaSource) + -- we do some source munging in order to save off locals from each chunk + -- and reintroduce them to the beginning of the next chunk, allowing + -- locals to work in the repl the way you'd expect them to. + env.___replLocals___ = env.___replLocals___ or {} + local splicedSource = {} + for line in luaSource:gmatch("([^\n]+)\n?") do + table.insert(splicedSource, line) + end + -- reintroduce locals from the previous time around + local bind = "local %s = ___replLocals___['%s']" + for name in pairs(env.___replLocals___) do + table.insert(splicedSource, 1, bind:format(name, name)) + end + -- save off new locals at the end - if safe to do so (i.e. last line is a return) + if (string.match(splicedSource[#splicedSource], "^ *return .*$")) then + if (#splicedSource > 1) then + table.insert(splicedSource, #splicedSource, saveSource) + end + end + return table.concat(splicedSource, "\n") + end + + local scope = makeScope(GLOBAL_SCOPE) + + -- REPL loop + while true do + chars = {} + local ok, parseok, x = pcall(read) + local srcstring = string.char(unpack(chars)) + if not ok then + onError('Parse', parseok) + clearstream() + else + if not parseok then break end -- eof + local compileOk, luaSource = pcall(compile, x, { + sourcemap = opts.sourcemap, + source = srcstring, + scope = scope, + }) + if not compileOk then + clearstream() + onError('Compile', luaSource) -- luaSource is error message in this case + else + if saveLocals then + luaSource = spliceSaveLocals(luaSource) + end + local luacompileok, loader = pcall(loadCode, luaSource, env) + if not luacompileok then + clearstream() + onError('Lua Compile', loader, luaSource) + else + local loadok, ret = xpcall(function () return {loader()} end, + function (runtimeErr) + onError('Runtime', runtimeErr) + end) + if loadok then + env._ = ret[1] + env.__ = ret + for i = 1, #ret do ret[i] = pp(ret[i]) end + onValues(ret) + end + end + end + end + end +end + +local macroLoaded = {} + +local module = { + parser = parser, + granulate = granulate, + stringStream = stringStream, + compile = compile, + compileString = compileString, + compileStream = compileStream, + compile1 = compile1, + mangle = globalMangling, + unmangle = globalUnmangling, + list = list, + sym = sym, + varg = varg, + scope = makeScope, + gensym = gensym, + eval = eval, + repl = repl, + dofile = dofile_fennel, + macroLoaded = macroLoaded, + path = "./?.fnl;./?/init.fnl", + traceback = traceback, + version = "0.1.1-dev", +} + +local function searchModule(modulename) + modulename = modulename:gsub("%.", "/") + for path in string.gmatch(module.path..";", "([^;]*);") do + local filename = path:gsub("%?", modulename) + local file = io.open(filename, "rb") + if(file) then + file:close() + return filename + end + end +end + +module.make_searcher = function(options) + return function(modulename) + local opts = {} + for k,v in pairs(options or {}) do opts[k] = v end + local filename = searchModule(modulename) + if filename then + return function(modname) + return dofile_fennel(filename, opts, modname) + end + end + end +end + +-- This will allow regular `require` to work with Fennel: +-- table.insert(package.loaders, fennel.searcher) +module.searcher = module.make_searcher() + +local function makeCompilerEnv(ast, scope, parent) + return setmetatable({ + -- State of compiler if needed + _SCOPE = scope, + _CHUNK = parent, + _AST = ast, + _IS_COMPILER = true, + _SPECIALS = SPECIALS, + _VARARG = VARARG, + -- Expose the module in the compiler + fennel = module, + -- Useful for macros and meta programming. All of Fennel can be accessed + -- via fennel.myfun, for example (fennel.eval "(print 1)"). + list = list, + sym = sym, + unpack = unpack, + gensym = function() return sym(gensym(scope)) end, + [globalMangling("list?")] = isList, + [globalMangling("multi-sym?")] = isMultiSym, + [globalMangling("sym?")] = isSym, + [globalMangling("table?")] = isTable, + [globalMangling("varg?")] = isVarg, + }, { __index = _ENV or _G }) +end + +local function macroGlobals(env, globals) + local allowed = {} + for k in pairs(env) do + local g = globalUnmangling(k) + table.insert(allowed, g) + end + if globals then + for _, k in pairs(globals) do + table.insert(allowed, k) + end + end + return allowed +end + +SPECIALS['require-macros'] = function(ast, scope, parent) + for i = 2, #ast do + local modname = ast[i] + local mod + if macroLoaded[modname] then + mod = macroLoaded[modname] + else + local filename = assertCompile(searchModule(modname), + modname .. " not found.", ast) + local env = makeCompilerEnv(ast, scope, parent) + mod = dofile_fennel(filename, { + env = env, + allowedGlobals = macroGlobals(env, currentGlobalNames()) + }) + macroLoaded[modname] = mod + end + for k, v in pairs(assertCompile(isTable(mod), 'expected ' .. modname .. + ' module to be table', ast)) do + if allowedGlobals then table.insert(allowedGlobals, k) end + scope.specials[k] = macroToSpecial(v) + end + end +end + +SPECIALS['eval-compiler'] = function(ast, scope, parent) + local oldFirst = ast[1] + ast[1] = sym('do') + local luaSource = compile(ast, { scope = makeScope(COMPILER_SCOPE) }) + ast[1] = oldFirst + local loader = loadCode(luaSource, makeCompilerEnv(ast, scope, parent)) + loader() +end + +-- Load standard macros +local stdmacros = [===[ +{"->" (fn [val ...] + (var x val) + (each [_ elt (ipairs [...])] + (table.insert elt 2 x) + (set x elt)) + x) + "->>" (fn [val ...] + (var x val) + (each [_ elt (pairs [...])] + (table.insert elt x) + (set x elt)) + x) + :doto (fn [val ...] + (let [name (gensym) + form (list (sym :let) [name val])] + (each [_ elt (pairs [...])] + (table.insert elt 2 name) + (table.insert form elt)) + (table.insert form name) + form)) + :when (fn [condition body1 ...] + (assert body1 "expected body") + (list (sym 'if') condition + (list (sym 'do') body1 ...))) + :partial (fn [f ...] + (let [body (list f ...)] + (table.insert body _VARARG) + (list (sym "fn") [_VARARG] body))) + :lambda (fn [...] + (let [args [...] + has-internal-name? (sym? (. args 1)) + arglist (if has-internal-name? (. args 2) (. args 1)) + arity-check-position (if has-internal-name? 3 2)] + (assert (> (# args) 1) "missing body expression") + (each [i a (ipairs arglist)] + (if (and (not (: (tostring a) :match "^?")) + (~= (tostring a) "...")) + (table.insert args arity-check-position + (list (sym "assert") + (list (sym "~=") (sym "nil") a) + (: "Missing argument %s on %s:%s" + :format (tostring a) + (or a.filename "unknown") + (or a.line "?")))))) + (list (sym "fn") (unpack args)))) +} +]===] +do + local env = makeCompilerEnv(nil, GLOBAL_SCOPE, {}) + for name, fn in pairs(eval(stdmacros, { + env = env, + allowedGlobals = macroGlobals(env, currentGlobalNames()), + })) do + SPECIALS[name] = macroToSpecial(fn) + end +end +SPECIALS['λ'] = SPECIALS['lambda'] + +return module diff --git a/lib/fun.lua b/lib/fun.lua new file mode 100644 index 0000000..137f3bf --- /dev/null +++ b/lib/fun.lua @@ -0,0 +1,1056 @@ +--- +--- Lua Fun - a high-performance functional programming library for LuaJIT +--- +--- Copyright (c) 2013-2016 Roman Tsisyk +--- +--- Distributed under the MIT/X11 License. See COPYING.md for more details. +--- + +local exports = {} +local methods = {} + +-- compatibility with Lua 5.1/5.2 +local unpack = rawget(table, "unpack") or unpack + +-------------------------------------------------------------------------------- +-- Tools +-------------------------------------------------------------------------------- + +local return_if_not_empty = function(state_x, ...) + if state_x == nil then + return nil + end + return ... +end + +local call_if_not_empty = function(fun, state_x, ...) + if state_x == nil then + return nil + end + return state_x, fun(...) +end + +local function deepcopy(orig) -- used by cycle() + local orig_type = type(orig) + local copy + if orig_type == 'table' then + copy = {} + for orig_key, orig_value in next, orig, nil do + copy[deepcopy(orig_key)] = deepcopy(orig_value) + end + else + copy = orig + end + return copy +end + +local iterator_mt = { + -- usually called by for-in loop + __call = function(self, param, state) + return self.gen(param, state) + end; + __tostring = function(self) + return '' + end; + -- add all exported methods + __index = methods; +} + +local wrap = function(gen, param, state) + return setmetatable({ + gen = gen, + param = param, + state = state + }, iterator_mt), param, state +end +exports.wrap = wrap + +local unwrap = function(self) + return self.gen, self.param, self.state +end +methods.unwrap = unwrap + +-------------------------------------------------------------------------------- +-- Basic Functions +-------------------------------------------------------------------------------- + +local nil_gen = function(_param, _state) + return nil +end + +local string_gen = function(param, state) + local state = state + 1 + if state > #param then + return nil + end + local r = string.sub(param, state, state) + return state, r +end + +local pairs_gen = pairs({ a = 0 }) -- get the generating function from pairs +local map_gen = function(tab, key) + local value + local key, value = pairs_gen(tab, key) + return key, key, value +end + +local rawiter = function(obj, param, state) + assert(obj ~= nil, "invalid iterator") + if type(obj) == "table" then + local mt = getmetatable(obj); + if mt ~= nil then + if mt == iterator_mt then + return obj.gen, obj.param, obj.state + elseif mt.__ipairs ~= nil then + return mt.__ipairs(obj) + elseif mt.__pairs ~= nil then + return mt.__pairs(obj) + end + end + if #obj > 0 then + -- array + return ipairs(obj) + else + -- hash + return map_gen, obj, nil + end + elseif (type(obj) == "function") then + return obj, param, state + elseif (type(obj) == "string") then + if #obj == 0 then + return nil_gen, nil, nil + end + return string_gen, obj, 0 + end + error(string.format('object %s of type "%s" is not iterable', + obj, type(obj))) +end + +local iter = function(obj, param, state) + return wrap(rawiter(obj, param, state)) +end +exports.iter = iter + +local method0 = function(fun) + return function(self) + return fun(self.gen, self.param, self.state) + end +end + +local method1 = function(fun) + return function(self, arg1) + return fun(arg1, self.gen, self.param, self.state) + end +end + +local method2 = function(fun) + return function(self, arg1, arg2) + return fun(arg1, arg2, self.gen, self.param, self.state) + end +end + +local export0 = function(fun) + return function(gen, param, state) + return fun(rawiter(gen, param, state)) + end +end + +local export1 = function(fun) + return function(arg1, gen, param, state) + return fun(arg1, rawiter(gen, param, state)) + end +end + +local export2 = function(fun) + return function(arg1, arg2, gen, param, state) + return fun(arg1, arg2, rawiter(gen, param, state)) + end +end + +local each = function(fun, gen, param, state) + repeat + state = call_if_not_empty(fun, gen(param, state)) + until state == nil +end +methods.each = method1(each) +exports.each = export1(each) +methods.for_each = methods.each +exports.for_each = exports.each +methods.foreach = methods.each +exports.foreach = exports.each + +-------------------------------------------------------------------------------- +-- Generators +-------------------------------------------------------------------------------- + +local range_gen = function(param, state) + local stop, step = param[1], param[2] + local state = state + step + if state > stop then + return nil + end + return state, state +end + +local range_rev_gen = function(param, state) + local stop, step = param[1], param[2] + local state = state + step + if state < stop then + return nil + end + return state, state +end + +local range = function(start, stop, step) + if step == nil then + if stop == nil then + if start == 0 then + return nil_gen, nil, nil + end + stop = start + start = stop > 0 and 1 or -1 + end + step = start <= stop and 1 or -1 + end + + assert(type(start) == "number", "start must be a number") + assert(type(stop) == "number", "stop must be a number") + assert(type(step) == "number", "step must be a number") + assert(step ~= 0, "step must not be zero") + + if (step > 0) then + return wrap(range_gen, {stop, step}, start - step) + elseif (step < 0) then + return wrap(range_rev_gen, {stop, step}, start - step) + end +end +exports.range = range + +local duplicate_table_gen = function(param_x, state_x) + return state_x + 1, unpack(param_x) +end + +local duplicate_fun_gen = function(param_x, state_x) + return state_x + 1, param_x(state_x) +end + +local duplicate_gen = function(param_x, state_x) + return state_x + 1, param_x +end + +local duplicate = function(...) + if select('#', ...) <= 1 then + return wrap(duplicate_gen, select(1, ...), 0) + else + return wrap(duplicate_table_gen, {...}, 0) + end +end +exports.duplicate = duplicate +exports.replicate = duplicate +exports.xrepeat = duplicate + +local tabulate = function(fun) + assert(type(fun) == "function") + return wrap(duplicate_fun_gen, fun, 0) +end +exports.tabulate = tabulate + +local zeros = function() + return wrap(duplicate_gen, 0, 0) +end +exports.zeros = zeros + +local ones = function() + return wrap(duplicate_gen, 1, 0) +end +exports.ones = ones + +local rands_gen = function(param_x, _state_x) + return 0, math.random(param_x[1], param_x[2]) +end + +local rands_nil_gen = function(_param_x, _state_x) + return 0, math.random() +end + +local rands = function(n, m) + if n == nil and m == nil then + return wrap(rands_nil_gen, 0, 0) + end + assert(type(n) == "number", "invalid first arg to rands") + if m == nil then + m = n + n = 0 + else + assert(type(m) == "number", "invalid second arg to rands") + end + assert(n < m, "empty interval") + return wrap(rands_gen, {n, m - 1}, 0) +end +exports.rands = rands + +-------------------------------------------------------------------------------- +-- Slicing +-------------------------------------------------------------------------------- + +local nth = function(n, gen_x, param_x, state_x) + assert(n > 0, "invalid first argument to nth") + -- An optimization for arrays and strings + if gen_x == ipairs then + return param_x[n] + elseif gen_x == string_gen then + if n <= #param_x then + return string.sub(param_x, n, n) + else + return nil + end + end + for i=1,n-1,1 do + state_x = gen_x(param_x, state_x) + if state_x == nil then + return nil + end + end + return return_if_not_empty(gen_x(param_x, state_x)) +end +methods.nth = method1(nth) +exports.nth = export1(nth) + +local head_call = function(state, ...) + if state == nil then + error("head: iterator is empty") + end + return ... +end + +local head = function(gen, param, state) + return head_call(gen(param, state)) +end +methods.head = method0(head) +exports.head = export0(head) +exports.car = exports.head +methods.car = methods.head + +local tail = function(gen, param, state) + state = gen(param, state) + if state == nil then + return wrap(nil_gen, nil, nil) + end + return wrap(gen, param, state) +end +methods.tail = method0(tail) +exports.tail = export0(tail) +exports.cdr = exports.tail +methods.cdr = methods.tail + +local take_n_gen_x = function(i, state_x, ...) + if state_x == nil then + return nil + end + return {i, state_x}, ... +end + +local take_n_gen = function(param, state) + local n, gen_x, param_x = param[1], param[2], param[3] + local i, state_x = state[1], state[2] + if i >= n then + return nil + end + return take_n_gen_x(i + 1, gen_x(param_x, state_x)) +end + +local take_n = function(n, gen, param, state) + assert(n >= 0, "invalid first argument to take_n") + return wrap(take_n_gen, {n, gen, param}, {0, state}) +end +methods.take_n = method1(take_n) +exports.take_n = export1(take_n) + +local take_while_gen_x = function(fun, state_x, ...) + if state_x == nil or not fun(...) then + return nil + end + return state_x, ... +end + +local take_while_gen = function(param, state_x) + local fun, gen_x, param_x = param[1], param[2], param[3] + return take_while_gen_x(fun, gen_x(param_x, state_x)) +end + +local take_while = function(fun, gen, param, state) + assert(type(fun) == "function", "invalid first argument to take_while") + return wrap(take_while_gen, {fun, gen, param}, state) +end +methods.take_while = method1(take_while) +exports.take_while = export1(take_while) + +local take = function(n_or_fun, gen, param, state) + if type(n_or_fun) == "number" then + return take_n(n_or_fun, gen, param, state) + else + return take_while(n_or_fun, gen, param, state) + end +end +methods.take = method1(take) +exports.take = export1(take) + +local drop_n = function(n, gen, param, state) + assert(n >= 0, "invalid first argument to drop_n") + local i + for i=1,n,1 do + state = gen(param, state) + if state == nil then + return wrap(nil_gen, nil, nil) + end + end + return wrap(gen, param, state) +end +methods.drop_n = method1(drop_n) +exports.drop_n = export1(drop_n) + +local drop_while_x = function(fun, state_x, ...) + if state_x == nil or not fun(...) then + return state_x, false + end + return state_x, true, ... +end + +local drop_while = function(fun, gen_x, param_x, state_x) + assert(type(fun) == "function", "invalid first argument to drop_while") + local cont, state_x_prev + repeat + state_x_prev = deepcopy(state_x) + state_x, cont = drop_while_x(fun, gen_x(param_x, state_x)) + until not cont + if state_x == nil then + return wrap(nil_gen, nil, nil) + end + return wrap(gen_x, param_x, state_x_prev) +end +methods.drop_while = method1(drop_while) +exports.drop_while = export1(drop_while) + +local drop = function(n_or_fun, gen_x, param_x, state_x) + if type(n_or_fun) == "number" then + return drop_n(n_or_fun, gen_x, param_x, state_x) + else + return drop_while(n_or_fun, gen_x, param_x, state_x) + end +end +methods.drop = method1(drop) +exports.drop = export1(drop) + +local split = function(n_or_fun, gen_x, param_x, state_x) + return take(n_or_fun, gen_x, param_x, state_x), + drop(n_or_fun, gen_x, param_x, state_x) +end +methods.split = method1(split) +exports.split = export1(split) +methods.split_at = methods.split +exports.split_at = exports.split +methods.span = methods.split +exports.span = exports.split + +-------------------------------------------------------------------------------- +-- Indexing +-------------------------------------------------------------------------------- + +local index = function(x, gen, param, state) + local i = 1 + for _k, r in gen, param, state do + if r == x then + return i + end + i = i + 1 + end + return nil +end +methods.index = method1(index) +exports.index = export1(index) +methods.index_of = methods.index +exports.index_of = exports.index +methods.elem_index = methods.index +exports.elem_index = exports.index + +local indexes_gen = function(param, state) + local x, gen_x, param_x = param[1], param[2], param[3] + local i, state_x = state[1], state[2] + local r + while true do + state_x, r = gen_x(param_x, state_x) + if state_x == nil then + return nil + end + i = i + 1 + if r == x then + return {i, state_x}, i + end + end +end + +local indexes = function(x, gen, param, state) + return wrap(indexes_gen, {x, gen, param}, {0, state}) +end +methods.indexes = method1(indexes) +exports.indexes = export1(indexes) +methods.elem_indexes = methods.indexes +exports.elem_indexes = exports.indexes +methods.indices = methods.indexes +exports.indices = exports.indexes +methods.elem_indices = methods.indexes +exports.elem_indices = exports.indexes + +-------------------------------------------------------------------------------- +-- Filtering +-------------------------------------------------------------------------------- + +local filter1_gen = function(fun, gen_x, param_x, state_x, a) + while true do + if state_x == nil or fun(a) then break; end + state_x, a = gen_x(param_x, state_x) + end + return state_x, a +end + +-- call each other +local filterm_gen +local filterm_gen_shrink = function(fun, gen_x, param_x, state_x) + return filterm_gen(fun, gen_x, param_x, gen_x(param_x, state_x)) +end + +filterm_gen = function(fun, gen_x, param_x, state_x, ...) + if state_x == nil then + return nil + end + if fun(...) then + return state_x, ... + end + return filterm_gen_shrink(fun, gen_x, param_x, state_x) +end + +local filter_detect = function(fun, gen_x, param_x, state_x, ...) + if select('#', ...) < 2 then + return filter1_gen(fun, gen_x, param_x, state_x, ...) + else + return filterm_gen(fun, gen_x, param_x, state_x, ...) + end +end + +local filter_gen = function(param, state_x) + local fun, gen_x, param_x = param[1], param[2], param[3] + return filter_detect(fun, gen_x, param_x, gen_x(param_x, state_x)) +end + +local filter = function(fun, gen, param, state) + return wrap(filter_gen, {fun, gen, param}, state) +end +methods.filter = method1(filter) +exports.filter = export1(filter) +methods.remove_if = methods.filter +exports.remove_if = exports.filter + +local grep = function(fun_or_regexp, gen, param, state) + local fun = fun_or_regexp + if type(fun_or_regexp) == "string" then + fun = function(x) return string.find(x, fun_or_regexp) ~= nil end + end + return filter(fun, gen, param, state) +end +methods.grep = method1(grep) +exports.grep = export1(grep) + +local partition = function(fun, gen, param, state) + local neg_fun = function(...) + return not fun(...) + end + return filter(fun, gen, param, state), + filter(neg_fun, gen, param, state) +end +methods.partition = method1(partition) +exports.partition = export1(partition) + +-------------------------------------------------------------------------------- +-- Reducing +-------------------------------------------------------------------------------- + +local foldl_call = function(fun, start, state, ...) + if state == nil then + return nil, start + end + return state, fun(start, ...) +end + +local foldl = function(fun, start, gen_x, param_x, state_x) + while true do + state_x, start = foldl_call(fun, start, gen_x(param_x, state_x)) + if state_x == nil then + break; + end + end + return start +end +methods.foldl = method2(foldl) +exports.foldl = export2(foldl) +methods.reduce = methods.foldl +exports.reduce = exports.foldl + +local length = function(gen, param, state) + if gen == ipairs or gen == string_gen then + return #param + end + local len = 0 + repeat + state = gen(param, state) + len = len + 1 + until state == nil + return len - 1 +end +methods.length = method0(length) +exports.length = export0(length) + +local is_null = function(gen, param, state) + return gen(param, deepcopy(state)) == nil +end +methods.is_null = method0(is_null) +exports.is_null = export0(is_null) + +local is_prefix_of = function(iter_x, iter_y) + local gen_x, param_x, state_x = iter(iter_x) + local gen_y, param_y, state_y = iter(iter_y) + + local r_x, r_y + for i=1,10,1 do + state_x, r_x = gen_x(param_x, state_x) + state_y, r_y = gen_y(param_y, state_y) + if state_x == nil then + return true + end + if state_y == nil or r_x ~= r_y then + return false + end + end +end +methods.is_prefix_of = is_prefix_of +exports.is_prefix_of = is_prefix_of + +local all = function(fun, gen_x, param_x, state_x) + local r + repeat + state_x, r = call_if_not_empty(fun, gen_x(param_x, state_x)) + until state_x == nil or not r + return state_x == nil +end +methods.all = method1(all) +exports.all = export1(all) +methods.every = methods.all +exports.every = exports.all + +local any = function(fun, gen_x, param_x, state_x) + local r + repeat + state_x, r = call_if_not_empty(fun, gen_x(param_x, state_x)) + until state_x == nil or r + return not not r +end +methods.any = method1(any) +exports.any = export1(any) +methods.some = methods.any +exports.some = exports.any + +local sum = function(gen, param, state) + local s = 0 + local r = 0 + repeat + s = s + r + state, r = gen(param, state) + until state == nil + return s +end +methods.sum = method0(sum) +exports.sum = export0(sum) + +local product = function(gen, param, state) + local p = 1 + local r = 1 + repeat + p = p * r + state, r = gen(param, state) + until state == nil + return p +end +methods.product = method0(product) +exports.product = export0(product) + +local min_cmp = function(m, n) + if n < m then return n else return m end +end + +local max_cmp = function(m, n) + if n > m then return n else return m end +end + +local min = function(gen, param, state) + local state, m = gen(param, state) + if state == nil then + error("min: iterator is empty") + end + + local cmp + if type(m) == "number" then + -- An optimization: use math.min for numbers + cmp = math.min + else + cmp = min_cmp + end + + for _, r in gen, param, state do + m = cmp(m, r) + end + return m +end +methods.min = method0(min) +exports.min = export0(min) +methods.minimum = methods.min +exports.minimum = exports.min + +local min_by = function(cmp, gen_x, param_x, state_x) + local state_x, m = gen_x(param_x, state_x) + if state_x == nil then + error("min: iterator is empty") + end + + for _, r in gen_x, param_x, state_x do + m = cmp(m, r) + end + return m +end +methods.min_by = method1(min_by) +exports.min_by = export1(min_by) +methods.minimum_by = methods.min_by +exports.minimum_by = exports.min_by + +local max = function(gen_x, param_x, state_x) + local state_x, m = gen_x(param_x, state_x) + if state_x == nil then + error("max: iterator is empty") + end + + local cmp + if type(m) == "number" then + -- An optimization: use math.max for numbers + cmp = math.max + else + cmp = max_cmp + end + + for _, r in gen_x, param_x, state_x do + m = cmp(m, r) + end + return m +end +methods.max = method0(max) +exports.max = export0(max) +methods.maximum = methods.max +exports.maximum = exports.max + +local max_by = function(cmp, gen_x, param_x, state_x) + local state_x, m = gen_x(param_x, state_x) + if state_x == nil then + error("max: iterator is empty") + end + + for _, r in gen_x, param_x, state_x do + m = cmp(m, r) + end + return m +end +methods.max_by = method1(max_by) +exports.max_by = export1(max_by) +methods.maximum_by = methods.maximum_by +exports.maximum_by = exports.maximum_by + +local totable = function(gen_x, param_x, state_x) + local tab, key, val = {} + while true do + state_x, val = gen_x(param_x, state_x) + if state_x == nil then + break + end + table.insert(tab, val) + end + return tab +end +methods.totable = method0(totable) +exports.totable = export0(totable) + +local tomap = function(gen_x, param_x, state_x) + local tab, key, val = {} + while true do + state_x, key, val = gen_x(param_x, state_x) + if state_x == nil then + break + end + tab[key] = val + end + return tab +end +methods.tomap = method0(tomap) +exports.tomap = export0(tomap) + +-------------------------------------------------------------------------------- +-- Transformations +-------------------------------------------------------------------------------- + +local map_gen = function(param, state) + local gen_x, param_x, fun = param[1], param[2], param[3] + return call_if_not_empty(fun, gen_x(param_x, state)) +end + +local map = function(fun, gen, param, state) + return wrap(map_gen, {gen, param, fun}, state) +end +methods.map = method1(map) +exports.map = export1(map) + +local enumerate_gen_call = function(state, i, state_x, ...) + if state_x == nil then + return nil + end + return {i + 1, state_x}, i, ... +end + +local enumerate_gen = function(param, state) + local gen_x, param_x = param[1], param[2] + local i, state_x = state[1], state[2] + return enumerate_gen_call(state, i, gen_x(param_x, state_x)) +end + +local enumerate = function(gen, param, state) + return wrap(enumerate_gen, {gen, param}, {1, state}) +end +methods.enumerate = method0(enumerate) +exports.enumerate = export0(enumerate) + +local intersperse_call = function(i, state_x, ...) + if state_x == nil then + return nil + end + return {i + 1, state_x}, ... +end + +local intersperse_gen = function(param, state) + local x, gen_x, param_x = param[1], param[2], param[3] + local i, state_x = state[1], state[2] + if i % 2 == 1 then + return {i + 1, state_x}, x + else + return intersperse_call(i, gen_x(param_x, state_x)) + end +end + +-- TODO: interperse must not add x to the tail +local intersperse = function(x, gen, param, state) + return wrap(intersperse_gen, {x, gen, param}, {0, state}) +end +methods.intersperse = method1(intersperse) +exports.intersperse = export1(intersperse) + +-------------------------------------------------------------------------------- +-- Compositions +-------------------------------------------------------------------------------- + +local function zip_gen_r(param, state, state_new, ...) + if #state_new == #param / 2 then + return state_new, ... + end + + local i = #state_new + 1 + local gen_x, param_x = param[2 * i - 1], param[2 * i] + local state_x, r = gen_x(param_x, state[i]) + if state_x == nil then + return nil + end + table.insert(state_new, state_x) + return zip_gen_r(param, state, state_new, r, ...) +end + +local zip_gen = function(param, state) + return zip_gen_r(param, state, {}) +end + +-- A special hack for zip/chain to skip last two state, if a wrapped iterator +-- has been passed +local numargs = function(...) + local n = select('#', ...) + if n >= 3 then + -- Fix last argument + local it = select(n - 2, ...) + if type(it) == 'table' and getmetatable(it) == iterator_mt and + it.param == select(n - 1, ...) and it.state == select(n, ...) then + return n - 2 + end + end + return n +end + +local zip = function(...) + local n = numargs(...) + if n == 0 then + return wrap(nil_gen, nil, nil) + end + local param = { [2 * n] = 0 } + local state = { [n] = 0 } + + local i, gen_x, param_x, state_x + for i=1,n,1 do + local it = select(n - i + 1, ...) + gen_x, param_x, state_x = rawiter(it) + param[2 * i - 1] = gen_x + param[2 * i] = param_x + state[i] = state_x + end + + return wrap(zip_gen, param, state) +end +methods.zip = zip +exports.zip = zip + +local cycle_gen_call = function(param, state_x, ...) + if state_x == nil then + local gen_x, param_x, state_x0 = param[1], param[2], param[3] + return gen_x(param_x, deepcopy(state_x0)) + end + return state_x, ... +end + +local cycle_gen = function(param, state_x) + local gen_x, param_x, state_x0 = param[1], param[2], param[3] + return cycle_gen_call(param, gen_x(param_x, state_x)) +end + +local cycle = function(gen, param, state) + return wrap(cycle_gen, {gen, param, state}, deepcopy(state)) +end +methods.cycle = method0(cycle) +exports.cycle = export0(cycle) + +-- call each other +local chain_gen_r1 +local chain_gen_r2 = function(param, state, state_x, ...) + if state_x == nil then + local i = state[1] + i = i + 1 + if i > #param / 3 then + return nil + end + local state_x = param[3 * i] + return chain_gen_r1(param, {i, state_x}) + end + return {state[1], state_x}, ... +end + +chain_gen_r1 = function(param, state) + local i, state_x = state[1], state[2] + local gen_x, param_x = param[3 * i - 2], param[3 * i - 1] + return chain_gen_r2(param, state, gen_x(param_x, state[2])) +end + +local chain = function(...) + local n = numargs(...) + if n == 0 then + return wrap(nil_gen, nil, nil) + end + + local param = { [3 * n] = 0 } + local i, gen_x, param_x, state_x + for i=1,n,1 do + local elem = select(i, ...) + gen_x, param_x, state_x = iter(elem) + param[3 * i - 2] = gen_x + param[3 * i - 1] = param_x + param[3 * i] = state_x + end + + return wrap(chain_gen_r1, param, {1, param[3]}) +end +methods.chain = chain +exports.chain = chain + +-------------------------------------------------------------------------------- +-- Operators +-------------------------------------------------------------------------------- + +local operator = { + ---------------------------------------------------------------------------- + -- Comparison operators + ---------------------------------------------------------------------------- + lt = function(a, b) return a < b end, + le = function(a, b) return a <= b end, + eq = function(a, b) return a == b end, + ne = function(a, b) return a ~= b end, + ge = function(a, b) return a >= b end, + gt = function(a, b) return a > b end, + + ---------------------------------------------------------------------------- + -- Arithmetic operators + ---------------------------------------------------------------------------- + add = function(a, b) return a + b end, + div = function(a, b) return a / b end, + floordiv = function(a, b) return math.floor(a/b) end, + intdiv = function(a, b) + local q = a / b + if a >= 0 then return math.floor(q) else return math.ceil(q) end + end, + mod = function(a, b) return a % b end, + mul = function(a, b) return a * b end, + neq = function(a) return -a end, + unm = function(a) return -a end, -- an alias + pow = function(a, b) return a ^ b end, + sub = function(a, b) return a - b end, + truediv = function(a, b) return a / b end, + + ---------------------------------------------------------------------------- + -- String operators + ---------------------------------------------------------------------------- + concat = function(a, b) return a..b end, + len = function(a) return #a end, + length = function(a) return #a end, -- an alias + + ---------------------------------------------------------------------------- + -- Logical operators + ---------------------------------------------------------------------------- + land = function(a, b) return a and b end, + lor = function(a, b) return a or b end, + lnot = function(a) return not a end, + truth = function(a) return not not a end, +} +exports.operator = operator +methods.operator = operator +exports.op = operator +methods.op = operator + +-------------------------------------------------------------------------------- +-- module definitions +-------------------------------------------------------------------------------- + +-- a special syntax sugar to export all functions to the global table +setmetatable(exports, { + __call = function(t, override) + for k, v in pairs(t) do + if _G[k] ~= nil then + local msg = 'function ' .. k .. ' already exists in global scope.' + if override then + _G[k] = v + print('WARNING: ' .. msg .. ' Overwritten.') + else + print('NOTICE: ' .. msg .. ' Skipped.') + end + else + _G[k] = v + end + end + end, +}) + +return exports diff --git a/lib/keys.fnl b/lib/keys.fnl new file mode 100644 index 0000000..0ca7b7e --- /dev/null +++ b/lib/keys.fnl @@ -0,0 +1,35 @@ +;;; lib/keys.fnl --- Key configuration utilities + +;;; Commentary: +;; Comments + +;;; Code: + +(var keys {}) + +(local awful (require :awful)) +(local fun (require :lib.fun)) + +(local modifiers + {:mod "Mod4" + :alt "Mod1" + :super "Mod4" + :shift "Shift" + :ctrl "Control"}) + +;;; +;; Functions + +(fn map-mods [mods] + (->> mods + (fun.map (partial . modifiers)) + (fun.totable))) + +(fn keys.key [mods kc fun] + (awful.key (map-mods mods) kc fun)) + +(fn keys.button [mods bc fun] + (awful.button (map-mods mods) bc fun)) + +keys +;;; lib/keys.fnl ends here diff --git a/lib/std.fnl b/lib/std.fnl new file mode 100644 index 0000000..97bf29c --- /dev/null +++ b/lib/std.fnl @@ -0,0 +1,17 @@ +;;; std.fnl --- A small standard library + +;;; Commentary: +;; Porting a few functions that I'm used to from other Lisps. + +;;; Code: + +(var std {}) + +;;; +;; Functions + +(fn std.zero? [n] + (= 0 n)) + +std +;;; std.fnl ends here diff --git a/module/decorate.fnl b/module/decorate.fnl new file mode 100644 index 0000000..898b7d2 --- /dev/null +++ b/module/decorate.fnl @@ -0,0 +1,74 @@ +;;; decorate.fnl --- Client decorations + +;;; Commentary: +;; Comments + +;;; Code: + +(local awful (require :awful)) +(local beautiful (require :beautiful)) +(local gears (require :gears)) +(local wibox (require :wibox)) + +(local std (require :lib.std)) + +;; TODO Properly abstract these into a library + +(local fun (require :lib.fun)) +(local modifiers + {:mod "Mod4" + :alt "Mod1" + :super "Mod4" + :shift "Shift" + :ctrl "Control"}) + +(fn map-mods [mods] + (->> mods + (fun.map (partial . modifiers)) + (fun.totable))) + +(fn button [mods bc fun] + (awful.button (map-mods mods) bc fun)) + +;;; +;; Functions + +(fn mouse-button [c bc cmd] + (button [] bc (lambda [] + (: c :emit_signal + :request::activate :titlebar {:raise true}) + ((. awful.mouse.client cmd) c)))) + +(fn make-titlebar [c side -size] + (let [size (or -size beautiful.titlebar_size 10)] + (awful.titlebar c {:size size :position side}))) + +(fn setup-empty-titlebar [bar buttons] + ;; Placeholder layouts + (: bar :setup + {1 {:layout wibox.layout.fixed.vertical} + 2 {:buttons buttons + :layout wibox.layout.flex.vertical} + 3 {:layout (wibox.layout.fixed.vertical)} + :layout wibox.layout.align.vertical})) + +(fn titlebar-hook [c] + (let [buttons (gears.table.join + (mouse-button c 1 :move) + (mouse-button c 3 :resize)) + mainbar (make-titlebar c beautiful.titlebar_position)] + (setup-empty-titlebar mainbar buttons) + + (when (and (not (std.zero? beautiful.titlebar_border_width)) + beautiful.use_titlebars_for_borders) + (let [size beautiful.titlebar_border_width] + (each [_ side (ipairs [:top :right :bottom])] + (setup-empty-titlebar (make-titlebar c side size) buttons)))))) + +;;; +;; Processing + +(_G.client.connect_signal :request::titlebars titlebar-hook) + +{} +;;; decorate.fnl ends here diff --git a/module/sidebar.fnl b/module/sidebar.fnl new file mode 100644 index 0000000..01e71c4 --- /dev/null +++ b/module/sidebar.fnl @@ -0,0 +1,74 @@ +;;; module/sidebar.fnl --- Informational sidebar + +;;; Code: + +(local awful (require :awful)) +(local beautiful (require :beautiful)) +(local wibox (require :wibox)) + +(local sb-clock (wibox.widget.textclock "%H\n%M")) +(local sb-systray (doto (wibox.widget.systray) + (: :set_base_size 24))) +(local fill-width (doto (wibox.layout.fixed.horizontal) + (: :fill_space true) + (: :set_spacing 10))) + +(local bounding (wibox.container.margin + (wibox.container.place + (wibox.container.margin + nil + 10 10 10 10 beautiful.sidebar_subbox)) + 1 1 1 1 beautiful.sidebar_bg)) + +;;; +;; Functions + +(fn draw-sidebar [s] + (set s.sb-tag (awful.widget.taglist + {:screen s + :filter awful.widget.taglist.filter.selected + :style {:font (.. beautiful.font " Bold 10")} + ;; :widget_template + ;; {1 + ;; {:id "index_role" + ;; :widget wibox.widget.textbox} + ;; :widget wibox.container.margin + ;; :margins 5 + ;; :create_callback + ;; (lambda [self, c3, index, objects])} + :layout wibox.layout.fixed.vertical + } + )) + (set s.sb-tasks (awful.widget.tasklist + {:screen s + :filter awful.widget.tasklist.filter.currenttags + :style {:disable_task_name true} + })) + + (set s.sb (awful.wibar {:position beautiful.sidebar_position + :width beautiful.sidebar_width + :screen s})) + + (: s.sb :setup + {:layout wibox.layout.align.vertical + 1 {:layout wibox.layout.fixed.vertical + 1 {1 s.sb-tag + :halign "center" + :layout (wibox.container.margin nil 10 10 10 10)} + 2 s.sb-tasks} + 2 {:layout wibox.layout.fixed.vertical} + 3 {:layout wibox.layout.fixed.vertical + 1 sb-systray + 2 {1 sb-clock + :valign "center" + :halign "center" + :layout bounding}}})) + +;;; +;; Configuration + +;; (awful.screen.connect_for_each_screen +;; draw-sidebar) + +{} +;;; module/sidebar.fnl ends here diff --git a/rc.lua b/rc.lua new file mode 100644 index 0000000..fffbd06 --- /dev/null +++ b/rc.lua @@ -0,0 +1,4 @@ +local fennel = require("lib.fennel") +fennel.path = fennel.path .. ";.config/awesome/?.fnl" +table.insert(package.loaders or package.searchers, fennel.searcher) +require("cfg") -- .config/awesome/cfg.fnl