40 #include <type_traits>
46 using namespace stream_write_qualified_char_as_number;
48 class ExplainProvenanceImpl :
public ExplainProvenance {
52 ExplainProvenanceImpl(SouffleProgram& prog) : ExplainProvenance(prog) {
56 void setup()
override {
58 for (
auto&
rel : prog.getAllRelations()) {
59 std::string name =
rel->getName();
62 if (name.find(
"@info") == std::string::npos) {
68 std::vector<std::string> bodyLiterals;
73 for (
size_t i = 1;
i + 1 <
rel->getArity();
i++) {
76 bodyLiterals.push_back(bodyLit);
82 info.insert({std::make_pair(name.substr(0, name.find(
".@info")), ruleNum), bodyLiterals});
83 rules.insert({std::make_pair(name.substr(0, name.find(
".@info")), ruleNum),
rule});
89 std::string relName, std::vector<RamDomain> tuple,
int ruleNum,
int levelNum,
size_t depthLimit) {
90 std::stringstream joinedArgs;
91 joinedArgs <<
join(decodeArguments(relName, tuple),
", ");
92 auto joinedArgsStr = joinedArgs.str();
96 return mk<LeafNode>(relName +
"(" + joinedArgsStr +
")");
99 assert(info.find(std::make_pair(relName, ruleNum)) != info.end() &&
"invalid rule for tuple");
102 if (depthLimit <= 1) {
103 tuple.push_back(ruleNum);
104 tuple.push_back(levelNum);
108 auto it = std::find(subproofs.begin(), subproofs.end(),
tuple);
109 if (it != subproofs.end()) {
110 idx = it - subproofs.begin();
112 subproofs.push_back(tuple);
113 idx = subproofs.size() - 1;
116 return mk<LeafNode>(
"subproof " + relName +
"(" + std::to_string(idx) +
")");
119 tuple.push_back(levelNum);
122 mk<InnerNode>(relName +
"(" + joinedArgsStr +
")",
"(R" + std::to_string(ruleNum) +
")");
125 std::vector<RamDomain> ret;
128 prog.executeSubroutine(relName +
"_" + std::to_string(ruleNum) +
"_subproof", tuple, ret);
131 size_t tupleCurInd = 0;
132 auto bodyRelations = info[std::make_pair(relName, ruleNum)];
135 for (
auto it = bodyRelations.begin() + 1; it < bodyRelations.end(); it++) {
136 std::string bodyLiteral = *it;
138 std::string bodyRel =
splitString(bodyLiteral,
',')[0];
141 assert(bodyRel.size() > 0 &&
"body of a relation should have positive length");
142 bool isConstraint =
contains(constraintList, bodyRel);
145 auto bodyRelAtomName = bodyRel;
146 if (bodyRel[0] ==
'!' && bodyRel !=
"!=") {
147 bodyRelAtomName = bodyRel.substr(1);
152 size_t auxiliaryArity;
159 arity = prog.getRelation(bodyRelAtomName)->getArity();
160 auxiliaryArity = prog.getRelation(bodyRelAtomName)->getAuxiliaryArity();
162 auto tupleEnd = tupleCurInd + arity;
165 std::vector<RamDomain> subproofTuple;
167 for (; tupleCurInd < tupleEnd - auxiliaryArity; tupleCurInd++) {
168 subproofTuple.push_back(ret[tupleCurInd]);
171 int subproofRuleNum = ret[tupleCurInd];
172 int subproofLevelNum = ret[tupleCurInd + 1];
177 if (bodyRel[0] ==
'!' && bodyRel !=
"!=") {
178 std::stringstream joinedTuple;
179 joinedTuple <<
join(decodeArguments(bodyRelAtomName, subproofTuple),
", ");
180 auto joinedTupleStr = joinedTuple.str();
181 internalNode->add_child(mk<LeafNode>(bodyRel +
"(" + joinedTupleStr +
")"));
182 internalNode->setSize(internalNode->getSize() + 1);
184 }
else if (isConstraint) {
185 std::stringstream joinedConstraint;
190 joinedConstraint << subproofTuple[0] <<
" " << bodyRel <<
" " << subproofTuple[1];
192 joinedConstraint << bodyRel <<
"(\"" << symTable.resolve(subproofTuple[0]) <<
"\", \""
193 << symTable.resolve(subproofTuple[1]) <<
"\")";
196 internalNode->add_child(mk<LeafNode>(joinedConstraint.str()));
197 internalNode->setSize(internalNode->getSize() + 1);
201 explain(bodyRel, subproofTuple, subproofRuleNum, subproofLevelNum, depthLimit - 1);
202 internalNode->setSize(internalNode->getSize() + child->getSize());
203 internalNode->add_child(std::move(child));
206 tupleCurInd = tupleEnd;
212 Own<TreeNode>
explain(std::string relName, std::vector<std::string> args,
size_t depthLimit)
override {
213 auto tuple = argsToNums(relName, args);
215 return mk<LeafNode>(
"Relation not found");
218 std::tuple<int, int> tupleInfo = findTuple(relName,
tuple);
220 int ruleNum = std::get<0>(tupleInfo);
221 int levelNum = std::get<1>(tupleInfo);
223 if (ruleNum < 0 || levelNum == -1) {
224 return mk<LeafNode>(
"Tuple not found");
227 return explain(relName,
tuple, ruleNum, levelNum, depthLimit);
230 Own<TreeNode> explainSubproof(std::string relName,
RamDomain subproofNum,
size_t depthLimit)
override {
231 if (subproofNum >= (
int)subproofs.size()) {
232 return mk<LeafNode>(
"Subproof not found");
235 auto tup = subproofs[subproofNum];
237 auto rel = prog.getRelation(relName);
240 ruleNum = tup[
rel->getArity() -
rel->getAuxiliaryArity()];
243 levelNum = tup[
rel->getArity() -
rel->getAuxiliaryArity() + 1];
245 tup.erase(tup.begin() +
rel->getArity() -
rel->getAuxiliaryArity(), tup.end());
247 return explain(relName, tup, ruleNum, levelNum, depthLimit);
250 std::vector<std::string> explainNegationGetVariables(
251 std::string relName, std::vector<std::string> args,
size_t ruleNum)
override {
252 std::vector<std::string> variables;
255 std::tuple<int, int> foundTuple = findTuple(relName, argsToNums(relName, args));
256 if (std::get<0>(foundTuple) != -1 || std::get<1>(foundTuple) != -1) {
258 return std::vector<std::string>({
"@"});
262 auto atoms = info[std::make_pair(relName, ruleNum)];
266 if (atoms.size() <= 1) {
267 return std::vector<std::string>({
"@fact"});
273 auto isVariable = [&](std::string arg) {
274 return !(
isNumber(arg.c_str()) || arg[0] ==
'\"' || arg ==
"_");
279 std::map<std::string, std::string> headVariableMapping;
280 for (
size_t i = 0;
i < headVariables.size();
i++) {
281 if (!isVariable(headVariables[
i])) {
285 if (headVariableMapping.find(headVariables[
i]) == headVariableMapping.end()) {
286 headVariableMapping[headVariables[
i]] = args[
i];
288 if (headVariableMapping[headVariables[
i]] != args[
i]) {
289 return std::vector<std::string>({
"@non_matching"});
295 std::vector<std::string> uniqueBodyVariables;
296 for (
auto it = atoms.begin() + 1; it < atoms.end(); it++) {
301 for (
auto atomIt = atomRepresentation.begin() + 1; atomIt < atomRepresentation.end(); atomIt++) {
302 if (!isVariable(*atomIt)) {
306 if (!
contains(uniqueBodyVariables, *atomIt) && !
contains(headVariables, *atomIt)) {
307 uniqueBodyVariables.push_back(*atomIt);
312 return uniqueBodyVariables;
315 Own<TreeNode> explainNegation(std::string relName,
size_t ruleNum,
const std::vector<std::string>& tuple,
316 std::map<std::string, std::string>& bodyVariables)
override {
318 std::vector<std::string> uniqueVariables;
321 std::map<std::string, char> variableTypes;
324 auto atoms = info[std::make_pair(relName, ruleNum)];
329 uniqueVariables.insert(uniqueVariables.end(), headVariables.begin(), headVariables.end());
331 auto isVariable = [&](std::string arg) {
332 return !(
isNumber(arg.c_str()) || arg[0] ==
'\"' || arg ==
"_");
336 for (
auto it = atoms.begin() + 1; it < atoms.end(); it++) {
341 for (
auto atomIt = atomRepresentation.begin() + 1; atomIt < atomRepresentation.end(); atomIt++) {
342 if (!
contains(uniqueVariables, *atomIt) && !
contains(headVariables, *atomIt)) {
344 if (!isVariable(*atomIt)) {
348 uniqueVariables.push_back(*atomIt);
350 if (!
contains(constraintList, atomRepresentation[0])) {
352 auto currentRel = prog.getRelation(atomRepresentation[0]);
353 assert(currentRel !=
nullptr &&
354 (
"relation " + atomRepresentation[0] +
" doesn't exist").c_str());
355 variableTypes[*atomIt] =
356 *currentRel->getAttrType(atomIt - atomRepresentation.begin() - 1);
357 }
else if (atomIt->find(
"agg_") != std::string::npos) {
358 variableTypes[*atomIt] =
'i';
364 std::vector<RamDomain> args;
366 size_t varCounter = 0;
372 auto tupleNums = argsToNums(relName, tuple);
373 args.insert(args.end(), tupleNums.begin(), tupleNums.end());
374 varCounter += tuple.size();
376 while (varCounter < uniqueVariables.size()) {
377 auto var = uniqueVariables[varCounter];
378 auto varValue = bodyVariables[var];
379 if (variableTypes[var] ==
's') {
380 if (varValue.size() >= 2 && varValue[0] ==
'"' && varValue[varValue.size() - 1] ==
'"') {
381 auto originalStr = varValue.substr(1, varValue.size() - 2);
382 args.push_back(symTable.lookup(originalStr));
385 args.push_back(symTable.lookup(varValue));
388 args.push_back(std::stoi(varValue));
395 std::vector<RamDomain> ret;
398 prog.executeSubroutine(relName +
"_" + std::to_string(ruleNum) +
"_negation_subproof", args, ret);
401 assert(ret.size() == atoms.size() - 1);
404 std::stringstream joinedArgsStr;
405 joinedArgsStr <<
join(tuple,
",");
406 auto internalNode = mk<InnerNode>(
407 relName +
"(" + joinedArgsStr.str() +
")",
"(R" + std::to_string(ruleNum) +
")");
410 for (
size_t i = 0;
i < headVariables.size();
i++) {
411 bodyVariables[headVariables[
i]] = tuple[
i];
416 int literalCounter = 1;
419 bool atomExists =
true;
420 if (returnCounter == 0) {
425 auto atomRepresentation =
splitString(atoms[literalCounter],
',');
426 std::string bodyRel = atomRepresentation[0];
429 bool isConstraint =
contains(constraintList, bodyRel);
432 auto bodyRelAtomName = bodyRel;
433 if (bodyRel[0] ==
'!' && bodyRel !=
"!=") {
434 bodyRelAtomName = bodyRel.substr(1);
438 std::stringstream childLabel;
441 assert(atomRepresentation.size() == 3 &&
"not a binary constraint");
443 childLabel << bodyVariables[atomRepresentation[1]] <<
" " << bodyRel <<
" "
444 << bodyVariables[atomRepresentation[2]];
446 childLabel << bodyRel <<
"(";
447 for (
size_t i = 1;
i < atomRepresentation.size();
i++) {
449 if (!isVariable(atomRepresentation[
i])) {
450 childLabel << atomRepresentation[
i];
452 childLabel << bodyVariables[atomRepresentation[
i]];
454 if (
i < atomRepresentation.size() - 1) {
463 childLabel <<
" ✓";
468 internalNode->add_child(mk<LeafNode>(childLabel.str()));
469 internalNode->setSize(internalNode->getSize() + 1);
477 std::string getRule(std::string relName,
size_t ruleNum)
override {
478 auto key = make_pair(relName, ruleNum);
480 auto rule = rules.find(key);
481 if (
rule == rules.end()) {
482 return "Rule not found";
488 std::vector<std::string> getRules(
const std::string& relName)
override {
489 std::vector<std::string> relRules;
491 for (
auto&
rule : rules) {
492 if (
rule.first.first == relName) {
493 relRules.push_back(
rule.second);
500 std::string measureRelation(std::string relName)
override {
501 auto rel = prog.getRelation(relName);
503 if (
rel ==
nullptr) {
504 return "No relation found\n";
508 int skip =
size / 10;
514 std::stringstream
ss;
523 if (numTuples % skip != 0) {
528 std::vector<RamDomain> currentTuple;
529 for (arity_type
i = 0;
i <
rel->getPrimaryArity();
i++) {
531 if (*
rel->getAttrType(
i) ==
's') {
534 n = symTable.lookupExisting(s);
535 }
else if (*
rel->getAttrType(
i) ==
'f') {
539 }
else if (*
rel->getAttrType(
i) ==
'u') {
547 currentTuple.push_back(
n);
556 std::cout <<
"Tuples expanded: "
557 <<
explain(relName, currentTuple, ruleNum, levelNum, 10000)->getSize();
564 std::chrono::duration_cast<std::chrono::duration<double>>(tupleEnd - tupleStart);
566 std::cout <<
", Time: " << tupleDuration.count() <<
"\n";
570 auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(after_time - before_time);
572 ss <<
"total: " << proc <<
" ";
578 void printRulesJSON(std::ostream& os)
override {
579 os <<
"\"rules\": [\n";
581 for (
auto const& cur : rules) {
587 os <<
"\t{ \"rule-number\": \"(R" << cur.first.second <<
")\", \"rule\": \""
593 void queryProcess(
const std::vector<std::pair<std::string, std::vector<std::string>>>& rels)
override {
594 std::regex varRegex(
"[a-zA-Z_][a-zA-Z_0-9]*", std::regex_constants::extended);
595 std::regex symbolRegex(
"\"([^\"]*)\"", std::regex_constants::extended);
596 std::regex numberRegex(
"[0-9]+", std::regex_constants::extended);
598 std::smatch argsMatcher;
601 std::map<std::string, Equivalence> nameToEquivalence;
607 std::vector<Relation*> varRels;
613 for (
const auto&
rel : rels) {
616 std::vector<RamDomain> constTuple;
619 std::cout <<
"Relation <" <<
rel.first <<
"> does not exist" << std::endl;
623 if (
relation->getPrimaryArity() !=
rel.second.size()) {
624 std::cout <<
"<" +
rel.first <<
"> has arity of " <<
relation->getPrimaryArity() << std::endl;
629 bool containVar =
false;
630 for (
size_t j = 0;
j <
rel.second.size(); ++
j) {
632 if (std::regex_match(
rel.second[
j], argsMatcher, varRegex)) {
634 auto nameToEquivalenceIter = nameToEquivalence.find(argsMatcher[0]);
637 if (nameToEquivalenceIter == nameToEquivalence.end()) {
638 nameToEquivalence.insert(
640 std::make_pair(idx,
j))});
642 nameToEquivalenceIter->second.push_back(std::make_pair(idx,
j));
650 if (!std::regex_match(
rel.second[
j], argsMatcher, symbolRegex)) {
651 std::cout << argsMatcher.str(0) <<
" does not match type defined in relation"
655 rd = prog.getSymbolTable().lookup(argsMatcher[1]);
659 std::cout <<
rel.second[
j] <<
" does not match type defined in relation"
667 std::cout <<
rel.second[
j] <<
" does not match type defined in relation"
675 std::cout <<
rel.second[
j] <<
" does not match type defined in relation"
684 constConstraints.
push_back(std::make_pair(std::make_pair(idx,
j), rd));
686 constTuple.push_back(rd);
692 bool tupleExist = containsTuple(
relation, constTuple);
702 std::cout <<
"false." << std::endl;
703 std::cout <<
"Tuple " <<
rel.first <<
"(";
704 for (
size_t l = 0;
l <
rel.second.size() - 1; ++
l) {
705 std::cout <<
rel.second[
l] <<
", ";
707 std::cout <<
rel.second.back() <<
") does not exist" << std::endl;
718 if (varRels.size() == 0) {
719 std::cout <<
"true." << std::endl;
724 findQuerySolution(varRels, nameToEquivalence, constConstraints);
728 std::map<std::pair<std::string, size_t>, std::vector<std::string>> info;
729 std::map<std::pair<std::string, size_t>, std::string> rules;
730 std::vector<std::vector<RamDomain>> subproofs;
731 std::vector<std::string> constraintList = {
732 "=",
"!=",
"<",
"<=",
">=",
">",
"match",
"contains",
"not_match",
"not_contains"};
734 std::tuple<int, int> findTuple(
const std::string& relName, std::vector<RamDomain> tup) {
735 auto rel = prog.getRelation(relName);
738 return std::make_tuple(-1, -1);
744 std::vector<RamDomain> currentTuple;
748 if (*
rel->getAttrType(
i) ==
's') {
751 n = symTable.lookupExisting(s);
752 }
else if (*
rel->getAttrType(
i) ==
'f') {
756 }
else if (*
rel->getAttrType(
i) ==
'u') {
764 currentTuple.push_back(
n);
779 return std::make_tuple(ruleNum, levelNum);
784 return std::make_tuple(-1, -1);
794 void findQuerySolution(
const std::vector<Relation*>& varRels,
795 const std::map<std::string, Equivalence>& nameToEquivalence,
796 const ConstConstraint& constConstraints) {
798 std::vector<Relation::iterator> varRelationIterators;
803 size_t solutionCount = 0;
804 std::stringstream solution;
808 bool isSolution =
true;
811 std::vector<tuple> element;
812 for (
auto it : varRelationIterators) {
813 element.push_back(*it);
816 for (
auto var : nameToEquivalence) {
817 if (!var.second.verify(element)) {
824 isSolution = constConstraints.verify(element);
828 std::cout << solution.str();
829 solution.str(std::string());
832 for (
auto&& var : nameToEquivalence) {
833 auto idx = var.second.getFirstIdx();
834 auto raw = element[idx.first][idx.second];
836 solution << var.second.getSymbol() <<
" = ";
837 switch (var.second.getType()) {
838 case 'i': solution << ramBitCast<RamSigned>(raw);
break;
839 case 'f': solution << ramBitCast<RamFloat>(raw);
break;
840 case 'u': solution << ramBitCast<RamUnsigned>(raw);
break;
841 case 's': solution << prog.getSymbolTable().resolve(raw);
break;
842 default:
fatal(
"invalid type: `%c`", var.second.getType());
845 auto sep = ++c < nameToEquivalence.size() ?
", " :
" ";
851 if (1 < solutionCount) {
852 for (std::string input; getline(std::cin, input);) {
853 if (input ==
";")
break;
854 if (input ==
".")
return;
856 std::cout <<
"use ; to find next solution, use . to break from current query\n";
862 size_t i = varRels.size() - 1;
863 bool terminate =
true;
864 for (
auto it = varRelationIterators.rbegin(); it != varRelationIterators.rend(); ++it) {
865 if ((++(*it)) != varRels[
i]->end()) {
869 (*it) = varRels[
i]->begin();
876 if (solutionCount == 0) {
877 std::cout <<
"false." << std::endl;
880 std::cout << solution.str() <<
"." << std::endl;
888 bool containsTuple(Relation*
relation,
const std::vector<RamDomain>& constTuple) {
889 bool tupleExist =
false;
892 for (
size_t j = 0;
j < constTuple.size(); ++
j) {
893 if (constTuple[
j] != (*it)[
j]) {