#!/usr/bin/lua

-- Metrics web server

-- Copyright (c) 2016 Jeff Schornick <jeff@schornick.org>
-- Copyright (c) 2015 Kevin Lyda
-- Licensed under the Apache License, Version 2.0

socket = require("socket")

-- Parsing

function space_split(s)
  local elements = {}
  for element in s:gmatch("%S+") do
    table.insert(elements, element)
  end
  return elements
end

function get_contents(filename)
  local f = io.open(filename, "rb")
  local contents = ""
  if f then
    contents = f:read "*a"
    f:close()
  end

  return contents
end

-- Metric printing

function print_metric(metric, labels, value)
  local label_string = ""
  if labels then
    for label,value in pairs(labels) do
      label_string =  label_string .. label .. '="' .. value .. '",'
    end
    label_string = "{" .. string.sub(label_string, 1, -2) .. "}"
  end
  output(string.format("%s%s %s", metric, label_string, value))
end

function metric(name, mtype, labels, value)
  output("# TYPE " .. name .. " " .. mtype)
  local outputter = function(labels, value)
    print_metric(name, labels, value)
  end
  if value then
    outputter(labels, value)
  end
  return outputter
end

function timed_scrape(collector)
  local start_time = socket.gettime()
  local success = 1
  local status, err = pcall(collector.scrape)
  if not status then
    success = 0
    print(err)
  end
  return (socket.gettime() - start_time), success
end

function run_all_collectors(collectors)
  local metric_duration = metric("node_scrape_collector_duration_seconds", "gauge")
  local metric_success = metric("node_scrape_collector_success", "gauge")
  for _,cname in pairs(collectors) do
    if col_mods[cname] ~= nil then
      local duration, success = timed_scrape(col_mods[cname])
      local labels = {collector=cname}
      metric_duration(labels, duration)
      metric_success(labels, success)
    end
  end
end

-- Web server-specific functions

function http_ok_header()
  output("HTTP/1.0 200 OK\r\nServer: lua-metrics\r\nContent-Type: text/plain; version=0.0.4\r\n\r")
end

function http_not_found()
  output("HTTP/1.0 404 Not Found\r\nServer: lua-metrics\r\nContent-Type: text/plain\r\n\r\nERROR: File Not Found.")
end

function serve(request)
  local q = request:match("^GET /metrics%??([^ ]*) HTTP/1%.[01]$")
  if q == nil then
    http_not_found()
  else
    http_ok_header()
    local cols = {}
    for c in q:gmatch("collect[^=]*=([^&]+)") do
      cols[#cols+1] = c
    end
    if #cols == 0 then
      cols = col_names
    end
    run_all_collectors(cols)
  end
  client:close()
  return true
end

-- Main program

for k,v in ipairs(arg) do
  if (v == "-p") or (v == "--port") then
    port = arg[k+1]
  end
  if (v == "-b") or (v == "--bind") then
    bind = arg[k+1]
  end
end

col_mods = {}
col_names = {}
ls_fd = io.popen("ls -1 /usr/lib/lua/prometheus-collectors/*.lua")
for c in ls_fd:lines() do
  c = c:match("([^/]+)%.lua$")
  col_mods[c] = require('prometheus-collectors.'..c)
  col_names[#col_names+1] = c
end
ls_fd:close()

if port then
  server = assert(socket.bind(bind, port))

  while 1 do
    client = server:accept()
    client:settimeout(60)
    local request, err = client:receive()

    if not err then
      output = function (str) client:send(str.."\n") end
      if not serve(request) then
        break
      end
    end
  end
else
  output = print
  run_all_collectors(col_names)
end