souffle  2.0.2-371-g6315b36
MaterializeAggregationQueries.cpp
Go to the documentation of this file.
1 /*
2  * Souffle - A Datalog Compiler
3  * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved
4  * Licensed under the Universal Permissive License v 1.0 as shown at:
5  * - https://opensource.org/licenses/UPL
6  * - <souffle root>/licenses/SOUFFLE-UPL.txt
7  */
8 
9 /************************************************************************
10  *
11  * @file MaterializeAggregationQueries.cpp
12  *
13  ***********************************************************************/
14 
16 #include "AggregateOp.h"
17 #include "ast/Aggregator.h"
18 #include "ast/Argument.h"
19 #include "ast/Atom.h"
20 #include "ast/Attribute.h"
21 #include "ast/Clause.h"
22 #include "ast/Literal.h"
23 #include "ast/Node.h"
24 #include "ast/Program.h"
25 #include "ast/QualifiedName.h"
26 #include "ast/Relation.h"
27 #include "ast/TranslationUnit.h"
28 #include "ast/UnnamedVariable.h"
29 #include "ast/Variable.h"
30 #include "ast/analysis/Aggregate.h"
31 #include "ast/analysis/Ground.h"
32 #include "ast/analysis/Type.h"
35 #include "ast/utility/Utils.h"
36 #include "ast/utility/Visitor.h"
37 #include "souffle/TypeAttribute.h"
40 #include <algorithm>
41 #include <map>
42 #include <memory>
43 #include <set>
44 #include <utility>
45 #include <vector>
46 
47 namespace souffle::ast::transform {
48 
50  // I should not be fiddling with aggregates that are in the aggregate clause.
51  // We can short circuit if we find an aggregate node.
52  struct InstantiateUnnamedVariables : public NodeMapper {
53  mutable int count = 0;
54  Own<Node> operator()(Own<Node> node) const override {
55  if (isA<UnnamedVariable>(node.get())) {
56  return mk<Variable>("_" + toString(count++));
57  }
58  if (isA<Aggregator>(node.get())) {
59  // then DON'T recurse
60  return node;
61  }
62  node->apply(*this);
63  return node;
64  }
65  };
66 
67  InstantiateUnnamedVariables update;
68  for (const auto& lit : aggClause.getBodyLiterals()) {
69  lit->apply(update);
70  }
71 }
72 
74  const TranslationUnit& tu, const Clause& clause, const Aggregator& aggregate) {
75  /**
76  * The head atom should contain immediate local and injected variables.
77  * No witnesses! They have already been transformed away.
78  * This means that we exclude any inner aggregate local variables. But
79  * we do NOT exclude inner aggregate injected variables!! It's important
80  * that the injected variable ends up in this head so that we do not obfuscate
81  * the injected variable's relationship to the outer scope.
82  * It does not affect the aggregate value because adding an extra column
83  * for the injected variable, where that column will only have one value at a time,
84  * will essentially replicate the aggregate body relation for as many possible
85  * values of the injected variable that there are. The fact that the injected variable
86  * will take one value at a time is key.
87  **/
88  std::set<std::string> headArguments;
89  // find local variables of this aggregate and add them
90  for (const auto& localVarName : analysis::getLocalVariables(tu, clause, aggregate)) {
91  headArguments.insert(localVarName);
92  }
93  // find local variables of inner aggregate and remove them
94  visitDepthFirst(aggregate, [&](const Aggregator& innerAggregate) {
95  if (aggregate == innerAggregate) {
96  return;
97  }
98  for (const auto& innerLocalVariableName : analysis::getLocalVariables(tu, clause, innerAggregate)) {
99  headArguments.erase(innerLocalVariableName);
100  }
101  });
102  // find injected variables of this aggregate and add them
103  for (const auto& injectedVarName : analysis::getInjectedVariables(tu, clause, aggregate)) {
104  headArguments.insert(injectedVarName);
105  }
106  return headArguments;
107 }
108 
109 // TODO (Issue 1696): Deal with recursive parameters with an assert statement.
111  const TranslationUnit& translationUnit, Clause& aggClause, const Clause& originalClause,
112  const Aggregator& aggregate) {
113  /**
114  * Mask inner aggregates to make sure we don't consider them grounded and everything.
115  **/
116  struct NegateAggregateAtoms : public NodeMapper {
117  std::unique_ptr<Node> operator()(std::unique_ptr<Node> node) const override {
118  if (auto* aggregate = dynamic_cast<Aggregator*>(node.get())) {
119  /**
120  * Go through body literals. If the literal is an atom,
121  * then replace the atom with a negated version of the atom, so that
122  * injected parameters that occur in an inner aggregate don't "seem" grounded.
123  **/
124  std::vector<Own<Literal>> newBody;
125  for (const auto& lit : aggregate->getBodyLiterals()) {
126  if (auto* atom = dynamic_cast<Atom*>(lit)) {
127  newBody.push_back(mk<Negation>(souffle::clone(atom)));
128  }
129  }
130  aggregate->setBody(std::move(newBody));
131  }
132  node->apply(*this);
133  return node;
134  }
135  };
136 
137  auto aggClauseInnerAggregatesMasked = souffle::clone(&aggClause);
138  aggClauseInnerAggregatesMasked->setHead(mk<Atom>("*"));
139  NegateAggregateAtoms update;
140  aggClauseInnerAggregatesMasked->apply(update);
141 
142  // what is the set of injected variables? Those are the ones we need to ground.
143  std::set<std::string> injectedVariables =
144  analysis::getInjectedVariables(translationUnit, originalClause, aggregate);
145 
146  std::set<std::string> alreadyGrounded;
147  for (const auto& argPair : analysis::getGroundedTerms(translationUnit, *aggClauseInnerAggregatesMasked)) {
148  const auto* variable = dynamic_cast<const ast::Variable*>(argPair.first);
149  bool variableIsGrounded = argPair.second;
150  if (variable == nullptr || variableIsGrounded) {
151  continue;
152  }
153  // If it's not an injected variable, we don't need to ground it
154  if (injectedVariables.find(variable->getName()) == injectedVariables.end()) {
155  continue;
156  }
157 
158  std::string ungroundedVariableName = variable->getName();
159  if (alreadyGrounded.find(ungroundedVariableName) != alreadyGrounded.end()) {
160  // may as well not bother with it because it has already
161  // been grounded in a previous iteration
162  continue;
163  }
164  // Try to find any atom in the rule where this ungrounded variable is mentioned
165  for (const auto& lit : originalClause.getBodyLiterals()) {
166  // -1. This may not be the same literal
167  bool originalAggregateFound = false;
168  visitDepthFirst(*lit, [&](const Aggregator& a) {
169  if (a == aggregate) {
170  originalAggregateFound = true;
171  return;
172  }
173  });
174  if (originalAggregateFound) {
175  continue;
176  }
177  // 0. Variable must not already have been grounded
178  if (alreadyGrounded.find(ungroundedVariableName) != alreadyGrounded.end()) {
179  continue;
180  }
181  // 1. Variable must occur in this literal
182  bool variableOccursInLit = false;
183  visitDepthFirst(*lit, [&](const Variable& var) {
184  if (var.getName() == ungroundedVariableName) {
185  variableOccursInLit = true;
186  }
187  });
188  if (!variableOccursInLit) {
189  continue;
190  }
191  // 2. Variable must be grounded by this literal.
192  auto singleLiteralClause = mk<Clause>();
193  singleLiteralClause->addToBody(souffle::clone(lit));
194  bool variableGroundedByLiteral = false;
195  for (const auto& argPair : analysis::getGroundedTerms(translationUnit, *singleLiteralClause)) {
196  const auto* var = dynamic_cast<const ast::Variable*>(argPair.first);
197  if (var == nullptr) {
198  continue;
199  }
200  bool isGrounded = argPair.second;
201  if (var->getName() == ungroundedVariableName && isGrounded) {
202  variableGroundedByLiteral = true;
203  }
204  }
205  if (!variableGroundedByLiteral) {
206  continue;
207  }
208  // 3. if it's an atom:
209  // the relation must be of a lower stratum for us to be able to add it. (not implemented)
210  // sanitise the atom by removing any unnecessary arguments that aren't constants
211  // or basically just any other variables
212  if (const auto* atom = dynamic_cast<const Atom*>(lit)) {
213  // Right now we only allow things to be grounded by atoms.
214  // This is limiting but the case of it being grounded by
215  // something else becomes complicated VERY quickly.
216  // It may involve pulling in a cascading series of literals like
217  // x = y, y = 4. It just seems very painful.
218  // remove other unnecessary bloating arguments and replace with an underscore
219  VecOwn<Argument> arguments;
220  for (auto arg : atom->getArguments()) {
221  if (auto* var = dynamic_cast<ast::Variable*>(arg)) {
222  if (var->getName() == ungroundedVariableName) {
223  arguments.emplace_back(arg->clone());
224  continue;
225  }
226  }
227  arguments.emplace_back(new UnnamedVariable());
228  }
229 
230  auto groundingAtom =
231  mk<Atom>(atom->getQualifiedName(), std::move(arguments), atom->getSrcLoc());
232  aggClause.addToBody(souffle::clone(groundingAtom));
233  alreadyGrounded.insert(ungroundedVariableName);
234  }
235  }
236  assert(alreadyGrounded.find(ungroundedVariableName) != alreadyGrounded.end() &&
237  "Error: Unable to ground parameter in materialisation-requiring aggregate body");
238  // after this loop, we should have added at least one thing to provide a grounding.
239  // If not, we should error out. The program will not be able to run.
240  // We have an ungrounded variable that we cannot ground once the aggregate body is
241  // outlined.
242  }
243 }
244 
246  TranslationUnit& translationUnit) {
247  bool changed = false;
248  Program& program = translationUnit.getProgram();
249  /**
250  * GENERAL PROCEDURE FOR MATERIALISING THE BODY OF AN AGGREGATE:
251  * NB:
252  * * Only bodies with more than one atom or an inner aggregate need to be materialised.
253  * * Ignore inner aggregates (they will be unwound in subsequent applications of this transformer)
254  *
255  * * Copy aggregate body literals into a new clause
256  * * Pull in grounding atoms
257  *
258  * * Set up the head: This will include any local and injected variables in the body.
259  * * Instantiate unnamed variables in count operation (idk why but it's fine)
260  *
261  **/
262  std::set<const Aggregator*> innerAggregates;
263  visitDepthFirst(program, [&](const Aggregator& agg) {
264  visitDepthFirst(agg, [&](const Aggregator& innerAgg) {
265  if (agg != innerAgg) {
266  innerAggregates.insert(&innerAgg);
267  }
268  });
269  });
270 
271  visitDepthFirst(program, [&](const Clause& clause) {
272  visitDepthFirst(clause, [&](const Aggregator& agg) {
273  if (!needsMaterializedRelation(agg)) {
274  return;
275  }
276  // only materialise bottom level aggregates
277  if (innerAggregates.find(&agg) != innerAggregates.end()) {
278  return;
279  }
280  // begin materialisation process
281  auto aggregateBodyRelationName = analysis::findUniqueRelationName(program, "__agg_subclause");
282  auto aggClause = mk<Clause>();
283  // quickly copy in all the literals from the aggregate body
284  for (const auto& lit : agg.getBodyLiterals()) {
285  aggClause->addToBody(souffle::clone(lit));
286  }
287  if (agg.getBaseOperator() == AggregateOp::COUNT) {
288  instantiateUnnamedVariables(*aggClause);
289  }
290  // pull in any necessary grounding atoms
291  groundInjectedParameters(translationUnit, *aggClause, clause, agg);
292  // the head must contain all injected/local variables, but not variables
293  // local to any inner aggregates. So we'll just take a set minus here.
294  // auto aggClauseHead = mk<Atom>(aggregateBodyRelationName);
295  auto* aggClauseHead = new Atom(aggregateBodyRelationName);
296  std::set<std::string> headArguments = distinguishHeadArguments(translationUnit, clause, agg);
297  // insert the head arguments into the head atom
298  for (const auto& variableName : headArguments) {
299  aggClauseHead->addArgument(mk<Variable>(variableName));
300  }
301  aggClause->setHead(Own<Atom>(aggClauseHead));
302  // add them to the relation as well (need to do a bit of type analysis to make this work)
303  auto aggRel = mk<Relation>(aggregateBodyRelationName);
304  std::map<const Argument*, analysis::TypeSet> argTypes =
305  analysis::TypeAnalysis::analyseTypes(translationUnit, *aggClause);
306 
307  for (const auto& cur : aggClauseHead->getArguments()) {
308  // cur will point us to a particular argument
309  // that is found in the aggClause
310  aggRel->addAttribute(mk<Attribute>(toString(*cur),
311  (analysis::isOfKind(argTypes[cur], TypeAttribute::Signed)) ? "number" : "symbol"));
312  }
313  // Set up the aggregate body atom that will represent the materialised relation we just created
314  // and slip in place of the unrestricted literal(s) body.
315  // Now it's time to update the aggregate body atom. We can now
316  // replace the complex body (with literals) with a body with just the single atom referring
317  // to the new relation we just created.
318  // all local variables will be replaced by an underscore
319  // so we should just quickly fetch the set of local variables for this aggregate.
320  auto localVariables = analysis::getLocalVariables(translationUnit, clause, agg);
321  if (agg.getTargetExpression() != nullptr) {
322  const auto* targetExpressionVariable =
323  dynamic_cast<const Variable*>(agg.getTargetExpression());
324  localVariables.erase(targetExpressionVariable->getName());
325  }
326  VecOwn<Argument> args;
327  for (auto arg : aggClauseHead->getArguments()) {
328  if (auto* var = dynamic_cast<ast::Variable*>(arg)) {
329  // replace local variable by underscore if local, only injected or
330  // target variables will appear
331  if (localVariables.find(var->getName()) != localVariables.end()) {
332  args.emplace_back(new UnnamedVariable());
333  continue;
334  }
335  }
336  args.emplace_back(arg->clone());
337  }
338  auto aggAtom =
339  mk<Atom>(aggClauseHead->getQualifiedName(), std::move(args), aggClauseHead->getSrcLoc());
340 
341  VecOwn<Literal> newBody;
342  newBody.push_back(std::move(aggAtom));
343  const_cast<Aggregator&>(agg).setBody(std::move(newBody));
344  // Now we can just add these new things (relation and its single clause) to the program
345  program.addClause(std::move(aggClause));
346  program.addRelation(std::move(aggRel));
347  changed = true;
348  });
349  });
350  return changed;
351 }
352 
354  // everything with more than 1 atom => materialize
355  int countAtoms = 0;
356  const Atom* atom = nullptr;
357  for (const auto& literal : agg.getBodyLiterals()) {
358  const Atom* currentAtom = dynamic_cast<const Atom*>(literal);
359  if (currentAtom != nullptr) {
360  ++countAtoms;
361  atom = currentAtom;
362  }
363  }
364 
365  if (countAtoms > 1) {
366  return true;
367  }
368 
369  bool seenInnerAggregate = false;
370  // If we have an aggregate within this aggregate => materialize
371  visitDepthFirst(agg, [&](const Aggregator& innerAgg) {
372  if (agg != innerAgg) {
373  seenInnerAggregate = true;
374  }
375  });
376 
377  if (seenInnerAggregate) {
378  return true;
379  }
380 
381  // If the same variable occurs several times => materialize
382  bool duplicates = false;
383  std::set<std::string> vars;
384  if (atom != nullptr) {
385  visitDepthFirst(*atom, [&](const ast::Variable& var) {
386  duplicates = duplicates || !vars.insert(var.getName()).second;
387  });
388  }
389 
390  // If there are duplicates a materialization is required
391  // for all others the materialization can be skipped
392  return duplicates;
393 }
394 
395 } // namespace souffle::ast::transform
souffle::ast::analysis::isOfKind
bool isOfKind(const Type &type, TypeAttribute kind)
Check if the type is of a kind corresponding to the TypeAttribute.
Definition: TypeSystem.cpp:189
souffle::ast::transform::MaterializeAggregationQueriesTransformer::groundInjectedParameters
static void groundInjectedParameters(const TranslationUnit &translationUnit, Clause &aggClause, const Clause &originalClause, const Aggregator &aggregate)
Modify the aggClause by adding in grounding literals for every variable that appears in the clause un...
Definition: MaterializeAggregationQueries.cpp:114
TranslationUnit.h
souffle::ast::analysis::getLocalVariables
std::set< std::string > getLocalVariables(const TranslationUnit &tu, const Clause &clause, const Aggregator &aggregate)
Computes the set of local variables in an aggregate expression.
Definition: Aggregate.cpp:55
UnnamedVariable.h
AggregateOp.h
LambdaNodeMapper.h
souffle::ast::analysis::TypeAnalysis::analyseTypes
static std::map< const Argument *, TypeSet > analyseTypes(const TranslationUnit &tu, const Clause &clause, std::ostream *logs=nullptr)
Analyse the given clause and computes for each contained argument a set of potential types.
Definition: Type.cpp:131
souffle::ast::NodeMapper
An abstract class for manipulating AST Nodes by substitution.
Definition: NodeMapper.h:36
souffle::Own
std::unique_ptr< A > Own
Definition: ContainerUtil.h:42
MiscUtil.h
souffle::ast::Clause
Intermediate representation of a horn clause.
Definition: Clause.h:51
souffle::ast::analysis::Variable
A variable to be utilized within constraints to be handled by the constraint solver.
Definition: ConstraintSystem.h:41
Relation.h
Aggregate.h
souffle::ast::transform::MaterializeAggregationQueriesTransformer::needsMaterializedRelation
static bool needsMaterializedRelation(const Aggregator &agg)
A test determining whether the body of a given aggregation needs to be 'outlined' into an independent...
Definition: MaterializeAggregationQueries.cpp:357
souffle::ast::Atom
An atom class.
Definition: Atom.h:51
MaterializeAggregationQueries.h
souffle::ast::analysis::getGroundedTerms
std::map< const Argument *, bool > getGroundedTerms(const TranslationUnit &tu, const Clause &clause)
Analyse the given clause and computes for each contained argument whether it is a grounded value or n...
Definition: Ground.cpp:278
Utils.h
souffle::ast::transform::MaterializeAggregationQueriesTransformer::instantiateUnnamedVariables
static void instantiateUnnamedVariables(Clause &aggClause)
Whatever variables have been left unnamed have significance for a count aggregate.
Definition: MaterializeAggregationQueries.cpp:53
souffle::TypeAttribute::Signed
@ Signed
Attribute.h
souffle::toString
const std::string & toString(const std::string &str)
A generic function converting strings into strings (trivial case).
Definition: StringUtil.h:234
Argument.h
Ground.h
souffle::clone
auto clone(const std::vector< A * > &xs)
Definition: ContainerUtil.h:172
StringUtil.h
Atom.h
Literal.h
souffle::test::count
int count(const C &c)
Definition: table_test.cpp:40
souffle::ast::transform
Definition: Program.h:45
souffle::ast::transform::MaterializeAggregationQueriesTransformer::materializeAggregationQueries
static bool materializeAggregationQueries(TranslationUnit &translationUnit)
Creates artificial relations for bodies of aggregation functions consisting of more than a single ato...
Definition: MaterializeAggregationQueries.cpp:249
souffle::ast::analysis::findUniqueRelationName
std::string findUniqueRelationName(const Program &program, std::string base)
Find a new relation name.
Definition: Aggregate.cpp:190
Node.h
souffle::AggregateOp::COUNT
@ COUNT
Aggregator.h
souffle::ast::Aggregator::getBodyLiterals
std::vector< Literal * > getBodyLiterals() const
Return body literals.
Definition: Aggregator.h:71
souffle::ast::Aggregator
Defines the aggregator class.
Definition: Aggregator.h:53
QualifiedName.h
Program.h
souffle::ast::Aggregator::getBaseOperator
AggregateOp getBaseOperator() const
Return the (base type) operator of the aggregator.
Definition: Aggregator.h:61
souffle::ast::Aggregator::getTargetExpression
const Argument * getTargetExpression() const
Return target expression.
Definition: Aggregator.h:66
Visitor.h
Clause.h
Type.h
souffle::ast::visitDepthFirst
void visitDepthFirst(const Node &root, Visitor< R, Ps... > &visitor, Args &... args)
A utility function visiting all nodes within the ast rooted by the given node recursively in a depth-...
Definition: Visitor.h:273
souffle::ast::transform::MaterializeAggregationQueriesTransformer::distinguishHeadArguments
static std::set< std::string > distinguishHeadArguments(const TranslationUnit &tu, const Clause &clause, const Aggregator &aggregate)
When we materialise an aggregate subclause, it's a good question which variables belong in the head o...
Definition: MaterializeAggregationQueries.cpp:77
TypeAttribute.h
Variable.h
TypeSystem.h
souffle::ast::analysis::getInjectedVariables
std::set< std::string > getInjectedVariables(const TranslationUnit &tu, const Clause &clause, const Aggregator &aggregate)
Given an aggregate and a clause, we find all the variables that have been injected into the aggregate...
Definition: Aggregate.cpp:205