souffle  2.0.2-371-g6315b36
InlineRelations.cpp
Go to the documentation of this file.
1 /*
2  * Souffle - A Datalog Compiler
3  * Copyright (c) 2018, The Souffle Developers. All rights reserved.
4  * Licensed under the Universal Permissive License v 1.0 as shown at:
5  * - https://opensource.org/licenses/UPL
6  * - <souffle root>/licenses/SOUFFLE-UPL.txt
7  */
8 
9 /************************************************************************
10  *
11  * @file InlineRelations.cpp
12  *
13  * Define classes and functionality related to inlining.
14  *
15  ***********************************************************************/
16 
18 #include "AggregateOp.h"
19 #include "RelationTag.h"
20 #include "ast/Aggregator.h"
21 #include "ast/Argument.h"
22 #include "ast/Atom.h"
23 #include "ast/BinaryConstraint.h"
24 #include "ast/BooleanConstraint.h"
25 #include "ast/Clause.h"
26 #include "ast/Constant.h"
27 #include "ast/Constraint.h"
28 #include "ast/Functor.h"
29 #include "ast/IntrinsicFunctor.h"
30 #include "ast/Literal.h"
31 #include "ast/Negation.h"
32 #include "ast/Node.h"
33 #include "ast/NumericConstant.h"
34 #include "ast/Program.h"
35 #include "ast/QualifiedName.h"
36 #include "ast/RecordInit.h"
37 #include "ast/Relation.h"
38 #include "ast/TranslationUnit.h"
39 #include "ast/TypeCast.h"
40 #include "ast/UnnamedVariable.h"
41 #include "ast/UserDefinedFunctor.h"
42 #include "ast/Variable.h"
44 #include "ast/utility/NodeMapper.h"
45 #include "ast/utility/Utils.h"
46 #include "ast/utility/Visitor.h"
49 #include <algorithm>
50 #include <cassert>
51 #include <cstddef>
52 #include <memory>
53 #include <optional>
54 #include <ostream>
55 #include <set>
56 #include <string>
57 #include <utility>
58 #include <vector>
59 
60 namespace souffle::ast::transform {
61 
62 template <class T>
63 class NullableVector {
64 private:
65  std::vector<T> vector;
66  bool valid = false;
67 
68 public:
69  NullableVector() = default;
70  NullableVector(std::vector<T> vector) : vector(std::move(vector)), valid(true) {}
71 
72  bool isValid() const {
73  return valid;
74  }
75 
76  const std::vector<T>& getVector() const {
77  assert(valid && "Accessing invalid vector!");
78  return vector;
79  }
80 };
81 
83 
84 /**
85  * Replace constants in the head of inlined clauses with (constrained) variables.
86  */
87 bool normaliseInlinedHeads(Program& program) {
88  bool changed = false;
89  static int newVarCount = 0;
90 
91  // Go through the clauses of all inlined relations
92  for (Relation* rel : program.getRelations()) {
93  if (!rel->hasQualifier(RelationQualifier::INLINE)) {
94  continue;
95  }
96 
97  for (Clause* clause : getClauses(program, *rel)) {
98  // Set up the new clause with an empty body and no arguments in the head
99  auto newClause = mk<Clause>();
100  newClause->setSrcLoc(clause->getSrcLoc());
101  auto clauseHead = mk<Atom>(clause->getHead()->getQualifiedName());
102 
103  // Add in everything in the original body
104  for (Literal* lit : clause->getBodyLiterals()) {
105  newClause->addToBody(souffle::clone(lit));
106  }
107 
108  // Set up the head arguments in the new clause
109  for (Argument* arg : clause->getHead()->getArguments()) {
110  if (auto* constant = dynamic_cast<Constant*>(arg)) {
111  // Found a constant in the head, so replace it with a variable
112  std::stringstream newVar;
113  newVar << "<new_var_" << newVarCount++ << ">";
114  clauseHead->addArgument(mk<ast::Variable>(newVar.str()));
115 
116  // Add a body constraint to set the variable's value to be the original constant
117  newClause->addToBody(mk<BinaryConstraint>(BinaryConstraintOp::EQ,
118  mk<ast::Variable>(newVar.str()), souffle::clone(constant)));
119  } else {
120  // Already a variable
121  clauseHead->addArgument(souffle::clone(arg));
122  }
123  }
124 
125  newClause->setHead(std::move(clauseHead));
126 
127  // Replace the old clause with this one
128  program.addClause(std::move(newClause));
129  program.removeClause(clause);
130  changed = true;
131  }
132  }
133 
134  return changed;
135 }
136 
137 /**
138  * Removes all underscores in all atoms of inlined relations
139  */
140 bool nameInlinedUnderscores(Program& program) {
141  struct M : public NodeMapper {
142  mutable bool changed = false;
143  const std::set<QualifiedName> inlinedRelations;
144  bool replaceUnderscores;
145 
146  M(std::set<QualifiedName> inlinedRelations, bool replaceUnderscores)
147  : inlinedRelations(std::move(inlinedRelations)), replaceUnderscores(replaceUnderscores) {}
148 
149  Own<Node> operator()(Own<Node> node) const override {
150  static int underscoreCount = 0;
151 
152  if (!replaceUnderscores) {
153  // Check if we should start replacing underscores for this node's subnodes
154  if (auto* atom = dynamic_cast<Atom*>(node.get())) {
155  if (inlinedRelations.find(atom->getQualifiedName()) != inlinedRelations.end()) {
156  // Atom associated with an inlined relation, so replace the underscores
157  // in all of its subnodes with named variables.
158  M replace(inlinedRelations, true);
159  node->apply(replace);
160  changed |= replace.changed;
161  return node;
162  }
163  }
164  } else if (isA<UnnamedVariable>(node.get())) {
165  // Give a unique name to the underscored variable
166  // TODO (azreika): need a more consistent way of handling internally generated variables in
167  // general
168  std::stringstream newVarName;
169  newVarName << "<underscore_" << underscoreCount++ << ">";
170  changed = true;
171  return mk<ast::Variable>(newVarName.str());
172  }
173 
174  node->apply(*this);
175  return node;
176  }
177  };
178 
179  // Store the names of all relations to be inlined
180  std::set<QualifiedName> inlinedRelations;
181  for (Relation* rel : program.getRelations()) {
182  if (rel->hasQualifier(RelationQualifier::INLINE)) {
183  inlinedRelations.insert(rel->getQualifiedName());
184  }
185  }
186 
187  // Apply the renaming procedure to the entire program
188  M update(inlinedRelations, false);
189  program.apply(update);
190  return update.changed;
191 }
192 
193 /**
194  * Checks if a given clause contains an atom that should be inlined.
195  */
196 bool containsInlinedAtom(const Program& program, const Clause& clause) {
197  bool foundInlinedAtom = false;
198 
199  visitDepthFirst(clause, [&](const Atom& atom) {
200  Relation* rel = getRelation(program, atom.getQualifiedName());
201  if (rel->hasQualifier(RelationQualifier::INLINE)) {
202  foundInlinedAtom = true;
203  }
204  });
205 
206  return foundInlinedAtom;
207 }
208 
209 /**
210  * Reduces a vector of substitutions.
211  * Returns false only if matched argument pairs are found to be incompatible.
212  */
213 bool reduceSubstitution(std::vector<std::pair<Argument*, Argument*>>& sub) {
214  // Keep trying to reduce the substitutions until we reach a fixed point.
215  // Note that at this point no underscores ('_') or counters ('$') should appear.
216  bool done = false;
217  while (!done) {
218  done = true;
219 
220  // Try reducing each pair by one step
221  for (size_t i = 0; i < sub.size(); i++) {
222  auto currPair = sub[i];
223  Argument* lhs = currPair.first;
224  Argument* rhs = currPair.second;
225 
226  // Start trying to reduce the substitution
227  // Note: Can probably go further with this substitution reduction
228  if (*lhs == *rhs) {
229  // Get rid of redundant `x = x`
230  sub.erase(sub.begin() + i);
231  done = false;
232  } else if (isA<Constant>(lhs) && isA<Constant>(rhs)) {
233  // Both are constants but not equal (prev case => !=)
234  // Failed to unify!
235  return false;
236  } else if (isA<RecordInit>(lhs) && isA<RecordInit>(rhs)) {
237  // Note: we will not deal with the case where only one side is
238  // a record and the other is a variable, as variables can be records
239  // on a deeper level.
240  std::vector<Argument*> lhsArgs = static_cast<RecordInit*>(lhs)->getArguments();
241  std::vector<Argument*> rhsArgs = static_cast<RecordInit*>(rhs)->getArguments();
242 
243  if (lhsArgs.size() != rhsArgs.size()) {
244  // Records of unequal size can't be equated
245  return false;
246  }
247 
248  // Equate all corresponding arguments
249  for (size_t i = 0; i < lhsArgs.size(); i++) {
250  sub.push_back(std::make_pair(lhsArgs[i], rhsArgs[i]));
251  }
252 
253  // Get rid of the record equality
254  sub.erase(sub.begin() + i);
255  done = false;
256  } else if ((isA<RecordInit>(lhs) && isA<Constant>(rhs)) ||
257  (isA<Constant>(lhs) && isA<RecordInit>(rhs))) {
258  // A record =/= a constant
259  return false;
260  }
261  }
262  }
263 
264  return true;
265 }
266 
267 /**
268  * Returns the nullable vector of substitutions needed to unify the two given atoms.
269  * If unification is not successful, the returned vector is marked as invalid.
270  * Assumes that the atoms are both of the same relation.
271  */
272 NullableVector<std::pair<Argument*, Argument*>> unifyAtoms(Atom* first, Atom* second) {
273  std::vector<std::pair<Argument*, Argument*>> substitution;
274 
275  std::vector<Argument*> firstArgs = first->getArguments();
276  std::vector<Argument*> secondArgs = second->getArguments();
277 
278  // Create the initial unification equalities
279  for (size_t i = 0; i < firstArgs.size(); i++) {
280  substitution.push_back(std::make_pair(firstArgs[i], secondArgs[i]));
281  }
282 
283  // Reduce the substitutions
284  bool success = reduceSubstitution(substitution);
285  if (success) {
286  return NullableVector<std::pair<Argument*, Argument*>>(substitution);
287  } else {
288  // Failed to unify the two atoms
289  return NullableVector<std::pair<Argument*, Argument*>>();
290  }
291 }
292 
293 /**
294  * Inlines the given atom based on a given clause.
295  * Returns the vector of replacement literals and the necessary constraints.
296  * If unification is unsuccessful, the vector of literals is marked as invalid.
297  */
298 std::pair<NullableVector<Literal*>, std::vector<BinaryConstraint*>> inlineBodyLiterals(
299  Atom* atom, Clause* atomInlineClause) {
300  bool changed = false;
301  std::vector<Literal*> addedLits;
302  std::vector<BinaryConstraint*> constraints;
303 
304  // Rename the variables in the inlined clause to avoid conflicts when unifying multiple atoms
305  // - particularly when an inlined relation appears twice in a clause.
306  static int inlineCount = 0;
307 
308  // Make a temporary clone so we can rename variables without fear
309  auto atomClause = souffle::clone(atomInlineClause);
310 
311  struct VariableRenamer : public NodeMapper {
312  int varnum;
313  VariableRenamer(int varnum) : varnum(varnum) {}
314  Own<Node> operator()(Own<Node> node) const override {
315  if (auto* var = dynamic_cast<ast::Variable*>(node.get())) {
316  // Rename the variable
317  auto newVar = souffle::clone(var);
318  std::stringstream newName;
319  newName << "<inlined_" << var->getName() << "_" << varnum << ">";
320  newVar->setName(newName.str());
321  return newVar;
322  }
323  node->apply(*this);
324  return node;
325  }
326  };
327 
328  VariableRenamer update(inlineCount);
329  atomClause->apply(update);
330 
331  inlineCount++;
332 
333  // Get the constraints needed to unify the two atoms
334  NullableVector<std::pair<Argument*, Argument*>> res = unifyAtoms(atomClause->getHead(), atom);
335  if (res.isValid()) {
336  changed = true;
337  for (std::pair<Argument*, Argument*> pair : res.getVector()) {
338  // FIXME: float equiv (`FEQ`)
339  constraints.push_back(new BinaryConstraint(
340  BinaryConstraintOp::EQ, souffle::clone(pair.first), souffle::clone(pair.second)));
341  }
342 
343  // Add in the body of the current clause of the inlined atom
344  for (Literal* lit : atomClause->getBodyLiterals()) {
345  addedLits.push_back(lit->clone());
346  }
347  }
348 
349  if (changed) {
350  return std::make_pair(NullableVector<Literal*>(addedLits), constraints);
351  } else {
352  return std::make_pair(NullableVector<Literal*>(), constraints);
353  }
354 }
355 
356 /**
357  * Returns the negated version of a given literal
358  */
359 Literal* negateLiteral(Literal* lit) {
360  if (auto* atom = dynamic_cast<Atom*>(lit)) {
361  auto* neg = new Negation(souffle::clone(atom));
362  return neg;
363  } else if (auto* neg = dynamic_cast<Negation*>(lit)) {
364  Atom* atom = neg->getAtom()->clone();
365  return atom;
366  } else if (auto* cons = dynamic_cast<Constraint*>(lit)) {
367  Constraint* newCons = cons->clone();
368  negateConstraintInPlace(*newCons);
369  return newCons;
370  }
371 
372  fatal("unsupported literal type: %s", *lit);
373 }
374 
375 /**
376  * Return the negated version of a disjunction of conjunctions.
377  * E.g. (a1(x) ^ a2(x) ^ ...) v (b1(x) ^ b2(x) ^ ...) --into-> (!a1(x) ^ !b1(x)) v (!a2(x) ^ !b2(x)) v ...
378  */
379 std::vector<std::vector<Literal*>> combineNegatedLiterals(std::vector<std::vector<Literal*>> litGroups) {
380  std::vector<std::vector<Literal*>> negation;
381 
382  // Corner case: !() = ()
383  if (litGroups.empty()) {
384  return negation;
385  }
386 
387  std::vector<Literal*> litGroup = litGroups[0];
388  if (litGroups.size() == 1) {
389  // !(a1 ^ a2 ^ a3 ^ ...) --into-> !a1 v !a2 v !a3 v ...
390  for (auto& i : litGroup) {
391  std::vector<Literal*> newVec;
392  newVec.push_back(negateLiteral(i));
393  negation.push_back(newVec);
394  }
395 
396  // Done!
397  return negation;
398  }
399 
400  // Produce the negated versions of all disjunctions ignoring the first recursively
401  std::vector<std::vector<Literal*>> combinedRHS = combineNegatedLiterals(
402  std::vector<std::vector<Literal*>>(litGroups.begin() + 1, litGroups.end()));
403 
404  // We now just need to add the negation of a single literal from the untouched LHS
405  // to every single conjunction on the RHS to finalise creating every possible combination
406  for (Literal* lhsLit : litGroup) {
407  for (std::vector<Literal*> rhsVec : combinedRHS) {
408  std::vector<Literal*> newVec;
409  newVec.push_back(negateLiteral(lhsLit));
410 
411  for (Literal* lit : rhsVec) {
412  newVec.push_back(lit->clone());
413  }
414 
415  negation.push_back(newVec);
416  }
417  }
418 
419  for (std::vector<Literal*> rhsVec : combinedRHS) {
420  for (Literal* lit : rhsVec) {
421  delete lit;
422  }
423  }
424 
425  return negation;
426 }
427 
428 /**
429  * Forms the bodies that will replace the negation of a given inlined atom.
430  * E.g. a(x) <- (a11(x), a12(x)) ; (a21(x), a22(x)) => !a(x) <- (!a11(x), !a21(x)) ; (!a11(x), !a22(x)) ; ...
431  * Essentially, produce every combination (m_1 ^ m_2 ^ ...) where m_i is the negation of a literal in the
432  * ith rule of a.
433  */
434 std::vector<std::vector<Literal*>> formNegatedLiterals(Program& program, Atom* atom) {
435  // Constraints added to unify atoms should not be negated and should be added to
436  // all the final rule combinations produced, and so should be stored separately.
437  std::vector<std::vector<Literal*>> addedBodyLiterals;
438  std::vector<std::vector<BinaryConstraint*>> addedConstraints;
439 
440  // Go through every possible clause associated with the given atom
441  for (Clause* inClause : getClauses(program, *getRelation(program, atom->getQualifiedName()))) {
442  // Form the replacement clause by inlining based on the current clause
443  std::pair<NullableVector<Literal*>, std::vector<BinaryConstraint*>> inlineResult =
444  inlineBodyLiterals(atom, inClause);
445  NullableVector<Literal*> replacementBodyLiterals = inlineResult.first;
446  std::vector<BinaryConstraint*> currConstraints = inlineResult.second;
447 
448  if (!replacementBodyLiterals.isValid()) {
449  // Failed to unify, so just move on
450  continue;
451  }
452 
453  addedBodyLiterals.push_back(replacementBodyLiterals.getVector());
454  addedConstraints.push_back(currConstraints);
455  }
456 
457  // We now have a list of bodies needed to inline the given atom.
458  // We want to inline the negated version, however, which is done using De Morgan's Law.
459  std::vector<std::vector<Literal*>> negatedAddedBodyLiterals = combineNegatedLiterals(addedBodyLiterals);
460 
461  // Add in the necessary constraints to all the body literals
462  for (auto& negatedAddedBodyLiteral : negatedAddedBodyLiterals) {
463  for (std::vector<BinaryConstraint*> constraintGroup : addedConstraints) {
464  for (BinaryConstraint* constraint : constraintGroup) {
465  negatedAddedBodyLiteral.push_back(constraint->clone());
466  }
467  }
468  }
469 
470  // Free up the old body literals and constraints
471  for (std::vector<Literal*> litGroup : addedBodyLiterals) {
472  for (Literal* lit : litGroup) {
473  delete lit;
474  }
475  }
476  for (std::vector<BinaryConstraint*> consGroup : addedConstraints) {
477  for (Constraint* cons : consGroup) {
478  delete cons;
479  }
480  }
481 
482  return negatedAddedBodyLiterals;
483 }
484 
485 /**
486  * Renames all variables in a given argument uniquely.
487  */
488 void renameVariables(Argument* arg) {
489  static int varCount = 0;
490  varCount++;
491 
492  struct M : public NodeMapper {
493  int varnum;
494  M(int varnum) : varnum(varnum) {}
495  Own<Node> operator()(Own<Node> node) const override {
496  if (auto* var = dynamic_cast<ast::Variable*>(node.get())) {
497  auto newVar = souffle::clone(var);
498  std::stringstream newName;
499  newName << var->getName() << "-v" << varnum;
500  newVar->setName(newName.str());
501  return newVar;
502  }
503  node->apply(*this);
504  return node;
505  }
506  };
507 
508  M update(varCount);
509  arg->apply(update);
510 }
511 
512 // Performs a given binary op on a list of aggregators recursively.
513 // E.g. ( <aggr1, aggr2, aggr3, ...>, o > = (aggr1 o (aggr2 o (agg3 o (...))))
514 // TODO (azreika): remove aggregator support
515 Argument* combineAggregators(std::vector<Aggregator*> aggrs, std::string fun) {
516  // Due to variable scoping issues with aggregators, we rename all variables uniquely in the
517  // added aggregator
518  renameVariables(aggrs[0]);
519 
520  if (aggrs.size() == 1) {
521  return aggrs[0];
522  }
523 
524  Argument* rhs = combineAggregators(std::vector<Aggregator*>(aggrs.begin() + 1, aggrs.end()), fun);
525 
526  Argument* result = new IntrinsicFunctor(std::move(fun), Own<Argument>(aggrs[0]), Own<Argument>(rhs));
527 
528  return result;
529 }
530 
531 /**
532  * Returns a vector of arguments that should replace the given argument after one step of inlining.
533  * Note: This function is currently generalised to perform any required inlining within aggregators
534  * as well, making it simple to extend to this later on if desired (and the semantic check is removed).
535  */
536 // TODO (azreika): rewrite this method, removing aggregators
537 NullableVector<Argument*> getInlinedArgument(Program& program, const Argument* arg) {
538  bool changed = false;
539  std::vector<Argument*> versions;
540 
541  // Each argument has to be handled differently - essentially, want to go down to
542  // nested aggregators, and inline their bodies if needed.
543  if (const auto* aggr = dynamic_cast<const Aggregator*>(arg)) {
544  // First try inlining the target expression if necessary
545  if (aggr->getTargetExpression() != nullptr) {
546  NullableVector<Argument*> argumentVersions =
547  getInlinedArgument(program, aggr->getTargetExpression());
548 
549  if (argumentVersions.isValid()) {
550  // An element in the target expression can be inlined!
551  changed = true;
552 
553  // Create a new aggregator per version of the target expression
554  for (Argument* newArg : argumentVersions.getVector()) {
555  auto* newAggr = new Aggregator(aggr->getBaseOperator(), Own<Argument>(newArg));
556  VecOwn<Literal> newBody;
557  for (Literal* lit : aggr->getBodyLiterals()) {
558  newBody.push_back(souffle::clone(lit));
559  }
560  newAggr->setBody(std::move(newBody));
561  versions.push_back(newAggr);
562  }
563  }
564  }
565 
566  // Try inlining body arguments if the target expression has not been changed.
567  // (At this point we only handle one step of inlining at a time)
568  if (!changed) {
569  std::vector<Literal*> bodyLiterals = aggr->getBodyLiterals();
570  for (size_t i = 0; i < bodyLiterals.size(); i++) {
571  Literal* currLit = bodyLiterals[i];
572 
573  NullableVector<std::vector<Literal*>> literalVersions = getInlinedLiteral(program, currLit);
574 
575  if (literalVersions.isValid()) {
576  // Literal can be inlined!
577  changed = true;
578 
579  AggregateOp op = aggr->getBaseOperator();
580 
581  // Create an aggregator (with the same operation) for each possible body
582  std::vector<Aggregator*> aggrVersions;
583  for (std::vector<Literal*> inlineVersions : literalVersions.getVector()) {
584  Own<Argument> target;
585  if (aggr->getTargetExpression() != nullptr) {
586  target = souffle::clone(aggr->getTargetExpression());
587  }
588  auto* newAggr = new Aggregator(aggr->getBaseOperator(), std::move(target));
589 
590  VecOwn<Literal> newBody;
591  // Add in everything except the current literal being replaced
592  for (size_t j = 0; j < bodyLiterals.size(); j++) {
593  if (i != j) {
594  newBody.push_back(souffle::clone(bodyLiterals[j]));
595  }
596  }
597 
598  // Add in everything new that replaces that literal
599  for (Literal* addedLit : inlineVersions) {
600  newBody.push_back(Own<Literal>(addedLit));
601  }
602 
603  newAggr->setBody(std::move(newBody));
604  aggrVersions.push_back(newAggr);
605  }
606 
607  // Utility lambda: get functor used to tie aggregators together.
608  auto aggregateToFunctor = [](AggregateOp op) {
609  switch (op) {
610  case AggregateOp::MIN:
611  case AggregateOp::FMIN:
612  case AggregateOp::UMIN: return "min";
613  case AggregateOp::MAX:
614  case AggregateOp::FMAX:
615  case AggregateOp::UMAX: return "max";
616  case AggregateOp::SUM:
617  case AggregateOp::FSUM:
618  case AggregateOp::USUM:
619  case AggregateOp::COUNT: return "+";
620  case AggregateOp::MEAN: fatal("no translation");
621  }
622 
624  };
625  // Create the actual overall aggregator that ties the replacement aggregators together.
626  // example: min x : { a(x) }. <=> min ( min x : { a1(x) }, min x : { a2(x) }, ... )
627  if (op != AggregateOp::MEAN) {
628  versions.push_back(combineAggregators(aggrVersions, aggregateToFunctor(op)));
629  }
630  }
631 
632  // Only perform one stage of inlining at a time.
633  if (changed) {
634  break;
635  }
636  }
637  }
638  } else if (const auto* functor = dynamic_cast<const Functor*>(arg)) {
639  size_t i = 0;
640  for (auto funArg : functor->getArguments()) {
641  // TODO (azreika): use unique pointers
642  // try inlining each argument from left to right
643  NullableVector<Argument*> argumentVersions = getInlinedArgument(program, funArg);
644  if (argumentVersions.isValid()) {
645  changed = true;
646  for (Argument* newArgVersion : argumentVersions.getVector()) {
647  // same functor but with new argument version
648  VecOwn<Argument> argsCopy;
649  size_t j = 0;
650  for (auto& functorArg : functor->getArguments()) {
651  if (j == i) {
652  argsCopy.emplace_back(newArgVersion);
653  } else {
654  argsCopy.emplace_back(functorArg->clone());
655  }
656  ++j;
657  }
658  if (const auto* intrFunc = dynamic_cast<const IntrinsicFunctor*>(arg)) {
659  auto* newFunctor =
660  new IntrinsicFunctor(intrFunc->getBaseFunctionOp(), std::move(argsCopy));
661  newFunctor->setSrcLoc(functor->getSrcLoc());
662  versions.push_back(newFunctor);
663  } else if (const auto* userFunc = dynamic_cast<const UserDefinedFunctor*>(arg)) {
664  auto* newFunctor = new UserDefinedFunctor(userFunc->getName(), std::move(argsCopy));
665  newFunctor->setSrcLoc(userFunc->getSrcLoc());
666  versions.push_back(newFunctor);
667  }
668  }
669  // only one step at a time
670  break;
671  }
672  ++i;
673  }
674  } else if (const auto* cast = dynamic_cast<const ast::TypeCast*>(arg)) {
675  NullableVector<Argument*> argumentVersions = getInlinedArgument(program, cast->getValue());
676  if (argumentVersions.isValid()) {
677  changed = true;
678  for (Argument* newArg : argumentVersions.getVector()) {
679  Argument* newTypeCast = new ast::TypeCast(Own<Argument>(newArg), cast->getType());
680  versions.push_back(newTypeCast);
681  }
682  }
683  } else if (const auto* record = dynamic_cast<const RecordInit*>(arg)) {
684  std::vector<Argument*> recordArguments = record->getArguments();
685  for (size_t i = 0; i < recordArguments.size(); i++) {
686  Argument* currentRecArg = recordArguments[i];
687  NullableVector<Argument*> argumentVersions = getInlinedArgument(program, currentRecArg);
688  if (argumentVersions.isValid()) {
689  changed = true;
690  for (Argument* newArgumentVersion : argumentVersions.getVector()) {
691  auto* newRecordArg = new RecordInit();
692  for (size_t j = 0; j < recordArguments.size(); j++) {
693  if (i == j) {
694  newRecordArg->addArgument(Own<Argument>(newArgumentVersion));
695  } else {
696  newRecordArg->addArgument(souffle::clone(recordArguments[j]));
697  }
698  }
699  versions.push_back(newRecordArg);
700  }
701  }
702 
703  // Only perform one stage of inlining at a time.
704  if (changed) {
705  break;
706  }
707  }
708  }
709 
710  if (changed) {
711  // Return a valid vector - replacements need to be made!
712  return NullableVector<Argument*>(versions);
713  } else {
714  // Return an invalid vector - no inlining has occurred
715  return NullableVector<Argument*>();
716  }
717 }
718 
719 /**
720  * Returns a vector of atoms that should replace the given atom after one step of inlining.
721  * Assumes the relation the atom belongs to is not inlined itself.
722  */
723 NullableVector<Atom*> getInlinedAtom(Program& program, Atom& atom) {
724  bool changed = false;
725  std::vector<Atom*> versions;
726 
727  // Try to inline each of the atom's arguments
728  std::vector<Argument*> arguments = atom.getArguments();
729  for (size_t i = 0; i < arguments.size(); i++) {
730  Argument* arg = arguments[i];
731 
732  NullableVector<Argument*> argumentVersions = getInlinedArgument(program, arg);
733 
734  if (argumentVersions.isValid()) {
735  // Argument has replacements
736  changed = true;
737 
738  // Create a new atom per new version of the argument
739  for (Argument* newArgument : argumentVersions.getVector()) {
740  auto args = atom.getArguments();
741  VecOwn<Argument> newArgs;
742  for (size_t j = 0; j < args.size(); j++) {
743  if (j == i) {
744  newArgs.emplace_back(newArgument);
745  } else {
746  newArgs.emplace_back(args[j]->clone());
747  }
748  }
749  auto* newAtom = new Atom(atom.getQualifiedName(), std::move(newArgs), atom.getSrcLoc());
750  versions.push_back(newAtom);
751  }
752  }
753 
754  // Only perform one stage of inlining at a time.
755  if (changed) {
756  break;
757  }
758  }
759 
760  if (changed) {
761  // Return a valid vector - replacements need to be made!
762  return NullableVector<Atom*>(versions);
763  } else {
764  // Return an invalid vector - no replacements need to be made
765  return NullableVector<Atom*>();
766  }
767 }
768 
769 /**
770  * Tries to perform a single step of inlining on the given literal.
771  * Returns a pair of nullable vectors (v, w) such that:
772  * - v is valid if and only if the literal can be directly inlined, whereby it
773  * contains the bodies that replace it
774  * - if v is not valid, then w is valid if and only if the literal cannot be inlined directly,
775  * but contains a subargument that can be. In this case, it will contain the versions
776  * that will replace it.
777  * - If both are invalid, then no more inlining can occur on this literal and we are done.
778  */
779 NullableVector<std::vector<Literal*>> getInlinedLiteral(Program& program, Literal* lit) {
780  bool inlined = false;
781  bool changed = false;
782 
783  std::vector<std::vector<Literal*>> addedBodyLiterals;
784  std::vector<Literal*> versions;
785 
786  if (auto* atom = dynamic_cast<Atom*>(lit)) {
787  // Check if this atom is meant to be inlined
788  Relation* rel = getRelation(program, atom->getQualifiedName());
789 
790  if (rel->hasQualifier(RelationQualifier::INLINE)) {
791  // We found an atom in the clause that needs to be inlined!
792  // The clause needs to be replaced
793  inlined = true;
794 
795  // N new clauses should be formed, where N is the number of clauses
796  // associated with the inlined relation
797  for (Clause* inClause : getClauses(program, *rel)) {
798  // Form the replacement clause
799  std::pair<NullableVector<Literal*>, std::vector<BinaryConstraint*>> inlineResult =
800  inlineBodyLiterals(atom, inClause);
801  NullableVector<Literal*> replacementBodyLiterals = inlineResult.first;
802  std::vector<BinaryConstraint*> currConstraints = inlineResult.second;
803 
804  if (!replacementBodyLiterals.isValid()) {
805  // Failed to unify the atoms! We can skip this one...
806  continue;
807  }
808 
809  // Unification successful - the returned vector of literals represents one possible body
810  // replacement We can add in the unification constraints as part of these literals.
811  std::vector<Literal*> bodyResult = replacementBodyLiterals.getVector();
812 
813  for (BinaryConstraint* cons : currConstraints) {
814  bodyResult.push_back(cons);
815  }
816 
817  addedBodyLiterals.push_back(bodyResult);
818  }
819  } else {
820  // Not meant to be inlined, but a subargument may be
821  NullableVector<Atom*> atomVersions = getInlinedAtom(program, *atom);
822  if (atomVersions.isValid()) {
823  // Subnode needs to be inlined, so we have a vector of replacement atoms
824  changed = true;
825  for (Atom* newAtom : atomVersions.getVector()) {
826  versions.push_back(newAtom);
827  }
828  }
829  }
830  } else if (auto neg = dynamic_cast<Negation*>(lit)) {
831  // For negations, check the corresponding atom
832  Atom* atom = neg->getAtom();
833  NullableVector<std::vector<Literal*>> atomVersions = getInlinedLiteral(program, atom);
834 
835  if (atomVersions.isValid()) {
836  // The atom can be inlined
837  inlined = true;
838 
839  if (atomVersions.getVector().empty()) {
840  // No clauses associated with the atom, so just becomes a true literal
841  addedBodyLiterals.push_back({new BooleanConstraint(true)});
842  } else {
843  // Suppose an atom a(x) is inlined and has the following rules:
844  // - a(x) :- a11(x), a12(x).
845  // - a(x) :- a21(x), a22(x).
846  // Then, a(x) <- (a11(x) ^ a12(x)) v (a21(x) ^ a22(x))
847  // => !a(x) <- (!a11(x) v !a12(x)) ^ (!a21(x) v !a22(x))
848  // => !a(x) <- (!a11(x) ^ !a21(x)) v (!a11(x) ^ !a22(x)) v ...
849  // Essentially, produce every combination (m_1 ^ m_2 ^ ...) where m_i is a
850  // negated literal in the ith rule of a.
851  addedBodyLiterals = formNegatedLiterals(program, atom);
852  }
853  }
854  if (atomVersions.isValid()) {
855  for (const auto& curVec : atomVersions.getVector()) {
856  for (auto* cur : curVec) {
857  delete cur;
858  }
859  }
860  }
861  } else if (auto* constraint = dynamic_cast<BinaryConstraint*>(lit)) {
862  NullableVector<Argument*> lhsVersions = getInlinedArgument(program, constraint->getLHS());
863  if (lhsVersions.isValid()) {
864  changed = true;
865  for (Argument* newLhs : lhsVersions.getVector()) {
866  Literal* newLit = new BinaryConstraint(constraint->getBaseOperator(), Own<Argument>(newLhs),
867  souffle::clone(constraint->getRHS()));
868  versions.push_back(newLit);
869  }
870  } else {
871  NullableVector<Argument*> rhsVersions = getInlinedArgument(program, constraint->getRHS());
872  if (rhsVersions.isValid()) {
873  changed = true;
874  for (Argument* newRhs : rhsVersions.getVector()) {
875  Literal* newLit = new BinaryConstraint(constraint->getBaseOperator(),
876  souffle::clone(constraint->getLHS()), Own<Argument>(newRhs));
877  versions.push_back(newLit);
878  }
879  }
880  }
881  }
882 
883  if (changed) {
884  // Not inlined directly but found replacement literals
885  // Rewrite these as single-literal bodies
886  for (Literal* version : versions) {
887  std::vector<Literal*> newBody;
888  newBody.push_back(version);
889  addedBodyLiterals.push_back(newBody);
890  }
891  inlined = true;
892  }
893 
894  if (inlined) {
895  return NullableVector<std::vector<Literal*>>(addedBodyLiterals);
896  } else {
897  return NullableVector<std::vector<Literal*>>();
898  }
899 }
900 
901 /**
902  * Returns a list of clauses that should replace the given clause after one step of inlining.
903  * If no inlining can occur, the list will only contain a clone of the original clause.
904  */
905 std::vector<Clause*> getInlinedClause(Program& program, const Clause& clause) {
906  bool changed = false;
907  std::vector<Clause*> versions;
908 
909  // Try to inline things contained in the arguments of the head first.
910  // E.g. `a(x, max y : { b(y) }) :- c(x).`, where b should be inlined.
911  Atom* head = clause.getHead();
912  NullableVector<Atom*> headVersions = getInlinedAtom(program, *head);
913  if (headVersions.isValid()) {
914  // The head atom can be inlined!
915  changed = true;
916 
917  // Produce the new clauses with the replacement head atoms
918  for (Atom* newHead : headVersions.getVector()) {
919  auto* newClause = new Clause();
920  newClause->setSrcLoc(clause.getSrcLoc());
921 
922  newClause->setHead(Own<Atom>(newHead));
923 
924  // The body will remain unchanged
925  for (Literal* lit : clause.getBodyLiterals()) {
926  newClause->addToBody(souffle::clone(lit));
927  }
928 
929  versions.push_back(newClause);
930  }
931  }
932 
933  // Only perform one stage of inlining at a time.
934  // If the head atoms did not need inlining, try inlining atoms nested in the body.
935  if (!changed) {
936  std::vector<Literal*> bodyLiterals = clause.getBodyLiterals();
937  for (size_t i = 0; i < bodyLiterals.size(); i++) {
938  Literal* currLit = bodyLiterals[i];
939 
940  // Three possible cases when trying to inline a literal:
941  // 1) The literal itself may be directly inlined. In this case, the atom can be replaced
942  // with multiple different bodies, as the inlined atom may have several rules.
943  // 2) Otherwise, the literal itself may not need to be inlined, but a subnode (e.g. an argument)
944  // may need to be inlined. In this case, an altered literal must replace the original.
945  // Again, several possible versions may exist, as the inlined relation may have several rules.
946  // 3) The literal does not depend on any inlined relations, and so does not need to be changed.
947  NullableVector<std::vector<Literal*>> litVersions = getInlinedLiteral(program, currLit);
948 
949  if (litVersions.isValid()) {
950  // Case 1 and 2: Inlining has occurred!
951  changed = true;
952 
953  // The literal may be replaced with several different bodies.
954  // Create a new clause for each possible version.
955  std::vector<std::vector<Literal*>> bodyVersions = litVersions.getVector();
956 
957  // Create the base clause with the current literal removed
958  auto baseClause = Own<Clause>(cloneHead(&clause));
959  for (Literal* oldLit : bodyLiterals) {
960  if (currLit != oldLit) {
961  baseClause->addToBody(souffle::clone(oldLit));
962  }
963  }
964 
965  for (std::vector<Literal*> body : bodyVersions) {
966  Clause* replacementClause = baseClause->clone();
967 
968  // Add in the current set of literals replacing the inlined literal
969  // In Case 2, each body contains exactly one literal
970  for (Literal* newLit : body) {
971  replacementClause->addToBody(Own<Literal>(newLit));
972  }
973 
974  versions.push_back(replacementClause);
975  }
976  }
977 
978  // Only replace at most one literal per iteration
979  if (changed) {
980  break;
981  }
982  }
983  }
984 
985  if (!changed) {
986  // Case 3: No inlining changes, so a clone of the original should be returned
987  std::vector<Clause*> ret;
988  ret.push_back(clause.clone());
989  return ret;
990  } else {
991  // Inlining changes, so return the replacement clauses.
992  return versions;
993  }
994 }
995 
996 bool InlineRelationsTransformer::transform(TranslationUnit& translationUnit) {
997  bool changed = false;
998  Program& program = translationUnit.getProgram();
999 
1000  // Replace constants in the head of inlined clauses with (constrained) variables.
1001  // This is done to simplify atom unification, particularly when negations are involved.
1002  changed |= normaliseInlinedHeads(program);
1003 
1004  // Remove underscores in inlined atoms in the program to avoid issues during atom unification
1005  changed |= nameInlinedUnderscores(program);
1006 
1007  // Keep trying to inline things until we reach a fixed point.
1008  // Since we know there are no cyclic dependencies between inlined relations, this will necessarily
1009  // terminate.
1010  bool clausesChanged = true;
1011  while (clausesChanged) {
1012  std::set<Clause*> clausesToDelete;
1013  clausesChanged = false;
1014 
1015  // Go through each relation in the program and check if we need to inline any of its clauses
1016  for (Relation* rel : program.getRelations()) {
1017  // Skip if the relation is going to be inlined
1018  if (rel->hasQualifier(RelationQualifier::INLINE)) {
1019  continue;
1020  }
1021 
1022  // Go through the relation's clauses and try inlining them
1023  for (Clause* clause : getClauses(program, *rel)) {
1024  if (containsInlinedAtom(program, *clause)) {
1025  // Generate the inlined versions of this clause - the clause will be replaced by these
1026  std::vector<Clause*> newClauses = getInlinedClause(program, *clause);
1027 
1028  // Replace the clause with these equivalent versions
1029  clausesToDelete.insert(clause);
1030  for (Clause* replacementClause : newClauses) {
1031  program.addClause(Own<Clause>(replacementClause));
1032  }
1033 
1034  // We've changed the program this iteration
1035  clausesChanged = true;
1036  changed = true;
1037  }
1038  }
1039  }
1040 
1041  // Delete all clauses that were replaced
1042  for (const Clause* clause : clausesToDelete) {
1043  program.removeClause(clause);
1044  changed = true;
1045  }
1046  }
1047 
1048  return changed;
1049 }
1050 
1051 } // namespace souffle::ast::transform
souffle::ast::cloneHead
Clause * cloneHead(const Clause *clause)
Returns a clause which contains head of the given clause.
Definition: Utils.cpp:254
souffle::AggregateOp::MIN
@ MIN
BinaryConstraintOps.h
souffle::ast::transform::normaliseInlinedHeads
bool normaliseInlinedHeads(Program &program)
Replace constants in the head of inlined clauses with (constrained) variables.
Definition: InlineRelations.cpp:93
souffle::AggregateOp::USUM
@ USUM
souffle::ast::transform::NullableVector
Definition: InlineRelations.cpp:69
UNREACHABLE_BAD_CASE_ANALYSIS
#define UNREACHABLE_BAD_CASE_ANALYSIS
Definition: MiscUtil.h:206
TranslationUnit.h
Functor.h
souffle::ast::transform::renameVariables
void renameVariables(Argument *arg)
Renames all variables in a given argument uniquely.
Definition: InlineRelations.cpp:494
souffle::ast::BooleanConstraint
Boolean constraint class.
Definition: BooleanConstraint.h:45
UnnamedVariable.h
souffle::ast::transform::reduceSubstitution
bool reduceSubstitution(std::vector< std::pair< Argument *, Argument * >> &sub)
Reduces a vector of substitutions.
Definition: InlineRelations.cpp:219
souffle::AggregateOp::FSUM
@ FSUM
InlineRelations.h
souffle::ast::transform::NullableVector::valid
bool valid
Definition: InlineRelations.cpp:78
AggregateOp.h
souffle::ast::analysis::sub
std::shared_ptr< Constraint< Var > > sub(const Var &a, const Var &b, const std::string &symbol="⊑")
A generic factory for constraints of the form.
Definition: ConstraintSystem.h:228
souffle::ast::transform::nameInlinedUnderscores
bool nameInlinedUnderscores(Program &program)
Removes all underscores in all atoms of inlined relations.
Definition: InlineRelations.cpp:146
souffle::AggregateOp
AggregateOp
Types of aggregation functions.
Definition: AggregateOp.h:34
souffle::ast::NodeMapper
An abstract class for manipulating AST Nodes by substitution.
Definition: NodeMapper.h:36
souffle::ast::transform::InlineRelationsTransformer::transform
bool transform(TranslationUnit &translationUnit) override
Definition: InlineRelations.cpp:1002
souffle::Own
std::unique_ptr< A > Own
Definition: ContainerUtil.h:42
souffle::BinaryConstraintOp::EQ
@ EQ
MiscUtil.h
souffle::ast::Relation
Defines a relation with a name, attributes, qualifiers, and internal representation.
Definition: Relation.h:49
souffle::ast::Clause
Intermediate representation of a horn clause.
Definition: Clause.h:51
Constraint.h
BooleanConstraint.h
souffle::ast::analysis::Variable
A variable to be utilized within constraints to be handled by the constraint solver.
Definition: ConstraintSystem.h:41
Relation.h
IntrinsicFunctor.h
rhs
Own< Argument > rhs
Definition: ResolveAliases.cpp:185
souffle::ast::Atom
An atom class.
Definition: Atom.h:51
j
var j
Definition: htmlJsChartistMin.h:15
Utils.h
NodeMapper.h
Constant.h
souffle::ast::Argument
An abstract class for arguments.
Definition: Argument.h:33
souffle::ast::analysis::Constraint
A generic base class for constraints on variables.
Definition: ConstraintSystem.h:44
UserDefinedFunctor.h
lhs
Own< Argument > lhs
Definition: ResolveAliases.cpp:184
souffle::AggregateOp::UMIN
@ UMIN
Argument.h
souffle::ast::transform::containsInlinedAtom
bool containsInlinedAtom(const Program &program, const Clause &clause)
Checks if a given clause contains an atom that should be inlined.
Definition: InlineRelations.cpp:202
souffle::ast::Program
The program class consists of relations, clauses and types.
Definition: Program.h:52
TypeCast.h
souffle::clone
auto clone(const std::vector< A * > &xs)
Definition: ContainerUtil.h:172
souffle::ast::transform::combineNegatedLiterals
std::vector< std::vector< Literal * > > combineNegatedLiterals(std::vector< std::vector< Literal * >> litGroups)
Return the negated version of a disjunction of conjunctions.
Definition: InlineRelations.cpp:385
i
size_t i
Definition: json11.h:663
souffle::ast::getRelation
Relation * getRelation(const Program &program, const QualifiedName &name)
Returns the relation with the given name in the program.
Definition: Utils.cpp:101
souffle::ast::transform::getInlinedClause
std::vector< Clause * > getInlinedClause(Program &program, const Clause &clause)
Returns a list of clauses that should replace the given clause after one step of inlining.
Definition: InlineRelations.cpp:911
souffle::ast::transform::getInlinedAtom
NullableVector< Atom * > getInlinedAtom(Program &program, Atom &atom)
Returns a vector of atoms that should replace the given atom after one step of inlining.
Definition: InlineRelations.cpp:729
souffle::ast::transform::NullableVector::vector
std::vector< T > vector
Definition: InlineRelations.cpp:77
RelationTag.h
souffle::AggregateOp::MAX
@ MAX
souffle::ast::Negation
Negation of an atom negated atom.
Definition: Negation.h:50
souffle::ast::IntrinsicFunctor
Intrinsic Functor class for functors are in-built.
Definition: IntrinsicFunctor.h:47
Atom.h
souffle::AggregateOp::UMAX
@ UMAX
souffle::AggregateOp::FMIN
@ FMIN
Literal.h
PolymorphicObjects.h
souffle::ast::transform::negateLiteral
Literal * negateLiteral(Literal *lit)
Returns the negated version of a given literal.
Definition: InlineRelations.cpp:365
souffle::ast::transform::getInlinedLiteral
NullableVector< std::vector< Literal * > > getInlinedLiteral(Program &, Literal *)
Tries to perform a single step of inlining on the given literal.
Definition: InlineRelations.cpp:785
souffle::ast::Program::addClause
void addClause(Own< Clause > clause)
Add a clause.
Definition: Program.cpp:62
souffle::AggregateOp::FMAX
@ FMAX
souffle::ast::Literal
Defines an abstract class for literals in a horn clause.
Definition: Literal.h:36
souffle::ast::transform
Definition: Program.h:45
souffle::ast::Program::getRelations
std::vector< Relation * > getRelations() const
Return relations.
Definition: Program.h:60
Negation.h
souffle::ast::transform::NullableVector::NullableVector
NullableVector()=default
souffle::ast::transform::NullableVector::isValid
bool isValid() const
Definition: InlineRelations.cpp:84
souffle::ast::getClauses
std::vector< Clause * > getClauses(const Program &program, const QualifiedName &relationName)
Returns a vector of clauses in the program describing the relation with the given name.
Definition: Utils.cpp:77
souffle::ast::Functor
Abstract functor class.
Definition: Functor.h:36
Node.h
souffle::ast::BinaryConstraint
Binary constraint class.
Definition: BinaryConstraint.h:53
souffle::AggregateOp::COUNT
@ COUNT
souffle::AggregateOp::SUM
@ SUM
Aggregator.h
souffle::ast::Aggregator
Defines the aggregator class.
Definition: Aggregator.h:53
souffle::AggregateOp::MEAN
@ MEAN
QualifiedName.h
std
Definition: Brie.h:3053
Program.h
souffle::ast::transform::NullableVector::getVector
const std::vector< T > & getVector() const
Definition: InlineRelations.cpp:88
souffle::ast::transform::unifyAtoms
NullableVector< std::pair< Argument *, Argument * > > unifyAtoms(Atom *first, Atom *second)
Returns the nullable vector of substitutions needed to unify the two given atoms.
Definition: InlineRelations.cpp:278
souffle::RelationQualifier::INLINE
@ INLINE
souffle::fatal
void fatal(const char *format, const Args &... args)
Definition: MiscUtil.h:198
souffle::ast::negateConstraintInPlace
void negateConstraintInPlace(Constraint &constraint)
Negate an ast constraint.
Definition: Utils.cpp:297
Visitor.h
Clause.h
BinaryConstraint.h
souffle::ast::transform::formNegatedLiterals
std::vector< std::vector< Literal * > > formNegatedLiterals(Program &program, Atom *atom)
Forms the bodies that will replace the negation of a given inlined atom.
Definition: InlineRelations.cpp:440
souffle::ast::visitDepthFirst
void visitDepthFirst(const Node &root, Visitor< R, Ps... > &visitor, Args &... args)
A utility function visiting all nodes within the ast rooted by the given node recursively in a depth-...
Definition: Visitor.h:273
souffle::ast::transform::inlineBodyLiterals
std::pair< NullableVector< Literal * >, std::vector< BinaryConstraint * > > inlineBodyLiterals(Atom *atom, Clause *atomInlineClause)
Inlines the given atom based on a given clause.
Definition: InlineRelations.cpp:304
rel
void rel(size_t limit, bool showLimit=true)
Definition: Tui.h:1086
souffle::ast::Program::removeClause
bool removeClause(const Clause *clause)
Remove a clause.
Definition: Program.cpp:68
souffle::VecOwn
std::vector< Own< A > > VecOwn
Definition: ContainerUtil.h:45
RecordInit.h
souffle::ast::RecordInit
Defines a record initialization class.
Definition: RecordInit.h:42
NumericConstant.h
souffle::ast::transform::combineAggregators
Argument * combineAggregators(std::vector< Aggregator * > aggrs, std::string fun)
Definition: InlineRelations.cpp:521
Variable.h
souffle::ast::transform::getInlinedArgument
NullableVector< Argument * > getInlinedArgument(Program &program, const Argument *arg)
Returns a vector of arguments that should replace the given argument after one step of inlining.
Definition: InlineRelations.cpp:543