Implement basic Teal type fixer (#50) #51

Merged
Aire-One merged 6 commits from feat/#50 into master 2022-12-21 20:28:14 +01:00
8 changed files with 216 additions and 64 deletions

View File

@ -1,42 +1,56 @@
local List = require "pl.List" local List = require "pl.List"
local Type_Info = require "entity.Type_Info" local Type_Info = require "entity.Type_Info"
local Variable_Info = require "entity.Variable_Info"
local record Function_Info local record Function_Info
metamethod __call: function(Function_Info): Function_Info metamethod __call: function(
Function_Info,
name: string,
parameters: List<Variable_Info.Variable_Info>,
return_types: List<Type_Info.Type_Info>
): Function_Info
Function_Info: Function_Info Function_Info: Function_Info
record Parameter
name: string
types: List<Type_Info.Type_Info>
end
name: string name: string
parameters: List<Parameter> parameters: List<Variable_Info.Variable_Info>
return_types: List<string> return_types: List<Type_Info.Type_Info>
append_parameter: function(self: Function_Info, name: string, type: string) append_parameter: function(self: Function_Info, name: string, type: string)
append_return_type: function(self: Function_Info, return_type: string) append_return_type: function(self: Function_Info, return_type: string)
fixup: function(Function_Info)
end end
local __Function_Info: metatable<Function_Info> = { local __Function_Info: metatable<Function_Info> = {
__call = function(_self: Function_Info): Function_Info __call = function(
return { _self: Function_Info,
name = "", name: string,
parameters = List(), parameters: List<Variable_Info.Variable_Info>,
return_types = List(), return_types: List<Type_Info.Type_Info>): Function_Info
} return {
name = name or "",
parameters = parameters or (List() as List<Variable_Info.Variable_Info>),
return_types = return_types or (List() as List<Type_Info.Type_Info>),
fixup = function(self: Function_Info)
for p in self.parameters:iter() do
p:fixup()
end
for r in self.return_types:iter() do
r:fixup()
end
end,
}
end, end,
} }
function Function_Info:append_parameter(name: string, types: List<Type_Info.Type_Info>) function Function_Info:append_parameter(name: string, types: List<Type_Info.Type_Info>)
self.parameters:append { self.parameters:append(Variable_Info(name, types))
name = name,
types = types,
}
end end
function Function_Info:append_return_type(return_type: string) function Function_Info:append_return_type(return_type: string)
self.return_types:append(return_type) self.return_types:append(Type_Info(return_type))
end end
return setmetatable({} as Function_Info, __Function_Info) return setmetatable({} as Function_Info, __Function_Info)

View File

@ -1,7 +1,14 @@
local Function_Info = require "entity.Function_Info" local Function_Info = require "entity.Function_Info"
local List = require "pl.List" local List = require "pl.List"
local Map = require "pl.Map"
local Variable_Info = require "entity.Variable_Info" local Variable_Info = require "entity.Variable_Info"
local module_to_require <const> : Map<string, string> = Map({
Shape = "gears.shape",
Surface = "gears.surface",
Widget = "wibox.widget",
})
local record Module_Doc local record Module_Doc
metamethod __call: function(Module_Doc): Module_Doc metamethod __call: function(Module_Doc): Module_Doc
@ -15,6 +22,11 @@ local record Module_Doc
properties: List<Variable_Info.Variable_Info> properties: List<Variable_Info.Variable_Info>
static_functions: List<Function_Info.Function_Info> static_functions: List<Function_Info.Function_Info>
signals: List<string> signals: List<string>
requires: Map<string, string>
fixup: function(Module_Doc)
populate_requires: function(Module_Doc)
end end
local __Module_Doc: metatable<Module_Doc> = { local __Module_Doc: metatable<Module_Doc> = {
@ -25,6 +37,78 @@ local __Module_Doc: metatable<Module_Doc> = {
properties = List(), properties = List(),
static_functions = List(), static_functions = List(),
signals = List(), signals = List(),
requires = Map(),
fixup = function(self: Module_Doc)
for c in self.constructors:iter() do
c:fixup()
if #c.return_types == 1 then
c.return_types[1].name = self.record_name
end
end
for m in self.methods:iter() do
m:fixup()
end
for p in self.properties:iter() do
p:fixup()
end
for s in self.static_functions:iter() do
s:fixup()
end
end,
populate_requires = function(self: Module_Doc)
-- TODO : Move this to other Entities. Can be a little tricky because we populate a map
for c in self.constructors:iter() do
for p in c.parameters:iter() do
for t in p.types:iter() do
local mod = module_to_require:get(t.name)
if mod then
self.requires:set(t.name, mod)
end
end
end
end
for m in self.methods:iter() do
for p in m.parameters:iter() do
for t in p.types:iter() do
local mod = module_to_require:get(t.name)
if mod then
self.requires:set(t.name, mod)
end
end
end
for t in m.return_types:iter() do
local mod = module_to_require:get(t.name)
if mod then
self.requires:set(t.name, mod)
end
end
end
for p in self.properties:iter() do
for t in p.types:iter() do
local mod = module_to_require:get(t.name)
if mod then
self.requires:set(t.name, mod)
end
end
end
for s in self.static_functions:iter() do
for p in s.parameters:iter() do
for t in p.types:iter() do
local mod = module_to_require:get(t.name)
if mod then
self.requires:set(t.name, mod)
end
end
end
for t in s.return_types:iter() do
local mod = module_to_require:get(t.name)
if mod then
self.requires:set(t.name, mod)
end
end
end
end
} }
end, end,
} }

