51 std::set<std::string> allVariablesInAggregate;
55 std::set<std::string> localVariables;
56 for (
const std::string& name : allVariablesInAggregate) {
57 if (injectedVariables.find(name) == injectedVariables.end() &&
58 witnessVariables.find(name) == witnessVariables.end()) {
59 localVariables.insert(name);
62 return localVariables;
70 const TranslationUnit& tu,
const Clause& clause,
const Aggregator& aggregate) {
74 struct M :
public NodeMapper {
76 mutable std::set<std::string> aggregatorVariables;
78 const std::set<std::string>& getAggregatorVariables() {
79 return aggregatorVariables;
82 std::unique_ptr<Node> operator()(std::unique_ptr<Node> node)
const override {
83 static int numReplaced = 0;
84 if (
dynamic_cast<Aggregator*
>(node.get()) !=
nullptr) {
86 std::stringstream newVariableName;
87 newVariableName <<
"+aggr_var_" << numReplaced++;
90 aggregatorVariables.insert(newVariableName.str());
92 return std::make_unique<Variable>(newVariableName.str());
99 auto aggregatorlessClause = std::make_unique<Clause>();
100 aggregatorlessClause->setHead(std::make_unique<Atom>(
"*"));
101 for (Literal* lit : clause.getBodyLiterals()) {
105 auto negatedHead = std::make_unique<Negation>(
souffle::clone(clause.getHead()));
106 aggregatorlessClause->addToBody(std::move(negatedHead));
110 aggregatorlessClause->apply(update);
111 auto groundingAtom = std::make_unique<Atom>(
"+grounding_atom");
112 for (std::string variableName : update.getAggregatorVariables()) {
113 groundingAtom->addArgument(std::make_unique<Variable>(variableName));
115 aggregatorlessClause->addToBody(std::move(groundingAtom));
118 auto aggregateSubclause = std::make_unique<Clause>();
119 aggregateSubclause->setHead(mk<Atom>(
"*"));
120 for (
const auto& lit : aggregate.getBodyLiterals()) {
124 std::set<std::string> witnessVariables;
129 if (
const auto* variable =
dynamic_cast<const Variable*
>(argPair.first)) {
130 bool variableIsGrounded = argPair.second;
131 if (!variableIsGrounded) {
134 for (
const auto& aggArgPair : isGroundedInAggregateSubclause) {
135 if (
const auto* var =
dynamic_cast<const Variable*
>(aggArgPair.first)) {
136 bool aggVariableIsGrounded = aggArgPair.second;
137 if (var->getName() == variable->getName() && aggVariableIsGrounded) {
138 witnessVariables.insert(variable->getName());
149 for (
const std::string& injected : injectedVariables) {
150 witnessVariables.erase(injected);
153 return witnessVariables;
159 std::map<std::string, int> variableOccurrences;
160 visitDepthFirst(clause, [&](
const Variable& var) { variableOccurrences[var.getName()]++; });
161 visitDepthFirst(aggregate, [&](
const Variable& var) { variableOccurrences[var.getName()]--; });
162 std::set<std::string> variablesOutsideAggregate;
163 for (
auto const& pair : variableOccurrences) {
164 std::string v = pair.first;
165 int numOccurrences = pair.second;
166 if (numOccurrences > 0) {
167 variablesOutsideAggregate.insert(v);
170 return variablesOutsideAggregate;
174 std::set<std::string> variablesInClause;
175 visitDepthFirst(clause, [&](
const Variable& v) { variablesInClause.insert(v.getName()); });
177 std::string candidate =
base;
178 while (variablesInClause.find(candidate) != variablesInClause.end()) {
186 auto candidate =
base;
187 while (
getRelation(program, candidate) !=
nullptr) {
263 std::set<std::string> variablesInTargetAggregate;
265 [&](
const Variable& variable) { variablesInTargetAggregate.insert(variable.getName()); });
267 std::set<Own<Aggregator>> ancestorAggregates;
271 if (agg == aggregate) {
278 struct ReplaceAggregatesWithVariables :
public NodeMapper {
280 mutable std::set<std::string> aggregatorVariables;
284 std::set<Own<Aggregator>> ancestors;
286 Own<Aggregator> targetAggregate;
288 const std::set<std::string>& getAggregatorVariables() {
289 return aggregatorVariables;
292 ReplaceAggregatesWithVariables(std::set<Own<Aggregator>> ancestors, Own<Aggregator> targetAggregate)
293 : ancestors(
std::move(ancestors)), targetAggregate(
std::move(targetAggregate)) {}
295 std::unique_ptr<Node> operator()(std::unique_ptr<Node> node)
const override {
296 static int numReplaced = 0;
297 if (
auto* aggregate =
dynamic_cast<Aggregator*
>(node.get())) {
301 bool isAncestor =
false;
302 for (
auto& ancestor : ancestors) {
303 if (*ancestor == *aggregate) {
308 if (!isAncestor || *aggregate == *targetAggregate) {
310 std::stringstream newVariableName;
311 newVariableName <<
"+aggr_var_" << numReplaced++;
318 aggregatorVariables.insert(newVariableName.str());
322 return mk<Variable>(newVariableName.str());
331 auto tweakedClause = mk<Clause>();
333 tweakedClause->setHead(mk<Atom>(
"*"));
341 ReplaceAggregatesWithVariables update(std::move(ancestorAggregates),
souffle::clone(&aggregate));
342 tweakedClause->apply(update);
344 auto groundingAtom = mk<Atom>(
"+grounding_atom");
345 for (std::string variableName : update.getAggregatorVariables()) {
346 groundingAtom->addArgument(mk<Variable>(variableName));
349 tweakedClause->addToBody(std::move(groundingAtom));
351 std::set<std::string> injectedVariables;
354 if (
const auto* variable =
dynamic_cast<const Variable*
>(argPair.first)) {
355 bool varIsGrounded = argPair.second;
356 if (varIsGrounded && variablesInTargetAggregate.find(variable->getName()) !=
357 variablesInTargetAggregate.end()) {
359 injectedVariables.insert(variable->getName());
366 [&](
const Variable& v) { injectedVariables.erase(v.getName()); });
369 return injectedVariables;