48 #include <type_traits>
64 using map_t = std::map<std::string, Own<Argument>>;
72 Substitution() =
default;
74 Substitution(
const std::string& var,
const Argument* arg) {
78 ~Substitution() =
default;
87 Own<Node> operator()(Own<Node> node)
const {
89 struct M :
public NodeMapper {
94 using NodeMapper::operator();
96 Own<Node> operator()(Own<Node> node)
const override {
98 if (
auto var =
dynamic_cast<ast::Variable*
>(node.get())) {
99 auto pos =
map.find(var->getName());
100 if (pos !=
map.end()) {
118 template <
typename T>
119 Own<T> operator()(Own<T> node)
const {
120 Own<Node> resPtr = (*this)(Own<Node>(node.release()));
121 assert(isA<T>(resPtr.get()) &&
"Invalid node type mapping.");
122 return Own<T>(
dynamic_cast<T*
>(resPtr.release()));
132 void append(
const Substitution&
sub) {
135 pair.second =
sub(std::move(pair.second));
139 for (
const auto& pair :
sub.varToTerm) {
148 void print(std::ostream& out)
const {
151 [](std::ostream& out,
const std::pair<
const std::string, Own<Argument>>& cur) {
152 out << cur.first <<
" -> " << *cur.second;
157 friend std::ostream&
operator<<(std::ostream& out,
const Substitution& s) __attribute__((unused)) {
173 Equation(
const Argument&
lhs,
const Argument&
rhs)
180 Equation(Equation&& other) =
default;
182 ~Equation() =
default;
187 void apply(
const Substitution&
sub) {
195 void print(std::ostream& out)
const {
196 out << *
lhs <<
" = " << *
rhs;
199 friend std::ostream&
operator<<(std::ostream& out,
const Equation&
e) __attribute__((unused)) {
211 auto isVar = [&](
const Argument& arg) {
return isA<ast::Variable>(&arg); };
214 auto isRec = [&](
const Argument& arg) {
return isA<RecordInit>(&arg); };
217 auto isMultiResultFunctor = [&](
const Argument& arg) {
219 if (inf ==
nullptr)
return false;
232 std::set<std::string> baseGroundedVariables;
233 for (
const auto* atom : getBodyLiterals<Atom>(clause)) {
234 for (
const Argument* arg : atom->getArguments()) {
235 if (
const auto* var =
dynamic_cast<const ast::Variable*
>(arg)) {
236 baseGroundedVariables.insert(var->getName());
240 for (
const Argument* arg : rec.getArguments()) {
241 if (const auto* var = dynamic_cast<const ast::Variable*>(arg)) {
242 baseGroundedVariables.insert(var->getName());
249 std::vector<Equation> equations;
252 equations.push_back(Equation(constraint.getLHS(), constraint.getRHS()));
257 Substitution substitution;
260 auto newMapping = [&](
const std::string& var,
const Argument* term) {
262 Substitution newMapping(var, term);
265 for (
auto& equation : equations) {
266 equation.apply(newMapping);
270 substitution.append(newMapping);
273 while (!equations.empty()) {
275 Equation equation = equations.back();
276 equations.pop_back();
279 const Argument&
lhs = *equation.lhs;
280 const Argument&
rhs = *equation.rhs;
288 if (isRec(
lhs) && isRec(
rhs)) {
290 const auto& lhs_args =
static_cast<const RecordInit&
>(
lhs).getArguments();
291 const auto& rhs_args =
static_cast<const RecordInit&
>(
rhs).getArguments();
294 assert(lhs_args.size() == rhs_args.size() &&
"Record lengths not equal");
297 for (
size_t i = 0;
i < lhs_args.size();
i++) {
298 equations.push_back(Equation(lhs_args[
i], rhs_args[
i]));
305 if (!isVar(
lhs) && !isVar(
rhs)) {
310 if (isVar(
lhs) && isVar(
rhs)) {
311 auto& var =
static_cast<const ast::Variable&
>(
lhs);
312 newMapping(var.getName(), &
rhs);
318 equations.push_back(Equation(
rhs,
lhs));
326 const auto& v =
static_cast<const ast::Variable&
>(
lhs);
327 const Argument& t =
rhs;
330 if (isMultiResultFunctor(t)) {
339 assert(!occurs(v, t));
343 newMapping(v.getName(), &t);
348 auto pos = baseGroundedVariables.find(v.getName());
349 if (pos != baseGroundedVariables.end()) {
354 newMapping(v.getName(), &t);
361 Own<Clause> ResolveAliasesTransformer::removeTrivialEquality(
const Clause& clause) {
365 for (Literal* literal : clause.getBodyLiterals()) {
366 if (
auto* constraint =
dynamic_cast<BinaryConstraint*
>(literal)) {
369 if (*constraint->getLHS() == *constraint->getRHS()) {
382 Own<Clause> ResolveAliasesTransformer::removeComplexTermsInAtoms(
const Clause& clause) {
383 Own<Clause> res(clause.clone());
386 std::vector<Atom*> atoms = getBodyLiterals<Atom>(*res);
389 std::vector<const Argument*> terms;
390 for (
const Atom* atom : atoms) {
391 for (
const Argument* arg : atom->getArguments()) {
393 if (!isA<Functor>(arg)) {
398 if (!
any_of(terms, [&](
const Argument* cur) {
return *cur == *arg; })) {
399 terms.push_back(arg);
408 if (!isA<Functor>(arg)) {
413 if (!
any_of(terms, [&](
const Argument* cur) {
return *cur == *arg; })) {
414 terms.push_back(arg);
420 using substitution_map = std::vector<std::pair<Own<Argument>, Own<ast::Variable>>>;
421 substitution_map termToVar;
423 static int varCounter = 0;
424 for (
const Argument* arg : terms) {
427 auto newVariable = mk<ast::Variable>(
" _tmp_" +
toString(varCounter++));
428 termToVar.push_back(std::make_pair(std::move(term), std::move(newVariable)));
432 struct Update :
public NodeMapper {
433 const substitution_map&
map;
435 Update(
const substitution_map&
map) :
map(
map) {}
437 Own<Node> operator()(Own<Node> node)
const override {
439 for (
const auto& pair :
map) {
440 auto& term = pair.first;
441 auto& variable = pair.second;
443 if (*term == *node) {
455 Update update(termToVar);
456 for (Atom* atom : atoms) {
461 for (
const auto& pair : termToVar) {
462 auto& term = pair.first;
463 auto& variable = pair.second;
472 bool ResolveAliasesTransformer::transform(TranslationUnit& translationUnit) {
473 bool changed =
false;
474 Program& program = translationUnit.getProgram();
477 std::vector<const Clause*>
clauses;
491 Own<Clause> cleaned = removeTrivialEquality(*noAlias);
495 Own<Clause> normalised = removeComplexTermsInAtoms(*cleaned);
498 if (*normalised != *clause) {
500 program.removeClause(clause);
501 program.addClause(std::move(normalised));