View File

@ -1,6 +1,21 @@
local List = require "pl.List" local List = require "pl.List"
local Map = require "pl.Map" local Map = require "pl.Map"
local type_fix <const> : Map<string, string> = Map({
bool = "boolean",
client = "Client",
["gears.shape"] = "Shape",
["gears.surface"] = "Surface",
image = "Image",
int = "integer",
screen = "Screen",
shape = "Shape",
surface = "Surface",
tag = "Tag",
["wibox.widget"] = "Widget",
widget = "Widget",
})
local record Type_Info local record Type_Info
metamethod __call: function(Type_Info, record_name: string): Type_Info metamethod __call: function(Type_Info, record_name: string): Type_Info
@ -11,19 +26,26 @@ local record Type_Info
-- Map : name -> type -- Map : name -> type
-- We can't use Variable_Info here because it's a circular dependency. -- We can't use Variable_Info here because it's a circular dependency.
record_entries: Map<string, List<Type_Info>> | nil record_entries: Map<string, List<Type_Info>> | nil
fixup: function(self: Type_Info)
end end
local __Type_Info: metatable<Type_Info> = { local __Type_Info: metatable<Type_Info> = {
__call = function(_self: Type_Info, record_name: string): Type_Info __call = function(_self: Type_Info, record_name: string): Type_Info
if record_name ~= nil then
return {
name = record_name,
record_entries = Map()
}
end
return { return {
name = "", name = record_name and record_name or "",
record_entries = nil, record_entries = record_name and Map() or nil,
fixup = function(self: Type_Info)
self.name = type_fix:get(self.name) or self.name
if self.record_entries then
for _, types in (self.record_entries as Map<string, List<Type_Info>>):iter() do
for t in types:iter() do
t:fixup()
end
end
end
end
} }
end, end,
} }

View File

