Skip to content

Commit

Permalink
Fixes for regressions in local variable behavior. (#98)
Browse files Browse the repository at this point in the history
As a result of the optimization passes, some incorrect behavior was introduced for local variables. This commit restores correct functionality. It also fixes a crash when accessing certain nodes with an undefined key.

Signed-off-by: Matthew Johnson <[email protected]>
  • Loading branch information
matajoh authored Jan 19, 2024
1 parent ebda1aa commit 6d84fc7
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 51 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# Changelog

## 2024-01-19 - Version 0.3.11
Minor improvements and bug fixes.

**New Features**
- Updated to more recent Trieste version
- More sophisticated logging

**Bug fixes**
- Comprehensions over local variables were not properly capturing the local (regression due to optimization)
- Local variable initializations were order-dependent (regression due to optimization)
- In some circumstances, indexing the data object with an undefined key caused a segfault.

**Other**
- Various CI changes due to issues with Github actions.

## 2023-09-21 - Version 0.3.10
Instrumentation and optimization.

Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.10
0.3.11
2 changes: 1 addition & 1 deletion examples/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
regorust = "0.3.10"
regorust = "0.3.11"
clap = { version = "4.0", features = ["derive"] }
1 change: 1 addition & 0 deletions src/internal.hh
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ namespace rego
PassDef explicit_enums();
PassDef body_locals(const BuiltIns& builtins);
PassDef value_locals(const BuiltIns& builtins);
PassDef compr_locals(const BuiltIns& builtins);
PassDef rules_to_compr();
PassDef compr();
PassDef absolute_refs();
Expand Down
206 changes: 161 additions & 45 deletions src/passes/init.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
#include "internal.hh"

#include <algorithm>
#include <cstddef>
#include <deque>
#include <stdexcept>

namespace
{
using namespace rego;
using namespace wf::ops;

struct InitSide
{
std::set<Location> vars;
std::set<Location> inits;
};

struct InitInfo
{
std::set<Location> lhs_vars;
std::set<Location> rhs_vars;
std::size_t index;
InitSide lhs;
InitSide rhs;
};

Node to_init(
Expand Down Expand Up @@ -37,11 +49,12 @@ namespace
return LiteralInit << lhs_vars << rhs_vars << (AssignInfix << lhs << rhs);
}

void vars_from(Node node, std::set<Location>& vars)
void inits_from(
Node node, const std::set<Location>& locals, std::set<Location>& inits)
{
if (node->type() == Var)
if (node->type() == Var && contains(locals, node->location()))
{
vars.insert(node->location());
inits.insert(node->location());
return;
}

Expand Down Expand Up @@ -78,13 +91,102 @@ namespace

for (Node child : *node)
{
vars_from(child, vars);
inits_from(child, locals, inits);
}
}

void vars_from(
Node node, const std::set<Location>& locals, std::set<Location>& vars)
{
if (node->type() == Var && contains(locals, node->location()))
{
vars.insert(node->location());
}

for (Node child : *node)
{
vars_from(child, locals, vars);
}
}

InitSide side_from(Node node, const std::set<Location>& locals)
{
InitSide side;
inits_from(node, locals, side.inits);
vars_from(node, locals, side.vars);
return side;
}

bool any_compiler_inits(const InitSide& lhs)
{
return std::any_of(lhs.inits.begin(), lhs.inits.end(), [](auto& loc) {
std::string name = loc.str();
return starts_with(name, "unify$") || starts_with(name, "out$") ||
starts_with(name, "value$");
});
}

void remove_locals(
std::deque<InitInfo>& init_deque, const std::set<Location>& to_remove)
{
std::size_t count = init_deque.size();
for (std::size_t i = 0; i < count; ++i)
{
InitInfo& init_stmt = init_deque.front();
for (auto& loc : to_remove)
{
init_stmt.lhs.vars.erase(loc);
init_stmt.lhs.inits.erase(loc);
init_stmt.rhs.vars.erase(loc);
init_stmt.rhs.inits.erase(loc);
}
if (!init_stmt.lhs.inits.empty() || !init_stmt.rhs.inits.empty())
{
init_deque.push_back(init_stmt);
}

init_deque.pop_front();
}
}

std::vector<InitInfo> sort_init_stmts(
const std::set<Location>& locals, std::deque<InitInfo>& init_deque)
{
std::set<Location> initialized;
std::vector<InitInfo> init_stmts;
while (!init_deque.empty() && initialized != locals)
{
// find all strict init statements
auto it =
std::find_if(init_deque.begin(), init_deque.end(), [](auto& init_stmt) {
return init_stmt.lhs.vars.empty() || init_stmt.rhs.vars.empty();
});

if (it == init_deque.end())
{
// we have a cycle, so we use the first statement
it = init_deque.begin();
init_stmts.push_back(*it);
}
else
{
init_stmts.push_back(*it);
}

std::set<Location> to_remove;
to_remove.insert(it->lhs.inits.begin(), it->lhs.inits.end());
to_remove.insert(it->rhs.inits.begin(), it->rhs.inits.end());
init_deque.erase(it);
remove_locals(init_deque, to_remove);
initialized.insert(to_remove.begin(), to_remove.end());
}

return init_stmts;
}

void find_init_stmts(Node unifybody, std::set<Location>& locals)
{
// gather all locals
std::deque<InitInfo> potential_init_stmts;
for (std::size_t i = 0; i < unifybody->size(); ++i)
{
Node stmt = unifybody->at(i);
Expand All @@ -95,15 +197,6 @@ namespace
else if (stmt->type() == LiteralEnum)
{
locals.erase((stmt / Item)->location());
find_init_stmts(stmt / UnifyBody, locals);
}
else if (stmt->type() == LiteralWith)
{
find_init_stmts(stmt / UnifyBody, locals);
}
else if (stmt->type() == LiteralNot)
{
find_init_stmts(stmt / UnifyBody, locals);
}
else if (stmt->type() == Literal)
{
Expand All @@ -115,42 +208,65 @@ namespace

Node lhs = expr->front();
Node rhs = expr->back();
std::set<Location> lhs_vars;
vars_from(lhs, lhs_vars);
std::set<Location> lhs_found;
std::set_intersection(
lhs_vars.begin(),
lhs_vars.end(),
locals.begin(),
locals.end(),
std::inserter(lhs_found, lhs_found.begin()));

std::set<Location> rhs_vars;
vars_from(rhs, rhs_vars);
std::set<Location> rhs_found;
std::set_intersection(
rhs_vars.begin(),
rhs_vars.end(),
locals.begin(),
locals.end(),
std::inserter(rhs_found, rhs_found.begin()));

if (lhs_found.empty() && rhs_found.empty())
{
continue;
}

for (auto& loc : lhs_found)
InitSide lhs_side = side_from(lhs, locals);
InitSide rhs_side = side_from(rhs, locals);

if (any_compiler_inits(lhs_side))
{
locals.erase(loc);
// compiler statements will never be right-assign, so we can
// use this fact later to help resolve some ambiguities
rhs_side.inits.clear();
}

for (auto& loc : rhs_found)
if (lhs_side.inits.empty() && rhs_side.inits.empty())
{
locals.erase(loc);
continue;
}

unifybody->replace_at(i, to_init(lhs, lhs_found, rhs, rhs_found));
potential_init_stmts.push_back({i, lhs_side, rhs_side});
}
}

std::vector<InitInfo> init_stmts =
sort_init_stmts(locals, potential_init_stmts);
for (std::size_t i = 0; i < init_stmts.size(); ++i)
{
InitInfo& init_stmt = init_stmts[i];
Node expr = unifybody->at(init_stmt.index)->front()->front();

Node lhs = expr->front();
Node rhs = expr->back();

for (auto& loc : init_stmt.lhs.inits)
{
locals.erase(loc);
}

for (auto& loc : init_stmt.rhs.inits)
{
locals.erase(loc);
}

unifybody->replace_at(
init_stmt.index,
to_init(lhs, init_stmt.lhs.inits, rhs, init_stmt.rhs.inits));
}

// where appropriate, recurse with the updated locals
for (Node stmt : *unifybody)
{
if (stmt->type() == LiteralEnum)
{
find_init_stmts(stmt / UnifyBody, locals);
}
else if (stmt->type() == LiteralWith)
{
find_init_stmts(stmt / UnifyBody, locals);
}
else if (stmt->type() == LiteralNot)
{
find_init_stmts(stmt / UnifyBody, locals);
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/passes/locals.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ namespace rego
RuleObj, [builtins](Node n) { return preprocess_body(n, builtins); });
locals.pre(
RuleSet, [builtins](Node n) { return preprocess_body(n, builtins); });

return locals;
}

PassDef compr_locals(const BuiltIns& builtins)
{
PassDef locals = {
"compr_locals", wf_pass_locals, dir::bottomup | dir::once};
locals.pre(
ArrayCompr, [builtins](Node n) { return preprocess_body(n, builtins); });
locals.pre(
Expand Down
1 change: 0 additions & 1 deletion src/passes/rulebody.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ namespace rego
<< (T(Var)[Lhs] * T(Var)[Rhs] * T(UnifyBody)[UnifyBody])) >>
[](Match& _) {
ACTION();
logging::Debug() << "enum";
Location value = _.fresh({"value"});
return Seq << (Lift << UnifyBody
<< (Local << (Var ^ value) << Undefined))
Expand Down
1 change: 1 addition & 0 deletions src/rego.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace rego
explicit_enums(),
body_locals(builtins),
value_locals(builtins),
compr_locals(builtins),
rules_to_compr(),
compr(),
absolute_refs(),
Expand Down
10 changes: 9 additions & 1 deletion src/unifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,15 @@ namespace rego
}
else
{
auto maybe_nodes = Resolver::apply_access(container, args[1]->node());
Node index = args[1]->node();
if (index->type() == Undefined)
{
values.push_back(
ValueDef::create(var, Undefined ^ "undefined", sources));
return values;
}

auto maybe_nodes = Resolver::apply_access(container, index);
if (maybe_nodes)
{
Nodes defs = maybe_nodes.value();
Expand Down
28 changes: 28 additions & 0 deletions tests/regocpp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1094,3 +1094,31 @@ cases:
query: data.every_some.output = x
want_result:
- x: true
- note: regocpp/bug95
modules:
- |
package test
x = c {
a = b
b = c
a = 12
}
query: data.test.x = x
want_result:
- x: 12
- note: regocpp/bug97
modules:
- |
package test
x = y {
a = [1, 2, 3]
y = {z | z = a[_]}
}
query: data.test.x = x
want_result:
- x:
- 1
- 2
- 3
2 changes: 1 addition & 1 deletion wrappers/python/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
project = 'regopy'
copyright = '2023, Microsoft'
author = 'Microsoft'
release = '0.3.10'
release = '0.3.11'

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion wrappers/rust/regorust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "regorust"
version = "0.3.10"
version = "0.3.11"
edition = "2021"
description = "Rust bindings for the rego-cpp Rego compiler and interpreter"
license = "MIT"
Expand Down

0 comments on commit 6d84fc7

Please sign in to comment.