@ -10,6 +10,8 @@ local record Variable_Info
types: List<Type_Info.Type_Info> types: List<Type_Info.Type_Info>
constraints: List<string> constraints: List<string>
fixup: function(self: Variable_Info)
end end
local __Variable_Info: metatable<Variable_Info> = { local __Variable_Info: metatable<Variable_Info> = {
@ -19,6 +21,12 @@ local __Variable_Info: metatable<Variable_Info> = {
return { return {
name = name, name = name,
types = types, types = types,
fixup = function(self: Variable_Info)
for t in self.types:iter() do
t:fixup()
end
end,
} }
end, end,
} }

View File

@ -9,13 +9,14 @@ local Variable_Info = require "entity.Variable_Info"
local record Module local record Module
indent: function(str: string, level: number): string indent: function(str: string, level: number): string
render_requires: function(requires: Map<string, string>): string
render_typed_variable: function(name: string, types: List<Type_Info.Type_Info>): string render_typed_variable: function(name: string, types: List<Type_Info.Type_Info>): string
render_anonymous_function_signature: function(item: Function_Info.Function_Info): string render_anonymous_function_signature: function(item: Function_Info.Function_Info): string
render_record_functions: function(items: List<Function_Info.Function_Info>): string render_record_functions: function(items: List<Function_Info.Function_Info>): string
render_enum: function(name: string, values: List<string>): string render_enum: function(name: string, values: List<string>): string
render_record_properties: function(items: List<Variable_Info.Variable_Info>): string render_record_properties: function(items: List<Variable_Info.Variable_Info>): string
render_record: function(name: string, items: List<Variable_Info.Variable_Info>): string render_record: function(name: string, items: List<Variable_Info.Variable_Info>): string
render_records_from_Parameters: function(items: List<Function_Info.Parameter>): string render_records_from_Parameters: function(items: List<Variable_Info.Variable_Info>): string
end end
@ -26,6 +27,22 @@ function snippets.indent(str: string, level: number): string
return stringx.rstrip(stringx.indent(str, level, string.rep(" ", 3))) return stringx.rstrip(stringx.indent(str, level, string.rep(" ", 3)))
end end
function snippets.render_requires(requires: Map<string, string>): string
local tmpl = [[local $(name) = require "$(path)"]]
local require_statements <const> = List()
for name, path in requires:iter() do
local tmpl_args = {
name = name,
path = path,
}
require_statements:append(utils.do_or_fail(template.substitute, tmpl, tmpl_args))
end
return require_statements:concat("\n")
end
function snippets.render_typed_variable(name: string, types: List<Type_Info.Type_Info>): string function snippets.render_typed_variable(name: string, types: List<Type_Info.Type_Info>): string
local tmpl = local tmpl =
[[$(name): $(types)]] [[$(name): $(types)]]
@ -44,17 +61,21 @@ function snippets.render_anonymous_function_signature(item: Function_Info.Functi
local tmpl_args = { local tmpl_args = {
function_name = item.name, function_name = item.name,
function_parameter = item.parameters:map(function(param: Function_Info.Parameter): string function_parameter = item.parameters:map(function(param: Variable_Info.Variable_Info): string
return snippets.render_typed_variable(param.name, param.types) return snippets.render_typed_variable(param.name, param.types)
end):concat(", "), end):concat(", "),
function_return = item.return_types:concat(", "), function_return = item.return_types:map(
function(return_type: Type_Info.Type_Info): string
return return_type.name
end
):concat(", "),
} }
return utils.do_or_fail(template.substitute, tmpl, tmpl_args) return utils.do_or_fail(template.substitute, tmpl, tmpl_args)
end end
function snippets.render_records_from_Parameters(items: List<Function_Info.Parameter>): string function snippets.render_records_from_Parameters(items: List<Variable_Info.Variable_Info>): string
return items:map(function(param: Function_Info.Parameter): string return items:map(function(param: Variable_Info.Variable_Info): string
if #param.types == 0 then if #param.types == 0 then
return "" return ""
end end

View File

@ -8,6 +8,10 @@ local snippets = require "generator.snippets"
local tmpl = [[ local tmpl = [[
-- Auto generated file (Do not manually edit this file!) -- Auto generated file (Do not manually edit this file!)
# if module.requires:len() ~= 0 then
$(snippets.render_requires(module.requires))
# end -- /requires
local record $(module.record_name) local record $(module.record_name)
# if #module.signals ~= 0 then # if #module.signals ~= 0 then
$(snippets.indent(snippets.render_enum("Signal", module.signals))) $(snippets.indent(snippets.render_enum("Signal", module.signals)))

View File

@ -27,6 +27,8 @@ log:info("Finished Module List scrapping, found " .. #module_infos .. " modules"
local function do_one_file(url: string, module_name: string, output: string) local function do_one_file(url: string, module_name: string, output: string)
local html = crawler.fetch(url) local html = crawler.fetch(url)
local module_doc = scraper.module_doc.get_doc_from_page(html, module_name) local module_doc = scraper.module_doc.get_doc_from_page(html, module_name)
module_doc:fixup()
module_doc:populate_requires()
filesystem.file_writer.write( filesystem.file_writer.write(
generator.teal_type_definitions.generate_teal(module_doc), generator.teal_type_definitions.generate_teal(module_doc),
output output

View File

@ -18,13 +18,13 @@ end
local function parse_parameter_types(parameter_type: string): List<Type_Info.Type_Info> local function parse_parameter_types(parameter_type: string): List<Type_Info.Type_Info>
if parameter_type == "" then if parameter_type == "" then
local type_info: Type_Info.Type_Info = { name = "any" } local type_info: Type_Info.Type_Info = Type_Info("any")
return List({ type_info }) return List({ type_info })
end end
return stringx.split(parameter_type, " or "):map( return stringx.split(parameter_type, " or "):map(
function(type_name: string): Type_Info.Type_Info function(type_name: string): Type_Info.Type_Info
return { name = utils.sanitize_string(type_name) } return Type_Info(utils.sanitize_string(type_name))
end end
) )
end end
@ -33,7 +33,7 @@ local function extract_item_name(item_name_node: scan.HTMLNode): string
return item_name_node and ((item_name_node.attr.name as string):gsub("^.*[%.:]", "")) return item_name_node and ((item_name_node.attr.name as string):gsub("^.*[%.:]", ""))
end end
local function extract_function_parameter_Parameters(tr_node: scan.HTMLNode): { Function_Info.Parameter } local function extract_function_parameter_Parameters(tr_node: scan.HTMLNode): { Variable_Info.Variable_Info }
local query_selectors = { local query_selectors = {
name = "span.parameter", name = "span.parameter",
types = "span.types" types = "span.types"
@ -42,21 +42,21 @@ local function extract_function_parameter_Parameters(tr_node: scan.HTMLNode): {
return scraper_utils.scrape_tuples( return scraper_utils.scrape_tuples(
tr_node:outer_html(), tr_node:outer_html(),
{ query_selectors.name, query_selectors.types }, { query_selectors.name, query_selectors.types },
function(nodes: { string : scan.HTMLNode | nil }): Function_Info.Parameter function(nodes: { string : scan.HTMLNode | nil }): Variable_Info.Variable_Info
return { return Variable_Info(
name = extract_node_text(nodes[query_selectors.name] as scan.HTMLNode), extract_node_text(nodes[query_selectors.name] as scan.HTMLNode),
types = parse_parameter_types(extract_node_text(nodes[query_selectors.types] as scan.HTMLNode)), parse_parameter_types(extract_node_text(nodes[query_selectors.types] as scan.HTMLNode))
} )
end) end)
end end
local function extract_function_parameters(function_parameters_node: scan.HTMLNode): { Function_Info.Parameter } local function extract_function_parameters(function_parameters_node: scan.HTMLNode): { Variable_Info.Variable_Info }
local current_record_parameter: Type_Info.Type_Info | nil = nil local current_record_parameter: Type_Info.Type_Info | nil = nil
return scraper_utils.scrape( return scraper_utils.scrape(
function_parameters_node:outer_html(), function_parameters_node:outer_html(),
"tr", "tr",
function(line_node: scan.HTMLNode): Function_Info.Parameter function(line_node: scan.HTMLNode): Variable_Info.Variable_Info
local parameters = extract_function_parameter_Parameters(line_node) local parameters = extract_function_parameter_Parameters(line_node)
if #parameters == 0 then if #parameters == 0 then
return nil return nil
@ -81,20 +81,17 @@ local function extract_function_parameters(function_parameters_node: scan.HTMLNo
if #types == 1 and types[1].name == "table" then if #types == 1 and types[1].name == "table" then
local record_name = utils.capitalize(name) local record_name = utils.capitalize(name)
current_record_parameter = Type_Info(record_name) current_record_parameter = Type_Info(record_name)
return { return Variable_Info(
name = name, name,
types = List({ current_record_parameter }), List({ current_record_parameter })
} )
end end
return { return Variable_Info(name, types)
name = name,
types = types,
}
end) end)
end end
local function extract_function_return_types(function_return_types_node: scan.HTMLNode): { string } local function extract_function_return_types(function_return_types_node: scan.HTMLNode): List<Type_Info.Type_Info>
if not function_return_types_node then if not function_return_types_node then
return {} return {}
end end
@ -102,7 +99,10 @@ local function extract_function_return_types(function_return_types_node: scan.HT
local selector = "span.types .type" local selector = "span.types .type"
local html = function_return_types_node:outer_html() local html = function_return_types_node:outer_html()
return scraper_utils.scrape(html, selector, extract_node_text) return List(scraper_utils.scrape(html, selector, extract_node_text)):map(
function(type_name: string): Type_Info.Type_Info
return Type_Info(type_name)
end)
end end
local function extract_property_constraints(property_constraint_node: scan.HTMLNode): { string } local function extract_property_constraints(property_constraint_node: scan.HTMLNode): { string }
@ -141,19 +141,19 @@ local function extract_section_functions(dl: string): { Function_Info.Function_I
body:outer_html(), body:outer_html(),
{ query_selectors.parameters, query_selectors.return_types } { query_selectors.parameters, query_selectors.return_types }
) )
return { return Function_Info(
name = scraper_utils.scrape( scraper_utils.scrape(
header:outer_html(), header:outer_html(),
query_selectors.name, query_selectors.name,
extract_item_name extract_item_name
)[1], )[1],
parameters = #body_elements:get(query_selectors.parameters) ~= 0 and #body_elements:get(query_selectors.parameters) ~= 0 and
List(extract_function_parameters(body_elements:get(query_selectors.parameters)[1])) or List(extract_function_parameters(body_elements:get(query_selectors.parameters)[1])) or
(List() as List<Function_Info.Parameter>), (List() as List<Variable_Info.Variable_Info>),
return_types = #body_elements:get(query_selectors.return_types) ~= 0 and #body_elements:get(query_selectors.return_types) ~= 0 and
List(extract_function_return_types(body_elements:get(query_selectors.return_types)[1])) or extract_function_return_types(body_elements:get(query_selectors.return_types)[1]) or
(List() as List<string>), (List() as List<Type_Info.Type_Info>)
} )
end end
) )
end end
@ -209,11 +209,8 @@ function module.get_doc_from_page(html: string, module_name: string): Module_Doc
local module_doc = Module_Doc() local module_doc = Module_Doc()
module_doc.record_name = utils.capitalize((module_name:gsub(".*%.", ""))) module_doc.record_name = utils.capitalize((module_name:gsub(".*%.", "")))
local self_type: Type_Info.Type_Info = { name = module_doc.record_name } local self_type = Type_Info(module_doc.record_name)
local self_parameter: Function_Info.Parameter = { local self_parameter = Variable_Info("self", List({ self_type }))
name = "self",
types = List({ self_type }),
}
for i = 1, #nodes:get("h2.section-header") do for i = 1, #nodes:get("h2.section-header") do
local h2 = nodes:get("h2.section-header")[i] local h2 = nodes:get("h2.section-header")[i]